refluxer 0.1.0

Rust API wrapper for Fluxer
Documentation
use super::cache::Cache;
use super::context::Context;
use super::handler::EventHandler;
use crate::error::{Error, GatewayError};
use crate::gateway::connection::{GatewayConnection, SessionState};
use crate::gateway::event::GatewayEvent;
use crate::http::client::HttpClient;
use futures_util::{FutureExt, StreamExt};
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use tokio::sync::Semaphore;

const DEFAULT_GATEWAY_URL: &str = "wss://gateway.fluxer.app";
const DEFAULT_MAX_CONCURRENT_EVENTS: usize = 1024;

pub struct Client {
    token: String,
    gateway_url_override: Option<String>,
    handler: Arc<dyn EventHandler>,
    http: HttpClient,
    cache: Cache,
    event_semaphore: Arc<Semaphore>,
}

impl Client {
    pub fn builder() -> ClientBuilder {
        ClientBuilder::default()
    }

    pub async fn start(&self) -> Result<(), Error> {
        let mut backoff = 1u64;
        let mut session: Option<SessionState> = None;
        let gateway_url = self.resolve_gateway_url().await;

        loop {
            let connect_result = if let Some(ref s) = session {
                tracing::info!(
                    url = %gateway_url,
                    session_id = %s.session_id,
                    seq = s.sequence,
                    "resuming gateway session",
                );
                GatewayConnection::resume(&gateway_url, &self.token, s).await
            } else {
                tracing::info!(url = %gateway_url, "connecting to gateway");
                GatewayConnection::connect(&gateway_url, &self.token).await
            };

            match connect_result {
                Ok(mut gw) => {
                    backoff = 1;
                    let ctx = Context::new(self.http.clone(), self.cache.clone());

                    while let Some(result) = gw.next().await {
                        match result {
                            Ok(event) => {
                                self.dispatch(ctx.clone(), event).await;
                            }
                            Err(GatewayError::InvalidSession { resumable }) => {
                                tracing::warn!(resumable, "session invalidated");
                                if !resumable {
                                    session = None;
                                }
                                break;
                            }
                            Err(GatewayError::Closed { code, reason }) => {
                                tracing::warn!(code, %reason, "gateway closed");
                                // Save session for resume on recoverable closes
                                session = gw.session_state();
                                break;
                            }
                            Err(e) => {
                                tracing::error!(?e, "gateway error");
                                session = gw.session_state();
                                break;
                            }
                        }
                    }

                    // Update session state if not already cleared
                    if session.is_some() || gw.session_state().is_some() {
                        session = session.or_else(|| gw.session_state());
                    }
                }
                Err(GatewayError::InvalidSession { .. }) => {
                    tracing::warn!("resume rejected, falling back to fresh identify");
                    session = None;
                }
                Err(e) => {
                    tracing::error!(?e, "failed to connect to gateway");
                }
            }

            let wait = std::time::Duration::from_secs(backoff);
            tracing::info!(?wait, "reconnecting");
            tokio::time::sleep(wait).await;
            backoff = (backoff * 2).min(60);
        }
    }

    /// Resolve the gateway WebSocket URL: if the user explicitly set one via
    /// `ClientBuilder::gateway_url`, honour it; otherwise query
    /// `GET /gateway/bot` and log the session-start limits. Falls back to
    /// [`DEFAULT_GATEWAY_URL`] on discovery failure so bots can still connect
    /// in degraded networks.
    ///
    /// TODO(sharding): `GatewayBotInfo::shards` and `max_concurrency` are
    /// currently logged but not used. A future shard manager should spawn
    /// `shards` connections identifying with `[shard_id, shard_count]`.
    async fn resolve_gateway_url(&self) -> String {
        if let Some(url) = &self.gateway_url_override {
            tracing::debug!(%url, "using gateway URL override");
            return url.clone();
        }
        match self.http.get_gateway_bot().await {
            Ok(info) => {
                tracing::info!(
                    url = %info.url,
                    shards = info.shards,
                    session_total = info.session_start_limit.total,
                    session_remaining = info.session_start_limit.remaining,
                    session_reset_after_ms = info.session_start_limit.reset_after,
                    max_concurrency = info.session_start_limit.max_concurrency,
                    "discovered gateway URL",
                );
                info.url
            }
            Err(e) => {
                tracing::warn!(
                    error = ?e,
                    fallback = DEFAULT_GATEWAY_URL,
                    "get_gateway_bot failed, falling back to default URL",
                );
                DEFAULT_GATEWAY_URL.to_string()
            }
        }
    }

