Skip to main content

unifly_api/
websocket.rs

1//! WebSocket event stream with auto-reconnect.
2//!
3//! Connects to a UniFi controller's legacy WebSocket endpoint and streams
4//! parsed events through a [`tokio::sync::broadcast`] channel. Handles
5//! reconnection with exponential backoff + jitter automatically.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use unifly_api::websocket::{WebSocketHandle, ReconnectConfig};
11//! use unifly_api::transport::TlsMode;
12//! use tokio_util::sync::CancellationToken;
13//! use url::Url;
14//!
15//! let cancel = CancellationToken::new();
16//! let ws_url = Url::parse("wss://192.168.1.1/proxy/network/wss/s/default/events")?;
17//!
18//! let handle = WebSocketHandle::connect(
19//!     ws_url, ReconnectConfig::default(), cancel.clone(), None,
20//!     TlsMode::DangerAcceptInvalid,
21//! )?;
22//! let mut rx = handle.subscribe();
23//!
24//! while let Ok(event) = rx.recv().await {
25//!     println!("{}: {}", event.key, event.message.as_deref().unwrap_or(""));
26//! }
27//!
28//! handle.shutdown();
29//! ```
30
31use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use futures_util::StreamExt;
36use rustls::ClientConfig;
37use rustls_pki_types::CertificateDer;
38use serde::{Deserialize, Serialize};
39use tokio::sync::broadcast;
40use tokio_tungstenite::Connector;
41use tokio_tungstenite::tungstenite::{self, ClientRequestBuilder};
42use tokio_util::sync::CancellationToken;
43use url::Url;
44
45use crate::error::Error;
46use crate::transport::TlsMode;
47
48// ── Broadcast channel capacity ───────────────────────────────────────
49
50const EVENT_CHANNEL_CAPACITY: usize = 1024;
51
52// ── UnifiEvent ───────────────────────────────────────────────────────
53
54/// A parsed event from the UniFi WebSocket stream.
55///
56/// Uses `#[serde(flatten)]` to capture all fields beyond the core set,
57/// so nothing from the controller is silently dropped.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct UnifiEvent {
60    /// Event key, e.g. `"EVT_WU_Connected"`, `"EVT_SW_Disconnected"`.
61    pub key: String,
62
63    /// Subsystem that emitted the event: `"wlan"`, `"lan"`, `"sta"`, `"gw"`, etc.
64    pub subsystem: String,
65
66    /// Site ID this event belongs to.
67    pub site_id: String,
68
69    /// Human-readable event message, if present.
70    #[serde(default)]
71    pub message: Option<String>,
72
73    /// ISO-8601 timestamp from the controller.
74    #[serde(default)]
75    pub datetime: Option<String>,
76
77    /// All remaining fields the controller sends.
78    #[serde(flatten)]
79    pub extra: serde_json::Value,
80}
81
82// ── ReconnectConfig ──────────────────────────────────────────────────
83
84/// Exponential backoff configuration for WebSocket reconnection.
85#[derive(Debug, Clone)]
86pub struct ReconnectConfig {
87    /// Delay before the first reconnection attempt. Default: 1s.
88    pub initial_delay: Duration,
89
90    /// Upper bound on backoff delay. Default: 30s.
91    pub max_delay: Duration,
92
93    /// Maximum reconnection attempts before giving up.
94    /// `None` means retry forever.
95    pub max_retries: Option<u32>,
96}
97
98impl Default for ReconnectConfig {
99    fn default() -> Self {
100        Self {
101            initial_delay: Duration::from_secs(1),
102            max_delay: Duration::from_secs(30),
103            max_retries: None,
104        }
105    }
106}
107
108// ── WebSocketHandle ──────────────────────────────────────────────────
109
110/// Handle to a running WebSocket event stream.
111///
112/// Cheaply cloneable via the inner broadcast sender. Drop all handles
113/// and call [`shutdown`](Self::shutdown) to tear down the background task.
114pub struct WebSocketHandle {
115    event_rx: broadcast::Receiver<Arc<UnifiEvent>>,
116    cancel: CancellationToken,
117}
118
119impl WebSocketHandle {
120    /// Connect to the controller WebSocket and spawn the reconnection loop.
121    ///
122    /// Returns immediately once the background task is spawned.
123    /// The first connection attempt happens asynchronously -- subscribe to
124    /// the event receiver to start consuming events.
125    pub fn connect(
126        ws_url: Url,
127        reconnect: ReconnectConfig,
128        cancel: CancellationToken,
129        cookie: Option<String>,
130        tls_mode: TlsMode,
131    ) -> Result<Self, Error> {
132        let (event_tx, event_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
133
134        let task_cancel = cancel.clone();
135        tokio::spawn(async move {
136            ws_loop(ws_url, event_tx, reconnect, task_cancel, cookie, tls_mode).await;
137        });
138
139        Ok(Self { event_rx, cancel })
140    }
141
142    /// Get a new broadcast receiver for the event stream.
143    ///
144    /// Multiple consumers can subscribe concurrently. If a consumer falls
145    /// behind, it receives [`broadcast::error::RecvError::Lagged`].
146    pub fn subscribe(&self) -> broadcast::Receiver<Arc<UnifiEvent>> {
147        self.event_rx.resubscribe()
148    }
149
150    /// Signal the background task to shut down gracefully.
151    pub fn shutdown(&self) {
152        self.cancel.cancel();
153    }
154}
155
156// ── Background reconnection loop ─────────────────────────────────────
157
158/// Main loop: connect → read → on error, backoff → reconnect.
159async fn ws_loop(
160    ws_url: Url,
161    event_tx: broadcast::Sender<Arc<UnifiEvent>>,
162    reconnect: ReconnectConfig,
163    cancel: CancellationToken,
164    cookie: Option<String>,
165    tls_mode: TlsMode,
166) {
167    let mut attempt: u32 = 0;
168
169    loop {
170        tokio::select! {
171            biased;
172            () = cancel.cancelled() => break,
173            result = connect_and_read(&ws_url, &event_tx, &cancel, cookie.as_deref(), &tls_mode) => {
174                match result {
175                    // Clean disconnect (server close frame or stream ended).
176                    // Reset attempt counter and reconnect immediately.
177                    Ok(()) => {
178                        tracing::info!("WebSocket disconnected cleanly, reconnecting");
179                        attempt = 0;
180                    }
181                    Err(e) => {
182                        tracing::warn!(error = %e, attempt, "WebSocket error");
183
184                        if let Some(max) = reconnect.max_retries {
185                            if attempt >= max {
186                                tracing::error!(
187                                    max_retries = max,
188                                    "WebSocket reconnection limit reached, giving up"
189                                );
190                                break;
191                            }
192                        }
193
194                        let delay = calculate_backoff(attempt, &reconnect);
195                        let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
196                        tracing::info!(
197                            delay_ms,
198                            attempt,
199                            "Waiting before reconnect"
200                        );
201
202                        tokio::select! {
203                            biased;
204                            () = cancel.cancelled() => break,
205                            () = tokio::time::sleep(delay) => {}
206                        }
207
208                        attempt += 1;
209                    }
210                }
211            }
212        }
213    }
214
215    // Note: tracing after the loop is technically reachable (via break)
216    // but the compiler's macro expansion for select! can't prove it.
217    #[allow(unreachable_code)]
218    {
219        tracing::debug!("WebSocket loop exiting");
220    }
221}
222
223// ── Single connection lifecycle ──────────────────────────────────────
224
225/// Establish a single WebSocket connection, read messages until it drops.
226///
227/// If `cookie` is provided, it's injected as a `Cookie` header on the
228/// WebSocket upgrade request (required for legacy cookie-based auth).
229async fn connect_and_read(
230    url: &Url,
231    event_tx: &broadcast::Sender<Arc<UnifiEvent>>,
232    cancel: &CancellationToken,
233    cookie: Option<&str>,
234    tls_mode: &TlsMode,
235) -> Result<(), Error> {
236    tracing::info!(url = %url, "Connecting to WebSocket");
237
238    let uri: tungstenite::http::Uri = url
239        .as_str()
240        .parse()
241        .map_err(|e: tungstenite::http::uri::InvalidUri| Error::WebSocketConnect(e.to_string()))?;
242
243    let mut request = ClientRequestBuilder::new(uri);
244    if let Some(cookie_val) = cookie {
245        request = request.with_header("Cookie", cookie_val);
246    }
247
248    let connector = build_tls_connector(tls_mode)?;
249
250    let (ws_stream, _response) =
251        tokio_tungstenite::connect_async_tls_with_config(request, None, false, connector)
252            .await
253            .map_err(|e| Error::WebSocketConnect(e.to_string()))?;
254
255    tracing::info!("WebSocket connected");
256
257    let (_write, mut read) = ws_stream.split();
258
259    loop {
260        tokio::select! {
261            biased;
262            () = cancel.cancelled() => return Ok(()),
263            frame = read.next() => {
264                match frame {
265                    Some(Ok(tungstenite::Message::Text(text))) => {
266                        parse_and_broadcast(&text, event_tx);
267                    }
268                    Some(Ok(tungstenite::Message::Ping(_))) => {
269                        // tungstenite handles pong replies automatically
270                        tracing::trace!("WebSocket ping");
271                    }
272                    Some(Ok(tungstenite::Message::Close(frame))) => {
273                        if let Some(ref cf) = frame {
274                            tracing::info!(
275                                code = %cf.code,
276                                reason = %cf.reason,
277                                "WebSocket close frame received"
278                            );
279                        } else {
280                            tracing::info!("WebSocket close frame received (no payload)");
281                        }
282                        return Ok(());
283                    }
284                    Some(Err(e)) => {
285                        return Err(Error::WebSocketConnect(e.to_string()));
286                    }
287                    None => {
288                        // Stream ended without a close frame
289                        tracing::info!("WebSocket stream ended");
290                        return Ok(());
291                    }
292                    _ => {
293                        // Binary, Pong, Frame -- ignore
294                    }
295                }
296            }
297        }
298    }
299}
300
301// ── Message parsing ──────────────────────────────────────────────────
302
303/// Raw envelope the controller sends over the WebSocket.
304///
305/// All messages have the shape `{ "meta": { "rc": "ok", ... }, "data": [...] }`.
306#[derive(Debug, Deserialize)]
307struct WsEnvelope {
308    #[allow(dead_code)]
309    meta: WsMeta,
310    data: Vec<serde_json::Value>,
311}
312
313#[derive(Debug, Deserialize)]
314struct WsMeta {
315    #[allow(dead_code)]
316    rc: String,
317    #[serde(default)]
318    message: Option<String>,
319}
320
321/// Parse a WebSocket text frame and broadcast any events found inside.
322fn parse_and_broadcast(text: &str, event_tx: &broadcast::Sender<Arc<UnifiEvent>>) {
323    let envelope: WsEnvelope = match serde_json::from_str(text) {
324        Ok(e) => e,
325        Err(e) => {
326            tracing::debug!(error = %e, "Failed to parse WebSocket envelope");
327            return;
328        }
329    };
330
331    let msg_type = envelope.meta.message.as_deref().unwrap_or("");
332
333    // Only "events" messages contain discrete events with a `key` field.
334    // Sync messages ("device:sync", "sta:sync", etc.) are state dumps --
335    // we surface them as events too, using the message type as the key.
336    for data in envelope.data {
337        let event = match msg_type {
338            "events" => match serde_json::from_value::<UnifiEvent>(data.clone()) {
339                Ok(evt) => evt,
340                Err(e) => {
341                    tracing::debug!(
342                        error = %e,
343                        msg_type,
344                        "Could not deserialize event, constructing from raw data"
345                    );
346                    event_from_raw(msg_type, &data)
347                }
348            },
349            // Sync and other message types -- construct a synthetic event
350            _ => event_from_raw(msg_type, &data),
351        };
352
353        // Ignore send errors -- just means no active subscribers right now
354        let _ = event_tx.send(Arc::new(event));
355    }
356}
357
358/// Build a [`UnifiEvent`] from raw JSON when typed deserialization fails
359/// or the message is a sync/unknown type.
360fn event_from_raw(msg_type: &str, data: &serde_json::Value) -> UnifiEvent {
361    UnifiEvent {
362        key: data["key"].as_str().unwrap_or(msg_type).to_string(),
363        subsystem: data["subsystem"].as_str().unwrap_or("unknown").to_string(),
364        site_id: data["site_id"].as_str().unwrap_or("").to_string(),
365        message: data["msg"]
366            .as_str()
367            .or_else(|| data["message"].as_str())
368            .map(String::from),
369        datetime: data["datetime"].as_str().map(String::from),
370        extra: data.clone(),
371    }
372}
373
374// ── TLS connector ────────────────────────────────────────────────────
375
376/// Build a [`Connector`] matching the given [`TlsMode`].
377///
378/// - `System`: returns `None` (uses default webpki-roots verification).
379/// - `CustomCa`: loads a PEM CA file into a custom root store.
380/// - `DangerAcceptInvalid`: disables all certificate verification.
381fn build_tls_connector(tls_mode: &TlsMode) -> Result<Option<Connector>, Error> {
382    match tls_mode {
383        TlsMode::System => Ok(None),
384        TlsMode::CustomCa(path) => {
385            let root_store = load_root_store(path)?;
386            let tls_config = ClientConfig::builder()
387                .with_root_certificates(root_store)
388                .with_no_client_auth();
389            Ok(Some(Connector::Rustls(Arc::new(tls_config))))
390        }
391        TlsMode::DangerAcceptInvalid => {
392            let tls_config = ClientConfig::builder()
393                .dangerous()
394                .with_custom_certificate_verifier(Arc::new(NoVerifier))
395                .with_no_client_auth();
396            Ok(Some(Connector::Rustls(Arc::new(tls_config))))
397        }
398    }
399}
400
401/// Load a PEM CA file into a [`rustls::RootCertStore`].
402fn load_root_store(path: &Path) -> Result<rustls::RootCertStore, Error> {
403    use rustls_pki_types::pem::PemObject;
404
405    let mut root_store = rustls::RootCertStore::empty();
406    for cert in CertificateDer::pem_file_iter(path)
407        .map_err(|e| Error::Tls(format!("failed to read CA cert: {e}")))?
408    {
409        let cert = cert.map_err(|e| Error::Tls(format!("invalid PEM in CA file: {e}")))?;
410        root_store
411            .add(cert)
412            .map_err(|e| Error::Tls(format!("invalid CA cert: {e}")))?;
413    }
414    Ok(root_store)
415}
416
417/// Certificate verifier that accepts any server certificate.
418///
419/// Only used when `TlsMode::DangerAcceptInvalid` is selected (self-signed
420/// controllers). This is intentionally insecure.
421#[derive(Debug)]
422struct NoVerifier;
423
424impl rustls::client::danger::ServerCertVerifier for NoVerifier {
425    fn verify_server_cert(
426        &self,
427        _end_entity: &CertificateDer<'_>,
428        _intermediates: &[CertificateDer<'_>],
429        _server_name: &rustls::pki_types::ServerName<'_>,
430        _ocsp_response: &[u8],
431        _now: rustls::pki_types::UnixTime,
432    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
433        Ok(rustls::client::danger::ServerCertVerified::assertion())
434    }
435
436    fn verify_tls12_signature(
437        &self,
438        _message: &[u8],
439        _cert: &CertificateDer<'_>,
440        _dss: &rustls::DigitallySignedStruct,
441    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
442        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
443    }
444
445    fn verify_tls13_signature(
446        &self,
447        _message: &[u8],
448        _cert: &CertificateDer<'_>,
449        _dss: &rustls::DigitallySignedStruct,
450    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
451        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
452    }
453
454    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
455        rustls::crypto::ring::default_provider()
456            .signature_verification_algorithms
457            .supported_schemes()
458    }
459}
460
461// ── Backoff calculation ──────────────────────────────────────────────
462
463/// Exponential backoff with jitter.
464///
465/// `delay = min(initial * 2^attempt, max) + jitter`
466///
467/// Jitter is +-25% to spread out reconnection storms from multiple clients.
468fn calculate_backoff(attempt: u32, config: &ReconnectConfig) -> Duration {
469    let base = config.initial_delay.as_secs_f64()
470        * 2.0_f64.powi(i32::try_from(attempt).unwrap_or(i32::MAX));
471    let capped = base.min(config.max_delay.as_secs_f64());
472
473    // Deterministic "jitter" seeded from the attempt number.
474    // Not cryptographically random, but good enough for backoff spread.
475    let jitter_factor = 1.0 + 0.25 * ((f64::from(attempt) * 7.3).sin());
476    let with_jitter = (capped * jitter_factor).max(0.0);
477
478    Duration::from_secs_f64(with_jitter)
479}
480
481// ── Tests ────────────────────────────────────────────────────────────
482
483#[cfg(test)]
484#[allow(clippy::unwrap_used)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn default_reconnect_config() {
490        let config = ReconnectConfig::default();
491        assert_eq!(config.initial_delay, Duration::from_secs(1));
492        assert_eq!(config.max_delay, Duration::from_secs(30));
493        assert!(config.max_retries.is_none());
494    }
495
496    #[test]
497    fn backoff_increases_exponentially() {
498        let config = ReconnectConfig::default();
499
500        let d0 = calculate_backoff(0, &config);
501        let d1 = calculate_backoff(1, &config);
502        let d2 = calculate_backoff(2, &config);
503
504        // Each step should roughly double (within jitter bounds)
505        assert!(d1 > d0, "d1 ({d1:?}) should be greater than d0 ({d0:?})");
506        assert!(d2 > d1, "d2 ({d2:?}) should be greater than d1 ({d1:?})");
507    }
508
509    #[test]
510    fn backoff_caps_at_max_delay() {
511        let config = ReconnectConfig {
512            initial_delay: Duration::from_secs(1),
513            max_delay: Duration::from_secs(10),
514            max_retries: None,
515        };
516
517        let d10 = calculate_backoff(10, &config);
518        // With jitter factor up to 1.25, max effective is 12.5s
519        assert!(
520            d10 <= Duration::from_secs(13),
521            "delay at attempt 10 ({d10:?}) should be capped near max_delay"
522        );
523    }
524
525    #[test]
526    fn parse_event_from_raw_json() {
527        let data = serde_json::json!({
528            "key": "EVT_WU_Connected",
529            "subsystem": "wlan",
530            "site_id": "abc123",
531            "msg": "User[aa:bb:cc:dd:ee:ff] connected",
532            "datetime": "2026-02-10T12:00:00Z",
533            "user": "aa:bb:cc:dd:ee:ff",
534            "ssid": "MyNetwork"
535        });
536
537        let event = event_from_raw("events", &data);
538        assert_eq!(event.key, "EVT_WU_Connected");
539        assert_eq!(event.subsystem, "wlan");
540        assert_eq!(event.site_id, "abc123");
541        assert_eq!(
542            event.message.as_deref(),
543            Some("User[aa:bb:cc:dd:ee:ff] connected")
544        );
545        assert_eq!(event.datetime.as_deref(), Some("2026-02-10T12:00:00Z"));
546    }
547
548    #[test]
549    fn parse_sync_event_from_raw_json() {
550        let data = serde_json::json!({
551            "mac": "aa:bb:cc:dd:ee:ff",
552            "state": 1,
553            "site_id": "site1"
554        });
555
556        let event = event_from_raw("device:sync", &data);
557        assert_eq!(event.key, "device:sync");
558        assert_eq!(event.subsystem, "unknown");
559        assert_eq!(event.site_id, "site1");
560    }
561
562    #[test]
563    fn deserialize_unifi_event() {
564        let json = r#"{
565            "key": "EVT_SW_Disconnected",
566            "subsystem": "lan",
567            "site_id": "default",
568            "message": "Switch lost contact",
569            "datetime": "2026-02-10T13:00:00Z",
570            "sw": "aa:bb:cc:dd:ee:ff",
571            "port": 4
572        }"#;
573
574        let event: UnifiEvent = serde_json::from_str(json).unwrap();
575        assert_eq!(event.key, "EVT_SW_Disconnected");
576        assert_eq!(event.subsystem, "lan");
577        assert_eq!(event.site_id, "default");
578        assert_eq!(event.message.as_deref(), Some("Switch lost contact"));
579        // Extra fields should be captured in `extra`
580        assert_eq!(event.extra["sw"], "aa:bb:cc:dd:ee:ff");
581        assert_eq!(event.extra["port"], 4);
582    }
583
584    #[test]
585    fn parse_and_broadcast_events_message() {
586        let (tx, mut rx) = broadcast::channel(16);
587
588        let raw = serde_json::json!({
589            "meta": { "rc": "ok", "message": "events" },
590            "data": [{
591                "key": "EVT_WU_Connected",
592                "subsystem": "wlan",
593                "site_id": "default",
594                "msg": "Client connected",
595                "user": "aa:bb:cc:dd:ee:ff"
596            }]
597        });
598
599        parse_and_broadcast(&raw.to_string(), &tx);
600
601        let event = rx.try_recv().unwrap();
602        assert_eq!(event.key, "EVT_WU_Connected");
603        assert_eq!(event.subsystem, "wlan");
604    }
605
606    #[test]
607    fn parse_and_broadcast_sync_message() {
608        let (tx, mut rx) = broadcast::channel(16);
609
610        let raw = serde_json::json!({
611            "meta": { "rc": "ok", "message": "device:sync" },
612            "data": [{
613                "mac": "aa:bb:cc:dd:ee:ff",
614                "state": 1,
615                "site_id": "site1"
616            }]
617        });
618
619        parse_and_broadcast(&raw.to_string(), &tx);
620
621        let event = rx.try_recv().unwrap();
622        assert_eq!(event.key, "device:sync");
623        assert_eq!(event.site_id, "site1");
624    }
625
626    #[test]
627    fn parse_and_broadcast_malformed_json() {
628        let (tx, mut rx) = broadcast::channel::<Arc<UnifiEvent>>(16);
629
630        parse_and_broadcast("not json at all", &tx);
631
632        // Should not panic, should just log and skip
633        assert!(rx.try_recv().is_err());
634    }
635}