Skip to main content

hanzo_engine/tools/
mod.rs

1pub(crate) mod grammar;
2pub(crate) mod parsers;
3mod request;
4mod response;
5
6use hanzo_ml::Result;
7pub use request::*;
8pub use response::*;
9use serde::de::{self, Deserializer, MapAccess, Visitor};
10use serde_json::{Map, Value};
11use std::collections::HashMap;
12use std::fmt;
13use std::sync::Arc;
14use uuid::Uuid;
15
16use hanzo_llm_mcp::CalledFunction;
17
18pub use hanzo_llm_mcp::{ToolCallback, ToolCallbackWithTool};
19
20/// Collection of callbacks keyed by tool name.
21pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
22
23/// Collection of callbacks with their tool definitions keyed by tool name.
24pub type ToolCallbacksWithTools = HashMap<String, ToolCallbackWithTool>;
25
26fn contains_tool_call_prefix(prefix: &str) -> bool {
27    parsers::contains_tool_call_prefix(prefix)
28}
29
30fn process_model_specific_message(message: &str) -> Result<String> {
31    parsers::process_model_specific_message(message)
32}
33
34pub struct ToolCallingMatcher {
35    tool_choice: ToolChoice,
36    known_tool_names: Option<std::collections::HashSet<String>>,
37    tools: Option<Arc<Vec<crate::Tool>>>,
38}
39
40// Same as CalledFunction, but has different cases for variations on the names
41#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
42pub struct CalledFunctionParameters {
43    #[serde(alias = "function")]
44    pub name: String,
45    #[serde(alias = "arguments", deserialize_with = "flexible_args")]
46    pub parameters: Value,
47}
48
49// Accept either `{...}` **or** a `"stringified { ... }"`
50fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
51where
52    D: Deserializer<'de>,
53{
54    struct ArgVisitor;
55
56    impl<'de> Visitor<'de> for ArgVisitor {
57        type Value = Value;
58
59        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
60            f.write_str("an object or a JSON-encoded string containing an object")
61        }
62
63        // Case 1 – the good case: already a JSON object
64        fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
65        where
66            M: MapAccess<'de>,
67        {
68            let mut map = Map::new();
69            while let Some((k, v)) = m.next_entry()? {
70                map.insert(k, v);
71            }
72            Ok(Value::Object(map))
73        }
74
75        // Case 2 – got a *string*; try parsing it as JSON
76        fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
77        where
78            E: de::Error,
79        {
80            serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
81        }
82    }
83
84    d.deserialize_any(ArgVisitor)
85}
86
87/// Fixup potentially broken JSON
88/// 1) allow/handle arguments as maps in quotations
89fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
90    // Only apply the fix if the first pattern matches - otherwise we might corrupt valid JSON
91    // where arguments is a properly escaped string containing `}`
92    if raw.contains(r#""arguments":"{"#) {
93        // 1) Delete the opening quote that shouldn't be there
94        let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
95        // 2) Delete the closing quote that matches it
96        let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
97        Ok(fixed)
98    } else {
99        Ok(raw.to_string())
100    }
101}
102
103impl ToolCallingMatcher {
104    pub fn new(tool_choice: ToolChoice, tools: Option<&[crate::Tool]>) -> anyhow::Result<Self> {
105        let known_tool_names = tools.map(|t| {
106            t.iter()
107                .map(|tool| tool.function.name.clone())
108                .collect::<std::collections::HashSet<_>>()
109        });
110        let tools_arc = tools.map(|t| Arc::new(t.to_vec()));
111        Ok(Self {
112            tool_choice,
113            known_tool_names,
114            tools: tools_arc,
115        })
116    }
117
118    /// Build a tool call grammar if a known format prefix is detected in
119    /// `text` and tools are available.  Returns `None` when tool choice is
120    /// `None`, no format matches, or the format is not yet ready (e.g.
121    /// DeepSeek before the JSON fence).
122    pub fn build_tool_call_grammar(&self, text: &str) -> Option<llguidance::api::TopLevelGrammar> {
123        if matches!(self.tool_choice, ToolChoice::None) {
124            return None;
125        }
126        let tools = self.tools.as_ref()?;
127        parsers::build_tool_call_grammar(text, tools)
128    }
129
130    /// Build a pure JSON object grammar for Harmony tool call arguments.
131    /// When `tool_name` identifies a tool with `strict: true`, its
132    /// parameters schema is used for constrained decoding.
133    /// Returns `None` when tool choice is `None` or no tools are defined.
134    pub fn build_harmony_tool_grammar(
135        &self,
136        tool_name: Option<&str>,
137    ) -> Option<llguidance::api::TopLevelGrammar> {
138        if matches!(self.tool_choice, ToolChoice::None) {
139            return None;
140        }
141        let tools = self.tools.as_ref()?;
142        Some(parsers::harmony::tool_call_grammar_for_tool(
143            tool_name,
144            Some(tools),
145        ))
146    }
147
148    // Checks if the `message_prefix` could be a tool call. If false, either
149    // [`ToolChoice::None`] was selected, or the prefix could not match.
150    //
151    // If the start of a message could be a tool call, then it looks like an incomplete JSON of a given structure, e.g. `{"name": "foo", "param`.
152    //
153    // Returns a tuple of `(could_be_tool, is_complete_tool)`.
154    pub fn prefix_could_be_tool(&self, message_prefix: &str) -> Result<(bool, bool)> {
155        if matches!(self.tool_choice, ToolChoice::None) {
156            return Ok((false, false));
157        }
158        let message_prefix = process_model_specific_message(message_prefix)?;
159        let message_prefix = fix_broken_json(&message_prefix).map_err(hanzo_ml::Error::msg)?;
160
161        // Check if the prefix could be a JSON serialization of any of the following types.
162        Ok([
163            could_be_json::<CalledFunctionParameters>,
164            could_be_json::<Vec<CalledFunctionParameters>>,
165        ]
166        .iter()
167        .find_map(|check| {
168            let (could_be_tool, is_complete_tool) = check(&message_prefix);
169            if could_be_tool || is_complete_tool {
170                Some((could_be_tool, is_complete_tool))
171            } else {
172                None
173            }
174        })
175        .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
176    }
177
178    pub fn get_call(&self, message: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
179        if matches!(self.tool_choice, ToolChoice::None) {
180            return Ok(Vec::new());
181        }
182        let message = process_model_specific_message(message)?;
183        let message = fix_broken_json(&message)?;
184
185        let mut calls = if let Ok(deser) =
186            serde_json::from_str::<CalledFunctionParameters>(&message)
187        {
188            let id = format!("call-{}", Uuid::new_v4());
189            vec![ToolCallResponse {
190                index: 0,
191                id,
192                tp: ToolCallType::Function,
193                function: CalledFunction {
194                    name: deser.name,
195                    arguments: serde_json::to_string(&deser.parameters)?,
196                },
197            }]
198        } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
199            deser
200                .into_iter()
201                .enumerate()
202                .map(|(idx, deser)| {
203                    let id = format!("call-{}", Uuid::new_v4());
204                    Ok(ToolCallResponse {
205                        index: idx,
206                        id,
207                        tp: ToolCallType::Function,
208                        function: CalledFunction {
209                            name: deser.name,
210                            arguments: serde_json::to_string(&deser.parameters)?,
211                        },
212                    })
213                })
214                .collect::<anyhow::Result<Vec<_>>>()?
215        } else {
216            if matches!(self.tool_choice, ToolChoice::Tool(_)) {
217                anyhow::bail!("Tool choice was required but no tools were called.")
218            }
219            return Ok(Vec::new());
220        };
221
222        // Filter out hallucinated tool names.
223        if let Some(ref known) = self.known_tool_names {
224            let before = calls.len();
225            calls.retain(|tc| {
226                let valid = known.contains(&tc.function.name);
227                if !valid {
228                    tracing::warn!(
229                        "Dropping hallucinated tool call `{}` (not in defined tools: {:?})",
230                        tc.function.name,
231                        known
232                    );
233                }
234                valid
235            });
236            if calls.is_empty() && before > 0 && matches!(self.tool_choice, ToolChoice::Tool(_)) {
237                anyhow::bail!("Tool choice was required but model called unknown tools.");
238            }
239        }
240
241        Ok(calls)
242    }
243}
244
245/// Checks if the given prefix could be the start of, or the entire JSON serialization of a given type, `T`.
246///
247/// Returns a tuple of `(could_be_tool, is_entire_tool)`.
248fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
249where
250    T: serde::de::DeserializeOwned,
251{
252    if text_prefix.trim().is_empty() {
253        return (false, false);
254    }
255    match serde_json::from_str::<T>(text_prefix) {
256        Ok(_) => (false, true),
257        // EOF show that JSON parsing was successful up to the end of the entire string.
258        Err(e) if e.is_eof() => (true, false),
259        _ => (false, false),
260    }
261}
262
263/// Takes raw UTf8 text and parses any possible tool calls from it.
264pub fn parse_text_tools(
265    raw_text: &str,
266    matcher: Option<Arc<ToolCallingMatcher>>,
267) -> anyhow::Result<(Option<&str>, Vec<ToolCallResponse>)> {
268    let mut tool_calls = Vec::new();
269    let mut text_new = Some(raw_text);
270
271    if let Some(ref matcher) = matcher {
272        let calls = matcher.get_call(raw_text).map_err(hanzo_ml::Error::msg)?;
273        if !calls.is_empty() {
274            text_new = None;
275            tool_calls = calls;
276        }
277    };
278    Ok((text_new, tool_calls))
279}