Skip to main content

haystack_server/
ws.rs

1//! WebSocket handler and watch subscription manager.
2//!
3//! This module provides two major components:
4//!
5//! 1. **`WatchManager`** — a thread-safe subscription registry that manages
6//!    watch lifecycles (subscribe, poll, unsubscribe, add/remove IDs).
7//!
8//! 2. **`ws_handler`** — an Axum WebSocket upgrade endpoint (`GET /api/ws`)
9//!    that handles Haystack watch operations over JSON messages.
10
11use std::collections::{HashMap, HashSet};
12use std::time::{Duration, Instant};
13
14use axum::Extension;
15use axum::extract::State;
16use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
17use axum::response::Response;
18use parking_lot::RwLock;
19use serde_json::{Map, Value};
20use uuid::Uuid;
21
22use haystack_core::codecs::json::v3 as json_v3;
23use haystack_core::data::HDict;
24use haystack_core::graph::SharedGraph;
25
26use crate::auth::AuthUser;
27use crate::state::SharedState;
28
29// ---------------------------------------------------------------------------
30// Tuning constants
31// ---------------------------------------------------------------------------
32
33const MAX_WATCHES: usize = 100;
34const MAX_ENTITY_IDS_PER_WATCH: usize = 1_000;
35/// Maximum total entity IDs a single user can watch across all watches.
36const MAX_TOTAL_WATCHED_IDS: usize = 5_000;
37/// Maximum watches a single user can hold at once.
38const MAX_WATCHES_PER_USER: usize = 20;
39
40/// Maximum entries in the per-connection encode cache.
41const MAX_ENCODE_CACHE_ENTRIES: usize = 50_000;
42
43/// Server-initiated ping interval for liveness detection.
44const PING_INTERVAL: Duration = Duration::from_secs(30);
45
46/// If no pong is received within this duration after a ping, the connection
47/// is considered dead and will be closed.
48const PONG_TIMEOUT: Duration = Duration::from_secs(10);
49
50/// mpsc channel capacity for outbound messages.
51const CHANNEL_CAPACITY: usize = 64;
52
53/// Number of consecutive `try_send` failures before closing a slow client.
54const MAX_SEND_FAILURES: u32 = 3;
55
56// ---------------------------------------------------------------------------
57// WebSocket message types
58// ---------------------------------------------------------------------------
59
60/// Incoming JSON message from a WebSocket client.
61#[derive(serde::Deserialize, Debug)]
62struct WsRequest {
63    op: String,
64    #[serde(rename = "reqId")]
65    req_id: Option<String>,
66    #[serde(rename = "watchId")]
67    watch_id: Option<String>,
68    ids: Option<Vec<String>>,
69}
70
71/// Outgoing JSON message sent to a WebSocket client.
72#[derive(serde::Serialize, Debug)]
73struct WsResponse {
74    #[serde(rename = "reqId", skip_serializing_if = "Option::is_none")]
75    req_id: Option<String>,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    error: Option<String>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    rows: Option<Vec<Value>>,
80    #[serde(rename = "watchId", skip_serializing_if = "Option::is_none")]
81    watch_id: Option<String>,
82}
83
84impl WsResponse {
85    /// Build an error response, preserving the request ID for correlation.
86    fn error(req_id: Option<String>, msg: impl Into<String>) -> Self {
87        Self {
88            req_id,
89            error: Some(msg.into()),
90            rows: None,
91            watch_id: None,
92        }
93    }
94
95    /// Build a success response with rows and an optional watch ID.
96    fn ok(req_id: Option<String>, rows: Vec<Value>, watch_id: Option<String>) -> Self {
97        Self {
98            req_id,
99            error: None,
100            rows: Some(rows),
101            watch_id,
102        }
103    }
104}
105
106// ---------------------------------------------------------------------------
107// Entity encoding helper
108// ---------------------------------------------------------------------------
109
110/// Encode an `HDict` entity as a JSON object using the Haystack JSON v3
111/// encoding for individual tag values.
112fn encode_entity(entity: &HDict) -> Value {
113    let mut m = Map::new();
114    let mut keys: Vec<&String> = entity.tags().keys().collect();
115    keys.sort();
116    for k in keys {
117        let v = &entity.tags()[k];
118        if let Ok(encoded) = json_v3::encode_kind(v) {
119            m.insert(k.clone(), encoded);
120        }
121    }
122    Value::Object(m)
123}
124
125// ---------------------------------------------------------------------------
126// WebSocket op dispatch
127// ---------------------------------------------------------------------------
128
129/// Handle a parsed `WsRequest` by dispatching to the appropriate watch op.
130fn handle_ws_request(req: &WsRequest, username: &str, state: &SharedState) -> String {
131    let resp = match req.op.as_str() {
132        "watchSub" => handle_watch_sub(req, username, state),
133        "watchPoll" => handle_watch_poll(req, username, state),
134        "watchUnsub" => handle_watch_unsub(req, username, state),
135        other => WsResponse::error(req.req_id.clone(), format!("unknown op: {other}")),
136    };
137    serde_json::to_string(&resp).unwrap_or_else(|e| {
138        let fallback = WsResponse::error(req.req_id.clone(), format!("serialization error: {e}"));
139        serde_json::to_string(&fallback).unwrap()
140    })
141}
142
143fn handle_watch_sub(req: &WsRequest, username: &str, state: &SharedState) -> WsResponse {
144    let ids = match &req.ids {
145        Some(ids) if !ids.is_empty() => ids.clone(),
146        _ => {
147            return WsResponse::error(
148                req.req_id.clone(),
149                "watchSub requires non-empty 'ids' array",
150            );
151        }
152    };
153
154    // Strip leading '@' from ref strings if present.
155    let ids: Vec<String> = ids
156        .into_iter()
157        .map(|id| id.strip_prefix('@').unwrap_or(&id).to_string())
158        .collect();
159
160    let graph_version = state.graph.version();
161    let watch_id = match state
162        .watches
163        .subscribe(username, ids.clone(), graph_version)
164    {
165        Ok(wid) => wid,
166        Err(e) => return WsResponse::error(req.req_id.clone(), e),
167    };
168
169    let rows: Vec<Value> = ids
170        .iter()
171        .filter_map(|id| state.graph.get(id).map(|e| encode_entity(&e)))
172        .collect();
173
174    WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
175}
176
177fn handle_watch_poll(req: &WsRequest, username: &str, state: &SharedState) -> WsResponse {
178    let watch_id = match &req.watch_id {
179        Some(wid) => wid.clone(),
180        None => {
181            return WsResponse::error(req.req_id.clone(), "watchPoll requires 'watchId'");
182        }
183    };
184
185    match state.watches.poll(&watch_id, username, &state.graph) {
186        Some(changed) => {
187            let rows: Vec<Value> = changed.iter().map(encode_entity).collect();
188            WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
189        }
190        None => WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}")),
191    }
192}
193
194fn handle_watch_unsub(req: &WsRequest, username: &str, state: &SharedState) -> WsResponse {
195    let watch_id = match &req.watch_id {
196        Some(wid) => wid.clone(),
197        None => {
198            return WsResponse::error(req.req_id.clone(), "watchUnsub requires 'watchId'");
199        }
200    };
201
202    if let Some(ids) = &req.ids
203        && !ids.is_empty()
204    {
205        let clean: Vec<String> = ids
206            .iter()
207            .map(|id| id.strip_prefix('@').unwrap_or(id).to_string())
208            .collect();
209        if !state.watches.remove_ids(&watch_id, username, &clean) {
210            return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
211        }
212        return WsResponse::ok(req.req_id.clone(), vec![], Some(watch_id));
213    }
214
215    if !state.watches.unsubscribe(&watch_id, username) {
216        return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
217    }
218    WsResponse::ok(req.req_id.clone(), vec![], None)
219}
220
221/// A single watch subscription.
222struct Watch {
223    /// Entity IDs being watched.
224    entity_ids: HashSet<String>,
225    /// Graph version at last poll.
226    last_version: u64,
227    /// Username of the watch owner.
228    owner: String,
229}
230
231/// Manages watch subscriptions for change polling.
232pub struct WatchManager {
233    watches: RwLock<HashMap<String, Watch>>,
234    /// Cached entity encodings keyed by (ref_val, version) for watch poll.
235    encode_cache: RwLock<HashMap<(String, u64), Value>>,
236    /// Graph version at which the encode cache was last validated.
237    cache_version: RwLock<u64>,
238}
239
240impl WatchManager {
241    /// Create a new empty WatchManager.
242    pub fn new() -> Self {
243        Self {
244            watches: RwLock::new(HashMap::new()),
245            encode_cache: RwLock::new(HashMap::new()),
246            cache_version: RwLock::new(0),
247        }
248    }
249
250    /// Subscribe to changes on a set of entity IDs.
251    pub fn subscribe(
252        &self,
253        username: &str,
254        ids: Vec<String>,
255        graph_version: u64,
256    ) -> Result<String, String> {
257        let mut watches = self.watches.write();
258        if watches.len() >= MAX_WATCHES {
259            return Err("maximum number of watches reached".to_string());
260        }
261        let user_count = watches.values().filter(|w| w.owner == username).count();
262        if user_count >= MAX_WATCHES_PER_USER {
263            return Err(format!(
264                "user '{}' has reached the maximum of {} watches",
265                username, MAX_WATCHES_PER_USER
266            ));
267        }
268        if ids.len() > MAX_ENTITY_IDS_PER_WATCH {
269            return Err(format!(
270                "too many entity IDs (max {})",
271                MAX_ENTITY_IDS_PER_WATCH
272            ));
273        }
274        let user_total: usize = watches
275            .values()
276            .filter(|w| w.owner == username)
277            .map(|w| w.entity_ids.len())
278            .sum();
279        if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
280            return Err(format!(
281                "user '{}' would exceed the maximum of {} total watched IDs",
282                username, MAX_TOTAL_WATCHED_IDS
283            ));
284        }
285        let watch_id = Uuid::new_v4().to_string();
286        let watch = Watch {
287            entity_ids: ids.into_iter().collect(),
288            last_version: graph_version,
289            owner: username.to_string(),
290        };
291        watches.insert(watch_id.clone(), watch);
292        Ok(watch_id)
293    }
294
295    /// Poll for changes since the last poll.
296    pub fn poll(&self, watch_id: &str, username: &str, graph: &SharedGraph) -> Option<Vec<HDict>> {
297        let (entity_ids, last_version) = {
298            let mut watches = self.watches.write();
299            let watch = watches.get_mut(watch_id)?;
300            if watch.owner != username {
301                return None;
302            }
303
304            let current_version = graph.version();
305            if current_version == watch.last_version {
306                return Some(Vec::new());
307            }
308
309            let ids = watch.entity_ids.clone();
310            let last = watch.last_version;
311            watch.last_version = current_version;
312            (ids, last)
313        };
314
315        let changes = match graph.changes_since(last_version) {
316            Ok(c) => c,
317            Err(_gap) => {
318                return Some(entity_ids.iter().filter_map(|id| graph.get(id)).collect());
319            }
320        };
321        let changed_refs: HashSet<&str> = changes.iter().map(|d| d.ref_val.as_str()).collect();
322
323        Some(
324            entity_ids
325                .iter()
326                .filter(|id| changed_refs.contains(id.as_str()))
327                .filter_map(|id| graph.get(id))
328                .collect(),
329        )
330    }
331
332    /// Unsubscribe a watch by ID.
333    pub fn unsubscribe(&self, watch_id: &str, username: &str) -> bool {
334        let mut watches = self.watches.write();
335        match watches.get(watch_id) {
336            Some(watch) if watch.owner == username => {
337                watches.remove(watch_id);
338                true
339            }
340            _ => false,
341        }
342    }
343
344    /// Add entity IDs to an existing watch.
345    pub fn add_ids(&self, watch_id: &str, username: &str, ids: Vec<String>) -> bool {
346        let mut watches = self.watches.write();
347
348        let (owner_ok, per_watch_ok, user_total) = match watches.get(watch_id) {
349            Some(watch) => (
350                watch.owner == username,
351                watch.entity_ids.len() + ids.len() <= MAX_ENTITY_IDS_PER_WATCH,
352                watches
353                    .values()
354                    .filter(|w| w.owner == username)
355                    .map(|w| w.entity_ids.len())
356                    .sum::<usize>(),
357            ),
358            None => return false,
359        };
360
361        if !owner_ok || !per_watch_ok {
362            return false;
363        }
364        if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
365            return false;
366        }
367
368        if let Some(watch) = watches.get_mut(watch_id) {
369            watch.entity_ids.extend(ids);
370        }
371        true
372    }
373
374    /// Remove specific entity IDs from an existing watch.
375    pub fn remove_ids(&self, watch_id: &str, username: &str, ids: &[String]) -> bool {
376        let mut watches = self.watches.write();
377        if let Some(watch) = watches.get_mut(watch_id) {
378            if watch.owner != username {
379                return false;
380            }
381            for id in ids {
382                watch.entity_ids.remove(id);
383            }
384            true
385        } else {
386            false
387        }
388    }
389
390    /// Remove all watches owned by a given user.
391    pub fn remove_by_owner(&self, owner: &str) {
392        let mut watches = self.watches.write();
393        watches.retain(|_, w| w.owner != owner);
394    }
395
396    /// Return the list of entity IDs for a given watch.
397    pub fn get_ids(&self, watch_id: &str) -> Option<Vec<String>> {
398        let watches = self.watches.read();
399        watches
400            .get(watch_id)
401            .map(|w| w.entity_ids.iter().cloned().collect())
402    }
403
404    /// Return the number of active watches.
405    pub fn len(&self) -> usize {
406        self.watches.read().len()
407    }
408
409    /// Return whether there are no active watches.
410    pub fn is_empty(&self) -> bool {
411        self.watches.read().is_empty()
412    }
413
414    /// Encode an entity using the cache.
415    pub fn encode_cached(&self, ref_val: &str, graph_version: u64, entity: &HDict) -> Value {
416        {
417            let mut cv = self.cache_version.write();
418            if graph_version > *cv {
419                self.encode_cache.write().clear();
420                *cv = graph_version;
421            }
422        }
423
424        let key = (ref_val.to_string(), graph_version);
425        if let Some(cached) = self.encode_cache.read().get(&key) {
426            return cached.clone();
427        }
428
429        let encoded = encode_entity(entity);
430        let mut cache = self.encode_cache.write();
431        cache.insert(key, encoded.clone());
432        if cache.len() > MAX_ENCODE_CACHE_ENTRIES {
433            let to_remove = cache.len() / 4;
434            let keys: Vec<_> = cache.keys().take(to_remove).cloned().collect();
435            for k in keys {
436                cache.remove(&k);
437            }
438        }
439        encoded
440    }
441
442    /// Get the IDs of all entities watched by any watch.
443    pub fn all_watched_ids(&self) -> HashSet<String> {
444        let watches = self.watches.read();
445        watches
446            .values()
447            .flat_map(|w| w.entity_ids.iter().cloned())
448            .collect()
449    }
450
451    /// Find watches that contain any of the given changed ref_vals.
452    pub fn watches_affected_by(
453        &self,
454        changed_refs: &HashSet<&str>,
455    ) -> Vec<(String, String, Vec<String>)> {
456        let watches = self.watches.read();
457        let mut affected = Vec::new();
458        for (wid, watch) in watches.iter() {
459            let matched: Vec<String> = watch
460                .entity_ids
461                .iter()
462                .filter(|id| changed_refs.contains(id.as_str()))
463                .cloned()
464                .collect();
465            if !matched.is_empty() {
466                affected.push((wid.clone(), watch.owner.clone(), matched));
467            }
468        }
469        affected
470    }
471}
472
473impl Default for WatchManager {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479// ---------------------------------------------------------------------------
480// WebSocket handler
481// ---------------------------------------------------------------------------
482
483/// WebSocket upgrade handler for `/api/ws`.
484pub async fn ws_handler(
485    ws: WebSocketUpgrade,
486    State(state): State<SharedState>,
487    auth: Option<Extension<AuthUser>>,
488) -> Response {
489    let username = auth
490        .map(|Extension(u)| u.username)
491        .unwrap_or_else(|| "anonymous".into());
492    ws.on_upgrade(move |socket| handle_socket(socket, username, state))
493}
494
495/// Handle a WebSocket connection after upgrade.
496async fn handle_socket(socket: WebSocket, username: String, state: SharedState) {
497    use tokio::sync::mpsc;
498
499    let (tx, mut rx) = mpsc::channel::<Message>(CHANNEL_CAPACITY);
500
501    // Spawn a task to forward messages from the channel to the WS session.
502    use futures_util::{SinkExt, StreamExt};
503
504    let (mut ws_sender, mut ws_receiver) = socket.split();
505
506    tokio::spawn(async move {
507        while let Some(msg) = rx.recv().await {
508            if ws_sender.send(msg).await.is_err() {
509                break;
510            }
511        }
512    });
513
514    // Track connection liveness.
515    let mut last_activity = Instant::now();
516    let mut ping_interval = tokio::time::interval(PING_INTERVAL);
517    ping_interval.tick().await; // consume the immediate first tick
518    let mut awaiting_pong = false;
519    let mut send_failures: u32 = 0;
520
521    // Track graph version for server-push change detection.
522    let mut last_push_version = state.graph.version();
523
524    // Server-push check interval.
525    let mut push_interval = tokio::time::interval(Duration::from_millis(500));
526    push_interval.tick().await;
527
528    loop {
529        tokio::select! {
530            // Incoming WS messages
531            msg = ws_receiver.next() => {
532                let Some(Ok(msg)) = msg else { break };
533                last_activity = Instant::now();
534                awaiting_pong = false;
535
536                match msg {
537                    Message::Text(text) => {
538                        let response_text = match serde_json::from_str::<WsRequest>(&text) {
539                            Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
540                            Err(e) => {
541                                let err = WsResponse::error(None, format!("invalid request: {e}"));
542                                serde_json::to_string(&err).unwrap()
543                            }
544                        };
545                        if tx.try_send(Message::Text(response_text.into())).is_err() {
546                            send_failures += 1;
547                            if send_failures >= MAX_SEND_FAILURES {
548                                log::warn!("closing slow WS client ({})", username);
549                                break;
550                            }
551                        } else {
552                            send_failures = 0;
553                        }
554                    }
555                    Message::Ping(_) | Message::Pong(_) => {
556                        awaiting_pong = false;
557                    }
558                    Message::Close(_) => {
559                        break;
560                    }
561                    _ => {}
562                }
563            }
564
565            // Server-initiated ping for liveness
566            _ = ping_interval.tick() => {
567                if awaiting_pong && last_activity.elapsed() > PONG_TIMEOUT {
568                    log::info!("closing stale WS connection ({}): no pong", username);
569                    break;
570                }
571                if tx.try_send(Message::Ping(vec![].into())).is_err() {
572                    break;
573                }
574                awaiting_pong = true;
575            }
576
577            // Server-push: check for graph changes
578            _ = push_interval.tick() => {
579                let current_version = state.graph.version();
580                if current_version > last_push_version {
581                    let changes = match state.graph.changes_since(last_push_version) {
582                        Ok(c) => c,
583                        Err(_gap) => {
584                            last_push_version = current_version;
585                            continue;
586                        }
587                    };
588                    let changed_refs: HashSet<&str> =
589                        changes.iter().map(|d| d.ref_val.as_str()).collect();
590
591                    let affected = state.watches.watches_affected_by(&changed_refs);
592                    for (watch_id, owner, changed_ids) in &affected {
593                        if owner != &username {
594                            continue;
595                        }
596                        let rows: Vec<Value> = changed_ids
597                            .iter()
598                            .filter_map(|id| {
599                                let entity = state.graph.get(id)?;
600                                Some(state.watches.encode_cached(id, current_version, &entity))
601                            })
602                            .collect();
603                        if !rows.is_empty() {
604                            let push_msg = serde_json::json!({
605                                "type": "push",
606                                "watchId": watch_id,
607                                "rows": rows,
608                            });
609                            if let Ok(text) = serde_json::to_string(&push_msg) {
610                                let _ = tx.try_send(Message::Text(text.into()));
611                            }
612                        }
613                    }
614                    last_push_version = current_version;
615                }
616            }
617        }
618    }
619
620    // Cleanup: remove all watches owned by this user on disconnect.
621    state.watches.remove_by_owner(&username);
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627    use haystack_core::graph::{EntityGraph, SharedGraph};
628    use haystack_core::kinds::{HRef, Kind};
629
630    fn make_graph_with_entity(id: &str) -> SharedGraph {
631        let graph = SharedGraph::new(EntityGraph::new());
632        let mut entity = HDict::new();
633        entity.set("id", Kind::Ref(HRef::from_val(id)));
634        entity.set("site", Kind::Marker);
635        entity.set("dis", Kind::Str(format!("Site {id}")));
636        graph.add(entity).unwrap();
637        graph
638    }
639
640    #[test]
641    fn subscribe_returns_watch_id() {
642        let wm = WatchManager::new();
643        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
644        assert!(!watch_id.is_empty());
645    }
646
647    #[test]
648    fn poll_no_changes() {
649        let graph = make_graph_with_entity("site-1");
650        let wm = WatchManager::new();
651        let version = graph.version();
652        let watch_id = wm
653            .subscribe("admin", vec!["site-1".into()], version)
654            .unwrap();
655
656        let changes = wm.poll(&watch_id, "admin", &graph).unwrap();
657        assert!(changes.is_empty());
658    }
659
660    #[test]
661    fn poll_with_changes() {
662        let graph = make_graph_with_entity("site-1");
663        let wm = WatchManager::new();
664        let version = graph.version();
665        let watch_id = wm
666            .subscribe("admin", vec!["site-1".into()], version)
667            .unwrap();
668
669        let mut changes = HDict::new();
670        changes.set("dis", Kind::Str("Updated".into()));
671        graph.update("site-1", changes).unwrap();
672
673        let result = wm.poll(&watch_id, "admin", &graph).unwrap();
674        assert_eq!(result.len(), 1);
675    }
676
677    #[test]
678    fn poll_unknown_watch() {
679        let graph = make_graph_with_entity("site-1");
680        let wm = WatchManager::new();
681        assert!(wm.poll("unknown", "admin", &graph).is_none());
682    }
683
684    #[test]
685    fn poll_wrong_owner() {
686        let graph = make_graph_with_entity("site-1");
687        let wm = WatchManager::new();
688        let version = graph.version();
689        let watch_id = wm
690            .subscribe("admin", vec!["site-1".into()], version)
691            .unwrap();
692
693        assert!(wm.poll(&watch_id, "other-user", &graph).is_none());
694    }
695
696    #[test]
697    fn unsubscribe_removes_watch() {
698        let wm = WatchManager::new();
699        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
700        assert!(wm.unsubscribe(&watch_id, "admin"));
701        assert!(!wm.unsubscribe(&watch_id, "admin"));
702    }
703
704    #[test]
705    fn unsubscribe_wrong_owner() {
706        let wm = WatchManager::new();
707        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
708        assert!(!wm.unsubscribe(&watch_id, "other-user"));
709        assert!(wm.unsubscribe(&watch_id, "admin"));
710    }
711
712    #[test]
713    fn remove_ids_selective() {
714        let wm = WatchManager::new();
715        let watch_id = wm
716            .subscribe(
717                "admin",
718                vec!["site-1".into(), "site-2".into(), "site-3".into()],
719                0,
720            )
721            .unwrap();
722
723        assert!(wm.remove_ids(&watch_id, "admin", &["site-2".into()]));
724
725        let remaining = wm.get_ids(&watch_id).unwrap();
726        assert_eq!(remaining.len(), 2);
727        assert!(remaining.contains(&"site-1".to_string()));
728        assert!(remaining.contains(&"site-3".to_string()));
729        assert!(!remaining.contains(&"site-2".to_string()));
730    }
731
732    #[test]
733    fn remove_ids_nonexistent_watch() {
734        let wm = WatchManager::new();
735        assert!(!wm.remove_ids("no-such-watch", "admin", &["site-1".into()]));
736    }
737
738    #[test]
739    fn remove_ids_wrong_owner() {
740        let wm = WatchManager::new();
741        let watch_id = wm
742            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
743            .unwrap();
744
745        assert!(!wm.remove_ids(&watch_id, "other-user", &["site-1".into()]));
746
747        let remaining = wm.get_ids(&watch_id).unwrap();
748        assert_eq!(remaining.len(), 2);
749    }
750
751    #[test]
752    fn remove_ids_leaves_watch_alive() {
753        let wm = WatchManager::new();
754        let watch_id = wm
755            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
756            .unwrap();
757
758        assert!(wm.remove_ids(&watch_id, "admin", &["site-1".into(), "site-2".into()]));
759
760        let remaining = wm.get_ids(&watch_id).unwrap();
761        assert!(remaining.is_empty());
762
763        assert!(wm.unsubscribe(&watch_id, "admin"));
764    }
765
766    #[test]
767    fn unsubscribe_full_removal() {
768        let wm = WatchManager::new();
769        let watch_id = wm
770            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
771            .unwrap();
772
773        assert!(wm.unsubscribe(&watch_id, "admin"));
774        assert!(wm.get_ids(&watch_id).is_none());
775        assert!(!wm.unsubscribe(&watch_id, "admin"));
776    }
777
778    #[test]
779    fn add_ids_ownership_check() {
780        let wm = WatchManager::new();
781        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
782
783        assert!(!wm.add_ids(&watch_id, "other-user", vec!["site-2".into()]));
784        assert!(wm.add_ids(&watch_id, "admin", vec!["site-2".into()]));
785
786        let ids = wm.get_ids(&watch_id).unwrap();
787        assert_eq!(ids.len(), 2);
788        assert!(ids.contains(&"site-1".to_string()));
789        assert!(ids.contains(&"site-2".to_string()));
790    }
791
792    #[test]
793    fn get_ids_returns_none_for_unknown_watch() {
794        let wm = WatchManager::new();
795        assert!(wm.get_ids("nonexistent").is_none());
796    }
797
798    #[test]
799    fn ws_request_deserialization() {
800        let json = r#"{
801            "op": "watchSub",
802            "reqId": "abc-123",
803            "ids": ["@ref1", "@ref2"]
804        }"#;
805        let req: WsRequest = serde_json::from_str(json).unwrap();
806        assert_eq!(req.op, "watchSub");
807        assert_eq!(req.req_id.as_deref(), Some("abc-123"));
808        assert!(req.watch_id.is_none());
809        let ids = req.ids.unwrap();
810        assert_eq!(ids, vec!["@ref1", "@ref2"]);
811    }
812
813    #[test]
814    fn ws_request_deserialization_minimal() {
815        let json = r#"{"op": "watchPoll", "watchId": "w-1"}"#;
816        let req: WsRequest = serde_json::from_str(json).unwrap();
817        assert_eq!(req.op, "watchPoll");
818        assert!(req.req_id.is_none());
819        assert_eq!(req.watch_id.as_deref(), Some("w-1"));
820        assert!(req.ids.is_none());
821    }
822
823    #[test]
824    fn ws_response_serialization() {
825        let resp = WsResponse::ok(
826            Some("r-1".into()),
827            vec![serde_json::json!({"id": "r:site-1"})],
828            Some("w-1".into()),
829        );
830        let json = serde_json::to_value(&resp).unwrap();
831        assert_eq!(json["reqId"], "r-1");
832        assert_eq!(json["watchId"], "w-1");
833        assert!(json["rows"].is_array());
834        assert_eq!(json["rows"][0]["id"], "r:site-1");
835        assert!(json.get("error").is_none());
836    }
837
838    #[test]
839    fn ws_response_omits_none_fields() {
840        let resp = WsResponse::ok(None, vec![], None);
841        let json = serde_json::to_value(&resp).unwrap();
842        assert!(json.get("reqId").is_none());
843        assert!(json.get("error").is_none());
844        assert!(json.get("watchId").is_none());
845        assert!(json["rows"].is_array());
846    }
847
848    #[test]
849    fn ws_response_includes_req_id() {
850        let resp = WsResponse::error(Some("req-42".into()), "something went wrong");
851        let json = serde_json::to_value(&resp).unwrap();
852        assert_eq!(json["reqId"], "req-42");
853        assert_eq!(json["error"], "something went wrong");
854        assert!(json.get("rows").is_none());
855        assert!(json.get("watchId").is_none());
856    }
857
858    #[test]
859    fn ws_error_response_format() {
860        let resp = WsResponse::error(None, "bad request");
861        let json = serde_json::to_value(&resp).unwrap();
862        assert_eq!(json["error"], "bad request");
863        assert!(json.get("reqId").is_none());
864        assert!(json.get("rows").is_none());
865        assert!(json.get("watchId").is_none());
866    }
867}