diself 0.3.0

A Discord self-bot library for Rust
Documentation
use crate::error::{Error, Result};
use crate::gateway::{Connection, Identify};
use rand::Rng;
use serde_json::{json, Value};
use tokio::time::{self, Duration, Interval, Instant};

const DEFAULT_GATEWAY_URL: &str = "wss://gateway.discord.gg/?v=10&encoding=json";
const INVALID_SESSION_RETRY_DELAY: Duration = Duration::from_secs(1);
const MAX_RECONNECT_BACKOFF: Duration = Duration::from_secs(60);

pub struct Gateway {
    token: String,
    connection: Option<Connection>,
    heartbeat: Option<Interval>,
    heartbeat_interval_ms: u64,
    awaiting_heartbeat_ack: bool,
    pending_heartbeat: bool,
    sequence: Option<u64>,
    session_id: Option<String>,
    resume_gateway_url: Option<String>,
    reconnect_attempts: u32,
    max_reconnect_attempts: u32,
}

impl Gateway {
    pub async fn connect(token: impl Into<String>) -> Result<Self> {
        let mut gateway = Self {
            token: token.into(),
            connection: None,
            heartbeat: None,
            heartbeat_interval_ms: 0,
            awaiting_heartbeat_ack: false,
            pending_heartbeat: false,
            sequence: None,
            session_id: None,
            resume_gateway_url: None,
            reconnect_attempts: 0,
            max_reconnect_attempts: 10,
        };

        gateway.reconnect(true).await?;
        Ok(gateway)
    }

    pub async fn next_event(&mut self) -> Result<Option<Value>> {
        loop {
            if self.pending_heartbeat {
                self.send_heartbeat().await?;
                self.pending_heartbeat = false;
            }

            let heartbeat = self.heartbeat.as_mut().ok_or(Error::InvalidPayload)?;
            let connection = self.connection.as_mut().ok_or(Error::InvalidPayload)?;

            tokio::select! {
                _ = heartbeat.tick() => {
                    if self.awaiting_heartbeat_ack {
                        tracing::warn!("Heartbeat ACK timeout, reconnecting gateway");
                        self.reconnect(true).await?;
                        continue;
                    }
                    self.pending_heartbeat = true;
                }
                payload = connection.receive() => {
                    let Some(payload) = payload? else {
                        tracing::warn!("Gateway connection closed, reconnecting");
                        self.reconnect(true).await?;
                        continue;
                    };

                    if let Some(seq) = payload.get("s").and_then(|s| s.as_u64()) {
                        self.sequence = Some(seq);
                    }

                    if let Some(next) = self.handle_control_opcode(&payload).await? {
                        return Ok(Some(next));
                    }
                }
            }
        }
    }

    pub async fn shutdown(&mut self) -> Result<()> {
        self.awaiting_heartbeat_ack = false;
        self.pending_heartbeat = false;
        self.heartbeat = None;

        if let Some(mut connection) = self.connection.take() {
            connection.close().await?;
        }

        Ok(())
    }

    async fn handle_control_opcode(&mut self, payload: &Value) -> Result<Option<Value>> {
        let op = payload.get("op").and_then(|op| op.as_u64());

        match op {
            Some(0) => {
                if let Some(event_type) = payload.get("t").and_then(|t| t.as_str()) {
                    match event_type {
                        "READY" => {
                            self.session_id = payload["d"]["session_id"]
                                .as_str()
                                .map(ToOwned::to_owned);
                            self.resume_gateway_url = payload["d"]["resume_gateway_url"]
                                .as_str()
                                .map(|url| format!("{url}/?v=10&encoding=json"));
                            tracing::info!(
                                "Gateway READY received (session resumable: {})",
                                self.can_resume()
                            );
                            if let Some(guilds) = payload["d"]["guilds"].as_array() {
                                self.subscribe_to_guilds(guilds).await;
                            }
                        }
                        "RESUMED" => {
                            tracing::info!("Gateway session resumed successfully");
                        }
                        _ => {}
                    }
                }
                Ok(Some(payload.clone()))
            }
            Some(1) => {
                self.pending_heartbeat = true;
                Ok(None)
            }
            Some(7) => {
                tracing::info!("Gateway requested reconnect");
                self.reconnect(true).await?;
                Ok(None)
            }
            Some(9) => {
                let can_resume = payload["d"].as_bool().unwrap_or(false);
                if !can_resume {
                    self.session_id = None;
                    self.sequence = None;
                }
                tracing::warn!(
                    "Received INVALID_SESSION (resumable: {}), reconnecting",
                    can_resume
                );
                time::sleep(INVALID_SESSION_RETRY_DELAY).await;
                self.reconnect(can_resume).await?;
                Ok(None)
            }
            Some(10) => {
                tracing::debug!("Received unexpected HELLO after handshake");
                Ok(None)
            }
            Some(11) => {
                self.awaiting_heartbeat_ack = false;
                tracing::trace!("Heartbeat ACK received");
                Ok(None)
            }
            _ => Ok(None),
        }
    }

