battleware_client/
events.rs

1use crate::{Error, Result};
2use battleware_types::{
3    api::{Events, Update},
4    Identity, Seed, NAMESPACE,
5};
6use commonware_codec::ReadExt;
7use futures_util::{Stream as FutStream, StreamExt};
8use tokio::sync::mpsc;
9use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
10use tracing::{debug, error};
11
12/// Stream of events from the WebSocket connection
13pub struct Stream<T: ReadExt + Send + Sync + 'static> {
14    receiver: mpsc::UnboundedReceiver<Result<T>>,
15    _handle: tokio::task::JoinHandle<()>,
16}
17
18/// Trait for verifying consensus messages
19pub trait Verifiable {
20    fn verify(&self, identity: &Identity) -> bool;
21}
22
23impl Verifiable for Seed {
24    fn verify(&self, identity: &Identity) -> bool {
25        self.verify(NAMESPACE, identity)
26    }
27}
28
29impl Verifiable for Events {
30    fn verify(&self, identity: &Identity) -> bool {
31        self.verify(identity)
32    }
33}
34
35impl Verifiable for Update {
36    fn verify(&self, identity: &Identity) -> bool {
37        match self {
38            Update::Seed(seed) => seed.verify(NAMESPACE, identity),
39            Update::Events(events) => events.verify(identity),
40            Update::FilteredEvents(events) => events.verify(identity),
41        }
42    }
43}
44
45impl<T: ReadExt + Send + Sync + 'static> Stream<T> {
46    pub(crate) fn new<S>(mut ws: WebSocketStream<S>) -> Self
47    where
48        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
49    {
50        let (tx, rx) = mpsc::unbounded_channel();
51
52        let handle = tokio::spawn(async move {
53            while let Some(msg) = ws.next().await {
54                match msg {
55                    Ok(Message::Binary(data)) => {
56                        debug!("Received binary message: {} bytes", data.len());
57                        let mut buf = data.as_slice();
58                        match T::read(&mut buf) {
59                            Ok(event) => {
60                                if tx.send(Ok(event)).is_err() {
61                                    break; // Receiver dropped
62                                }
63                            }
64                            Err(e) => {
65                                error!("Failed to decode event: {}", e);
66                                let err = Error::InvalidData(e);
67                                if tx.send(Err(err)).is_err() {
68                                    break;
69                                }
70                            }
71                        }
72                    }
73                    Ok(Message::Close(_)) => {
74                        debug!("WebSocket closed");
75                        let _ = tx.send(Err(Error::ConnectionClosed));
76                        break;
77                    }
78                    Ok(_) => {} // Ignore other message types
79                    Err(e) => {
80                        error!("WebSocket error: {}", e);
81                        let _ = tx.send(Err(e.into()));
82                        break;
83                    }
84                }
85            }
86        });
87
88        Self {
89            receiver: rx,
90            _handle: handle,
91        }
92    }
93
94    pub(crate) fn new_with_verifier<S>(mut ws: WebSocketStream<S>, identity: Identity) -> Self
95    where
96        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
97        T: Verifiable,
98    {
99        let (tx, rx) = mpsc::unbounded_channel();
100
101        let handle = tokio::spawn(async move {
102            while let Some(msg) = ws.next().await {
103                match msg {
104                    Ok(Message::Binary(data)) => {
105                        debug!("Received binary message: {} bytes", data.len());
106                        let mut buf = data.as_slice();
107                        match T::read(&mut buf) {
108                            Ok(event) => {
109                                // Verify the message
110                                if !event.verify(&identity) {
111                                    error!("Failed to verify consensus message");
112                                    let err = Error::InvalidSignature;
113                                    if tx.send(Err(err)).is_err() {
114                                        break;
115                                    }
116                                    continue;
117                                }
118
119                                if tx.send(Ok(event)).is_err() {
120                                    break; // Receiver dropped
121                                }
122                            }
123                            Err(e) => {
124                                error!("Failed to decode event: {}", e);
125                                let err = Error::InvalidData(e);
126                                if tx.send(Err(err)).is_err() {
127                                    break;
128                                }
129                            }
130                        }
131                    }
132                    Ok(Message::Close(_)) => {
133                        debug!("WebSocket closed");
134                        let _ = tx.send(Err(Error::ConnectionClosed));
135                        break;
136                    }
137                    Ok(_) => {} // Ignore other message types
138                    Err(e) => {
139                        error!("WebSocket error: {}", e);
140                        let _ = tx.send(Err(e.into()));
141                        break;
142                    }
143                }
144            }
145        });
146
147        Self {
148            receiver: rx,
149            _handle: handle,
150        }
151    }
152
153    /// Receive the next event from the stream
154    pub async fn next(&mut self) -> Option<Result<T>> {
155        self.receiver.recv().await
156    }
157}
158
159impl<T: ReadExt + Send + Sync + 'static> FutStream for Stream<T> {
160    type Item = Result<T>;
161
162    fn poll_next(
163        mut self: std::pin::Pin<&mut Self>,
164        cx: &mut std::task::Context<'_>,
165    ) -> std::task::Poll<Option<Self::Item>> {
166        self.receiver.poll_recv(cx)
167    }
168}