refluxer 0.2.0

Rust API wrapper for Fluxer
Documentation
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::task::{Context, Poll};

use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, Stream, StreamExt};
use tokio::sync::{Mutex, mpsc};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};

use super::event::*;
use super::heartbeat::{self, HeartbeatFailure, HeartbeatState};
use crate::error::GatewayError;

type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;

/// State needed to resume a gateway session after disconnect.
#[derive(Debug, Clone)]
pub struct SessionState {
    pub session_id: String,
    pub sequence: u64,
}

pub struct GatewayConnection {
    rx: SplitStream<WsStream>,
    tx: Arc<Mutex<SplitSink<WsStream, WsMessage>>>,
    state: Arc<HeartbeatState>,
    heartbeat_rx: mpsc::UnboundedReceiver<HeartbeatFailure>,
    heartbeat_tx: mpsc::UnboundedSender<HeartbeatFailure>,
    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
    token: String,
    #[allow(dead_code)]
    gateway_url: String,
    session_id: Option<String>,
}

impl GatewayConnection {
    /// Open a fresh connection and send Identify.
    pub async fn connect(url: &str, token: &str) -> Result<Self, GatewayError> {
        let mut conn = Self::connect_ws(url, token).await?;
        conn.handle_hello().await?;
        conn.identify().await?;
        Ok(conn)
    }

    /// Open a connection and attempt to Resume an existing session.
    /// If the server responds with InvalidSession, returns
    /// `Err(GatewayError::InvalidSession)` so the caller can fall back to
    /// a fresh Identify.
    pub async fn resume(
        url: &str,
        token: &str,
        session: &SessionState,
    ) -> Result<Self, GatewayError> {
        let mut conn = Self::connect_ws(url, token).await?;
        conn.session_id = Some(session.session_id.clone());
        conn.state
            .sequence
            .store(session.sequence, Ordering::SeqCst);
        conn.handle_hello().await?;
        conn.send_resume(&session.session_id, session.sequence)
            .await?;
        Ok(conn)
    }

    /// Returns the current session state if a READY has been received,
    /// so it can be used for a later Resume.
    pub fn session_state(&self) -> Option<SessionState> {
        self.session_id.as_ref().map(|sid| SessionState {
            session_id: sid.clone(),
            sequence: self.state.sequence.load(Ordering::SeqCst),
        })
    }

    // -- internal helpers --

    async fn connect_ws(url: &str, token: &str) -> Result<Self, GatewayError> {
        let connect_url = if url.contains("v=") {
            url.to_string()
        } else if url.contains('?') {
            format!("{url}&v=1")
        } else {
            format!("{url}/?v=1")
        };
        let (ws, _) = connect_async(&connect_url).await?;
        let (tx, rx) = ws.split();
        let tx = Arc::new(Mutex::new(tx));
        let state = Arc::new(HeartbeatState::new());
        let (heartbeat_tx, heartbeat_rx) = mpsc::unbounded_channel();

        Ok(Self {
            rx,
            tx,
            state,
            heartbeat_rx,
            heartbeat_tx,
            heartbeat_handle: None,
            token: token.to_string(),
            gateway_url: url.to_string(),
            session_id: None,
        })
    }

    async fn handle_hello(&mut self) -> Result<(), GatewayError> {
        if let Some(msg) = self.rx.next().await {
            let msg = msg?;
            if let WsMessage::Text(text) = msg {
                let payload: GatewayPayload =
                    serde_json::from_str(&text).map_err(GatewayError::Deserialize)?;
                if payload.op == Opcode::Hello {
                    let hello: HelloPayload = serde_json::from_value(payload.d.unwrap_or_default())
                        .map_err(GatewayError::Deserialize)?;
                    tracing::info!(interval = hello.heartbeat_interval, "received Hello");
                    self.heartbeat_handle = Some(heartbeat::spawn_heartbeat(
                        hello.heartbeat_interval,
                        self.tx.clone(),
                        self.state.clone(),
                        self.heartbeat_tx.clone(),
                    ));
                    return Ok(());
                }
            }
        }
        Err(GatewayError::AuthFailed)
    }

    async fn identify(&self) -> Result<(), GatewayError> {
        let data = serde_json::to_value(IdentifyPayload {
            token: self.token.clone(),
            properties: IdentifyProperties {
                os: std::env::consts::OS.into(),
                browser: "refluxer".into(),
                device: "refluxer".into(),
            },
        })
        .map_err(GatewayError::Deserialize)?;

        let identify = GatewayPayload {
            op: Opcode::Identify,
            d: Some(data),
            s: None,
            t: None,
        };
        self.send_payload(&identify).await?;
        tracing::info!("identify sent");
        Ok(())
    }

