use bytes::{Buf, BytesMut};
use futures_core::Stream;
use pin_project::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::chat::{ChatChunkError, ChatChunkResult, ChatCompletionChunkObject};
#[pin_project]
pub(crate) struct ChunkStream<S>
where
S: Stream<Item = ReqwestStreamItem> + Unpin,
{
#[pin]
stream: S,
buffer: BytesMut,
}
type ReqwestStreamItem = Result<bytes::Bytes, reqwest::Error>;
impl<S> ChunkStream<S>
where
S: Stream<Item = ReqwestStreamItem> + Unpin,
{
pub(crate) fn new(stream: S) -> Self {
ChunkStream {
stream,
buffer: BytesMut::new(),
}
}
}
impl<S> Stream for ChunkStream<S>
where
S: Stream<Item = ReqwestStreamItem> + Unpin,
{
type Item = ChatChunkResult;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<ChatChunkResult>> {
let mut this = self.project();
loop {
if let Some(position) = this
.buffer
.iter()
.position(|b| *b == b'\n')
{
let line = this.buffer.split_to(position);
this.buffer.advance(1); let line = String::from_utf8(line.to_vec())
.map_err(ChatChunkError::StringDecodingError)?;
if line == "data: [DONE]" {
return Poll::Ready(None);
}
if line.is_empty() {
continue;
}
let data = line
.strip_prefix("data: ")
.ok_or_else(|| {
ChatChunkError::DataPrefixMissing(line.clone())
})?;
let chunk =
serde_json::from_str::<ChatCompletionChunkObject>(&data)
.map_err(|error| {
ChatChunkError::DeserializeFailed(
error,
data.to_string(),
)
})?;
return Poll::Ready(Some(Ok(chunk)));
}
match this
.stream
.as_mut()
.poll_next(cx)
{
| Poll::Ready(Some(Ok(chunk))) => {
this.buffer.extend(&chunk);
},
| Poll::Ready(Some(Err(error))) => {
return Poll::Ready(Some(Err(
ChatChunkError::StreamError(error),
)));
},
| Poll::Ready(None) => {
return if this.buffer.is_empty() {
Poll::Ready(None)
} else {
let line = this.buffer.split_off(0);
let line = String::from_utf8(line.to_vec())
.map_err(ChatChunkError::StringDecodingError)?;
if line == "data: [DONE]" {
return Poll::Ready(None);
}
if line.is_empty() {
return Poll::Ready(None);
}
let data = line
.strip_prefix("data: ")
.ok_or_else(|| {
ChatChunkError::DataPrefixMissing(line.clone())
})?;
let chunk = serde_json::from_str::<
ChatCompletionChunkObject,
>(&data)
.map_err(|error| {
ChatChunkError::DeserializeFailed(
error,
data.to_string(),
)
})?;
Poll::Ready(Some(Ok(chunk)))
};
},
| Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
use crate::chat::chat_completion_chunk_object::{
ChatCompletionChunkChoice, ChatCompletionDelta,
};
use crate::chat::{ChatModel, Role};
use bytes::Bytes;
use tokio_stream::StreamExt;
use super::*;
#[tokio::test]
async fn test_stream_line_reader() {
let source = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
"#;
let input_stream = tokio_stream::iter(vec![Ok(Bytes::from(
source,
))]);
let mut stream = ChunkStream::new(input_stream);
assert_eq!(
stream
.next()
.await
.unwrap()
.unwrap(),
ChatCompletionChunkObject {
id: "chatcmpl-123".to_string(),
object: "chat.completion.chunk".to_string(),
created: 1694268190,
model: ChatModel::Gpt35Turbo0125,
system_fingerprint: Some("fp_44709d6fcb".to_string()),
choices: vec![
ChatCompletionChunkChoice {
index: 0,
delta: Some(ChatCompletionDelta {
role: Some(Role::Assistant),
content: Some("".to_string()),
tool_calls: None,
}),
logprobs: None,
finish_reason: None,
}
],
}
);
assert_eq!(
stream
.next()
.await
.unwrap()
.unwrap(),
ChatCompletionChunkObject {
id: "chatcmpl-123".to_string(),
object: "chat.completion.chunk".to_string(),
created: 1694268190,
model: ChatModel::Gpt35Turbo0125,
system_fingerprint: Some("fp_44709d6fcb".to_string()),
choices: vec![
ChatCompletionChunkChoice {
index: 0,
delta: Some(ChatCompletionDelta {
role: None,
content: Some("Hello".to_string()),
tool_calls: None,
}),
logprobs: None,
finish_reason: None,
}
],
}
);
assert_eq!(
stream
.next()
.await
.unwrap()
.unwrap(),
ChatCompletionChunkObject {
id: "chatcmpl-123".to_string(),
object: "chat.completion.chunk".to_string(),
created: 1694268190,
model: ChatModel::Gpt35Turbo0125,
system_fingerprint: Some("fp_44709d6fcb".to_string()),
choices: vec![
ChatCompletionChunkChoice {
index: 0,
delta: Some(ChatCompletionDelta {
role: None,
content: Some("!".to_string()),
tool_calls: None,
}),
logprobs: None,
finish_reason: None,
}
],
}
);
assert_eq!(
stream
.next()
.await
.unwrap()
.unwrap(),
ChatCompletionChunkObject {
id: "chatcmpl-123".to_string(),
object: "chat.completion.chunk".to_string(),
created: 1694268190,
model: ChatModel::Gpt35Turbo0125,
system_fingerprint: Some("fp_44709d6fcb".to_string()),
choices: vec![
ChatCompletionChunkChoice {
index: 0,
delta: Some(ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
}),
logprobs: None,
finish_reason: Some("stop".to_string()),
}
],
}
);
assert_eq!(stream.next().await.is_none(), true);
}
}