use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::types::OllamaError;
use crate::Result;
use bytes::Bytes;
use futures::Stream;
use serde::de::DeserializeOwned;
pub trait StreamEventExt<M>: Sized {
fn from_message(msg: M) -> Self;
fn from_error(err: String) -> Self;
fn partial(partial: String, error: Option<String>) -> Self;
}
pub struct GenericStreamParser<S, M, E>
where
S: Stream<Item = Result<Bytes>> + Send + Unpin,
M: DeserializeOwned,
E: StreamEventExt<M>,
{
inner: S,
buffer: Vec<u8>,
_marker: PhantomData<(M, E)>,
}
impl<S, M, E> GenericStreamParser<S, M, E>
where
S: Stream<Item = Result<Bytes>> + Send + Unpin,
M: DeserializeOwned,
E: StreamEventExt<M>,
{
pub fn new(stream: S) -> Self {
Self {
inner: stream,
buffer: Vec::new(),
_marker: PhantomData,
}
}
fn parse_lines(&mut self) -> Option<Result<E>> {
loop {
let newline_pos = self.buffer.iter().position(|&b| b == b'\n')?;
let line_bytes = self.buffer.drain(..=newline_pos).collect::<Vec<u8>>();
let line_str = String::from_utf8_lossy(&line_bytes);
let line_str = line_str.trim();
if line_str.is_empty() {
continue; }
match serde_json::from_str::<M>(line_str) {
Ok(msg) => return Some(Ok(E::from_message(msg))),
Err(e_msg) => {
match serde_json::from_str::<OllamaError>(line_str) {
Ok(err) => return Some(Ok(E::from_error(err.error))),
Err(_) => {
return Some(Ok(E::partial(
line_str.to_string(),
Some(e_msg.to_string()),
)));
}
}
}
}
}
}
}
impl<S, M, E> Stream for GenericStreamParser<S, M, E>
where
S: Stream<Item = Result<Bytes>> + Send + Unpin,
M: DeserializeOwned + Unpin,
E: StreamEventExt<M> + Unpin,
{
type Item = Result<E>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if let Some(event) = this.parse_lines() {
return Poll::Ready(Some(event));
}
if this.buffer.is_empty() {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.buffer.extend_from_slice(&bytes);
continue; }
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => return Poll::Ready(None), Poll::Pending => return Poll::Pending,
}
} else {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.buffer.extend_from_slice(&bytes);
continue;
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
let content = String::from_utf8_lossy(&this.buffer).to_string();
this.buffer.clear();
if !content.trim().is_empty() {
return Poll::Ready(Some(Ok(E::partial(content, None))));
}
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
}