use bytes::Bytes;
use futures::stream::TryStreamExt;
use futures::{Future, Stream};
use reqwest::{RequestBuilder, Response};
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::error::{BoxError, Error as GenaiError};
#[allow(clippy::type_complexity)]
pub struct WebStream {
stream_mode: StreamMode,
reqwest_builder: Option<RequestBuilder>,
response_future: Option<Pin<Box<dyn Future<Output = Result<Response, BoxError>> + Send>>>,
bytes_stream: Option<Pin<Box<dyn Stream<Item = Result<Bytes, BoxError>> + Send>>>,
partial_message: Option<String>,
remaining_messages: Option<VecDeque<String>>,
utf8_carry: Vec<u8>,
}
pub enum StreamMode {
Delimiter(&'static str),
Sse,
}
impl WebStream {
pub fn new_with_delimiter(reqwest_builder: RequestBuilder, message_delimiter: &'static str) -> Self {
Self {
stream_mode: StreamMode::Delimiter(message_delimiter),
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
utf8_carry: Vec::new(),
}
}
pub fn new_with_sse(reqwest_builder: RequestBuilder) -> Self {
Self {
stream_mode: StreamMode::Sse,
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
utf8_carry: Vec::new(),
}
}
}
impl Stream for WebStream {
type Item = Result<String, BoxError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(ref mut remaining_messages) = this.remaining_messages
&& let Some(msg) = remaining_messages.pop_front()
{
return Poll::Ready(Some(Ok(msg)));
}
loop {
if let Some(ref mut fut) = this.response_future {
match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(response)) => {
let status = response.status();
if !status.is_success() {
this.response_future = None;
let error_future = async move {
let body = response
.text()
.await
.unwrap_or_else(|e| format!("Failed to read error body: {}", e));
Err::<Response, BoxError>(Box::new(GenaiError::HttpError {
status,
canonical_reason: status.canonical_reason().unwrap_or("Unknown").to_string(),
body,
}))
};
this.response_future = Some(Box::pin(error_future));
continue;
}
let bytes_stream = response.bytes_stream().map_err(|e| Box::new(e) as BoxError);
this.bytes_stream = Some(Box::pin(bytes_stream));
this.response_future = None;
}
Poll::Ready(Err(e)) => {
this.response_future = None;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
if let Some(ref mut stream) = this.bytes_stream {
match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
let mut raw = std::mem::take(&mut this.utf8_carry);
raw.extend_from_slice(&bytes);
let valid_up_to = match std::str::from_utf8(&raw) {
Ok(_) => raw.len(),
Err(e) => {
if e.error_len().is_some() {
return Poll::Ready(Some(Err(
Box::new(String::from_utf8(raw).unwrap_err()) as BoxError
)));
}
e.valid_up_to()
}
};
this.utf8_carry = raw[valid_up_to..].to_vec();
let buff_string = String::from_utf8(raw[..valid_up_to].to_vec()).unwrap();
let buff_response = match this.stream_mode {
StreamMode::Delimiter(delimiter) => {
process_buff_string_delimited(buff_string, &mut this.partial_message, delimiter)
}
StreamMode::Sse => process_buff_string_sse(buff_string, &mut this.partial_message),
};
let BuffResponse {
mut first_message,
next_messages,
candidate_message,
} = buff_response?;
if let Some(next_messages) = next_messages {
this.remaining_messages.get_or_insert(VecDeque::new()).extend(next_messages);
}
if let Some(candidate_message) = candidate_message {
if this.partial_message.is_some() {
tracing::warn!("GENAI - WARNING - partial_message is not none");
}
this.partial_message = Some(candidate_message);
}
if let Some(first_message) = first_message.take() {
return Poll::Ready(Some(Ok(first_message)));
} else {
continue;
}
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
if let Some(partial) = this.partial_message.take()
&& !partial.is_empty()
{
return Poll::Ready(Some(Ok(partial)));
}
this.bytes_stream = None;
}
Poll::Pending => return Poll::Pending,
}
}
if let Some(reqwest_builder) = this.reqwest_builder.take() {
let fut = async move { reqwest_builder.send().await.map_err(|e| Box::new(e) as BoxError) };
this.response_future = Some(Box::pin(fut));
continue;
}
return Poll::Ready(None);
}
}
}
struct BuffResponse {
first_message: Option<String>,
next_messages: Option<Vec<String>>,
candidate_message: Option<String>,
}
fn process_buff_string_sse(
buff_string: String,
partial_message: &mut Option<String>,
) -> Result<BuffResponse, crate::webc::Error> {
let full_string = match partial_message.take() {
Some(partial) => format!("{partial}{buff_string}"),
None => buff_string,
};
let normalized = full_string.replace("\r\n", "\n").replace('\r', "\n");
process_buff_string_delimited(normalized, partial_message, "\n\n")
}
fn process_buff_string_delimited(
buff_string: String,
partial_message: &mut Option<String>,
delimiter: &str,
) -> Result<BuffResponse, crate::webc::Error> {
let full_string = if let Some(partial) = partial_message.take() {
format!("{partial}{buff_string}")
} else {
buff_string
};
let mut parts: Vec<String> = full_string.split(delimiter).map(|s| s.to_string()).collect();
let candidate_message = parts.pop();
let mut messages: Vec<String> = parts.into_iter().filter(|s| !s.is_empty()).collect();
let mut first_message = None;
let mut next_messages = None;
if !messages.is_empty() {
first_message = Some(messages.remove(0));
if !messages.is_empty() {
next_messages = Some(messages);
}
}
Ok(BuffResponse {
first_message,
next_messages,
candidate_message,
})
}