1use std::sync::Arc;
16use std::time::Duration;
17
18use chrono::{DateTime, Utc};
19use futures::StreamExt;
20use parking_lot::RwLock;
21use serde::{Deserialize, Serialize};
22use tokio::sync::{broadcast, watch};
23use tokio::task::JoinHandle;
24use tokio_tungstenite::tungstenite;
25
26use crate::models::{Positions, Regime, Risk, V2Status};
27use crate::stat::Source;
28use crate::state::EngineState;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32struct RawEvent {
33 event: String,
34 #[serde(default)]
35 ts: Option<String>,
36 #[serde(default)]
37 data: serde_json::Value,
38}
39
40#[derive(Debug, Clone)]
44pub enum EngineEvent {
45 Heartbeat(DateTime<Utc>),
46 Status(Box<V2Status>),
47 Positions(Box<Positions>),
48 Risk(Box<Risk>),
49 Regime(Box<Regime>),
50 Unknown {
51 event: String,
52 ts: DateTime<Utc>,
53 data: serde_json::Value,
54 },
55}
56
57#[derive(Debug, thiserror::Error)]
61pub enum WsError {
62 #[error("invalid websocket url: {0}")]
63 InvalidUrl(String),
64 #[error("subscriber shutdown failed: {0}")]
65 Shutdown(String),
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
81pub enum JitterMode {
82 None,
85 #[default]
88 Full,
89}
90
91#[derive(Debug, Clone, Copy)]
98pub struct ReconnectConfig {
99 pub initial_backoff: Duration,
100 pub max_backoff: Duration,
101 pub multiplier: u32,
102 pub jitter: JitterMode,
103}
104
105impl Default for ReconnectConfig {
106 fn default() -> Self {
107 Self {
108 initial_backoff: Duration::from_millis(500),
109 max_backoff: Duration::from_secs(30),
110 multiplier: 2,
111 jitter: JitterMode::default(),
112 }
113 }
114}
115
116#[must_use]
125pub fn exp_backoff_cap(
126 initial: Duration,
127 max: Duration,
128 multiplier: u32,
129 attempt: u32,
130) -> Duration {
131 let base_ms = u128::from(u64::try_from(initial.as_millis()).unwrap_or(u64::MAX));
135 let mul = u128::from(multiplier.max(1));
136 let mut factor: u128 = 1;
137 for _ in 0..attempt {
138 factor = factor.saturating_mul(mul);
139 if factor.saturating_mul(base_ms) >= max.as_millis() {
142 break;
143 }
144 }
145 let scaled_ms = factor.saturating_mul(base_ms);
146 let capped_ms = scaled_ms.min(max.as_millis());
147 Duration::from_millis(u64::try_from(capped_ms).unwrap_or(u64::MAX))
150}
151
152#[must_use]
158pub fn apply_jitter(cap: Duration, mode: JitterMode, rng: &mut dyn FnMut() -> u64) -> Duration {
159 match mode {
160 JitterMode::None => cap,
161 JitterMode::Full => {
162 let ms = u64::try_from(cap.as_millis()).unwrap_or(u64::MAX);
163 if ms == 0 {
164 return Duration::ZERO;
165 }
166 let modulus = ms.saturating_add(1);
169 Duration::from_millis(rng() % modulus)
170 }
171 }
172}
173
174#[derive(Debug, Clone, Copy)]
179struct XorshiftRng {
180 state: u64,
181}
182
183impl XorshiftRng {
184 fn seeded_from_now() -> Self {
185 let ns = std::time::SystemTime::now()
186 .duration_since(std::time::UNIX_EPOCH)
187 .map(|d| d.as_nanos())
188 .unwrap_or(0);
189 let seed = u64::try_from(ns & u128::from(u64::MAX)).unwrap_or(1);
192 Self { state: seed.max(1) }
193 }
194
195 fn next_u64(&mut self) -> u64 {
196 let mut x = self.state;
198 x ^= x << 13;
199 x ^= x >> 7;
200 x ^= x << 17;
201 self.state = x;
202 x
203 }
204}
205
206#[derive(Debug)]
213pub struct WsSubscriber {
214 state: Arc<RwLock<EngineState>>,
215 events: broadcast::Sender<EngineEvent>,
216 shutdown_tx: watch::Sender<bool>,
217 task: JoinHandle<()>,
218}
219
220impl WsSubscriber {
221 pub fn spawn(
232 url: &str,
233 token: Option<String>,
234 state: Arc<RwLock<EngineState>>,
235 ) -> Result<Self, WsError> {
236 Self::spawn_with_config(url, token, state, ReconnectConfig::default())
237 }
238
239 pub fn spawn_with_config(
246 url: &str,
247 token: Option<String>,
248 state: Arc<RwLock<EngineState>>,
249 reconnect: ReconnectConfig,
250 ) -> Result<Self, WsError> {
251 let url = url::Url::parse(url).map_err(|e| WsError::InvalidUrl(e.to_string()))?;
252 if !matches!(url.scheme(), "ws" | "wss") {
253 return Err(WsError::InvalidUrl(format!(
254 "unexpected scheme: {}",
255 url.scheme()
256 )));
257 }
258
259 let (events, _) = broadcast::channel(128);
260 let (shutdown_tx, shutdown_rx) = watch::channel(false);
261
262 let task = tokio::spawn(run_loop(
263 url,
264 token,
265 state.clone(),
266 events.clone(),
267 shutdown_rx,
268 reconnect,
269 ));
270
271 Ok(Self {
272 state,
273 events,
274 shutdown_tx,
275 task,
276 })
277 }
278
279 #[must_use]
282 pub fn state(&self) -> Arc<RwLock<EngineState>> {
283 self.state.clone()
284 }
285
286 #[must_use]
290 pub fn events(&self) -> broadcast::Receiver<EngineEvent> {
291 self.events.subscribe()
292 }
293
294 pub async fn shutdown(self) -> Result<(), WsError> {
300 let _ = self.shutdown_tx.send(true);
301 self.task
302 .await
303 .map_err(|e| WsError::Shutdown(e.to_string()))
304 }
305}
306
307async fn run_loop(
308 url: url::Url,
309 token: Option<String>,
310 state: Arc<RwLock<EngineState>>,
311 events: broadcast::Sender<EngineEvent>,
312 mut shutdown: watch::Receiver<bool>,
313 reconnect: ReconnectConfig,
314) {
315 let mut attempt: u32 = 0;
319 let mut rng = XorshiftRng::seeded_from_now();
320
321 loop {
322 if *shutdown.borrow() {
323 break;
324 }
325
326 state.write().on_reconnect_attempt(Utc::now());
327
328 match connect_and_read(&url, token.as_deref(), &state, &events, &mut shutdown).await {
329 ReadOutcome::Shutdown => break,
330 ReadOutcome::Disconnected => {
331 state.write().on_ws_disconnected();
332
333 let cap = exp_backoff_cap(
334 reconnect.initial_backoff,
335 reconnect.max_backoff,
336 reconnect.multiplier,
337 attempt,
338 );
339 let sleep = apply_jitter(cap, reconnect.jitter, &mut || rng.next_u64());
340 let sleep_ms = u64::try_from(sleep.as_millis()).unwrap_or(u64::MAX);
341 let cap_ms = u64::try_from(cap.as_millis()).unwrap_or(u64::MAX);
342 tracing::warn!(
343 attempt,
344 cap_ms,
345 sleep_ms,
346 "ws disconnected, retrying with jittered backoff"
347 );
348
349 tokio::select! {
350 () = tokio::time::sleep(sleep) => {}
351 _ = shutdown.changed() => break,
352 }
353
354 attempt = attempt.saturating_add(1);
355 }
356 ReadOutcome::Connected => {
357 attempt = 0;
361 }
362 }
363 }
364
365 tracing::debug!("ws subscriber task exited");
366}
367
368enum ReadOutcome {
369 Connected,
371 Disconnected,
373 Shutdown,
375}
376
377async fn connect_and_read(
378 url: &url::Url,
379 token: Option<&str>,
380 state: &Arc<RwLock<EngineState>>,
381 events: &broadcast::Sender<EngineEvent>,
382 shutdown: &mut watch::Receiver<bool>,
383) -> ReadOutcome {
384 let request = match build_request(url, token) {
385 Ok(r) => r,
386 Err(e) => {
387 tracing::warn!(err = %e, "invalid ws request");
388 return ReadOutcome::Disconnected;
389 }
390 };
391
392 let (ws, _resp) = match tokio_tungstenite::connect_async(request).await {
393 Ok(pair) => pair,
394 Err(e) => {
395 tracing::warn!(err = %e, "ws connect failed");
396 return ReadOutcome::Disconnected;
397 }
398 };
399
400 state.write().on_ws_connected();
401 tracing::info!(url = %url, "ws connected");
402
403 let (_sink, mut stream) = ws.split();
404 let mut any_frame = false;
405
406 loop {
407 tokio::select! {
408 _ = shutdown.changed() => {
409 if *shutdown.borrow() {
410 tracing::debug!("shutdown requested during read");
411 return ReadOutcome::Shutdown;
412 }
413 }
414 frame = stream.next() => {
415 match frame {
416 Some(Ok(tungstenite::Message::Text(text))) => {
417 any_frame = true;
418 dispatch_frame(&text, state, events);
419 }
420 Some(Ok(tungstenite::Message::Binary(bin))) => {
421 any_frame = true;
422 if let Ok(text) = std::str::from_utf8(&bin) {
423 dispatch_frame(text, state, events);
424 }
425 }
426 Some(Ok(tungstenite::Message::Ping(_) | tungstenite::Message::Pong(_))) => {
427 any_frame = true;
430 state.write().apply_heartbeat(Utc::now());
431 }
432 Some(Ok(tungstenite::Message::Close(_))) | None => {
433 tracing::info!("ws closed by peer");
434 state.write().on_ws_disconnected();
435 return if any_frame {
436 ReadOutcome::Connected
437 } else {
438 ReadOutcome::Disconnected
439 };
440 }
441 Some(Ok(tungstenite::Message::Frame(_))) => {
442 }
445 Some(Err(e)) => {
446 tracing::warn!(err = %e, "ws read error");
447 state.write().on_ws_disconnected();
448 return ReadOutcome::Disconnected;
449 }
450 }
451 }
452 }
453 }
454}
455
456fn build_request(
457 url: &url::Url,
458 token: Option<&str>,
459) -> Result<tungstenite::handshake::client::Request, String> {
460 use tungstenite::client::IntoClientRequest as _;
461
462 let mut request = url
463 .as_str()
464 .into_client_request()
465 .map_err(|e| e.to_string())?;
466
467 if let Some(t) = token {
468 let value = format!("Bearer {t}")
469 .parse::<tungstenite::http::HeaderValue>()
470 .map_err(|e| e.to_string())?;
471 request.headers_mut().insert("Authorization", value);
472 }
473
474 Ok(request)
475}
476
477fn dispatch_frame(
478 text: &str,
479 state: &Arc<RwLock<EngineState>>,
480 events: &broadcast::Sender<EngineEvent>,
481) {
482 let raw: RawEvent = match serde_json::from_str(text) {
483 Ok(raw) => raw,
484 Err(e) => {
485 tracing::debug!(err = %e, preview = %text.chars().take(80).collect::<String>(), "ws decode error");
486 return;
487 }
488 };
489
490 let ts = raw
491 .ts
492 .as_deref()
493 .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
494 .map_or_else(Utc::now, |dt| dt.with_timezone(&Utc));
495
496 let evt = match raw.event.as_str() {
497 "heartbeat" => {
498 state.write().apply_heartbeat(ts);
499 EngineEvent::Heartbeat(ts)
500 }
501 "status" | "v2_status" => match serde_json::from_value::<V2Status>(raw.data.clone()) {
502 Ok(s) => {
503 state.write().apply_status(s.clone(), ts, Source::Ws);
504 EngineEvent::Status(Box::new(s))
505 }
506 Err(e) => {
507 tracing::debug!(err = %e, "status decode error");
508 EngineEvent::Unknown {
509 event: raw.event,
510 ts,
511 data: raw.data,
512 }
513 }
514 },
515 "positions" | "positions_update" => {
516 match serde_json::from_value::<Positions>(raw.data.clone()) {
517 Ok(p) => {
518 state.write().apply_positions(p.clone(), ts, Source::Ws);
519 EngineEvent::Positions(Box::new(p))
520 }
521 Err(e) => {
522 tracing::debug!(err = %e, "positions decode error");
523 EngineEvent::Unknown {
524 event: raw.event,
525 ts,
526 data: raw.data,
527 }
528 }
529 }
530 }
531 "risk" | "risk_update" => match serde_json::from_value::<Risk>(raw.data.clone()) {
532 Ok(r) => {
533 state.write().apply_risk(r.clone(), ts, Source::Ws);
534 EngineEvent::Risk(Box::new(r))
535 }
536 Err(e) => {
537 tracing::debug!(err = %e, "risk decode error");
538 EngineEvent::Unknown {
539 event: raw.event,
540 ts,
541 data: raw.data,
542 }
543 }
544 },
545 "regime" | "regime_update" => match serde_json::from_value::<Regime>(raw.data.clone()) {
546 Ok(r) => {
547 state.write().apply_regime(r.clone(), ts, Source::Ws);
548 EngineEvent::Regime(Box::new(r))
549 }
550 Err(e) => {
551 tracing::debug!(err = %e, "regime decode error");
552 EngineEvent::Unknown {
553 event: raw.event,
554 ts,
555 data: raw.data,
556 }
557 }
558 },
559 _ => EngineEvent::Unknown {
560 event: raw.event,
561 ts,
562 data: raw.data,
563 },
564 };
565
566 let _ = events.send(evt);
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
578 fn exp_backoff_cap_starts_at_initial_on_attempt_zero() {
579 let d = exp_backoff_cap(Duration::from_millis(500), Duration::from_secs(30), 2, 0);
580 assert_eq!(d, Duration::from_millis(500));
581 }
582
583 #[test]
584 fn exp_backoff_cap_doubles_each_attempt_until_max() {
585 let initial = Duration::from_millis(500);
586 let max = Duration::from_secs(30);
587 let seq: Vec<u128> = (0..8)
589 .map(|a| exp_backoff_cap(initial, max, 2, a).as_millis())
590 .collect();
591 assert_eq!(
592 seq,
593 vec![500, 1_000, 2_000, 4_000, 8_000, 16_000, 30_000, 30_000]
594 );
595 }
596
597 #[test]
598 fn exp_backoff_cap_saturates_on_runaway_attempt() {
599 let d = exp_backoff_cap(
602 Duration::from_millis(500),
603 Duration::from_secs(30),
604 2,
605 1_000_000,
606 );
607 assert_eq!(d, Duration::from_secs(30));
608 }
609
610 #[test]
611 fn exp_backoff_cap_with_multiplier_one_stays_at_initial() {
612 let d = exp_backoff_cap(Duration::from_millis(500), Duration::from_secs(30), 1, 5);
613 assert_eq!(d, Duration::from_millis(500));
614 }
615
616 #[test]
619 fn jitter_none_returns_cap_unchanged() {
620 let mut rng = || 0_u64;
621 let out = apply_jitter(Duration::from_millis(1_234), JitterMode::None, &mut rng);
622 assert_eq!(out, Duration::from_millis(1_234));
623 }
624
625 #[test]
626 fn jitter_full_is_bounded_by_cap() {
627 let mut rng = XorshiftRng::seeded_from_now();
631 let cap = Duration::from_millis(5_000);
632 for _ in 0..10_000 {
633 let d = apply_jitter(cap, JitterMode::Full, &mut || rng.next_u64());
634 assert!(d <= cap, "jitter produced {d:?} > cap {cap:?}");
635 }
636 }
637
638 #[test]
639 fn jitter_full_varies_across_draws() {
640 let mut rng = XorshiftRng::seeded_from_now();
643 let cap = Duration::from_millis(5_000);
644 let samples: Vec<_> = (0..100)
645 .map(|_| apply_jitter(cap, JitterMode::Full, &mut || rng.next_u64()))
646 .collect();
647 let unique: std::collections::BTreeSet<_> = samples.iter().collect();
648 assert!(
649 unique.len() > 1,
650 "expected at least two distinct jitter values, got {}",
651 unique.len()
652 );
653 }
654
655 #[test]
656 fn jitter_full_with_zero_cap_returns_zero() {
657 let mut rng = || 0xDEAD_BEEF_u64;
658 let out = apply_jitter(Duration::ZERO, JitterMode::Full, &mut rng);
659 assert_eq!(out, Duration::ZERO);
660 }
661
662 #[test]
665 fn xorshift_is_deterministic_and_non_trivial() {
666 let mut a = XorshiftRng { state: 0x1234_5678 };
667 let mut b = XorshiftRng { state: 0x1234_5678 };
668 let seq_a: Vec<u64> = (0..16).map(|_| a.next_u64()).collect();
669 let seq_b: Vec<u64> = (0..16).map(|_| b.next_u64()).collect();
670 assert_eq!(seq_a, seq_b, "same seed must produce same sequence");
671 let unique: std::collections::BTreeSet<_> = seq_a.iter().collect();
672 assert!(
673 unique.len() >= 15,
674 "xorshift should not cycle in 16 draws, got {}",
675 unique.len()
676 );
677 }
678}