use futures::Stream;
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
struct EventStreamInner<T, R> {
events: Mutex<VecDeque<T>>,
done: AtomicBool,
result: Mutex<Option<R>>,
waker: Mutex<Option<Waker>>,
}
pub struct EventStream<T, R = T> {
inner: Arc<EventStreamInner<T, R>>,
is_complete: fn(&T) -> bool,
extract_result: fn(T) -> R,
}
impl<T, R> EventStream<T, R>
where
T: Clone + Send + 'static,
R: Send + 'static,
{
pub fn new(is_complete: fn(&T) -> bool, extract_result: fn(T) -> R) -> Self {
Self {
inner: Arc::new(EventStreamInner {
events: Mutex::new(VecDeque::new()),
done: AtomicBool::new(false),
result: Mutex::new(None),
waker: Mutex::new(None),
}),
is_complete,
extract_result,
}
}
fn wake(&self) {
if let Some(waker) = self.inner.waker.lock().take() {
waker.wake();
}
}
pub fn push(&self, event: T) {
if self.inner.done.load(Ordering::SeqCst) {
return;
}
let is_complete = (self.is_complete)(&event);
if is_complete {
self.inner.events.lock().push_back(event.clone());
let result = (self.extract_result)(event);
*self.inner.result.lock() = Some(result);
self.inner.done.store(true, Ordering::SeqCst);
} else {
self.inner.events.lock().push_back(event);
}
self.wake();
}
pub fn end(&self, result: Option<R>) {
if result.is_some() {
*self.inner.result.lock() = result;
}
self.inner.done.store(true, Ordering::SeqCst);
self.wake();
}
pub fn is_done(&self) -> bool {
self.inner.done.load(Ordering::SeqCst)
}
pub async fn result(&self) -> R {
loop {
{
let mut result = self.inner.result.lock();
if let Some(r) = result.take() {
return r;
}
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
pub async fn try_result(&self, timeout: std::time::Duration) -> Option<R> {
tokio::time::timeout(timeout, self.result()).await.ok()
}
}
impl<T, R> Stream for EventStream<T, R>
where
T: Send + Unpin,
{
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
{
let mut queue = this.inner.events.lock();
if let Some(event) = queue.pop_front() {
return Poll::Ready(Some(event));
}
}
if this.inner.done.load(Ordering::SeqCst) {
return Poll::Ready(None);
}
*this.inner.waker.lock() = Some(cx.waker().clone());
{
let mut queue = this.inner.events.lock();
if let Some(event) = queue.pop_front() {
return Poll::Ready(Some(event));
}
}
if this.inner.done.load(Ordering::SeqCst) {
return Poll::Ready(None);
}
Poll::Pending
}
}
impl<T, R> Clone for EventStream<T, R> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
is_complete: self.is_complete,
extract_result: self.extract_result,
}
}
}
pub type AssistantMessageEventStream =
EventStream<crate::types::AssistantMessageEvent, crate::types::AssistantMessage>;
impl AssistantMessageEventStream {
pub fn new_assistant_stream() -> Self {
Self::new(
|event| event.is_complete(),
|event| match event {
crate::types::AssistantMessageEvent::Done { message, .. } => message.clone(),
crate::types::AssistantMessageEvent::Error { error, .. } => error.clone(),
_ => unreachable!("is_complete should only return true for Done/Error"),
},
)
}
}