ambi 0.3.0

A flexible, multi-backend, customizable AI agent framework, entirely based on Rust.
// src/agent/tool/parser.rs

use crate::types::ToolCallParser;
use serde_json::Value;

/// Built-in framework tool parser based on Tag wrapping
/// Used to parse LLM outputs in the format like `[TOOL_CALL]{"name": "xxx", "args": {}}[/TOOL_CALL]`.
pub struct TagToolParser {
    /// The opening tag used by the LLM.
    pub start_tag: String,
    /// The closing tag used by the LLM.
    pub end_tag: String,
}

impl TagToolParser {
    /// Creates a customized tag parser.
    pub fn new(start_tag: &str, end_tag: &str) -> Self {
        Self {
            start_tag: start_tag.to_string(),
            end_tag: end_tag.to_string(),
        }
    }

    /// Remove the Markdown JSON syntax block wrapper that the model sometimes automatically adds
    fn clean_markdown_json(raw: &str) -> &str {
        let mut s = raw.trim();
        if s.starts_with("```json") {
            s = &s[7..];
        } else if s.starts_with("```") {
            s = &s[3..];
        }
        s = s.trim();
        if s.ends_with("```") {
            s = &s[..s.len() - 3];
        }
        s.trim()
    }

    /// Attempt to deserialize JSON and extract standard fields.
    /// Supports single objects and arrays of objects.
    fn try_parse_json(calls: &mut Vec<(String, Value)>, str: &str) -> bool {
        if let Ok(val) = serde_json::from_str::<Value>(str) {
            let mut process_item = |item: &Value| {
                if let (Some(name), Some(args)) =
                    (item.get("name").and_then(|n| n.as_str()), item.get("args"))
                {
                    calls.push((name.to_string(), args.clone()));
                }
            };

            if val.is_object() {
                process_item(&val);
            } else if let Some(arr) = val.as_array() {
                for item in arr {
                    process_item(item);
                }
            }
            true
        } else {
            false
        }
    }

    /// Extract the call and implement a truncation recovery mechanism (to prevent the JSON tail from being incomplete due to token limit)
    fn extract_and_push_call(json_str: &str, calls: &mut Vec<(String, Value)>) {
        let trimmed = json_str.trim();
        if trimmed.is_empty() {
            return;
        }

        // Try a complete analysis
        if Self::try_parse_json(calls, trimmed) {
            return;
        }

        // Truncation Recovery Mechanism: Find the last right curly brace and discard any garbled text or
        // truncated fields that may have been generated by the LLM afterwards.
        if let Some(last_brace) = trimmed.rfind('}') {
            let truncated = &trimmed[..=last_brace];
            if Self::try_parse_json(calls, truncated) {
                log::info!("Successfully recovered a truncated Tool JSON object.");
                return;
            }
        }

        // Complete failure: notify the pipeline to prompt the LLM to correct the format in the next round
        log::warn!("Failed to parse Tool JSON syntax: {}", trimmed);
        calls.push((
            "__format_error__".to_string(),
            serde_json::json!({
                "error": "Invalid JSON syntax",
                "raw": trimmed
            }),
        ));
    }
}

/// # Trait Implementation
impl ToolCallParser for TagToolParser {
    fn get_tags(&self) -> (String, String) {
        (self.start_tag.clone(), self.end_tag.clone())
    }
    fn format_instruction(&self, tools_json: &str) -> String {
        format!(
            "You can use tools. Call format:\n{}{{\"name\":\"tool_name\",\"args\":{{...}}}}{}\nAvailable tools:\n{}",
            self.start_tag, self.end_tag, tools_json
        )
    }

    fn parse(&self, text: &str) -> Vec<(String, Value)> {
        let mut calls = Vec::new();
        let mut current_text = text;

        while let Some(start) = current_text.find(&self.start_tag) {
            let content_start = start + self.start_tag.len();

            if let Some(end_offset) = current_text[content_start..].find(&self.end_tag) {
                // Found the complete closing tag
                let end = content_start + end_offset;
                let clean_json = Self::clean_markdown_json(&current_text[content_start..end]);
                Self::extract_and_push_call(clean_json, &mut calls);
                current_text = &current_text[end + self.end_tag.len()..];
            } else {
                // No end tag (usually occurs when the model reaches the max_tokens limit halfway through generation)
                let clean_json = Self::clean_markdown_json(&current_text[content_start..]);
                Self::extract_and_push_call(clean_json, &mut calls);
                break;
            }
        }
        calls
    }
}

/// # Default Constructor
pub struct DefaultToolParser;

impl DefaultToolParser {
    /// Instantiates the `[TOOL_CALL]` based tag parser.
    pub fn make() -> TagToolParser {
        TagToolParser::new("[TOOL_CALL]", "[/TOOL_CALL]")
    }
}