Skip to main content

codetether_agent/cognition/
tool_router.rs

1//! FunctionGemma-powered hybrid tool-call router.
2//!
3//! Sits between the primary LLM response and the tool-extraction step in the
4//! session agentic loop.  When the primary LLM returns text-only output that
5//! *describes* tool calls without using structured `ContentPart::ToolCall`
6//! entries, the router passes the text + available tool definitions through a
7//! local FunctionGemma model (via Candle) and emits properly-formatted
8//! `ContentPart::ToolCall` entries.
9//!
10//! **Feature-gated**: this module only compiles when the `functiongemma` cargo
11//! feature is enabled.  The binary size is unaffected in default builds.
12
13use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerBackend, ThinkerConfig};
19
20// ── Configuration ────────────────────────────────────────────────────────────
21
22/// Environment-variable driven configuration for the tool-call router.
23#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25    /// Whether the router is active.  Default: `false`.
26    pub enabled: bool,
27    /// Filesystem path to the FunctionGemma GGUF model.
28    pub model_path: Option<String>,
29    /// Filesystem path to the matching tokenizer.json.
30    pub tokenizer_path: Option<String>,
31    /// Architecture hint (default: `"gemma3"`).
32    pub arch: String,
33    /// Device preference (auto / cpu / cuda).
34    pub device: super::thinker::CandleDevicePreference,
35    /// Max tokens for the FunctionGemma response.
36    /// FunctionGemma only outputs `<tool_call>` JSON blocks — 128 is generous.
37    pub max_tokens: usize,
38    /// Temperature for FunctionGemma sampling.
39    pub temperature: f32,
40}
41
42impl Default for ToolRouterConfig {
43    fn default() -> Self {
44        Self {
45            enabled: false,
46            model_path: None,
47            tokenizer_path: None,
48            arch: "gemma3".to_string(),
49            device: super::thinker::CandleDevicePreference::Auto,
50            max_tokens: 128,
51            temperature: 0.1,
52        }
53    }
54}
55
56impl ToolRouterConfig {
57    /// Build from environment variables.
58    ///
59    /// | Variable | Description |
60    /// |----------|-------------|
61    /// | `CODETETHER_TOOL_ROUTER_ENABLED` | `true` / `1` to activate |
62    /// | `CODETETHER_TOOL_ROUTER_MODEL_PATH` | Path to `.gguf` model |
63    /// | `CODETETHER_TOOL_ROUTER_TOKENIZER_PATH` | Path to `tokenizer.json` |
64    /// | `CODETETHER_TOOL_ROUTER_ARCH` | Architecture hint (default: `gemma3`) |
65    /// | `CODETETHER_TOOL_ROUTER_DEVICE` | `auto` / `cpu` / `cuda` |
66    /// | `CODETETHER_TOOL_ROUTER_MAX_TOKENS` | Max decode tokens (default: 512) |
67    /// | `CODETETHER_TOOL_ROUTER_TEMPERATURE` | Sampling temp (default: 0.1) |
68    /// | `CODETETHER_FUNCTIONGEMMA_DISABLED` | Emergency kill switch. Defaults to `true` |
69    pub fn from_env() -> Self {
70        let enabled_requested = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
71            .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
72            .unwrap_or(false);
73
74        // Temporary safety default: keep FunctionGemma disabled unless explicitly
75        // unblocked. This prevents local CPU/GPU contention in normal CLI/TUI runs.
76        let disabled = std::env::var("CODETETHER_FUNCTIONGEMMA_DISABLED")
77            .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
78            .unwrap_or(true);
79
80        let enabled = enabled_requested && !disabled;
81
82        Self {
83            enabled,
84            model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
85            tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
86            arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
87                .unwrap_or_else(|_| "gemma3".to_string()),
88            device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
89                .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
90                .unwrap_or(super::thinker::CandleDevicePreference::Auto),
91            max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
92                .ok()
93                .and_then(|v| v.parse().ok())
94                .unwrap_or(128),
95            temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
96                .ok()
97                .and_then(|v| v.parse().ok())
98                .unwrap_or(0.1),
99        }
100    }
101}
102
103// ── Prompt formatting ────────────────────────────────────────────────────────
104
105/// Serialize tool definitions into FunctionGemma's expected chat template.
106///
107/// FunctionGemma expects tools as a JSON list in the system turn, followed by
108/// the user's intent.  The model produces structured JSON function call output.
109fn build_functiongemma_prompt(assistant_text: &str, tools: &[ToolDefinition]) -> String {
110    // Build tool descriptions as a JSON array for the system section.
111    let tool_defs: Vec<serde_json::Value> = tools
112        .iter()
113        .map(|t| {
114            serde_json::json!({
115                "name": t.name,
116                "description": t.description,
117                "parameters": t.parameters,
118            })
119        })
120        .collect();
121
122    let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
123
124    // FunctionGemma chat template:
125    //   <start_of_turn>system
126    //   You are a function calling AI model. ...
127    //   <end_of_turn>
128    //   <start_of_turn>user
129    //   <user intent text>
130    //   <end_of_turn>
131    //   <start_of_turn>model
132    format!(
133        "<start_of_turn>system\n\
134         You are a function calling AI model. You are provided with function \
135         signatures within <tools></tools> XML tags. You may call one or more \
136         functions to assist with the user query. Don't make assumptions about \
137         what values to plug into functions.\n\n\
138         <tools>\n{tools_json}\n</tools>\n\n\
139         For each function call return a JSON object with function name and \
140         arguments within <tool_call></tool_call> XML tags as follows:\n\
141         <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
142         <end_of_turn>\n\
143         <start_of_turn>user\n\
144         {assistant_text}\n\
145         <end_of_turn>\n\
146         <start_of_turn>model\n"
147    )
148}
149
150// ── Response parsing ─────────────────────────────────────────────────────────
151
152/// A single parsed tool call from FunctionGemma output.
153#[derive(Debug, Clone)]
154struct ParsedToolCall {
155    name: String,
156    arguments: String, // JSON string
157}
158
159/// Parse FunctionGemma output into zero or more structured tool calls.
160///
161/// Expected format:
162/// ```text
163/// <tool_call>
164/// {"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
165/// </tool_call>
166/// ```
167///
168/// Handles multiple `<tool_call>` blocks in a single response.
169fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
170    let mut calls = Vec::new();
171
172    // Extract everything between <tool_call> and </tool_call>
173    let mut remaining = text;
174    while let Some(start) = remaining.find("<tool_call>") {
175        remaining = &remaining[start + "<tool_call>".len()..];
176        if let Some(end) = remaining.find("</tool_call>") {
177            let block = remaining[..end].trim();
178            remaining = &remaining[end + "</tool_call>".len()..];
179
180            // Try to parse the JSON block
181            if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
182                let name = value
183                    .get("name")
184                    .and_then(|n| n.as_str())
185                    .unwrap_or("")
186                    .to_string();
187                let arguments = value
188                    .get("arguments")
189                    .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
190                    .unwrap_or_else(|| "{}".to_string());
191
192                if !name.is_empty() {
193                    calls.push(ParsedToolCall { name, arguments });
194                }
195            } else {
196                tracing::warn!(
197                    block = %block,
198                    "FunctionGemma produced unparseable tool_call block"
199                );
200            }
201        } else {
202            break; // Unclosed <tool_call> — stop
203        }
204    }
205
206    calls
207}
208
209// ── Router ───────────────────────────────────────────────────────────────────
210
211/// Hybrid tool-call router backed by a local FunctionGemma model.
212///
213/// Created once at session start; shared via `Arc` across prompt calls.
214pub struct ToolCallRouter {
215    runtime: Arc<Mutex<CandleThinker>>,
216}
217
218impl std::fmt::Debug for ToolCallRouter {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        f.debug_struct("ToolCallRouter").finish()
221    }
222}
223
224impl ToolCallRouter {
225    /// Construct from a [`ToolRouterConfig`].
226    ///
227    /// Returns `None` if the router is disabled or missing required paths.
228    pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
229        if !config.enabled {
230            tracing::debug!("FunctionGemma tool router is disabled");
231            return Ok(None);
232        }
233
234        let model_path = config.model_path.as_ref().ok_or_else(|| {
235            anyhow!("CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled")
236        })?;
237        let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
238            anyhow!(
239                "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
240            )
241        })?;
242
243        // Build a ThinkerConfig configured for the FunctionGemma model
244        let thinker_config = ThinkerConfig {
245            enabled: true,
246            backend: ThinkerBackend::Candle,
247            candle_model_path: Some(model_path.clone()),
248            candle_tokenizer_path: Some(tokenizer_path.clone()),
249            candle_arch: Some(config.arch.clone()),
250            candle_device: config.device,
251            max_tokens: config.max_tokens,
252            temperature: config.temperature,
253            ..ThinkerConfig::default()
254        };
255
256        let runtime = CandleThinker::new(&thinker_config)?;
257        tracing::info!(
258            model_path = %model_path,
259            arch = %config.arch,
260            "FunctionGemma tool-call router initialised"
261        );
262
263        Ok(Some(Self {
264            runtime: Arc::new(Mutex::new(runtime)),
265        }))
266    }
267
268    /// Conditionally reformat a `CompletionResponse`.
269    ///
270    /// - If the model natively supports tool calling, return **unchanged**
271    ///   (FunctionGemma is only useful for models that lack native tool support).
272    /// - If the response already contains `ContentPart::ToolCall` entries,
273    ///   return it **unchanged** (zero overhead path).
274    /// - If the assistant text doesn't look like it's describing tool usage,
275    ///   return **unchanged** (cheap heuristic avoids expensive inference).
276    /// - Otherwise, run FunctionGemma to convert the text into structured
277    ///   tool calls.
278    /// - On any internal error, return the **original** response unchanged
279    ///   (safe degradation — the router never breaks existing functionality).
280    pub async fn maybe_reformat(
281        &self,
282        response: CompletionResponse,
283        tools: &[ToolDefinition],
284        model_supports_tools: bool,
285    ) -> CompletionResponse {
286        // Fast path: model already handles tool calling natively.
287        // FunctionGemma is only needed for models that return text descriptions
288        // of tool calls instead of structured ContentPart::ToolCall entries.
289        if model_supports_tools {
290            tracing::trace!("Skipping FunctionGemma: model supports native tool calling");
291            return response;
292        }
293
294        // Fast path: if the response already has structured tool calls, pass through.
295        let has_tool_calls = response
296            .message
297            .content
298            .iter()
299            .any(|p| matches!(p, ContentPart::ToolCall { .. }));
300
301        if has_tool_calls {
302            return response;
303        }
304
305        // No tools were provided — nothing for FunctionGemma to match against.
306        if tools.is_empty() {
307            return response;
308        }
309
310        // Collect assistant text from the response.
311        let assistant_text: String = response
312            .message
313            .content
314            .iter()
315            .filter_map(|p| match p {
316                ContentPart::Text { text } => Some(text.as_str()),
317                _ => None,
318            })
319            .collect::<Vec<_>>()
320            .join("\n");
321
322        if assistant_text.trim().is_empty() {
323            return response;
324        }
325
326        // Cheap heuristic: skip FunctionGemma if the text doesn't mention any
327        // available tool name.  This avoids expensive CPU inference for pure
328        // conversational / final-answer responses.
329        let text_lower = assistant_text.to_lowercase();
330        let mentions_tool = tools
331            .iter()
332            .any(|t| text_lower.contains(&t.name.to_lowercase()));
333        if !mentions_tool {
334            tracing::trace!("Skipping FunctionGemma: assistant text mentions no tool names");
335            return response;
336        }
337
338        // Run FunctionGemma in a blocking thread (CPU-bound).
339        match self.run_functiongemma(&assistant_text, tools).await {
340            Ok(parsed) if !parsed.is_empty() => {
341                tracing::info!(
342                    num_calls = parsed.len(),
343                    "FunctionGemma router produced tool calls from text-only response"
344                );
345                self.rewrite_response(response, parsed)
346            }
347            Ok(_) => {
348                // FunctionGemma decided no tool calls are needed — pass through.
349                response
350            }
351            Err(e) => {
352                tracing::warn!(
353                    error = %e,
354                    "FunctionGemma router failed; returning original response"
355                );
356                response
357            }
358        }
359    }
360
361    /// Run the FunctionGemma model in a blocking thread.
362    async fn run_functiongemma(
363        &self,
364        assistant_text: &str,
365        tools: &[ToolDefinition],
366    ) -> Result<Vec<ParsedToolCall>> {
367        let prompt = build_functiongemma_prompt(assistant_text, tools);
368        let runtime = Arc::clone(&self.runtime);
369
370        let output = tokio::task::spawn_blocking(move || {
371            let mut guard = runtime
372                .lock()
373                .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
374            // Use the raw prompt — we've already formatted it with the Gemma chat template.
375            // The thinker's `think()` wraps in System/User/Assistant roles; we need direct
376            // access to the generation loop.  We pass the full prompt as the user message
377            // and an empty system prompt so the thinker doesn't re-wrap it.
378            guard.think("", &prompt)
379        })
380        .await
381        .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
382
383        Ok(parse_functiongemma_response(&output.text))
384    }
385
386    /// Rewrite the `CompletionResponse` to replace text with structured tool calls.
387    ///
388    /// The original text is **removed** so the model sees a pure tool-call
389    /// assistant turn.  On the follow-up turn it will receive the tool results
390    /// and compose a proper answer – rather than ignoring them because it
391    /// already gave a complete text response.
392    fn rewrite_response(
393        &self,
394        mut response: CompletionResponse,
395        calls: Vec<ParsedToolCall>,
396    ) -> CompletionResponse {
397        // Strip all text parts — the model should see only tool calls so it
398        // properly processes the tool results on the next iteration.
399        response
400            .message
401            .content
402            .retain(|p| !matches!(p, ContentPart::Text { .. }));
403
404        for call in calls {
405            response.message.content.push(ContentPart::ToolCall {
406                id: format!("fc_{}", Uuid::new_v4()),
407                name: call.name,
408                arguments: call.arguments,
409            });
410        }
411
412        // Signal the session loop that tool calls are present.
413        response.finish_reason = FinishReason::ToolCalls;
414        response
415    }
416}
417
418// ── Tests ────────────────────────────────────────────────────────────────────
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn parse_single_tool_call() {
426        let text = r#"<tool_call>
427{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
428</tool_call>"#;
429        let calls = parse_functiongemma_response(text);
430        assert_eq!(calls.len(), 1);
431        assert_eq!(calls[0].name, "read_file");
432        assert!(calls[0].arguments.contains("/tmp/foo.rs"));
433    }
434
435    #[test]
436    fn parse_multiple_tool_calls() {
437        let text = r#"I'll read both files.
438<tool_call>
439{"name": "read_file", "arguments": {"path": "a.rs"}}
440</tool_call>
441<tool_call>
442{"name": "read_file", "arguments": {"path": "b.rs"}}
443</tool_call>"#;
444        let calls = parse_functiongemma_response(text);
445        assert_eq!(calls.len(), 2);
446        assert_eq!(calls[0].name, "read_file");
447        assert_eq!(calls[1].name, "read_file");
448    }
449
450    #[test]
451    fn parse_no_tool_calls() {
452        let text = "I cannot help with that request.";
453        let calls = parse_functiongemma_response(text);
454        assert!(calls.is_empty());
455    }
456
457    #[test]
458    fn parse_malformed_json_skipped() {
459        let text = r#"<tool_call>
460not valid json
461</tool_call>
462<tool_call>
463{"name": "list_dir", "arguments": {"path": "."}}
464</tool_call>"#;
465        let calls = parse_functiongemma_response(text);
466        assert_eq!(calls.len(), 1);
467        assert_eq!(calls[0].name, "list_dir");
468    }
469
470    #[test]
471    fn parse_empty_name_skipped() {
472        let text = r#"<tool_call>
473{"name": "", "arguments": {}}
474</tool_call>"#;
475        let calls = parse_functiongemma_response(text);
476        assert!(calls.is_empty());
477    }
478
479    #[test]
480    fn prompt_contains_tool_definitions() {
481        let tools = vec![ToolDefinition {
482            name: "read_file".to_string(),
483            description: "Read a file".to_string(),
484            parameters: serde_json::json!({
485                "type": "object",
486                "properties": {
487                    "path": { "type": "string" }
488                },
489                "required": ["path"]
490            }),
491        }];
492        let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
493        assert!(prompt.contains("<start_of_turn>system"));
494        assert!(prompt.contains("read_file"));
495        assert!(prompt.contains("<tools>"));
496        assert!(prompt.contains("Please read foo.rs"));
497        assert!(prompt.contains("<start_of_turn>model"));
498    }
499
500    #[test]
501    fn config_defaults_disabled() {
502        let config = ToolRouterConfig::default();
503        assert!(!config.enabled);
504        assert_eq!(config.arch, "gemma3");
505        assert_eq!(config.max_tokens, 128);
506    }
507}