use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::task::{Context, Poll};
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, Stream, StreamExt};
use tokio::sync::{Mutex, mpsc};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
use super::event::*;
use super::heartbeat::{self, HeartbeatFailure, HeartbeatState};
use crate::error::GatewayError;
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[derive(Debug, Clone)]
pub struct SessionState {
pub session_id: String,
pub sequence: u64,
}
pub struct GatewayConnection {
rx: SplitStream<WsStream>,
tx: Arc<Mutex<SplitSink<WsStream, WsMessage>>>,
state: Arc<HeartbeatState>,
heartbeat_rx: mpsc::UnboundedReceiver<HeartbeatFailure>,
heartbeat_tx: mpsc::UnboundedSender<HeartbeatFailure>,
heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
token: String,
#[allow(dead_code)]
gateway_url: String,
session_id: Option<String>,
}
impl GatewayConnection {
pub async fn connect(url: &str, token: &str) -> Result<Self, GatewayError> {
let mut conn = Self::connect_ws(url, token).await?;
conn.handle_hello().await?;
conn.identify().await?;
Ok(conn)
}
pub async fn resume(
url: &str,
token: &str,
session: &SessionState,
) -> Result<Self, GatewayError> {
let mut conn = Self::connect_ws(url, token).await?;
conn.session_id = Some(session.session_id.clone());
conn.state
.sequence
.store(session.sequence, Ordering::SeqCst);
conn.handle_hello().await?;
conn.send_resume(&session.session_id, session.sequence)
.await?;
Ok(conn)
}
pub fn session_state(&self) -> Option<SessionState> {
self.session_id.as_ref().map(|sid| SessionState {
session_id: sid.clone(),
sequence: self.state.sequence.load(Ordering::SeqCst),
})
}
async fn connect_ws(url: &str, token: &str) -> Result<Self, GatewayError> {
let connect_url = if url.contains("v=") {
url.to_string()
} else if url.contains('?') {
format!("{url}&v=1")
} else {
format!("{url}/?v=1")
};
let (ws, _) = connect_async(&connect_url).await?;
let (tx, rx) = ws.split();
let tx = Arc::new(Mutex::new(tx));
let state = Arc::new(HeartbeatState::new());
let (heartbeat_tx, heartbeat_rx) = mpsc::unbounded_channel();
Ok(Self {
rx,
tx,
state,
heartbeat_rx,
heartbeat_tx,
heartbeat_handle: None,
token: token.to_string(),
gateway_url: url.to_string(),
session_id: None,
})
}
async fn handle_hello(&mut self) -> Result<(), GatewayError> {
if let Some(msg) = self.rx.next().await {
let msg = msg?;
if let WsMessage::Text(text) = msg {
let payload: GatewayPayload =
serde_json::from_str(&text).map_err(GatewayError::Deserialize)?;
if payload.op == Opcode::Hello {
let hello: HelloPayload = serde_json::from_value(payload.d.unwrap_or_default())
.map_err(GatewayError::Deserialize)?;
tracing::info!(interval = hello.heartbeat_interval, "received Hello");
self.heartbeat_handle = Some(heartbeat::spawn_heartbeat(
hello.heartbeat_interval,
self.tx.clone(),
self.state.clone(),
self.heartbeat_tx.clone(),
));
return Ok(());
}
}
}
Err(GatewayError::AuthFailed)
}
async fn identify(&self) -> Result<(), GatewayError> {
let data = serde_json::to_value(IdentifyPayload {
token: self.token.clone(),
properties: IdentifyProperties {
os: std::env::consts::OS.into(),
browser: "refluxer".into(),
device: "refluxer".into(),
},
})
.map_err(GatewayError::Deserialize)?;
let identify = GatewayPayload {
op: Opcode::Identify,
d: Some(data),
s: None,
t: None,
};
self.send_payload(&identify).await?;
tracing::info!("identify sent");
Ok(())
}
async fn send_resume(&self, session_id: &str, seq: u64) -> Result<(), GatewayError> {
let data = serde_json::to_value(ResumePayload {
token: self.token.clone(),
session_id: session_id.to_string(),
seq,
})
.map_err(GatewayError::Deserialize)?;
let resume = GatewayPayload {
op: Opcode::Resume,
d: Some(data),
s: None,
t: None,
};
self.send_payload(&resume).await?;
tracing::info!(session_id, seq, "resume sent");
Ok(())
}
async fn send_payload(&self, payload: &GatewayPayload) -> Result<(), GatewayError> {
let msg = serde_json::to_string(payload).map_err(GatewayError::Deserialize)?;
self.tx
.lock()
.await
.send(WsMessage::Text(msg.into()))
.await
.map_err(GatewayError::Connection)?;
Ok(())
}
fn process_message(&mut self, msg: WsMessage) -> Option<Result<GatewayEvent, GatewayError>> {
match msg {
WsMessage::Text(text) => {
let payload: GatewayPayload = match serde_json::from_str(&text) {
Ok(p) => p,
Err(e) => return Some(Err(GatewayError::Deserialize(e))),
};
if let Some(seq) = payload.s {
self.state.sequence.store(seq, Ordering::SeqCst);
}
match payload.op {
Opcode::Dispatch => {
let event_name = payload.t.as_deref().unwrap_or("UNKNOWN");
let data = payload.d.unwrap_or(serde_json::Value::Null);
tracing::debug!(event = event_name, "gateway dispatch");
if event_name == "READY"
&& let Ok(ready) = serde_json::from_value::<ReadyPayload>(data.clone())
{
self.session_id = Some(ready.session_id.clone());
}
Some(Ok(GatewayEvent::from_dispatch(event_name, data)))
}
Opcode::HeartbeatAck => {
self.state.acknowledged.store(true, Ordering::SeqCst);
tracing::trace!("heartbeat ack");
None
}
Opcode::Reconnect => {
tracing::info!("server requested reconnect");
Some(Err(GatewayError::Closed {
code: 4000,
reason: "reconnect requested".into(),
}))
}
Opcode::InvalidSession => {
let resumable = payload.d.and_then(|v| v.as_bool()).unwrap_or(false);
tracing::warn!(resumable, "received InvalidSession");
Some(Err(GatewayError::InvalidSession { resumable }))
}
_ => None,
}
}
WsMessage::Close(frame) => {
let (code, reason) = frame
.map(|f| (f.code.into(), f.reason.to_string()))
.unwrap_or((1000, "connection closed".into()));
Some(Err(GatewayError::Closed { code, reason }))
}
_ => None,
}
}
fn heartbeat_error(failure: HeartbeatFailure) -> GatewayError {
match failure {
HeartbeatFailure::Timeout => GatewayError::HeartbeatTimeout,
HeartbeatFailure::SendFailed => GatewayError::Closed {
code: 1006,
reason: "heartbeat send failed".into(),
},
}
}
}
impl Stream for GatewayConnection {
type Item = Result<GatewayEvent, GatewayError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Poll::Ready(Some(failure)) = self.heartbeat_rx.poll_recv(cx) {
return Poll::Ready(Some(Err(Self::heartbeat_error(failure))));
}
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => {
if let Some(result) = self.process_message(msg) {
return Poll::Ready(Some(result));
}
continue;
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(GatewayError::Connection(e))));
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
impl Drop for GatewayConnection {
fn drop(&mut self) {
if let Some(handle) = self.heartbeat_handle.take() {
handle.abort();
}
}
}