use std::time::Duration;
use error_stack::{Report, ResultExt};
use eventsource_stream::{Event, Eventsource};
use futures::StreamExt;
use crate::{
format::{ResponseInfo, StreamingChatResponse, StreamingResponse, StreamingResponseSender},
providers::{ProviderError, ProviderErrorKind},
};
pub fn stream_sse_to_channel(
response: reqwest::Response,
chunk_tx: StreamingResponseSender,
mut mapper: impl StreamingChunkMapper,
) -> tokio::task::JoinHandle<()> {
tokio::task::spawn(async move {
let mut stream = response.bytes_stream().eventsource();
let mut model: Option<String> = None;
while let Some(event) = stream.next().await {
match event {
Ok(event) => {
let chunk = mapper.process_chunk(&event);
tracing::trace!(chunk = ?chunk);
match chunk {
Ok(None) => continue,
Ok(Some(chunk)) => {
if model.is_none() {
model = chunk.model.clone();
}
let result = chunk_tx
.send_async(Ok(StreamingResponse::Chunk(chunk)))
.await;
if result.is_err() {
tracing::warn!("channel closed early");
return;
}
}
Err(e) => {
chunk_tx.send_async(Err(e)).await.ok();
return;
}
}
}
Err(e) => {
chunk_tx
.send_async(Err(e).change_context(ProviderError {
kind: ProviderErrorKind::ProviderClosedConnection,
status_code: None,
body: None,
latency: Duration::ZERO,
}))
.await
.ok();
return;
}
}
}
chunk_tx
.send_async(Ok(StreamingResponse::ResponseInfo(ResponseInfo {
meta: None,
model: model.unwrap_or_default(),
})))
.await
.ok();
})
}
pub trait StreamingChunkMapper: Send + Sync + 'static {
fn process_chunk(
&mut self,
event: &Event,
) -> Result<Option<StreamingChatResponse>, Report<ProviderError>>;
}