pushwire-client 0.1.1

Generic multiplexed push client with WebSocket and SSE transports
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
use std::time::Duration;

use dashmap::DashMap;
use pushwire_core::{ChannelKind, Frame, SystemOp};
use tokio::sync::{Notify, mpsc};
use tracing::{debug, info, warn};
use uuid::Uuid;

use crate::connection::{ActiveTransport, InboundMsg, connect_with_preference};
use crate::cursor::{CursorResult, CursorTracker};
use crate::dispatch::ChannelReceiver;
use crate::reconnect::ReconnectPolicy;
use crate::subscription::SubscriptionTracker;

pub use crate::connection::TransportPreference;

// ---------------------------------------------------------------------------
// Connection state
// ---------------------------------------------------------------------------

/// Connection state machine.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ConnectionState {
    Disconnected = 0,
    Connecting = 1,
    Connected = 2,
    Resuming = 3,
}

// ---------------------------------------------------------------------------
// Configuration
// ---------------------------------------------------------------------------

/// Configuration for a push-wire client connection.
#[non_exhaustive]
pub struct ClientConfig {
    pub url: String,
    pub client_id: Uuid,
    pub token: Option<String>,
    pub reconnect: ReconnectPolicy,
    pub transport_preference: TransportPreference,
    pub binary_mode: bool,
}

impl ClientConfig {
    pub fn new(url: impl Into<String>) -> Self {
        Self {
            url: url.into(),
            client_id: Uuid::new_v4(),
            token: None,
            reconnect: ReconnectPolicy::default(),
            transport_preference: TransportPreference::WsFirst,
            binary_mode: false,
        }
    }
}

// ---------------------------------------------------------------------------
// Errors
// ---------------------------------------------------------------------------

/// Error types for client operations.
#[derive(Debug, thiserror::Error)]
pub enum ConnectError {
    #[error("transport error: {0}")]
    Transport(String),
    #[error("auth rejected: {0}")]
    AuthRejected(String),
    #[error("timeout")]
    Timeout,
}

#[derive(Debug, thiserror::Error)]
pub enum SendError {
    #[error("not connected")]
    NotConnected,
    #[error("channel closed")]
    ChannelClosed,
    #[error("serialization error: {0}")]
    Serialize(#[from] serde_json::Error),
}

// ---------------------------------------------------------------------------
// PushClient
// ---------------------------------------------------------------------------

/// Generic multiplexed push client.
///
/// Parameterized by `C: ChannelKind` — the consumer defines their own channel
/// taxonomy. Register handlers with [`on`](PushClient::on), then call
/// [`connect`](PushClient::connect) to start receiving frames.
pub struct PushClient<C: ChannelKind> {
    config: ClientConfig,
    cursors: Arc<CursorTracker<C>>,
    receivers: Arc<DashMap<C, Arc<dyn ChannelReceiver<C>>>>,
    subscriptions: Arc<SubscriptionTracker<C>>,
    state: Arc<AtomicU8>,
    transport: Option<ActiveTransport<C>>,
    shutdown: Arc<Notify>,
    processor_handle: Option<tokio::task::JoinHandle<()>>,
}

impl<C: ChannelKind> PushClient<C> {
    pub fn new(config: ClientConfig) -> Self {
        Self {
            config,
            cursors: Arc::new(CursorTracker::new()),
            receivers: Arc::new(DashMap::new()),
            subscriptions: Arc::new(SubscriptionTracker::new()),
            state: Arc::new(AtomicU8::new(ConnectionState::Disconnected as u8)),
            transport: None,
            shutdown: Arc::new(Notify::new()),
            processor_handle: None,
        }
    }

    /// Register a handler for a channel. Must be called before `connect()`.
    pub fn on(&mut self, channel: C, receiver: impl ChannelReceiver<C>) {
        self.subscriptions.subscribe(&[channel]);
        self.receivers.insert(channel, Arc::new(receiver));
    }

