use async_tungstenite::{
stream::Stream,
tokio::{connect_async, TokioAdapter},
tungstenite::Message as TungsteniteMessage,
WebSocketStream,
};
use futures::FutureExt;
use robespierre_models::{
auth::Session,
events::{ClientToServerEvent, ServerToClientEvent},
id::ChannelId,
};
use std::result::Result as StdResult;
use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
use tokio_rustls::client::TlsStream;
pub mod typing;
#[derive(Debug, thiserror::Error)]
pub enum EventsError {
#[error("tungstenite error: {0}")]
WsError(#[from] async_tungstenite::tungstenite::Error),
#[error("serialization / deserialization error: {0}")]
DeserializationError(#[from] serde_json::Error),
#[error("error while authenticating: {0}")]
AuthError(String),
#[error("websocket closed")]
Closed,
}
pub type Result<T = ()> = StdResult<T, EventsError>;
struct ConnectionInternal {
stream: WebSocketStream<Stream<TokioAdapter<TcpStream>, TokioAdapter<TlsStream<TcpStream>>>>,
closed: bool,
}
pub struct Connection(ConnectionInternal);
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Authentication<'a> {
Bot { token: &'a str },
User { session_token: &'a str },
}
impl<'a> From<&'a Session> for Authentication<'a> {
fn from(s: &'a Session) -> Self {
Self::User {
session_token: &s.token.0,
}
}
}
#[async_trait::async_trait]
pub trait RawEventHandler: Send + Sync + Clone + 'static {
type Context: 'static;
async fn handle(self, ctx: Self::Context, event: ServerToClientEvent);
}
#[derive(Debug, Copy, Clone)]
pub enum ConnectionMessage {
StartTyping { channel: ChannelId },
StopTyping { channel: ChannelId },
Close,
}
#[derive(Clone, Debug)]
pub struct ConnectionMessanger(UnboundedSender<ConnectionMessage>);
impl ConnectionMessanger {
pub fn send(&self, message: ConnectionMessage) {
self.0
.send(message)
.expect("Something went terribly wrong and the receiver closed");
}
}
pub trait Context: Sized + Clone + Send + 'static {
fn set_messanger(self, messanger: ConnectionMessanger) -> Self;
}
impl Connection {
pub async fn connect<'a>(auth: impl Into<Authentication<'a>>) -> Result<Self> {
Self::connect_with_url(auth, "wss://ws.revolt.chat").await
}
pub async fn connect_with_url<'a>(
auth: impl Into<Authentication<'a>>,
url: &str,
) -> Result<Self> {
tracing::debug!("Connecting to websocket on {}", url);
let (stream, _response) = connect_async(url).await?;
let mut internal = ConnectionInternal {
stream,
closed: false,
};
internal.authenticate(auth.into()).await?;
let connection = Self(internal);
Ok(connection)
}
pub async fn run<C, H>(mut self, ctx: C, handler: H) -> Result
where
C: Context,
H: RawEventHandler<Context = C>,
{
let mut int = tokio::time::interval(std::time::Duration::from_secs(15));
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<ConnectionMessage>();
enum Event {
FromServer(Result<ServerToClientEvent>),
ConnectionMessage(Option<ConnectionMessage>),
Tick,
TypingManagerTick,
}
let mut typing_session_manager = typing::TypingSessionManager::default();
loop {
let event = futures::select! {
event = self.get_event().fuse() => Event::FromServer(event),
connection_message = rx.recv().fuse() => Event::ConnectionMessage(connection_message),
_ = int.tick().fuse() => Event::Tick,
_ = typing_session_manager.tick().fuse() => Event::TypingManagerTick,
};
match event {
Event::FromServer(event) => {
let event = event?;
let handler = handler.clone();
let ctx = ctx.clone().set_messanger(ConnectionMessanger(tx.clone()));
let fut = handler.handle(ctx, event);
tokio::spawn(fut);
}
Event::ConnectionMessage(Some(message)) => match message {
ConnectionMessage::StartTyping { channel } => {
typing_session_manager.start_typing(channel);
self.start_typing(channel).await?;
}
ConnectionMessage::StopTyping { channel } => {
if typing_session_manager.stop_typing(channel) {
self.stop_typing(channel).await?;
}
}
ConnectionMessage::Close => {
self.close().await?;
return Ok(()); }
},
Event::ConnectionMessage(None) => {
unreachable!()
}
Event::Tick => {
self.hb().await?;
}
Event::TypingManagerTick => {
for session in typing_session_manager.current_sessions() {
self.start_typing(*session).await?;
}
}
}
}
}
pub async fn hb(&mut self) -> Result {
self.0.hb().await
}
pub async fn get_event(&mut self) -> Result<ServerToClientEvent> {
self.0.get_event().await
}
pub async fn start_typing(&mut self, channel: ChannelId) -> Result {
self.0
.send_event(ClientToServerEvent::BeginTyping { channel })
.await
}
pub async fn stop_typing(&mut self, channel: ChannelId) -> Result {
self.0
.send_event(ClientToServerEvent::EndTyping { channel })
.await
}
pub async fn close(mut self) -> Result {
self.0.close().await?;
Ok(())
}
}
impl ConnectionInternal {
async fn hb(&mut self) -> Result {
self.send_event(ClientToServerEvent::Ping { data: 0 })
.await?;
Ok(())
}
async fn authenticate(&mut self, auth: Authentication<'_>) -> Result {
tracing::debug!("Authenticating");
self.send_event(match &auth {
Authentication::Bot { token } => ClientToServerEvent::Authenticate {
token: token.to_string(),
},
Authentication::User { session_token } => ClientToServerEvent::Authenticate {
token: session_token.to_string(),
},
})
.await?;
let msg = self.get_event().await?;
match msg {
ServerToClientEvent::Authenticated => {}
ServerToClientEvent::Error { error } => {
tracing::error!("Error while authenticating: {}", error);
return Err(EventsError::AuthError(error));
}
msg => {
tracing::info!("Unexpected message after auth: {:?}", msg);
}
}
Ok(())
}
async fn send_event(&mut self, message: ClientToServerEvent) -> Result {
use futures::sink::SinkExt;
let json = serde_json::to_string(&message)?;
tracing::debug!("[>] {}", &json);
self.stream.send(TungsteniteMessage::text(json)).await?;
Ok(())
}
async fn close(&mut self) -> Result {
self.stream.close(None).await?;
self.closed = true;
Ok(())
}
async fn get_event(&mut self) -> Result<ServerToClientEvent> {
if self.closed {
return Err(EventsError::Closed);
}
use async_std::stream::StreamExt;
let msg: TungsteniteMessage = self
.stream
.next()
.await
.expect("Last message in ws without closing")?;
match msg {
TungsteniteMessage::Text(json) => {
tracing::debug!("[<] {}", &json);
return Ok(serde_json::from_str(&json)?);
}
TungsteniteMessage::Binary(b) => tracing::debug!("Got binary: {:?}", &b),
TungsteniteMessage::Ping(ping) => tracing::debug!("Got ping: {:?}", &ping),
TungsteniteMessage::Pong(pong) => tracing::debug!("Got pong: {:?}", &pong),
TungsteniteMessage::Close(close) => {
tracing::debug!("Got close: {:?}", close);
self.closed = true;
return Err(EventsError::Closed);
}
};
unimplemented!()
}
}