use std::future::Future;
use std::pin::Pin;
use async_stream::try_stream;
use futures_util::{Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::error::{RStructorError, Result};
use crate::model::Instructor;
pub type TextStream<'a> = Pin<Box<dyn Stream<Item = Result<String>> + Send + 'a>>;
pub type ObjectStream<'a, T> = Pin<Box<dyn Stream<Item = Result<StreamedObject<T>>> + Send + 'a>>;
#[derive(Debug, Clone)]
pub enum StreamedObject<T> {
Partial(Value),
Complete(T),
}
impl<T> StreamedObject<T> {
pub fn complete(self) -> Option<T> {
match self {
StreamedObject::Complete(value) => Some(value),
StreamedObject::Partial(_) => None,
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum SseEvent {
Data(String),
Done,
}
#[derive(Default)]
pub(crate) struct SseDecoder {
buf: Vec<u8>,
}
impl SseDecoder {
pub(crate) fn push(&mut self, chunk: &[u8]) -> Vec<SseEvent> {
self.buf.extend_from_slice(chunk);
let mut events = Vec::new();
while let Some(nl) = self.buf.iter().position(|&b| b == b'\n') {
let line_bytes: Vec<u8> = self.buf.drain(..=nl).collect();
let line = String::from_utf8_lossy(&line_bytes);
let line = line.trim_end_matches(['\r', '\n']);
if let Some(rest) = line.strip_prefix("data:") {
let data = rest.trim();
if data == "[DONE]" {
events.push(SseEvent::Done);
} else if !data.is_empty() {
events.push(SseEvent::Data(data.to_string()));
}
}
}
events
}
}
pub(crate) fn sse_text_stream<'a, Fut, F>(send: Fut, extract: F) -> TextStream<'a>
where
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
{
Box::pin(try_stream! {
let response = send.await?;
let mut bytes = response.bytes_stream();
let mut decoder = SseDecoder::default();
'outer: while let Some(chunk) = bytes.next().await {
let chunk = chunk.map_err(RStructorError::from)?;
for event in decoder.push(chunk.as_ref()) {
match event {
SseEvent::Done => break 'outer,
SseEvent::Data(data) => {
if let Ok(json) = serde_json::from_str::<Value>(&data)
&& let Some(text) = extract(&json)
&& !text.is_empty()
{
yield text;
}
}
}
}
}
})
}
pub(crate) fn object_stream<'a, T, Fut, F>(send: Fut, extract: F) -> ObjectStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'a,
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
{
object_stream_with(send, extract, |raw: &str| {
super::utils::parse_and_validate_response::<T>(raw).map_err(|(err, _ctx)| err)
})
}
pub(crate) fn object_stream_with<'a, T, Fut, F, Fin>(
send: Fut,
extract: F,
finalize: Fin,
) -> ObjectStream<'a, T>
where
T: Send + 'a,
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
Fin: FnOnce(&str) -> Result<T> + Send + 'a,
{
Box::pin(try_stream! {
let response = send.await?;
let mut bytes = response.bytes_stream();
let mut decoder = SseDecoder::default();
let mut buf = String::new();
let mut last_partial: Option<Value> = None;
'outer: while let Some(chunk) = bytes.next().await {
let chunk = chunk.map_err(RStructorError::from)?;
for event in decoder.push(chunk.as_ref()) {
match event {
SseEvent::Done => break 'outer,
SseEvent::Data(data) => {
if let Ok(json) = serde_json::from_str::<Value>(&data)
&& let Some(text) = extract(&json)
{
buf.push_str(&text);
if let Some(partial) = complete_json(&buf)
&& last_partial.as_ref() != Some(&partial)
{
last_partial = Some(partial.clone());
yield StreamedObject::Partial(partial);
}
}
}
}
}
}
let value: T = finalize(buf.trim())?;
yield StreamedObject::Complete(value);
})
}
pub(crate) fn openai_delta(event: &Value) -> Option<String> {
event
.get("choices")?
.get(0)?
.get("delta")?
.get("content")?
.as_str()
.map(str::to_owned)
}
pub(crate) fn anthropic_delta(event: &Value) -> Option<String> {
if event.get("type")?.as_str()? != "content_block_delta" {
return None;
}
let delta = event.get("delta")?;
delta
.get("text")
.and_then(Value::as_str)
.or_else(|| delta.get("partial_json").and_then(Value::as_str))
.map(str::to_owned)
}
pub(crate) fn gemini_delta(event: &Value) -> Option<String> {
let parts = event
.get("candidates")?
.get(0)?
.get("content")?
.get("parts")?
.as_array()?;
let text: String = parts
.iter()
.filter_map(|p| p.get("text").and_then(Value::as_str))
.collect();
if text.is_empty() { None } else { Some(text) }
}
pub(crate) fn complete_json(s: &str) -> Option<Value> {
let repaired = repair_json(s)?;
serde_json::from_str(&repaired).ok()
}
fn repair_json(s: &str) -> Option<String> {
let s = s.trim();
if s.is_empty() {
return None;
}
let mut out = String::with_capacity(s.len() + 8);
let mut stack: Vec<char> = Vec::new();
let mut in_string = false;
let mut escaped = false;
for c in s.chars() {
if in_string {
out.push(c);
if escaped {
escaped = false;
} else if c == '\\' {
escaped = true;
} else if c == '"' {
in_string = false;
}
} else {
match c {
'"' => {
in_string = true;
out.push(c);
}
'{' => {
stack.push('{');
out.push(c);
}
'[' => {
stack.push('[');
out.push(c);
}
'}' => {
if stack.pop() != Some('{') {
return None;
}
out.push(c);
}
']' => {
if stack.pop() != Some('[') {
return None;
}
out.push(c);
}
_ => out.push(c),
}
}
}
if in_string && escaped {
out.pop();
}
if in_string {
out.push('"');
}
loop {
let trimmed_len = out.trim_end().len();
out.truncate(trimmed_len);
if out.ends_with(',') {
out.pop();
continue;
}
if out.ends_with(':') {
if let Some(cut) = out.rfind(['{', ',']) {
out.truncate(cut + 1);
} else {
return None;
}
continue;
}
break;
}
for &opener in stack.iter().rev() {
out.push(if opener == '{' { '}' } else { ']' });
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn decoder_emits_complete_data_event() {
let mut d = SseDecoder::default();
assert_eq!(
d.push(b"data: {\"a\":1}\n\n"),
vec![SseEvent::Data("{\"a\":1}".to_string())]
);
}
#[test]
fn decoder_buffers_across_chunk_boundary() {
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data: {\"hel"), vec![]);
assert_eq!(d.push(b"lo\":1"), vec![]);
assert_eq!(
d.push(b"}\n"),
vec![SseEvent::Data("{\"hello\":1}".to_string())]
);
}
#[test]
fn decoder_handles_crlf_and_ignores_non_data_lines() {
let mut d = SseDecoder::default();
assert_eq!(
d.push(b"event: message\r\ndata: {\"x\":1}\r\n\r\n: keep-alive\r\n"),
vec![SseEvent::Data("{\"x\":1}".to_string())]
);
}
#[test]
fn decoder_recognizes_done_sentinel() {
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data: [DONE]\n\n"), vec![SseEvent::Done]);
}
#[test]
fn openai_delta_extracts_content() {
assert_eq!(
openai_delta(&json!({"choices":[{"delta":{"content":"Hi"}}]})),
Some("Hi".to_string())
);
assert_eq!(
openai_delta(&json!({"choices":[{"delta":{"role":"assistant"}}]})),
None
);
}
#[test]
fn anthropic_delta_extracts_text_and_partial_json() {
assert_eq!(
anthropic_delta(
&json!({"type":"content_block_delta","delta":{"type":"text_delta","text":"Hi"}})
),
Some("Hi".to_string())
);
assert_eq!(
anthropic_delta(
&json!({"type":"content_block_delta","delta":{"type":"input_json_delta","partial_json":"{\"a\":"}})
),
Some("{\"a\":".to_string())
);
assert_eq!(anthropic_delta(&json!({"type":"message_start"})), None);
}
#[test]
fn gemini_delta_concatenates_parts() {
assert_eq!(
gemini_delta(
&json!({"candidates":[{"content":{"parts":[{"text":"a"},{"text":"b"}]}}]})
),
Some("ab".to_string())
);
}
#[test]
fn complete_json_closes_open_string_and_object() {
assert_eq!(
complete_json(r#"{"name": "Ali"#).unwrap(),
json!({"name": "Ali"})
);
}
#[test]
fn complete_json_drops_dangling_key_and_comma() {
assert_eq!(complete_json(r#"{"a": 1, "b":"#).unwrap(), json!({"a": 1}));
assert_eq!(complete_json(r#"{"a": 1, "#).unwrap(), json!({"a": 1}));
assert_eq!(complete_json(r#"{"a": 1,"#).unwrap(), json!({"a": 1}));
}
#[test]
fn complete_json_closes_nested_and_arrays() {
assert_eq!(
complete_json(r#"{"items":[{"x":1},{"x":2"#).unwrap(),
json!({"items":[{"x":1},{"x":2}]})
);
assert_eq!(complete_json(r#"[1, 2, 3"#).unwrap(), json!([1, 2, 3]));
assert_eq!(complete_json(r#"[1, 2, "#).unwrap(), json!([1, 2]));
}
#[test]
fn complete_json_skips_incomplete_primitive() {
assert!(complete_json(r#"{"a": tr"#).is_none());
assert!(complete_json(r#"{"a": 12."#).is_none());
assert!(complete_json("").is_none());
}
#[test]
fn complete_json_handles_escapes() {
assert_eq!(
complete_json(r#"{"s": "line\"#).unwrap(),
json!({"s": "line"})
);
assert_eq!(
complete_json(r#"{"s": "a\nb"#).unwrap(),
json!({"s": "a\nb"})
);
}
#[test]
fn complete_json_progressive_prefixes_converge() {
let full = r#"{"name":"Alice","age":30,"tags":["x","y"]}"#;
for i in 1..=full.len() {
if let Some(v) = complete_json(&full[..i]) {
assert!(v.is_object() || v.is_array());
}
}
assert_eq!(
complete_json(full).unwrap(),
json!({"name":"Alice","age":30,"tags":["x","y"]})
);
}
#[test]
fn decoder_splits_crlf_across_chunks() {
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data:{a}\r"), vec![]);
assert_eq!(d.push(b"\n\r\n"), vec![SseEvent::Data("{a}".to_string())]);
}
#[test]
fn decoder_emits_multiple_data_lines_in_one_chunk_in_order() {
let mut d = SseDecoder::default();
assert_eq!(
d.push(b"data:{a}\ndata:{b}\n"),
vec![
SseEvent::Data("{a}".to_string()),
SseEvent::Data("{b}".to_string()),
]
);
}
#[test]
fn decoder_reassembles_utf8_multibyte_split_across_chunks() {
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data:\xe2\x82"), vec![]);
let events = d.push(b"\xac\n");
assert_eq!(events, vec![SseEvent::Data("\u{20AC}".to_string())]);
if let SseEvent::Data(s) = &events[0] {
assert_eq!(s.chars().next(), Some('\u{20AC}'));
}
}
#[test]
fn decoder_handles_done_and_data_in_same_push() {
let mut d = SseDecoder::default();
assert_eq!(
d.push(b"data:[DONE]\ndata:{a}\n"),
vec![SseEvent::Done, SseEvent::Data("{a}".to_string())]
);
}
#[test]
fn decoder_ignores_empty_and_whitespace_data_lines() {
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data:\n"), vec![]);
let mut d = SseDecoder::default();
assert_eq!(d.push(b"data: \n"), vec![]);
}
#[test]
fn decoder_done_sentinel_is_case_sensitive() {
let mut d = SseDecoder::default();
assert_eq!(
d.push(b"data:[done]\n"),
vec![SseEvent::Data("[done]".to_string())]
);
}
#[test]
fn complete_json_rejects_truncated_unicode_escape() {
assert!(complete_json(r#"{"s":"\u00"#).is_none());
}
#[test]
fn complete_json_rejects_unbalanced_or_extra_closers() {
assert!(complete_json(r#"{"a":1}}"#).is_none());
assert!(complete_json(r#"{"a":1]"#).is_none());
}
#[test]
fn complete_json_handles_odd_and_even_trailing_backslashes() {
assert_eq!(complete_json(r#"{"s":"a\\"#).unwrap(), json!({"s": "a\\"}));
assert_eq!(complete_json(r#"{"s":"a\\\"#).unwrap(), json!({"s": "a\\"}));
}
#[test]
fn complete_json_rejects_dangling_minus_but_allows_negative_exponent() {
assert!(complete_json(r#"{"a":-"#).is_none());
assert_eq!(
complete_json(r#"{"a":-1.2e10"#).unwrap(),
json!({"a": -1.2e10})
);
}
#[test]
fn complete_json_rejects_dangling_colon_without_container() {
assert!(complete_json(r#""key":"#).is_none());
assert!(complete_json("x:").is_none());
}
#[test]
fn complete_json_passes_top_level_scalars_through() {
assert_eq!(complete_json("42").unwrap(), json!(42));
assert_eq!(complete_json("true").unwrap(), json!(true));
assert_eq!(complete_json(r#""hello"#).unwrap(), json!("hello"));
assert_eq!(complete_json("[1,2,3]").unwrap(), json!([1, 2, 3]));
}
#[test]
fn complete_json_trims_trailing_whitespace_and_comma() {
assert_eq!(complete_json("{\"a\":1, \n ").unwrap(), json!({"a": 1}));
}
#[test]
fn streamed_object_complete_accessor() {
assert_eq!(StreamedObject::Complete(42).complete(), Some(42));
assert_eq!(
StreamedObject::<i32>::Partial(json!({"a": 1})).complete(),
None
);
}
}
pub type ItemStream<'a, T> = Pin<Box<dyn Stream<Item = Result<T>> + Send + 'a>>;
#[derive(Default)]
pub(crate) struct JsonArrayStreamer {
in_array: bool,
depth: i32,
in_string: bool,
escaped: bool,
started_element: bool,
current: String,
}
impl JsonArrayStreamer {
pub(crate) fn push_str(&mut self, s: &str) -> Vec<Value> {
let mut out = Vec::new();
for c in s.chars() {
if !self.in_array {
if c == '[' {
self.in_array = true;
}
continue;
}
if self.in_string {
self.current.push(c);
if self.escaped {
self.escaped = false;
} else if c == '\\' {
self.escaped = true;
} else if c == '"' {
self.in_string = false;
}
continue;
}
match c {
'"' => {
self.started_element = true;
self.in_string = true;
self.current.push(c);
}
'{' | '[' => {
self.started_element = true;
self.depth += 1;
self.current.push(c);
}
'}' | ']' if self.depth > 0 => {
self.depth -= 1;
self.current.push(c);
}
']' => {
if let Some(v) = self.finish_element() {
out.push(v);
}
self.in_array = false;
}
',' if self.depth == 0 => {
if let Some(v) = self.finish_element() {
out.push(v);
}
}
c if c.is_whitespace() && !self.started_element => {}
_ => {
self.started_element = true;
self.current.push(c);
}
}
}
out
}
fn finish_element(&mut self) -> Option<Value> {
let text = std::mem::take(&mut self.current);
self.started_element = false;
let trimmed = text.trim();
if trimmed.is_empty() {
return None;
}
serde_json::from_str(trimmed).ok()
}
}
pub(crate) fn iter_stream<'a, T, Fut, F, Fin>(
send: Fut,
extract: F,
finalize_item: Fin,
) -> ItemStream<'a, T>
where
T: Send + 'a,
Fut: Future<Output = Result<reqwest::Response>> + Send + 'a,
F: Fn(&Value) -> Option<String> + Send + 'a,
Fin: Fn(Value) -> Result<T> + Send + 'a,
{
Box::pin(try_stream! {
let response = send.await?;
let mut bytes = response.bytes_stream();
let mut decoder = SseDecoder::default();
let mut array = JsonArrayStreamer::default();
'outer: while let Some(chunk) = bytes.next().await {
let chunk = chunk.map_err(RStructorError::from)?;
for event in decoder.push(chunk.as_ref()) {
match event {
SseEvent::Done => break 'outer,
SseEvent::Data(data) => {
if let Ok(json) = serde_json::from_str::<Value>(&data)
&& let Some(text) = extract(&json)
{
for element in array.push_str(&text) {
yield finalize_item(element)?;
}
}
}
}
}
}
})
}
pub(crate) fn finalize_item<T: Instructor + DeserializeOwned>(value: Value) -> Result<T> {
let item: T = serde_json::from_value(value)
.map_err(|e| RStructorError::SerializationError(e.to_string()))?;
item.validate()?;
Ok(item)
}
pub(crate) fn array_wrapper_schema(item_schema: Value, strict: bool) -> Value {
let mut wrapper = serde_json::json!({
"type": "object",
"properties": { "items": { "type": "array", "items": item_schema } },
"required": ["items"],
});
if strict {
wrapper["additionalProperties"] = Value::Bool(false);
}
wrapper
}
#[cfg(test)]
mod array_tests {
use super::*;
use serde_json::json;
#[test]
fn streams_object_elements_as_they_complete() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":["#), Vec::<Value>::new());
assert_eq!(s.push_str(r#"{"n":1},{"n":2}"#), vec![json!({"n":1})]);
assert_eq!(
s.push_str(r#",{"n":3}]}"#),
vec![json!({"n":2}), json!({"n":3})]
);
}
#[test]
fn handles_scalars_strings_and_nesting() {
let mut s = JsonArrayStreamer::default();
let got = s.push_str(r#"{"items":[1, "a,b", {"x":[1,2]}, true]}"#);
assert_eq!(
got,
vec![json!(1), json!("a,b"), json!({"x":[1,2]}), json!(true)]
);
}
#[test]
fn ignores_strings_containing_brackets_before_array() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[{"v":"#), Vec::<Value>::new());
}
#[test]
fn handles_escaped_quotes_in_string_element() {
let mut s = JsonArrayStreamer::default();
assert_eq!(
s.push_str(r#"{"items":["he said \"hi\"","x"]}"#),
vec![json!("he said \"hi\""), json!("x")]
);
}
#[test]
fn handles_string_containing_closing_bracket() {
let mut s = JsonArrayStreamer::default();
assert_eq!(
s.push_str(r#"{"items":["a]b","c"]}"#),
vec![json!("a]b"), json!("c")]
);
}
#[test]
fn handles_array_of_arrays_elements() {
let mut s = JsonArrayStreamer::default();
assert_eq!(
s.push_str(r#"{"items":[[1,2],[3,4]]}"#),
vec![json!([1, 2]), json!([3, 4])]
);
}
#[test]
fn handles_null_and_bare_top_level_array() {
let mut s = JsonArrayStreamer::default();
assert_eq!(
s.push_str(r#"{"items":[null,1]}"#),
vec![json!(null), json!(1)]
);
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"[1,2,3]"#), vec![json!(1), json!(2), json!(3)]);
}
#[test]
fn drops_invalid_element_and_recovers() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[1abc,2]}"#), vec![json!(2)]);
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[,1]}"#), vec![json!(1)]);
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[1,]}"#), vec![json!(1)]);
}
#[test]
fn empty_items_array_yields_nothing() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[]}"#), Vec::<Value>::new());
}
#[test]
fn element_split_across_push_str_calls() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[12"#), Vec::<Value>::new());
assert_eq!(s.push_str(r#"34,5]}"#), vec![json!(1234), json!(5)]);
}
#[test]
fn escape_flag_persists_across_push_str_calls() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":["a\"#), Vec::<Value>::new());
assert_eq!(s.push_str(r#"\b","c"]}"#), vec![json!("a\\b"), json!("c")]);
}
#[test]
fn re_entry_after_array_close() {
let mut s = JsonArrayStreamer::default();
assert_eq!(s.push_str(r#"{"items":[1]}"#), vec![json!(1)]);
assert_eq!(s.push_str(r#"[9,9]"#), vec![json!(9), json!(9)]);
}
#[test]
fn array_wrapper_schema_strict_adds_additional_properties_and_required() {
let item = json!({"type": "object"});
let wrapper = array_wrapper_schema(item.clone(), true);
assert_eq!(wrapper["additionalProperties"], json!(false));
assert_eq!(wrapper["required"], json!(["items"]));
assert_eq!(wrapper["type"], json!("object"));
assert_eq!(wrapper["properties"]["items"]["type"], json!("array"));
assert_eq!(wrapper["properties"]["items"]["items"], item);
}
#[test]
fn array_wrapper_schema_non_strict_omits_additional_properties() {
let item = json!({"type": "string"});
let wrapper = array_wrapper_schema(item.clone(), false);
assert!(wrapper.get("additionalProperties").is_none());
assert_eq!(wrapper["required"], json!(["items"]));
assert_eq!(wrapper["properties"]["items"]["items"], item);
}
#[cfg(feature = "derive")]
#[test]
fn finalize_item_validation_and_deserialize_failures() {
use crate::Instructor;
use serde::{Deserialize, Serialize};
#[derive(Instructor, Serialize, Deserialize, Debug, PartialEq)]
#[llm(validate = "validate_ticket")]
struct Ticket {
title: String,
priority: u8,
}
fn validate_ticket(t: &Ticket) -> crate::Result<()> {
if !(1..=5).contains(&t.priority) {
return Err(RStructorError::ValidationError(format!(
"priority must be 1-5, got {}",
t.priority
)));
}
Ok(())
}
let err = finalize_item::<Ticket>(json!({"title": "x", "priority": 99}))
.expect_err("priority 99 should fail validation");
assert!(
matches!(err, RStructorError::ValidationError(_)),
"expected ValidationError, got {err:?}"
);
let err = finalize_item::<Ticket>(json!({"title": 123, "priority": 1}))
.expect_err("non-string title should fail deserialization");
assert!(
matches!(err, RStructorError::SerializationError(_)),
"expected SerializationError, got {err:?}"
);
let ok = finalize_item::<Ticket>(json!({"title": "x", "priority": 3})).unwrap();
assert_eq!(
ok,
Ticket {
title: "x".into(),
priority: 3
}
);
}
}