    async fn dispatch(&self, ctx: Context, event: GatewayEvent) {
        let handler = self.handler.clone();
        let cache = self.cache.clone();
        let permit = match self.event_semaphore.clone().acquire_owned().await {
            Ok(permit) => permit,
            Err(_) => {
                tracing::error!("event dispatch semaphore closed");
                return;
            }
        };

        tokio::spawn(async move {
            let _permit = permit;
            let result = AssertUnwindSafe(Self::handle_event(handler, cache, ctx, event))
                .catch_unwind()
                .await;
            if result.is_err() {
                tracing::error!("event handler panicked");
            }
        });
    }

    async fn handle_event(
        handler: Arc<dyn EventHandler>,
        cache: Cache,
        ctx: Context,
        event: GatewayEvent,
    ) {
        match event {
            GatewayEvent::Ready(ready) => {
                cache.set_current_user(ready.user.clone()).await;
                handler.ready(ctx, ready.user).await;
            }
            GatewayEvent::Resumed(_) => {
                tracing::info!("session resumed successfully");
                handler.resumed(ctx).await;
            }
            GatewayEvent::MessageCreate(msg) => handler.message_create(ctx, msg).await,
            GatewayEvent::MessageUpdate(msg) => handler.message_update(ctx, msg).await,
            GatewayEvent::MessageDelete(p) => handler.message_delete(ctx, p.channel_id, p.id).await,
            GatewayEvent::GuildCreate(guild) => {
                cache.insert_guild(guild.clone()).await;
                handler.guild_create(ctx, guild).await;
            }
            GatewayEvent::GuildUpdate(guild) => {
                cache.insert_guild(guild.clone()).await;
                handler.guild_update(ctx, guild).await;
            }
            GatewayEvent::GuildDelete(p) => {
                cache.remove_guild(p.id).await;
                handler.guild_delete(ctx, p.id).await;
            }
            GatewayEvent::GuildMemberAdd(p) => {
                handler.guild_member_add(ctx, p.guild_id, p.member).await
            }
            GatewayEvent::GuildMemberRemove(p) => {
                handler.guild_member_remove(ctx, p.guild_id, p.user).await
            }
            GatewayEvent::ChannelCreate(ch) => {
                cache.insert_channel(ch.clone()).await;
                handler.channel_create(ctx, ch).await;
            }
            GatewayEvent::ChannelUpdate(ch) => {
                cache.insert_channel(ch.clone()).await;
                handler.channel_update(ctx, ch).await;
            }
            GatewayEvent::ChannelDelete(ch) => {
                cache.remove_channel(ch.id).await;
                handler.channel_delete(ctx, ch).await;
            }
            GatewayEvent::TypingStart(p) => {
                handler.typing_start(ctx, p.channel_id, p.user_id).await
            }
            GatewayEvent::UserUpdate(user) => handler.user_update(ctx, user).await,
            GatewayEvent::MessageDeleteBulk(p) => handler.message_delete_bulk(ctx, p).await,
            GatewayEvent::MessageReactionAdd(p) => handler.message_reaction_add(ctx, p).await,
            GatewayEvent::MessageReactionRemove(p) => handler.message_reaction_remove(ctx, p).await,
            GatewayEvent::MessageReactionRemoveAll(p) => {
                handler.message_reaction_remove_all(ctx, p).await
            }
            GatewayEvent::MessageReactionRemoveEmoji(p) => {
                handler.message_reaction_remove_emoji(ctx, p).await
            }
            GatewayEvent::GuildMemberUpdate(p) => handler.guild_member_update(ctx, p).await,
            GatewayEvent::GuildRoleCreate(p) => handler.guild_role_create(ctx, p).await,
            GatewayEvent::GuildRoleUpdate(p) => handler.guild_role_update(ctx, p).await,
            GatewayEvent::GuildRoleDelete(p) => handler.guild_role_delete(ctx, p).await,
            GatewayEvent::GuildBanAdd(p) => handler.guild_ban_add(ctx, p).await,
            GatewayEvent::GuildBanRemove(p) => handler.guild_ban_remove(ctx, p).await,
            GatewayEvent::ChannelPinsUpdate(p) => handler.channel_pins_update(ctx, p).await,
            GatewayEvent::InviteCreate(p) => handler.invite_create(ctx, p).await,
            GatewayEvent::InviteDelete(p) => handler.invite_delete(ctx, p).await,
            GatewayEvent::WebhooksUpdate(p) => handler.webhooks_update(ctx, p).await,
            GatewayEvent::GuildEmojisUpdate(p) => handler.guild_emojis_update(ctx, p).await,
            GatewayEvent::GuildStickersUpdate(p) => handler.guild_stickers_update(ctx, p).await,
            GatewayEvent::Unknown { .. } => {}
        }
    }
}

