#[allow(unused_imports)]
use crate::sync_util::LockExt;
use std::sync::Arc;
use rig::OneOrMany;
#[cfg(test)]
use rig::completion::CompletionError;
use rig::completion::message::{AssistantContent, Message, Reasoning, ToolCall, ToolFunction};
use rig::completion::{CompletionModel, CompletionRequestBuilder, GetTokenUsage, ToolDefinition};
use serde_json::Value;
use super::message::StreamEvent;
use super::rig_stream::wrap_rig_stream;
use super::stream::{LlmContext, StreamFn};
use super::tool::LoopTool;
use futures::Stream;
use std::pin::Pin;
#[cfg(test)]
pub fn rig_stream_fn_from_model<M>(
model: M,
tools: Vec<ToolDefinition>,
chunk_timeout: Option<std::time::Duration>,
) -> StreamFn
where
M: CompletionModel + Clone + Send + Sync + 'static,
M::StreamingResponse: Clone + Unpin + Send + Sync + GetTokenUsage + 'static,
{
rig_stream_fn_from_model_with_provider(model, tools, chunk_timeout, None)
}
#[allow(dead_code)]
pub fn rig_stream_fn_from_model_with_provider<M>(
model: M,
tools: Vec<ToolDefinition>,
chunk_timeout: Option<std::time::Duration>,
provider_name: Option<String>,
) -> StreamFn
where
M: CompletionModel + Clone + Send + Sync + 'static,
M::StreamingResponse: Clone + Unpin + Send + Sync + GetTokenUsage + 'static,
{
rig_stream_fn_from_model_with_filter(model, tools, chunk_timeout, provider_name, None)
}
pub fn rig_stream_fn_from_model_with_filter<M>(
model: M,
tools: Vec<ToolDefinition>,
chunk_timeout: Option<std::time::Duration>,
provider_name: Option<String>,
tool_def_filter: Option<std::sync::Arc<std::sync::Mutex<std::collections::HashSet<String>>>>,
) -> StreamFn
where
M: CompletionModel + Clone + Send + Sync + 'static,
M::StreamingResponse: Clone + Unpin + Send + Sync + GetTokenUsage + 'static,
{
let tools = Arc::new(tools);
let provider_name = Arc::new(provider_name);
let filter = Arc::new(tool_def_filter);
Arc::new(move |ctx: LlmContext, opts: super::stream::StreamOptions| {
let model = model.clone();
let tools = tools.clone();
let provider_name = provider_name.clone();
let filter = filter.clone();
invoke_one_stream(
model,
tools,
ctx,
chunk_timeout,
opts,
provider_name,
filter,
)
})
}
fn invoke_one_stream<M>(
model: M,
tools: Arc<Vec<ToolDefinition>>,
ctx: LlmContext,
chunk_timeout: Option<std::time::Duration>,
opts: super::stream::StreamOptions,
provider_name: Arc<Option<String>>,
tool_def_filter: Arc<
Option<std::sync::Arc<std::sync::Mutex<std::collections::HashSet<String>>>>,
>,
) -> Pin<Box<dyn Stream<Item = StreamEvent> + Send>>
where
M: CompletionModel + Clone + Send + Sync + 'static,
M::StreamingResponse: Clone + Unpin + Send + Sync + GetTokenUsage + 'static,
{
Box::pin(async_stream::stream! {
let rig_messages: Vec<Message> = ctx
.messages
.iter()
.filter_map(value_to_rig_message)
.collect();
let (prompt, history) = if rig_messages.is_empty() {
yield StreamEvent::Error {
error: "rig_stream_fn: empty message list — no prompt to send".to_string(),
};
return;
} else {
let mut messages = rig_messages;
let last = messages.pop().unwrap();
(last, messages)
};
let mut builder = CompletionRequestBuilder::new(model.clone(), prompt);
let system_prompt = ctx.system_prompt;
let history_len = history.len();
let outgoing_tools: Vec<ToolDefinition> =
filter_tool_defs(&tools, tool_def_filter.as_ref().as_ref());
let provider: Option<&str> = provider_name.as_ref().as_deref();
emit_cache_prefix_event(
provider,
&system_prompt,
&outgoing_tools,
history_len,
);
if !system_prompt.is_empty() {
builder = builder.preamble(system_prompt);
}
builder = builder.messages(history);
if !outgoing_tools.is_empty() {
builder = builder.tools(outgoing_tools);
}
let additional = build_provider_additional_params(provider, &opts);
if let Some(v) = additional {
builder = builder.additional_params(v);
}
let request = builder.build();
match model.stream(request).await {
Ok(response) => {
let mut wrapped = wrap_rig_stream(response, chunk_timeout, Some(opts.signal.clone()));
use futures::stream::StreamExt;
while let Some(evt) = wrapped.next().await {
yield evt;
}
}
Err(e) => {
yield StreamEvent::Error {
error: format!("rig stream call failed: {e}"),
};
}
}
})
}
fn emit_cache_prefix_event(
provider: Option<&str>,
system_prompt: &str,
tools: &[ToolDefinition],
history_len: usize,
) {
use std::hash::{Hash, Hasher};
let mut h_system = std::collections::hash_map::DefaultHasher::new();
system_prompt.hash(&mut h_system);
let system_hash = h_system.finish();
let mut tool_names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
tool_names.sort_unstable();
let mut h_tools = std::collections::hash_map::DefaultHasher::new();
for n in &tool_names {
n.hash(&mut h_tools);
0u8.hash(&mut h_tools);
}
let tools_hash = h_tools.finish();
tracing::debug!(
target: "dirge::prompt_cache",
provider = provider.unwrap_or("unknown"),
system_hash = format!("{system_hash:016x}"),
tools_hash = format!("{tools_hash:016x}"),
tool_count = tools.len(),
system_bytes = system_prompt.len(),
history_len = history_len,
"prompt_cache_prefix"
);
}
pub fn filter_tool_defs(
tools: &[ToolDefinition],
filter: Option<&std::sync::Arc<std::sync::Mutex<std::collections::HashSet<String>>>>,
) -> Vec<ToolDefinition> {
match filter {
None => tools.to_vec(),
Some(arc) => {
let loaded = arc.lock_ignore_poison();
let always_on: std::collections::HashSet<&str> =
crate::agent::tools::tool_search::ALWAYS_ON_TOOLS
.iter()
.copied()
.collect();
tools
.iter()
.filter(|td| always_on.contains(td.name.as_str()) || loaded.contains(&td.name))
.cloned()
.collect()
}
}
}
pub fn value_to_rig_message(value: &Value) -> Option<Message> {
let role = value.get("role").and_then(|r| r.as_str())?;
match role {
"user" => {
let content = value.get("content").and_then(|c| c.as_str())?;
Some(Message::user(content))
}
"assistant" => {
let blocks = value.get("content").and_then(|c| c.as_array())?;
let assistant_contents: Vec<AssistantContent> = blocks
.iter()
.filter_map(value_to_assistant_content)
.collect();
let content = OneOrMany::many(assistant_contents).ok()?;
Some(Message::Assistant { id: None, content })
}
"tool" | "toolResult" => {
let tool_call_id = value
.get("toolCallId")
.or_else(|| value.get("tool_call_id"))
.and_then(|c| c.as_str())?;
let text = value
.get("content")
.and_then(|c| {
if let Some(s) = c.as_str() {
Some(s.to_string())
} else if let Some(blocks) = c.as_array() {
let joined = blocks
.iter()
.filter_map(|b| {
b.as_object().and_then(|o| {
if o.get("type").and_then(|t| t.as_str()) == Some("text") {
o.get("text").and_then(|t| t.as_str()).map(String::from)
} else {
None
}
})
})
.collect::<Vec<_>>()
.join("\n");
Some(joined)
} else {
None
}
})
.unwrap_or_default();
Some(Message::tool_result(tool_call_id, text))
}
_ => None,
}
}
fn value_to_assistant_content(block: &Value) -> Option<AssistantContent> {
let obj = block.as_object()?;
let kind = obj.get("type").and_then(|t| t.as_str())?;
match kind {
"text" => {
let text = obj.get("text").and_then(|t| t.as_str())?;
Some(AssistantContent::text(text))
}
"thinking" => {
let text = obj.get("text").and_then(|t| t.as_str())?;
Some(AssistantContent::Reasoning(Reasoning::new(text)))
}
"toolCall" => {
let id = obj.get("id").and_then(|t| t.as_str())?.to_string();
let name = obj.get("name").and_then(|t| t.as_str())?.to_string();
let arguments = obj.get("arguments").cloned().unwrap_or(Value::Null);
Some(AssistantContent::ToolCall(ToolCall {
id,
call_id: None,
function: ToolFunction { name, arguments },
signature: None,
additional_params: None,
}))
}
_ => None,
}
}
pub fn loop_tool_to_rig_definition(tool: &dyn LoopTool) -> ToolDefinition {
let parameters = tool
.flat_parameters()
.cloned()
.unwrap_or_else(|| tool.parameters().clone());
ToolDefinition {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters,
}
}
pub fn build_provider_additional_params(
provider_name: Option<&str>,
opts: &super::stream::StreamOptions,
) -> Option<serde_json::Value> {
let mut additional = serde_json::Map::new();
if let Some(level) = opts.reasoning {
match provider_name {
Some("anthropic") => {
let budget = budget_for_level(level, opts.thinking_budgets.as_ref());
if budget > 0 {
additional.insert(
"thinking".to_string(),
serde_json::json!({
"type": "enabled",
"budget_tokens": budget,
}),
);
}
}
Some("openai" | "deepseek" | "glm" | "custom" | "openrouter") => {
if let Some(effort) = thinking_level_to_openai_effort(level) {
additional.insert(
"reasoning".to_string(),
serde_json::json!({ "effort": effort }),
);
}
}
Some("gemini") => {
let budget = budget_for_level(level, opts.thinking_budgets.as_ref());
if budget > 0 {
additional.insert(
"thinking_config".to_string(),
serde_json::json!({ "thinking_budget": budget }),
);
}
}
Some("ollama") | None => {
additional.insert(
"reasoning_level".to_string(),
serde_json::to_value(level).unwrap_or(serde_json::Value::Null),
);
}
Some(_) => {
additional.insert(
"reasoning_level".to_string(),
serde_json::to_value(level).unwrap_or(serde_json::Value::Null),
);
}
}
}
if !opts.headers.is_empty()
&& let Ok(v) = serde_json::to_value(&opts.headers)
{
additional.insert("headers".to_string(), v);
}
if !opts.metadata.is_empty() {
additional.insert(
"metadata".to_string(),
serde_json::Value::Object(
opts.metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
),
);
}
if additional.is_empty() {
None
} else {
Some(serde_json::Value::Object(additional))
}
}
fn thinking_level_to_openai_effort(level: super::types::ThinkingLevel) -> Option<&'static str> {
use super::types::ThinkingLevel as TL;
match level {
TL::Off => None,
TL::Minimal | TL::Low => Some("low"),
TL::Medium => Some("medium"),
TL::High | TL::Xhigh => Some("high"),
}
}
fn budget_for_level(
level: super::types::ThinkingLevel,
budgets: Option<&super::types::ThinkingBudgets>,
) -> u32 {
use super::types::ThinkingLevel as TL;
match level {
TL::Off => 0,
TL::Minimal => budgets.and_then(|b| b.minimal).unwrap_or(1024),
TL::Low => budgets.and_then(|b| b.low).unwrap_or(2048),
TL::Medium => budgets.and_then(|b| b.medium).unwrap_or(4096),
TL::High | TL::Xhigh => budgets.and_then(|b| b.high).unwrap_or(16384),
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig::completion::message::UserContent;
#[test]
fn user_value_converts_to_user_message() {
let v = serde_json::json!({"role": "user", "content": "hello"});
let msg = value_to_rig_message(&v).expect("must convert");
match msg {
Message::User { content } => {
let first = content.first();
match first {
UserContent::Text(t) => assert_eq!(t.text, "hello"),
_ => panic!("expected text"),
}
}
_ => panic!("expected User"),
}
}
#[test]
fn assistant_text_block_converts() {
let v = serde_json::json!({
"role": "assistant",
"content": [{"type": "text", "text": "hi there"}],
});
let msg = value_to_rig_message(&v).expect("must convert");
match msg {
Message::Assistant { id, content } => {
assert!(id.is_none());
match content.first() {
AssistantContent::Text(t) => assert_eq!(t.text, "hi there"),
_ => panic!("expected text"),
}
}
_ => panic!("expected Assistant"),
}
}
#[test]
fn assistant_tool_call_block_converts() {
let v = serde_json::json!({
"role": "assistant",
"content": [{
"type": "toolCall",
"id": "call_1",
"name": "echo",
"arguments": {"value": "x"},
}],
});
let msg = value_to_rig_message(&v).expect("must convert");
match msg {
Message::Assistant { content, .. } => match content.first() {
AssistantContent::ToolCall(tc) => {
assert_eq!(tc.id, "call_1");
assert_eq!(tc.function.name, "echo");
assert_eq!(tc.function.arguments["value"], "x");
}
_ => panic!("expected ToolCall"),
},
_ => panic!("expected Assistant"),
}
}
#[test]
fn assistant_thinking_block_converts_to_reasoning() {
let v = serde_json::json!({
"role": "assistant",
"content": [{"type": "thinking", "text": "let me think"}],
});
let msg = value_to_rig_message(&v).expect("must convert");
match msg {
Message::Assistant { content, .. } => match content.first() {
AssistantContent::Reasoning(_) => {}
_ => panic!("expected Reasoning"),
},
_ => panic!("expected Assistant"),
}
}
#[test]
fn tool_result_value_converts() {
let v = serde_json::json!({
"role": "toolResult",
"toolCallId": "call_1",
"toolName": "echo",
"content": [
{"type": "text", "text": "line 1"},
{"type": "text", "text": "line 2"},
],
"details": {},
"isError": false,
});
let msg = value_to_rig_message(&v).expect("must convert");
match msg {
Message::User { content } => match content.first() {
UserContent::ToolResult(tr) => {
assert_eq!(tr.id, "call_1");
}
_ => panic!("expected ToolResult"),
},
_ => panic!("expected User"),
}
}
#[test]
fn tool_role_snake_case_converts() {
let v = serde_json::json!({
"role": "tool",
"tool_call_id": "call_abc",
"content": "tool output text",
});
let msg = value_to_rig_message(&v).expect("must convert");
match msg {
Message::User { content } => match content.first() {
UserContent::ToolResult(tr) => {
assert_eq!(tr.id, "call_abc");
}
other => panic!("expected ToolResult, got {other:?}"),
},
other => panic!("expected User, got {other:?}"),
}
}
#[test]
fn custom_role_returns_none() {
let v = serde_json::json!({"role": "custom", "content": "x"});
assert!(value_to_rig_message(&v).is_none());
}
#[test]
fn missing_role_returns_none() {
let v = serde_json::json!({"content": "x"});
assert!(value_to_rig_message(&v).is_none());
}
#[test]
fn loop_tool_definition_strips_label() {
#[derive(Debug)]
struct Stub;
impl LoopTool for Stub {
fn name(&self) -> &str {
"stub"
}
fn description(&self) -> &str {
"stub description"
}
fn label(&self) -> &str {
"Stub Label"
}
fn parameters(&self) -> &Value {
static P: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
P.get_or_init(|| serde_json::json!({"type": "object"}))
}
fn execute<'a>(
&'a self,
_id: &'a str,
_args: Value,
_signal: AbortSignal,
_on_update: super::super::tool::LoopToolUpdate,
) -> Pin<
Box<
dyn Future<Output = Result<super::super::result::LoopToolResult, String>>
+ Send
+ 'a,
>,
> {
Box::pin(async { unreachable!("not called in conversion test") })
}
}
let def = loop_tool_to_rig_definition(&Stub);
assert_eq!(def.name, "stub");
assert_eq!(def.description, "stub description");
assert_eq!(def.parameters["type"], "object");
}
#[test]
fn stream_fn_is_send_sync_static() {
fn assert_constraints<M>(_model: M)
where
M: CompletionModel + Clone + Send + Sync + 'static,
M::StreamingResponse: Clone + Unpin + Send + Sync + GetTokenUsage + 'static,
{
}
let _: fn(_) = assert_constraints::<NopModel>;
}
#[derive(Clone)]
struct NopModel;
impl GetTokenUsage for NopStreamResponse {
fn token_usage(&self) -> Option<rig::completion::Usage> {
None
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, PartialEq)]
struct NopStreamResponse;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, PartialEq)]
struct NopResponse;
impl CompletionModel for NopModel {
type Response = NopResponse;
type StreamingResponse = NopStreamResponse;
type Client = ();
fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
NopModel
}
async fn completion(
&self,
_request: rig::completion::CompletionRequest,
) -> Result<rig::completion::CompletionResponse<Self::Response>, CompletionError> {
unreachable!("completion() not used in stream factory tests")
}
async fn stream(
&self,
_request: rig::completion::CompletionRequest,
) -> Result<
rig::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
CompletionError,
> {
let inner: rig::streaming::StreamingResult<Self::StreamingResponse> =
Box::pin(futures::stream::empty());
Ok(rig::streaming::StreamingCompletionResponse::stream(inner))
}
}
#[tokio::test]
async fn factory_invocation_produces_start_and_done() {
use futures::stream::StreamExt;
let factory = rig_stream_fn_from_model::<NopModel>(NopModel, vec![], None);
let ctx = LlmContext {
system_prompt: "test preamble".to_string(),
messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
};
let mut stream = factory(
ctx,
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
);
let mut kinds = Vec::new();
while let Some(evt) = stream.next().await {
kinds.push(match &evt {
StreamEvent::Start { .. } => "start",
StreamEvent::Delta { .. } => "delta",
StreamEvent::Done { .. } => "done",
StreamEvent::Error { error } => {
panic!("unexpected error: {error}");
}
StreamEvent::Retry { .. } => {
panic!("unexpected retry event in non-retried stream");
}
});
}
assert!(kinds.contains(&"start"));
assert!(kinds.contains(&"done"));
}
#[tokio::test]
async fn factory_empty_messages_emits_error() {
use futures::stream::StreamExt;
let factory = rig_stream_fn_from_model::<NopModel>(NopModel, vec![], None);
let ctx = LlmContext {
system_prompt: String::new(),
messages: Vec::new(),
};
let mut stream = factory(
ctx,
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
);
let mut found_error = false;
while let Some(evt) = stream.next().await {
if matches!(evt, StreamEvent::Error { .. }) {
found_error = true;
}
}
assert!(found_error, "empty messages must produce an Error event");
}
use crate::agent::agent_loop::stream::StreamOptions;
use crate::agent::agent_loop::tool::AbortSignal;
use crate::agent::agent_loop::types::{ThinkingBudgets, ThinkingLevel};
fn opts_with_reasoning(level: ThinkingLevel) -> StreamOptions {
let mut o = StreamOptions::from_signal(AbortSignal::new());
o.reasoning = Some(level);
o
}
#[test]
fn anthropic_reasoning_maps_to_thinking_budget() {
let opts = opts_with_reasoning(ThinkingLevel::Medium);
let v = build_provider_additional_params(Some("anthropic"), &opts).unwrap();
assert_eq!(v["thinking"]["type"], "enabled");
assert_eq!(v["thinking"]["budget_tokens"], 4096);
}
#[test]
fn anthropic_off_omits_thinking_key() {
let opts = opts_with_reasoning(ThinkingLevel::Off);
let v = build_provider_additional_params(Some("anthropic"), &opts);
assert!(v.is_none(), "Off should produce empty additional_params");
}
#[test]
fn anthropic_respects_caller_budget_override() {
let mut opts = opts_with_reasoning(ThinkingLevel::High);
opts.thinking_budgets = Some(ThinkingBudgets {
high: Some(32_000),
..Default::default()
});
let v = build_provider_additional_params(Some("anthropic"), &opts).unwrap();
assert_eq!(v["thinking"]["budget_tokens"], 32_000);
}
#[test]
fn openai_reasoning_maps_to_effort() {
for (level, expected) in [
(ThinkingLevel::Low, "low"),
(ThinkingLevel::Medium, "medium"),
(ThinkingLevel::High, "high"),
] {
let opts = opts_with_reasoning(level);
let v = build_provider_additional_params(Some("openai"), &opts).unwrap();
assert_eq!(
v["reasoning"]["effort"], expected,
"level {level:?} should map to {expected}"
);
}
}
#[test]
fn openai_compat_providers_share_effort_shape() {
let opts = opts_with_reasoning(ThinkingLevel::Medium);
for provider in ["deepseek", "glm", "custom", "openrouter"] {
let v = build_provider_additional_params(Some(provider), &opts).unwrap();
assert_eq!(
v["reasoning"]["effort"], "medium",
"provider {provider} should use effort=medium"
);
}
}
#[test]
fn openai_clamps_unsupported_levels() {
let opts_min = opts_with_reasoning(ThinkingLevel::Minimal);
let v = build_provider_additional_params(Some("openai"), &opts_min).unwrap();
assert_eq!(v["reasoning"]["effort"], "low");
let opts_x = opts_with_reasoning(ThinkingLevel::Xhigh);
let v = build_provider_additional_params(Some("openai"), &opts_x).unwrap();
assert_eq!(v["reasoning"]["effort"], "high");
}
#[test]
fn openai_off_omits_reasoning_key() {
let opts = opts_with_reasoning(ThinkingLevel::Off);
let v = build_provider_additional_params(Some("openai"), &opts);
assert!(v.is_none());
}
#[test]
fn gemini_reasoning_maps_to_thinking_config() {
let opts = opts_with_reasoning(ThinkingLevel::High);
let v = build_provider_additional_params(Some("gemini"), &opts).unwrap();
assert_eq!(v["thinking_config"]["thinking_budget"], 16384);
}
#[test]
fn headers_and_metadata_pass_through_for_all_providers() {
let mut opts = StreamOptions::from_signal(AbortSignal::new());
opts.headers
.insert("X-Tenant".to_string(), "acme".to_string());
opts.metadata
.insert("user_id".to_string(), serde_json::json!("u-42"));
for provider in ["anthropic", "openai", "gemini", "ollama", "unknown"] {
let v = build_provider_additional_params(Some(provider), &opts).unwrap();
assert_eq!(v["headers"]["X-Tenant"], "acme", "provider {provider}");
assert_eq!(v["metadata"]["user_id"], "u-42", "provider {provider}");
}
}
#[test]
fn empty_options_produces_none() {
let opts = StreamOptions::from_signal(AbortSignal::new());
assert!(build_provider_additional_params(Some("anthropic"), &opts).is_none());
assert!(build_provider_additional_params(None, &opts).is_none());
}
#[test]
fn unknown_provider_uses_generic_key() {
let opts = opts_with_reasoning(ThinkingLevel::High);
let v = build_provider_additional_params(Some("future-provider"), &opts).unwrap();
assert!(v.get("reasoning_level").is_some());
assert!(v.get("reasoning").is_none());
assert!(v.get("thinking").is_none());
}
fn mk_def(name: &str) -> ToolDefinition {
ToolDefinition {
name: name.to_string(),
description: format!("desc for {name}"),
parameters: serde_json::json!({}),
}
}
#[test]
fn tool_search_filter_none_passes_all_tools() {
let defs = vec![mk_def("read"), mk_def("write"), mk_def("custom_mcp")];
let out = filter_tool_defs(&defs, None);
assert_eq!(out.len(), 3);
let names: Vec<&str> = out.iter().map(|d| d.name.as_str()).collect();
assert_eq!(names, vec!["read", "write", "custom_mcp"]);
}
#[test]
fn tool_search_filter_empty_set_keeps_only_always_on() {
let defs = vec![
mk_def("read"),
mk_def("write"),
mk_def("tool_search"),
mk_def("write_todo_list"),
mk_def("task_status"),
mk_def("custom_mcp"),
];
let filter = std::sync::Arc::new(std::sync::Mutex::new(
std::collections::HashSet::<String>::new(),
));
let out = filter_tool_defs(&defs, Some(&filter));
let names: std::collections::HashSet<String> = out.iter().map(|d| d.name.clone()).collect();
assert!(names.contains("tool_search"));
assert!(names.contains("write_todo_list"));
assert!(names.contains("task_status"));
assert!(!names.contains("read"), "read must be filtered out");
assert!(!names.contains("write"));
assert!(!names.contains("custom_mcp"));
}
#[test]
fn tool_search_filter_only_tool_search_suppresses_others() {
let defs = vec![
mk_def("read"),
mk_def("write"),
mk_def("tool_search"),
mk_def("custom_mcp"),
];
let mut set = std::collections::HashSet::new();
set.insert("tool_search".to_string());
let filter = std::sync::Arc::new(std::sync::Mutex::new(set));
let out = filter_tool_defs(&defs, Some(&filter));
let names: Vec<&str> = out.iter().map(|d| d.name.as_str()).collect();
assert_eq!(names, vec!["tool_search"]);
}
#[test]
fn tool_search_filter_loaded_tool_surfaces_on_next_turn() {
let defs = vec![
mk_def("read"),
mk_def("write"),
mk_def("tool_search"),
mk_def("custom_mcp"),
];
let filter = std::sync::Arc::new(std::sync::Mutex::new(
std::collections::HashSet::<String>::new(),
));
let out1 = filter_tool_defs(&defs, Some(&filter));
assert!(!out1.iter().any(|d| d.name == "read"));
filter.lock().unwrap().insert("read".to_string());
let out2 = filter_tool_defs(&defs, Some(&filter));
assert!(out2.iter().any(|d| d.name == "read"));
}
#[test]
fn tool_search_filter_ignores_unknown_names_in_set() {
let defs = vec![mk_def("read"), mk_def("tool_search")];
let mut set = std::collections::HashSet::new();
set.insert("read".to_string());
set.insert("phantom_tool".to_string()); let filter = std::sync::Arc::new(std::sync::Mutex::new(set));
let out = filter_tool_defs(&defs, Some(&filter));
let names: std::collections::HashSet<String> = out.iter().map(|d| d.name.clone()).collect();
assert!(names.contains("read"));
assert!(names.contains("tool_search"));
assert!(!names.contains("phantom_tool"));
assert_eq!(out.len(), 2);
}
#[test]
fn emit_cache_prefix_event_is_deterministic() {
let defs = vec![mk_def("write"), mk_def("read")];
emit_cache_prefix_event(Some("anthropic"), "preamble-x", &defs, 3);
emit_cache_prefix_event(Some("anthropic"), "preamble-x", &defs, 3);
let permuted = vec![mk_def("read"), mk_def("write")];
emit_cache_prefix_event(Some("anthropic"), "preamble-x", &permuted, 3);
}
}