robespierre_events/
lib.rs1use async_tungstenite::{
2 stream::Stream,
3 tokio::{connect_async, TokioAdapter},
4 tungstenite::Message as TungsteniteMessage,
5 WebSocketStream,
6};
7use futures::FutureExt;
8use robespierre_models::{
9 auth::Session,
10 events::{ClientToServerEvent, ServerToClientEvent},
11 id::ChannelId,
12};
13use std::result::Result as StdResult;
14use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
15use tokio_rustls::client::TlsStream;
16
17pub mod typing;
18
19#[derive(Debug, thiserror::Error)]
21pub enum EventsError {
22 #[error("tungstenite error: {0}")]
23 WsError(#[from] async_tungstenite::tungstenite::Error),
24
25 #[error("serialization / deserialization error: {0}")]
26 DeserializationError(#[from] serde_json::Error),
27
28 #[error("error while authenticating: {0}")]
29 AuthError(String),
30
31 #[error("websocket closed")]
32 Closed,
33}
34
35pub type Result<T = ()> = StdResult<T, EventsError>;
36
37struct ConnectionInternal {
38 stream: WebSocketStream<Stream<TokioAdapter<TcpStream>, TokioAdapter<TlsStream<TcpStream>>>>,
39 closed: bool,
40}
41
42pub struct Connection(ConnectionInternal);
44
45#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub enum Authentication<'a> {
48 Bot { token: &'a str },
49 User { session_token: &'a str },
50}
51
52impl<'a> From<&'a Session> for Authentication<'a> {
53 fn from(s: &'a Session) -> Self {
54 Self::User {
55 session_token: &s.token.0,
56 }
57 }
58}
59
60#[async_trait::async_trait]
61pub trait RawEventHandler: Send + Sync + Clone + 'static {
62 type Context: 'static;
63 async fn handle(self, ctx: Self::Context, event: ServerToClientEvent);
64}
65
66#[derive(Debug, Copy, Clone)]
68pub enum ConnectionMessage {
69 StartTyping { channel: ChannelId },
71 StopTyping { channel: ChannelId },
73 Close,
75}
76
77#[derive(Clone, Debug)]
78pub struct ConnectionMessanger(UnboundedSender<ConnectionMessage>);
79
80impl ConnectionMessanger {
81 pub fn send(&self, message: ConnectionMessage) {
83 self.0
84 .send(message)
85 .expect("Something went terribly wrong and the receiver closed");
86 }
87}
88
89pub trait Context: Sized + Clone + Send + 'static {
92 fn set_messanger(self, messanger: ConnectionMessanger) -> Self;
96}
97
98impl Connection {
99 pub async fn connect<'a>(auth: impl Into<Authentication<'a>>) -> Result<Self> {
101 Self::connect_with_url(auth, "wss://ws.revolt.chat").await
102 }
103
104 pub async fn connect_with_url<'a>(
108 auth: impl Into<Authentication<'a>>,
109 url: &str,
110 ) -> Result<Self> {
111 tracing::debug!("Connecting to websocket on {}", url);
112 let (stream, _response) = connect_async(url).await?;
113 let mut internal = ConnectionInternal {
114 stream,
115 closed: false,
116 };
117 internal.authenticate(auth.into()).await?;
118
119 let connection = Self(internal);
120
121 Ok(connection)
122 }
123
124 pub async fn run<C, H>(mut self, ctx: C, handler: H) -> Result
133 where
134 C: Context,
135 H: RawEventHandler<Context = C>,
136 {
137 let mut int = tokio::time::interval(std::time::Duration::from_secs(15));
138
139 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<ConnectionMessage>();
140
141 enum Event {
142 FromServer(Result<ServerToClientEvent>),
143 ConnectionMessage(Option<ConnectionMessage>),
144 Tick,
145 TypingManagerTick,
146 }
147
148 let mut typing_session_manager = typing::TypingSessionManager::default();
149
150 loop {
151 let event = futures::select! {
156 event = self.get_event().fuse() => Event::FromServer(event),
157 connection_message = rx.recv().fuse() => Event::ConnectionMessage(connection_message),
158 _ = int.tick().fuse() => Event::Tick,
159 _ = typing_session_manager.tick().fuse() => Event::TypingManagerTick,
160 };
161
162 match event {
163 Event::FromServer(event) => {
164 let event = event?;
165
166 let handler = handler.clone();
167 let ctx = ctx.clone().set_messanger(ConnectionMessanger(tx.clone()));
168
169 let fut = handler.handle(ctx, event);
170 tokio::spawn(fut);
171 }
172 Event::ConnectionMessage(Some(message)) => match message {
173 ConnectionMessage::StartTyping { channel } => {
174 typing_session_manager.start_typing(channel);
175 self.start_typing(channel).await?;
176 }
177 ConnectionMessage::StopTyping { channel } => {
178 if typing_session_manager.stop_typing(channel) {
179 self.stop_typing(channel).await?;
181 }
182 }
183 ConnectionMessage::Close => {
184 self.close().await?;
185 return Ok(()); }
187 },
188 Event::ConnectionMessage(None) => {
189 unreachable!()
195 }
196 Event::Tick => {
197 self.hb().await?;
198 }
199 Event::TypingManagerTick => {
200 for session in typing_session_manager.current_sessions() {
201 self.start_typing(*session).await?;
202 }
203 }
204 }
205 }
206 }
207
208 pub async fn hb(&mut self) -> Result {
210 self.0.hb().await
211 }
212
213 pub async fn get_event(&mut self) -> Result<ServerToClientEvent> {
215 self.0.get_event().await
216 }
217
218 pub async fn start_typing(&mut self, channel: ChannelId) -> Result {
223 self.0
224 .send_event(ClientToServerEvent::BeginTyping { channel })
225 .await
226 }
227
228 pub async fn stop_typing(&mut self, channel: ChannelId) -> Result {
230 self.0
231 .send_event(ClientToServerEvent::EndTyping { channel })
232 .await
233 }
234
235 pub async fn close(mut self) -> Result {
237 self.0.close().await?;
238
239 Ok(())
240 }
241}
242
243impl ConnectionInternal {
244 async fn hb(&mut self) -> Result {
245 self.send_event(ClientToServerEvent::Ping { data: 0 })
246 .await?;
247
248 Ok(())
249 }
250
251 async fn authenticate(&mut self, auth: Authentication<'_>) -> Result {
252 tracing::debug!("Authenticating");
253 self.send_event(match &auth {
254 Authentication::Bot { token } => ClientToServerEvent::Authenticate {
255 token: token.to_string(),
256 },
257 Authentication::User { session_token } => ClientToServerEvent::Authenticate {
258 token: session_token.to_string(),
259 },
260 })
261 .await?;
262
263 let msg = self.get_event().await?;
264
265 match msg {
266 ServerToClientEvent::Authenticated => {}
267 ServerToClientEvent::Error { error } => {
268 tracing::error!("Error while authenticating: {}", error);
269
270 return Err(EventsError::AuthError(error));
271 }
272 msg => {
273 tracing::info!("Unexpected message after auth: {:?}", msg);
274 }
275 }
276
277 Ok(())
278 }
279
280 async fn send_event(&mut self, message: ClientToServerEvent) -> Result {
281 use futures::sink::SinkExt;
282
283 let json = serde_json::to_string(&message)?;
284
285 tracing::debug!("[>] {}", &json);
286
287 self.stream.send(TungsteniteMessage::text(json)).await?;
288
289 Ok(())
290 }
291
292 async fn close(&mut self) -> Result {
293 self.stream.close(None).await?;
294 self.closed = true;
295
296 Ok(())
297 }
298
299 async fn get_event(&mut self) -> Result<ServerToClientEvent> {
300 if self.closed {
301 return Err(EventsError::Closed);
302 }
303
304 use async_std::stream::StreamExt;
305 let msg: TungsteniteMessage = self
306 .stream
307 .next()
308 .await
309 .expect("Last message in ws without closing")?;
310
311 match msg {
312 TungsteniteMessage::Text(json) => {
313 tracing::debug!("[<] {}", &json);
314 return Ok(serde_json::from_str(&json)?);
315 }
316 TungsteniteMessage::Binary(b) => tracing::debug!("Got binary: {:?}", &b),
317 TungsteniteMessage::Ping(ping) => tracing::debug!("Got ping: {:?}", &ping),
318 TungsteniteMessage::Pong(pong) => tracing::debug!("Got pong: {:?}", &pong),
319 TungsteniteMessage::Close(close) => {
320 tracing::debug!("Got close: {:?}", close);
321 self.closed = true;
322
323 return Err(EventsError::Closed);
324 }
325 };
326
327 unimplemented!()
328 }
329}