1use async_trait::async_trait;
2use crb_agent::{Address, Agent, AgentContext, AgentSession, Envelope, Event, OnEvent, TheEvent};
3use crb_runtime::{ManagedContext, ReachableContext};
4use derive_more::{Deref, DerefMut};
5use futures::{Stream, StreamExt, future::select, stream::BoxStream};
6use futures_util::{future::Either, stream::SelectAll};
7
8#[derive(Deref, DerefMut)]
9pub struct StreamSession<A: Agent> {
10 #[deref]
11 #[deref_mut]
12 session: AgentSession<A>,
13 streams: SelectAll<BoxStream<'static, Envelope<A>>>,
14}
15
16impl<A: Agent> Default for StreamSession<A> {
17 fn default() -> Self {
18 Self {
19 session: AgentSession::default(),
20 streams: SelectAll::default(),
21 }
22 }
23}
24
25impl<A: Agent> ReachableContext for StreamSession<A> {
26 type Address = Address<A>;
27
28 fn address(&self) -> &Self::Address {
29 self.session.address()
30 }
31}
32
33impl<A: Agent> ManagedContext for StreamSession<A> {
34 fn is_alive(&self) -> bool {
35 self.session.is_alive()
36 }
37
38 fn shutdown(&mut self) {
39 self.session.shutdown()
40 }
41
42 fn stop(&mut self) {
43 self.session.stop();
44 }
45}
46
47#[async_trait]
48impl<A: Agent> AgentContext<A> for StreamSession<A> {
49 fn session(&mut self) -> &mut AgentSession<A> {
50 &mut self.session
51 }
52
53 async fn next_envelope(&mut self) -> Option<Envelope<A>> {
54 let next_fut = self.session.next_envelope();
55 if self.streams.is_empty() {
56 next_fut.await
57 } else {
58 let event = self.streams.next();
59 let either = select(next_fut, event).await;
60 match either {
61 Either::Left((None, _)) => {
62 self.streams.clear();
63 None
64 }
65 Either::Right((None, next_fut)) => next_fut.await,
66 Either::Left((Some(event), _)) => Some(event),
67 Either::Right((Some(event), _)) => Some(event),
68 }
69 }
70 }
71}
72
73impl<A: Agent> StreamSession<A> {
74 pub fn consume<E, S>(&mut self, stream: S)
75 where
76 A: OnEvent<E>,
77 E: TheEvent,
78 S: Stream<Item = E> + Send + Unpin + 'static,
79 {
80 let stream = stream.map(Event::envelope::<A>);
81 self.streams.push(stream.boxed());
82 }
83
84 pub fn consume_events<S>(&mut self, stream: S)
85 where
86 S: Stream<Item = Envelope<A>> + Send + Unpin + 'static,
87 {
88 self.streams.push(stream.boxed());
89 }
90}