    /// Connect to the server. Performs auth handshake and starts the
    /// receive loop.
    pub async fn connect(&mut self) -> Result<(), ConnectError> {
        self.set_state(ConnectionState::Connecting);

        let capabilities = self.subscriptions.active();
        let resume_cursors = self.cursors.export();

        let (transport, inbound_rx) = connect_with_preference(
            self.config.transport_preference,
            &self.config.url,
            self.config.client_id,
            self.config.token.as_deref(),
            &capabilities,
            resume_cursors,
        )
        .await?;

        self.transport = Some(transport);
        self.set_state(ConnectionState::Connected);

        // Spawn the processor task that dispatches inbound messages.
        self.spawn_processor(inbound_rx);

        info!(client_id = ?self.config.client_id, "connected");
        Ok(())
    }

    /// Send a frame to the server (client → server).
    pub async fn send(&self, frame: Frame<C>) -> Result<(), SendError> {
        if self.state() != ConnectionState::Connected {
            return Err(SendError::NotConnected);
        }
        match &self.transport {
            Some(t) => t.send_frame(frame).await,
            None => Err(SendError::NotConnected),
        }
    }

    /// Subscribe to additional channels after connect.
    pub async fn subscribe(&self, channels: &[C]) -> Result<(), SendError> {
        if let Some(op) = self.subscriptions.subscribe(channels)
            && let Some(t) = &self.transport
        {
            t.send_system(op).await?;
        }
        Ok(())
    }

    /// Unsubscribe from channels.
    pub async fn unsubscribe(&self, channels: &[C]) -> Result<(), SendError> {
        if let Some(op) = self.subscriptions.unsubscribe(channels)
            && let Some(t) = &self.transport
        {
            t.send_system(op).await?;
        }
        Ok(())
    }

    /// Graceful disconnect.
    pub async fn disconnect(&mut self) -> Result<(), SendError> {
        self.shutdown.notify_waiters();

        if let Some(t) = &self.transport {
            let _ = t.send_system(SystemOp::Goodbye { reason: None }).await;
        }

        if let Some(transport) = self.transport.take() {
            transport.close().await;
        }

        if let Some(handle) = self.processor_handle.take() {
            handle.abort();
        }

        self.set_state(ConnectionState::Disconnected);
        info!(client_id = ?self.config.client_id, "disconnected");
        Ok(())
    }

    /// Current connection state.
    pub fn state(&self) -> ConnectionState {
        match self.state.load(Ordering::SeqCst) {
            0 => ConnectionState::Disconnected,
            1 => ConnectionState::Connecting,
            2 => ConnectionState::Connected,
            3 => ConnectionState::Resuming,
            _ => ConnectionState::Disconnected,
        }
    }

    /// Per-channel cursor values (for diagnostics / resume).
    pub fn cursors(&self) -> HashMap<C, u64> {
        self.cursors.export()
    }

    // -----------------------------------------------------------------------
    // Internal
    // -----------------------------------------------------------------------

    fn set_state(&self, state: ConnectionState) {
        self.state.store(state as u8, Ordering::SeqCst);
    }

