use super::types::ContextBackend;
use crate::error::AgentError;
pub struct AgentSendStream<B: ContextBackend> {
inner:
std::pin::Pin<Box<dyn futures_core::Stream<Item = Result<B::Response, AgentError>> + Send>>,
chunks: Vec<B::Response>,
}
impl<B: ContextBackend> AgentSendStream<B> {
pub(crate) fn new(
inner: impl futures_core::Stream<Item = Result<B::Response, AgentError>> + Send + 'static,
) -> Self {
Self {
inner: Box::pin(inner),
chunks: Vec::new(),
}
}
pub fn take_chunks(&mut self) -> Vec<B::Response> {
std::mem::take(&mut self.chunks)
}
}
impl<B: ContextBackend> futures_core::Stream for AgentSendStream<B> {
type Item = Result<B::Response, AgentError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
match this.inner.as_mut().poll_next(cx) {
std::task::Poll::Ready(Some(Ok(resp))) => {
this.chunks.push(resp.clone());
std::task::Poll::Ready(Some(Ok(resp)))
}
std::task::Poll::Ready(Some(Err(e))) => std::task::Poll::Ready(Some(Err(e))),
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
impl<B: ContextBackend> kameo::Reply for AgentSendStream<B> {
type Ok = Self;
type Error = std::convert::Infallible;
type Value = Self;
fn to_result(self) -> Result<Self::Ok, Self::Error> {
Ok(self)
}
fn into_any_err(self) -> Option<Box<dyn kameo::reply::ReplyError>> {
None
}
fn into_value(self) -> Self::Value {
self
}
}