dynamo_parsers/tool_calling/harmony/
harmony_parser.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::config::JsonParserConfig;
5use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
6use openai_harmony::chat::{Content::Text, Role};
7use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding};
8use serde_json::Value;
9
10static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell<
11    Result<HarmonyEncoding, anyhow::Error>,
12> = tokio::sync::OnceCell::const_new();
13
14pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
15    GLOBAL_HARMONY_GPTOSS_ENCODING
16        .get_or_init(|| async {
17            tokio::task::spawn_blocking(|| {
18                load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss)
19            })
20            .await
21            .map_err(anyhow::Error::msg)
22            .flatten()
23        })
24        .await
25}
26
27/// Parse tool calls from a complete Harmony Format text chunk using direct token parsing.
28///
29/// This function is optimized for parsing complete text chunks where the entire content
30/// is available at once. It uses `parse_messages_from_completion_tokens` to directly
31/// parse all tokens into Harmony Format messages, then extracts tool calls from messages
32/// with the "commentary" channel and "functions.*" recipients.
33///
34/// This function doesn't perform start token detection
35/// or token-by-token streaming, making it more efficient for complete chunks.
36///
37/// # Arguments
38/// * `text` - The full Harmony-format string to be parsed, excluding any trailing stop tokens.
39///   Example:
40///   `<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}`
41/// * `_config` - Parser configuration (currently unused but kept for API consistency)
42///
43/// # Returns
44/// * `Ok((tool_calls, normal_text))` - Tuple containing extracted tool calls and any normal text
45/// * `Err(e)` - If parsing fails due to encoding or tokenization errors
46pub async fn parse_tool_calls_harmony_complete(
47    text: &str,
48    _config: &JsonParserConfig,
49) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
50    let enc = match get_harmony_encoding().await.as_ref() {
51        Ok(e) => e,
52        Err(e) => {
53            tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed.");
54            return Ok((vec![], Some(text.to_string())));
55        }
56    };
57
58    // // Encode the text into tokens using harmony encoding
59    let tokens: Vec<u32> = enc.tokenizer().encode_with_special_tokens(text);
60    let messages = match enc.parse_messages_from_completion_tokens(tokens, Some(Role::Assistant)) {
61        Ok(messages) => messages,
62        Err(e) => {
63            tracing::debug!(
64                "Failed to parse messages from completion tokens: {e}. Tool calls will not be parsed."
65            );
66            return Ok((vec![], Some(text.to_string())));
67        }
68    };
69
70    let mut normal_text = String::new();
71
72    let mut res = Vec::with_capacity(messages.len());
73    let mut call_idx = 0; // Index of the tool call
74
75    for message in messages.iter() {
76        if message.author.role != Role::Assistant {
77            continue;
78        }
79
80        let channel = message.channel.as_deref();
81        let recipient = message.recipient.as_deref().unwrap_or_default();
82
83        // Handle commentary channel
84        if channel == Some("commentary") && recipient.starts_with("functions.") {
85            let Some(fname) = message
86                .recipient
87                .as_ref()
88                .and_then(|r| r.split('.').nth(1))
89                .filter(|s| !s.is_empty())
90                .map(|s| s.to_string())
91            else {
92                continue;
93            };
94
95            let args = match message.content.first() {
96                Some(Text(text)) => match serde_json::from_str::<Value>(text.text.trim()) {
97                    Ok(value) => value,
98                    Err(_) => {
99                        Value::Null // Set args to null if it's not valid JSON
100                    }
101                },
102                _ => {
103                    Value::Null // Set args to null if it's not a text content
104                }
105            };
106            // Add tool call to result if args is valid JSON
107            if !args.is_null() {
108                call_idx += 1;
109                res.push(ToolCallResponse {
110                    id: format!("call-{}", call_idx),
111                    tp: ToolCallType::Function,
112                    function: CalledFunction {
113                        name: fname.to_string(),
114                        // Safety: `Value::Object` is always valid JSON, so serialization cannot fail
115                        arguments: serde_json::to_string(&args).unwrap(),
116                    },
117                });
118            }
119        // Handle reasoning(analysis) channel
120        } else if channel == Some("analysis") {
121            normal_text.push_str(match &message.content[0] {
122                Text(t) => &t.text,
123                _ => "",
124            });
125        }
126    }
127    Ok((res, Some(normal_text.to_string())))
128}
129
130pub fn detect_tool_call_start_harmony(
131    chunk: &str,
132    config: &JsonParserConfig,
133    strict: bool,
134) -> bool {
135    let trimmed = chunk.trim();
136    if trimmed.is_empty() {
137        return false;
138    }
139
140    if strict {
141        // Check for complete start tokens first
142        let has_complete_token = config
143            .tool_call_start_tokens
144            .iter()
145            .any(|token| !token.is_empty() && trimmed.contains(token));
146
147        if has_complete_token {
148            return true;
149        }
150
151        // Check for partial start tokens (streaming scenario)
152        // This handles cases where start tokens are split across multiple chunks
153        config.tool_call_start_tokens.iter().any(|token| {
154            if token.is_empty() {
155                return false;
156            }
157            // Check if the chunk could be a prefix of this start token
158            // Handle Unicode character boundaries properly
159            for i in 1..=token.chars().count() {
160                if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
161                    let prefix_str = &prefix[..prefix.len()];
162                    if trimmed == prefix_str || trimmed.ends_with(prefix_str) {
163                        return true;
164                    }
165                }
166            }
167            false
168        })
169    } else {
170        // Non-strict mode: check complete tokens and some heuristics
171        let has_complete_token = config
172            .tool_call_start_tokens
173            .iter()
174            .any(|token| !token.is_empty() && trimmed.contains(token));
175
176        if has_complete_token {
177            return true;
178        }
179
180        // Check for partial start tokens or known patterns
181        let has_partial_token = config.tool_call_start_tokens.iter().any(|token| {
182            if token.is_empty() {
183                return false;
184            }
185            // Check if the chunk could be a prefix of this start token
186            // Handle Unicode character boundaries properly
187            for i in 1..=token.chars().count() {
188                if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
189                    let prefix_str = &prefix[..prefix.len()];
190                    if trimmed == prefix_str || trimmed.ends_with(prefix_str) {
191                        return true;
192                    }
193                }
194            }
195            false
196        });
197
198        has_partial_token || trimmed.contains("<|channel|>")
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
207        let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
208        (call.function.name, args)
209    }
210
211    #[tokio::test]
212    async fn test_parse_tool_calls_harmony_complete_basic() {
213        let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#;
214        let (tool_calls, normal_content) =
215            parse_tool_calls_harmony_complete(text, &Default::default())
216                .await
217                .unwrap();
218        assert_eq!(normal_content, Some("".to_string()));
219        let (name, args) = extract_name_and_args(tool_calls[0].clone());
220        assert_eq!(name, "get_current_weather");
221        assert_eq!(args["location"], "San Francisco");
222        assert_eq!(args["format"], "celsius");
223    }
224
225    #[tokio::test]
226    async fn test_parse_tools_harmony_without_start_token() {
227        let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|message|>{"location":"San Francisco"}<|call|>"#;
228        let (tool_calls, normal_content) =
229            parse_tool_calls_harmony_complete(text, &Default::default())
230                .await
231                .unwrap();
232        assert_eq!(normal_content, Some(text.trim().to_string()));
233        assert_eq!(tool_calls.len(), 0);
234    }
235
236    #[tokio::test]
237    async fn test_parse_tool_calls_harmony_with_multi_args() {
238        let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>"#;
239        let (tool_calls, normal_content) =
240            parse_tool_calls_harmony_complete(text, &Default::default())
241                .await
242                .unwrap();
243        assert_eq!(
244            normal_content,
245            Some("Need to use function get_current_weather.".to_string())
246        );
247        assert_eq!(tool_calls.len(), 1);
248        let (name, args) = extract_name_and_args(tool_calls[0].clone());
249        assert_eq!(name, "get_current_weather");
250        assert_eq!(args["location"], "San Francisco");
251        assert_eq!(args["unit"], "fahrenheit");
252    }
253
254    #[tokio::test]
255    async fn test_parse_tool_calls_harmony_with_normal_text() {
256        let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>"#;
257        let (tool_calls, normal_content) =
258            parse_tool_calls_harmony_complete(text, &Default::default())
259                .await
260                .unwrap();
261        assert_eq!(
262            normal_content,
263            Some("Need to use function get_current_weather.".to_string())
264        );
265        assert_eq!(tool_calls.len(), 1);
266        let (name, args) = extract_name_and_args(tool_calls[0].clone());
267        assert_eq!(name, "get_current_weather");
268        assert_eq!(args["location"], "San Francisco");
269    }
270
271    #[tokio::test]
272    async fn test_parse_tool_calls_harmony_without_call_token() {
273        let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
274        let (tool_calls, normal_content) =
275            parse_tool_calls_harmony_complete(text, &Default::default())
276                .await
277                .unwrap();
278        assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
279        assert_eq!(tool_calls.len(), 1);
280        let (name, args) = extract_name_and_args(tool_calls[0].clone());
281        assert_eq!(name, "get_weather");
282        assert_eq!(args["location"], "San Francisco, CA");
283        assert_eq!(args["unit"], "celsius");
284    }
285}
286
287#[cfg(test)]
288mod detect_parser_tests {
289    use super::*;
290
291    #[test]
292    fn test_detect_tool_call_start_harmony_chunk_with_tool_call_start_token() {
293        let text = r#"<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json"#;
294        let config = JsonParserConfig {
295            tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
296            tool_call_end_tokens: vec!["<|call|>".to_string()],
297            ..Default::default()
298        };
299        let result = detect_tool_call_start_harmony(text, &config, false);
300        assert!(result);
301    }
302
303    #[test]
304    fn test_detect_tool_call_start_harmony_chunk_without_tool_call_start_token() {
305        // This is a warkaround for now. Right now everything is treated as tool call start token.
306        // We need to improve this in the future.
307        let text = r#"<|channel|>commentary to=functions.get_current_weather"#;
308        let config = JsonParserConfig {
309            tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
310            tool_call_end_tokens: vec!["<|call|>".to_string()],
311            ..Default::default()
312        };
313        let result = detect_tool_call_start_harmony(text, &config, false);
314        assert!(result);
315    }
316
317    #[test]
318    fn test_detect_tool_call_start_harmony_partial_tokens() {
319        // Test partial token detection for streaming scenarios
320        let config = JsonParserConfig {
321            tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
322            tool_call_end_tokens: vec!["<|call|>".to_string()],
323            ..Default::default()
324        };
325
326        // Test various partial prefixes in strict mode
327        assert!(
328            detect_tool_call_start_harmony("<", &config, true),
329            "'<' should be detected as potential start"
330        );
331        assert!(
332            detect_tool_call_start_harmony("<|", &config, true),
333            "'<|' should be detected as potential start"
334        );
335        assert!(
336            detect_tool_call_start_harmony("<|start|>", &config, true),
337            "'<|start|>' should be detected as potential start"
338        );
339        assert!(
340            detect_tool_call_start_harmony("<|start|>assistant", &config, true),
341            "'<|start|>assistant' should be detected as potential start"
342        );
343
344        // Test that unrelated text is not detected in strict mode
345        assert!(
346            !detect_tool_call_start_harmony("hello world", &config, true),
347            "'hello world' should not be detected in strict mode"
348        );
349        assert!(
350            !detect_tool_call_start_harmony("xyz", &config, true),
351            "'xyz' should not be detected in strict mode"
352        );
353    }
354}