use std::future::Future;
use std::pin::Pin;
use async_stream::try_stream;
use futures_util::{Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::error::{RStructorError, Result};
use crate::model::Instructor;
pub type TextStream<'a> = Pin<Box<dyn Stream<Item = Result<String>> + Send + 'a>>;
pub type ObjectStream<'a, T> = Pin<Box<dyn Stream<Item = Result<StreamedObject<T>>> + Send + 'a>>;
#[derive(Debug, Clone)]
pub enum StreamedObject<T> {
Partial(Value),
Complete(T),
}
impl<T> StreamedObject<T> {
pub fn complete(self) -> Option<T> {
match self {
StreamedObject::Complete(value) => Some(value),
StreamedObject::Partial(_) => None,
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum SseEvent {
Data(String),
Done,
}
#[derive(Default)]
pub(crate) struct SseDecoder {
buf: Vec<u8>,
}
impl SseDecoder {
pub(crate) fn push(&mut self, chunk: &[u8]) -> Vec<SseEvent> {
self.buf.extend_from_slice(chunk);
let mut events = Vec::new();
while let Some(nl) = self.buf.iter().position(|&b| b == b'\n') {
let line_bytes: Vec<u8> = self.buf.drain(..=nl).collect();
let line = String::from_utf8_lossy(&line_bytes);
let line = line.trim_end_matches(['\r', '\n']);
if let Some(rest) = line.strip_prefix("data:") {
let data = rest.trim();
if data == "[DONE]" {
events.push(SseEvent::Done);
} else if !data.is_empty() {
events.push(SseEvent::Data(data.to_string()));
}
}
}
events
}
}
pub(crate) fn sse_text_stream<'a, Fut, F>(send: Fut, extract: F) -> TextStream<'a>
where
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
{
Box::pin(try_stream! {
let response = send.await?;
let mut bytes = response.bytes_stream();
let mut decoder = SseDecoder::default();
'outer: while let Some(chunk) = bytes.next().await {
let chunk = chunk.map_err(RStructorError::from)?;
for event in decoder.push(chunk.as_ref()) {
match event {
SseEvent::Done => break 'outer,
SseEvent::Data(data) => {
if let Ok(json) = serde_json::from_str::<Value>(&data)
&& let Some(text) = extract(&json)
&& !text.is_empty()
{
yield text;
}
}
}
}
}
})
}
pub(crate) fn object_stream<'a, T, Fut, F>(send: Fut, extract: F) -> ObjectStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'a,
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
{
object_stream_with(send, extract, |raw: &str| {
super::utils::parse_and_validate_response::<T>(raw).map_err(|(err, _ctx)| err)
})
}
pub(crate) fn object_stream_with<'a, T, Fut, F, Fin>(
send: Fut,
extract: F,
finalize: Fin,
) -> ObjectStream<'a, T>
where
T: Send + 'a,
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
Fin: FnOnce(&str) -> Result<T> + Send + 'a,
{
Box::pin(try_stream! {
let response = send.await?;
let mut bytes = response.bytes_stream();
let mut decoder = SseDecoder::default();
let mut buf = String::new();
let mut last_partial: Option<Value> = None;
'outer: while let Some(chunk) = bytes.next().await {
let chunk = chunk.map_err(RStructorError::from)?;
for event in decoder.push(chunk.as_ref()) {
match event {
SseEvent::Done => break 'outer,
SseEvent::Data(data) => {
if let Ok(json) = serde_json::from_str::<Value>(&data)
&& let Some(text) = extract(&json)
{
buf.push_str(&text);
if let Some(partial) = complete_json(&buf)
&& last_partial.as_ref() != Some(&partial)
{
last_partial = Some(partial.clone());
yield StreamedObject::Partial(partial);
}
}
}
}
}
}
let value: T = finalize(buf.trim())?;
yield StreamedObject::Complete(value);
})
}
pub(crate) fn openai_delta(event: &Value) -> Option<String> {
event
.get("choices")?
.get(0)?
.get("delta")?
.get("content")?
.as_str()
.map(str::to_owned)
}
pub(crate) fn anthropic_delta(event: &Value) -> Option<String> {
if event.get("type")?.as_str()? != "content_block_delta" {
return None;
}
let delta = event.get("delta")?;
delta
.get("text")
.and_then(Value::as_str)
.or_else(|| delta.get("partial_json").and_then(Value::as_str))
.map(str::to_owned)
}
pub(crate) fn gemini_delta(event: &Value) -> Option<String> {
let parts = event
.get("candidates")?
.get(0)?
.get("content")?
.get("parts")?
.as_array()?;
let text: String = parts
.iter()
.filter_map(|p| p.get("text").and_then(Value::as_str))
.collect();
if text.is_empty() { None } else { Some(text) }
}
pub(crate) fn complete_json(s: &str) -> Option<Value> {
let repaired = repair_json(s)?;
serde_json::from_str(&repaired).ok()
}
fn repair_json(s: &str) -> Option<String> {
let s = s.trim();
if s.is_empty() {
return None;
}
let mut out = String::with_capacity(s.len() + 8);
let mut stack: Vec<char> = Vec::new();
let mut in_string = false;
let mut escaped = false;
for c in s.chars() {
if in_string {
out.push(c);
if escaped {
escaped = false;
} else if c == '\\' {
escaped = true;
} else if c == '"' {
in_string = false;
}
} else {
match c {
'"' => {
in_string = true;
out.push(c);
}
'{' => {
stack.push('{');
out.push(c);
}
'[' => {
stack.push('[');
out.push(c);
}
'}' => {
if stack.pop() != Some('{') {
return None;
}
out.push(c);
}
']' => {
if stack.pop() != Some('[') {
return None;
}
out.push(c);
}
_ => out.push(c),
}
}
}
if in_string && escaped {
out.pop();
}
if in_string {
out.push('"');
}
loop {
let trimmed_len = out.trim_end().len();
out.truncate(trimmed_len);
if out.ends_with(',') {
out.pop();
continue;
}
if out.ends_with(':') {
if let Some(cut) = out.rfind(['{', ',']) {
out.truncate(cut + 1);
} else {
return None;
}
continue;
}
break;
}
for &opener in stack.iter().rev() {
out.push(if opener == '{' { '}' } else { ']' });
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn decoder_emits_complete_data_event() {
let mut d = SseDecoder::default();
assert_eq!(
d.push(b"data: {\"a\":1}\n\n"),
vec![SseEvent::Data("{\"a\":1}".to_string())]
);
}
#[test]
fn decoder_buffers_across_chunk_boundary() {
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data: {\"hel"), vec![]);
assert_eq!(d.push(b"lo\":1"), vec![]);
assert_eq!(
d.push(b"}\n"),
vec![SseEvent::Data("{\"hello\":1}".to_string())]
);
}
#[test]
fn decoder_handles_crlf_and_ignores_non_data_lines() {
let mut d = SseDecoder::default();
assert_eq!(
d.push(b"event: message\r\ndata: {\"x\":1}\r\n\r\n: keep-alive\r\n"),
vec![SseEvent::Data("{\"x\":1}".to_string())]
);
}
#[test]
fn decoder_recognizes_done_sentinel() {
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data: [DONE]\n\n"), vec![SseEvent::Done]);
}
#[test]
fn openai_delta_extracts_content() {
assert_eq!(
openai_delta(&json!({"choices":[{"delta":{"content":"Hi"}}]})),
Some("Hi".to_string())
);
assert_eq!(
openai_delta(&json!({"choices":[{"delta":{"role":"assistant"}}]})),
None
);
}
#[test]
fn anthropic_delta_extracts_text_and_partial_json() {
assert_eq!(
anthropic_delta(
&json!({"type":"content_block_delta","delta":{"type":"text_delta","text":"Hi"}})
),
Some("Hi".to_string())
);
assert_eq!(
anthropic_delta(
&json!({"type":"content_block_delta","delta":{"type":"input_json_delta","partial_json":"{\"a\":"}})
),
Some("{\"a\":".to_string())
);
assert_eq!(anthropic_delta(&json!({"type":"message_start"})), None);
}
#[test]
fn gemini_delta_concatenates_parts() {
assert_eq!(
gemini_delta(
&json!({"candidates":[{"content":{"parts":[{"text":"a"},{"text":"b"}]}}]})
),
Some("ab".to_string())
);
}
#[test]
fn complete_json_closes_open_string_and_object() {
assert_eq!(
complete_json(r#"{"name": "Ali"#).unwrap(),
json!({"name": "Ali"})
);
}
#[test]
fn complete_json_drops_dangling_key_and_comma() {
assert_eq!(complete_json(r#"{"a": 1, "b":"#).unwrap(), json!({"a": 1}));
assert_eq!(complete_json(r#"{"a": 1, "#).unwrap(), json!({"a": 1}));
assert_eq!(complete_json(r#"{"a": 1,"#).unwrap(), json!({"a": 1}));
}
#[test]
fn complete_json_closes_nested_and_arrays() {
assert_eq!(
complete_json(r#"{"items":[{"x":1},{"x":2"#).unwrap(),
json!({"items":[{"x":1},{"x":2}]})
);
assert_eq!(complete_json(r#"[1, 2, 3"#).unwrap(), json!([1, 2, 3]));
assert_eq!(complete_json(r#"[1, 2, "#).unwrap(), json!([1, 2]));
}
#[test]
fn complete_json_skips_incomplete_primitive() {
assert!(complete_json(r#"{"a": tr"#).is_none());
assert!(complete_json(r#"{"a": 12."#).is_none());
assert!(complete_json("").is_none());
}
#[test]
fn complete_json_handles_escapes() {
assert_eq!(
complete_json(r#"{"s": "line\"#).unwrap(),
json!({"s": "line"})
);
assert_eq!(
complete_json(r#"{"s": "a\nb"#).unwrap(),
json!({"s": "a\nb"})
);
}
#[test]
fn complete_json_progressive_prefixes_converge() {
let full = r#"{"name":"Alice","age":30,"tags":["x","y"]}"#;
for i in 1..=full.len() {
if let Some(v) = complete_json(&full[..i]) {
assert!(v.is_object() || v.is_array());
}
}
assert_eq!(
complete_json(full).unwrap(),
json!({"name":"Alice","age":30,"tags":["x","y"]})
);
}
}
pub type ItemStream<'a, T> = Pin<Box<dyn Stream<Item = Result<T>> + Send + 'a>>;
#[derive(Default)]
pub(crate) struct JsonArrayStreamer {
in_array: bool,
depth: i32,
in_string: bool,
escaped: bool,
started_element: bool,
current: String,
}
impl JsonArrayStreamer {
pub(crate) fn push_str(&mut self, s: &str) -> Vec<Value> {
let mut out = Vec::new();
for c in s.chars() {
if !self.in_array {
if c == '[' {
self.in_array = true;
}
continue;
}
if self.in_string {
self.current.push(c);
if self.escaped {
self.escaped = false;
} else if c == '\\' {
self.escaped = true;
} else if c == '"' {
self.in_string = false;
}
continue;
}
match c {
'"' => {
self.started_element = true;
self.in_string = true;
self.current.push(c);
}
'{' | '[' => {
self.started_element = true;
self.depth += 1;
self.current.push(c);
}
'}' | ']' if self.depth > 0 => {
self.depth -= 1;
self.current.push(c);
}
']' => {
if let Some(v) = self.finish_element() {
out.push(v);
}
self.in_array = false;
}
',' if self.depth == 0 => {
if let Some(v) = self.finish_element() {
out.push(v);
}
}
c if c.is_whitespace() && !self.started_element => {}
_ => {
self.started_element = true;
self.current.push(c);
}
}
}
out
}
fn finish_element(&mut self) -> Option<Value> {
let text = std::mem::take(&mut self.current);
self.started_element = false;
let trimmed = text.trim();
if trimmed.is_empty() {
return None;
}
serde_json::from_str(trimmed).ok()
}
}
pub(crate) fn iter_stream<'a, T, Fut, F, Fin>(
send: Fut,
extract: F,
finalize_item: Fin,
) -> ItemStream<'a, T>
where
T: Send + 'a,
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
Fin: Fn(Value) -> Result<T> + Send + 'a,
{
Box::pin(try_stream! {
let response = send.await?;
let mut bytes = response.bytes_stream();
let mut decoder = SseDecoder::default();
let mut array = JsonArrayStreamer::default();
'outer: while let Some(chunk) = bytes.next().await {
let chunk = chunk.map_err(RStructorError::from)?;
for event in decoder.push(chunk.as_ref()) {
match event {
SseEvent::Done => break 'outer,
SseEvent::Data(data) => {
if let Ok(json) = serde_json::from_str::<Value>(&data)
&& let Some(text) = extract(&json)
{
for element in array.push_str(&text) {
yield finalize_item(element)?;
}
}
}
}
}
}
})
}
pub(crate) fn finalize_item<T: Instructor + DeserializeOwned>(value: Value) -> Result<T> {
let item: T = serde_json::from_value(value)
.map_err(|e| RStructorError::SerializationError(e.to_string()))?;
item.validate()?;
Ok(item)
}
pub(crate) fn array_wrapper_schema(item_schema: Value, strict: bool) -> Value {
let mut wrapper = serde_json::json!({
"type": "object",
"properties": { "items": { "type": "array", "items": item_schema } },
"required": ["items"],
});
if strict {
wrapper["additionalProperties"] = Value::Bool(false);
}
wrapper
}
#[cfg(test)]
mod array_tests {
use super::*;
use serde_json::json;
#[test]
fn streams_object_elements_as_they_complete() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":["#), Vec::<Value>::new());
assert_eq!(s.push_str(r#"{"n":1},{"n":2}"#), vec![json!({"n":1})]);
assert_eq!(
s.push_str(r#",{"n":3}]}"#),
vec![json!({"n":2}), json!({"n":3})]
);
}
#[test]
fn handles_scalars_strings_and_nesting() {
let mut s = JsonArrayStreamer::default();
let got = s.push_str(r#"{"items":[1, "a,b", {"x":[1,2]}, true]}"#);
assert_eq!(
got,
vec![json!(1), json!("a,b"), json!({"x":[1,2]}), json!(true)]
);
}
#[test]
fn ignores_strings_containing_brackets_before_array() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[{"v":"#), Vec::<Value>::new());
}
}