Skip to main content

codetether_agent/cognition/
tool_router.rs

1//! FunctionGemma-powered hybrid tool-call router.
2//!
3//! Sits between the primary LLM response and the tool-extraction step in the
4//! session agentic loop.  When the primary LLM returns text-only output that
5//! *describes* tool calls without using structured `ContentPart::ToolCall`
6//! entries, the router passes the text + available tool definitions through a
7//! local FunctionGemma model (via Candle) and emits properly-formatted
8//! `ContentPart::ToolCall` entries.
9//!
10//! **Feature-gated**: this module only compiles when the `functiongemma` cargo
11//! feature is enabled.  The binary size is unaffected in default builds.
12
13use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerConfig, ThinkerBackend};
19
20// ── Configuration ────────────────────────────────────────────────────────────
21
22/// Environment-variable driven configuration for the tool-call router.
23#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25    /// Whether the router is active.  Default: `false`.
26    pub enabled: bool,
27    /// Filesystem path to the FunctionGemma GGUF model.
28    pub model_path: Option<String>,
29    /// Filesystem path to the matching tokenizer.json.
30    pub tokenizer_path: Option<String>,
31    /// Architecture hint (default: `"gemma3"`).
32    pub arch: String,
33    /// Device preference (auto / cpu / cuda).
34    pub device: super::thinker::CandleDevicePreference,
35    /// Max tokens for the FunctionGemma response.
36    pub max_tokens: usize,
37    /// Temperature for FunctionGemma sampling.
38    pub temperature: f32,
39}
40
41impl Default for ToolRouterConfig {
42    fn default() -> Self {
43        Self {
44            enabled: false,
45            model_path: None,
46            tokenizer_path: None,
47            arch: "gemma3".to_string(),
48            device: super::thinker::CandleDevicePreference::Auto,
49            max_tokens: 512,
50            temperature: 0.1,
51        }
52    }
53}
54
55impl ToolRouterConfig {
56    /// Build from environment variables.
57    ///
58    /// | Variable | Description |
59    /// |----------|-------------|
60    /// | `CODETETHER_TOOL_ROUTER_ENABLED` | `true` / `1` to activate |
61    /// | `CODETETHER_TOOL_ROUTER_MODEL_PATH` | Path to `.gguf` model |
62    /// | `CODETETHER_TOOL_ROUTER_TOKENIZER_PATH` | Path to `tokenizer.json` |
63    /// | `CODETETHER_TOOL_ROUTER_ARCH` | Architecture hint (default: `gemma3`) |
64    /// | `CODETETHER_TOOL_ROUTER_DEVICE` | `auto` / `cpu` / `cuda` |
65    /// | `CODETETHER_TOOL_ROUTER_MAX_TOKENS` | Max decode tokens (default: 512) |
66    /// | `CODETETHER_TOOL_ROUTER_TEMPERATURE` | Sampling temp (default: 0.1) |
67    pub fn from_env() -> Self {
68        let enabled = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
69            .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
70            .unwrap_or(false);
71
72        Self {
73            enabled,
74            model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
75            tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
76            arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
77                .unwrap_or_else(|_| "gemma3".to_string()),
78            device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
79                .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
80                .unwrap_or(super::thinker::CandleDevicePreference::Auto),
81            max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
82                .ok()
83                .and_then(|v| v.parse().ok())
84                .unwrap_or(512),
85            temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
86                .ok()
87                .and_then(|v| v.parse().ok())
88                .unwrap_or(0.1),
89        }
90    }
91}
92
93// ── Prompt formatting ────────────────────────────────────────────────────────
94
95/// Serialize tool definitions into FunctionGemma's expected chat template.
96///
97/// FunctionGemma expects tools as a JSON list in the system turn, followed by
98/// the user's intent.  The model produces structured JSON function call output.
99fn build_functiongemma_prompt(
100    assistant_text: &str,
101    tools: &[ToolDefinition],
102) -> String {
103    // Build tool descriptions as a JSON array for the system section.
104    let tool_defs: Vec<serde_json::Value> = tools
105        .iter()
106        .map(|t| {
107            serde_json::json!({
108                "name": t.name,
109                "description": t.description,
110                "parameters": t.parameters,
111            })
112        })
113        .collect();
114
115    let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
116
117    // FunctionGemma chat template:
118    //   <start_of_turn>system
119    //   You are a function calling AI model. ...
120    //   <end_of_turn>
121    //   <start_of_turn>user
122    //   <user intent text>
123    //   <end_of_turn>
124    //   <start_of_turn>model
125    format!(
126        "<start_of_turn>system\n\
127         You are a function calling AI model. You are provided with function \
128         signatures within <tools></tools> XML tags. You may call one or more \
129         functions to assist with the user query. Don't make assumptions about \
130         what values to plug into functions.\n\n\
131         <tools>\n{tools_json}\n</tools>\n\n\
132         For each function call return a JSON object with function name and \
133         arguments within <tool_call></tool_call> XML tags as follows:\n\
134         <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
135         <end_of_turn>\n\
136         <start_of_turn>user\n\
137         {assistant_text}\n\
138         <end_of_turn>\n\
139         <start_of_turn>model\n"
140    )
141}
142
143// ── Response parsing ─────────────────────────────────────────────────────────
144
145/// A single parsed tool call from FunctionGemma output.
146#[derive(Debug, Clone)]
147struct ParsedToolCall {
148    name: String,
149    arguments: String, // JSON string
150}
151
152/// Parse FunctionGemma output into zero or more structured tool calls.
153///
154/// Expected format:
155/// ```text
156/// <tool_call>
157/// {"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
158/// </tool_call>
159/// ```
160///
161/// Handles multiple `<tool_call>` blocks in a single response.
162fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
163    let mut calls = Vec::new();
164
165    // Extract everything between <tool_call> and </tool_call>
166    let mut remaining = text;
167    while let Some(start) = remaining.find("<tool_call>") {
168        remaining = &remaining[start + "<tool_call>".len()..];
169        if let Some(end) = remaining.find("</tool_call>") {
170            let block = remaining[..end].trim();
171            remaining = &remaining[end + "</tool_call>".len()..];
172
173            // Try to parse the JSON block
174            if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
175                let name = value
176                    .get("name")
177                    .and_then(|n| n.as_str())
178                    .unwrap_or("")
179                    .to_string();
180                let arguments = value
181                    .get("arguments")
182                    .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
183                    .unwrap_or_else(|| "{}".to_string());
184
185                if !name.is_empty() {
186                    calls.push(ParsedToolCall { name, arguments });
187                }
188            } else {
189                tracing::warn!(
190                    block = %block,
191                    "FunctionGemma produced unparseable tool_call block"
192                );
193            }
194        } else {
195            break; // Unclosed <tool_call> — stop
196        }
197    }
198
199    calls
200}
201
202// ── Router ───────────────────────────────────────────────────────────────────
203
204/// Hybrid tool-call router backed by a local FunctionGemma model.
205///
206/// Created once at session start; shared via `Arc` across prompt calls.
207pub struct ToolCallRouter {
208    runtime: Arc<Mutex<CandleThinker>>,
209}
210
211impl std::fmt::Debug for ToolCallRouter {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.debug_struct("ToolCallRouter").finish()
214    }
215}
216
217impl ToolCallRouter {
218    /// Construct from a [`ToolRouterConfig`].
219    ///
220    /// Returns `None` if the router is disabled or missing required paths.
221    pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
222        if !config.enabled {
223            tracing::debug!("FunctionGemma tool router is disabled");
224            return Ok(None);
225        }
226
227        let model_path = config.model_path.as_ref().ok_or_else(|| {
228            anyhow!(
229                "CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled"
230            )
231        })?;
232        let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
233            anyhow!(
234                "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
235            )
236        })?;
237
238        // Build a ThinkerConfig configured for the FunctionGemma model
239        let thinker_config = ThinkerConfig {
240            enabled: true,
241            backend: ThinkerBackend::Candle,
242            candle_model_path: Some(model_path.clone()),
243            candle_tokenizer_path: Some(tokenizer_path.clone()),
244            candle_arch: Some(config.arch.clone()),
245            candle_device: config.device,
246            max_tokens: config.max_tokens,
247            temperature: config.temperature,
248            ..ThinkerConfig::default()
249        };
250
251        let runtime = CandleThinker::new(&thinker_config)?;
252        tracing::info!(
253            model_path = %model_path,
254            arch = %config.arch,
255            "FunctionGemma tool-call router initialised"
256        );
257
258        Ok(Some(Self {
259            runtime: Arc::new(Mutex::new(runtime)),
260        }))
261    }
262
263    /// Conditionally reformat a `CompletionResponse`.
264    ///
265    /// - If the response already contains `ContentPart::ToolCall` entries,
266    ///   return it **unchanged** (zero overhead path).
267    /// - If the response is text-only, run FunctionGemma to convert the text
268    ///   into structured tool calls.
269    /// - On any internal error, return the **original** response unchanged
270    ///   (safe degradation — the router never breaks existing functionality).
271    pub async fn maybe_reformat(
272        &self,
273        response: CompletionResponse,
274        tools: &[ToolDefinition],
275    ) -> CompletionResponse {
276        // Fast path: if the response already has structured tool calls, pass through.
277        let has_tool_calls = response
278            .message
279            .content
280            .iter()
281            .any(|p| matches!(p, ContentPart::ToolCall { .. }));
282
283        if has_tool_calls {
284            return response;
285        }
286
287        // No tools were provided — nothing for FunctionGemma to match against.
288        if tools.is_empty() {
289            return response;
290        }
291
292        // Collect assistant text from the response.
293        let assistant_text: String = response
294            .message
295            .content
296            .iter()
297            .filter_map(|p| match p {
298                ContentPart::Text { text } => Some(text.as_str()),
299                _ => None,
300            })
301            .collect::<Vec<_>>()
302            .join("\n");
303
304        if assistant_text.trim().is_empty() {
305            return response;
306        }
307
308        // Run FunctionGemma in a blocking thread (CPU-bound).
309        match self.run_functiongemma(&assistant_text, tools).await {
310            Ok(parsed) if !parsed.is_empty() => {
311                tracing::info!(
312                    num_calls = parsed.len(),
313                    "FunctionGemma router produced tool calls from text-only response"
314                );
315                self.rewrite_response(response, parsed)
316            }
317            Ok(_) => {
318                // FunctionGemma decided no tool calls are needed — pass through.
319                response
320            }
321            Err(e) => {
322                tracing::warn!(
323                    error = %e,
324                    "FunctionGemma router failed; returning original response"
325                );
326                response
327            }
328        }
329    }
330
331    /// Run the FunctionGemma model in a blocking thread.
332    async fn run_functiongemma(
333        &self,
334        assistant_text: &str,
335        tools: &[ToolDefinition],
336    ) -> Result<Vec<ParsedToolCall>> {
337        let prompt = build_functiongemma_prompt(assistant_text, tools);
338        let runtime = Arc::clone(&self.runtime);
339
340        let output = tokio::task::spawn_blocking(move || {
341            let mut guard = runtime
342                .lock()
343                .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
344            // Use the raw prompt — we've already formatted it with the Gemma chat template.
345            // The thinker's `think()` wraps in System/User/Assistant roles; we need direct
346            // access to the generation loop.  We pass the full prompt as the user message
347            // and an empty system prompt so the thinker doesn't re-wrap it.
348            guard.think("", &prompt)
349        })
350        .await
351        .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
352
353        Ok(parse_functiongemma_response(&output.text))
354    }
355
356    /// Rewrite the `CompletionResponse` to include parsed tool calls.
357    fn rewrite_response(
358        &self,
359        mut response: CompletionResponse,
360        calls: Vec<ParsedToolCall>,
361    ) -> CompletionResponse {
362        // Keep existing text parts; append tool call parts.
363        for call in calls {
364            response.message.content.push(ContentPart::ToolCall {
365                id: format!("fc_{}", Uuid::new_v4()),
366                name: call.name,
367                arguments: call.arguments,
368            });
369        }
370
371        // Signal the session loop that tool calls are present.
372        response.finish_reason = FinishReason::ToolCalls;
373        response
374    }
375}
376
377// ── Tests ────────────────────────────────────────────────────────────────────
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn parse_single_tool_call() {
385        let text = r#"<tool_call>
386{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
387</tool_call>"#;
388        let calls = parse_functiongemma_response(text);
389        assert_eq!(calls.len(), 1);
390        assert_eq!(calls[0].name, "read_file");
391        assert!(calls[0].arguments.contains("/tmp/foo.rs"));
392    }
393
394    #[test]
395    fn parse_multiple_tool_calls() {
396        let text = r#"I'll read both files.
397<tool_call>
398{"name": "read_file", "arguments": {"path": "a.rs"}}
399</tool_call>
400<tool_call>
401{"name": "read_file", "arguments": {"path": "b.rs"}}
402</tool_call>"#;
403        let calls = parse_functiongemma_response(text);
404        assert_eq!(calls.len(), 2);
405        assert_eq!(calls[0].name, "read_file");
406        assert_eq!(calls[1].name, "read_file");
407    }
408
409    #[test]
410    fn parse_no_tool_calls() {
411        let text = "I cannot help with that request.";
412        let calls = parse_functiongemma_response(text);
413        assert!(calls.is_empty());
414    }
415
416    #[test]
417    fn parse_malformed_json_skipped() {
418        let text = r#"<tool_call>
419not valid json
420</tool_call>
421<tool_call>
422{"name": "list_dir", "arguments": {"path": "."}}
423</tool_call>"#;
424        let calls = parse_functiongemma_response(text);
425        assert_eq!(calls.len(), 1);
426        assert_eq!(calls[0].name, "list_dir");
427    }
428
429    #[test]
430    fn parse_empty_name_skipped() {
431        let text = r#"<tool_call>
432{"name": "", "arguments": {}}
433</tool_call>"#;
434        let calls = parse_functiongemma_response(text);
435        assert!(calls.is_empty());
436    }
437
438    #[test]
439    fn prompt_contains_tool_definitions() {
440        let tools = vec![ToolDefinition {
441            name: "read_file".to_string(),
442            description: "Read a file".to_string(),
443            parameters: serde_json::json!({
444                "type": "object",
445                "properties": {
446                    "path": { "type": "string" }
447                },
448                "required": ["path"]
449            }),
450        }];
451        let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
452        assert!(prompt.contains("<start_of_turn>system"));
453        assert!(prompt.contains("read_file"));
454        assert!(prompt.contains("<tools>"));
455        assert!(prompt.contains("Please read foo.rs"));
456        assert!(prompt.contains("<start_of_turn>model"));
457    }
458
459    #[test]
460    fn config_defaults_disabled() {
461        let config = ToolRouterConfig::default();
462        assert!(!config.enabled);
463        assert_eq!(config.arch, "gemma3");
464        assert_eq!(config.max_tokens, 512);
465    }
466}