codetether_agent/cognition/
tool_router.rs1use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerBackend, ThinkerConfig};
19
20#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25 pub enabled: bool,
27 pub model_path: Option<String>,
29 pub tokenizer_path: Option<String>,
31 pub arch: String,
33 pub device: super::thinker::CandleDevicePreference,
35 pub max_tokens: usize,
38 pub temperature: f32,
40}
41
42impl Default for ToolRouterConfig {
43 fn default() -> Self {
44 Self {
45 enabled: false,
46 model_path: None,
47 tokenizer_path: None,
48 arch: "gemma3".to_string(),
49 device: super::thinker::CandleDevicePreference::Auto,
50 max_tokens: 128,
51 temperature: 0.1,
52 }
53 }
54}
55
56impl ToolRouterConfig {
57 pub fn from_env() -> Self {
69 let enabled = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
70 .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
71 .unwrap_or(false);
72
73 Self {
74 enabled,
75 model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
76 tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
77 arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
78 .unwrap_or_else(|_| "gemma3".to_string()),
79 device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
80 .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
81 .unwrap_or(super::thinker::CandleDevicePreference::Auto),
82 max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
83 .ok()
84 .and_then(|v| v.parse().ok())
85 .unwrap_or(128),
86 temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
87 .ok()
88 .and_then(|v| v.parse().ok())
89 .unwrap_or(0.1),
90 }
91 }
92}
93
94fn build_functiongemma_prompt(assistant_text: &str, tools: &[ToolDefinition]) -> String {
101 let tool_defs: Vec<serde_json::Value> = tools
103 .iter()
104 .map(|t| {
105 serde_json::json!({
106 "name": t.name,
107 "description": t.description,
108 "parameters": t.parameters,
109 })
110 })
111 .collect();
112
113 let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
114
115 format!(
124 "<start_of_turn>system\n\
125 You are a function calling AI model. You are provided with function \
126 signatures within <tools></tools> XML tags. You may call one or more \
127 functions to assist with the user query. Don't make assumptions about \
128 what values to plug into functions.\n\n\
129 <tools>\n{tools_json}\n</tools>\n\n\
130 For each function call return a JSON object with function name and \
131 arguments within <tool_call></tool_call> XML tags as follows:\n\
132 <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
133 <end_of_turn>\n\
134 <start_of_turn>user\n\
135 {assistant_text}\n\
136 <end_of_turn>\n\
137 <start_of_turn>model\n"
138 )
139}
140
141#[derive(Debug, Clone)]
145struct ParsedToolCall {
146 name: String,
147 arguments: String, }
149
150fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
161 let mut calls = Vec::new();
162
163 let mut remaining = text;
165 while let Some(start) = remaining.find("<tool_call>") {
166 remaining = &remaining[start + "<tool_call>".len()..];
167 if let Some(end) = remaining.find("</tool_call>") {
168 let block = remaining[..end].trim();
169 remaining = &remaining[end + "</tool_call>".len()..];
170
171 if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
173 let name = value
174 .get("name")
175 .and_then(|n| n.as_str())
176 .unwrap_or("")
177 .to_string();
178 let arguments = value
179 .get("arguments")
180 .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
181 .unwrap_or_else(|| "{}".to_string());
182
183 if !name.is_empty() {
184 calls.push(ParsedToolCall { name, arguments });
185 }
186 } else {
187 tracing::warn!(
188 block = %block,
189 "FunctionGemma produced unparseable tool_call block"
190 );
191 }
192 } else {
193 break; }
195 }
196
197 calls
198}
199
200pub struct ToolCallRouter {
206 runtime: Arc<Mutex<CandleThinker>>,
207}
208
209impl std::fmt::Debug for ToolCallRouter {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("ToolCallRouter").finish()
212 }
213}
214
215impl ToolCallRouter {
216 pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
220 if !config.enabled {
221 tracing::debug!("FunctionGemma tool router is disabled");
222 return Ok(None);
223 }
224
225 let model_path = config.model_path.as_ref().ok_or_else(|| {
226 anyhow!("CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled")
227 })?;
228 let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
229 anyhow!(
230 "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
231 )
232 })?;
233
234 let thinker_config = ThinkerConfig {
236 enabled: true,
237 backend: ThinkerBackend::Candle,
238 candle_model_path: Some(model_path.clone()),
239 candle_tokenizer_path: Some(tokenizer_path.clone()),
240 candle_arch: Some(config.arch.clone()),
241 candle_device: config.device,
242 max_tokens: config.max_tokens,
243 temperature: config.temperature,
244 ..ThinkerConfig::default()
245 };
246
247 let runtime = CandleThinker::new(&thinker_config)?;
248 tracing::info!(
249 model_path = %model_path,
250 arch = %config.arch,
251 "FunctionGemma tool-call router initialised"
252 );
253
254 Ok(Some(Self {
255 runtime: Arc::new(Mutex::new(runtime)),
256 }))
257 }
258
259 pub async fn maybe_reformat(
272 &self,
273 response: CompletionResponse,
274 tools: &[ToolDefinition],
275 model_supports_tools: bool,
276 ) -> CompletionResponse {
277 if model_supports_tools {
281 tracing::trace!("Skipping FunctionGemma: model supports native tool calling");
282 return response;
283 }
284
285 let has_tool_calls = response
287 .message
288 .content
289 .iter()
290 .any(|p| matches!(p, ContentPart::ToolCall { .. }));
291
292 if has_tool_calls {
293 return response;
294 }
295
296 if tools.is_empty() {
298 return response;
299 }
300
301 let assistant_text: String = response
303 .message
304 .content
305 .iter()
306 .filter_map(|p| match p {
307 ContentPart::Text { text } => Some(text.as_str()),
308 _ => None,
309 })
310 .collect::<Vec<_>>()
311 .join("\n");
312
313 if assistant_text.trim().is_empty() {
314 return response;
315 }
316
317 let text_lower = assistant_text.to_lowercase();
321 let mentions_tool = tools
322 .iter()
323 .any(|t| text_lower.contains(&t.name.to_lowercase()));
324 if !mentions_tool {
325 tracing::trace!("Skipping FunctionGemma: assistant text mentions no tool names");
326 return response;
327 }
328
329 match self.run_functiongemma(&assistant_text, tools).await {
331 Ok(parsed) if !parsed.is_empty() => {
332 tracing::info!(
333 num_calls = parsed.len(),
334 "FunctionGemma router produced tool calls from text-only response"
335 );
336 self.rewrite_response(response, parsed)
337 }
338 Ok(_) => {
339 response
341 }
342 Err(e) => {
343 tracing::warn!(
344 error = %e,
345 "FunctionGemma router failed; returning original response"
346 );
347 response
348 }
349 }
350 }
351
352 async fn run_functiongemma(
354 &self,
355 assistant_text: &str,
356 tools: &[ToolDefinition],
357 ) -> Result<Vec<ParsedToolCall>> {
358 let prompt = build_functiongemma_prompt(assistant_text, tools);
359 let runtime = Arc::clone(&self.runtime);
360
361 let output = tokio::task::spawn_blocking(move || {
362 let mut guard = runtime
363 .lock()
364 .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
365 guard.think("", &prompt)
370 })
371 .await
372 .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
373
374 Ok(parse_functiongemma_response(&output.text))
375 }
376
377 fn rewrite_response(
384 &self,
385 mut response: CompletionResponse,
386 calls: Vec<ParsedToolCall>,
387 ) -> CompletionResponse {
388 response
391 .message
392 .content
393 .retain(|p| !matches!(p, ContentPart::Text { .. }));
394
395 for call in calls {
396 response.message.content.push(ContentPart::ToolCall {
397 id: format!("fc_{}", Uuid::new_v4()),
398 name: call.name,
399 arguments: call.arguments,
400 });
401 }
402
403 response.finish_reason = FinishReason::ToolCalls;
405 response
406 }
407}
408
409#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn parse_single_tool_call() {
417 let text = r#"<tool_call>
418{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
419</tool_call>"#;
420 let calls = parse_functiongemma_response(text);
421 assert_eq!(calls.len(), 1);
422 assert_eq!(calls[0].name, "read_file");
423 assert!(calls[0].arguments.contains("/tmp/foo.rs"));
424 }
425
426 #[test]
427 fn parse_multiple_tool_calls() {
428 let text = r#"I'll read both files.
429<tool_call>
430{"name": "read_file", "arguments": {"path": "a.rs"}}
431</tool_call>
432<tool_call>
433{"name": "read_file", "arguments": {"path": "b.rs"}}
434</tool_call>"#;
435 let calls = parse_functiongemma_response(text);
436 assert_eq!(calls.len(), 2);
437 assert_eq!(calls[0].name, "read_file");
438 assert_eq!(calls[1].name, "read_file");
439 }
440
441 #[test]
442 fn parse_no_tool_calls() {
443 let text = "I cannot help with that request.";
444 let calls = parse_functiongemma_response(text);
445 assert!(calls.is_empty());
446 }
447
448 #[test]
449 fn parse_malformed_json_skipped() {
450 let text = r#"<tool_call>
451not valid json
452</tool_call>
453<tool_call>
454{"name": "list_dir", "arguments": {"path": "."}}
455</tool_call>"#;
456 let calls = parse_functiongemma_response(text);
457 assert_eq!(calls.len(), 1);
458 assert_eq!(calls[0].name, "list_dir");
459 }
460
461 #[test]
462 fn parse_empty_name_skipped() {
463 let text = r#"<tool_call>
464{"name": "", "arguments": {}}
465</tool_call>"#;
466 let calls = parse_functiongemma_response(text);
467 assert!(calls.is_empty());
468 }
469
470 #[test]
471 fn prompt_contains_tool_definitions() {
472 let tools = vec![ToolDefinition {
473 name: "read_file".to_string(),
474 description: "Read a file".to_string(),
475 parameters: serde_json::json!({
476 "type": "object",
477 "properties": {
478 "path": { "type": "string" }
479 },
480 "required": ["path"]
481 }),
482 }];
483 let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
484 assert!(prompt.contains("<start_of_turn>system"));
485 assert!(prompt.contains("read_file"));
486 assert!(prompt.contains("<tools>"));
487 assert!(prompt.contains("Please read foo.rs"));
488 assert!(prompt.contains("<start_of_turn>model"));
489 }
490
491 #[test]
492 fn config_defaults_disabled() {
493 let config = ToolRouterConfig::default();
494 assert!(!config.enabled);
495 assert_eq!(config.arch, "gemma3");
496 assert_eq!(config.max_tokens, 128);
497 }
498}