    fn spawn_processor(&mut self, mut inbound_rx: mpsc::Receiver<InboundMsg<C>>) {
        let cursors = self.cursors.clone();
        let receivers = self.receivers.clone();
        let state = self.state.clone();
        let shutdown = self.shutdown.clone();

        // Reconnect state — shared with the processor so it can trigger
        // reconnection on transport close.
        let reconnect_policy = self.config.reconnect.clone();
        let url = self.config.url.clone();
        let client_id = self.config.client_id;
        let token = self.config.token.clone();
        let transport_pref = self.config.transport_preference;
        let subscriptions = self.subscriptions.clone();
        let attempt_count = Arc::new(AtomicU32::new(0));

        self.processor_handle = Some(tokio::spawn(async move {
            loop {
                tokio::select! {
                    _ = shutdown.notified() => {
                        debug!("processor: shutdown signal received");
                        break;
                    }
                    msg = inbound_rx.recv() => {
                        match msg {
                            Some(InboundMsg::Frame(frame)) => {
                                // Track cursor and send ACK.
                                if let Some(cursor) = frame.cursor {
                                    let result = cursors.advance(frame.channel, cursor);
                                    if let CursorResult::GapDetected { expected, got } = result {
                                        warn!(
                                            channel = frame.channel.name(),
                                            expected, got,
                                            "cursor gap detected"
                                        );
                                    }
                                }

                                // Dispatch to registered receiver.
                                if let Some(receiver) = receivers.get(&frame.channel) {
                                    receiver.on_frame(frame);
                                } else {
                                    debug!(
                                        channel = frame.channel.name(),
                                        "no receiver for channel, dropping"
                                    );
                                }

                                // Reset reconnect attempt counter on successful data.
                                attempt_count.store(0, Ordering::SeqCst);
                            }
                            Some(InboundMsg::System(op)) => {
                                handle_system_op(&op);
                                attempt_count.store(0, Ordering::SeqCst);
                            }
                            Some(InboundMsg::Closed) | None => {
                                info!("transport closed");
                                state.store(
                                    ConnectionState::Disconnected as u8,
                                    Ordering::SeqCst,
                                );

                                // Attempt reconnect.
                                let attempts = attempt_count.load(Ordering::SeqCst);
                                if !reconnect_policy.should_retry(attempts) {
                                    info!("reconnect exhausted, staying disconnected");
                                    break;
                                }

                                state.store(
                                    ConnectionState::Resuming as u8,
                                    Ordering::SeqCst,
                                );

                                let delay = reconnect_policy.delay_for_attempt(attempts);
                                let jittered = if reconnect_policy.jitter {
                                    add_jitter(delay)
                                } else {
                                    delay
                                };
                                info!(
                                    attempt = attempts + 1,
                                    delay_ms = jittered.as_millis(),
                                    "reconnecting"
                                );
                                tokio::time::sleep(jittered).await;

                                let capabilities = subscriptions.active();
                                let resume = cursors.export();

                                match connect_with_preference(
                                    transport_pref,
                                    &url,
                                    client_id,
                                    token.as_deref(),
                                    &capabilities,
                                    resume,
                                )
                                .await
                                {
                                    Ok((_transport, new_rx)) => {
                                        // Reconnected — swap the inbound receiver
                                        // and continue processing. The transport
                                        // handle is dropped here (the spawned tasks
                                        // keep running via their JoinHandles).
                                        inbound_rx = new_rx;
                                        attempt_count.store(0, Ordering::SeqCst);
                                        state.store(
                                            ConnectionState::Connected as u8,
                                            Ordering::SeqCst,
                                        );
                                        info!("reconnected successfully");
                                    }
                                    Err(e) => {
                                        warn!(?e, "reconnect failed");
                                        attempt_count.fetch_add(1, Ordering::SeqCst);
                                        // Loop back — will hit Closed/None again
                                        // immediately and retry.
                                        inbound_rx.close();
                                        continue;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }));
    }
}

fn handle_system_op<C: ChannelKind>(op: &SystemOp<C>) {
    match op {
        SystemOp::Ping => {
            // Server pings are handled at the transport level for WebSocket
            // (tungstenite auto-responds). Application-level Ping is logged.
            debug!("received application-level Ping");
        }
        SystemOp::Pong => {
            debug!("received Pong");
        }
        SystemOp::Error { message } => {
            warn!(message, "server error");
        }
        SystemOp::ResumeRequired {
            channel,
            from_cursor,
        } => {
            warn!(
                channel = channel.name(),
                from_cursor, "server requires full resync from cursor"
            );
        }
        SystemOp::Goodbye { reason } => {
            info!(?reason, "server goodbye");
        }
        SystemOp::Health { status, detail } => {
            debug!(?status, ?detail, "server health");
        }
        other => {
            debug!(?other, "unhandled system op");
        }
    }
}

fn add_jitter(delay: Duration) -> Duration {
    use rand::Rng;
    let jitter_range = delay.as_millis() as f64 * 0.25;
    let jitter = rand::thread_rng().gen_range(-jitter_range..jitter_range);
    let ms = (delay.as_millis() as f64 + jitter).max(0.0);
    Duration::from_millis(ms as u64)
}