1#![doc = include_str!("../README.md")]
2
3use std::time::Duration;
4
5use async_channel::{self, Receiver, Sender};
6use futures::{SinkExt, Stream, StreamExt};
7use rive_models::{
8 authentication::Authentication,
9 event::{ClientEvent, ServerEvent},
10};
11use tokio::{net::TcpStream, select, spawn, time::sleep};
12use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
13
14pub const BASE_URL: &str = "wss://ws.revolt.chat";
16
17#[derive(Debug, thiserror::Error)]
19pub enum Error {
20 #[error("Tungstenite error: {0}")]
22 WsError(#[from] tokio_tungstenite::tungstenite::Error),
23
24 #[error("Serde JSON deserialization/serialization error: {0}")]
26 SerializationError(#[from] serde_json::Error),
27
28 #[error("Client event sender error: {0}")]
30 ClientSenderError(#[from] async_channel::SendError<ClientEvent>),
31
32 #[error("Server event sender error: {0}")]
34 ServerSenderError(#[from] Box<async_channel::SendError<Result<ServerEvent, Error>>>),
35}
36
37#[derive(Debug, Clone)]
39pub struct GatewayConfig {
40 pub auth: Authentication,
42 pub base_url: String,
44 pub heartbeat: bool,
46}
47
48impl Default for GatewayConfig {
49 fn default() -> Self {
50 Self {
51 auth: Authentication::None,
52 base_url: BASE_URL.to_string(),
53 heartbeat: true,
54 }
55 }
56}
57
58impl GatewayConfig {
59 pub fn new(auth: Authentication, base_url: String, heartbeat: bool) -> Self {
61 Self {
62 auth,
63 base_url,
64 heartbeat,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct Gateway {
72 client_sender: Sender<ClientEvent>,
73 server_receiver: Receiver<Result<ServerEvent, Error>>,
74}
75
76impl Gateway {
77 pub async fn connect(auth: Authentication) -> Result<Self, Error> {
79 Gateway::connect_with_url(BASE_URL, auth).await
80 }
81
82 pub async fn connect_with_url(
84 url: impl Into<String>,
85 auth: Authentication,
86 ) -> Result<Self, Error> {
87 Self::connect_with_config(GatewayConfig::new(auth, url.into(), true)).await
88 }
89
90 pub async fn connect_with_config(config: GatewayConfig) -> Result<Self, Error> {
91 let (socket, _) = tokio_tungstenite::connect_async(&config.base_url).await?;
92 let (client_sender, client_receiver) = async_channel::unbounded();
93 let (server_sender, server_receiver) = async_channel::unbounded();
94
95 let revolt = Gateway {
96 client_sender: client_sender.clone(),
97 server_receiver,
98 };
99
100 spawn(Gateway::handle(client_receiver, server_sender, socket));
101
102 if config.heartbeat {
103 spawn(Self::heartbeat(client_sender));
104 }
105
106 if !matches!(config.auth, Authentication::None) {
107 let event = ClientEvent::Authenticate {
108 token: config.auth.value(),
109 };
110 revolt.send(event).await?;
111 }
112
113 Ok(revolt)
114 }
115
116 pub async fn send(&self, event: ClientEvent) -> Result<(), Error> {
118 self.client_sender.send(event).await.map_err(Error::from)?;
119
120 Ok(())
121 }
122
123 async fn heartbeat(client_sender: Sender<ClientEvent>) -> Result<(), Error> {
124 loop {
125 client_sender.send(ClientEvent::Ping { data: 0 }).await?;
128 sleep(Duration::from_secs(15)).await;
129 }
130 }
131
132 async fn handle(
133 mut client_receiver: Receiver<ClientEvent>,
134 server_sender: Sender<Result<ServerEvent, Error>>,
135 mut socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
136 ) -> Result<(), Error> {
137 loop {
138 select! {
139 Some(event) = client_receiver.next() => {
140 let msg = Self::encode_client_event(event)?;
141 socket.send(msg).await?;
142 },
143 Some(msg) = socket.next() => {
144 let msg = msg.map_err(Error::from)?;
145 let event = Self::decode_server_event(msg);
146 server_sender.send(event).await.map_err(|err| Error::from(Box::new(err)))?;
147 },
148 else => break,
149 };
150 }
151
152 Ok(())
153 }
154
155 fn encode_client_event(event: ClientEvent) -> Result<Message, Error> {
156 let json = serde_json::to_string(&event).map_err(Error::from)?;
157 let msg = Message::Text(json);
158
159 Ok(msg)
160 }
161
162 fn decode_server_event(msg: Message) -> Result<ServerEvent, Error> {
163 let text = msg.to_text().map_err(Error::from)?;
164 let event = serde_json::from_str(text).map_err(Error::from)?;
165
166 Ok(event)
167 }
168}
169
170impl Stream for Gateway {
171 type Item = Result<ServerEvent, Error>;
172
173 fn poll_next(
174 mut self: std::pin::Pin<&mut Self>,
175 cx: &mut std::task::Context<'_>,
176 ) -> std::task::Poll<Option<Self::Item>> {
177 self.server_receiver.poll_next_unpin(cx)
178 }
179}