rive_gateway/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::time::Duration;
4
5use async_channel::{self, Receiver, Sender};
6use futures::{SinkExt, Stream, StreamExt};
7use rive_models::{
8    authentication::Authentication,
9    event::{ClientEvent, ServerEvent},
10};
11use tokio::{net::TcpStream, select, spawn, time::sleep};
12use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
13
14/// Base WebSocket API URL of official Revolt instance
15pub const BASE_URL: &str = "wss://ws.revolt.chat";
16
17/// Gateway client error
18#[derive(Debug, thiserror::Error)]
19pub enum Error {
20    /// WebSocket error
21    #[error("Tungstenite error: {0}")]
22    WsError(#[from] tokio_tungstenite::tungstenite::Error),
23
24    /// Data serialization/deserialization error
25    #[error("Serde JSON deserialization/serialization error: {0}")]
26    SerializationError(#[from] serde_json::Error),
27
28    /// Internal client event channel sender error
29    #[error("Client event sender error: {0}")]
30    ClientSenderError(#[from] async_channel::SendError<ClientEvent>),
31
32    /// Internal server event channel sender error
33    #[error("Server event sender error: {0}")]
34    ServerSenderError(#[from] Box<async_channel::SendError<Result<ServerEvent, Error>>>),
35}
36
37/// Gateway configuration
38#[derive(Debug, Clone)]
39pub struct GatewayConfig {
40    /// Auth token. If it is not [`Authentication::None`] then the event will be sent automatically.
41    pub auth: Authentication,
42    /// WebSocket API base URL
43    pub base_url: String,
44    /// Whether auto heartbeat is enabled
45    pub heartbeat: bool,
46}
47
48impl Default for GatewayConfig {
49    fn default() -> Self {
50        Self {
51            auth: Authentication::None,
52            base_url: BASE_URL.to_string(),
53            heartbeat: true,
54        }
55    }
56}
57
58impl GatewayConfig {
59    /// Creates a new [`GatewayConfig`].
60    pub fn new(auth: Authentication, base_url: String, heartbeat: bool) -> Self {
61        Self {
62            auth,
63            base_url,
64            heartbeat,
65        }
66    }
67}
68
69/// A wrapper for Revolt WebSocket API
70#[derive(Debug, Clone)]
71pub struct Gateway {
72    client_sender: Sender<ClientEvent>,
73    server_receiver: Receiver<Result<ServerEvent, Error>>,
74}
75
76impl Gateway {
77    /// Connect to gateway with default Revolt WebSocket URL ([`BASE_URL`])
78    pub async fn connect(auth: Authentication) -> Result<Self, Error> {
79        Gateway::connect_with_url(BASE_URL, auth).await
80    }
81
82    /// Connect to gateway with specified URL
83    pub async fn connect_with_url(
84        url: impl Into<String>,
85        auth: Authentication,
86    ) -> Result<Self, Error> {
87        Self::connect_with_config(GatewayConfig::new(auth, url.into(), true)).await
88    }
89
90    pub async fn connect_with_config(config: GatewayConfig) -> Result<Self, Error> {
91        let (socket, _) = tokio_tungstenite::connect_async(&config.base_url).await?;
92        let (client_sender, client_receiver) = async_channel::unbounded();
93        let (server_sender, server_receiver) = async_channel::unbounded();
94
95        let revolt = Gateway {
96            client_sender: client_sender.clone(),
97            server_receiver,
98        };
99
100        spawn(Gateway::handle(client_receiver, server_sender, socket));
101
102        if config.heartbeat {
103            spawn(Self::heartbeat(client_sender));
104        }
105
106        if !matches!(config.auth, Authentication::None) {
107            let event = ClientEvent::Authenticate {
108                token: config.auth.value(),
109            };
110            revolt.send(event).await?;
111        }
112
113        Ok(revolt)
114    }
115
116    /// Send an event to server
117    pub async fn send(&self, event: ClientEvent) -> Result<(), Error> {
118        self.client_sender.send(event).await.map_err(Error::from)?;
119
120        Ok(())
121    }
122
123    async fn heartbeat(client_sender: Sender<ClientEvent>) -> Result<(), Error> {
124        loop {
125            // TODO: an ability to send custom value somehow
126            // it can be useful for ping measure for example
127            client_sender.send(ClientEvent::Ping { data: 0 }).await?;
128            sleep(Duration::from_secs(15)).await;
129        }
130    }
131
132    async fn handle(
133        mut client_receiver: Receiver<ClientEvent>,
134        server_sender: Sender<Result<ServerEvent, Error>>,
135        mut socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
136    ) -> Result<(), Error> {
137        loop {
138            select! {
139                Some(event) = client_receiver.next() => {
140                    let msg = Self::encode_client_event(event)?;
141                    socket.send(msg).await?;
142                },
143                Some(msg) = socket.next() => {
144                    let msg = msg.map_err(Error::from)?;
145                    let event = Self::decode_server_event(msg);
146                    server_sender.send(event).await.map_err(|err| Error::from(Box::new(err)))?;
147                },
148                else => break,
149            };
150        }
151
152        Ok(())
153    }
154
155    fn encode_client_event(event: ClientEvent) -> Result<Message, Error> {
156        let json = serde_json::to_string(&event).map_err(Error::from)?;
157        let msg = Message::Text(json);
158
159        Ok(msg)
160    }
161
162    fn decode_server_event(msg: Message) -> Result<ServerEvent, Error> {
163        let text = msg.to_text().map_err(Error::from)?;
164        let event = serde_json::from_str(text).map_err(Error::from)?;
165
166        Ok(event)
167    }
168}
169
170impl Stream for Gateway {
171    type Item = Result<ServerEvent, Error>;
172
173    fn poll_next(
174        mut self: std::pin::Pin<&mut Self>,
175        cx: &mut std::task::Context<'_>,
176    ) -> std::task::Poll<Option<Self::Item>> {
177        self.server_receiver.poll_next_unpin(cx)
178    }
179}