Skip to main content

robespierre_events/
lib.rs

1use async_tungstenite::{
2    stream::Stream,
3    tokio::{connect_async, TokioAdapter},
4    tungstenite::Message as TungsteniteMessage,
5    WebSocketStream,
6};
7use futures::FutureExt;
8use robespierre_models::{
9    auth::Session,
10    events::{ClientToServerEvent, ServerToClientEvent},
11    id::ChannelId,
12};
13use std::result::Result as StdResult;
14use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
15use tokio_rustls::client::TlsStream;
16
17pub mod typing;
18
19/// Errors that can occur while working with ws messages / events.
20#[derive(Debug, thiserror::Error)]
21pub enum EventsError {
22    #[error("tungstenite error: {0}")]
23    WsError(#[from] async_tungstenite::tungstenite::Error),
24
25    #[error("serialization / deserialization error: {0}")]
26    DeserializationError(#[from] serde_json::Error),
27
28    #[error("error while authenticating: {0}")]
29    AuthError(String),
30
31    #[error("websocket closed")]
32    Closed,
33}
34
35pub type Result<T = ()> = StdResult<T, EventsError>;
36
37struct ConnectionInternal {
38    stream: WebSocketStream<Stream<TokioAdapter<TcpStream>, TokioAdapter<TlsStream<TcpStream>>>>,
39    closed: bool,
40}
41
42/// A websocket connection.
43pub struct Connection(ConnectionInternal);
44
45/// A value that can be used to authenticate on the websocket, either as a bot or as a non-bot user.
46#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub enum Authentication<'a> {
48    Bot { token: &'a str },
49    User { session_token: &'a str },
50}
51
52impl<'a> From<&'a Session> for Authentication<'a> {
53    fn from(s: &'a Session) -> Self {
54        Self::User {
55            session_token: &s.token.0,
56        }
57    }
58}
59
60#[async_trait::async_trait]
61pub trait RawEventHandler: Send + Sync + Clone + 'static {
62    type Context: 'static;
63    async fn handle(self, ctx: Self::Context, event: ServerToClientEvent);
64}
65
66/// A message to a [`Connection`]
67#[derive(Debug, Copy, Clone)]
68pub enum ConnectionMessage {
69    /// Tells the [`Connection`] to emit a [`ClientToServerEvent::BeginTyping`] event in the given channel.
70    StartTyping { channel: ChannelId },
71    /// Tells the [`Connection`] to emit a [`ClientToServerEvent::EndTyping`] event in the given channel.
72    StopTyping { channel: ChannelId },
73    /// Tells the [`Connection`] to close itself, and return from the loop.
74    Close,
75}
76
77#[derive(Clone, Debug)]
78pub struct ConnectionMessanger(UnboundedSender<ConnectionMessage>);
79
80impl ConnectionMessanger {
81    /// Sends a message to the [`Connection`], describing something it should do.
82    pub fn send(&self, message: ConnectionMessage) {
83        self.0
84            .send(message)
85            .expect("Something went terribly wrong and the receiver closed");
86    }
87}
88
89/// Trait implemented on types that can be passed as a context to [`Connection::run`],
90/// but not necessary for [`RawEventHandler::Context`]
91pub trait Context: Sized + Clone + Send + 'static {
92    /// Gives the context a messanger it can communicate to the [`Connection`] with,
93    /// allowing it to send messages like "BeginTyping", "EndTyping", or tell the connection
94    /// to close itself, for a clean shutdown.
95    fn set_messanger(self, messanger: ConnectionMessanger) -> Self;
96}
97
98impl Connection {
99    /// Connects to the websocket, and authenticates, returning the socket or an error if it failed.
100    pub async fn connect<'a>(auth: impl Into<Authentication<'a>>) -> Result<Self> {
101        Self::connect_with_url(auth, "wss://ws.revolt.chat").await
102    }
103
104    /// Connects to the websocket on the specified url, and authenticates, returning the socket or an error if it failed.
105    ///
106    /// Use if connecting to a self-hosted instance of revolt; otherwise use [Self::connect].
107    pub async fn connect_with_url<'a>(
108        auth: impl Into<Authentication<'a>>,
109        url: &str,
110    ) -> Result<Self> {
111        tracing::debug!("Connecting to websocket on {}", url);
112        let (stream, _response) = connect_async(url).await?;
113        let mut internal = ConnectionInternal {
114            stream,
115            closed: false,
116        };
117        internal.authenticate(auth.into()).await?;
118
119        let connection = Self(internal);
120
121        Ok(connection)
122    }
123
124    /// Runs the "main loop", listening for events on the websocket and
125    /// spawning tokio tasks to handle them, cloning the context, and giving
126    /// it a messanger.
127    ///
128    /// If you intend to implement this yourself, you can
129    /// use [`Connection::get_event`] to get events and [`Connection::hb`]
130    /// to "heartbeat"(send a ping message to the server so it doesn't
131    /// close the socket).
132    pub async fn run<C, H>(mut self, ctx: C, handler: H) -> Result
133    where
134        C: Context,
135        H: RawEventHandler<Context = C>,
136    {
137        let mut int = tokio::time::interval(std::time::Duration::from_secs(15));
138
139        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<ConnectionMessage>();
140
141        enum Event {
142            FromServer(Result<ServerToClientEvent>),
143            ConnectionMessage(Option<ConnectionMessage>),
144            Tick,
145            TypingManagerTick,
146        }
147
148        let mut typing_session_manager = typing::TypingSessionManager::default();
149
150        loop {
151            // Event::FromServer = we got an event from the server, which we should pass to the handler
152            // Event::ConnectionMessage = we got a message from a handler, can be something like send "BeginTyping", "EndTyping" to the ws, try to close the socket
153            // Event::Tick = we didn't get any event, but we have to ping the server or it will close the connection
154            // Event::TypingManagerTick = we didn't get any event, but we have to send all the "BeginTyping" events to the server or it will timeout and close them.
155            let event = futures::select! {
156                event = self.get_event().fuse() => Event::FromServer(event),
157                connection_message = rx.recv().fuse() => Event::ConnectionMessage(connection_message),
158                _ = int.tick().fuse() => Event::Tick,
159                _ = typing_session_manager.tick().fuse() => Event::TypingManagerTick,
160            };
161
162            match event {
163                Event::FromServer(event) => {
164                    let event = event?;
165
166                    let handler = handler.clone();
167                    let ctx = ctx.clone().set_messanger(ConnectionMessanger(tx.clone()));
168
169                    let fut = handler.handle(ctx, event);
170                    tokio::spawn(fut);
171                }
172                Event::ConnectionMessage(Some(message)) => match message {
173                    ConnectionMessage::StartTyping { channel } => {
174                        typing_session_manager.start_typing(channel);
175                        self.start_typing(channel).await?;
176                    }
177                    ConnectionMessage::StopTyping { channel } => {
178                        if typing_session_manager.stop_typing(channel) {
179                            // was removed
180                            self.stop_typing(channel).await?;
181                        }
182                    }
183                    ConnectionMessage::Close => {
184                        self.close().await?;
185                        return Ok(()); // will drop self
186                    }
187                },
188                Event::ConnectionMessage(None) => {
189                    // can never happen as the tx is never moved outside of this function,
190                    // only cloned, and therefore at least one sender is not dropped
191                    // also, the receiver is never dropped / closed
192
193                    // (unless ? propagates the error in which case this block shouldn't be reached)
194                    unreachable!()
195                }
196                Event::Tick => {
197                    self.hb().await?;
198                }
199                Event::TypingManagerTick => {
200                    for session in typing_session_manager.current_sessions() {
201                        self.start_typing(*session).await?;
202                    }
203                }
204            }
205        }
206    }
207
208    /// Sends a ping message to the server, so it doesn't close the connection.
209    pub async fn hb(&mut self) -> Result {
210        self.0.hb().await
211    }
212
213    /// Gets the next event from the server.
214    pub async fn get_event(&mut self) -> Result<ServerToClientEvent> {
215        self.0.get_event().await
216    }
217
218    /// Sends a [`ClientToServerEvent::BeginTyping`] event, for the given channel.
219    ///
220    /// Has a timeout of ~3 seconds, so if you want it to display "... is typing"
221    /// for longer than that, you have to call it again.
222    pub async fn start_typing(&mut self, channel: ChannelId) -> Result {
223        self.0
224            .send_event(ClientToServerEvent::BeginTyping { channel })
225            .await
226    }
227
228    /// Sends a [`ClientToServerEvent::EndTyping`] event, for the given channel.
229    pub async fn stop_typing(&mut self, channel: ChannelId) -> Result {
230        self.0
231            .send_event(ClientToServerEvent::EndTyping { channel })
232            .await
233    }
234
235    /// Closes the websocket.
236    pub async fn close(mut self) -> Result {
237        self.0.close().await?;
238
239        Ok(())
240    }
241}
242
243impl ConnectionInternal {
244    async fn hb(&mut self) -> Result {
245        self.send_event(ClientToServerEvent::Ping { data: 0 })
246            .await?;
247
248        Ok(())
249    }
250
251    async fn authenticate(&mut self, auth: Authentication<'_>) -> Result {
252        tracing::debug!("Authenticating");
253        self.send_event(match &auth {
254            Authentication::Bot { token } => ClientToServerEvent::Authenticate {
255                token: token.to_string(),
256            },
257            Authentication::User { session_token } => ClientToServerEvent::Authenticate {
258                token: session_token.to_string(),
259            },
260        })
261        .await?;
262
263        let msg = self.get_event().await?;
264
265        match msg {
266            ServerToClientEvent::Authenticated => {}
267            ServerToClientEvent::Error { error } => {
268                tracing::error!("Error while authenticating: {}", error);
269
270                return Err(EventsError::AuthError(error));
271            }
272            msg => {
273                tracing::info!("Unexpected message after auth: {:?}", msg);
274            }
275        }
276
277        Ok(())
278    }
279
280    async fn send_event(&mut self, message: ClientToServerEvent) -> Result {
281        use futures::sink::SinkExt;
282
283        let json = serde_json::to_string(&message)?;
284
285        tracing::debug!("[>] {}", &json);
286
287        self.stream.send(TungsteniteMessage::text(json)).await?;
288
289        Ok(())
290    }
291
292    async fn close(&mut self) -> Result {
293        self.stream.close(None).await?;
294        self.closed = true;
295
296        Ok(())
297    }
298
299    async fn get_event(&mut self) -> Result<ServerToClientEvent> {
300        if self.closed {
301            return Err(EventsError::Closed);
302        }
303
304        use async_std::stream::StreamExt;
305        let msg: TungsteniteMessage = self
306            .stream
307            .next()
308            .await
309            .expect("Last message in ws without closing")?;
310
311        match msg {
312            TungsteniteMessage::Text(json) => {
313                tracing::debug!("[<] {}", &json);
314                return Ok(serde_json::from_str(&json)?);
315            }
316            TungsteniteMessage::Binary(b) => tracing::debug!("Got binary: {:?}", &b),
317            TungsteniteMessage::Ping(ping) => tracing::debug!("Got ping: {:?}", &ping),
318            TungsteniteMessage::Pong(pong) => tracing::debug!("Got pong: {:?}", &pong),
319            TungsteniteMessage::Close(close) => {
320                tracing::debug!("Got close: {:?}", close);
321                self.closed = true;
322
323                return Err(EventsError::Closed);
324            }
325        };
326
327        unimplemented!()
328    }
329}