use bytes::Bytes;
use futures::stream::TryStreamExt;
use futures::{Future, Stream};
use reqwest::{RequestBuilder, Response};
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::error::{BoxError, Error as GenaiError};
#[allow(clippy::type_complexity)]
pub struct WebStream {
stream_mode: StreamMode,
reqwest_builder: Option<RequestBuilder>,
response_future: Option<Pin<Box<dyn Future<Output = Result<Response, BoxError>> + Send>>>,
bytes_stream: Option<Pin<Box<dyn Stream<Item = Result<Bytes, BoxError>> + Send>>>,
partial_message: Option<String>,
remaining_messages: Option<VecDeque<String>>,
}
pub enum StreamMode {
Delimiter(&'static str),
PrettyJsonArray,
}
impl WebStream {
pub fn new_with_delimiter(reqwest_builder: RequestBuilder, message_delimiter: &'static str) -> Self {
Self {
stream_mode: StreamMode::Delimiter(message_delimiter),
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
}
}
pub fn new_with_pretty_json_array(reqwest_builder: RequestBuilder) -> Self {
Self {
stream_mode: StreamMode::PrettyJsonArray,
reqwest_builder: Some(reqwest_builder),
response_future: None,
bytes_stream: None,
partial_message: None,
remaining_messages: None,
}
}
}
impl Stream for WebStream {
type Item = Result<String, BoxError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(ref mut remaining_messages) = this.remaining_messages
&& let Some(msg) = remaining_messages.pop_front()
{
return Poll::Ready(Some(Ok(msg)));
}
loop {
if let Some(ref mut fut) = this.response_future {
match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(response)) => {
let status = response.status();
if !status.is_success() {
this.response_future = None;
let error_future = async move {
let body = response
.text()
.await
.unwrap_or_else(|e| format!("Failed to read error body: {}", e));
Err::<Response, BoxError>(Box::new(GenaiError::HttpError {
status,
canonical_reason: status.canonical_reason().unwrap_or("Unknown").to_string(),
body,
}))
};
this.response_future = Some(Box::pin(error_future));
continue;
}
let bytes_stream = response.bytes_stream().map_err(|e| Box::new(e) as BoxError);
this.bytes_stream = Some(Box::pin(bytes_stream));
this.response_future = None;
}
Poll::Ready(Err(e)) => {
this.response_future = None;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
}
}
if let Some(ref mut stream) = this.bytes_stream {
match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
let buff_string = match String::from_utf8(bytes.to_vec()) {
Ok(s) => s,
Err(e) => return Poll::Ready(Some(Err(Box::new(e) as BoxError))),
};
let buff_response = match this.stream_mode {
StreamMode::Delimiter(delimiter) => {
process_buff_string_delimited(buff_string, &mut this.partial_message, delimiter)
}
StreamMode::PrettyJsonArray => {
new_with_pretty_json_array(buff_string, &mut this.partial_message)
}
};
let BuffResponse {
mut first_message,
next_messages,
candidate_message,
} = buff_response?;
if let Some(next_messages) = next_messages {
this.remaining_messages.get_or_insert(VecDeque::new()).extend(next_messages);
}
if let Some(candidate_message) = candidate_message {
if this.partial_message.is_some() {
tracing::warn!("GENAI - WARNING - partial_message is not none");
}
this.partial_message = Some(candidate_message);
}
if let Some(first_message) = first_message.take() {
return Poll::Ready(Some(Ok(first_message)));
} else {
continue;
}
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
if let Some(partial) = this.partial_message.take()
&& !partial.is_empty()
{
return Poll::Ready(Some(Ok(partial)));
}
this.bytes_stream = None;
}
Poll::Pending => return Poll::Pending,
}
}
if let Some(reqwest_builder) = this.reqwest_builder.take() {
let fut = async move { reqwest_builder.send().await.map_err(|e| Box::new(e) as BoxError) };
this.response_future = Some(Box::pin(fut));
continue;
}
return Poll::Ready(None);
}
}
}
struct BuffResponse {
first_message: Option<String>,
next_messages: Option<Vec<String>>,
candidate_message: Option<String>,
}
fn new_with_pretty_json_array(
buff_string: String,
partial_message: &mut Option<String>,
) -> Result<BuffResponse, crate::webc::Error> {
let mut buff_str = buff_string.as_str();
let mut messages: Vec<String> = Vec::new();
let full_string_holder: String;
if let Some(partial) = partial_message.take() {
full_string_holder = format!("{}{}", partial, buff_str);
buff_str = full_string_holder.as_str();
}
let mut depth = 0;
let mut in_string = false;
let mut escape = false;
let mut start_idx = 0;
let mut last_idx = 0;
for (idx, c) in buff_str.char_indices() {
if in_string {
if escape {
escape = false;
} else if c == '\\' {
escape = true;
} else if c == '"' {
in_string = false;
}
} else {
match c {
'"' => in_string = true,
'{' => {
if depth == 0 {
start_idx = idx;
}
depth += 1;
}
'}' => {
depth -= 1;
if depth == 0 {
let json_str = &buff_str[start_idx..idx + 1];
if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
messages.push(json_str.to_string());
} else {
tracing::warn!("WebStream: Extracted block failed JSON validation: {}", json_str);
}
last_idx = idx + 1;
}
}
'[' => {
if depth == 0 {
messages.push("[".to_string());
last_idx = idx + 1;
}
}
']' => {
if depth == 0 {
messages.push("]".to_string());
last_idx = idx + 1;
}
}
_ => {
}
}
}
}
if last_idx < buff_str.len() {
let remaining = &buff_str[last_idx..];
if !remaining.trim().is_empty() {
*partial_message = Some(remaining.to_string());
}
}
let first_message = if !messages.is_empty() {
Some(messages[0].to_string())
} else {
None
};
let next_messages = if messages.len() > 1 {
Some(messages[1..].to_vec())
} else {
None
};
Ok(BuffResponse {
first_message,
next_messages,
candidate_message: partial_message.take(),
})
}
fn process_buff_string_delimited(
buff_string: String,
partial_message: &mut Option<String>,
delimiter: &str,
) -> Result<BuffResponse, crate::webc::Error> {
let full_string = if let Some(partial) = partial_message.take() {
format!("{partial}{buff_string}")
} else {
buff_string
};
let mut parts: Vec<String> = full_string.split(delimiter).map(|s| s.to_string()).collect();
let candidate_message = parts.pop();
let mut messages: Vec<String> = parts.into_iter().filter(|s| !s.is_empty()).collect();
let mut first_message = None;
let mut next_messages = None;
if !messages.is_empty() {
first_message = Some(messages.remove(0));
if !messages.is_empty() {
next_messages = Some(messages);
}
}
Ok(BuffResponse {
first_message,
next_messages,
candidate_message,
})
}