use async_tungstenite::{
stream::Stream,
tokio::{connect_async, TokioAdapter},
tungstenite::Message as TungsteniteMessage,
WebSocketStream,
};
use futures::FutureExt;
use robespierre_models::{
channel::{Channel, ChannelField, Message, PartialChannel, PartialMessage},
id::{ChannelId, MemberId, MessageId, RoleId, ServerId, UserId},
server::{
Member, MemberField, PartialMember, PartialRole, PartialServer, RoleField, Server,
ServerField,
},
user::{PartialUser, RelationshipStatus, User, UserField},
};
use std::result::Result as StdResult;
use tokio::{net::TcpStream, sync::mpsc::UnboundedSender};
use tokio_rustls::client::TlsStream;
use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum EventsError {
#[error("tungstenite error: {0}")]
WsError(#[from] async_tungstenite::tungstenite::Error),
#[error("serialization / deserialization error: {0}")]
DeserializationError(#[from] serde_json::Error),
#[error("error while authenticating: {0}")]
AuthError(String),
#[error("websocket closed")]
Closed,
}
pub type Result<T = ()> = StdResult<T, EventsError>;
#[derive(Serialize, Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
#[serde(tag = "type")]
pub enum ClientToServerEvent {
Authenticate {
user_id: UserId,
session_token: String,
},
#[serde(rename = "Authenticate")]
AuthenticateBot {
token: String,
},
BeginTyping {
channel: ChannelId,
},
EndTyping {
channel: ChannelId,
},
Ping {
time: u32,
},
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct ReadyEvent {
pub users: Vec<User>,
pub servers: Vec<Server>,
pub channels: Vec<Channel>,
pub members: Vec<Member>,
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(tag = "type")]
pub enum ServerToClientEvent {
Error {
error: String,
},
Authenticated,
Pong {
time: u32,
},
Ready {
#[serde(flatten)]
event: ReadyEvent,
},
Message {
#[serde(flatten)]
message: Message,
},
MessageUpdate {
id: MessageId,
channel: ChannelId,
data: PartialMessage,
},
MessageDelete {
id: MessageId,
channel: ChannelId,
},
ChannelCreate {
#[serde(flatten)]
channel: Channel,
},
ChannelUpdate {
id: ChannelId,
data: PartialChannel,
#[serde(default)]
clear: Option<ChannelField>,
},
ChannelDelete {
id: ChannelId,
},
ChannelGroupJoin {
id: ChannelId,
user: UserId,
},
ChannelGroupLeave {
id: ChannelId,
user: UserId,
},
ChannelStartTyping {
id: ChannelId,
user: UserId,
},
ChannelStopTyping {
id: ChannelId,
user: UserId,
},
ChannelAck {
id: ChannelId,
user: UserId,
message_id: MessageId,
},
ServerUpdate {
id: ServerId,
data: PartialServer,
#[serde(default)]
clear: Option<ServerField>,
},
ServerDelete {
id: ServerId,
},
ServerMemberUpdate {
id: MemberId,
data: PartialMember,
#[serde(default)]
clear: Option<MemberField>,
},
ServerMemberJoin {
id: ServerId,
user: UserId,
},
ServerMemberLeave {
id: ServerId,
user: UserId,
},
ServerRoleUpdate {
id: ServerId,
role_id: RoleId,
data: PartialRole,
#[serde(default)]
clear: Option<RoleField>,
},
ServerRoleDelete {
id: ServerId,
role_id: RoleId,
},
UserUpdate {
id: UserId,
data: PartialUser,
#[serde(default)]
clear: Option<UserField>,
},
UserRelationship {
id: UserId,
user: UserId,
status: RelationshipStatus,
},
}
struct ConnectionInternal {
stream: WebSocketStream<Stream<TokioAdapter<TcpStream>, TokioAdapter<TlsStream<TcpStream>>>>,
closed: bool,
}
pub struct Connection(ConnectionInternal);
pub enum Authentication<'a> {
Bot {
token: &'a str,
},
User {
user_id: UserId,
session_token: &'a str,
},
}
#[async_trait::async_trait]
pub trait RawEventHandler: Send + Sync + Clone + 'static {
type Context: 'static;
async fn handle(self, ctx: Self::Context, event: ServerToClientEvent);
}
#[derive(Debug, Copy, Clone)]
pub enum ConnectionMessage {
StartTyping { channel: ChannelId },
StopTyping { channel: ChannelId },
Close,
}
#[derive(Clone, Debug)]
pub struct ConnectionMessanger(UnboundedSender<ConnectionMessage>);
impl ConnectionMessanger {
pub fn send(&self, message: ConnectionMessage) {
self.0
.send(message)
.expect("Something went terribly wrong and the receiver closed");
}
}
pub trait Context: Sized + Clone + Send + 'static {
fn set_messanger(self, messanger: ConnectionMessanger) -> Self;
}
impl Connection {
pub async fn connect<'a>(auth: impl Into<Authentication<'a>>) -> Result<Self> {
tracing::debug!("Connecting to websocket");
let (stream, _response) = connect_async("wss://ws.revolt.chat").await?;
let mut internal = ConnectionInternal {
stream,
closed: false,
};
internal.authenticate(auth.into()).await?;
let connection = Self(internal);
Ok(connection)
}
pub async fn run<C, H>(mut self, ctx: C, handler: H) -> Result
where
C: Context,
H: RawEventHandler<Context = C>,
{
let mut int = tokio::time::interval(std::time::Duration::from_secs(15));
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<ConnectionMessage>();
enum Event {
FromServer(Result<ServerToClientEvent>),
ConnectionMessage(Option<ConnectionMessage>),
Tick,
}
loop {
let event = futures::select! {
event = self.0.get_event().fuse() => Event::FromServer(event),
connection_message = rx.recv().fuse() => Event::ConnectionMessage(connection_message),
_ = int.tick().fuse() => Event::Tick,
};
match event {
Event::FromServer(event) => {
let event = event?;
let handler = handler.clone();
let ctx = ctx.clone().set_messanger(ConnectionMessanger(tx.clone()));
let fut = handler.handle(ctx, event);
tokio::spawn(fut);
}
Event::ConnectionMessage(Some(message)) => match message {
ConnectionMessage::StartTyping { channel } => {
self.start_typing(channel).await?
}
ConnectionMessage::StopTyping { channel } => self.stop_typing(channel).await?,
ConnectionMessage::Close => {
self.0.stream.close(None).await?;
return Ok(()); }
},
Event::ConnectionMessage(None) => {
unreachable!()
}
Event::Tick => {
let result = self.0.hb().await;
if let Err(err) = result {
tracing::error!("hb error: {}", err);
}
}
}
}
}
pub async fn hb(&mut self) -> Result {
self.0.hb().await
}
pub async fn get_event(&mut self) -> Result<ServerToClientEvent> {
self.0.get_event().await
}
pub async fn start_typing(&mut self, channel: ChannelId) -> Result {
self.0
.send_event(ClientToServerEvent::BeginTyping { channel })
.await
}
pub async fn stop_typing(&mut self, channel: ChannelId) -> Result {
self.0
.send_event(ClientToServerEvent::EndTyping { channel })
.await
}
}
impl ConnectionInternal {
async fn hb(&mut self) -> Result {
tracing::debug!("sending Ping message");
self.send_event(ClientToServerEvent::Ping { time: 0 })
.await?;
Ok(())
}
async fn authenticate<'a>(&mut self, auth: Authentication<'a>) -> Result {
tracing::debug!("Authenticating");
self.send_event(match &auth {
Authentication::Bot { token } => ClientToServerEvent::AuthenticateBot {
token: token.to_string(),
},
Authentication::User {
user_id,
session_token,
} => ClientToServerEvent::Authenticate {
user_id: user_id.clone(),
session_token: session_token.to_string(),
},
})
.await?;
let msg = self.get_event().await?;
match msg {
ServerToClientEvent::Authenticated => {}
ServerToClientEvent::Error { error } => {
tracing::error!("Error while authenticating: {}", error);
return Err(EventsError::AuthError(error));
}
msg => {
tracing::info!("Unexpected message after auth: {:?}", msg);
}
}
Ok(())
}
async fn send_event(&mut self, message: ClientToServerEvent) -> Result {
use futures::sink::SinkExt;
self.stream
.send(TungsteniteMessage::text(serde_json::to_string(&message)?))
.await?;
Ok(())
}
async fn get_event(&mut self) -> Result<ServerToClientEvent> {
if self.closed {
return Err(EventsError::Closed);
}
use async_std::stream::StreamExt;
let msg: TungsteniteMessage = self
.stream
.next()
.await
.expect("Last message in ws without closing")?;
match msg {
TungsteniteMessage::Text(json) => {
tracing::debug!("Got json: {}", &json);
return Ok(serde_json::from_str(&json)?);
}
TungsteniteMessage::Binary(b) => tracing::debug!("Got binary: {:?}", &b),
TungsteniteMessage::Ping(ping) => tracing::debug!("Got ping: {:?}", &ping),
TungsteniteMessage::Pong(pong) => tracing::debug!("Got pong: {:?}", &pong),
TungsteniteMessage::Close(close) => {
tracing::debug!("Got close: {:?}", close);
self.closed = true;
return Err(EventsError::Closed);
}
};
unimplemented!()
}
}