1pub(crate) mod grammar;
2pub(crate) mod parsers;
3mod request;
4mod response;
5
6use hanzo_ml::Result;
7pub use request::*;
8pub use response::*;
9use serde::de::{self, Deserializer, MapAccess, Visitor};
10use serde_json::{Map, Value};
11use std::collections::HashMap;
12use std::fmt;
13use std::sync::Arc;
14use uuid::Uuid;
15
16use hanzo_llm_mcp::CalledFunction;
17
18pub use hanzo_llm_mcp::{ToolCallback, ToolCallbackWithTool};
19
20pub type ToolCallbacks = HashMap<String, Arc<ToolCallback>>;
22
23pub type ToolCallbacksWithTools = HashMap<String, ToolCallbackWithTool>;
25
26fn contains_tool_call_prefix(prefix: &str) -> bool {
27 parsers::contains_tool_call_prefix(prefix)
28}
29
30fn process_model_specific_message(message: &str) -> Result<String> {
31 parsers::process_model_specific_message(message)
32}
33
34pub struct ToolCallingMatcher {
35 tool_choice: ToolChoice,
36 known_tool_names: Option<std::collections::HashSet<String>>,
37 tools: Option<Arc<Vec<crate::Tool>>>,
38}
39
40#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
42pub struct CalledFunctionParameters {
43 #[serde(alias = "function")]
44 pub name: String,
45 #[serde(alias = "arguments", deserialize_with = "flexible_args")]
46 pub parameters: Value,
47}
48
49fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
51where
52 D: Deserializer<'de>,
53{
54 struct ArgVisitor;
55
56 impl<'de> Visitor<'de> for ArgVisitor {
57 type Value = Value;
58
59 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 f.write_str("an object or a JSON-encoded string containing an object")
61 }
62
63 fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
65 where
66 M: MapAccess<'de>,
67 {
68 let mut map = Map::new();
69 while let Some((k, v)) = m.next_entry()? {
70 map.insert(k, v);
71 }
72 Ok(Value::Object(map))
73 }
74
75 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
77 where
78 E: de::Error,
79 {
80 serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
81 }
82 }
83
84 d.deserialize_any(ArgVisitor)
85}
86
87fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
90 if raw.contains(r#""arguments":"{"#) {
93 let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
95 let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
97 Ok(fixed)
98 } else {
99 Ok(raw.to_string())
100 }
101}
102
103impl ToolCallingMatcher {
104 pub fn new(tool_choice: ToolChoice, tools: Option<&[crate::Tool]>) -> anyhow::Result<Self> {
105 let known_tool_names = tools.map(|t| {
106 t.iter()
107 .map(|tool| tool.function.name.clone())
108 .collect::<std::collections::HashSet<_>>()
109 });
110 let tools_arc = tools.map(|t| Arc::new(t.to_vec()));
111 Ok(Self {
112 tool_choice,
113 known_tool_names,
114 tools: tools_arc,
115 })
116 }
117
118 pub fn build_tool_call_grammar(&self, text: &str) -> Option<llguidance::api::TopLevelGrammar> {
123 if matches!(self.tool_choice, ToolChoice::None) {
124 return None;
125 }
126 let tools = self.tools.as_ref()?;
127 parsers::build_tool_call_grammar(text, tools)
128 }
129
130 pub fn build_harmony_tool_grammar(
135 &self,
136 tool_name: Option<&str>,
137 ) -> Option<llguidance::api::TopLevelGrammar> {
138 if matches!(self.tool_choice, ToolChoice::None) {
139 return None;
140 }
141 let tools = self.tools.as_ref()?;
142 Some(parsers::harmony::tool_call_grammar_for_tool(
143 tool_name,
144 Some(tools),
145 ))
146 }
147
148 pub fn prefix_could_be_tool(&self, message_prefix: &str) -> Result<(bool, bool)> {
155 if matches!(self.tool_choice, ToolChoice::None) {
156 return Ok((false, false));
157 }
158 let message_prefix = process_model_specific_message(message_prefix)?;
159 let message_prefix = fix_broken_json(&message_prefix).map_err(hanzo_ml::Error::msg)?;
160
161 Ok([
163 could_be_json::<CalledFunctionParameters>,
164 could_be_json::<Vec<CalledFunctionParameters>>,
165 ]
166 .iter()
167 .find_map(|check| {
168 let (could_be_tool, is_complete_tool) = check(&message_prefix);
169 if could_be_tool || is_complete_tool {
170 Some((could_be_tool, is_complete_tool))
171 } else {
172 None
173 }
174 })
175 .unwrap_or((contains_tool_call_prefix(&message_prefix), false)))
176 }
177
178 pub fn get_call(&self, message: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
179 if matches!(self.tool_choice, ToolChoice::None) {
180 return Ok(Vec::new());
181 }
182 let message = process_model_specific_message(message)?;
183 let message = fix_broken_json(&message)?;
184
185 let mut calls = if let Ok(deser) =
186 serde_json::from_str::<CalledFunctionParameters>(&message)
187 {
188 let id = format!("call-{}", Uuid::new_v4());
189 vec![ToolCallResponse {
190 index: 0,
191 id,
192 tp: ToolCallType::Function,
193 function: CalledFunction {
194 name: deser.name,
195 arguments: serde_json::to_string(&deser.parameters)?,
196 },
197 }]
198 } else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
199 deser
200 .into_iter()
201 .enumerate()
202 .map(|(idx, deser)| {
203 let id = format!("call-{}", Uuid::new_v4());
204 Ok(ToolCallResponse {
205 index: idx,
206 id,
207 tp: ToolCallType::Function,
208 function: CalledFunction {
209 name: deser.name,
210 arguments: serde_json::to_string(&deser.parameters)?,
211 },
212 })
213 })
214 .collect::<anyhow::Result<Vec<_>>>()?
215 } else {
216 if matches!(self.tool_choice, ToolChoice::Tool(_)) {
217 anyhow::bail!("Tool choice was required but no tools were called.")
218 }
219 return Ok(Vec::new());
220 };
221
222 if let Some(ref known) = self.known_tool_names {
224 let before = calls.len();
225 calls.retain(|tc| {
226 let valid = known.contains(&tc.function.name);
227 if !valid {
228 tracing::warn!(
229 "Dropping hallucinated tool call `{}` (not in defined tools: {:?})",
230 tc.function.name,
231 known
232 );
233 }
234 valid
235 });
236 if calls.is_empty() && before > 0 && matches!(self.tool_choice, ToolChoice::Tool(_)) {
237 anyhow::bail!("Tool choice was required but model called unknown tools.");
238 }
239 }
240
241 Ok(calls)
242 }
243}
244
245fn could_be_json<T>(text_prefix: &str) -> (bool, bool)
249where
250 T: serde::de::DeserializeOwned,
251{
252 if text_prefix.trim().is_empty() {
253 return (false, false);
254 }
255 match serde_json::from_str::<T>(text_prefix) {
256 Ok(_) => (false, true),
257 Err(e) if e.is_eof() => (true, false),
259 _ => (false, false),
260 }
261}
262
263pub fn parse_text_tools(
265 raw_text: &str,
266 matcher: Option<Arc<ToolCallingMatcher>>,
267) -> anyhow::Result<(Option<&str>, Vec<ToolCallResponse>)> {
268 let mut tool_calls = Vec::new();
269 let mut text_new = Some(raw_text);
270
271 if let Some(ref matcher) = matcher {
272 let calls = matcher.get_call(raw_text).map_err(hanzo_ml::Error::msg)?;
273 if !calls.is_empty() {
274 text_new = None;
275 tool_calls = calls;
276 }
277 };
278 Ok((text_new, tool_calls))
279}