Skip to main content

pushwire_client/
session.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
4use std::time::Duration;
5
6use dashmap::DashMap;
7use pushwire_core::{ChannelKind, Frame, SystemOp};
8use tokio::sync::{Notify, mpsc};
9use tracing::{debug, info, warn};
10use uuid::Uuid;
11
12use crate::connection::{ActiveTransport, InboundMsg, connect_with_preference};
13use crate::cursor::{CursorResult, CursorTracker};
14use crate::dispatch::ChannelReceiver;
15use crate::reconnect::ReconnectPolicy;
16use crate::subscription::SubscriptionTracker;
17
18pub use crate::connection::TransportPreference;
19
20// ---------------------------------------------------------------------------
21// Connection state
22// ---------------------------------------------------------------------------
23
24/// Connection state machine.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[repr(u8)]
27pub enum ConnectionState {
28    Disconnected = 0,
29    Connecting = 1,
30    Connected = 2,
31    Resuming = 3,
32}
33
34// ---------------------------------------------------------------------------
35// Configuration
36// ---------------------------------------------------------------------------
37
38/// Configuration for a push-wire client connection.
39#[non_exhaustive]
40pub struct ClientConfig {
41    pub url: String,
42    pub client_id: Uuid,
43    pub token: Option<String>,
44    pub reconnect: ReconnectPolicy,
45    pub transport_preference: TransportPreference,
46    pub binary_mode: bool,
47}
48
49impl ClientConfig {
50    pub fn new(url: impl Into<String>) -> Self {
51        Self {
52            url: url.into(),
53            client_id: Uuid::new_v4(),
54            token: None,
55            reconnect: ReconnectPolicy::default(),
56            transport_preference: TransportPreference::WsFirst,
57            binary_mode: false,
58        }
59    }
60}
61
62// ---------------------------------------------------------------------------
63// Errors
64// ---------------------------------------------------------------------------
65
66/// Error types for client operations.
67#[derive(Debug, thiserror::Error)]
68pub enum ConnectError {
69    #[error("transport error: {0}")]
70    Transport(String),
71    #[error("auth rejected: {0}")]
72    AuthRejected(String),
73    #[error("timeout")]
74    Timeout,
75}
76
77#[derive(Debug, thiserror::Error)]
78pub enum SendError {
79    #[error("not connected")]
80    NotConnected,
81    #[error("channel closed")]
82    ChannelClosed,
83    #[error("serialization error: {0}")]
84    Serialize(#[from] serde_json::Error),
85}
86
87// ---------------------------------------------------------------------------
88// PushClient
89// ---------------------------------------------------------------------------
90
91/// Generic multiplexed push client.
92///
93/// Parameterized by `C: ChannelKind` — the consumer defines their own channel
94/// taxonomy. Register handlers with [`on`](PushClient::on), then call
95/// [`connect`](PushClient::connect) to start receiving frames.
96pub struct PushClient<C: ChannelKind> {
97    config: ClientConfig,
98    cursors: Arc<CursorTracker<C>>,
99    receivers: Arc<DashMap<C, Arc<dyn ChannelReceiver<C>>>>,
100    subscriptions: Arc<SubscriptionTracker<C>>,
101    state: Arc<AtomicU8>,
102    transport: Option<ActiveTransport<C>>,
103    shutdown: Arc<Notify>,
104    processor_handle: Option<tokio::task::JoinHandle<()>>,
105}
106
107impl<C: ChannelKind> PushClient<C> {
108    pub fn new(config: ClientConfig) -> Self {
109        Self {
110            config,
111            cursors: Arc::new(CursorTracker::new()),
112            receivers: Arc::new(DashMap::new()),
113            subscriptions: Arc::new(SubscriptionTracker::new()),
114            state: Arc::new(AtomicU8::new(ConnectionState::Disconnected as u8)),
115            transport: None,
116            shutdown: Arc::new(Notify::new()),
117            processor_handle: None,
118        }
119    }
120
121    /// Register a handler for a channel. Must be called before `connect()`.
122    pub fn on(&mut self, channel: C, receiver: impl ChannelReceiver<C>) {
123        self.subscriptions.subscribe(&[channel]);
124        self.receivers.insert(channel, Arc::new(receiver));
125    }
126
127    /// Connect to the server. Performs auth handshake and starts the
128    /// receive loop.
129    pub async fn connect(&mut self) -> Result<(), ConnectError> {
130        self.set_state(ConnectionState::Connecting);
131
132        let capabilities = self.subscriptions.active();
133        let resume_cursors = self.cursors.export();
134
135        let (transport, inbound_rx) = connect_with_preference(
136            self.config.transport_preference,
137            &self.config.url,
138            self.config.client_id,
139            self.config.token.as_deref(),
140            &capabilities,
141            resume_cursors,
142        )
143        .await?;
144
145        self.transport = Some(transport);
146        self.set_state(ConnectionState::Connected);
147
148        // Spawn the processor task that dispatches inbound messages.
149        self.spawn_processor(inbound_rx);
150
151        info!(client_id = ?self.config.client_id, "connected");
152        Ok(())
153    }
154
155    /// Send a frame to the server (client → server).
156    pub async fn send(&self, frame: Frame<C>) -> Result<(), SendError> {
157        if self.state() != ConnectionState::Connected {
158            return Err(SendError::NotConnected);
159        }
160        match &self.transport {
161            Some(t) => t.send_frame(frame).await,
162            None => Err(SendError::NotConnected),
163        }
164    }
165
166    /// Subscribe to additional channels after connect.
167    pub async fn subscribe(&self, channels: &[C]) -> Result<(), SendError> {
168        if let Some(op) = self.subscriptions.subscribe(channels)
169            && let Some(t) = &self.transport
170        {
171            t.send_system(op).await?;
172        }
173        Ok(())
174    }
175
176    /// Unsubscribe from channels.
177    pub async fn unsubscribe(&self, channels: &[C]) -> Result<(), SendError> {
178        if let Some(op) = self.subscriptions.unsubscribe(channels)
179            && let Some(t) = &self.transport
180        {
181            t.send_system(op).await?;
182        }
183        Ok(())
184    }
185
186    /// Graceful disconnect.
187    pub async fn disconnect(&mut self) -> Result<(), SendError> {
188        self.shutdown.notify_waiters();
189
190        if let Some(t) = &self.transport {
191            let _ = t.send_system(SystemOp::Goodbye { reason: None }).await;
192        }
193
194        if let Some(transport) = self.transport.take() {
195            transport.close().await;
196        }
197
198        if let Some(handle) = self.processor_handle.take() {
199            handle.abort();
200        }
201
202        self.set_state(ConnectionState::Disconnected);
203        info!(client_id = ?self.config.client_id, "disconnected");
204        Ok(())
205    }
206
207    /// Current connection state.
208    pub fn state(&self) -> ConnectionState {
209        match self.state.load(Ordering::SeqCst) {
210            0 => ConnectionState::Disconnected,
211            1 => ConnectionState::Connecting,
212            2 => ConnectionState::Connected,
213            3 => ConnectionState::Resuming,
214            _ => ConnectionState::Disconnected,
215        }
216    }
217
218    /// Per-channel cursor values (for diagnostics / resume).
219    pub fn cursors(&self) -> HashMap<C, u64> {
220        self.cursors.export()
221    }
222
223    // -----------------------------------------------------------------------
224    // Internal
225    // -----------------------------------------------------------------------
226
227    fn set_state(&self, state: ConnectionState) {
228        self.state.store(state as u8, Ordering::SeqCst);
229    }
230
231    fn spawn_processor(&mut self, mut inbound_rx: mpsc::Receiver<InboundMsg<C>>) {
232        let cursors = self.cursors.clone();
233        let receivers = self.receivers.clone();
234        let state = self.state.clone();
235        let shutdown = self.shutdown.clone();
236
237        // Reconnect state — shared with the processor so it can trigger
238        // reconnection on transport close.
239        let reconnect_policy = self.config.reconnect.clone();
240        let url = self.config.url.clone();
241        let client_id = self.config.client_id;
242        let token = self.config.token.clone();
243        let transport_pref = self.config.transport_preference;
244        let subscriptions = self.subscriptions.clone();
245        let attempt_count = Arc::new(AtomicU32::new(0));
246
247        self.processor_handle = Some(tokio::spawn(async move {
248            loop {
249                tokio::select! {
250                    _ = shutdown.notified() => {
251                        debug!("processor: shutdown signal received");
252                        break;
253                    }
254                    msg = inbound_rx.recv() => {
255                        match msg {
256                            Some(InboundMsg::Frame(frame)) => {
257                                // Track cursor and send ACK.
258                                if let Some(cursor) = frame.cursor {
259                                    let result = cursors.advance(frame.channel, cursor);
260                                    if let CursorResult::GapDetected { expected, got } = result {
261                                        warn!(
262                                            channel = frame.channel.name(),
263                                            expected, got,
264                                            "cursor gap detected"
265                                        );
266                                    }
267                                }
268
269                                // Dispatch to registered receiver.
270                                if let Some(receiver) = receivers.get(&frame.channel) {
271                                    receiver.on_frame(frame);
272                                } else {
273                                    debug!(
274                                        channel = frame.channel.name(),
275                                        "no receiver for channel, dropping"
276                                    );
277                                }
278
279                                // Reset reconnect attempt counter on successful data.
280                                attempt_count.store(0, Ordering::SeqCst);
281                            }
282                            Some(InboundMsg::System(op)) => {
283                                handle_system_op(&op);
284                                attempt_count.store(0, Ordering::SeqCst);
285                            }
286                            Some(InboundMsg::Closed) | None => {
287                                info!("transport closed");
288                                state.store(
289                                    ConnectionState::Disconnected as u8,
290                                    Ordering::SeqCst,
291                                );
292
293                                // Attempt reconnect.
294                                let attempts = attempt_count.load(Ordering::SeqCst);
295                                if !reconnect_policy.should_retry(attempts) {
296                                    info!("reconnect exhausted, staying disconnected");
297                                    break;
298                                }
299
300                                state.store(
301                                    ConnectionState::Resuming as u8,
302                                    Ordering::SeqCst,
303                                );
304
305                                let delay = reconnect_policy.delay_for_attempt(attempts);
306                                let jittered = if reconnect_policy.jitter {
307                                    add_jitter(delay)
308                                } else {
309                                    delay
310                                };
311                                info!(
312                                    attempt = attempts + 1,
313                                    delay_ms = jittered.as_millis(),
314                                    "reconnecting"
315                                );
316                                tokio::time::sleep(jittered).await;
317
318                                let capabilities = subscriptions.active();
319                                let resume = cursors.export();
320
321                                match connect_with_preference(
322                                    transport_pref,
323                                    &url,
324                                    client_id,
325                                    token.as_deref(),
326                                    &capabilities,
327                                    resume,
328                                )
329                                .await
330                                {
331                                    Ok((_transport, new_rx)) => {
332                                        // Reconnected — swap the inbound receiver
333                                        // and continue processing. The transport
334                                        // handle is dropped here (the spawned tasks
335                                        // keep running via their JoinHandles).
336                                        inbound_rx = new_rx;
337                                        attempt_count.store(0, Ordering::SeqCst);
338                                        state.store(
339                                            ConnectionState::Connected as u8,
340                                            Ordering::SeqCst,
341                                        );
342                                        info!("reconnected successfully");
343                                    }
344                                    Err(e) => {
345                                        warn!(?e, "reconnect failed");
346                                        attempt_count.fetch_add(1, Ordering::SeqCst);
347                                        // Loop back — will hit Closed/None again
348                                        // immediately and retry.
349                                        inbound_rx.close();
350                                        continue;
351                                    }
352                                }
353                            }
354                        }
355                    }
356                }
357            }
358        }));
359    }
360}
361
362fn handle_system_op<C: ChannelKind>(op: &SystemOp<C>) {
363    match op {
364        SystemOp::Ping => {
365            // Server pings are handled at the transport level for WebSocket
366            // (tungstenite auto-responds). Application-level Ping is logged.
367            debug!("received application-level Ping");
368        }
369        SystemOp::Pong => {
370            debug!("received Pong");
371        }
372        SystemOp::Error { message } => {
373            warn!(message, "server error");
374        }
375        SystemOp::ResumeRequired {
376            channel,
377            from_cursor,
378        } => {
379            warn!(
380                channel = channel.name(),
381                from_cursor, "server requires full resync from cursor"
382            );
383        }
384        SystemOp::Goodbye { reason } => {
385            info!(?reason, "server goodbye");
386        }
387        SystemOp::Health { status, detail } => {
388            debug!(?status, ?detail, "server health");
389        }
390        other => {
391            debug!(?other, "unhandled system op");
392        }
393    }
394}
395
396fn add_jitter(delay: Duration) -> Duration {
397    use rand::Rng;
398    let jitter_range = delay.as_millis() as f64 * 0.25;
399    let jitter = rand::thread_rng().gen_range(-jitter_range..jitter_range);
400    let ms = (delay.as_millis() as f64 + jitter).max(0.0);
401    Duration::from_millis(ms as u64)
402}