dynamo_parsers/reasoning/
gpt_oss_parser.rs1use std::fmt::Debug;
5
6use crate::ParserResult;
7use crate::ReasoningParser;
8
9use openai_harmony::StreamableParser;
10use openai_harmony::chat::TextContent;
11use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, chat::Role, load_harmony_encoding};
12
13use std::sync::OnceLock;
16
17static GLOBAL_HARMONY_GPTOSS_ENCODING: OnceLock<Result<HarmonyEncoding, anyhow::Error>> =
18 OnceLock::new();
19
20fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
21 GLOBAL_HARMONY_GPTOSS_ENCODING
22 .get_or_init(|| load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss))
23}
24
25pub struct GptOssReasoningParser {
26 parser: StreamableParser,
27}
28
29impl Debug for GptOssReasoningParser {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("GptOssReasoningParser")
33 .field("parser", &self.parser.state_json())
34 .finish()
35 }
36}
37
38impl GptOssReasoningParser {
39 pub fn new() -> anyhow::Result<Self> {
40 let parser = match get_harmony_encoding().as_ref() {
41 Ok(enc) => match StreamableParser::new(enc.clone(), Some(Role::Assistant)) {
42 Ok(p) => p,
43 Err(e) => {
44 tracing::warn!("Harmony StreamableParser init failed for GPT OSS: {e}");
45 return Err(anyhow::anyhow!(
46 "Failed to load Harmony StreamableParser: {e}"
47 ));
48 }
49 },
50 Err(e) => {
51 tracing::warn!("Failed to load Harmony encoding for GPT OSS: {e}");
52 return Err(anyhow::anyhow!("Failed to load Harmony encoding: {e}"));
53 }
54 };
55 Ok(Self { parser })
56 }
57}
58
59fn encode_text_to_tokens(text: &str) -> anyhow::Result<Vec<u32>> {
60 let enc = get_harmony_encoding()
61 .as_ref()
62 .map_err(|e| anyhow::anyhow!("Failed to get harmony encoding: {e}"))?;
63 Ok(enc.tokenizer().encode_with_special_tokens(text))
64}
65
66impl ReasoningParser for GptOssReasoningParser {
67 fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
68 let token_ids = if token_ids.is_empty() {
69 let encoded_tokens = match encode_text_to_tokens(text) {
71 Ok(tokens) => tokens,
72 Err(err) => {
73 tracing::warn!("Failed to encode Harmony tokens: {err}");
74 return ParserResult::default();
75 }
76 };
77 &encoded_tokens.to_vec()
78 } else {
79 token_ids
80 };
81
82 let parser = &mut self.parser;
83
84 for (i, token_id) in token_ids.iter().enumerate() {
85 tracing::debug!(
86 "Processing token {} of {}: {}",
87 i + 1,
88 token_ids.len(),
89 token_id
90 );
91 if let Err(e) = parser.process(*token_id) {
92 tracing::warn!("Harmony parse error for token_id {token_id}: {e}");
93 return ParserResult::default();
94 }
95 }
96
97 let output_msgs = parser.messages();
98 tracing::debug!("Parser has {} output messages", output_msgs.len());
99
100 match output_msgs.len() {
101 0 => {
102 tracing::debug!("No output messages, using current content");
103 let current = parser.current_content().unwrap_or_default();
104 tracing::debug!("Current content length: {}", current.len());
105 ParserResult {
106 normal_text: String::new(),
107 reasoning_text: current,
108 }
109 }
110 1 => {
111 tracing::debug!("Single output message detected");
112 let mut reasoning_text = String::new();
113 if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
114 output_msgs[0].content.first()
115 {
116 reasoning_text.push_str(text);
117 tracing::debug!("Extracted reasoning text length: {}", reasoning_text.len());
118 }
119 let current = parser.current_content().unwrap_or_default();
120 tracing::debug!("Current content length: {}", current.len());
121 ParserResult {
122 normal_text: current,
123 reasoning_text,
124 }
125 }
126 _ => {
127 tracing::debug!("Multiple output messages detected: {}", output_msgs.len());
128 let mut reasoning_text = String::new();
129 let mut normal_text = String::new();
130
131 for (i, parse_msg) in output_msgs.iter().take(output_msgs.len() - 1).enumerate() {
133 tracing::debug!("Processing reasoning message {}", i + 1);
134 if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
135 parse_msg.content.first()
136 {
137 reasoning_text.push_str(text);
138 tracing::debug!("Added {} chars to reasoning text", text.len());
139 }
140 }
141
142 let last_msg = &output_msgs[output_msgs.len() - 1];
143 tracing::debug!("Processing final message");
144
145 if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
147 last_msg.content.first()
148 {
149 normal_text.push_str(text);
150 tracing::debug!("Added {} chars to normal text", text.len());
151 }
152
153 tracing::debug!(
154 "Final result - normal_text: {} chars, reasoning_text: {} chars",
155 normal_text.len(),
156 reasoning_text.len()
157 );
158
159 ParserResult {
160 normal_text,
161 reasoning_text,
162 }
163 }
164 }
165 }
166
167 fn parse_reasoning_streaming_incremental(
168 &mut self,
169 text: &str,
170 token_ids: &[u32],
171 ) -> ParserResult {
172 let token_ids = if token_ids.is_empty() {
173 let encoded_tokens = match encode_text_to_tokens(text) {
175 Ok(tokens) => tokens,
176 Err(err) => {
177 tracing::warn!("Failed to encode Harmony tokens: {err}");
178 return ParserResult::default();
179 }
180 };
181 &encoded_tokens.to_vec()
182 } else {
183 token_ids
184 };
185
186 let parser: &mut StreamableParser = &mut self.parser;
187 let mut normal_delta = String::new();
188 let mut reasoning_delta = String::new();
189
190 for (i, token_id) in token_ids.iter().enumerate() {
191 tracing::debug!(
192 "Processing streaming token {} of {}: {}",
193 i + 1,
194 token_ids.len(),
195 token_id
196 );
197 if let Err(e) = parser.process(*token_id) {
198 tracing::warn!("Harmony parse error for token_id {token_id}: {e}");
199 return ParserResult::default();
200 }
201
202 if let (Some(delta), Some(channel)) = (
203 parser.last_content_delta().unwrap_or_default(),
204 parser.current_channel(),
205 ) {
206 match channel.as_str() {
210 "final" => normal_delta.push_str(&delta),
211 "analysis" => reasoning_delta.push_str(&delta),
212 "commentary" => {}
213 _ => {}
214 }
215 }
216 }
217
218 if !normal_delta.is_empty() || !reasoning_delta.is_empty() {
219 tracing::debug!(
220 "Returning aggregated deltas: normal: {} chars, reasoning: {} chars",
221 normal_delta.len(),
222 reasoning_delta.len()
223 );
224 return ParserResult {
225 normal_text: normal_delta,
226 reasoning_text: reasoning_delta,
227 };
228 }
229
230 if let Some(channel) = parser.current_channel() {
231 if channel == "commentary" {
232 tracing::debug!("In commentary channel, recovering full content");
233 if let Ok(enc) = get_harmony_encoding() {
236 let current_content = parser.current_content().unwrap_or_default();
237 let mut final_text = text.to_string();
238
239 if current_content.is_empty() {
253 let tokens = parser.tokens();
254
255 let channel_token_id = enc
257 .tokenizer()
258 .encode_with_special_tokens("<|channel|>")
259 .last()
260 .copied();
261
262 let last_channel_token_idx = channel_token_id
264 .and_then(|token_id| {
265 tokens.iter().rposition(|token| *token == token_id)
266 })
267 .unwrap_or(0);
268
269 let end_token_idx = parser.tokens().len();
271 let generated_text = enc
273 .tokenizer()
274 .decode_utf8(&parser.tokens()[last_channel_token_idx..end_token_idx])
275 .unwrap_or_default();
276
277 final_text = generated_text;
278 }
279
280 return ParserResult {
281 normal_text: final_text,
282 reasoning_text: String::new(),
283 };
284 }
285 } else {
286 tracing::warn!("Shouldn't be delta content after in channel: {}", channel);
287 }
288 }
289 tracing::debug!("No deltas to return, returning empty result");
290 ParserResult::default()
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_gpt_oss_reasoning_parser() {
300 let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
301 let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
302 let result = parser.detect_and_parse_reasoning(text, &[]);
303 assert!(result.normal_text == "The capital of Brazil is Brasília.");
304 assert!(
305 result.reasoning_text
306 == "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
307 );
308 }
309
310 #[test]
311 fn test_gpt_oss_reasoning_parser_streaming() {
312 let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
313 let chunks = vec![
314 "<|channel|>",
315 "analysis<|message|>The user asks a simple factual question: capital of Brazil.",
316 " The answer is Brasília. No additional explanation needed.",
317 "<|end|><|start|>assistant<|channel|>final<|message|>",
318 "The capital of Brazil is Brasília.",
319 ];
320 let mut reasoning_text_incr = String::new();
321 let mut normal_text_incr = String::new();
322 for chunk in chunks {
323 let result = parser.parse_reasoning_streaming_incremental(chunk, &[]);
324 normal_text_incr.push_str(&result.normal_text);
325 reasoning_text_incr.push_str(&result.reasoning_text);
326 }
327 assert!(normal_text_incr == "The capital of Brazil is Brasília.");
328 assert!(
329 reasoning_text_incr
330 == "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
331 );
332 }
333
334 #[test]
335 fn test_gpt_oss_reasoning_parser_streaming_chunked() {
336 let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
337 let enc = get_harmony_encoding()
338 .as_ref()
339 .expect("Failed to get encoding");
340 let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
341 let token_ids = enc.tokenizer().encode_with_special_tokens(text);
342 let mut reasoning_text_incr = String::new();
343 let mut normal_text_incr = String::new();
344
345 let mut idx = 0;
346 let chunk_size = 4;
347 while idx < token_ids.len() {
348 let end = (idx + chunk_size).min(token_ids.len());
349 let result =
350 parser.parse_reasoning_streaming_incremental("Test text", &token_ids[idx..end]);
351 normal_text_incr.push_str(&result.normal_text);
352 reasoning_text_incr.push_str(&result.reasoning_text);
353 idx = end;
354 }
355
356 assert_eq!(normal_text_incr, "The capital of Brazil is Brasília.");
357 assert_eq!(
358 reasoning_text_incr,
359 "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
360 );
361 }
362
363 #[test]
364 fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
365 let text = "<|channel|>analysis<|message|>User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>{}";
366 let enc = get_harmony_encoding()
367 .as_ref()
368 .expect("Failed to get encoding");
369 let token_ids = enc.tokenizer().encode_with_special_tokens(text);
370
371 {
373 let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
374 let mut reasoning_text_incr = String::new();
375 let mut normal_text_incr = String::new();
376 for token in token_ids.iter() {
377 let result = parser.parse_reasoning_streaming_incremental("", &[(*token)]);
378 normal_text_incr.push_str(&result.normal_text);
379 reasoning_text_incr.push_str(&result.reasoning_text);
380 }
381 assert_eq!(
382 reasoning_text_incr,
383 "User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function."
384 );
385 assert_eq!(
387 normal_text_incr,
388 "<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
389 );
390 }
391
392 {
394 let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
395 let mut reasoning_text_incr = String::new();
396 let mut normal_text_incr = String::new();
397 let chunk_tokens = [
398 vec![200005],
399 vec![35644, 200008, 1844, 31064, 25, 392, 25216, 11, 4853],
400 vec![2371, 25, 382, 5519, 869, 326, 6788, 16842, 1416, 1757],
401 vec![2371, 2420, 3230, 2360, 290, 5181, 1114, 717, 39303, 126214],
402 vec![
403 13, 7649, 1114, 13, 200007, 200006, 173781, 200005, 12606, 815,
404 ],
405 vec![
406 316, 28, 44580, 775, 39303, 126214, 220, 200003, 4108, 200008,
407 ],
408 vec![12083],
409 ];
410 let concatenated: Vec<u32> = chunk_tokens.iter().flatten().copied().collect();
412 assert_eq!(concatenated, token_ids);
413
414 for token in chunk_tokens.iter() {
415 let result = parser.parse_reasoning_streaming_incremental("", token);
416 normal_text_incr.push_str(&result.normal_text);
417 reasoning_text_incr.push_str(&result.reasoning_text);
418 }
419 assert_eq!(
420 reasoning_text_incr,
421 "User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function."
422 );
423 assert_eq!(
424 normal_text_incr,
425 "<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
426 );
427 }
428 }
429}