use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::Stream;
use pin_project_lite::pin_project;
use reqwest::Response;
use crate::error::Error;
use crate::types::ChatCompletionChunk;
pin_project! {
pub struct ChatCompletionStream {
buffer: String,
done: bool,
#[pin]
byte_stream: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
}
}
impl ChatCompletionStream {
pub(crate) fn new(response: Response) -> Self {
let byte_stream = response.bytes_stream();
Self {
buffer: String::new(),
done: false,
byte_stream: Box::pin(byte_stream),
}
}
}
impl Stream for ChatCompletionStream {
type Item = Result<ChatCompletionChunk, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.done {
return Poll::Ready(None);
}
loop {
if let Some(result) = try_parse_sse_event(this.buffer, this.done) {
return Poll::Ready(Some(result));
}
match this.byte_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
match std::str::from_utf8(&bytes) {
Ok(text) => this.buffer.push_str(text),
Err(e) => {
return Poll::Ready(Some(Err(Error::Stream(format!(
"invalid UTF-8 in SSE stream: {e}"
)))));
}
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::Http(e))));
}
Poll::Ready(None) => {
*this.done = true;
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
fn try_parse_sse_event(
buffer: &mut String,
done: &mut bool,
) -> Option<Result<ChatCompletionChunk, Error>> {
loop {
let newline_pos = buffer.find('\n')?;
let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
buffer.drain(..=newline_pos);
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(payload) = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:")) {
let payload = payload.trim();
if payload == "[DONE]" {
*done = true;
return None;
}
match serde_json::from_str::<ChatCompletionChunk>(payload) {
Ok(chunk) => return Some(Ok(chunk)),
Err(e) => return Some(Err(Error::Json(e))),
}
}
}
}