codetether_agent/cognition/
tool_router.rs1use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerBackend, ThinkerConfig};
19
20#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25 pub enabled: bool,
27 pub model_path: Option<String>,
29 pub tokenizer_path: Option<String>,
31 pub arch: String,
33 pub device: super::thinker::CandleDevicePreference,
35 pub max_tokens: usize,
38 pub temperature: f32,
40}
41
42impl Default for ToolRouterConfig {
43 fn default() -> Self {
44 Self {
45 enabled: false,
46 model_path: None,
47 tokenizer_path: None,
48 arch: "gemma3".to_string(),
49 device: super::thinker::CandleDevicePreference::Auto,
50 max_tokens: 128,
51 temperature: 0.1,
52 }
53 }
54}
55
56impl ToolRouterConfig {
57 pub fn from_env() -> Self {
70 let enabled_requested = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
71 .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
72 .unwrap_or(false);
73
74 let disabled = std::env::var("CODETETHER_FUNCTIONGEMMA_DISABLED")
77 .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
78 .unwrap_or(true);
79
80 let enabled = enabled_requested && !disabled;
81
82 Self {
83 enabled,
84 model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
85 tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
86 arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
87 .unwrap_or_else(|_| "gemma3".to_string()),
88 device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
89 .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
90 .unwrap_or(super::thinker::CandleDevicePreference::Auto),
91 max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
92 .ok()
93 .and_then(|v| v.parse().ok())
94 .unwrap_or(128),
95 temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
96 .ok()
97 .and_then(|v| v.parse().ok())
98 .unwrap_or(0.1),
99 }
100 }
101}
102
103fn build_functiongemma_prompt(assistant_text: &str, tools: &[ToolDefinition]) -> String {
110 let tool_defs: Vec<serde_json::Value> = tools
112 .iter()
113 .map(|t| {
114 serde_json::json!({
115 "name": t.name,
116 "description": t.description,
117 "parameters": t.parameters,
118 })
119 })
120 .collect();
121
122 let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
123
124 format!(
133 "<start_of_turn>system\n\
134 You are a function calling AI model. You are provided with function \
135 signatures within <tools></tools> XML tags. You may call one or more \
136 functions to assist with the user query. Don't make assumptions about \
137 what values to plug into functions.\n\n\
138 <tools>\n{tools_json}\n</tools>\n\n\
139 For each function call return a JSON object with function name and \
140 arguments within <tool_call></tool_call> XML tags as follows:\n\
141 <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
142 <end_of_turn>\n\
143 <start_of_turn>user\n\
144 {assistant_text}\n\
145 <end_of_turn>\n\
146 <start_of_turn>model\n"
147 )
148}
149
150#[derive(Debug, Clone)]
154struct ParsedToolCall {
155 name: String,
156 arguments: String, }
158
159fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
170 let mut calls = Vec::new();
171
172 let mut remaining = text;
174 while let Some(start) = remaining.find("<tool_call>") {
175 remaining = &remaining[start + "<tool_call>".len()..];
176 if let Some(end) = remaining.find("</tool_call>") {
177 let block = remaining[..end].trim();
178 remaining = &remaining[end + "</tool_call>".len()..];
179
180 if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
182 let name = value
183 .get("name")
184 .and_then(|n| n.as_str())
185 .unwrap_or("")
186 .to_string();
187 let arguments = value
188 .get("arguments")
189 .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
190 .unwrap_or_else(|| "{}".to_string());
191
192 if !name.is_empty() {
193 calls.push(ParsedToolCall { name, arguments });
194 }
195 } else {
196 tracing::warn!(
197 block = %block,
198 "FunctionGemma produced unparseable tool_call block"
199 );
200 }
201 } else {
202 break; }
204 }
205
206 calls
207}
208
209pub struct ToolCallRouter {
215 runtime: Arc<Mutex<CandleThinker>>,
216}
217
218impl std::fmt::Debug for ToolCallRouter {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 f.debug_struct("ToolCallRouter").finish()
221 }
222}
223
224impl ToolCallRouter {
225 pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
229 if !config.enabled {
230 tracing::debug!("FunctionGemma tool router is disabled");
231 return Ok(None);
232 }
233
234 let model_path = config.model_path.as_ref().ok_or_else(|| {
235 anyhow!("CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled")
236 })?;
237 let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
238 anyhow!(
239 "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
240 )
241 })?;
242
243 let thinker_config = ThinkerConfig {
245 enabled: true,
246 backend: ThinkerBackend::Candle,
247 candle_model_path: Some(model_path.clone()),
248 candle_tokenizer_path: Some(tokenizer_path.clone()),
249 candle_arch: Some(config.arch.clone()),
250 candle_device: config.device,
251 max_tokens: config.max_tokens,
252 temperature: config.temperature,
253 ..ThinkerConfig::default()
254 };
255
256 let runtime = CandleThinker::new(&thinker_config)?;
257 tracing::info!(
258 model_path = %model_path,
259 arch = %config.arch,
260 "FunctionGemma tool-call router initialised"
261 );
262
263 Ok(Some(Self {
264 runtime: Arc::new(Mutex::new(runtime)),
265 }))
266 }
267
268 pub async fn maybe_reformat(
281 &self,
282 response: CompletionResponse,
283 tools: &[ToolDefinition],
284 model_supports_tools: bool,
285 ) -> CompletionResponse {
286 if model_supports_tools {
290 tracing::trace!("Skipping FunctionGemma: model supports native tool calling");
291 return response;
292 }
293
294 let has_tool_calls = response
296 .message
297 .content
298 .iter()
299 .any(|p| matches!(p, ContentPart::ToolCall { .. }));
300
301 if has_tool_calls {
302 return response;
303 }
304
305 if tools.is_empty() {
307 return response;
308 }
309
310 let assistant_text: String = response
312 .message
313 .content
314 .iter()
315 .filter_map(|p| match p {
316 ContentPart::Text { text } => Some(text.as_str()),
317 _ => None,
318 })
319 .collect::<Vec<_>>()
320 .join("\n");
321
322 if assistant_text.trim().is_empty() {
323 return response;
324 }
325
326 let text_lower = assistant_text.to_lowercase();
330 let mentions_tool = tools
331 .iter()
332 .any(|t| text_lower.contains(&t.name.to_lowercase()));
333 if !mentions_tool {
334 tracing::trace!("Skipping FunctionGemma: assistant text mentions no tool names");
335 return response;
336 }
337
338 match self.run_functiongemma(&assistant_text, tools).await {
340 Ok(parsed) if !parsed.is_empty() => {
341 tracing::info!(
342 num_calls = parsed.len(),
343 "FunctionGemma router produced tool calls from text-only response"
344 );
345 self.rewrite_response(response, parsed)
346 }
347 Ok(_) => {
348 response
350 }
351 Err(e) => {
352 tracing::warn!(
353 error = %e,
354 "FunctionGemma router failed; returning original response"
355 );
356 response
357 }
358 }
359 }
360
361 async fn run_functiongemma(
363 &self,
364 assistant_text: &str,
365 tools: &[ToolDefinition],
366 ) -> Result<Vec<ParsedToolCall>> {
367 let prompt = build_functiongemma_prompt(assistant_text, tools);
368 let runtime = Arc::clone(&self.runtime);
369
370 let output = tokio::task::spawn_blocking(move || {
371 let mut guard = runtime
372 .lock()
373 .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
374 guard.think("", &prompt)
379 })
380 .await
381 .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
382
383 Ok(parse_functiongemma_response(&output.text))
384 }
385
386 fn rewrite_response(
393 &self,
394 mut response: CompletionResponse,
395 calls: Vec<ParsedToolCall>,
396 ) -> CompletionResponse {
397 response
400 .message
401 .content
402 .retain(|p| !matches!(p, ContentPart::Text { .. }));
403
404 for call in calls {
405 response.message.content.push(ContentPart::ToolCall {
406 id: format!("fc_{}", Uuid::new_v4()),
407 name: call.name,
408 arguments: call.arguments,
409 });
410 }
411
412 response.finish_reason = FinishReason::ToolCalls;
414 response
415 }
416}
417
418#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn parse_single_tool_call() {
426 let text = r#"<tool_call>
427{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
428</tool_call>"#;
429 let calls = parse_functiongemma_response(text);
430 assert_eq!(calls.len(), 1);
431 assert_eq!(calls[0].name, "read_file");
432 assert!(calls[0].arguments.contains("/tmp/foo.rs"));
433 }
434
435 #[test]
436 fn parse_multiple_tool_calls() {
437 let text = r#"I'll read both files.
438<tool_call>
439{"name": "read_file", "arguments": {"path": "a.rs"}}
440</tool_call>
441<tool_call>
442{"name": "read_file", "arguments": {"path": "b.rs"}}
443</tool_call>"#;
444 let calls = parse_functiongemma_response(text);
445 assert_eq!(calls.len(), 2);
446 assert_eq!(calls[0].name, "read_file");
447 assert_eq!(calls[1].name, "read_file");
448 }
449
450 #[test]
451 fn parse_no_tool_calls() {
452 let text = "I cannot help with that request.";
453 let calls = parse_functiongemma_response(text);
454 assert!(calls.is_empty());
455 }
456
457 #[test]
458 fn parse_malformed_json_skipped() {
459 let text = r#"<tool_call>
460not valid json
461</tool_call>
462<tool_call>
463{"name": "list_dir", "arguments": {"path": "."}}
464</tool_call>"#;
465 let calls = parse_functiongemma_response(text);
466 assert_eq!(calls.len(), 1);
467 assert_eq!(calls[0].name, "list_dir");
468 }
469
470 #[test]
471 fn parse_empty_name_skipped() {
472 let text = r#"<tool_call>
473{"name": "", "arguments": {}}
474</tool_call>"#;
475 let calls = parse_functiongemma_response(text);
476 assert!(calls.is_empty());
477 }
478
479 #[test]
480 fn prompt_contains_tool_definitions() {
481 let tools = vec![ToolDefinition {
482 name: "read_file".to_string(),
483 description: "Read a file".to_string(),
484 parameters: serde_json::json!({
485 "type": "object",
486 "properties": {
487 "path": { "type": "string" }
488 },
489 "required": ["path"]
490 }),
491 }];
492 let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
493 assert!(prompt.contains("<start_of_turn>system"));
494 assert!(prompt.contains("read_file"));
495 assert!(prompt.contains("<tools>"));
496 assert!(prompt.contains("Please read foo.rs"));
497 assert!(prompt.contains("<start_of_turn>model"));
498 }
499
500 #[test]
501 fn config_defaults_disabled() {
502 let config = ToolRouterConfig::default();
503 assert!(!config.enabled);
504 assert_eq!(config.arch, "gemma3");
505 assert_eq!(config.max_tokens, 128);
506 }
507}