pub struct ClientBuilder {
    token: Option<String>,
    gateway_url: Option<String>,
    handler: Option<Arc<dyn EventHandler>>,
    base_url: Option<String>,
    auto_retry: bool,
    max_concurrent_events: usize,
}

impl Default for ClientBuilder {
    fn default() -> Self {
        Self {
            token: None,
            gateway_url: None,
            handler: None,
            base_url: None,
            auto_retry: true,
            max_concurrent_events: DEFAULT_MAX_CONCURRENT_EVENTS,
        }
    }
}

impl ClientBuilder {
    pub fn token(mut self, token: &str) -> Self {
        self.token = Some(token.into());
        self
    }
    /// Override the Gateway URL. If unset, the client calls `GET /gateway/bot`
    /// on startup to discover it.
    pub fn gateway_url(mut self, url: &str) -> Self {
        self.gateway_url = Some(url.into());
        self
    }
    pub fn base_url(mut self, url: &str) -> Self {
        self.base_url = Some(url.into());
        self
    }
    pub fn auto_retry(mut self, enabled: bool) -> Self {
        self.auto_retry = enabled;
        self
    }
    pub fn max_concurrent_events(mut self, limit: usize) -> Self {
        self.max_concurrent_events = limit.max(1);
        self
    }
    pub fn event_handler(mut self, handler: impl EventHandler) -> Self {
        self.handler = Some(Arc::new(handler));
        self
    }
    #[allow(clippy::result_large_err)]
    pub fn build(self) -> Result<Client, Error> {
        let token = self.token.ok_or(Error::MissingToken)?;
        let mut http_builder = HttpClient::builder()
            .token(&token)
            .auto_retry(self.auto_retry);
        if let Some(base_url) = &self.base_url {
            http_builder = http_builder.base_url(base_url);
        }
        let http = http_builder.build()?;
        let handler = self.handler.ok_or(Error::MissingEventHandler)?;
        Ok(Client {
            token,
            gateway_url_override: self.gateway_url,
            handler,
            http,
            cache: Cache::new(),
            event_semaphore: Arc::new(Semaphore::new(self.max_concurrent_events)),
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    struct NoopHandler;

    impl EventHandler for NoopHandler {}

    #[test]
    fn builder_requires_token() {
        let result = Client::builder().event_handler(NoopHandler).build();

        assert!(matches!(result, Err(Error::MissingToken)));
    }

    #[test]
    fn builder_requires_event_handler() {
        let result = Client::builder().token("test-token").build();

        assert!(matches!(result, Err(Error::MissingEventHandler)));
    }
}