use std::pin::Pin;
use std::task::{Context, Poll};
use futures::{Stream, channel::mpsc, stream::BoxStream};
use crate::{
error::ChatFailure,
types::{
messages::{content::Content, parts::PartEnum},
response::StreamEvent,
},
};
#[allow(clippy::large_enum_variant)]
pub enum Input {
Item(PartEnum),
Content(Content),
Cancel,
}
pub trait IntoInput {
fn into_input(self) -> Input;
}
impl IntoInput for Input {
fn into_input(self) -> Input {
self
}
}
impl IntoInput for PartEnum {
fn into_input(self) -> Input {
Input::Item(self)
}
}
impl IntoInput for Content {
fn into_input(self) -> Input {
Input::Content(self)
}
}
impl IntoInput for String {
fn into_input(self) -> Input {
Input::Item(PartEnum::from(self))
}
}
impl IntoInput for &str {
fn into_input(self) -> Input {
Input::Item(PartEnum::from(self))
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum SendError {
#[error("input is disconnected: the output stream was dropped or finished")]
Disconnected,
}
#[derive(Clone)]
pub struct InputStream {
pub(crate) tx: mpsc::UnboundedSender<Input>,
}
impl InputStream {
pub fn send(&self, input: impl IntoInput) -> Result<(), SendError> {
self.tx
.unbounded_send(input.into_input())
.map_err(|_| SendError::Disconnected)
}
pub fn cancel(&self) {
let _ = self.tx.unbounded_send(Input::Cancel);
}
pub fn is_connected(&self) -> bool {
!self.tx.is_closed()
}
}
pub struct OutputStream<'a> {
pub(crate) inner: BoxStream<'a, Result<StreamEvent, ChatFailure>>,
}
impl Stream for OutputStream<'_> {
type Item = Result<StreamEvent, ChatFailure>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().inner.as_mut().poll_next(cx)
}
}
pub struct ChatStream<'a> {
pub(crate) input: InputStream,
pub(crate) output: OutputStream<'a>,
}
impl<'a> ChatStream<'a> {
pub fn send(&self, input: impl IntoInput) -> Result<(), SendError> {
self.input.send(input)
}
pub fn cancel(&self) {
self.input.cancel();
}
pub fn input(&self) -> InputStream {
self.input.clone()
}
pub fn output(self) -> OutputStream<'a> {
self.output
}
pub fn split(self) -> (InputStream, OutputStream<'a>) {
(self.input.clone(), self.output)
}
}
impl Stream for ChatStream<'_> {
type Item = Result<StreamEvent, ChatFailure>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.get_mut().output).poll_next(cx)
}
}