use super::*;
use crate::config::ProviderEntry;
use rig::client::CompletionClient;
use std::collections::HashMap;
fn mock_env(vars: &[(&str, &str)]) -> impl Fn(&str) -> Option<String> + use<> {
let map: HashMap<String, String> = vars
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect();
move |name: &str| map.get(name).cloned()
}
#[test]
fn auto_detect_returns_none_when_no_vars_set() {
assert_eq!(auto_detect_provider_from(mock_env(&[])), None);
}
#[test]
fn auto_detect_finds_deepseek_when_key_set() {
let env = mock_env(&[("DEEPSEEK_API_KEY", "sk-test-123")]);
assert_eq!(auto_detect_provider_from(env), Some("deepseek"));
}
#[test]
fn auto_detect_finds_openai_when_key_set() {
let env = mock_env(&[("OPENAI_API_KEY", "sk-test-456")]);
assert_eq!(auto_detect_provider_from(env), Some("openai"));
}
#[test]
fn auto_detect_skips_empty_var() {
let env = mock_env(&[("DEEPSEEK_API_KEY", ""), ("OPENAI_API_KEY", "sk-test-789")]);
assert_eq!(auto_detect_provider_from(env), Some("openai"));
}
#[test]
fn auto_detect_returns_first_match_in_order() {
let env = mock_env(&[("DEEPSEEK_API_KEY", "sk-ds"), ("OPENAI_API_KEY", "sk-oai")]);
assert_eq!(auto_detect_provider_from(env), Some("deepseek"));
}
#[test]
fn auto_detect_each_provider_in_isolation() {
for &(env_var, expected) in PROVIDER_AUTODETECT_ORDER {
let env = mock_env(&[(env_var, "sk-x")]);
assert_eq!(
auto_detect_provider_from(env),
Some(expected),
"env_var={env_var}",
);
}
}
#[test]
fn auto_detect_zhipu_api_key_resolves_to_glm() {
let env = mock_env(&[("ZHIPU_API_KEY", "fake-zhipu-key")]);
assert_eq!(auto_detect_provider_from(env), Some("glm"));
}
#[test]
fn auto_detect_glm_api_key_wins_over_zhipu_when_both_set() {
let env = mock_env(&[("GLM_API_KEY", "primary"), ("ZHIPU_API_KEY", "fallback")]);
assert_eq!(auto_detect_provider_from(env), Some("glm"));
}
#[test]
fn fallback_list_covers_canonical_alternatives() {
assert_eq!(
provider_env_var_fallbacks(ProviderKind::Glm),
&["ZHIPU_API_KEY"]
);
assert_eq!(
provider_env_var_fallbacks(ProviderKind::Anthropic),
&["ANTHROPIC_OAUTH_TOKEN"]
);
assert_eq!(
provider_env_var_fallbacks(ProviderKind::Gemini),
&["GOOGLE_GENERATIVE_AI_API_KEY", "GOOGLE_API_KEY"]
);
for kind in [
ProviderKind::OpenAI,
ProviderKind::DeepSeek,
ProviderKind::OpenRouter,
ProviderKind::Ollama,
ProviderKind::Custom,
] {
assert!(
provider_env_var_fallbacks(kind).is_empty(),
"no fallback expected for {kind:?}",
);
}
}
fn build_openai_any_agent() -> AnyAgent {
use rig::providers::openai;
let client = openai::CompletionsClient::builder()
.api_key("test-key")
.build()
.expect("openai CompletionsClient::new should work");
let model = client.completion_model("gpt-4o");
let agent = rig::agent::AgentBuilder::new(model).build();
AnyAgent::new(
AnyAgentInner::OpenAI(agent),
ToolCache::new(),
std::time::Duration::from_secs(300),
Vec::new(), String::new(), "gpt-4o".to_string(),
)
}
#[test]
fn build_stream_fn_returns_send_sync_static() {
fn assert_send_sync_static<T: Send + Sync + 'static>(_: &T) {}
let agent = build_openai_any_agent();
let stream_fn = agent.build_stream_fn(vec![]);
assert_send_sync_static(&stream_fn);
}
#[tokio::test]
async fn build_stream_fn_is_multi_callable() {
use crate::agent::agent_loop::LlmContext;
use crate::agent::agent_loop::tool::AbortSignal;
use futures::stream::StreamExt;
let agent = build_openai_any_agent();
let stream_fn = agent.build_stream_fn(vec![]);
let ctx = LlmContext {
system_prompt: String::new(),
messages: vec![],
};
let mut s = stream_fn(
ctx,
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
);
let first = s.next().await;
assert!(first.is_some(), "first call should produce events");
let ctx2 = LlmContext {
system_prompt: String::new(),
messages: vec![],
};
let mut s2 = stream_fn(
ctx2,
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
);
let second = s2.next().await;
assert!(second.is_some(), "second call should also produce events");
}
#[test]
fn build_stream_fn_covers_all_variants_compile_time() {
let agent = build_openai_any_agent();
let _ = agent.build_stream_fn(vec![]);
}
#[tokio::test]
async fn any_model_filtered_stream_fn_hides_unloaded_dynamic_tools() {
use crate::agent::agent_loop::stream::{LlmContext, StreamOptions};
use crate::agent::agent_loop::tool::AbortSignal;
use futures::StreamExt;
use rig::completion::ToolDefinition;
use rig::providers::openai;
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
fn tool_def(name: &str) -> ToolDefinition {
ToolDefinition {
name: name.to_string(),
description: format!("{name} description"),
parameters: serde_json::json!({"type": "object", "properties": {}}),
}
}
fn header_end(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|window| window == b"\r\n\r\n")
}
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let base_url = format!("http://{}/v1", listener.local_addr().unwrap());
let (body_tx, body_rx) = tokio::sync::oneshot::channel();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = Vec::new();
let body_start = loop {
let mut chunk = [0u8; 1024];
let read = socket.read(&mut chunk).await.unwrap();
assert!(read > 0, "client closed before sending request headers");
buf.extend_from_slice(&chunk[..read]);
if let Some(end) = header_end(&buf) {
break end + 4;
}
};
let headers = String::from_utf8_lossy(&buf[..body_start]);
let content_length = headers
.lines()
.filter_map(|line| line.split_once(':'))
.find_map(|(name, value)| {
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap();
while buf.len() < body_start + content_length {
let mut chunk = [0u8; 1024];
let read = socket.read(&mut chunk).await.unwrap();
assert!(read > 0, "client closed before sending full request body");
buf.extend_from_slice(&chunk[..read]);
}
let body = serde_json::from_slice::<serde_json::Value>(
&buf[body_start..body_start + content_length],
)
.unwrap();
body_tx.send(body).ok();
socket
.write_all(
b"HTTP/1.1 500 Internal Server Error\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
)
.await
.unwrap();
});
let client = openai::CompletionsClient::builder()
.api_key("test-key")
.base_url(&base_url)
.build()
.unwrap();
let model = AnyModel::OpenAI(client.completion_model("gpt-4o"));
let loaded = Arc::new(Mutex::new(HashSet::from(["mcp_loaded".to_string()])));
let stream_fn = model.build_stream_fn_with_filter(
vec![tool_def("mcp_loaded"), tool_def("mcp_hidden")],
std::time::Duration::from_secs(5),
Some("openai".to_string()),
Some(loaded),
);
let mut stream = stream_fn(
LlmContext {
system_prompt: String::new(),
messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
},
StreamOptions::from_signal(AbortSignal::new()),
);
while stream.next().await.is_some() {}
let body = body_rx.await.unwrap();
server.await.unwrap();
let tool_names: Vec<_> = body["tools"]
.as_array()
.unwrap()
.iter()
.map(|tool| tool["function"]["name"].as_str().unwrap())
.collect();
assert!(tool_names.contains(&"mcp_loaded"));
assert!(!tool_names.contains(&"mcp_hidden"));
}
#[test]
fn review_runner_gets_isolated_cache_dirge_7ls() {
let agent = build_openai_any_agent();
let parent_cache = agent.cache().clone();
let fresh_cache = ToolCache::new();
assert!(
!fresh_cache.shares_storage_with(&parent_cache),
"ToolCache::new() must produce a distinct Arc — review runner relies on this for isolation"
);
let parent_clone = parent_cache.clone();
assert!(
parent_clone.shares_storage_with(&parent_cache),
"ToolCache::clone() must share storage — main-agent/subagent path depends on this"
);
parent_cache.set("key", "value".to_string());
assert_eq!(parent_clone.get("key"), Some("value".to_string()));
parent_cache.clear();
assert!(parent_clone.get("key").is_none());
}
#[test]
fn curator_runner_is_skill_only_dirge_yai1() {
use super::filter_tool_names;
let registered_tools = [
"read",
"write",
"edit",
"bash",
"grep",
"find_files",
"glob",
"list_dir",
"write_todo_list",
"apply_patch",
"session_search",
"memory",
"skill",
"task",
"question",
];
let iter_names = || registered_tools.iter().copied();
let review_filter = filter_tool_names(iter_names(), &["memory", "skill"]);
assert_eq!(
review_filter,
vec!["memory".to_string(), "skill".to_string()],
"review filter must be memory + skill in registration order"
);
let curator_filter = filter_tool_names(iter_names(), &["skill"]);
assert_eq!(
curator_filter,
vec!["skill".to_string()],
"curator filter must contain ONLY skill — dirge-yai1"
);
assert!(
!curator_filter.iter().any(|n| n == "memory"),
"curator filter MUST NOT include memory — model cannot write entries even if it tried"
);
for name in &curator_filter {
assert!(
review_filter.contains(name),
"curator-only tool '{}' not in review filter — review must be a superset",
name
);
}
assert!(
review_filter.len() > curator_filter.len(),
"review filter must be strictly larger than curator filter"
);
for forbidden in ["read", "write", "edit", "bash", "task", "session_search"] {
assert!(
!review_filter.contains(&forbidden.to_string()),
"review must not expose '{}'",
forbidden
);
assert!(
!curator_filter.contains(&forbidden.to_string()),
"curator must not expose '{}'",
forbidden
);
}
}
#[cfg(feature = "mcp")]
#[test]
fn filter_loop_tools_is_a_hard_allowlist() {
use crate::agent::agent_loop::LoopTool;
use std::sync::Arc;
let tools: Vec<Arc<dyn LoopTool>> = vec![
Arc::new(NamedTool("read")),
Arc::new(NamedTool("grep")),
Arc::new(NamedTool("write")),
Arc::new(NamedTool("bash")),
];
let names = |kept: &[Arc<dyn LoopTool>]| {
kept.iter()
.map(|t| t.name().to_string())
.collect::<Vec<_>>()
};
let kept = crate::provider::spawn::filter_loop_tools(&tools, &["read", "grep"]);
assert_eq!(names(&kept), vec!["read", "grep"]);
let kept = crate::provider::spawn::filter_loop_tools(&tools, &["read", "bash"]);
assert_eq!(names(&kept), vec!["read", "bash"]);
assert!(
!names(&kept).iter().any(|n| n == "write"),
"reviewer fork must not expose write"
);
assert!(crate::provider::spawn::filter_loop_tools(&tools, &["nonexistent"]).is_empty());
assert!(crate::provider::spawn::filter_loop_tools(&tools, &[]).is_empty());
}
#[cfg(feature = "mcp")]
#[test]
fn swap_in_review_memory_replaces_only_the_memory_tool() {
use crate::agent::agent_loop::LoopTool;
use std::sync::Arc;
let original_memory: Arc<dyn LoopTool> = Arc::new(NamedTool("memory"));
let skill: Arc<dyn LoopTool> = Arc::new(NamedTool("skill"));
let review_memory: Arc<dyn LoopTool> = Arc::new(NamedTool("memory"));
let mut tools = vec![original_memory.clone(), skill.clone()];
crate::provider::spawn::swap_in_review_memory(&mut tools, &review_memory);
assert!(
Arc::ptr_eq(&tools[0], &review_memory),
"memory slot now points at the review tool",
);
assert!(
!Arc::ptr_eq(&tools[0], &original_memory),
"the original memory tool was replaced",
);
assert!(Arc::ptr_eq(&tools[1], &skill), "skill tool untouched");
let mut skill_only = vec![skill.clone()];
crate::provider::spawn::swap_in_review_memory(&mut skill_only, &review_memory);
assert!(
Arc::ptr_eq(&skill_only[0], &skill),
"no memory tool → unchanged"
);
}
#[test]
fn with_review_route_stashes_alternate_route_dirge_z73i() {
use crate::agent::agent_loop::message::StreamEvent;
use std::sync::Arc;
let agent = build_openai_any_agent();
assert!(
agent.review_stream_fn.is_none(),
"fresh agent has no review route by default"
);
assert!(agent.review_provider_name.is_none());
assert!(agent.review_model_name.is_none());
let dummy: crate::agent::agent_loop::StreamFn = Arc::new(|_ctx, _opts| {
Box::pin(futures::stream::iter(vec![StreamEvent::Error {
error: "from-review-route".to_string(),
}]))
});
let agent = agent.with_review_route(dummy.clone(), "glm".to_string(), "glm-4.6".to_string());
assert!(agent.review_stream_fn.is_some(), "stream_fn stashed");
assert_eq!(agent.review_provider_name.as_deref(), Some("glm"));
assert_eq!(agent.review_model_name.as_deref(), Some("glm-4.6"));
}
#[tokio::test]
async fn with_summarizer_stashes_summarize_fn_dirge_008x() {
use std::sync::Arc;
let agent = build_openai_any_agent();
assert!(
agent.summarize_fn.is_none(),
"fresh AnyAgent::new agent has no summarizer by default"
);
let dummy: crate::agent::compression::SummarizeFn =
Arc::new(|prompt: String| Box::pin(async move { Ok(format!("summary of: {prompt}")) }));
let agent = agent.with_summarizer(dummy);
let stashed = agent
.summarize_fn
.as_ref()
.expect("summarizer stashed after with_summarizer");
let out = stashed("hello".to_string()).await.unwrap();
assert_eq!(out, "summary of: hello");
}
use super::summarize;
use crate::session::{MessageRole, SessionMessage, ToolCallEntry, ToolCallState};
use compact_str::CompactString;
fn sm(role: MessageRole, content: &str, tool_calls: Vec<ToolCallEntry>) -> SessionMessage {
SessionMessage {
role,
content: CompactString::from(content),
estimated_tokens: 0,
id: CompactString::from("test-id"),
timestamp: 0,
tool_calls,
}
}
#[test]
fn serialize_conversation_includes_tool_calls() {
let msgs = vec![
sm(MessageRole::User, "list rust files", vec![]),
sm(
MessageRole::Assistant,
"I'll find them.",
vec![ToolCallEntry {
id: "call_1".into(),
name: "find_files".into(),
args: serde_json::json!({"pattern": "*.rs"}),
state: ToolCallState::Completed {
result: "src/main.rs\nsrc/lib.rs".into(),
},
}],
),
];
let out = summarize::serialize_conversation(&msgs);
assert!(out.contains("[User]"), "missing role tag: {out}");
assert!(
out.contains("[Tool: find_files("),
"missing tool call line: {out}"
);
assert!(
out.contains("src/main.rs"),
"missing tool result content: {out}"
);
}
#[test]
fn serialize_conversation_marks_interrupted_and_failed() {
let msgs = vec![sm(
MessageRole::Assistant,
"trying",
vec![
ToolCallEntry {
id: "a".into(),
name: "bash".into(),
args: serde_json::json!({"command": "sleep 9999"}),
state: ToolCallState::Interrupted,
},
ToolCallEntry {
id: "b".into(),
name: "read".into(),
args: serde_json::json!({"path": "/missing"}),
state: ToolCallState::Failed {
error: "no such file".into(),
},
},
],
)];
let out = summarize::serialize_conversation(&msgs);
assert!(out.contains("<interrupted>"), "got: {out}");
assert!(out.contains("<failed: no such file>"), "got: {out}");
}
#[test]
fn serialize_conversation_truncates_huge_tool_results() {
let big: String = "x".repeat(5000);
let msgs = vec![sm(
MessageRole::Assistant,
"huge",
vec![ToolCallEntry {
id: "c".into(),
name: "grep".into(),
args: serde_json::json!({"pattern": "."}),
state: ToolCallState::Completed { result: big },
}],
)];
let out = summarize::serialize_conversation(&msgs);
assert!(
out.contains("(truncated, 5000 bytes total)"),
"expected truncation marker; got: {out}"
);
}
#[test]
fn serialize_conversation_returns_full_prefix() {
let msgs: Vec<SessionMessage> = (0..200)
.map(|i| sm(MessageRole::Assistant, &format!("turn {i}"), vec![]))
.collect();
let out = summarize::serialize_conversation(&msgs);
assert!(out.contains("turn 199"), "tail must be present: {out}");
assert!(out.contains("turn 0"), "head must be present: {out}");
}
#[test]
fn custom_provider_https_is_allowed() {
let custom = std::collections::HashMap::from([(
"my-proxy".to_string(),
ProviderEntry {
provider_type: Some("custom".to_string()),
base_url: Some("https://my-proxy.example.com/v1".to_string()),
..Default::default()
},
)]);
let result = resolve_provider_info("my-proxy", &custom);
assert!(result.is_some(), "https provider should resolve");
}
#[test]
fn custom_provider_http_rejected_without_allow_insecure() {
let custom = std::collections::HashMap::from([(
"bad-proxy".to_string(),
ProviderEntry {
provider_type: Some("custom".to_string()),
base_url: Some("http://bad-proxy.example.com/v1".to_string()),
..Default::default()
},
)]);
let result = resolve_provider_info("bad-proxy", &custom);
assert!(
result.is_none(),
"http provider without allow_insecure should be rejected"
);
}
#[test]
fn custom_provider_http_allowed_with_allow_insecure() {
let custom = std::collections::HashMap::from([(
"local-ollama".to_string(),
ProviderEntry {
provider_type: Some("custom".to_string()),
base_url: Some("http://localhost:11434/v1".to_string()),
allow_insecure: true,
..Default::default()
},
)]);
let result = resolve_provider_info("local-ollama", &custom);
assert!(
result.is_some(),
"http provider with allow_insecure should be accepted"
);
}
#[test]
fn default_model_for_entry_resolves_alias_provider_type() {
let entry = ProviderEntry {
provider_type: Some("openai".to_string()),
base_url: Some("https://proxy.internal/v1".to_string()),
..Default::default()
};
assert_eq!(default_model_for_entry("my-openai", &entry), "gpt-4o");
let anthropic = ProviderEntry {
provider_type: Some("anthropic".to_string()),
..Default::default()
};
assert_eq!(
default_model_for_entry("work-claude", &anthropic),
"claude-sonnet-4-6"
);
}
#[test]
fn default_model_for_alias_uses_map_then_builtin_fallback() {
let providers = HashMap::from([(
"my-openai".to_string(),
ProviderEntry {
provider_type: Some("openai".to_string()),
..Default::default()
},
)]);
assert_eq!(default_model_for_alias("my-openai", &providers), "gpt-4o");
assert_eq!(
default_model_for_alias("anthropic", &providers),
"claude-sonnet-4-6"
);
assert_eq!(default_model_for("my-openai"), "deepseek/deepseek-v4-flash");
}
#[test]
fn plugin_provider_builtin_name_collision_rejected() {
let res = validate_custom_provider(
"openai",
"https://evil.example.com/v1",
false,
true,
);
assert!(
res.is_err(),
"plugin shadowing a built-in name must be rejected"
);
assert!(res.unwrap_err().contains("collides with built-in"));
}
#[test]
fn config_alias_of_builtin_name_with_base_url_is_accepted() {
let providers = std::collections::HashMap::from([(
"ollama".to_string(),
ProviderEntry {
provider_type: Some("openai".to_string()),
base_url: Some("http://localhost:11434/v1".to_string()),
allow_insecure: true,
..Default::default()
},
)]);
let result = resolve_provider_info("ollama", &providers);
assert!(
result.is_some(),
"config-declared alias of a built-in name should be accepted"
);
let info = result.unwrap();
assert_eq!(info.kind, ProviderKind::OpenAI);
assert_eq!(info.base_url.as_deref(), Some("http://localhost:11434/v1"));
}
#[test]
fn config_alias_still_enforces_url_scheme() {
let res = validate_custom_provider(
"openai",
"http://evil.example.com/v1", false, false,
);
assert!(
res.is_err(),
"config alias must still reject insecure http:// without allow_insecure"
);
assert!(res.unwrap_err().contains("insecure base_url"));
}
#[tokio::test]
async fn compaction_rejects_input_containing_delimiter() {
use rig::providers::openai;
let inner = openai::CompletionsClient::builder()
.api_key("test-key")
.base_url("http://127.0.0.1:1/v1")
.build()
.expect("build custom client");
let client = AnyClient::Custom(inner);
let poisoned = format!(
"innocent text {} attacker payload {} more",
crate::agent::prompt::COMPACTION_DELIMITER_OPEN,
crate::agent::prompt::COMPACTION_DELIMITER_CLOSE,
);
let msgs = vec![sm(MessageRole::User, &poisoned, vec![])];
let result = client
.compress_messages("test-model", &msgs, None, None)
.await;
assert!(
result.is_err(),
"compaction must reject input containing the reserved delimiter"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("reserved delimiter"),
"error should mention the reserved-delimiter reason, got: {err}"
);
}
#[tokio::test]
async fn compaction_passes_check_on_clean_input() {
use rig::providers::openai;
let inner = openai::CompletionsClient::builder()
.api_key("test-key")
.base_url("http://127.0.0.1:1/v1")
.build()
.expect("build custom client");
let client = AnyClient::Custom(inner);
let msgs = vec![sm(
MessageRole::User,
"ordinary message, no markers",
vec![],
)];
let result = client
.compress_messages("test-model", &msgs, None, None)
.await;
assert!(result.is_err(), "expected network/auth failure");
let err = result.unwrap_err().to_string();
assert!(
!err.contains("reserved delimiter"),
"clean input must NOT trip the delimiter check, got: {err}"
);
}
#[cfg(feature = "mcp")]
#[derive(Debug)]
struct NamedTool(&'static str);
#[cfg(feature = "mcp")]
impl crate::agent::agent_loop::LoopTool for NamedTool {
fn name(&self) -> &str {
self.0
}
fn description(&self) -> &str {
"test"
}
fn label(&self) -> &str {
"test"
}
fn parameters(&self) -> &serde_json::Value {
static EMPTY: std::sync::OnceLock<serde_json::Value> = std::sync::OnceLock::new();
EMPTY.get_or_init(|| serde_json::json!({"type": "object"}))
}
fn execute<'a>(
&'a self,
_id: &'a str,
_args: serde_json::Value,
_signal: crate::agent::agent_loop::tool::AbortSignal,
_on_update: crate::agent::agent_loop::tool::LoopToolUpdate,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<crate::agent::agent_loop::LoopToolResult, String>,
> + Send
+ 'a,
>,
> {
Box::pin(async move { Ok(crate::agent::agent_loop::LoopToolResult::default()) })
}
}
#[cfg(feature = "mcp")]
#[test]
fn extend_loop_tools_adds_injected_to_search_registry_not_loaded() {
use crate::agent::tools::tool_search::ToolMeta;
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
let filter: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
let registry: Arc<Mutex<Vec<ToolMeta>>> = Arc::new(Mutex::new(Vec::new()));
let mut agent =
build_openai_any_agent().with_dynamic_tool_search(filter.clone(), registry.clone());
let tools: Vec<Arc<dyn crate::agent::agent_loop::LoopTool>> = vec![
Arc::new(NamedTool("mcp_alpha")),
Arc::new(NamedTool("mcp_beta")),
];
agent.extend_loop_tools(tools);
assert_eq!(agent.loop_tools.len(), 2);
let reg = registry.lock().unwrap();
assert!(
reg.iter().any(|m| m.name == "mcp_alpha"),
"reg missing alpha"
);
assert!(reg.iter().any(|m| m.name == "mcp_beta"), "reg missing beta");
assert!(
filter.lock().unwrap().is_empty(),
"injected tools must not be pre-loaded — discovered via tool_search"
);
}
#[cfg(feature = "mcp")]
#[test]
fn extend_loop_tools_without_dynamic_search_only_grows_registry() {
use std::sync::Arc;
let mut agent = build_openai_any_agent(); let tools: Vec<Arc<dyn crate::agent::agent_loop::LoopTool>> = vec![Arc::new(NamedTool("x"))];
agent.extend_loop_tools(tools);
assert_eq!(agent.loop_tools.len(), 1);
assert!(agent.tool_def_filter.is_none());
assert!(agent.tool_search_registry.is_none());
}
#[cfg(feature = "mcp")]
#[test]
fn injected_tool_is_gated_then_visible_then_dispatchable() {
use crate::agent::agent_loop::loop_tool_to_rig_definition;
use crate::agent::agent_loop::rig_stream_factory::filter_tool_defs;
use crate::agent::tools::tool_search::{ToolMeta, rank_tools};
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
let filter: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
let registry: Arc<Mutex<Vec<ToolMeta>>> = Arc::new(Mutex::new(Vec::new()));
let mut agent =
build_openai_any_agent().with_dynamic_tool_search(filter.clone(), registry.clone());
agent.extend_loop_tools(vec![Arc::new(NamedTool("mcp_demo"))]);
let defs: Vec<_> = agent
.loop_tools
.iter()
.map(|t| loop_tool_to_rig_definition(t.as_ref()))
.collect();
assert!(
defs.iter().any(|d| d.name == "mcp_demo"),
"injected tool must be in the def list"
);
let before = filter_tool_defs(&defs, Some(&filter));
assert!(
!before.iter().any(|d| d.name == "mcp_demo"),
"must be hidden until discovered via tool_search"
);
{
let reg = registry.lock().unwrap();
let hits = rank_tools(®, "mcp_demo", 5);
assert!(
hits.iter().any(|m| m.name == "mcp_demo"),
"tool_search must be able to discover the injected tool"
);
}
filter.lock().unwrap().insert("mcp_demo".to_string());
let after = filter_tool_defs(&defs, Some(&filter));
assert!(
after.iter().any(|d| d.name == "mcp_demo"),
"must ship in the request once discovered"
);
assert!(
agent.loop_tools.iter().any(|t| t.name() == "mcp_demo"),
"dispatch must find the tool by name"
);
}