use std::collections::HashMap;
use async_trait::async_trait;
use futures::StreamExt;
use crate::error::Result;
use crate::language_model::{
BoxStream, CallOptions, Content, GenerateResult, LanguageModel, StreamPart, StreamResult,
TextPart,
};
use crate::middleware::language_model::LanguageModelMiddleware;
use crate::shared::ProviderMetadata;
const SUFFIX_BUFFER_SIZE: usize = 12;
type TransformFn = std::sync::Arc<dyn Fn(&str) -> String + Send + Sync>;
pub struct ExtractJsonMiddleware {
transform: Option<TransformFn>,
}
impl std::fmt::Debug for ExtractJsonMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtractJsonMiddleware")
.field("transform", &self.transform.is_some().then_some("<fn>"))
.finish()
}
}
impl Default for ExtractJsonMiddleware {
fn default() -> Self {
Self::new()
}
}
impl ExtractJsonMiddleware {
#[must_use]
pub fn new() -> Self {
Self { transform: None }
}
#[must_use]
pub fn with_transform<F>(mut self, transform: F) -> Self
where
F: Fn(&str) -> String + Send + Sync + 'static,
{
self.transform = Some(std::sync::Arc::new(transform));
self
}
fn apply_transform(&self, text: &str) -> String {
match self.transform.as_ref() {
Some(f) => f(text),
None => default_transform(text),
}
}
}
#[async_trait]
impl LanguageModelMiddleware for ExtractJsonMiddleware {
async fn wrap_generate(
&self,
next: &dyn LanguageModel,
params: CallOptions,
) -> Result<GenerateResult> {
let mut result = next.do_generate(params).await?;
for content in &mut result.content {
if let Content::Text(part) = content {
part.text = self.apply_transform(&part.text);
}
}
Ok(result)
}
async fn wrap_stream(
&self,
next: &dyn LanguageModel,
params: CallOptions,
) -> Result<StreamResult> {
let upstream = next.do_stream(params).await?;
let StreamResult {
stream,
request,
response,
} = upstream;
let transform = self.transform.clone();
let cleaned = transform_stream(stream, transform);
Ok(StreamResult {
stream: cleaned,
request,
response,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Phase {
Prefix,
Streaming,
Buffering,
}
#[derive(Debug)]
struct BlockState {
start_event: StreamPart,
phase: Phase,
buffer: String,
prefix_stripped: bool,
}
fn transform_stream(
stream: BoxStream<Result<StreamPart>>,
transform: Option<TransformFn>,
) -> BoxStream<Result<StreamPart>> {
let has_custom_transform = transform.is_some();
let state: HashMap<String, BlockState> = HashMap::new();
let pending: std::collections::VecDeque<Result<StreamPart>> = std::collections::VecDeque::new();
let init = StreamCtx {
stream,
state,
pending,
transform,
has_custom_transform,
};
let mapped = futures::stream::unfold(init, |mut ctx| async move {
loop {
if let Some(item) = ctx.pending.pop_front() {
return Some((item, ctx));
}
let next = ctx.stream.next().await?;
match next {
Err(e) => return Some((Err(e), ctx)),
Ok(part) => {
ctx.handle(part);
}
}
}
});
Box::pin(mapped)
}
struct StreamCtx {
stream: BoxStream<Result<StreamPart>>,
state: HashMap<String, BlockState>,
pending: std::collections::VecDeque<Result<StreamPart>>,
transform: Option<TransformFn>,
has_custom_transform: bool,
}
impl StreamCtx {
fn apply_transform(&self, text: &str) -> String {
match self.transform.as_ref() {
Some(f) => f(text),
None => default_transform(text),
}
}
fn handle(&mut self, part: StreamPart) {
match part {
StreamPart::TextStart {
id,
provider_metadata,
} => self.on_text_start(id, provider_metadata),
StreamPart::TextDelta { id, delta, .. } => self.on_text_delta(id, delta),
StreamPart::TextEnd {
id,
provider_metadata,
} => self.on_text_end(id, provider_metadata),
other => self.pending.push_back(Ok(other)),
}
}
fn on_text_start(&mut self, id: String, provider_metadata: Option<ProviderMetadata>) {
let start_event = StreamPart::TextStart {
id: id.clone(),
provider_metadata,
};
let phase = if self.has_custom_transform {
Phase::Buffering
} else {
Phase::Prefix
};
self.state.insert(
id,
BlockState {
start_event,
phase,
buffer: String::new(),
prefix_stripped: false,
},
);
}
fn on_text_delta(&mut self, id: String, delta: String) {
let Some(block) = self.state.get_mut(&id) else {
self.pending.push_back(Ok(StreamPart::TextDelta {
id,
delta,
provider_metadata: None,
}));
return;
};
block.buffer.push_str(&delta);
if block.phase == Phase::Buffering {
return;
}
if block.phase == Phase::Prefix {
if !block.buffer.is_empty() && !block.buffer.starts_with('`') {
block.phase = Phase::Streaming;
let start = block.start_event.clone();
self.pending.push_back(Ok(start));
} else if block.buffer.starts_with("```") {
if block.buffer.contains('\n') {
if let Some(prefix_len) = match_opening_fence_len(&block.buffer) {
block.buffer = block.buffer[prefix_len..].to_owned();
block.prefix_stripped = true;
block.phase = Phase::Streaming;
let start = block.start_event.clone();
self.pending.push_back(Ok(start));
} else {
block.phase = Phase::Streaming;
let start = block.start_event.clone();
self.pending.push_back(Ok(start));
}
}
} else if block.buffer.len() >= 3 && !block.buffer.starts_with("```") {
block.phase = Phase::Streaming;
let start = block.start_event.clone();
self.pending.push_back(Ok(start));
}
}
if block.phase == Phase::Streaming && block.buffer.len() > SUFFIX_BUFFER_SIZE {
let cut = floor_char_boundary(&block.buffer, block.buffer.len() - SUFFIX_BUFFER_SIZE);
let to_stream = block.buffer[..cut].to_owned();
block.buffer = block.buffer[cut..].to_owned();
if !to_stream.is_empty() {
self.pending.push_back(Ok(StreamPart::TextDelta {
id: id.clone(),
delta: to_stream,
provider_metadata: None,
}));
}
}
let _ = id;
}
fn on_text_end(&mut self, id: String, provider_metadata: Option<ProviderMetadata>) {
let Some(block) = self.state.remove(&id) else {
self.pending.push_back(Ok(StreamPart::TextEnd {
id,
provider_metadata,
}));
return;
};
let BlockState {
start_event,
phase,
buffer,
prefix_stripped,
} = block;
if matches!(phase, Phase::Prefix | Phase::Buffering) {
self.pending.push_back(Ok(start_event));
}
let remaining = match phase {
Phase::Buffering => self.apply_transform(&buffer),
_ if prefix_stripped => strip_trailing_fence_replace(&buffer),
_ => self.apply_transform(&buffer),
};
if !remaining.is_empty() {
self.pending.push_back(Ok(StreamPart::TextDelta {
id: id.clone(),
delta: remaining,
provider_metadata: None,
}));
}
self.pending.push_back(Ok(StreamPart::TextEnd {
id,
provider_metadata,
}));
}
}
fn default_transform(text: &str) -> String {
let after_prefix = strip_leading_fence(text);
let after_suffix = strip_trailing_fence_replace(after_prefix);
after_suffix.trim().to_owned()
}
fn strip_leading_fence(s: &str) -> &str {
let Some(after_fence) = s.strip_prefix("```") else {
return s;
};
let after_json = after_fence.strip_prefix("json").unwrap_or(after_fence);
let mut i = 0;
let bytes = after_json.as_bytes();
while i < bytes.len() && matches!(bytes[i], b' ' | b'\t' | b'\r' | b'\n' | 0x0b | 0x0c) {
i += 1;
}
&after_json[i..]
}
fn match_opening_fence_len(buf: &str) -> Option<usize> {
let rest = buf.strip_prefix("```")?;
let mut consumed = 3;
let rest = if let Some(r) = rest.strip_prefix("json") {
consumed += 4;
r
} else {
rest
};
let bytes = rest.as_bytes();
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'\n' => return Some(consumed + i + 1),
b' ' | b'\t' | b'\r' | 0x0b | 0x0c => i += 1,
_ => return None,
}
}
None
}
fn strip_trailing_fence_replace(s: &str) -> String {
let bytes = s.as_bytes();
let mut i = bytes.len();
while i > 0 && matches!(bytes[i - 1], b' ' | b'\t' | b'\r' | b'\n' | 0x0b | 0x0c) {
i -= 1;
}
let before_ws = &s[..i];
let Some(before_fence) = before_ws.strip_suffix("```") else {
return s.trim_end().to_owned();
};
let after = before_fence.strip_suffix('\n').unwrap_or(before_fence);
after.trim_end().to_owned()
}
fn floor_char_boundary(s: &str, index: usize) -> usize {
if index >= s.len() {
return s.len();
}
let mut i = index;
while !s.is_char_boundary(i) {
i -= 1;
}
i
}
#[allow(dead_code, reason = "kept for symmetry with ai-sdk imports")]
type _Unused = TextPart;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use futures::stream;
use super::*;
use crate::language_model::{FinishReason, FinishReasonKind, Usage};
use crate::middleware::wrap_language_model;
#[derive(Debug)]
struct Fake {
gen_text: String,
stream_deltas: Vec<String>,
}
#[async_trait]
impl LanguageModel for Fake {
fn provider(&self) -> &'static str {
"fake"
}
fn model_id(&self) -> &'static str {
"fake"
}
async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResult> {
Ok(GenerateResult {
content: vec![Content::Text(TextPart {
text: self.gen_text.clone(),
provider_options: None,
})],
finish_reason: FinishReason::new(FinishReasonKind::Stop),
usage: Usage::default(),
provider_metadata: None,
request: None,
response: None,
warnings: vec![],
})
}
async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
let mut parts: Vec<Result<StreamPart>> = vec![Ok(StreamPart::TextStart {
id: "b1".into(),
provider_metadata: None,
})];
for d in &self.stream_deltas {
parts.push(Ok(StreamPart::TextDelta {
id: "b1".into(),
delta: d.clone(),
provider_metadata: None,
}));
}
parts.push(Ok(StreamPart::TextEnd {
id: "b1".into(),
provider_metadata: None,
}));
parts.push(Ok(StreamPart::Finish {
usage: Usage::default(),
finish_reason: FinishReason::new(FinishReasonKind::Stop),
provider_metadata: None,
}));
Ok(StreamResult {
stream: Box::pin(stream::iter(parts)),
request: None,
response: None,
})
}
}
async fn collect(stream: BoxStream<Result<StreamPart>>) -> Vec<StreamPart> {
let mut out = Vec::new();
let mut s = stream;
while let Some(item) = s.next().await {
out.push(item.unwrap());
}
out
}
#[tokio::test]
async fn generate_strips_fence() {
let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
gen_text: "```json\n{\"x\":1}\n```".into(),
stream_deltas: vec![],
});
let wrapped = wrap_language_model(
inner,
[Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
);
let r = wrapped
.do_generate(CallOptions::default())
.await
.expect("gen");
let Content::Text(p) = &r.content[0] else {
panic!("text");
};
assert_eq!(p.text, "{\"x\":1}");
}
#[tokio::test]
async fn stream_no_fence_passes_through_incrementally() {
let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
gen_text: String::new(),
stream_deltas: vec!["hello ".into(), "world ".into(), "of streams".into()],
});
let wrapped = wrap_language_model(
inner,
[Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
);
let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
let frames = collect(s.stream).await;
let text: String = frames
.iter()
.filter_map(|f| match f {
StreamPart::TextDelta { delta, .. } => Some(delta.clone()),
_ => None,
})
.collect();
assert_eq!(text, "hello world of streams");
assert!(matches!(frames.first(), Some(StreamPart::TextStart { .. })));
assert!(
frames
.iter()
.any(|f| matches!(f, StreamPart::TextEnd { .. }))
);
}
#[tokio::test]
async fn stream_strips_fence_split_across_deltas() {
let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
gen_text: String::new(),
stream_deltas: vec![
"```json\n".into(),
"{\"city\":\"Tokyo\"}".into(),
"\n".into(),
"```".into(),
],
});
let wrapped = wrap_language_model(
inner,
[Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
);
let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
let frames = collect(s.stream).await;
let text: String = frames
.iter()
.filter_map(|f| match f {
StreamPart::TextDelta { delta, .. } => Some(delta.clone()),
_ => None,
})
.collect();
assert_eq!(text, "{\"city\":\"Tokyo\"}");
}
#[tokio::test]
async fn stream_buffering_phase_with_custom_transform() {
let mw: Arc<dyn LanguageModelMiddleware> =
Arc::new(ExtractJsonMiddleware::new().with_transform(|s| s.replace("alpha", "ALPHA")));
let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
gen_text: String::new(),
stream_deltas: vec!["al".into(), "pha-beta".into()],
});
let wrapped = wrap_language_model(inner, [mw]);
let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
let frames = collect(s.stream).await;
let text: String = frames
.iter()
.filter_map(|f| match f {
StreamPart::TextDelta { delta, .. } => Some(delta.clone()),
_ => None,
})
.collect();
assert_eq!(text, "ALPHA-beta");
}
#[tokio::test]
async fn stream_emits_incremental_frames_past_suffix_window() {
let inner: Arc<dyn LanguageModel> = Arc::new(Fake {
gen_text: String::new(),
stream_deltas: vec!["{\"alpha\":\"some-long-value-that-exceeds-buffer\"}".into()],
});
let wrapped = wrap_language_model(
inner,
[Arc::new(ExtractJsonMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
);
let s = wrapped.do_stream(CallOptions::default()).await.unwrap();
let frames = collect(s.stream).await;
let delta_count = frames
.iter()
.filter(|f| matches!(f, StreamPart::TextDelta { .. }))
.count();
assert!(
delta_count >= 2,
"expected incremental streaming (>=2 deltas), got {delta_count}: {frames:?}"
);
}
#[test]
fn default_transform_strips_lower_case_fence_only() {
assert_eq!(default_transform("```json\n{\"a\":1}\n```"), "{\"a\":1}");
assert_eq!(default_transform("```\n{\"a\":1}\n```"), "{\"a\":1}");
assert_eq!(
default_transform("```JSON\n{\"a\":1}\n```"),
"JSON\n{\"a\":1}"
);
}
#[test]
fn match_opening_fence_len_partial_buffer_returns_none() {
assert_eq!(match_opening_fence_len(""), None);
assert_eq!(match_opening_fence_len("``"), None);
assert_eq!(match_opening_fence_len("```"), None); assert_eq!(match_opening_fence_len("```json"), None); assert_eq!(match_opening_fence_len("```json "), None);
assert_eq!(
match_opening_fence_len("```json \n"),
Some("```json \n".len())
);
assert_eq!(match_opening_fence_len("```\n"), Some(4));
assert_eq!(match_opening_fence_len("```xml\n"), None);
}
#[test]
fn strip_trailing_fence_handles_optional_leading_newline() {
assert_eq!(strip_trailing_fence_replace("{}\n```"), "{}");
assert_eq!(strip_trailing_fence_replace("{}```"), "{}");
assert_eq!(strip_trailing_fence_replace("{}```\n "), "{}");
assert_eq!(strip_trailing_fence_replace("{}\n "), "{}");
}
}