1use std::collections::HashMap;
15
16use async_trait::async_trait;
17use futures::StreamExt;
18
19use crate::error::Result;
20use crate::language_model::{
21 BoxStream, CallOptions, Content, GenerateResult, LanguageModel, StreamPart, StreamResult,
22 TextPart,
23};
24use crate::middleware::language_model::LanguageModelMiddleware;
25use crate::shared::ProviderMetadata;
26
27const SUFFIX_BUFFER_SIZE: usize = 12;
31
32type TransformFn = std::sync::Arc<dyn Fn(&str) -> String + Send + Sync>;
36
37pub struct ExtractJsonMiddleware {
39 transform: Option<TransformFn>,
40}
41
42impl std::fmt::Debug for ExtractJsonMiddleware {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("ExtractJsonMiddleware")
45 .field("transform", &self.transform.is_some().then_some("<fn>"))
46 .finish()
47 }
48}
49
50impl Default for ExtractJsonMiddleware {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl ExtractJsonMiddleware {
57 #[must_use]
60 pub fn new() -> Self {
61 Self { transform: None }
62 }
63
64 #[must_use]
69 pub fn with_transform<F>(mut self, transform: F) -> Self
70 where
71 F: Fn(&str) -> String + Send + Sync + 'static,
72 {
73 self.transform = Some(std::sync::Arc::new(transform));
74 self
75 }
76
77 fn apply_transform(&self, text: &str) -> String {
78 match self.transform.as_ref() {
79 Some(f) => f(text),
80 None => default_transform(text),
81 }
82 }
83}
84
85#[async_trait]
86impl LanguageModelMiddleware for ExtractJsonMiddleware {
87 async fn wrap_generate(
88 &self,
89 next: &dyn LanguageModel,
90 params: CallOptions,
91 ) -> Result<GenerateResult> {
92 let mut result = next.do_generate(params).await?;
93 for content in &mut result.content {
94 if let Content::Text(part) = content {
95 part.text = self.apply_transform(&part.text);
96 }
97 }
98 Ok(result)
99 }
100
101 async fn wrap_stream(
102 &self,
103 next: &dyn LanguageModel,
104 params: CallOptions,
105 ) -> Result<StreamResult> {
106 let upstream = next.do_stream(params).await?;
107 let StreamResult {
108 stream,
109 request,
110 response,
111 } = upstream;
112 let transform = self.transform.clone();
113 let cleaned = transform_stream(stream, transform);
114 Ok(StreamResult {
115 stream: cleaned,
116 request,
117 response,
118 })
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126enum Phase {
127 Prefix,
131 Streaming,
134 Buffering,
137}
138
139#[derive(Debug)]
140struct BlockState {
141 start_event: StreamPart,
144 phase: Phase,
145 buffer: String,
146 prefix_stripped: bool,
149}
150
151fn transform_stream(
152 stream: BoxStream<Result<StreamPart>>,
153 transform: Option<TransformFn>,
154) -> BoxStream<Result<StreamPart>> {
155 let has_custom_transform = transform.is_some();
156 let state: HashMap<String, BlockState> = HashMap::new();
157 let pending: std::collections::VecDeque<Result<StreamPart>> = std::collections::VecDeque::new();
158
159 let init = StreamCtx {
160 stream,
161 state,
162 pending,
163 transform,
164 has_custom_transform,
165 };
166
167 let mapped = futures::stream::unfold(init, |mut ctx| async move {
168 loop {
169 if let Some(item) = ctx.pending.pop_front() {
172 return Some((item, ctx));
173 }
174 let next = ctx.stream.next().await?;
175 match next {
176 Err(e) => return Some((Err(e), ctx)),
177 Ok(part) => {
178 ctx.handle(part);
179 }
182 }
183 }
184 });
185 Box::pin(mapped)
186}
187
188struct StreamCtx {
189 stream: BoxStream<Result<StreamPart>>,
190 state: HashMap<String, BlockState>,
191 pending: std::collections::VecDeque<Result<StreamPart>>,
192 transform: Option<TransformFn>,
193 has_custom_transform: bool,
194}
195
196impl StreamCtx {
197 fn apply_transform(&self, text: &str) -> String {
198 match self.transform.as_ref() {
199 Some(f) => f(text),
200 None => default_transform(text),
201 }
202 }
203
204 fn handle(&mut self, part: StreamPart) {
205 match part {
206 StreamPart::TextStart {
207 id,
208 provider_metadata,
209 } => self.on_text_start(id, provider_metadata),
210 StreamPart::TextDelta { id, delta, .. } => self.on_text_delta(id, delta),
211 StreamPart::TextEnd {
212 id,
213 provider_metadata,
214 } => self.on_text_end(id, provider_metadata),
215 other => self.pending.push_back(Ok(other)),
216 }
217 }
218
219 fn on_text_start(&mut self, id: String, provider_metadata: Option<ProviderMetadata>) {
220 let start_event = StreamPart::TextStart {
221 id: id.clone(),
222 provider_metadata,
223 };
224 let phase = if self.has_custom_transform {
225 Phase::Buffering
226 } else {
227 Phase::Prefix
228 };
229 self.state.insert(
230 id,
231 BlockState {
232 start_event,
233 phase,
234 buffer: String::new(),
235 prefix_stripped: false,
236 },
237 );
238 }
243
244 fn on_text_delta(&mut self, id: String, delta: String) {
245 let Some(block) = self.state.get_mut(&id) else {
246 self.pending.push_back(Ok(StreamPart::TextDelta {
248 id,
249 delta,
250 provider_metadata: None,
251 }));
252 return;
253 };
254 block.buffer.push_str(&delta);
255
256 if block.phase == Phase::Buffering {
258 return;
259 }
260
261 if block.phase == Phase::Prefix {
262 if !block.buffer.is_empty() && !block.buffer.starts_with('`') {
265 block.phase = Phase::Streaming;
267 let start = block.start_event.clone();
268 self.pending.push_back(Ok(start));
269 } else if block.buffer.starts_with("```") {
270 if block.buffer.contains('\n') {
271 if let Some(prefix_len) = match_opening_fence_len(&block.buffer) {
272 block.buffer = block.buffer[prefix_len..].to_owned();
273 block.prefix_stripped = true;
274 block.phase = Phase::Streaming;
275 let start = block.start_event.clone();
276 self.pending.push_back(Ok(start));
277 } else {
278 block.phase = Phase::Streaming;
280 let start = block.start_event.clone();
281 self.pending.push_back(Ok(start));
282 }
283 }
284 } else if block.buffer.len() >= 3 && !block.buffer.starts_with("```") {
286 block.phase = Phase::Streaming;
288 let start = block.start_event.clone();
289 self.pending.push_back(Ok(start));
290 }
291 }
292
293 if block.phase == Phase::Streaming && block.buffer.len() > SUFFIX_BUFFER_SIZE {
295 let cut = floor_char_boundary(&block.buffer, block.buffer.len() - SUFFIX_BUFFER_SIZE);
299 let to_stream = block.buffer[..cut].to_owned();
300 block.buffer = block.buffer[cut..].to_owned();
301 if !to_stream.is_empty() {
302 self.pending.push_back(Ok(StreamPart::TextDelta {
303 id: id.clone(),
304 delta: to_stream,
305 provider_metadata: None,
306 }));
307 }
308 }
309 let _ = id;
310 }
311
312 fn on_text_end(&mut self, id: String, provider_metadata: Option<ProviderMetadata>) {
313 let Some(block) = self.state.remove(&id) else {
314 self.pending.push_back(Ok(StreamPart::TextEnd {
315 id,
316 provider_metadata,
317 }));
318 return;
319 };
320 let BlockState {
321 start_event,
322 phase,
323 buffer,
324 prefix_stripped,
325 } = block;
326
327 if matches!(phase, Phase::Prefix | Phase::Buffering) {
330 self.pending.push_back(Ok(start_event));
331 }
332
333 let remaining = match phase {
334 Phase::Buffering => self.apply_transform(&buffer),
335 _ if prefix_stripped => strip_trailing_fence_replace(&buffer),
336 _ => self.apply_transform(&buffer),
337 };
338
339 if !remaining.is_empty() {
340 self.pending.push_back(Ok(StreamPart::TextDelta {
341 id: id.clone(),
342 delta: remaining,
343 provider_metadata: None,
344 }));
345 }
346 self.pending.push_back(Ok(StreamPart::TextEnd {
347 id,
348 provider_metadata,
349 }));
350 }
351}
352
353fn default_transform(text: &str) -> String {
357 let after_prefix = strip_leading_fence(text);
358 let after_suffix = strip_trailing_fence_replace(after_prefix);
359 after_suffix.trim().to_owned()
360}
361
362fn strip_leading_fence(s: &str) -> &str {
369 let Some(after_fence) = s.strip_prefix("```") else {
370 return s;
371 };
372 let after_json = after_fence.strip_prefix("json").unwrap_or(after_fence);
373 let mut i = 0;
375 let bytes = after_json.as_bytes();
376 while i < bytes.len() && matches!(bytes[i], b' ' | b'\t' | b'\r' | b'\n' | 0x0b | 0x0c) {
377 i += 1;
378 }
379 &after_json[i..]
380}
381
382fn match_opening_fence_len(buf: &str) -> Option<usize> {
389 let rest = buf.strip_prefix("```")?;
390 let mut consumed = 3;
391 let rest = if let Some(r) = rest.strip_prefix("json") {
392 consumed += 4;
393 r
394 } else {
395 rest
396 };
397 let bytes = rest.as_bytes();
398 let mut i = 0;
399 while i < bytes.len() {
400 match bytes[i] {
401 b'\n' => return Some(consumed + i + 1),
402 b' ' | b'\t' | b'\r' | 0x0b | 0x0c => i += 1,
403 _ => return None,
404 }
405 }
406 None
407}
408
409fn strip_trailing_fence_replace(s: &str) -> String {
413 let bytes = s.as_bytes();
414 let mut i = bytes.len();
416 while i > 0 && matches!(bytes[i - 1], b' ' | b'\t' | b'\r' | b'\n' | 0x0b | 0x0c) {
417 i -= 1;
418 }
419 let before_ws = &s[..i];
420 let Some(before_fence) = before_ws.strip_suffix("```") else {
421 return s.trim_end().to_owned();
423 };
424 let after = before_fence.strip_suffix('\n').unwrap_or(before_fence);
426 after.trim_end().to_owned()
427}
428
429fn floor_char_boundary(s: &str, index: usize) -> usize {
433 if index >= s.len() {
434 return s.len();
435 }
436 let mut i = index;
437 while !s.is_char_boundary(i) {
438 i -= 1;
439 }
440 i
441}
442
443#[allow(dead_code, reason = "kept for symmetry with ai-sdk imports")]
446type _Unused = TextPart;
447
448#[cfg(test)]
449mod tests {
450 use std::sync::Arc;
451
452 use futures::stream;
453
454 use super::*;
455 use crate::language_model::{FinishReason, FinishReasonKind, Usage};
456 use crate::middleware::wrap_language_model;
457
458 #[derive(Debug)]
459 struct Fake {
460 gen_text: String,
461 stream_deltas: Vec<String>,
462 }
463
464 #[async_trait]
465 impl LanguageModel for Fake {
466 fn provider(&self) -> &'static str {
467 "fake"
468 }
469 fn model_id(&self) -> &'static str {
470 "fake"
471 }
472 async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResult> {
473 Ok(GenerateResult {
474 content: vec![Content::Text(TextPart {
475 text: self.gen_text.clone(),
476 provider_options: None,
477 })],
478 finish_reason: FinishReason::new(FinishReasonKind::Stop),
479 usage: Usage::default(),
480 provider_metadata: None,
481 request: None,
482 response: None,
483 warnings: vec![],
484 })
485 }
486 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
487 let mut parts: Vec<Result<StreamPart>> = vec![Ok(StreamPart::TextStart {
488 id: "b1".into(),
489 provider_metadata: None,
490 })];
491 for d in &self.stream_deltas {
492 parts.push(Ok(StreamPart::TextDelta {
493 id: "b1".into(),
494 delta: d.clone(),
495 provider_metadata: None,
496 }));
497 }
498 parts.push(Ok(StreamPart::TextEnd {
499 id: "b1".into(),
500 provider_metadata: None,
501 }));
502 parts.push(Ok(StreamPart::Finish {
503 usage: Usage::default(),
504 finish_reason: FinishReason::new(FinishReasonKind::Stop),
505 provider_metadata: None,
506 }));
507 Ok(StreamResult {
508 stream: Box::pin(stream::iter(parts)),
509 request: None,
510 response: None,
511 })
512 }
513 }
514
515 async fn collect(stream: BoxStream<Result<StreamPart>>) -> Vec<StreamPart> {
516 let mut out = Vec::new();
517 let mut s = stream;
518 while let Some(item) = s.next().await {
519 out.push(item.unwrap());
520 }
521 out
522 }
523
524 #[tokio::test]
525 async fn generate_strips_fence() {
526 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
527 gen_text: "```json\n{\"x\":1}\n```".into(),
528 stream_deltas: vec![],
529 });
530 let wrapped = wrap_language_model(
531 inner,
532 [Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
533 );
534 let r = wrapped
535 .do_generate(CallOptions::default())
536 .await
537 .expect("gen");
538 let Content::Text(p) = &r.content[0] else {
539 panic!("text");
540 };
541 assert_eq!(p.text, "{\"x\":1}");
542 }
543
544 #[tokio::test]
545 async fn stream_no_fence_passes_through_incrementally() {
546 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
550 gen_text: String::new(),
551 stream_deltas: vec!["hello ".into(), "world ".into(), "of streams".into()],
552 });
553 let wrapped = wrap_language_model(
554 inner,
555 [Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
556 );
557 let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
558 let frames = collect(s.stream).await;
559 let text: String = frames
560 .iter()
561 .filter_map(|f| match f {
562 StreamPart::TextDelta { delta, .. } => Some(delta.clone()),
563 _ => None,
564 })
565 .collect();
566 assert_eq!(text, "hello world of streams");
569 assert!(matches!(frames.first(), Some(StreamPart::TextStart { .. })));
571 assert!(
572 frames
573 .iter()
574 .any(|f| matches!(f, StreamPart::TextEnd { .. }))
575 );
576 }
577
578 #[tokio::test]
579 async fn stream_strips_fence_split_across_deltas() {
580 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
584 gen_text: String::new(),
585 stream_deltas: vec![
586 "```json\n".into(),
587 "{\"city\":\"Tokyo\"}".into(),
588 "\n".into(),
589 "```".into(),
590 ],
591 });
592 let wrapped = wrap_language_model(
593 inner,
594 [Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
595 );
596 let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
597 let frames = collect(s.stream).await;
598 let text: String = frames
599 .iter()
600 .filter_map(|f| match f {
601 StreamPart::TextDelta { delta, .. } => Some(delta.clone()),
602 _ => None,
603 })
604 .collect();
605 assert_eq!(text, "{\"city\":\"Tokyo\"}");
606 }
607
608 #[tokio::test]
609 async fn stream_buffering_phase_with_custom_transform() {
610 let mw: Arc<dyn LanguageModelMiddleware> =
614 Arc::new(ExtractJsonMiddleware::new().with_transform(|s| s.replace("alpha", "ALPHA")));
615 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
616 gen_text: String::new(),
617 stream_deltas: vec!["al".into(), "pha-beta".into()],
620 });
621 let wrapped = wrap_language_model(inner, [mw]);
622 let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
623 let frames = collect(s.stream).await;
624 let text: String = frames
625 .iter()
626 .filter_map(|f| match f {
627 StreamPart::TextDelta { delta, .. } => Some(delta.clone()),
628 _ => None,
629 })
630 .collect();
631 assert_eq!(text, "ALPHA-beta");
632 }
633
634 #[tokio::test]
635 async fn stream_emits_incremental_frames_past_suffix_window() {
636 let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
640 gen_text: String::new(),
641 stream_deltas: vec!["{\"alpha\":\"some-long-value-that-exceeds-buffer\"}".into()],
642 });
643 let wrapped = wrap_language_model(
644 inner,
645 [Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
646 );
647 let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
648 let frames = collect(s.stream).await;
649 let delta_count = frames
652 .iter()
653 .filter(|f| matches!(f, StreamPart::TextDelta { .. }))
654 .count();
655 assert!(
656 delta_count >= 2,
657 "expected incremental streaming (>=2 deltas), got {delta_count}: {frames:?}"
658 );
659 }
660
661 #[test]
662 fn default_transform_strips_lower_case_fence_only() {
663 assert_eq!(default_transform("```json\n{\"a\":1}\n```"), "{\"a\":1}");
668 assert_eq!(default_transform("```\n{\"a\":1}\n```"), "{\"a\":1}");
669 assert_eq!(
673 default_transform("```JSON\n{\"a\":1}\n```"),
674 "JSON\n{\"a\":1}"
675 );
676 }
677
678 #[test]
679 fn match_opening_fence_len_partial_buffer_returns_none() {
680 assert_eq!(match_opening_fence_len(""), None);
681 assert_eq!(match_opening_fence_len("``"), None);
682 assert_eq!(match_opening_fence_len("```"), None); assert_eq!(match_opening_fence_len("```json"), None); assert_eq!(match_opening_fence_len("```json "), None);
685 assert_eq!(
686 match_opening_fence_len("```json \n"),
687 Some("```json \n".len())
688 );
689 assert_eq!(match_opening_fence_len("```\n"), Some(4));
690 assert_eq!(match_opening_fence_len("```xml\n"), None);
692 }
693
694 #[test]
695 fn strip_trailing_fence_handles_optional_leading_newline() {
696 assert_eq!(strip_trailing_fence_replace("{}\n```"), "{}");
697 assert_eq!(strip_trailing_fence_replace("{}```"), "{}");
698 assert_eq!(strip_trailing_fence_replace("{}```\n "), "{}");
699 assert_eq!(strip_trailing_fence_replace("{}\n "), "{}");
701 }
702}