use {
crate::error::{LlmWebError, Result},
async_openai::types::chat::ChatCompletionResponseStream,
async_stream::try_stream,
futures::{Stream, StreamExt},
serde::de::DeserializeOwned,
std::pin::Pin,
};
pub type PartialStream<R> = Pin<Box<dyn Stream<Item = Result<R>> + Send>>;
pub fn partial_stream<R>(mut chat: ChatCompletionResponseStream) -> PartialStream<R>
where
R: DeserializeOwned + Send + 'static + PartialEq,
{
let s = try_stream! {
let mut buf = String::new();
let mut last: Option<R> = None;
while let Some(chunk_result) = chat.next().await {
let chunk = chunk_result.map_err(|e| LlmWebError::ModelClient(format!("{e}")))?;
let Some(delta) = chunk
.choices
.first()
.and_then(|c| c.delta.content.as_deref())
else {
continue;
};
if delta.is_empty() {
continue;
}
buf.push_str(delta);
let repaired = repair_partial_json(&buf);
let Ok(value) = serde_json::from_str::<R>(&repaired) else {
continue;
};
if last.as_ref() == Some(&value) {
continue;
}
last = Some(value);
yield serde_json::from_str::<R>(&repaired)?;
}
};
Box::pin(s)
}
pub fn repair_partial_json(input: &str) -> String {
let mut stack: Vec<u8> = Vec::new();
let mut in_string = false;
let mut escape = false;
let mut expecting_value: Vec<bool> = Vec::new();
for ch in input.bytes() {
if escape {
escape = false;
continue;
}
if in_string {
match ch {
b'\\' => escape = true,
b'"' => in_string = false,
_ => {}
}
continue;
}
match ch {
b'"' => in_string = true,
b'{' => {
stack.push(b'}');
expecting_value.push(false);
}
b'[' => {
stack.push(b']');
expecting_value.push(true); }
b'}' | b']' => {
stack.pop();
expecting_value.pop();
}
b':' => {
if let Some(last) = expecting_value.last_mut() {
*last = true;
}
}
b',' => {
if let Some(last) = expecting_value.last_mut() {
if *stack.last().unwrap_or(&b'?') == b'}' {
*last = false;
}
}
}
_ => {}
}
}
let mut out = input.trim_end().to_string();
if in_string {
out.push('"');
}
loop {
let trimmed_len = out.trim_end().len();
out.truncate(trimmed_len);
let Some(last) = out.chars().last() else { break };
match last {
',' => {
out.pop();
continue;
}
':' => {
out.push_str("null");
break;
}
_ => {}
}
let in_object = stack.last() == Some(&b'}');
let expecting_v = expecting_value.last().copied().unwrap_or(false);
if in_object && !expecting_v && last == '"' {
if let Some(idx) = find_unescaped_quote_from_end(&out) {
out.truncate(idx);
continue;
}
}
break;
}
for closer in stack.iter().rev() {
out.push(*closer as char);
}
out
}
fn find_unescaped_quote_from_end(s: &str) -> Option<usize> {
let bytes = s.as_bytes();
if bytes.is_empty() || bytes[bytes.len() - 1] != b'"' {
return None;
}
let mut i = bytes.len().checked_sub(1)?;
while i > 0 {
i -= 1;
if bytes[i] == b'"' {
let mut bs = 0usize;
let mut j = i;
while j > 0 && bytes[j - 1] == b'\\' {
bs += 1;
j -= 1;
}
if bs % 2 == 0 {
return Some(i);
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{Value, json};
fn parse(input: &str) -> Value {
let repaired = repair_partial_json(input);
serde_json::from_str(&repaired)
.unwrap_or_else(|e| panic!("repaired {repaired:?} failed to parse: {e}"))
}
#[test]
fn closes_open_object() {
assert_eq!(parse(r#"{"a": 1"#), json!({"a": 1}));
}
#[test]
fn closes_open_array() {
assert_eq!(parse(r#"[1, 2, 3"#), json!([1, 2, 3]));
}
#[test]
fn drops_trailing_comma() {
assert_eq!(parse(r#"[1, 2,"#), json!([1, 2]));
assert_eq!(parse(r#"{"a": 1,"#), json!({"a": 1}));
}
#[test]
fn fills_dangling_colon_with_null() {
assert_eq!(parse(r#"{"a":"#), json!({"a": null}));
}
#[test]
fn closes_open_string_in_value_position() {
assert_eq!(parse(r#"{"a": "hel"#), json!({"a": "hel"}));
}
#[test]
fn drops_dangling_partial_key() {
assert_eq!(parse(r#"{"a": 1, "b"#), json!({"a": 1}));
}
#[test]
fn drops_dangling_complete_key_without_colon() {
assert_eq!(parse(r#"{"a": 1, "b""#), json!({"a": 1}));
}
#[test]
fn handles_nested() {
assert_eq!(
parse(r#"{"top": [{"title": "hi", "n": 1}, {"title": "two"#),
json!({"top": [{"title": "hi", "n": 1}, {"title": "two"}]})
);
}
#[test]
fn escapes_inside_strings_ignored() {
assert_eq!(parse(r#"{"a": "x{y"#), json!({"a": "x{y"}));
}
#[test]
fn passes_through_already_valid_json() {
let s = r#"{"a": 1, "b": [2, 3]}"#;
assert_eq!(repair_partial_json(s), s);
}
}