Skip to main content

unifly_api/
websocket.rs

1//! WebSocket event stream with auto-reconnect.
2//!
3//! Connects to a UniFi controller's legacy WebSocket endpoint and streams
4//! parsed events through a [`tokio::sync::broadcast`] channel. Handles
5//! reconnection with exponential backoff + jitter automatically.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use unifly_api::websocket::{WebSocketHandle, ReconnectConfig};
11//! use unifly_api::transport::TlsMode;
12//! use tokio_util::sync::CancellationToken;
13//! use url::Url;
14//!
15//! let cancel = CancellationToken::new();
16//! let ws_url = Url::parse("wss://192.168.1.1/proxy/network/wss/s/default/events")?;
17//!
18//! let handle = WebSocketHandle::connect(
19//!     ws_url, ReconnectConfig::default(), cancel.clone(), None,
20//!     TlsMode::DangerAcceptInvalid,
21//! )?;
22//! let mut rx = handle.subscribe();
23//!
24//! while let Ok(event) = rx.recv().await {
25//!     println!("{}: {}", event.key, event.message.as_deref().unwrap_or(""));
26//! }
27//!
28//! handle.shutdown();
29//! ```
30
31use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use futures_util::StreamExt;
36use rustls::ClientConfig;
37use rustls_pki_types::CertificateDer;
38use serde::{Deserialize, Serialize};
39use tokio::sync::broadcast;
40use tokio_tungstenite::Connector;
41use tokio_tungstenite::tungstenite::{self, ClientRequestBuilder};
42use tokio_util::sync::CancellationToken;
43use url::Url;
44
45use crate::error::Error;
46use crate::transport::TlsMode;
47
48// ── Broadcast channel capacity ───────────────────────────────────────
49
50const EVENT_CHANNEL_CAPACITY: usize = 1024;
51
52// ── UnifiEvent ───────────────────────────────────────────────────────
53
54/// A parsed event from the UniFi WebSocket stream.
55///
56/// Uses `#[serde(flatten)]` to capture all fields beyond the core set,
57/// so nothing from the controller is silently dropped.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct UnifiEvent {
60    /// Event key, e.g. `"EVT_WU_Connected"`, `"EVT_SW_Disconnected"`.
61    pub key: String,
62
63    /// Subsystem that emitted the event: `"wlan"`, `"lan"`, `"sta"`, `"gw"`, etc.
64    pub subsystem: String,
65
66    /// Site ID this event belongs to.
67    pub site_id: String,
68
69    /// Human-readable event message, if present.
70    #[serde(default)]
71    pub message: Option<String>,
72
73    /// ISO-8601 timestamp from the controller.
74    #[serde(default)]
75    pub datetime: Option<String>,
76
77    /// All remaining fields the controller sends.
78    #[serde(flatten)]
79    pub extra: serde_json::Value,
80}
81
82// ── ReconnectConfig ──────────────────────────────────────────────────
83
84/// Exponential backoff configuration for WebSocket reconnection.
85#[derive(Debug, Clone)]
86pub struct ReconnectConfig {
87    /// Delay before the first reconnection attempt. Default: 1s.
88    pub initial_delay: Duration,
89
90    /// Upper bound on backoff delay. Default: 30s.
91    pub max_delay: Duration,
92
93    /// Maximum reconnection attempts before giving up.
94    /// `None` means retry forever.
95    pub max_retries: Option<u32>,
96}
97
98impl Default for ReconnectConfig {
99    fn default() -> Self {
100        Self {
101            initial_delay: Duration::from_secs(1),
102            max_delay: Duration::from_secs(30),
103            max_retries: None,
104        }
105    }
106}
107
108// ── WebSocketHandle ──────────────────────────────────────────────────
109
110/// Handle to a running WebSocket event stream.
111///
112/// Cheaply cloneable via the inner broadcast sender. Drop all handles
113/// and call [`shutdown`](Self::shutdown) to tear down the background task.
114pub struct WebSocketHandle {
115    event_rx: broadcast::Receiver<Arc<UnifiEvent>>,
116    cancel: CancellationToken,
117}
118
119impl WebSocketHandle {
120    /// Connect to the controller WebSocket and spawn the reconnection loop.
121    ///
122    /// Returns immediately once the background task is spawned.
123    /// The first connection attempt happens asynchronously -- subscribe to
124    /// the event receiver to start consuming events.
125    pub fn connect(
126        ws_url: Url,
127        reconnect: ReconnectConfig,
128        cancel: CancellationToken,
129        cookie: Option<String>,
130        tls_mode: TlsMode,
131    ) -> Result<Self, Error> {
132        let (event_tx, event_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
133
134        let task_cancel = cancel.clone();
135        tokio::spawn(async move {
136            ws_loop(ws_url, event_tx, reconnect, task_cancel, cookie, tls_mode).await;
137        });
138
139        Ok(Self { event_rx, cancel })
140    }
141
142    /// Get a new broadcast receiver for the event stream.
143    ///
144    /// Multiple consumers can subscribe concurrently. If a consumer falls
145    /// behind, it receives [`broadcast::error::RecvError::Lagged`].
146    pub fn subscribe(&self) -> broadcast::Receiver<Arc<UnifiEvent>> {
147        self.event_rx.resubscribe()
148    }
149
150    /// Signal the background task to shut down gracefully.
151    pub fn shutdown(&self) {
152        self.cancel.cancel();
153    }
154}
155
156// ── Background reconnection loop ─────────────────────────────────────
157
158/// Main loop: connect → read → on error, backoff → reconnect.
159async fn ws_loop(
160    ws_url: Url,
161    event_tx: broadcast::Sender<Arc<UnifiEvent>>,
162    reconnect: ReconnectConfig,
163    cancel: CancellationToken,
164    cookie: Option<String>,
165    tls_mode: TlsMode,
166) {
167    let mut attempt: u32 = 0;
168
169    loop {
170        tokio::select! {
171            biased;
172            () = cancel.cancelled() => break,
173            result = connect_and_read(&ws_url, &event_tx, &cancel, cookie.as_deref(), &tls_mode) => {
174                match result {
175                    // Clean disconnect (server close frame or stream ended).
176                    // Reset attempt counter and reconnect immediately.
177                    Ok(()) => {
178                        tracing::info!("WebSocket disconnected cleanly, reconnecting");
179                        attempt = 0;
180                    }
181                    Err(e) => {
182                        tracing::warn!(error = %e, attempt, "WebSocket error");
183
184                        if let Some(max) = reconnect.max_retries
185                            && attempt >= max {
186                                tracing::error!(
187                                    max_retries = max,
188                                    "WebSocket reconnection limit reached, giving up"
189                                );
190                                break;
191                            }
192
193                        let delay = calculate_backoff(attempt, &reconnect);
194                        let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
195                        tracing::info!(
196                            delay_ms,
197                            attempt,
198                            "Waiting before reconnect"
199                        );
200
201                        tokio::select! {
202                            biased;
203                            () = cancel.cancelled() => break,
204                            () = tokio::time::sleep(delay) => {}
205                        }
206
207                        attempt += 1;
208                    }
209                }
210            }
211        }
212    }
213
214    // Note: tracing after the loop is technically reachable (via break)
215    // but the compiler's macro expansion for select! can't prove it.
216    #[allow(unreachable_code)]
217    {
218        tracing::debug!("WebSocket loop exiting");
219    }
220}
221
222// ── Single connection lifecycle ──────────────────────────────────────
223
224/// Establish a single WebSocket connection, read messages until it drops.
225///
226/// If `cookie` is provided, it's injected as a `Cookie` header on the
227/// WebSocket upgrade request (required for legacy cookie-based auth).
228async fn connect_and_read(
229    url: &Url,
230    event_tx: &broadcast::Sender<Arc<UnifiEvent>>,
231    cancel: &CancellationToken,
232    cookie: Option<&str>,
233    tls_mode: &TlsMode,
234) -> Result<(), Error> {
235    tracing::info!(url = %url, "Connecting to WebSocket");
236
237    let uri: tungstenite::http::Uri = url
238        .as_str()
239        .parse()
240        .map_err(|e: tungstenite::http::uri::InvalidUri| Error::WebSocketConnect(e.to_string()))?;
241
242    let mut request = ClientRequestBuilder::new(uri);
243    if let Some(cookie_val) = cookie {
244        request = request.with_header("Cookie", cookie_val);
245    }
246
247    // Use plain connector for ws://, TLS connector for wss://
248    let connector = if url.scheme() == "wss" {
249        build_tls_connector(tls_mode)?
250    } else {
251        Some(Connector::Plain)
252    };
253
254    let (ws_stream, _response) =
255        tokio_tungstenite::connect_async_tls_with_config(request, None, false, connector)
256            .await
257            .map_err(|e| Error::WebSocketConnect(e.to_string()))?;
258
259    tracing::info!("WebSocket connected");
260
261    let (_write, mut read) = ws_stream.split();
262
263    loop {
264        tokio::select! {
265            biased;
266            () = cancel.cancelled() => return Ok(()),
267            frame = read.next() => {
268                match frame {
269                    Some(Ok(tungstenite::Message::Text(text))) => {
270                        parse_and_broadcast(&text, event_tx);
271                    }
272                    Some(Ok(tungstenite::Message::Ping(_))) => {
273                        // tungstenite handles pong replies automatically
274                        tracing::trace!("WebSocket ping");
275                    }
276                    Some(Ok(tungstenite::Message::Close(frame))) => {
277                        if let Some(ref cf) = frame {
278                            tracing::info!(
279                                code = %cf.code,
280                                reason = %cf.reason,
281                                "WebSocket close frame received"
282                            );
283                        } else {
284                            tracing::info!("WebSocket close frame received (no payload)");
285                        }
286                        return Ok(());
287                    }
288                    Some(Err(e)) => {
289                        return Err(Error::WebSocketConnect(e.to_string()));
290                    }
291                    None => {
292                        // Stream ended without a close frame
293                        tracing::info!("WebSocket stream ended");
294                        return Ok(());
295                    }
296                    _ => {
297                        // Binary, Pong, Frame -- ignore
298                    }
299                }
300            }
301        }
302    }
303}
304
305// ── Message parsing ──────────────────────────────────────────────────
306
307/// Raw envelope the controller sends over the WebSocket.
308///
309/// All messages have the shape `{ "meta": { "rc": "ok", ... }, "data": [...] }`.
310#[derive(Debug, Deserialize)]
311struct WsEnvelope {
312    #[allow(dead_code)]
313    meta: WsMeta,
314    data: Vec<serde_json::Value>,
315}
316
317#[derive(Debug, Deserialize)]
318struct WsMeta {
319    #[allow(dead_code)]
320    rc: String,
321    #[serde(default)]
322    message: Option<String>,
323}
324
325/// Parse a WebSocket text frame and broadcast any events found inside.
326fn parse_and_broadcast(text: &str, event_tx: &broadcast::Sender<Arc<UnifiEvent>>) {
327    let envelope: WsEnvelope = match serde_json::from_str(text) {
328        Ok(e) => e,
329        Err(e) => {
330            tracing::debug!(error = %e, "Failed to parse WebSocket envelope");
331            return;
332        }
333    };
334
335    let msg_type = envelope.meta.message.as_deref().unwrap_or("");
336
337    // Only "events" messages contain discrete events with a `key` field.
338    // Sync messages ("device:sync", "sta:sync", etc.) are state dumps --
339    // we surface them as events too, using the message type as the key.
340    for data in envelope.data {
341        let event = match msg_type {
342            "events" => match serde_json::from_value::<UnifiEvent>(data.clone()) {
343                Ok(evt) => evt,
344                Err(e) => {
345                    tracing::debug!(
346                        error = %e,
347                        msg_type,
348                        "Could not deserialize event, constructing from raw data"
349                    );
350                    event_from_raw(msg_type, &data)
351                }
352            },
353            // Sync and other message types -- construct a synthetic event
354            _ => event_from_raw(msg_type, &data),
355        };
356
357        // Ignore send errors -- just means no active subscribers right now
358        let _ = event_tx.send(Arc::new(event));
359    }
360}
361
362/// Build a [`UnifiEvent`] from raw JSON when typed deserialization fails
363/// or the message is a sync/unknown type.
364fn event_from_raw(msg_type: &str, data: &serde_json::Value) -> UnifiEvent {
365    UnifiEvent {
366        key: data["key"].as_str().unwrap_or(msg_type).to_string(),
367        subsystem: data["subsystem"].as_str().unwrap_or("unknown").to_string(),
368        site_id: data["site_id"].as_str().unwrap_or("").to_string(),
369        message: data["msg"]
370            .as_str()
371            .or_else(|| data["message"].as_str())
372            .map(String::from),
373        datetime: data["datetime"].as_str().map(String::from),
374        extra: data.clone(),
375    }
376}
377
378// ── TLS connector ────────────────────────────────────────────────────
379
380/// Build a [`Connector`] matching the given [`TlsMode`].
381///
382/// - `System`: returns `None` (uses default webpki-roots verification).
383/// - `CustomCa`: loads a PEM CA file into a custom root store.
384/// - `DangerAcceptInvalid`: disables all certificate verification.
385fn build_tls_connector(tls_mode: &TlsMode) -> Result<Option<Connector>, Error> {
386    // Ensure a rustls crypto provider is available (rustls 0.23+ requires this)
387    let _ = rustls::crypto::ring::default_provider().install_default();
388
389    match tls_mode {
390        TlsMode::System => Ok(None),
391        TlsMode::CustomCa(path) => {
392            let root_store = load_root_store(path)?;
393            let tls_config = ClientConfig::builder()
394                .with_root_certificates(root_store)
395                .with_no_client_auth();
396            Ok(Some(Connector::Rustls(Arc::new(tls_config))))
397        }
398        TlsMode::DangerAcceptInvalid => {
399            let tls_config = ClientConfig::builder()
400                .dangerous()
401                .with_custom_certificate_verifier(Arc::new(NoVerifier))
402                .with_no_client_auth();
403            Ok(Some(Connector::Rustls(Arc::new(tls_config))))
404        }
405    }
406}
407
408/// Load a PEM CA file into a [`rustls::RootCertStore`].
409fn load_root_store(path: &Path) -> Result<rustls::RootCertStore, Error> {
410    use rustls_pki_types::pem::PemObject;
411
412    let mut root_store = rustls::RootCertStore::empty();
413    for cert in CertificateDer::pem_file_iter(path)
414        .map_err(|e| Error::Tls(format!("failed to read CA cert: {e}")))?
415    {
416        let cert = cert.map_err(|e| Error::Tls(format!("invalid PEM in CA file: {e}")))?;
417        root_store
418            .add(cert)
419            .map_err(|e| Error::Tls(format!("invalid CA cert: {e}")))?;
420    }
421    Ok(root_store)
422}
423
424/// Certificate verifier that accepts any server certificate.
425///
426/// Only used when `TlsMode::DangerAcceptInvalid` is selected (self-signed
427/// controllers). This is intentionally insecure.
428#[derive(Debug)]
429struct NoVerifier;
430
431impl rustls::client::danger::ServerCertVerifier for NoVerifier {
432    fn verify_server_cert(
433        &self,
434        _end_entity: &CertificateDer<'_>,
435        _intermediates: &[CertificateDer<'_>],
436        _server_name: &rustls::pki_types::ServerName<'_>,
437        _ocsp_response: &[u8],
438        _now: rustls::pki_types::UnixTime,
439    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
440        Ok(rustls::client::danger::ServerCertVerified::assertion())
441    }
442
443    fn verify_tls12_signature(
444        &self,
445        _message: &[u8],
446        _cert: &CertificateDer<'_>,
447        _dss: &rustls::DigitallySignedStruct,
448    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
449        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
450    }
451
452    fn verify_tls13_signature(
453        &self,
454        _message: &[u8],
455        _cert: &CertificateDer<'_>,
456        _dss: &rustls::DigitallySignedStruct,
457    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
458        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
459    }
460
461    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
462        rustls::crypto::ring::default_provider()
463            .signature_verification_algorithms
464            .supported_schemes()
465    }
466}
467
468// ── Backoff calculation ──────────────────────────────────────────────
469
470/// Exponential backoff with jitter.
471///
472/// `delay = min(initial * 2^attempt, max) + jitter`
473///
474/// Jitter is +-25% to spread out reconnection storms from multiple clients.
475fn calculate_backoff(attempt: u32, config: &ReconnectConfig) -> Duration {
476    let base = config.initial_delay.as_secs_f64()
477        * 2.0_f64.powi(i32::try_from(attempt).unwrap_or(i32::MAX));
478    let capped = base.min(config.max_delay.as_secs_f64());
479
480    // Deterministic "jitter" seeded from the attempt number.
481    // Not cryptographically random, but good enough for backoff spread.
482    let jitter_factor = 1.0 + 0.25 * ((f64::from(attempt) * 7.3).sin());
483    let with_jitter = (capped * jitter_factor).max(0.0);
484
485    Duration::from_secs_f64(with_jitter)
486}
487
488// ── Tests ────────────────────────────────────────────────────────────
489
490#[cfg(test)]
491#[allow(clippy::unwrap_used)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn default_reconnect_config() {
497        let config = ReconnectConfig::default();
498        assert_eq!(config.initial_delay, Duration::from_secs(1));
499        assert_eq!(config.max_delay, Duration::from_secs(30));
500        assert!(config.max_retries.is_none());
501    }
502
503    #[test]
504    fn backoff_increases_exponentially() {
505        let config = ReconnectConfig::default();
506
507        let d0 = calculate_backoff(0, &config);
508        let d1 = calculate_backoff(1, &config);
509        let d2 = calculate_backoff(2, &config);
510
511        // Each step should roughly double (within jitter bounds)
512        assert!(d1 > d0, "d1 ({d1:?}) should be greater than d0 ({d0:?})");
513        assert!(d2 > d1, "d2 ({d2:?}) should be greater than d1 ({d1:?})");
514    }
515
516    #[test]
517    fn backoff_caps_at_max_delay() {
518        let config = ReconnectConfig {
519            initial_delay: Duration::from_secs(1),
520            max_delay: Duration::from_secs(10),
521            max_retries: None,
522        };
523
524        let d10 = calculate_backoff(10, &config);
525        // With jitter factor up to 1.25, max effective is 12.5s
526        assert!(
527            d10 <= Duration::from_secs(13),
528            "delay at attempt 10 ({d10:?}) should be capped near max_delay"
529        );
530    }
531
532    #[test]
533    fn parse_event_from_raw_json() {
534        let data = serde_json::json!({
535            "key": "EVT_WU_Connected",
536            "subsystem": "wlan",
537            "site_id": "abc123",
538            "msg": "User[aa:bb:cc:dd:ee:ff] connected",
539            "datetime": "2026-02-10T12:00:00Z",
540            "user": "aa:bb:cc:dd:ee:ff",
541            "ssid": "MyNetwork"
542        });
543
544        let event = event_from_raw("events", &data);
545        assert_eq!(event.key, "EVT_WU_Connected");
546        assert_eq!(event.subsystem, "wlan");
547        assert_eq!(event.site_id, "abc123");
548        assert_eq!(
549            event.message.as_deref(),
550            Some("User[aa:bb:cc:dd:ee:ff] connected")
551        );
552        assert_eq!(event.datetime.as_deref(), Some("2026-02-10T12:00:00Z"));
553    }
554
555    #[test]
556    fn parse_sync_event_from_raw_json() {
557        let data = serde_json::json!({
558            "mac": "aa:bb:cc:dd:ee:ff",
559            "state": 1,
560            "site_id": "site1"
561        });
562
563        let event = event_from_raw("device:sync", &data);
564        assert_eq!(event.key, "device:sync");
565        assert_eq!(event.subsystem, "unknown");
566        assert_eq!(event.site_id, "site1");
567    }
568
569    #[test]
570    fn deserialize_unifi_event() {
571        let json = r#"{
572            "key": "EVT_SW_Disconnected",
573            "subsystem": "lan",
574            "site_id": "default",
575            "message": "Switch lost contact",
576            "datetime": "2026-02-10T13:00:00Z",
577            "sw": "aa:bb:cc:dd:ee:ff",
578            "port": 4
579        }"#;
580
581        let event: UnifiEvent = serde_json::from_str(json).unwrap();
582        assert_eq!(event.key, "EVT_SW_Disconnected");
583        assert_eq!(event.subsystem, "lan");
584        assert_eq!(event.site_id, "default");
585        assert_eq!(event.message.as_deref(), Some("Switch lost contact"));
586        // Extra fields should be captured in `extra`
587        assert_eq!(event.extra["sw"], "aa:bb:cc:dd:ee:ff");
588        assert_eq!(event.extra["port"], 4);
589    }
590
591    #[test]
592    fn parse_and_broadcast_events_message() {
593        let (tx, mut rx) = broadcast::channel(16);
594
595        let raw = serde_json::json!({
596            "meta": { "rc": "ok", "message": "events" },
597            "data": [{
598                "key": "EVT_WU_Connected",
599                "subsystem": "wlan",
600                "site_id": "default",
601                "msg": "Client connected",
602                "user": "aa:bb:cc:dd:ee:ff"
603            }]
604        });
605
606        parse_and_broadcast(&raw.to_string(), &tx);
607
608        let event = rx.try_recv().unwrap();
609        assert_eq!(event.key, "EVT_WU_Connected");
610        assert_eq!(event.subsystem, "wlan");
611    }
612
613    #[test]
614    fn parse_and_broadcast_sync_message() {
615        let (tx, mut rx) = broadcast::channel(16);
616
617        let raw = serde_json::json!({
618            "meta": { "rc": "ok", "message": "device:sync" },
619            "data": [{
620                "mac": "aa:bb:cc:dd:ee:ff",
621                "state": 1,
622                "site_id": "site1"
623            }]
624        });
625
626        parse_and_broadcast(&raw.to_string(), &tx);
627
628        let event = rx.try_recv().unwrap();
629        assert_eq!(event.key, "device:sync");
630        assert_eq!(event.site_id, "site1");
631    }
632
633    #[test]
634    fn parse_and_broadcast_malformed_json() {
635        let (tx, mut rx) = broadcast::channel::<Arc<UnifiEvent>>(16);
636
637        parse_and_broadcast("not json at all", &tx);
638
639        // Should not panic, should just log and skip
640        assert!(rx.try_recv().is_err());
641    }
642}