    async fn send_resume(&self, session_id: &str, seq: u64) -> Result<(), GatewayError> {
        let data = serde_json::to_value(ResumePayload {
            token: self.token.clone(),
            session_id: session_id.to_string(),
            seq,
        })
        .map_err(GatewayError::Deserialize)?;

        let resume = GatewayPayload {
            op: Opcode::Resume,
            d: Some(data),
            s: None,
            t: None,
        };
        self.send_payload(&resume).await?;
        tracing::info!(session_id, seq, "resume sent");
        Ok(())
    }

    async fn send_payload(&self, payload: &GatewayPayload) -> Result<(), GatewayError> {
        let msg = serde_json::to_string(payload).map_err(GatewayError::Deserialize)?;
        self.tx
            .lock()
            .await
            .send(WsMessage::Text(msg.into()))
            .await
            .map_err(GatewayError::Connection)?;
        Ok(())
    }

    fn process_message(&mut self, msg: WsMessage) -> Option<Result<GatewayEvent, GatewayError>> {
        match msg {
            WsMessage::Text(text) => {
                let payload: GatewayPayload = match serde_json::from_str(&text) {
                    Ok(p) => p,
                    Err(e) => return Some(Err(GatewayError::Deserialize(e))),
                };
                if let Some(seq) = payload.s {
                    self.state.sequence.store(seq, Ordering::SeqCst);
                }
                match payload.op {
                    Opcode::Dispatch => {
                        let event_name = payload.t.as_deref().unwrap_or("UNKNOWN");
                        let data = payload.d.unwrap_or(serde_json::Value::Null);
                        tracing::debug!(event = event_name, "gateway dispatch");
                        if event_name == "READY"
                            && let Ok(ready) = serde_json::from_value::<ReadyPayload>(data.clone())
                        {
                            self.session_id = Some(ready.session_id.clone());
                        }
                        Some(Ok(GatewayEvent::from_dispatch(event_name, data)))
                    }
                    Opcode::HeartbeatAck => {
                        self.state.acknowledged.store(true, Ordering::SeqCst);
                        tracing::trace!("heartbeat ack");
                        None
                    }
                    Opcode::Reconnect => {
                        tracing::info!("server requested reconnect");
                        Some(Err(GatewayError::Closed {
                            code: 4000,
                            reason: "reconnect requested".into(),
                        }))
                    }
                    Opcode::InvalidSession => {
                        let resumable = payload.d.and_then(|v| v.as_bool()).unwrap_or(false);
                        tracing::warn!(resumable, "received InvalidSession");
                        Some(Err(GatewayError::InvalidSession { resumable }))
                    }
                    _ => None,
                }
            }
            WsMessage::Close(frame) => {
                let (code, reason) = frame
                    .map(|f| (f.code.into(), f.reason.to_string()))
                    .unwrap_or((1000, "connection closed".into()));
                Some(Err(GatewayError::Closed { code, reason }))
            }
            _ => None,
        }
    }

    fn heartbeat_error(failure: HeartbeatFailure) -> GatewayError {
        match failure {
            HeartbeatFailure::Timeout => GatewayError::HeartbeatTimeout,
            HeartbeatFailure::SendFailed => GatewayError::Closed {
                code: 1006,
                reason: "heartbeat send failed".into(),
            },
        }
    }
}

impl Stream for GatewayConnection {
    type Item = Result<GatewayEvent, GatewayError>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        loop {
            if let Poll::Ready(Some(failure)) = self.heartbeat_rx.poll_recv(cx) {
                return Poll::Ready(Some(Err(Self::heartbeat_error(failure))));
            }

            match Pin::new(&mut self.rx).poll_next(cx) {
                Poll::Ready(Some(Ok(msg))) => {
                    if let Some(result) = self.process_message(msg) {
                        return Poll::Ready(Some(result));
                    }
                    continue;
                }
                Poll::Ready(Some(Err(e))) => {
                    return Poll::Ready(Some(Err(GatewayError::Connection(e))));
                }
                Poll::Ready(None) => return Poll::Ready(None),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

impl Drop for GatewayConnection {
    fn drop(&mut self) {
        if let Some(handle) = self.heartbeat_handle.take() {
            handle.abort();
        }
    }
}