1use std::collections::HashMap;
15
16use async_trait::async_trait;
17use futures::StreamExt;
18
19use crate::error::Result;
20use crate::language_model::{
21 BoxStream, CallOptions, Content, GenerateResult, LanguageModel, ReasoningPart, StreamPart,
22 StreamResult, TextPart,
23};
24use crate::middleware::language_model::LanguageModelMiddleware;
25
26#[derive(Debug, Clone)]
28pub struct ExtractReasoningMiddleware {
29 tag_name: String,
30 start_with_reasoning: bool,
31 separator: String,
32}
33
34impl ExtractReasoningMiddleware {
35 #[must_use]
45 pub fn new(tag_name: impl Into<String>, start_with_reasoning: bool) -> Self {
46 Self {
47 tag_name: tag_name.into(),
48 start_with_reasoning,
49 separator: "\n".to_owned(),
50 }
51 }
52
53 #[must_use]
57 pub fn with_separator(mut self, separator: impl Into<String>) -> Self {
58 self.separator = separator.into();
59 self
60 }
61}
62
63fn find_matches(input: &str, tag: &str) -> Vec<(usize, usize, String)> {
68 let open = format!("<{tag}>");
69 let close = format!("</{tag}>");
70 let mut out = Vec::new();
71 let mut cursor = 0;
72 while let Some(rel_open) = input[cursor..].find(&open) {
73 let abs_open = cursor + rel_open;
74 let after_open = abs_open + open.len();
75 let Some(rel_close) = input[after_open..].find(&close) else {
76 break;
77 };
78 let abs_close = after_open + rel_close;
79 let captured = input[after_open..abs_close].to_owned();
80 let total_len = (abs_close + close.len()) - abs_open;
81 out.push((abs_open, total_len, captured));
82 cursor = abs_close + close.len();
83 }
84 out
85}
86
87fn extract_reasoning_join(
93 input: &str,
94 tag: &str,
95 start_with_reasoning: bool,
96 separator: &str,
97) -> Option<(String, String)> {
98 let owned;
99 let text: &str = if start_with_reasoning {
100 owned = format!("<{tag}>{input}");
101 &owned
102 } else {
103 input
104 };
105 let matches = find_matches(text, tag);
106 if matches.is_empty() {
107 return None;
108 }
109 let reasoning = matches
110 .iter()
111 .map(|m| m.2.as_str())
112 .collect::<Vec<_>>()
113 .join(separator);
114
115 let mut text_without = text.to_owned();
118 for (start, len, _) in matches.iter().rev() {
119 let before = text_without[..*start].to_owned();
120 let after = text_without[start + len..].to_owned();
121 text_without = if !before.is_empty() && !after.is_empty() {
122 format!("{before}{separator}{after}")
123 } else {
124 format!("{before}{after}")
125 };
126 }
127 Some((reasoning, text_without))
128}
129
130#[async_trait]
131impl LanguageModelMiddleware for ExtractReasoningMiddleware {
132 async fn wrap_generate(
133 &self,
134 next: &dyn LanguageModel,
135 params: CallOptions,
136 ) -> Result<GenerateResult> {
137 let mut result = next.do_generate(params).await?;
138 let mut new_content: Vec<Content> = Vec::with_capacity(result.content.len());
139 for c in result.content.drain(..) {
140 match c {
141 Content::Text(t) => {
142 if let Some((reasoning, text_without)) = extract_reasoning_join(
143 &t.text,
144 &self.tag_name,
145 self.start_with_reasoning,
146 &self.separator,
147 ) {
148 new_content.push(Content::Reasoning(ReasoningPart {
152 text: reasoning,
153 provider_options: t.provider_options.clone(),
154 }));
155 new_content.push(Content::Text(TextPart {
156 text: text_without,
157 provider_options: t.provider_options,
158 }));
159 } else {
160 new_content.push(Content::Text(t));
163 }
164 }
165 other => new_content.push(other),
166 }
167 }
168 result.content = new_content;
169 Ok(result)
170 }
171
172 async fn wrap_stream(
173 &self,
174 next: &dyn LanguageModel,
175 params: CallOptions,
176 ) -> Result<StreamResult> {
177 let upstream = next.do_stream(params).await?;
178 let StreamResult {
179 stream,
180 request,
181 response,
182 } = upstream;
183
184 let cleaned = transform_stream(
185 stream,
186 self.tag_name.clone(),
187 self.start_with_reasoning,
188 self.separator.clone(),
189 );
190 Ok(StreamResult {
191 stream: cleaned,
192 request,
193 response,
194 })
195 }
196}
197
198fn potential_start_index(haystack: &str, needle: &str) -> Option<usize> {
202 if needle.is_empty() {
203 return None;
204 }
205 if let Some(direct) = haystack.find(needle) {
206 return Some(direct);
207 }
208 let mut idx = haystack.len();
212 for (start, _) in haystack.char_indices().rev() {
213 idx = start;
214 let suffix = &haystack[idx..];
215 if needle.starts_with(suffix) {
216 return Some(idx);
217 }
218 }
219 let _ = idx;
220 None
221}
222
223#[derive(Debug)]
226#[allow(
227 clippy::struct_excessive_bools,
228 reason = "Mirrors upstream `reasoningExtractions[chunk.id]` shape — four independent boolean phase flags; collapsing them obscures the upstream comparison."
229)]
230struct Extraction {
231 is_first_reasoning: bool,
232 is_first_text: bool,
233 after_switch: bool,
234 is_reasoning: bool,
235 buffer: String,
236 id_counter: u32,
237 text_id: String,
238}
239
240struct StreamCtx {
241 stream: BoxStream<Result<StreamPart>>,
242 extractions: HashMap<String, Extraction>,
243 tag: String,
244 start_with_reasoning: bool,
245 separator: String,
246 delayed_text_start: Option<StreamPart>,
247}
248
249fn transform_stream(
255 stream: BoxStream<Result<StreamPart>>,
256 tag: String,
257 start_with_reasoning: bool,
258 separator: String,
259) -> BoxStream<Result<StreamPart>> {
260 let ctx = StreamCtx {
261 stream,
262 extractions: HashMap::new(),
263 tag,
264 start_with_reasoning,
265 separator,
266 delayed_text_start: None,
267 };
268 let mapped = futures::stream::unfold(ctx, |mut ctx| async move {
269 loop {
270 match ctx.stream.next().await {
271 None => return None,
272 Some(Err(e)) => return Some((vec![Err(e)], ctx)),
273 Some(Ok(part)) => {
274 let out = process_part(&mut ctx, part);
275 if !out.is_empty() {
276 return Some((out, ctx));
277 }
278 }
281 }
282 }
283 })
284 .flat_map(futures::stream::iter);
285 Box::pin(mapped)
286}
287
288fn process_part(ctx: &mut StreamCtx, part: StreamPart) -> Vec<Result<StreamPart>> {
289 match part {
290 StreamPart::TextStart { .. } => {
295 ctx.delayed_text_start = Some(part);
296 Vec::new()
297 }
298 StreamPart::TextDelta { id, delta, .. } => process_text_delta(ctx, &id, &delta),
299 StreamPart::TextEnd {
300 id,
301 provider_metadata,
302 } => {
303 let mut out: Vec<Result<StreamPart>> = Vec::new();
304 if let Some(start) = ctx.delayed_text_start.take() {
305 out.push(Ok(start));
306 }
307 ctx.extractions.remove(&id);
310 out.push(Ok(StreamPart::TextEnd {
311 id,
312 provider_metadata,
313 }));
314 out
315 }
316 other => vec![Ok(other)],
317 }
318}
319
320fn process_text_delta(ctx: &mut StreamCtx, id: &str, delta: &str) -> Vec<Result<StreamPart>> {
321 let opening_tag = format!("<{}>", ctx.tag);
322 let closing_tag = format!("</{}>", ctx.tag);
323
324 let extraction = ctx
325 .extractions
326 .entry(id.to_owned())
327 .or_insert_with(|| Extraction {
328 is_first_reasoning: true,
329 is_first_text: true,
330 after_switch: false,
331 is_reasoning: ctx.start_with_reasoning,
332 buffer: String::new(),
333 id_counter: 0,
334 text_id: id.to_owned(),
335 });
336 extraction.buffer.push_str(delta);
337
338 let mut out: Vec<Result<StreamPart>> = Vec::new();
339 loop {
340 let next_tag: &str = if extraction.is_reasoning {
341 &closing_tag
342 } else {
343 &opening_tag
344 };
345
346 let start_index = potential_start_index(&extraction.buffer, next_tag);
347 let Some(start_idx) = start_index else {
348 let snapshot = std::mem::take(&mut extraction.buffer);
350 publish(
351 extraction,
352 &snapshot,
353 &ctx.separator,
354 &mut ctx.delayed_text_start,
355 &mut out,
356 );
357 break;
358 };
359
360 let before = extraction.buffer[..start_idx].to_owned();
362 publish(
363 extraction,
364 &before,
365 &ctx.separator,
366 &mut ctx.delayed_text_start,
367 &mut out,
368 );
369
370 let after_tag = start_idx + next_tag.len();
371 let full_match = after_tag <= extraction.buffer.len();
372 if !full_match {
373 extraction.buffer = extraction.buffer[start_idx..].to_owned();
375 break;
376 }
377
378 extraction.buffer = extraction.buffer[after_tag..].to_owned();
379 if extraction.is_reasoning {
380 if extraction.is_first_reasoning {
384 out.push(Ok(StreamPart::ReasoningStart {
385 id: format!("reasoning-{}", extraction.id_counter),
386 provider_metadata: None,
387 }));
388 }
389 out.push(Ok(StreamPart::ReasoningEnd {
390 id: format!("reasoning-{}", extraction.id_counter),
391 provider_metadata: None,
392 }));
393 extraction.id_counter += 1;
394 }
395 extraction.is_reasoning = !extraction.is_reasoning;
396 extraction.after_switch = true;
397 }
398 out
399}
400
401fn publish(
402 extraction: &mut Extraction,
403 text: &str,
404 separator: &str,
405 delayed_text_start: &mut Option<StreamPart>,
406 out: &mut Vec<Result<StreamPart>>,
407) {
408 if text.is_empty() {
409 return;
410 }
411 let needs_prefix = extraction.after_switch
412 && (if extraction.is_reasoning {
413 !extraction.is_first_reasoning
414 } else {
415 !extraction.is_first_text
416 });
417 let payload = if needs_prefix {
418 format!("{separator}{text}")
419 } else {
420 text.to_owned()
421 };
422
423 if extraction.is_reasoning {
424 if extraction.after_switch || extraction.is_first_reasoning {
425 out.push(Ok(StreamPart::ReasoningStart {
426 id: format!("reasoning-{}", extraction.id_counter),
427 provider_metadata: None,
428 }));
429 }
430 out.push(Ok(StreamPart::ReasoningDelta {
431 id: format!("reasoning-{}", extraction.id_counter),
432 delta: payload,
433 provider_metadata: None,
434 }));
435 } else {
436 if let Some(start) = delayed_text_start.take() {
437 out.push(Ok(start));
438 }
439 out.push(Ok(StreamPart::TextDelta {
440 id: extraction.text_id.clone(),
441 delta: payload,
442 provider_metadata: None,
443 }));
444 }
445
446 extraction.after_switch = false;
447 if extraction.is_reasoning {
448 extraction.is_first_reasoning = false;
449 } else {
450 extraction.is_first_text = false;
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use std::sync::Arc;
457
458 use futures::stream;
459
460 use super::*;
461 use crate::language_model::{FinishReason, FinishReasonKind, Usage};
462 use crate::middleware::wrap_language_model;
463
464 #[derive(Debug)]
465 struct Fake {
466 text: String,
467 }
468
469 #[async_trait]
470 impl LanguageModel for Fake {
471 fn provider(&self) -> &'static str {
472 "fake"
473 }
474 fn model_id(&self) -> &'static str {
475 "fake"
476 }
477 async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResult> {
478 Ok(GenerateResult {
479 content: vec![Content::Text(TextPart {
480 text: self.text.clone(),
481 provider_options: None,
482 })],
483 finish_reason: FinishReason::new(FinishReasonKind::Stop),
484 usage: Usage::default(),
485 provider_metadata: None,
486 request: None,
487 response: None,
488 warnings: vec![],
489 })
490 }
491 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
492 let parts: Vec<Result<StreamPart>> = vec![
493 Ok(StreamPart::TextStart {
494 id: "b".into(),
495 provider_metadata: None,
496 }),
497 Ok(StreamPart::TextDelta {
498 id: "b".into(),
499 delta: self.text.clone(),
500 provider_metadata: None,
501 }),
502 Ok(StreamPart::TextEnd {
503 id: "b".into(),
504 provider_metadata: None,
505 }),
506 Ok(StreamPart::Finish {
507 usage: Usage::default(),
508 finish_reason: FinishReason::new(FinishReasonKind::Stop),
509 provider_metadata: None,
510 }),
511 ];
512 Ok(StreamResult {
513 stream: Box::pin(stream::iter(parts)),
514 request: None,
515 response: None,
516 })
517 }
518 }
519
520 #[tokio::test]
521 async fn generate_splits_single_think_tag() {
522 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
524 text: "<think>analyzing the request</think>Here is the response".into(),
525 });
526 let wrapped = wrap_language_model(
527 inner,
528 [Arc::new(ExtractReasoningMiddleware::new("think", false))
529 as Arc<dyn LanguageModelMiddleware>],
530 );
531 let r = wrapped
532 .do_generate(CallOptions::default())
533 .await
534 .expect("gen");
535 assert_eq!(r.content.len(), 2, "always reasoning + text");
536 match (&r.content[0], &r.content[1]) {
537 (Content::Reasoning(a), Content::Text(b)) => {
538 assert_eq!(a.text, "analyzing the request");
539 assert_eq!(b.text, "Here is the response");
540 }
541 other => panic!("unexpected split: {other:?}"),
542 }
543 }
544
545 #[tokio::test]
546 async fn generate_joins_multiple_think_tags_with_separator() {
547 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
550 text: "<think>analyzing the request</think>Here is the response<think>thinking about the response</think>more".into(),
551 });
552 let wrapped = wrap_language_model(
553 inner,
554 [Arc::new(ExtractReasoningMiddleware::new("think", false))
555 as Arc<dyn LanguageModelMiddleware>],
556 );
557 let r = wrapped
558 .do_generate(CallOptions::default())
559 .await
560 .expect("gen");
561 assert_eq!(r.content.len(), 2);
562 match (&r.content[0], &r.content[1]) {
563 (Content::Reasoning(a), Content::Text(b)) => {
564 assert_eq!(a.text, "analyzing the request\nthinking about the response");
565 assert_eq!(b.text, "Here is the response\nmore");
566 }
567 other => panic!("unexpected split: {other:?}"),
568 }
569 }
570
571 #[tokio::test]
572 async fn generate_preserves_text_when_tag_absent() {
573 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
574 text: "no tags here".into(),
575 });
576 let wrapped = wrap_language_model(
577 inner,
578 [Arc::new(ExtractReasoningMiddleware::new("think", false))
579 as Arc<dyn LanguageModelMiddleware>],
580 );
581 let r = wrapped
582 .do_generate(CallOptions::default())
583 .await
584 .expect("gen");
585 assert_eq!(r.content.len(), 1);
586 assert!(matches!(&r.content[0], Content::Text(t) if t.text == "no tags here"));
587 }
588
589 #[tokio::test]
590 async fn generate_custom_separator_overrides_default() {
591 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
592 text: "<t>a</t>mid<t>b</t>".into(),
593 });
594 let mw = ExtractReasoningMiddleware::new("t", false).with_separator(" | ");
595 let wrapped =
596 wrap_language_model(inner, [Arc::new(mw) as Arc<dyn LanguageModelMiddleware>]);
597 let r = wrapped
598 .do_generate(CallOptions::default())
599 .await
600 .expect("gen");
601 match (&r.content[0], &r.content[1]) {
602 (Content::Reasoning(a), Content::Text(b)) => {
603 assert_eq!(a.text, "a | b");
604 assert_eq!(b.text, "mid");
605 }
606 other => panic!("unexpected: {other:?}"),
607 }
608 }
609
610 #[derive(Debug)]
611 struct MultiChunkFake {
612 chunks: Vec<String>,
613 }
614
615 #[async_trait]
616 impl LanguageModel for MultiChunkFake {
617 fn provider(&self) -> &'static str {
618 "fake"
619 }
620 fn model_id(&self) -> &'static str {
621 "fake"
622 }
623 async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResult> {
624 unimplemented!()
625 }
626 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
627 let mut parts: Vec<Result<StreamPart>> = vec![Ok(StreamPart::TextStart {
628 id: "b".into(),
629 provider_metadata: None,
630 })];
631 for chunk in &self.chunks {
632 parts.push(Ok(StreamPart::TextDelta {
633 id: "b".into(),
634 delta: chunk.clone(),
635 provider_metadata: None,
636 }));
637 }
638 parts.push(Ok(StreamPart::TextEnd {
639 id: "b".into(),
640 provider_metadata: None,
641 }));
642 parts.push(Ok(StreamPart::Finish {
643 usage: Usage::default(),
644 finish_reason: FinishReason::new(FinishReasonKind::Stop),
645 provider_metadata: None,
646 }));
647 Ok(StreamResult {
648 stream: Box::pin(stream::iter(parts)),
649 request: None,
650 response: None,
651 })
652 }
653 }
654
655 #[tokio::test]
656 async fn stream_emits_incrementally_across_chunks() {
657 let inner: Arc<dyn LanguageModel> = Arc::new(MultiChunkFake {
662 chunks: vec![
663 "<thi".into(),
664 "nk>analyzing ".into(),
665 "the request</th".into(),
666 "ink>Here is ".into(),
667 "the response".into(),
668 ],
669 });
670 let wrapped = wrap_language_model(
671 inner,
672 [Arc::new(ExtractReasoningMiddleware::new("think", false))
673 as Arc<dyn LanguageModelMiddleware>],
674 );
675 let mut s = wrapped.do_stream(CallOptions::default()).await.unwrap();
676 let mut reasoning_deltas: Vec<String> = Vec::new();
677 let mut text_deltas: Vec<String> = Vec::new();
678 let mut reasoning_starts = 0u32;
679 let mut reasoning_ends = 0u32;
680 while let Some(item) = s.stream.next().await {
681 match item.unwrap() {
682 StreamPart::ReasoningStart { .. } => reasoning_starts += 1,
683 StreamPart::ReasoningDelta { delta, .. } => reasoning_deltas.push(delta),
684 StreamPart::ReasoningEnd { .. } => reasoning_ends += 1,
685 StreamPart::TextDelta { delta, .. } => text_deltas.push(delta),
686 _ => {}
687 }
688 }
689 assert_eq!(reasoning_starts, 1, "one reasoning block opened");
690 assert_eq!(reasoning_ends, 1, "one reasoning block closed");
691 assert!(
692 reasoning_deltas.len() >= 2,
693 "expected >=2 reasoning-delta ticks, got {reasoning_deltas:?}"
694 );
695 assert_eq!(reasoning_deltas.concat(), "analyzing the request");
696 assert_eq!(text_deltas.concat(), "Here is the response");
697 }
698
699 #[tokio::test]
700 async fn stream_emits_reasoning_then_text() {
701 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
702 text: "<think>x</think>y".into(),
703 });
704 let wrapped = wrap_language_model(
705 inner,
706 [Arc::new(ExtractReasoningMiddleware::new("think", false))
707 as Arc<dyn LanguageModelMiddleware>],
708 );
709 let mut s = wrapped.do_stream(CallOptions::default()).await.unwrap();
710 let mut events: Vec<String> = Vec::new();
711 while let Some(item) = s.stream.next().await {
712 match item.unwrap() {
713 StreamPart::TextDelta { delta, .. } => events.push(format!("text:{delta}")),
714 StreamPart::ReasoningDelta { delta, .. } => events.push(format!("reason:{delta}")),
715 StreamPart::TextEnd { .. } => events.push("end".into()),
716 _ => {}
717 }
718 }
719 assert!(
720 events.contains(&"reason:x".to_owned()),
721 "saw reason: {events:?}"
722 );
723 assert!(
724 events.contains(&"text:y".to_owned()),
725 "saw text: {events:?}"
726 );
727 }
728}