use std::pin::Pin;
use futures::{Stream, StreamExt, stream};
use tracing::info;
use crate::{
client::http::HttpClient,
model::{chat_stream_response::ChatStreamResponse, traits::SseStreamable},
};
pub trait StreamChatLikeExt: SseStreamable + HttpClient {
fn stream_for_each<'a, F, Fut>(
&'a mut self,
mut on_chunk: F,
) -> impl core::future::Future<Output = crate::ZaiResult<()>> + 'a
where
F: FnMut(ChatStreamResponse) -> Fut + 'a,
Fut: core::future::Future<Output = crate::ZaiResult<()>> + 'a,
{
async move {
let resp = self.post().await?;
let mut stream = resp.bytes_stream();
let mut buf: Vec<u8> = Vec::new();
while let Some(next) = stream.next().await {
let bytes = match next {
Ok(b) => b,
Err(e) => {
return Err(crate::client::error::ZaiError::NetworkError(
std::sync::Arc::new(e),
));
},
};
let lines = crate::model::sse_parser::extract_sse_data_lines(&mut buf, &bytes);
for rest in lines {
info!("SSE data: {}", String::from_utf8_lossy(&rest));
if rest == b"[DONE]" {
return Ok(());
}
if let Ok(chunk) = serde_json::from_slice::<ChatStreamResponse>(&rest) {
on_chunk(chunk).await?;
}
}
}
Ok(())
}
}
fn to_stream<'a>(
&'a mut self,
) -> impl core::future::Future<
Output = crate::ZaiResult<
Pin<Box<dyn Stream<Item = crate::ZaiResult<ChatStreamResponse>> + Send + 'static>>,
>,
> + 'a {
async move {
let resp = self.post().await?;
let byte_stream = resp.bytes_stream();
let s = byte_stream;
let out = stream::unfold((s, Vec::<u8>::new()), |(mut s, mut buf)| async move {
loop {
match s.next().await {
Some(Ok(bytes)) => {
let lines =
crate::model::sse_parser::extract_sse_data_lines(&mut buf, &bytes);
for rest in lines {
info!("SSE data: {}", String::from_utf8_lossy(&rest));
if rest == b"[DONE]" {
return None; }
if let Ok(item) =
serde_json::from_slice::<ChatStreamResponse>(&rest)
{
return Some((Ok(item), (s, buf)));
}
}
},
Some(Err(e)) => {
return Some((
Err(crate::client::error::ZaiError::NetworkError(
std::sync::Arc::new(e),
)),
(s, buf),
));
},
None => return None,
}
}
})
.boxed();
Ok(out)
}
}
}