dynamo_parsers/tool_calling/harmony/
harmony_parser.rs1use super::config::JsonParserConfig;
5use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
6use openai_harmony::chat::{Content::Text, Role};
7use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding};
8use serde_json::Value;
9
10static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell<
11 Result<HarmonyEncoding, anyhow::Error>,
12> = tokio::sync::OnceCell::const_new();
13
14pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
15 GLOBAL_HARMONY_GPTOSS_ENCODING
16 .get_or_init(|| async {
17 tokio::task::spawn_blocking(|| {
18 load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss)
19 })
20 .await
21 .map_err(anyhow::Error::msg)
22 .flatten()
23 })
24 .await
25}
26
27pub async fn parse_tool_calls_harmony_complete(
47 text: &str,
48 _config: &JsonParserConfig,
49) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
50 let enc = match get_harmony_encoding().await.as_ref() {
51 Ok(e) => e,
52 Err(e) => {
53 tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed.");
54 return Ok((vec![], Some(text.to_string())));
55 }
56 };
57
58 let tokens: Vec<u32> = enc.tokenizer().encode_with_special_tokens(text);
60 let messages = match enc.parse_messages_from_completion_tokens(tokens, Some(Role::Assistant)) {
61 Ok(messages) => messages,
62 Err(e) => {
63 tracing::debug!(
64 "Failed to parse messages from completion tokens: {e}. Tool calls will not be parsed."
65 );
66 return Ok((vec![], Some(text.to_string())));
67 }
68 };
69
70 let mut normal_text = String::new();
71
72 let mut res = Vec::with_capacity(messages.len());
73 let mut call_idx = 0; for message in messages.iter() {
76 if message.author.role != Role::Assistant {
77 continue;
78 }
79
80 let channel = message.channel.as_deref();
81 let recipient = message.recipient.as_deref().unwrap_or_default();
82
83 if channel == Some("commentary") && recipient.starts_with("functions.") {
85 let Some(fname) = message
86 .recipient
87 .as_ref()
88 .and_then(|r| r.split('.').nth(1))
89 .filter(|s| !s.is_empty())
90 .map(|s| s.to_string())
91 else {
92 continue;
93 };
94
95 let args = match message.content.first() {
96 Some(Text(text)) => match serde_json::from_str::<Value>(text.text.trim()) {
97 Ok(value) => value,
98 Err(_) => {
99 Value::Null }
101 },
102 _ => {
103 Value::Null }
105 };
106 if !args.is_null() {
108 call_idx += 1;
109 res.push(ToolCallResponse {
110 id: format!("call-{}", call_idx),
111 tp: ToolCallType::Function,
112 function: CalledFunction {
113 name: fname.to_string(),
114 arguments: serde_json::to_string(&args).unwrap(),
116 },
117 });
118 }
119 } else if channel == Some("analysis") {
121 normal_text.push_str(match &message.content[0] {
122 Text(t) => &t.text,
123 _ => "",
124 });
125 }
126 }
127 Ok((res, Some(normal_text.to_string())))
128}
129
130pub fn detect_tool_call_start_harmony(
131 chunk: &str,
132 config: &JsonParserConfig,
133 strict: bool,
134) -> bool {
135 let trimmed = chunk.trim();
136 if trimmed.is_empty() {
137 return false;
138 }
139
140 if strict {
141 let has_complete_token = config
143 .tool_call_start_tokens
144 .iter()
145 .any(|token| !token.is_empty() && trimmed.contains(token));
146
147 if has_complete_token {
148 return true;
149 }
150
151 config.tool_call_start_tokens.iter().any(|token| {
154 if token.is_empty() {
155 return false;
156 }
157 for i in 1..=token.chars().count() {
160 if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
161 let prefix_str = &prefix[..prefix.len()];
162 if trimmed == prefix_str || trimmed.ends_with(prefix_str) {
163 return true;
164 }
165 }
166 }
167 false
168 })
169 } else {
170 let has_complete_token = config
172 .tool_call_start_tokens
173 .iter()
174 .any(|token| !token.is_empty() && trimmed.contains(token));
175
176 if has_complete_token {
177 return true;
178 }
179
180 let has_partial_token = config.tool_call_start_tokens.iter().any(|token| {
182 if token.is_empty() {
183 return false;
184 }
185 for i in 1..=token.chars().count() {
188 if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
189 let prefix_str = &prefix[..prefix.len()];
190 if trimmed == prefix_str || trimmed.ends_with(prefix_str) {
191 return true;
192 }
193 }
194 }
195 false
196 });
197
198 has_partial_token || trimmed.contains("<|channel|>")
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
207 let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
208 (call.function.name, args)
209 }
210
211 #[tokio::test]
212 async fn test_parse_tool_calls_harmony_complete_basic() {
213 let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#;
214 let (tool_calls, normal_content) =
215 parse_tool_calls_harmony_complete(text, &Default::default())
216 .await
217 .unwrap();
218 assert_eq!(normal_content, Some("".to_string()));
219 let (name, args) = extract_name_and_args(tool_calls[0].clone());
220 assert_eq!(name, "get_current_weather");
221 assert_eq!(args["location"], "San Francisco");
222 assert_eq!(args["format"], "celsius");
223 }
224
225 #[tokio::test]
226 async fn test_parse_tools_harmony_without_start_token() {
227 let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|message|>{"location":"San Francisco"}<|call|>"#;
228 let (tool_calls, normal_content) =
229 parse_tool_calls_harmony_complete(text, &Default::default())
230 .await
231 .unwrap();
232 assert_eq!(normal_content, Some(text.trim().to_string()));
233 assert_eq!(tool_calls.len(), 0);
234 }
235
236 #[tokio::test]
237 async fn test_parse_tool_calls_harmony_with_multi_args() {
238 let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>"#;
239 let (tool_calls, normal_content) =
240 parse_tool_calls_harmony_complete(text, &Default::default())
241 .await
242 .unwrap();
243 assert_eq!(
244 normal_content,
245 Some("Need to use function get_current_weather.".to_string())
246 );
247 assert_eq!(tool_calls.len(), 1);
248 let (name, args) = extract_name_and_args(tool_calls[0].clone());
249 assert_eq!(name, "get_current_weather");
250 assert_eq!(args["location"], "San Francisco");
251 assert_eq!(args["unit"], "fahrenheit");
252 }
253
254 #[tokio::test]
255 async fn test_parse_tool_calls_harmony_with_normal_text() {
256 let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>"#;
257 let (tool_calls, normal_content) =
258 parse_tool_calls_harmony_complete(text, &Default::default())
259 .await
260 .unwrap();
261 assert_eq!(
262 normal_content,
263 Some("Need to use function get_current_weather.".to_string())
264 );
265 assert_eq!(tool_calls.len(), 1);
266 let (name, args) = extract_name_and_args(tool_calls[0].clone());
267 assert_eq!(name, "get_current_weather");
268 assert_eq!(args["location"], "San Francisco");
269 }
270
271 #[tokio::test]
272 async fn test_parse_tool_calls_harmony_without_call_token() {
273 let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
274 let (tool_calls, normal_content) =
275 parse_tool_calls_harmony_complete(text, &Default::default())
276 .await
277 .unwrap();
278 assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
279 assert_eq!(tool_calls.len(), 1);
280 let (name, args) = extract_name_and_args(tool_calls[0].clone());
281 assert_eq!(name, "get_weather");
282 assert_eq!(args["location"], "San Francisco, CA");
283 assert_eq!(args["unit"], "celsius");
284 }
285}
286
287#[cfg(test)]
288mod detect_parser_tests {
289 use super::*;
290
291 #[test]
292 fn test_detect_tool_call_start_harmony_chunk_with_tool_call_start_token() {
293 let text = r#"<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json"#;
294 let config = JsonParserConfig {
295 tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
296 tool_call_end_tokens: vec!["<|call|>".to_string()],
297 ..Default::default()
298 };
299 let result = detect_tool_call_start_harmony(text, &config, false);
300 assert!(result);
301 }
302
303 #[test]
304 fn test_detect_tool_call_start_harmony_chunk_without_tool_call_start_token() {
305 let text = r#"<|channel|>commentary to=functions.get_current_weather"#;
308 let config = JsonParserConfig {
309 tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
310 tool_call_end_tokens: vec!["<|call|>".to_string()],
311 ..Default::default()
312 };
313 let result = detect_tool_call_start_harmony(text, &config, false);
314 assert!(result);
315 }
316
317 #[test]
318 fn test_detect_tool_call_start_harmony_partial_tokens() {
319 let config = JsonParserConfig {
321 tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
322 tool_call_end_tokens: vec!["<|call|>".to_string()],
323 ..Default::default()
324 };
325
326 assert!(
328 detect_tool_call_start_harmony("<", &config, true),
329 "'<' should be detected as potential start"
330 );
331 assert!(
332 detect_tool_call_start_harmony("<|", &config, true),
333 "'<|' should be detected as potential start"
334 );
335 assert!(
336 detect_tool_call_start_harmony("<|start|>", &config, true),
337 "'<|start|>' should be detected as potential start"
338 );
339 assert!(
340 detect_tool_call_start_harmony("<|start|>assistant", &config, true),
341 "'<|start|>assistant' should be detected as potential start"
342 );
343
344 assert!(
346 !detect_tool_call_start_harmony("hello world", &config, true),
347 "'hello world' should not be detected in strict mode"
348 );
349 assert!(
350 !detect_tool_call_start_harmony("xyz", &config, true),
351 "'xyz' should not be detected in strict mode"
352 );
353 }
354}