dynamo_parsers/reasoning/
gpt_oss_parser.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Debug;
5
6use crate::ParserResult;
7use crate::ReasoningParser;
8
9use openai_harmony::StreamableParser;
10use openai_harmony::chat::TextContent;
11use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, chat::Role, load_harmony_encoding};
12
13///// Static initialization of harmony encoder to not affect performance every time a parser is created
14/// This is because load_harmony_encoding downloads some tiktoken files into a directory and we don't want to do this every time we create a parser.
15use std::sync::OnceLock;
16
17static GLOBAL_HARMONY_GPTOSS_ENCODING: OnceLock<Result<HarmonyEncoding, anyhow::Error>> =
18    OnceLock::new();
19
20fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
21    GLOBAL_HARMONY_GPTOSS_ENCODING
22        .get_or_init(|| load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss))
23}
24
25pub struct GptOssReasoningParser {
26    parser: StreamableParser,
27}
28
29/// Implement Debug for GptOssReasoningParser separately because StreamableParser does not implement Debug
30impl Debug for GptOssReasoningParser {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("GptOssReasoningParser")
33            .field("parser", &self.parser.state_json())
34            .finish()
35    }
36}
37
38impl GptOssReasoningParser {
39    pub fn new() -> anyhow::Result<Self> {
40        let parser = match get_harmony_encoding().as_ref() {
41            Ok(enc) => match StreamableParser::new(enc.clone(), Some(Role::Assistant)) {
42                Ok(p) => p,
43                Err(e) => {
44                    tracing::warn!("Harmony StreamableParser init failed for GPT OSS: {e}");
45                    return Err(anyhow::anyhow!(
46                        "Failed to load Harmony StreamableParser: {e}"
47                    ));
48                }
49            },
50            Err(e) => {
51                tracing::warn!("Failed to load Harmony encoding for GPT OSS: {e}");
52                return Err(anyhow::anyhow!("Failed to load Harmony encoding: {e}"));
53            }
54        };
55        Ok(Self { parser })
56    }
57}
58
59fn encode_text_to_tokens(text: &str) -> anyhow::Result<Vec<u32>> {
60    let enc = get_harmony_encoding()
61        .as_ref()
62        .map_err(|e| anyhow::anyhow!("Failed to get harmony encoding: {e}"))?;
63    Ok(enc.tokenizer().encode_with_special_tokens(text))
64}
65
66impl ReasoningParser for GptOssReasoningParser {
67    fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
68        let token_ids = if token_ids.is_empty() {
69            // WAR: Since we are moving to just text based reasoning parsing, converting to token_ids now using harmony encoding
70            let encoded_tokens = match encode_text_to_tokens(text) {
71                Ok(tokens) => tokens,
72                Err(err) => {
73                    tracing::warn!("Failed to encode Harmony tokens: {err}");
74                    return ParserResult::default();
75                }
76            };
77            &encoded_tokens.to_vec()
78        } else {
79            token_ids
80        };
81
82        let parser = &mut self.parser;
83
84        for (i, token_id) in token_ids.iter().enumerate() {
85            tracing::debug!(
86                "Processing token {} of {}: {}",
87                i + 1,
88                token_ids.len(),
89                token_id
90            );
91            if let Err(e) = parser.process(*token_id) {
92                tracing::warn!("Harmony parse error for token_id {token_id}: {e}");
93                return ParserResult::default();
94            }
95        }
96
97        let output_msgs = parser.messages();
98        tracing::debug!("Parser has {} output messages", output_msgs.len());
99
100        match output_msgs.len() {
101            0 => {
102                tracing::debug!("No output messages, using current content");
103                let current = parser.current_content().unwrap_or_default();
104                tracing::debug!("Current content length: {}", current.len());
105                ParserResult {
106                    normal_text: String::new(),
107                    reasoning_text: current,
108                }
109            }
110            1 => {
111                tracing::debug!("Single output message detected");
112                let mut reasoning_text = String::new();
113                if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
114                    output_msgs[0].content.first()
115                {
116                    reasoning_text.push_str(text);
117                    tracing::debug!("Extracted reasoning text length: {}", reasoning_text.len());
118                }
119                let current = parser.current_content().unwrap_or_default();
120                tracing::debug!("Current content length: {}", current.len());
121                ParserResult {
122                    normal_text: current,
123                    reasoning_text,
124                }
125            }
126            _ => {
127                tracing::debug!("Multiple output messages detected: {}", output_msgs.len());
128                let mut reasoning_text = String::new();
129                let mut normal_text = String::new();
130
131                // Loop until second last message
132                for (i, parse_msg) in output_msgs.iter().take(output_msgs.len() - 1).enumerate() {
133                    tracing::debug!("Processing reasoning message {}", i + 1);
134                    if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
135                        parse_msg.content.first()
136                    {
137                        reasoning_text.push_str(text);
138                        tracing::debug!("Added {} chars to reasoning text", text.len());
139                    }
140                }
141
142                let last_msg = &output_msgs[output_msgs.len() - 1];
143                tracing::debug!("Processing final message");
144
145                // Handle the last message
146                if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
147                    last_msg.content.first()
148                {
149                    normal_text.push_str(text);
150                    tracing::debug!("Added {} chars to normal text", text.len());
151                }
152
153                tracing::debug!(
154                    "Final result - normal_text: {} chars, reasoning_text: {} chars",
155                    normal_text.len(),
156                    reasoning_text.len()
157                );
158
159                ParserResult {
160                    normal_text,
161                    reasoning_text,
162                }
163            }
164        }
165    }
166
167    fn parse_reasoning_streaming_incremental(
168        &mut self,
169        text: &str,
170        token_ids: &[u32],
171    ) -> ParserResult {
172        let token_ids = if token_ids.is_empty() {
173            // WAR: Since we are moving to just text based reasoning parsing, converting to token_ids now using harmony encoding
174            let encoded_tokens = match encode_text_to_tokens(text) {
175                Ok(tokens) => tokens,
176                Err(err) => {
177                    tracing::warn!("Failed to encode Harmony tokens: {err}");
178                    return ParserResult::default();
179                }
180            };
181            &encoded_tokens.to_vec()
182        } else {
183            token_ids
184        };
185
186        let parser: &mut StreamableParser = &mut self.parser;
187        let mut normal_delta = String::new();
188        let mut reasoning_delta = String::new();
189
190        for (i, token_id) in token_ids.iter().enumerate() {
191            tracing::debug!(
192                "Processing streaming token {} of {}: {}",
193                i + 1,
194                token_ids.len(),
195                token_id
196            );
197            if let Err(e) = parser.process(*token_id) {
198                tracing::warn!("Harmony parse error for token_id {token_id}: {e}");
199                return ParserResult::default();
200            }
201
202            if let (Some(delta), Some(channel)) = (
203                parser.last_content_delta().unwrap_or_default(),
204                parser.current_channel(),
205            ) {
206                // `last_content_delta` only exposes the newest token slice, so we forward
207                // `final`/`analysis` chunks immediately; commentary is reconstructed in the
208                // fallback path below because it needs the stripped metadata.
209                match channel.as_str() {
210                    "final" => normal_delta.push_str(&delta),
211                    "analysis" => reasoning_delta.push_str(&delta),
212                    "commentary" => {}
213                    _ => {}
214                }
215            }
216        }
217
218        if !normal_delta.is_empty() || !reasoning_delta.is_empty() {
219            tracing::debug!(
220                "Returning aggregated deltas: normal: {} chars, reasoning: {} chars",
221                normal_delta.len(),
222                reasoning_delta.len()
223            );
224            return ParserResult {
225                normal_text: normal_delta,
226                reasoning_text: reasoning_delta,
227            };
228        }
229
230        if let Some(channel) = parser.current_channel() {
231            if channel == "commentary" {
232                tracing::debug!("In commentary channel, recovering full content");
233                // If we're in the commentary channel, we should return raw token content and recover content that has been consumed by the parser
234                // so that the tool parser can process it properly
235                if let Ok(enc) = get_harmony_encoding() {
236                    let current_content = parser.current_content().unwrap_or_default();
237                    let mut final_text = text.to_string();
238
239                    // Restore commentary metadata consumed by the parser so the tool-call parser can
240                    // process it correctly.
241                    //
242                    // Example:
243                    //   Before parsing:
244                    //   "<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"format\":\"celsius\",\"location\":\"San Francisco\"}<|call|>"
245                    //   After parsing, the header is stripped, so we must reconstruct it:
246                    //   "<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>"
247                    //
248                    // This ensures downstream tool-call parsing receives the channel, target, and
249                    // constraint metadata together with the message payload.
250
251                    // Recovery should only happen once, and only when `current_content` is empty.
252                    if current_content.is_empty() {
253                        let tokens = parser.tokens();
254
255                        // Get the token id for " <|channel|>"
256                        let channel_token_id = enc
257                            .tokenizer()
258                            .encode_with_special_tokens("<|channel|>")
259                            .last()
260                            .copied();
261
262                        // Find the last occurrence of the <|channel|> token (id 20005) in the tokens vector
263                        let last_channel_token_idx = channel_token_id
264                            .and_then(|token_id| {
265                                tokens.iter().rposition(|token| *token == token_id)
266                            })
267                            .unwrap_or(0);
268
269                        // Then get the generated text from the last <|channel|> to the end of parser.tokens()
270                        let end_token_idx = parser.tokens().len();
271                        // Use Harmony's decode_utf8 to decode tokens into text
272                        let generated_text = enc
273                            .tokenizer()
274                            .decode_utf8(&parser.tokens()[last_channel_token_idx..end_token_idx])
275                            .unwrap_or_default();
276
277                        final_text = generated_text;
278                    }
279
280                    return ParserResult {
281                        normal_text: final_text,
282                        reasoning_text: String::new(),
283                    };
284                }
285            } else {
286                tracing::warn!("Shouldn't be delta content after in channel: {}", channel);
287            }
288        }
289        tracing::debug!("No deltas to return, returning empty result");
290        ParserResult::default()
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_gpt_oss_reasoning_parser() {
300        let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
301        let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
302        let result = parser.detect_and_parse_reasoning(text, &[]);
303        assert!(result.normal_text == "The capital of Brazil is Brasília.");
304        assert!(
305            result.reasoning_text
306                == "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
307        );
308    }
309
310    #[test]
311    fn test_gpt_oss_reasoning_parser_streaming() {
312        let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
313        let chunks = vec![
314            "<|channel|>",
315            "analysis<|message|>The user asks a simple factual question: capital of Brazil.",
316            " The answer is Brasília. No additional explanation needed.",
317            "<|end|><|start|>assistant<|channel|>final<|message|>",
318            "The capital of Brazil is Brasília.",
319        ];
320        let mut reasoning_text_incr = String::new();
321        let mut normal_text_incr = String::new();
322        for chunk in chunks {
323            let result = parser.parse_reasoning_streaming_incremental(chunk, &[]);
324            normal_text_incr.push_str(&result.normal_text);
325            reasoning_text_incr.push_str(&result.reasoning_text);
326        }
327        assert!(normal_text_incr == "The capital of Brazil is Brasília.");
328        assert!(
329            reasoning_text_incr
330                == "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
331        );
332    }
333
334    #[test]
335    fn test_gpt_oss_reasoning_parser_streaming_chunked() {
336        let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
337        let enc = get_harmony_encoding()
338            .as_ref()
339            .expect("Failed to get encoding");
340        let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
341        let token_ids = enc.tokenizer().encode_with_special_tokens(text);
342        let mut reasoning_text_incr = String::new();
343        let mut normal_text_incr = String::new();
344
345        let mut idx = 0;
346        let chunk_size = 4;
347        while idx < token_ids.len() {
348            let end = (idx + chunk_size).min(token_ids.len());
349            let result =
350                parser.parse_reasoning_streaming_incremental("Test text", &token_ids[idx..end]);
351            normal_text_incr.push_str(&result.normal_text);
352            reasoning_text_incr.push_str(&result.reasoning_text);
353            idx = end;
354        }
355
356        assert_eq!(normal_text_incr, "The capital of Brazil is Brasília.");
357        assert_eq!(
358            reasoning_text_incr,
359            "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
360        );
361    }
362
363    #[test]
364    fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
365        let text = "<|channel|>analysis<|message|>User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>{}";
366        let enc = get_harmony_encoding()
367            .as_ref()
368            .expect("Failed to get encoding");
369        let token_ids = enc.tokenizer().encode_with_special_tokens(text);
370
371        // Send token one by one
372        {
373            let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
374            let mut reasoning_text_incr = String::new();
375            let mut normal_text_incr = String::new();
376            for token in token_ids.iter() {
377                let result = parser.parse_reasoning_streaming_incremental("", &[(*token)]);
378                normal_text_incr.push_str(&result.normal_text);
379                reasoning_text_incr.push_str(&result.reasoning_text);
380            }
381            assert_eq!(
382                reasoning_text_incr,
383                "User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function."
384            );
385            // [gluo TODO] missing "<|start|>assistant" and "{}" from original message
386            assert_eq!(
387                normal_text_incr,
388                "<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
389            );
390        }
391
392        // Send token in chunks (chunking obtained from actual model output)
393        {
394            let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
395            let mut reasoning_text_incr = String::new();
396            let mut normal_text_incr = String::new();
397            let chunk_tokens = [
398                vec![200005],
399                vec![35644, 200008, 1844, 31064, 25, 392, 25216, 11, 4853],
400                vec![2371, 25, 382, 5519, 869, 326, 6788, 16842, 1416, 1757],
401                vec![2371, 2420, 3230, 2360, 290, 5181, 1114, 717, 39303, 126214],
402                vec![
403                    13, 7649, 1114, 13, 200007, 200006, 173781, 200005, 12606, 815,
404                ],
405                vec![
406                    316, 28, 44580, 775, 39303, 126214, 220, 200003, 4108, 200008,
407                ],
408                vec![12083],
409            ];
410            // Concatenate chunk tokens and verify they match original token_ids
411            let concatenated: Vec<u32> = chunk_tokens.iter().flatten().copied().collect();
412            assert_eq!(concatenated, token_ids);
413
414            for token in chunk_tokens.iter() {
415                let result = parser.parse_reasoning_streaming_incremental("", token);
416                normal_text_incr.push_str(&result.normal_text);
417                reasoning_text_incr.push_str(&result.reasoning_text);
418            }
419            assert_eq!(
420                reasoning_text_incr,
421                "User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function."
422            );
423            assert_eq!(
424                normal_text_incr,
425                "<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
426            );
427        }
428    }
429}