1use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use futures_util::StreamExt;
36use rustls::ClientConfig;
37use rustls_pki_types::CertificateDer;
38use serde::{Deserialize, Serialize};
39use tokio::sync::broadcast;
40use tokio_tungstenite::Connector;
41use tokio_tungstenite::tungstenite::{self, ClientRequestBuilder};
42use tokio_util::sync::CancellationToken;
43use url::Url;
44
45use crate::error::Error;
46use crate::transport::TlsMode;
47
48const EVENT_CHANNEL_CAPACITY: usize = 1024;
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct UnifiEvent {
60 pub key: String,
62
63 pub subsystem: String,
65
66 pub site_id: String,
68
69 #[serde(default)]
71 pub message: Option<String>,
72
73 #[serde(default)]
75 pub datetime: Option<String>,
76
77 #[serde(flatten)]
79 pub extra: serde_json::Value,
80}
81
82#[derive(Debug, Clone)]
86pub struct ReconnectConfig {
87 pub initial_delay: Duration,
89
90 pub max_delay: Duration,
92
93 pub max_retries: Option<u32>,
96}
97
98impl Default for ReconnectConfig {
99 fn default() -> Self {
100 Self {
101 initial_delay: Duration::from_secs(1),
102 max_delay: Duration::from_secs(30),
103 max_retries: None,
104 }
105 }
106}
107
108pub struct WebSocketHandle {
115 event_rx: broadcast::Receiver<Arc<UnifiEvent>>,
116 cancel: CancellationToken,
117}
118
119impl WebSocketHandle {
120 pub fn connect(
126 ws_url: Url,
127 reconnect: ReconnectConfig,
128 cancel: CancellationToken,
129 cookie: Option<String>,
130 tls_mode: TlsMode,
131 ) -> Result<Self, Error> {
132 let (event_tx, event_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
133
134 let task_cancel = cancel.clone();
135 tokio::spawn(async move {
136 ws_loop(ws_url, event_tx, reconnect, task_cancel, cookie, tls_mode).await;
137 });
138
139 Ok(Self { event_rx, cancel })
140 }
141
142 pub fn subscribe(&self) -> broadcast::Receiver<Arc<UnifiEvent>> {
147 self.event_rx.resubscribe()
148 }
149
150 pub fn shutdown(&self) {
152 self.cancel.cancel();
153 }
154}
155
156async fn ws_loop(
160 ws_url: Url,
161 event_tx: broadcast::Sender<Arc<UnifiEvent>>,
162 reconnect: ReconnectConfig,
163 cancel: CancellationToken,
164 cookie: Option<String>,
165 tls_mode: TlsMode,
166) {
167 let mut attempt: u32 = 0;
168
169 loop {
170 tokio::select! {
171 biased;
172 () = cancel.cancelled() => break,
173 result = connect_and_read(&ws_url, &event_tx, &cancel, cookie.as_deref(), &tls_mode) => {
174 match result {
175 Ok(()) => {
178 tracing::info!("WebSocket disconnected cleanly, reconnecting");
179 attempt = 0;
180 }
181 Err(e) => {
182 tracing::warn!(error = %e, attempt, "WebSocket error");
183
184 if let Some(max) = reconnect.max_retries
185 && attempt >= max {
186 tracing::error!(
187 max_retries = max,
188 "WebSocket reconnection limit reached, giving up"
189 );
190 break;
191 }
192
193 let delay = calculate_backoff(attempt, &reconnect);
194 let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
195 tracing::info!(
196 delay_ms,
197 attempt,
198 "Waiting before reconnect"
199 );
200
201 tokio::select! {
202 biased;
203 () = cancel.cancelled() => break,
204 () = tokio::time::sleep(delay) => {}
205 }
206
207 attempt += 1;
208 }
209 }
210 }
211 }
212 }
213
214 #[allow(unreachable_code)]
217 {
218 tracing::debug!("WebSocket loop exiting");
219 }
220}
221
222async fn connect_and_read(
229 url: &Url,
230 event_tx: &broadcast::Sender<Arc<UnifiEvent>>,
231 cancel: &CancellationToken,
232 cookie: Option<&str>,
233 tls_mode: &TlsMode,
234) -> Result<(), Error> {
235 tracing::info!(url = %url, "Connecting to WebSocket");
236
237 let uri: tungstenite::http::Uri = url
238 .as_str()
239 .parse()
240 .map_err(|e: tungstenite::http::uri::InvalidUri| Error::WebSocketConnect(e.to_string()))?;
241
242 let mut request = ClientRequestBuilder::new(uri);
243 if let Some(cookie_val) = cookie {
244 request = request.with_header("Cookie", cookie_val);
245 }
246
247 let connector = if url.scheme() == "wss" {
249 build_tls_connector(tls_mode)?
250 } else {
251 Some(Connector::Plain)
252 };
253
254 let (ws_stream, _response) =
255 tokio_tungstenite::connect_async_tls_with_config(request, None, false, connector)
256 .await
257 .map_err(|e| Error::WebSocketConnect(e.to_string()))?;
258
259 tracing::info!("WebSocket connected");
260
261 let (_write, mut read) = ws_stream.split();
262
263 loop {
264 tokio::select! {
265 biased;
266 () = cancel.cancelled() => return Ok(()),
267 frame = read.next() => {
268 match frame {
269 Some(Ok(tungstenite::Message::Text(text))) => {
270 parse_and_broadcast(&text, event_tx);
271 }
272 Some(Ok(tungstenite::Message::Ping(_))) => {
273 tracing::trace!("WebSocket ping");
275 }
276 Some(Ok(tungstenite::Message::Close(frame))) => {
277 if let Some(ref cf) = frame {
278 tracing::info!(
279 code = %cf.code,
280 reason = %cf.reason,
281 "WebSocket close frame received"
282 );
283 } else {
284 tracing::info!("WebSocket close frame received (no payload)");
285 }
286 return Ok(());
287 }
288 Some(Err(e)) => {
289 return Err(Error::WebSocketConnect(e.to_string()));
290 }
291 None => {
292 tracing::info!("WebSocket stream ended");
294 return Ok(());
295 }
296 _ => {
297 }
299 }
300 }
301 }
302 }
303}
304
305#[derive(Debug, Deserialize)]
311struct WsEnvelope {
312 #[allow(dead_code)]
313 meta: WsMeta,
314 data: Vec<serde_json::Value>,
315}
316
317#[derive(Debug, Deserialize)]
318struct WsMeta {
319 #[allow(dead_code)]
320 rc: String,
321 #[serde(default)]
322 message: Option<String>,
323}
324
325fn parse_and_broadcast(text: &str, event_tx: &broadcast::Sender<Arc<UnifiEvent>>) {
327 let envelope: WsEnvelope = match serde_json::from_str(text) {
328 Ok(e) => e,
329 Err(e) => {
330 tracing::debug!(error = %e, "Failed to parse WebSocket envelope");
331 return;
332 }
333 };
334
335 let msg_type = envelope.meta.message.as_deref().unwrap_or("");
336
337 for data in envelope.data {
341 let event = match msg_type {
342 "events" => match serde_json::from_value::<UnifiEvent>(data.clone()) {
343 Ok(evt) => evt,
344 Err(e) => {
345 tracing::debug!(
346 error = %e,
347 msg_type,
348 "Could not deserialize event, constructing from raw data"
349 );
350 event_from_raw(msg_type, &data)
351 }
352 },
353 _ => event_from_raw(msg_type, &data),
355 };
356
357 let _ = event_tx.send(Arc::new(event));
359 }
360}
361
362fn event_from_raw(msg_type: &str, data: &serde_json::Value) -> UnifiEvent {
365 UnifiEvent {
366 key: data["key"].as_str().unwrap_or(msg_type).to_string(),
367 subsystem: data["subsystem"].as_str().unwrap_or("unknown").to_string(),
368 site_id: data["site_id"].as_str().unwrap_or("").to_string(),
369 message: data["msg"]
370 .as_str()
371 .or_else(|| data["message"].as_str())
372 .map(String::from),
373 datetime: data["datetime"].as_str().map(String::from),
374 extra: data.clone(),
375 }
376}
377
378fn build_tls_connector(tls_mode: &TlsMode) -> Result<Option<Connector>, Error> {
386 let _ = rustls::crypto::ring::default_provider().install_default();
388
389 match tls_mode {
390 TlsMode::System => Ok(None),
391 TlsMode::CustomCa(path) => {
392 let root_store = load_root_store(path)?;
393 let tls_config = ClientConfig::builder()
394 .with_root_certificates(root_store)
395 .with_no_client_auth();
396 Ok(Some(Connector::Rustls(Arc::new(tls_config))))
397 }
398 TlsMode::DangerAcceptInvalid => {
399 let tls_config = ClientConfig::builder()
400 .dangerous()
401 .with_custom_certificate_verifier(Arc::new(NoVerifier))
402 .with_no_client_auth();
403 Ok(Some(Connector::Rustls(Arc::new(tls_config))))
404 }
405 }
406}
407
408fn load_root_store(path: &Path) -> Result<rustls::RootCertStore, Error> {
410 use rustls_pki_types::pem::PemObject;
411
412 let mut root_store = rustls::RootCertStore::empty();
413 for cert in CertificateDer::pem_file_iter(path)
414 .map_err(|e| Error::Tls(format!("failed to read CA cert: {e}")))?
415 {
416 let cert = cert.map_err(|e| Error::Tls(format!("invalid PEM in CA file: {e}")))?;
417 root_store
418 .add(cert)
419 .map_err(|e| Error::Tls(format!("invalid CA cert: {e}")))?;
420 }
421 Ok(root_store)
422}
423
424#[derive(Debug)]
429struct NoVerifier;
430
431impl rustls::client::danger::ServerCertVerifier for NoVerifier {
432 fn verify_server_cert(
433 &self,
434 _end_entity: &CertificateDer<'_>,
435 _intermediates: &[CertificateDer<'_>],
436 _server_name: &rustls::pki_types::ServerName<'_>,
437 _ocsp_response: &[u8],
438 _now: rustls::pki_types::UnixTime,
439 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
440 Ok(rustls::client::danger::ServerCertVerified::assertion())
441 }
442
443 fn verify_tls12_signature(
444 &self,
445 _message: &[u8],
446 _cert: &CertificateDer<'_>,
447 _dss: &rustls::DigitallySignedStruct,
448 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
449 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
450 }
451
452 fn verify_tls13_signature(
453 &self,
454 _message: &[u8],
455 _cert: &CertificateDer<'_>,
456 _dss: &rustls::DigitallySignedStruct,
457 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
458 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
459 }
460
461 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
462 rustls::crypto::ring::default_provider()
463 .signature_verification_algorithms
464 .supported_schemes()
465 }
466}
467
468fn calculate_backoff(attempt: u32, config: &ReconnectConfig) -> Duration {
476 let base = config.initial_delay.as_secs_f64()
477 * 2.0_f64.powi(i32::try_from(attempt).unwrap_or(i32::MAX));
478 let capped = base.min(config.max_delay.as_secs_f64());
479
480 let jitter_factor = 1.0 + 0.25 * ((f64::from(attempt) * 7.3).sin());
483 let with_jitter = (capped * jitter_factor).max(0.0);
484
485 Duration::from_secs_f64(with_jitter)
486}
487
488#[cfg(test)]
491#[allow(clippy::unwrap_used)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn default_reconnect_config() {
497 let config = ReconnectConfig::default();
498 assert_eq!(config.initial_delay, Duration::from_secs(1));
499 assert_eq!(config.max_delay, Duration::from_secs(30));
500 assert!(config.max_retries.is_none());
501 }
502
503 #[test]
504 fn backoff_increases_exponentially() {
505 let config = ReconnectConfig::default();
506
507 let d0 = calculate_backoff(0, &config);
508 let d1 = calculate_backoff(1, &config);
509 let d2 = calculate_backoff(2, &config);
510
511 assert!(d1 > d0, "d1 ({d1:?}) should be greater than d0 ({d0:?})");
513 assert!(d2 > d1, "d2 ({d2:?}) should be greater than d1 ({d1:?})");
514 }
515
516 #[test]
517 fn backoff_caps_at_max_delay() {
518 let config = ReconnectConfig {
519 initial_delay: Duration::from_secs(1),
520 max_delay: Duration::from_secs(10),
521 max_retries: None,
522 };
523
524 let d10 = calculate_backoff(10, &config);
525 assert!(
527 d10 <= Duration::from_secs(13),
528 "delay at attempt 10 ({d10:?}) should be capped near max_delay"
529 );
530 }
531
532 #[test]
533 fn parse_event_from_raw_json() {
534 let data = serde_json::json!({
535 "key": "EVT_WU_Connected",
536 "subsystem": "wlan",
537 "site_id": "abc123",
538 "msg": "User[aa:bb:cc:dd:ee:ff] connected",
539 "datetime": "2026-02-10T12:00:00Z",
540 "user": "aa:bb:cc:dd:ee:ff",
541 "ssid": "MyNetwork"
542 });
543
544 let event = event_from_raw("events", &data);
545 assert_eq!(event.key, "EVT_WU_Connected");
546 assert_eq!(event.subsystem, "wlan");
547 assert_eq!(event.site_id, "abc123");
548 assert_eq!(
549 event.message.as_deref(),
550 Some("User[aa:bb:cc:dd:ee:ff] connected")
551 );
552 assert_eq!(event.datetime.as_deref(), Some("2026-02-10T12:00:00Z"));
553 }
554
555 #[test]
556 fn parse_sync_event_from_raw_json() {
557 let data = serde_json::json!({
558 "mac": "aa:bb:cc:dd:ee:ff",
559 "state": 1,
560 "site_id": "site1"
561 });
562
563 let event = event_from_raw("device:sync", &data);
564 assert_eq!(event.key, "device:sync");
565 assert_eq!(event.subsystem, "unknown");
566 assert_eq!(event.site_id, "site1");
567 }
568
569 #[test]
570 fn deserialize_unifi_event() {
571 let json = r#"{
572 "key": "EVT_SW_Disconnected",
573 "subsystem": "lan",
574 "site_id": "default",
575 "message": "Switch lost contact",
576 "datetime": "2026-02-10T13:00:00Z",
577 "sw": "aa:bb:cc:dd:ee:ff",
578 "port": 4
579 }"#;
580
581 let event: UnifiEvent = serde_json::from_str(json).unwrap();
582 assert_eq!(event.key, "EVT_SW_Disconnected");
583 assert_eq!(event.subsystem, "lan");
584 assert_eq!(event.site_id, "default");
585 assert_eq!(event.message.as_deref(), Some("Switch lost contact"));
586 assert_eq!(event.extra["sw"], "aa:bb:cc:dd:ee:ff");
588 assert_eq!(event.extra["port"], 4);
589 }
590
591 #[test]
592 fn parse_and_broadcast_events_message() {
593 let (tx, mut rx) = broadcast::channel(16);
594
595 let raw = serde_json::json!({
596 "meta": { "rc": "ok", "message": "events" },
597 "data": [{
598 "key": "EVT_WU_Connected",
599 "subsystem": "wlan",
600 "site_id": "default",
601 "msg": "Client connected",
602 "user": "aa:bb:cc:dd:ee:ff"
603 }]
604 });
605
606 parse_and_broadcast(&raw.to_string(), &tx);
607
608 let event = rx.try_recv().unwrap();
609 assert_eq!(event.key, "EVT_WU_Connected");
610 assert_eq!(event.subsystem, "wlan");
611 }
612
613 #[test]
614 fn parse_and_broadcast_sync_message() {
615 let (tx, mut rx) = broadcast::channel(16);
616
617 let raw = serde_json::json!({
618 "meta": { "rc": "ok", "message": "device:sync" },
619 "data": [{
620 "mac": "aa:bb:cc:dd:ee:ff",
621 "state": 1,
622 "site_id": "site1"
623 }]
624 });
625
626 parse_and_broadcast(&raw.to_string(), &tx);
627
628 let event = rx.try_recv().unwrap();
629 assert_eq!(event.key, "device:sync");
630 assert_eq!(event.site_id, "site1");
631 }
632
633 #[test]
634 fn parse_and_broadcast_malformed_json() {
635 let (tx, mut rx) = broadcast::channel::<Arc<UnifiEvent>>(16);
636
637 parse_and_broadcast("not json at all", &tx);
638
639 assert!(rx.try_recv().is_err());
641 }
642}