    async fn subscribe_to_guilds(&mut self, guilds: &[Value]) {
        let Some(conn) = self.connection.as_mut() else {
            return;
        };
        for guild in guilds {
            let Some(guild_id) = guild["id"].as_str() else {
                continue;
            };
            let op14 = json!({
                "op": 14,
                "d": {
                    "guild_id": guild_id,
                    "typing": true,
                    "threads": true,
                    "activities": true,
                }
            });
            if let Err(e) = conn.send(&op14).await {
                tracing::warn!("Failed to subscribe to guild {}: {}", guild_id, e);
            }
        }
    }

    async fn send_heartbeat(&mut self) -> Result<()> {
        let payload = json!({
            "op": 1,
            "d": self.sequence,
        });

        let connection = self.connection.as_mut().ok_or(Error::InvalidPayload)?;
        connection.send(&payload).await?;
        self.awaiting_heartbeat_ack = true;
        tracing::trace!("Heartbeat sent (seq: {:?})", self.sequence);
        Ok(())
    }

    async fn reconnect(&mut self, prefer_resume: bool) -> Result<()> {
        self.connection = None;
        self.heartbeat = None;
        self.awaiting_heartbeat_ack = false;
        self.pending_heartbeat = false;

        let mut use_resume = prefer_resume && self.can_resume();

        loop {
            if self.reconnect_attempts >= self.max_reconnect_attempts {
                return Err(Error::GatewayConnection(
                    "max reconnect attempts exceeded".to_string(),
                ));
            }

            if self.reconnect_attempts > 0 {
                let backoff = self.backoff_with_jitter(self.reconnect_attempts);
                tracing::warn!(
                    "Reconnect attempt {}/{} in {:?}",
                    self.reconnect_attempts,
                    self.max_reconnect_attempts,
                    backoff
                );
                time::sleep(backoff).await;
            }

            match self.open_session(use_resume).await {
                Ok(()) => {
                    self.reconnect_attempts = 0;
                    return Ok(());
                }
                Err(err) => {
                    tracing::error!("Failed to reconnect gateway: {}", err);
                    self.reconnect_attempts = self.reconnect_attempts.saturating_add(1);

                    if use_resume {
                        tracing::warn!("Resume failed, falling back to fresh IDENTIFY");
                        use_resume = false;
                    }
                }
            }
        }
    }

    async fn open_session(&mut self, resume: bool) -> Result<()> {
        let url = self
            .resume_gateway_url
            .as_deref()
            .unwrap_or(DEFAULT_GATEWAY_URL);
        let mut connection = Connection::connect(url).await?;

        let hello = connection.receive().await?.ok_or(Error::InvalidPayload)?;
        if hello.get("op") != Some(&json!(10)) {
            return Err(Error::InvalidPayload);
        }

        self.heartbeat_interval_ms = hello["d"]["heartbeat_interval"]
            .as_u64()
            .ok_or(Error::InvalidPayload)?;

        let heartbeat_interval = Duration::from_millis(self.heartbeat_interval_ms);
        // `interval` ticks immediately once; start at +interval to avoid false ACK timeout loops.
        let mut heartbeat =
            time::interval_at(Instant::now() + heartbeat_interval, heartbeat_interval);
        heartbeat.set_missed_tick_behavior(time::MissedTickBehavior::Delay);

        if resume {
            self.send_resume(&mut connection).await?;
        } else {
            self.send_identify(&mut connection).await?;
        }

        tracing::info!(
            "Gateway connected with heartbeat={}ms (resume={})",
            self.heartbeat_interval_ms,
            resume
        );

        self.connection = Some(connection);
        self.heartbeat = Some(heartbeat);
        self.awaiting_heartbeat_ack = false;
        self.pending_heartbeat = false;
        Ok(())
    }

    async fn send_identify(&self, connection: &mut Connection) -> Result<()> {
        let identify_payload = json!({
            "op": 2,
            "d": Identify::new(self.token.clone()),
        });
        connection.send(&identify_payload).await
    }

    async fn send_resume(&self, connection: &mut Connection) -> Result<()> {
        let payload = json!({
            "op": 6,
            "d": {
                "token": self.token,
                "session_id": self.session_id,
                "seq": self.sequence,
            }
        });
        connection.send(&payload).await
    }

    fn can_resume(&self) -> bool {
        self.session_id.is_some() && self.sequence.is_some()
    }

    fn backoff_with_jitter(&self, attempt: u32) -> Duration {
        let capped = attempt.min(6);
        let base_secs = 2_u64.saturating_pow(capped);
        let base = Duration::from_secs(base_secs).min(MAX_RECONNECT_BACKOFF);

        let max_jitter_ms = (base.as_millis() / 5) as u64;
        let jitter_ms = if max_jitter_ms == 0 {
            0
        } else {
            rand::thread_rng().gen_range(0..=max_jitter_ms)
        };

        base + Duration::from_millis(jitter_ms)
    }
}