Skip to main content

zero_engine_client/
ws.rs

1//! WebSocket subscriber for the engine's `/ws` push surface.
2//!
3//! Subscribes to the engine's broadcast channel, decodes typed
4//! events, and folds them into an `EngineState` mirror. Handles
5//! reconnection with exponential backoff; the TUI status bar reads
6//! `EngineState::connection` to render a DEGRADED banner during
7//! partition.
8//!
9//! Mirrors `ConnectionManager.broadcast()` in the engine's FastAPI
10//! server — event shape is `{event: string, ts: iso8601, data: object}`.
11//! Unknown event kinds are preserved in [`EngineEvent::Unknown`] so
12//! the engine can evolve its push surface without breaking the
13//! subscriber.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use chrono::{DateTime, Utc};
19use futures::StreamExt;
20use parking_lot::RwLock;
21use serde::{Deserialize, Serialize};
22use tokio::sync::{broadcast, watch};
23use tokio::task::JoinHandle;
24use tokio_tungstenite::tungstenite;
25
26use crate::models::{Positions, Regime, Risk, V2Status};
27use crate::stat::Source;
28use crate::state::EngineState;
29
30/// Raw event shape pushed by the engine's bus poller.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32struct RawEvent {
33    event: String,
34    #[serde(default)]
35    ts: Option<String>,
36    #[serde(default)]
37    data: serde_json::Value,
38}
39
40/// Typed event the subscriber emits to consumers. Known events are
41/// decoded into strong types; anything else lands in
42/// [`EngineEvent::Unknown`] with the raw payload.
43#[derive(Debug, Clone)]
44pub enum EngineEvent {
45    Heartbeat(DateTime<Utc>),
46    Status(Box<V2Status>),
47    Positions(Box<Positions>),
48    Risk(Box<Risk>),
49    Regime(Box<Regime>),
50    Unknown {
51        event: String,
52        ts: DateTime<Utc>,
53        data: serde_json::Value,
54    },
55}
56
57/// Errors the subscriber can surface to its caller. Reconnectable
58/// errors are handled internally via backoff and never bubble out;
59/// only construction-time and shutdown errors reach the caller.
60#[derive(Debug, thiserror::Error)]
61pub enum WsError {
62    #[error("invalid websocket url: {0}")]
63    InvalidUrl(String),
64    #[error("subscriber shutdown failed: {0}")]
65    Shutdown(String),
66}
67
68/// How to jitter an exponential backoff delay before sleeping.
69///
70/// The industry-standard approach for reconnect loops (Marc Brooker,
71/// AWS Architecture Blog) is "full jitter": sleep for a uniformly
72/// random duration in `[0, exp_backoff]` rather than sleeping for
73/// exactly `exp_backoff`. This breaks synchronized reconnect waves
74/// across many clients and keeps the cluster's recovery time tight
75/// even when a partition heals for everyone at once.
76///
77/// Zero ships with one CLI per operator, so the "thundering herd"
78/// is thin, but the cost of jitter is zero and the story stays
79/// consistent across hosted deployments.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
81pub enum JitterMode {
82    /// Sleep for exactly `exp_backoff`. Deterministic; intended for
83    /// tests that need to assert exact reconnect timing.
84    None,
85    /// Sleep for `rand_uniform(0, exp_backoff)` — the "full jitter"
86    /// variant. Default for production.
87    #[default]
88    Full,
89}
90
91/// Configuration for the subscriber's reconnect behavior.
92///
93/// The backoff sequence is `min(initial * multiplier^attempt, max)`,
94/// then passed through [`JitterMode`] to produce the actual sleep
95/// duration. On a successful read (see `ReadOutcome::Connected`),
96/// the attempt counter resets to zero.
97#[derive(Debug, Clone, Copy)]
98pub struct ReconnectConfig {
99    pub initial_backoff: Duration,
100    pub max_backoff: Duration,
101    pub multiplier: u32,
102    pub jitter: JitterMode,
103}
104
105impl Default for ReconnectConfig {
106    fn default() -> Self {
107        Self {
108            initial_backoff: Duration::from_millis(500),
109            max_backoff: Duration::from_secs(30),
110            multiplier: 2,
111            jitter: JitterMode::default(),
112        }
113    }
114}
115
116/// Compute the exponential-backoff *cap* for a given attempt count
117/// (0-based: attempt 0 uses `initial`, attempt 1 uses
118/// `initial * multiplier`, etc.), clamped to `max`.
119///
120/// Pure, const-friendly, and branch-free on overflow: a runaway
121/// attempt count saturates at `max_backoff` rather than wrapping.
122/// Split out as a free function so tests can exercise the full
123/// sequence without spinning up a subscriber.
124#[must_use]
125pub fn exp_backoff_cap(
126    initial: Duration,
127    max: Duration,
128    multiplier: u32,
129    attempt: u32,
130) -> Duration {
131    // Compute `initial * multiplier^attempt` in `u128` to avoid
132    // premature overflow; cap at `max` on the way out. `multiplier`
133    // of 0 or 1 still behaves sanely (stays at `initial`).
134    let base_ms = u128::from(u64::try_from(initial.as_millis()).unwrap_or(u64::MAX));
135    let mul = u128::from(multiplier.max(1));
136    let mut factor: u128 = 1;
137    for _ in 0..attempt {
138        factor = factor.saturating_mul(mul);
139        // Once factor * base would exceed max, we've saturated; stop
140        // multiplying to avoid wasteful work on high attempt counts.
141        if factor.saturating_mul(base_ms) >= max.as_millis() {
142            break;
143        }
144    }
145    let scaled_ms = factor.saturating_mul(base_ms);
146    let capped_ms = scaled_ms.min(max.as_millis());
147    // `max_millis` can't exceed `u64::MAX` — `Duration::as_millis`
148    // returns `u128` but any practical `max_backoff` fits in `u64`.
149    Duration::from_millis(u64::try_from(capped_ms).unwrap_or(u64::MAX))
150}
151
152/// Apply `mode` to a computed backoff cap.
153///
154/// Pure + seedable for tests: `rng` produces the next random `u64`
155/// used to scale the cap when `mode` is [`JitterMode::Full`]. The
156/// result is always `<= cap` so the `max_backoff` invariant holds.
157#[must_use]
158pub fn apply_jitter(cap: Duration, mode: JitterMode, rng: &mut dyn FnMut() -> u64) -> Duration {
159    match mode {
160        JitterMode::None => cap,
161        JitterMode::Full => {
162            let ms = u64::try_from(cap.as_millis()).unwrap_or(u64::MAX);
163            if ms == 0 {
164                return Duration::ZERO;
165            }
166            // `rng() % (ms + 1)` so the range is [0, ms] inclusive;
167            // saturating the +1 keeps the upper bound at u64::MAX.
168            let modulus = ms.saturating_add(1);
169            Duration::from_millis(rng() % modulus)
170        }
171    }
172}
173
174/// Tiny xorshift64 RNG used for jitter. Cryptographic quality is
175/// not required — we just want to decorrelate reconnect waves.
176/// Kept inline to avoid pulling the `rand` crate (and its six
177/// transitive deps) into the release-small binary-size budget.
178#[derive(Debug, Clone, Copy)]
179struct XorshiftRng {
180    state: u64,
181}
182
183impl XorshiftRng {
184    fn seeded_from_now() -> Self {
185        let ns = std::time::SystemTime::now()
186            .duration_since(std::time::UNIX_EPOCH)
187            .map(|d| d.as_nanos())
188            .unwrap_or(0);
189        // Mix in a thread-id-like byte so two subscribers spawned
190        // in the same tick don't march in lockstep.
191        let seed = u64::try_from(ns & u128::from(u64::MAX)).unwrap_or(1);
192        Self { state: seed.max(1) }
193    }
194
195    fn next_u64(&mut self) -> u64 {
196        // xorshift64* — Marsaglia, 2003.
197        let mut x = self.state;
198        x ^= x << 13;
199        x ^= x >> 7;
200        x ^= x << 17;
201        self.state = x;
202        x
203    }
204}
205
206/// Handle to a running WS subscriber task.
207///
208/// Dropping the handle does not stop the task — use
209/// [`WsSubscriber::shutdown`] for a clean exit. This is deliberate:
210/// the TUI passes the handle around via `Arc` and the subscriber
211/// outlives widgets that only need to subscribe to events.
212#[derive(Debug)]
213pub struct WsSubscriber {
214    state: Arc<RwLock<EngineState>>,
215    events: broadcast::Sender<EngineEvent>,
216    shutdown_tx: watch::Sender<bool>,
217    task: JoinHandle<()>,
218}
219
220impl WsSubscriber {
221    /// Spawn a subscriber against `url`, authenticating with
222    /// `token` if provided.
223    ///
224    /// Returns immediately; the connect attempt happens in the
225    /// background task. Consumers poll [`EngineState::connection`]
226    /// on the shared state to learn whether the first connection
227    /// has landed.
228    ///
229    /// # Errors
230    /// Returns [`WsError::InvalidUrl`] if the url cannot be parsed.
231    pub fn spawn(
232        url: &str,
233        token: Option<String>,
234        state: Arc<RwLock<EngineState>>,
235    ) -> Result<Self, WsError> {
236        Self::spawn_with_config(url, token, state, ReconnectConfig::default())
237    }
238
239    /// Like [`Self::spawn`] with a custom reconnect policy; used by
240    /// tests that want to exercise backoff without real wall-clock
241    /// delay.
242    ///
243    /// # Errors
244    /// Returns [`WsError::InvalidUrl`] if the url cannot be parsed.
245    pub fn spawn_with_config(
246        url: &str,
247        token: Option<String>,
248        state: Arc<RwLock<EngineState>>,
249        reconnect: ReconnectConfig,
250    ) -> Result<Self, WsError> {
251        let url = url::Url::parse(url).map_err(|e| WsError::InvalidUrl(e.to_string()))?;
252        if !matches!(url.scheme(), "ws" | "wss") {
253            return Err(WsError::InvalidUrl(format!(
254                "unexpected scheme: {}",
255                url.scheme()
256            )));
257        }
258
259        let (events, _) = broadcast::channel(128);
260        let (shutdown_tx, shutdown_rx) = watch::channel(false);
261
262        let task = tokio::spawn(run_loop(
263            url,
264            token,
265            state.clone(),
266            events.clone(),
267            shutdown_rx,
268            reconnect,
269        ));
270
271        Ok(Self {
272            state,
273            events,
274            shutdown_tx,
275            task,
276        })
277    }
278
279    /// Shared handle to the engine-state mirror. Widgets clone this
280    /// and acquire `.read()` for the duration of a render pass.
281    #[must_use]
282    pub fn state(&self) -> Arc<RwLock<EngineState>> {
283        self.state.clone()
284    }
285
286    /// Subscribe to the raw typed event stream. Each subscriber
287    /// gets its own receiver; slow consumers are dropped by tokio's
288    /// broadcast channel, which is appropriate for a push firehose.
289    #[must_use]
290    pub fn events(&self) -> broadcast::Receiver<EngineEvent> {
291        self.events.subscribe()
292    }
293
294    /// Signal the task to exit and wait for it.
295    ///
296    /// # Errors
297    /// Returns [`WsError::Shutdown`] only when the task panicked or
298    /// was cancelled externally — a clean exit always returns `Ok`.
299    pub async fn shutdown(self) -> Result<(), WsError> {
300        let _ = self.shutdown_tx.send(true);
301        self.task
302            .await
303            .map_err(|e| WsError::Shutdown(e.to_string()))
304    }
305}
306
307async fn run_loop(
308    url: url::Url,
309    token: Option<String>,
310    state: Arc<RwLock<EngineState>>,
311    events: broadcast::Sender<EngineEvent>,
312    mut shutdown: watch::Receiver<bool>,
313    reconnect: ReconnectConfig,
314) {
315    // Attempt counter drives the exponential cap. Reset on a clean
316    // connection (`ReadOutcome::Connected`) so a one-off disconnect
317    // doesn't leave us sleeping 30 s after it heals.
318    let mut attempt: u32 = 0;
319    let mut rng = XorshiftRng::seeded_from_now();
320
321    loop {
322        if *shutdown.borrow() {
323            break;
324        }
325
326        state.write().on_reconnect_attempt(Utc::now());
327
328        match connect_and_read(&url, token.as_deref(), &state, &events, &mut shutdown).await {
329            ReadOutcome::Shutdown => break,
330            ReadOutcome::Disconnected => {
331                state.write().on_ws_disconnected();
332
333                let cap = exp_backoff_cap(
334                    reconnect.initial_backoff,
335                    reconnect.max_backoff,
336                    reconnect.multiplier,
337                    attempt,
338                );
339                let sleep = apply_jitter(cap, reconnect.jitter, &mut || rng.next_u64());
340                let sleep_ms = u64::try_from(sleep.as_millis()).unwrap_or(u64::MAX);
341                let cap_ms = u64::try_from(cap.as_millis()).unwrap_or(u64::MAX);
342                tracing::warn!(
343                    attempt,
344                    cap_ms,
345                    sleep_ms,
346                    "ws disconnected, retrying with jittered backoff"
347                );
348
349                tokio::select! {
350                    () = tokio::time::sleep(sleep) => {}
351                    _ = shutdown.changed() => break,
352                }
353
354                attempt = attempt.saturating_add(1);
355            }
356            ReadOutcome::Connected => {
357                // Only reset after a frame actually landed — a hang
358                // that ended before any read wouldn't count as
359                // successful recovery.
360                attempt = 0;
361            }
362        }
363    }
364
365    tracing::debug!("ws subscriber task exited");
366}
367
368enum ReadOutcome {
369    /// Reached after a full handshake + at least one frame.
370    Connected,
371    /// Connection failed or was lost; reconnect loop should retry.
372    Disconnected,
373    /// Shutdown channel fired; reconnect loop should exit.
374    Shutdown,
375}
376
377async fn connect_and_read(
378    url: &url::Url,
379    token: Option<&str>,
380    state: &Arc<RwLock<EngineState>>,
381    events: &broadcast::Sender<EngineEvent>,
382    shutdown: &mut watch::Receiver<bool>,
383) -> ReadOutcome {
384    let request = match build_request(url, token) {
385        Ok(r) => r,
386        Err(e) => {
387            tracing::warn!(err = %e, "invalid ws request");
388            return ReadOutcome::Disconnected;
389        }
390    };
391
392    let (ws, _resp) = match tokio_tungstenite::connect_async(request).await {
393        Ok(pair) => pair,
394        Err(e) => {
395            tracing::warn!(err = %e, "ws connect failed");
396            return ReadOutcome::Disconnected;
397        }
398    };
399
400    state.write().on_ws_connected();
401    tracing::info!(url = %url, "ws connected");
402
403    let (_sink, mut stream) = ws.split();
404    let mut any_frame = false;
405
406    loop {
407        tokio::select! {
408            _ = shutdown.changed() => {
409                if *shutdown.borrow() {
410                    tracing::debug!("shutdown requested during read");
411                    return ReadOutcome::Shutdown;
412                }
413            }
414            frame = stream.next() => {
415                match frame {
416                    Some(Ok(tungstenite::Message::Text(text))) => {
417                        any_frame = true;
418                        dispatch_frame(&text, state, events);
419                    }
420                    Some(Ok(tungstenite::Message::Binary(bin))) => {
421                        any_frame = true;
422                        if let Ok(text) = std::str::from_utf8(&bin) {
423                            dispatch_frame(text, state, events);
424                        }
425                    }
426                    Some(Ok(tungstenite::Message::Ping(_) | tungstenite::Message::Pong(_))) => {
427                        // tungstenite autoresponds to pings; pongs
428                        // still bump freshness.
429                        any_frame = true;
430                        state.write().apply_heartbeat(Utc::now());
431                    }
432                    Some(Ok(tungstenite::Message::Close(_))) | None => {
433                        tracing::info!("ws closed by peer");
434                        state.write().on_ws_disconnected();
435                        return if any_frame {
436                            ReadOutcome::Connected
437                        } else {
438                            ReadOutcome::Disconnected
439                        };
440                    }
441                    Some(Ok(tungstenite::Message::Frame(_))) => {
442                        // Raw frames are not emitted by the default
443                        // tungstenite reader config.
444                    }
445                    Some(Err(e)) => {
446                        tracing::warn!(err = %e, "ws read error");
447                        state.write().on_ws_disconnected();
448                        return ReadOutcome::Disconnected;
449                    }
450                }
451            }
452        }
453    }
454}
455
456fn build_request(
457    url: &url::Url,
458    token: Option<&str>,
459) -> Result<tungstenite::handshake::client::Request, String> {
460    use tungstenite::client::IntoClientRequest as _;
461
462    let mut request = url
463        .as_str()
464        .into_client_request()
465        .map_err(|e| e.to_string())?;
466
467    if let Some(t) = token {
468        let value = format!("Bearer {t}")
469            .parse::<tungstenite::http::HeaderValue>()
470            .map_err(|e| e.to_string())?;
471        request.headers_mut().insert("Authorization", value);
472    }
473
474    Ok(request)
475}
476
477fn dispatch_frame(
478    text: &str,
479    state: &Arc<RwLock<EngineState>>,
480    events: &broadcast::Sender<EngineEvent>,
481) {
482    let raw: RawEvent = match serde_json::from_str(text) {
483        Ok(raw) => raw,
484        Err(e) => {
485            tracing::debug!(err = %e, preview = %text.chars().take(80).collect::<String>(), "ws decode error");
486            return;
487        }
488    };
489
490    let ts = raw
491        .ts
492        .as_deref()
493        .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
494        .map_or_else(Utc::now, |dt| dt.with_timezone(&Utc));
495
496    let evt = match raw.event.as_str() {
497        "heartbeat" => {
498            state.write().apply_heartbeat(ts);
499            EngineEvent::Heartbeat(ts)
500        }
501        "status" | "v2_status" => match serde_json::from_value::<V2Status>(raw.data.clone()) {
502            Ok(s) => {
503                state.write().apply_status(s.clone(), ts, Source::Ws);
504                EngineEvent::Status(Box::new(s))
505            }
506            Err(e) => {
507                tracing::debug!(err = %e, "status decode error");
508                EngineEvent::Unknown {
509                    event: raw.event,
510                    ts,
511                    data: raw.data,
512                }
513            }
514        },
515        "positions" | "positions_update" => {
516            match serde_json::from_value::<Positions>(raw.data.clone()) {
517                Ok(p) => {
518                    state.write().apply_positions(p.clone(), ts, Source::Ws);
519                    EngineEvent::Positions(Box::new(p))
520                }
521                Err(e) => {
522                    tracing::debug!(err = %e, "positions decode error");
523                    EngineEvent::Unknown {
524                        event: raw.event,
525                        ts,
526                        data: raw.data,
527                    }
528                }
529            }
530        }
531        "risk" | "risk_update" => match serde_json::from_value::<Risk>(raw.data.clone()) {
532            Ok(r) => {
533                state.write().apply_risk(r.clone(), ts, Source::Ws);
534                EngineEvent::Risk(Box::new(r))
535            }
536            Err(e) => {
537                tracing::debug!(err = %e, "risk decode error");
538                EngineEvent::Unknown {
539                    event: raw.event,
540                    ts,
541                    data: raw.data,
542                }
543            }
544        },
545        "regime" | "regime_update" => match serde_json::from_value::<Regime>(raw.data.clone()) {
546            Ok(r) => {
547                state.write().apply_regime(r.clone(), ts, Source::Ws);
548                EngineEvent::Regime(Box::new(r))
549            }
550            Err(e) => {
551                tracing::debug!(err = %e, "regime decode error");
552                EngineEvent::Unknown {
553                    event: raw.event,
554                    ts,
555                    data: raw.data,
556                }
557            }
558        },
559        _ => EngineEvent::Unknown {
560            event: raw.event,
561            ts,
562            data: raw.data,
563        },
564    };
565
566    // Best-effort send; dropped receivers are fine. The state mirror
567    // is the durable copy; broadcast events are a convenience tap.
568    let _ = events.send(evt);
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    // ── Pure backoff math ──────────────────────────────────────────
576
577    #[test]
578    fn exp_backoff_cap_starts_at_initial_on_attempt_zero() {
579        let d = exp_backoff_cap(Duration::from_millis(500), Duration::from_secs(30), 2, 0);
580        assert_eq!(d, Duration::from_millis(500));
581    }
582
583    #[test]
584    fn exp_backoff_cap_doubles_each_attempt_until_max() {
585        let initial = Duration::from_millis(500);
586        let max = Duration::from_secs(30);
587        // 500 → 1000 → 2000 → 4000 → 8000 → 16000 → 30000 (cap)
588        let seq: Vec<u128> = (0..8)
589            .map(|a| exp_backoff_cap(initial, max, 2, a).as_millis())
590            .collect();
591        assert_eq!(
592            seq,
593            vec![500, 1_000, 2_000, 4_000, 8_000, 16_000, 30_000, 30_000]
594        );
595    }
596
597    #[test]
598    fn exp_backoff_cap_saturates_on_runaway_attempt() {
599        // A truly pathological attempt count must not panic or
600        // overflow; it just pins at max_backoff.
601        let d = exp_backoff_cap(
602            Duration::from_millis(500),
603            Duration::from_secs(30),
604            2,
605            1_000_000,
606        );
607        assert_eq!(d, Duration::from_secs(30));
608    }
609
610    #[test]
611    fn exp_backoff_cap_with_multiplier_one_stays_at_initial() {
612        let d = exp_backoff_cap(Duration::from_millis(500), Duration::from_secs(30), 1, 5);
613        assert_eq!(d, Duration::from_millis(500));
614    }
615
616    // ── Jitter ─────────────────────────────────────────────────────
617
618    #[test]
619    fn jitter_none_returns_cap_unchanged() {
620        let mut rng = || 0_u64;
621        let out = apply_jitter(Duration::from_millis(1_234), JitterMode::None, &mut rng);
622        assert_eq!(out, Duration::from_millis(1_234));
623    }
624
625    #[test]
626    fn jitter_full_is_bounded_by_cap() {
627        // 10 000 draws from the real xorshift — every sample must
628        // land in [0, cap]. A one-off violation here would break
629        // the max_backoff invariant.
630        let mut rng = XorshiftRng::seeded_from_now();
631        let cap = Duration::from_millis(5_000);
632        for _ in 0..10_000 {
633            let d = apply_jitter(cap, JitterMode::Full, &mut || rng.next_u64());
634            assert!(d <= cap, "jitter produced {d:?} > cap {cap:?}");
635        }
636    }
637
638    #[test]
639    fn jitter_full_varies_across_draws() {
640        // Sanity check that jitter actually jitters — if the RNG
641        // were constant the sequence would collapse to one value.
642        let mut rng = XorshiftRng::seeded_from_now();
643        let cap = Duration::from_millis(5_000);
644        let samples: Vec<_> = (0..100)
645            .map(|_| apply_jitter(cap, JitterMode::Full, &mut || rng.next_u64()))
646            .collect();
647        let unique: std::collections::BTreeSet<_> = samples.iter().collect();
648        assert!(
649            unique.len() > 1,
650            "expected at least two distinct jitter values, got {}",
651            unique.len()
652        );
653    }
654
655    #[test]
656    fn jitter_full_with_zero_cap_returns_zero() {
657        let mut rng = || 0xDEAD_BEEF_u64;
658        let out = apply_jitter(Duration::ZERO, JitterMode::Full, &mut rng);
659        assert_eq!(out, Duration::ZERO);
660    }
661
662    // ── xorshift ───────────────────────────────────────────────────
663
664    #[test]
665    fn xorshift_is_deterministic_and_non_trivial() {
666        let mut a = XorshiftRng { state: 0x1234_5678 };
667        let mut b = XorshiftRng { state: 0x1234_5678 };
668        let seq_a: Vec<u64> = (0..16).map(|_| a.next_u64()).collect();
669        let seq_b: Vec<u64> = (0..16).map(|_| b.next_u64()).collect();
670        assert_eq!(seq_a, seq_b, "same seed must produce same sequence");
671        let unique: std::collections::BTreeSet<_> = seq_a.iter().collect();
672        assert!(
673            unique.len() >= 15,
674            "xorshift should not cycle in 16 draws, got {}",
675            unique.len()
676        );
677    }
678}