Skip to main content

codetether_agent/cognition/
tool_router.rs

1//! FunctionGemma-powered hybrid tool-call router.
2//!
3//! Sits between the primary LLM response and the tool-extraction step in the
4//! session agentic loop.  When the primary LLM returns text-only output that
5//! *describes* tool calls without using structured `ContentPart::ToolCall`
6//! entries, the router passes the text + available tool definitions through a
7//! local FunctionGemma model (via Candle) and emits properly-formatted
8//! `ContentPart::ToolCall` entries.
9//!
10//! **Feature-gated**: this module only compiles when the `functiongemma` cargo
11//! feature is enabled.  The binary size is unaffected in default builds.
12
13use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerBackend, ThinkerConfig};
19
20// ── Configuration ────────────────────────────────────────────────────────────
21
22/// Environment-variable driven configuration for the tool-call router.
23#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25    /// Whether the router is active.  Default: `false`.
26    pub enabled: bool,
27    /// Filesystem path to the FunctionGemma GGUF model.
28    pub model_path: Option<String>,
29    /// Filesystem path to the matching tokenizer.json.
30    pub tokenizer_path: Option<String>,
31    /// Architecture hint (default: `"gemma3"`).
32    pub arch: String,
33    /// Device preference (auto / cpu / cuda).
34    pub device: super::thinker::CandleDevicePreference,
35    /// Max tokens for the FunctionGemma response.
36    pub max_tokens: usize,
37    /// Temperature for FunctionGemma sampling.
38    pub temperature: f32,
39}
40
41impl Default for ToolRouterConfig {
42    fn default() -> Self {
43        Self {
44            enabled: false,
45            model_path: None,
46            tokenizer_path: None,
47            arch: "gemma3".to_string(),
48            device: super::thinker::CandleDevicePreference::Auto,
49            max_tokens: 512,
50            temperature: 0.1,
51        }
52    }
53}
54
55impl ToolRouterConfig {
56    /// Build from environment variables.
57    ///
58    /// | Variable | Description |
59    /// |----------|-------------|
60    /// | `CODETETHER_TOOL_ROUTER_ENABLED` | `true` / `1` to activate |
61    /// | `CODETETHER_TOOL_ROUTER_MODEL_PATH` | Path to `.gguf` model |
62    /// | `CODETETHER_TOOL_ROUTER_TOKENIZER_PATH` | Path to `tokenizer.json` |
63    /// | `CODETETHER_TOOL_ROUTER_ARCH` | Architecture hint (default: `gemma3`) |
64    /// | `CODETETHER_TOOL_ROUTER_DEVICE` | `auto` / `cpu` / `cuda` |
65    /// | `CODETETHER_TOOL_ROUTER_MAX_TOKENS` | Max decode tokens (default: 512) |
66    /// | `CODETETHER_TOOL_ROUTER_TEMPERATURE` | Sampling temp (default: 0.1) |
67    pub fn from_env() -> Self {
68        let enabled = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
69            .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
70            .unwrap_or(false);
71
72        Self {
73            enabled,
74            model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
75            tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
76            arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
77                .unwrap_or_else(|_| "gemma3".to_string()),
78            device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
79                .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
80                .unwrap_or(super::thinker::CandleDevicePreference::Auto),
81            max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
82                .ok()
83                .and_then(|v| v.parse().ok())
84                .unwrap_or(512),
85            temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
86                .ok()
87                .and_then(|v| v.parse().ok())
88                .unwrap_or(0.1),
89        }
90    }
91}
92
93// ── Prompt formatting ────────────────────────────────────────────────────────
94
95/// Serialize tool definitions into FunctionGemma's expected chat template.
96///
97/// FunctionGemma expects tools as a JSON list in the system turn, followed by
98/// the user's intent.  The model produces structured JSON function call output.
99fn build_functiongemma_prompt(assistant_text: &str, tools: &[ToolDefinition]) -> String {
100    // Build tool descriptions as a JSON array for the system section.
101    let tool_defs: Vec<serde_json::Value> = tools
102        .iter()
103        .map(|t| {
104            serde_json::json!({
105                "name": t.name,
106                "description": t.description,
107                "parameters": t.parameters,
108            })
109        })
110        .collect();
111
112    let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
113
114    // FunctionGemma chat template:
115    //   <start_of_turn>system
116    //   You are a function calling AI model. ...
117    //   <end_of_turn>
118    //   <start_of_turn>user
119    //   <user intent text>
120    //   <end_of_turn>
121    //   <start_of_turn>model
122    format!(
123        "<start_of_turn>system\n\
124         You are a function calling AI model. You are provided with function \
125         signatures within <tools></tools> XML tags. You may call one or more \
126         functions to assist with the user query. Don't make assumptions about \
127         what values to plug into functions.\n\n\
128         <tools>\n{tools_json}\n</tools>\n\n\
129         For each function call return a JSON object with function name and \
130         arguments within <tool_call></tool_call> XML tags as follows:\n\
131         <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
132         <end_of_turn>\n\
133         <start_of_turn>user\n\
134         {assistant_text}\n\
135         <end_of_turn>\n\
136         <start_of_turn>model\n"
137    )
138}
139
140// ── Response parsing ─────────────────────────────────────────────────────────
141
142/// A single parsed tool call from FunctionGemma output.
143#[derive(Debug, Clone)]
144struct ParsedToolCall {
145    name: String,
146    arguments: String, // JSON string
147}
148
149/// Parse FunctionGemma output into zero or more structured tool calls.
150///
151/// Expected format:
152/// ```text
153/// <tool_call>
154/// {"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
155/// </tool_call>
156/// ```
157///
158/// Handles multiple `<tool_call>` blocks in a single response.
159fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
160    let mut calls = Vec::new();
161
162    // Extract everything between <tool_call> and </tool_call>
163    let mut remaining = text;
164    while let Some(start) = remaining.find("<tool_call>") {
165        remaining = &remaining[start + "<tool_call>".len()..];
166        if let Some(end) = remaining.find("</tool_call>") {
167            let block = remaining[..end].trim();
168            remaining = &remaining[end + "</tool_call>".len()..];
169
170            // Try to parse the JSON block
171            if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
172                let name = value
173                    .get("name")
174                    .and_then(|n| n.as_str())
175                    .unwrap_or("")
176                    .to_string();
177                let arguments = value
178                    .get("arguments")
179                    .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
180                    .unwrap_or_else(|| "{}".to_string());
181
182                if !name.is_empty() {
183                    calls.push(ParsedToolCall { name, arguments });
184                }
185            } else {
186                tracing::warn!(
187                    block = %block,
188                    "FunctionGemma produced unparseable tool_call block"
189                );
190            }
191        } else {
192            break; // Unclosed <tool_call> — stop
193        }
194    }
195
196    calls
197}
198
199// ── Router ───────────────────────────────────────────────────────────────────
200
201/// Hybrid tool-call router backed by a local FunctionGemma model.
202///
203/// Created once at session start; shared via `Arc` across prompt calls.
204pub struct ToolCallRouter {
205    runtime: Arc<Mutex<CandleThinker>>,
206}
207
208impl std::fmt::Debug for ToolCallRouter {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        f.debug_struct("ToolCallRouter").finish()
211    }
212}
213
214impl ToolCallRouter {
215    /// Construct from a [`ToolRouterConfig`].
216    ///
217    /// Returns `None` if the router is disabled or missing required paths.
218    pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
219        if !config.enabled {
220            tracing::debug!("FunctionGemma tool router is disabled");
221            return Ok(None);
222        }
223
224        let model_path = config.model_path.as_ref().ok_or_else(|| {
225            anyhow!("CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled")
226        })?;
227        let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
228            anyhow!(
229                "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
230            )
231        })?;
232
233        // Build a ThinkerConfig configured for the FunctionGemma model
234        let thinker_config = ThinkerConfig {
235            enabled: true,
236            backend: ThinkerBackend::Candle,
237            candle_model_path: Some(model_path.clone()),
238            candle_tokenizer_path: Some(tokenizer_path.clone()),
239            candle_arch: Some(config.arch.clone()),
240            candle_device: config.device,
241            max_tokens: config.max_tokens,
242            temperature: config.temperature,
243            ..ThinkerConfig::default()
244        };
245
246        let runtime = CandleThinker::new(&thinker_config)?;
247        tracing::info!(
248            model_path = %model_path,
249            arch = %config.arch,
250            "FunctionGemma tool-call router initialised"
251        );
252
253        Ok(Some(Self {
254            runtime: Arc::new(Mutex::new(runtime)),
255        }))
256    }
257
258    /// Conditionally reformat a `CompletionResponse`.
259    ///
260    /// - If the response already contains `ContentPart::ToolCall` entries,
261    ///   return it **unchanged** (zero overhead path).
262    /// - If the response is text-only, run FunctionGemma to convert the text
263    ///   into structured tool calls.
264    /// - On any internal error, return the **original** response unchanged
265    ///   (safe degradation — the router never breaks existing functionality).
266    pub async fn maybe_reformat(
267        &self,
268        response: CompletionResponse,
269        tools: &[ToolDefinition],
270    ) -> CompletionResponse {
271        // Fast path: if the response already has structured tool calls, pass through.
272        let has_tool_calls = response
273            .message
274            .content
275            .iter()
276            .any(|p| matches!(p, ContentPart::ToolCall { .. }));
277
278        if has_tool_calls {
279            return response;
280        }
281
282        // No tools were provided — nothing for FunctionGemma to match against.
283        if tools.is_empty() {
284            return response;
285        }
286
287        // Collect assistant text from the response.
288        let assistant_text: String = response
289            .message
290            .content
291            .iter()
292            .filter_map(|p| match p {
293                ContentPart::Text { text } => Some(text.as_str()),
294                _ => None,
295            })
296            .collect::<Vec<_>>()
297            .join("\n");
298
299        if assistant_text.trim().is_empty() {
300            return response;
301        }
302
303        // Run FunctionGemma in a blocking thread (CPU-bound).
304        match self.run_functiongemma(&assistant_text, tools).await {
305            Ok(parsed) if !parsed.is_empty() => {
306                tracing::info!(
307                    num_calls = parsed.len(),
308                    "FunctionGemma router produced tool calls from text-only response"
309                );
310                self.rewrite_response(response, parsed)
311            }
312            Ok(_) => {
313                // FunctionGemma decided no tool calls are needed — pass through.
314                response
315            }
316            Err(e) => {
317                tracing::warn!(
318                    error = %e,
319                    "FunctionGemma router failed; returning original response"
320                );
321                response
322            }
323        }
324    }
325
326    /// Run the FunctionGemma model in a blocking thread.
327    async fn run_functiongemma(
328        &self,
329        assistant_text: &str,
330        tools: &[ToolDefinition],
331    ) -> Result<Vec<ParsedToolCall>> {
332        let prompt = build_functiongemma_prompt(assistant_text, tools);
333        let runtime = Arc::clone(&self.runtime);
334
335        let output = tokio::task::spawn_blocking(move || {
336            let mut guard = runtime
337                .lock()
338                .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
339            // Use the raw prompt — we've already formatted it with the Gemma chat template.
340            // The thinker's `think()` wraps in System/User/Assistant roles; we need direct
341            // access to the generation loop.  We pass the full prompt as the user message
342            // and an empty system prompt so the thinker doesn't re-wrap it.
343            guard.think("", &prompt)
344        })
345        .await
346        .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
347
348        Ok(parse_functiongemma_response(&output.text))
349    }
350
351    /// Rewrite the `CompletionResponse` to include parsed tool calls.
352    fn rewrite_response(
353        &self,
354        mut response: CompletionResponse,
355        calls: Vec<ParsedToolCall>,
356    ) -> CompletionResponse {
357        // Keep existing text parts; append tool call parts.
358        for call in calls {
359            response.message.content.push(ContentPart::ToolCall {
360                id: format!("fc_{}", Uuid::new_v4()),
361                name: call.name,
362                arguments: call.arguments,
363            });
364        }
365
366        // Signal the session loop that tool calls are present.
367        response.finish_reason = FinishReason::ToolCalls;
368        response
369    }
370}
371
372// ── Tests ────────────────────────────────────────────────────────────────────
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn parse_single_tool_call() {
380        let text = r#"<tool_call>
381{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
382</tool_call>"#;
383        let calls = parse_functiongemma_response(text);
384        assert_eq!(calls.len(), 1);
385        assert_eq!(calls[0].name, "read_file");
386        assert!(calls[0].arguments.contains("/tmp/foo.rs"));
387    }
388
389    #[test]
390    fn parse_multiple_tool_calls() {
391        let text = r#"I'll read both files.
392<tool_call>
393{"name": "read_file", "arguments": {"path": "a.rs"}}
394</tool_call>
395<tool_call>
396{"name": "read_file", "arguments": {"path": "b.rs"}}
397</tool_call>"#;
398        let calls = parse_functiongemma_response(text);
399        assert_eq!(calls.len(), 2);
400        assert_eq!(calls[0].name, "read_file");
401        assert_eq!(calls[1].name, "read_file");
402    }
403
404    #[test]
405    fn parse_no_tool_calls() {
406        let text = "I cannot help with that request.";
407        let calls = parse_functiongemma_response(text);
408        assert!(calls.is_empty());
409    }
410
411    #[test]
412    fn parse_malformed_json_skipped() {
413        let text = r#"<tool_call>
414not valid json
415</tool_call>
416<tool_call>
417{"name": "list_dir", "arguments": {"path": "."}}
418</tool_call>"#;
419        let calls = parse_functiongemma_response(text);
420        assert_eq!(calls.len(), 1);
421        assert_eq!(calls[0].name, "list_dir");
422    }
423
424    #[test]
425    fn parse_empty_name_skipped() {
426        let text = r#"<tool_call>
427{"name": "", "arguments": {}}
428</tool_call>"#;
429        let calls = parse_functiongemma_response(text);
430        assert!(calls.is_empty());
431    }
432
433    #[test]
434    fn prompt_contains_tool_definitions() {
435        let tools = vec![ToolDefinition {
436            name: "read_file".to_string(),
437            description: "Read a file".to_string(),
438            parameters: serde_json::json!({
439                "type": "object",
440                "properties": {
441                    "path": { "type": "string" }
442                },
443                "required": ["path"]
444            }),
445        }];
446        let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
447        assert!(prompt.contains("<start_of_turn>system"));
448        assert!(prompt.contains("read_file"));
449        assert!(prompt.contains("<tools>"));
450        assert!(prompt.contains("Please read foo.rs"));
451        assert!(prompt.contains("<start_of_turn>model"));
452    }
453
454    #[test]
455    fn config_defaults_disabled() {
456        let config = ToolRouterConfig::default();
457        assert!(!config.enabled);
458        assert_eq!(config.arch, "gemma3");
459        assert_eq!(config.max_tokens, 512);
460    }
461}