use futures_core::Stream;
use serde::de::DeserializeOwned;
use tokio_stream::StreamExt;
use crate::error::OllamaError;
const MAX_LINE_SIZE: usize = 10 * 1024 * 1024;
pub fn ndjson_stream<T: DeserializeOwned>(
response: reqwest::Response,
) -> impl Stream<Item = crate::error::Result<T>> {
async_stream::try_stream! {
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk?;
let chunk_str = String::from_utf8_lossy(&chunk);
if buffer.len() + chunk_str.len() > MAX_LINE_SIZE {
Err(OllamaError::LineTooLarge { max_bytes: MAX_LINE_SIZE })?;
}
buffer.push_str(&chunk_str);
while let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].trim().to_string();
buffer.drain(..=newline_pos);
if line.is_empty() {
continue;
}
let item: T = serde_json::from_str(&line)?;
yield item;
}
}
let remaining = buffer.trim().to_string();
if !remaining.is_empty() {
let item: T = serde_json::from_str(&remaining)?;
yield item;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::chat::ChatStreamChunk;
#[tokio::test]
async fn ndjson_stream_parses_lines() {
let body = r#"{"model":"m","created_at":"t","message":{"role":"assistant","content":"hi"},"done":false}
{"model":"m","created_at":"t","message":{"role":"assistant","content":"!"},"done":true}
"#;
let response = http::Response::builder().status(200).body(body).unwrap();
let response = reqwest::Response::from(response);
let stream = ndjson_stream::<ChatStreamChunk>(response);
let items: Vec<_> = tokio_stream::StreamExt::collect(stream).await;
assert_eq!(items.len(), 2);
assert!(items[0].is_ok());
assert!(items[1].is_ok());
assert!(items[1].as_ref().unwrap().done);
}
}