Skip to main content

haystack_server/
ws.rs

1//! WebSocket handler and watch subscription manager.
2//!
3//! This module provides two major components:
4//!
5//! 1. **`WatchManager`** — a thread-safe subscription registry that manages
6//!    watch lifecycles (subscribe, poll, unsubscribe, add/remove IDs). Each
7//!    watch tracks a set of entity IDs and the graph version at last poll,
8//!    enabling efficient change detection.
9//!
10//! 2. **`ws_handler`** — an Actix-Web WebSocket upgrade endpoint (`GET /api/ws`)
11//!    that handles Haystack watch operations over JSON messages. Supports
12//!    server-initiated ping/pong liveness, deflate compression for large
13//!    payloads, and automatic server-push of graph changes to watching clients.
14
15use std::collections::{HashMap, HashSet};
16use std::time::{Duration, Instant};
17
18use actix_web::{HttpMessage, HttpRequest, HttpResponse, web};
19use parking_lot::RwLock;
20use serde_json::{Map, Value};
21use uuid::Uuid;
22
23use haystack_core::codecs::json::v3 as json_v3;
24use haystack_core::data::HDict;
25use haystack_core::graph::SharedGraph;
26
27use crate::state::AppState;
28
29// ---------------------------------------------------------------------------
30// Tuning constants
31// ---------------------------------------------------------------------------
32
33const MAX_WATCHES: usize = 100;
34const MAX_ENTITY_IDS_PER_WATCH: usize = 10_000;
35/// Maximum watches a single user can hold at once.
36const MAX_WATCHES_PER_USER: usize = 20;
37
38/// Server-initiated ping interval for liveness detection.
39const PING_INTERVAL: Duration = Duration::from_secs(30);
40
41/// If no pong is received within this duration after a ping, the connection
42/// is considered dead and will be closed.
43const PONG_TIMEOUT: Duration = Duration::from_secs(10);
44
45/// mpsc channel capacity for outbound messages.
46const CHANNEL_CAPACITY: usize = 64;
47
48/// Number of consecutive `try_send` failures before closing a slow client.
49const MAX_SEND_FAILURES: u32 = 3;
50
51/// Minimum payload size (bytes) to consider compressing with deflate.
52const COMPRESSION_THRESHOLD: usize = 512;
53
54// ---------------------------------------------------------------------------
55// WebSocket message types
56// ---------------------------------------------------------------------------
57
58/// Incoming JSON message from a WebSocket client.
59#[derive(serde::Deserialize, Debug)]
60struct WsRequest {
61    op: String,
62    #[serde(rename = "reqId")]
63    req_id: Option<String>,
64    #[serde(rename = "watchDis")]
65    #[allow(dead_code)]
66    watch_dis: Option<String>,
67    #[serde(rename = "watchId")]
68    watch_id: Option<String>,
69    ids: Option<Vec<String>>,
70}
71
72/// Outgoing JSON message sent to a WebSocket client.
73#[derive(serde::Serialize, Debug)]
74struct WsResponse {
75    #[serde(rename = "reqId", skip_serializing_if = "Option::is_none")]
76    req_id: Option<String>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    error: Option<String>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    rows: Option<Vec<Value>>,
81    #[serde(rename = "watchId", skip_serializing_if = "Option::is_none")]
82    watch_id: Option<String>,
83}
84
85impl WsResponse {
86    /// Build an error response, preserving the request ID for correlation.
87    fn error(req_id: Option<String>, msg: impl Into<String>) -> Self {
88        Self {
89            req_id,
90            error: Some(msg.into()),
91            rows: None,
92            watch_id: None,
93        }
94    }
95
96    /// Build a success response with rows and an optional watch ID.
97    fn ok(req_id: Option<String>, rows: Vec<Value>, watch_id: Option<String>) -> Self {
98        Self {
99            req_id,
100            error: None,
101            rows: Some(rows),
102            watch_id,
103        }
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Entity encoding helper
109// ---------------------------------------------------------------------------
110
111/// Encode an `HDict` entity as a JSON object using the Haystack JSON v3
112/// encoding for individual tag values.
113fn encode_entity(entity: &HDict) -> Value {
114    let mut m = Map::new();
115    let mut keys: Vec<&String> = entity.tags().keys().collect();
116    keys.sort();
117    for k in keys {
118        let v = &entity.tags()[k];
119        if let Ok(encoded) = json_v3::encode_kind(v) {
120            m.insert(k.clone(), encoded);
121        }
122    }
123    Value::Object(m)
124}
125
126// ---------------------------------------------------------------------------
127// WebSocket op dispatch
128// ---------------------------------------------------------------------------
129
130/// Handle a parsed `WsRequest` by dispatching to the appropriate watch op.
131///
132/// Returns the serialized JSON response string.
133fn handle_ws_request(req: &WsRequest, username: &str, state: &AppState) -> String {
134    let resp = match req.op.as_str() {
135        "watchSub" => handle_watch_sub(req, username, state),
136        "watchPoll" => handle_watch_poll(req, username, state),
137        "watchUnsub" => handle_watch_unsub(req, username, state),
138        other => WsResponse::error(req.req_id.clone(), format!("unknown op: {other}")),
139    };
140    // Serialization of WsResponse should never fail in practice.
141    serde_json::to_string(&resp).unwrap_or_else(|e| {
142        let fallback = WsResponse::error(req.req_id.clone(), format!("serialization error: {e}"));
143        serde_json::to_string(&fallback).unwrap()
144    })
145}
146
147/// Handle the `watchSub` op: create a new watch and return the initial
148/// state of the subscribed entities.
149fn handle_watch_sub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
150    let ids = match &req.ids {
151        Some(ids) if !ids.is_empty() => ids.clone(),
152        _ => {
153            return WsResponse::error(
154                req.req_id.clone(),
155                "watchSub requires non-empty 'ids' array",
156            );
157        }
158    };
159
160    // Strip leading '@' from ref strings if present.
161    let ids: Vec<String> = ids
162        .into_iter()
163        .map(|id| id.strip_prefix('@').unwrap_or(&id).to_string())
164        .collect();
165
166    let graph_version = state.graph.version();
167    let watch_id = match state
168        .watches
169        .subscribe(username, ids.clone(), graph_version)
170    {
171        Ok(wid) => wid,
172        Err(e) => return WsResponse::error(req.req_id.clone(), e),
173    };
174
175    // Resolve initial entity values.
176    let rows: Vec<Value> = ids
177        .iter()
178        .filter_map(|id| state.graph.get(id).map(|e| encode_entity(&e)))
179        .collect();
180
181    WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
182}
183
184/// Handle the `watchPoll` op: poll an existing watch for changes.
185fn handle_watch_poll(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
186    let watch_id = match &req.watch_id {
187        Some(wid) => wid.clone(),
188        None => {
189            return WsResponse::error(req.req_id.clone(), "watchPoll requires 'watchId'");
190        }
191    };
192
193    match state.watches.poll(&watch_id, username, &state.graph) {
194        Some(changed) => {
195            let rows: Vec<Value> = changed.iter().map(encode_entity).collect();
196            WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
197        }
198        None => WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}")),
199    }
200}
201
202/// Handle the `watchUnsub` op: remove a watch or specific IDs from it.
203fn handle_watch_unsub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
204    let watch_id = match &req.watch_id {
205        Some(wid) => wid.clone(),
206        None => {
207            return WsResponse::error(req.req_id.clone(), "watchUnsub requires 'watchId'");
208        }
209    };
210
211    // If specific IDs are provided, remove only those; otherwise remove the
212    // entire watch.
213    if let Some(ids) = &req.ids
214        && !ids.is_empty()
215    {
216        let clean: Vec<String> = ids
217            .iter()
218            .map(|id| id.strip_prefix('@').unwrap_or(id).to_string())
219            .collect();
220        if !state.watches.remove_ids(&watch_id, username, &clean) {
221            return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
222        }
223        return WsResponse::ok(req.req_id.clone(), vec![], Some(watch_id));
224    }
225
226    if !state.watches.unsubscribe(&watch_id, username) {
227        return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
228    }
229    WsResponse::ok(req.req_id.clone(), vec![], None)
230}
231
232/// A single watch subscription.
233struct Watch {
234    /// Entity IDs being watched (HashSet for O(1) membership tests).
235    entity_ids: HashSet<String>,
236    /// Graph version at last poll.
237    last_version: u64,
238    /// Username of the watch owner.
239    owner: String,
240}
241
242/// Manages watch subscriptions for change polling.
243///
244/// Watches are keyed by a UUID watch ID and owned by a specific user.
245/// The manager enforces global and per-user watch limits, per-watch entity
246/// ID limits, and provides an entity encoding cache for efficient
247/// WebSocket server-push serialization.
248pub struct WatchManager {
249    watches: RwLock<HashMap<String, Watch>>,
250    /// Cached entity encodings keyed by (ref_val, version) for watch poll.
251    encode_cache: RwLock<HashMap<(String, u64), Value>>,
252    /// Graph version at which the encode cache was last validated.
253    cache_version: RwLock<u64>,
254}
255
256impl WatchManager {
257    /// Create a new empty WatchManager.
258    pub fn new() -> Self {
259        Self {
260            watches: RwLock::new(HashMap::new()),
261            encode_cache: RwLock::new(HashMap::new()),
262            cache_version: RwLock::new(0),
263        }
264    }
265
266    /// Subscribe to changes on a set of entity IDs.
267    ///
268    /// Returns the watch ID, or an error if a growth cap would be exceeded.
269    pub fn subscribe(
270        &self,
271        username: &str,
272        ids: Vec<String>,
273        graph_version: u64,
274    ) -> Result<String, String> {
275        let mut watches = self.watches.write();
276        if watches.len() >= MAX_WATCHES {
277            return Err("maximum number of watches reached".to_string());
278        }
279        let user_count = watches.values().filter(|w| w.owner == username).count();
280        if user_count >= MAX_WATCHES_PER_USER {
281            return Err(format!(
282                "user '{}' has reached the maximum of {} watches",
283                username, MAX_WATCHES_PER_USER
284            ));
285        }
286        if ids.len() > MAX_ENTITY_IDS_PER_WATCH {
287            return Err(format!(
288                "too many entity IDs (max {})",
289                MAX_ENTITY_IDS_PER_WATCH
290            ));
291        }
292        let watch_id = Uuid::new_v4().to_string();
293        let watch = Watch {
294            entity_ids: ids.into_iter().collect(),
295            last_version: graph_version,
296            owner: username.to_string(),
297        };
298        watches.insert(watch_id.clone(), watch);
299        Ok(watch_id)
300    }
301
302    /// Poll for changes since the last poll.
303    ///
304    /// Returns the current state of watched entities that have changed,
305    /// or `None` if the watch ID is not found or the caller is not the owner.
306    pub fn poll(&self, watch_id: &str, username: &str, graph: &SharedGraph) -> Option<Vec<HDict>> {
307        // Acquire the write lock only long enough to read watch state and
308        // update last_version.  Graph reads happen outside the lock to
309        // avoid holding it during potentially expensive I/O.
310        let (entity_ids, last_version) = {
311            let mut watches = self.watches.write();
312            let watch = watches.get_mut(watch_id)?;
313            if watch.owner != username {
314                return None;
315            }
316
317            let current_version = graph.version();
318            if current_version == watch.last_version {
319                // No changes
320                return Some(Vec::new());
321            }
322
323            let ids = watch.entity_ids.clone();
324            let last = watch.last_version;
325            watch.last_version = current_version;
326            (ids, last)
327        }; // write lock released here
328
329        // Graph reads happen without the WatchManager write lock held.
330        let changes = graph.changes_since(last_version);
331        let changed_refs: HashSet<&str> = changes.iter().map(|d| d.ref_val.as_str()).collect();
332
333        Some(
334            entity_ids
335                .iter()
336                .filter(|id| changed_refs.contains(id.as_str()))
337                .filter_map(|id| graph.get(id))
338                .collect(),
339        )
340    }
341
342    /// Unsubscribe a watch by ID.
343    ///
344    /// Returns `true` if the watch existed, was owned by `username`, and was removed.
345    pub fn unsubscribe(&self, watch_id: &str, username: &str) -> bool {
346        let mut watches = self.watches.write();
347        match watches.get(watch_id) {
348            Some(watch) if watch.owner == username => {
349                watches.remove(watch_id);
350                true
351            }
352            _ => false,
353        }
354    }
355
356    /// Add entity IDs to an existing watch.
357    ///
358    /// Returns `true` if the watch exists, is owned by `username`, and
359    /// the addition would not exceed the per-watch entity ID limit.
360    pub fn add_ids(&self, watch_id: &str, username: &str, ids: Vec<String>) -> bool {
361        let mut watches = self.watches.write();
362        if let Some(watch) = watches.get_mut(watch_id) {
363            if watch.owner != username {
364                return false;
365            }
366            if watch.entity_ids.len() + ids.len() > MAX_ENTITY_IDS_PER_WATCH {
367                return false;
368            }
369            watch.entity_ids.extend(ids);
370            true
371        } else {
372            false
373        }
374    }
375
376    /// Remove specific entity IDs from an existing watch.
377    ///
378    /// Returns `true` if the watch exists and is owned by `username`.
379    /// If all IDs are removed, the watch remains but with an empty entity set.
380    pub fn remove_ids(&self, watch_id: &str, username: &str, ids: &[String]) -> bool {
381        let mut watches = self.watches.write();
382        if let Some(watch) = watches.get_mut(watch_id) {
383            if watch.owner != username {
384                return false;
385            }
386            for id in ids {
387                watch.entity_ids.remove(id);
388            }
389            true
390        } else {
391            false
392        }
393    }
394
395    /// Return the list of entity IDs for a given watch.
396    ///
397    /// Returns `None` if the watch does not exist.
398    pub fn get_ids(&self, watch_id: &str) -> Option<Vec<String>> {
399        let watches = self.watches.read();
400        watches
401            .get(watch_id)
402            .map(|w| w.entity_ids.iter().cloned().collect())
403    }
404
405    /// Return the number of active watches.
406    pub fn len(&self) -> usize {
407        self.watches.read().len()
408    }
409
410    /// Return whether there are no active watches.
411    pub fn is_empty(&self) -> bool {
412        self.watches.read().is_empty()
413    }
414
415    /// Encode an entity using the cache. Returns cached value if the entity
416    /// version hasn't changed; otherwise encodes and caches the result.
417    pub fn encode_cached(&self, ref_val: &str, graph_version: u64, entity: &HDict) -> Value {
418        // Invalidate entire cache when graph version advances beyond what we've seen.
419        {
420            let mut cv = self.cache_version.write();
421            if graph_version > *cv {
422                self.encode_cache.write().clear();
423                *cv = graph_version;
424            }
425        }
426
427        let key = (ref_val.to_string(), graph_version);
428        if let Some(cached) = self.encode_cache.read().get(&key) {
429            return cached.clone();
430        }
431
432        let encoded = encode_entity(entity);
433        self.encode_cache.write().insert(key, encoded.clone());
434        encoded
435    }
436
437    /// Get the IDs of all entities watched by any watch, for server-push
438    /// change detection.
439    pub fn all_watched_ids(&self) -> HashSet<String> {
440        let watches = self.watches.read();
441        watches
442            .values()
443            .flat_map(|w| w.entity_ids.iter().cloned())
444            .collect()
445    }
446
447    /// Find watches that contain any of the given changed ref_vals.
448    /// Returns (watch_id, owner, changed_entity_ids) tuples.
449    pub fn watches_affected_by(
450        &self,
451        changed_refs: &HashSet<&str>,
452    ) -> Vec<(String, String, Vec<String>)> {
453        let watches = self.watches.read();
454        let mut affected = Vec::new();
455        for (wid, watch) in watches.iter() {
456            let matched: Vec<String> = watch
457                .entity_ids
458                .iter()
459                .filter(|id| changed_refs.contains(id.as_str()))
460                .cloned()
461                .collect();
462            if !matched.is_empty() {
463                affected.push((wid.clone(), watch.owner.clone(), matched));
464            }
465        }
466        affected
467    }
468}
469
470impl Default for WatchManager {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476// ---------------------------------------------------------------------------
477// Compression helpers (application-level deflate)
478// ---------------------------------------------------------------------------
479
480/// Compress a response string with deflate if it exceeds the threshold.
481/// Returns the original text if compression doesn't save space.
482fn maybe_compress_response(text: &str) -> WsPayload {
483    if text.len() < COMPRESSION_THRESHOLD {
484        return WsPayload::Text(text.to_string());
485    }
486    let compressed = compress_deflate(text.as_bytes());
487    if compressed.len() < text.len() {
488        WsPayload::Binary(compressed)
489    } else {
490        WsPayload::Text(text.to_string())
491    }
492}
493
494fn compress_deflate(data: &[u8]) -> Vec<u8> {
495    use flate2::Compression;
496    use flate2::write::DeflateEncoder;
497    use std::io::Write;
498
499    let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
500    let _ = encoder.write_all(data);
501    encoder.finish().unwrap_or_else(|_| data.to_vec())
502}
503
504/// Maximum decompressed payload size (10 MB) to prevent zip bomb attacks.
505const MAX_DECOMPRESSED_SIZE: u64 = 10 * 1024 * 1024;
506
507fn decompress_deflate(data: &[u8]) -> Result<String, std::io::Error> {
508    use flate2::read::DeflateDecoder;
509    use std::io::Read;
510
511    let decoder = DeflateDecoder::new(data);
512    let mut limited = decoder.take(MAX_DECOMPRESSED_SIZE);
513    let mut output = String::new();
514    limited.read_to_string(&mut output)?;
515    Ok(output)
516}
517
518enum WsPayload {
519    Text(String),
520    Binary(Vec<u8>),
521}
522
523// ---------------------------------------------------------------------------
524// WebSocket handler
525// ---------------------------------------------------------------------------
526
527/// WebSocket upgrade handler for `/api/ws`.
528///
529/// Upgrades the HTTP connection to a WebSocket and handles Haystack
530/// watch operations (watchSub, watchPoll, watchUnsub) over JSON
531/// messages.  Each client request may include a `reqId` field which
532/// is echoed back in the response for correlation.
533///
534/// Features:
535/// - Server-initiated ping every [`PING_INTERVAL`] for liveness detection
536/// - Backpressure: slow clients are disconnected after [`MAX_SEND_FAILURES`]
537/// - Deflate compression for large responses (binary frames)
538/// - Server-push: graph changes are pushed to watching clients automatically
539pub async fn ws_handler(
540    req: HttpRequest,
541    stream: web::Payload,
542    state: web::Data<AppState>,
543) -> actix_web::Result<HttpResponse> {
544    // Require authenticated user when auth is enabled
545    let username = if state.auth.is_enabled() {
546        match req.extensions().get::<crate::auth::AuthUser>() {
547            Some(u) => u.username.clone(),
548            None => {
549                return Err(crate::error::HaystackError::new(
550                    "authentication required for WebSocket connections",
551                    actix_web::http::StatusCode::UNAUTHORIZED,
552                )
553                .into());
554            }
555        }
556    } else {
557        req.extensions()
558            .get::<crate::auth::AuthUser>()
559            .map(|u| u.username.clone())
560            .unwrap_or_else(|| "anonymous".to_string())
561    };
562
563    let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
564
565    actix_rt::spawn(async move {
566        use actix_ws::Message;
567        use tokio::sync::mpsc;
568
569        let (tx, mut rx) = mpsc::channel::<WsPayload>(CHANNEL_CAPACITY);
570
571        // Spawn a task to forward messages from the channel to the WS session.
572        let mut session_clone = session.clone();
573        actix_rt::spawn(async move {
574            while let Some(payload) = rx.recv().await {
575                let result = match payload {
576                    WsPayload::Text(text) => session_clone.text(text).await,
577                    WsPayload::Binary(data) => session_clone.binary(data).await,
578                };
579                if result.is_err() {
580                    break;
581                }
582            }
583        });
584
585        // Track connection liveness.
586        let mut last_activity = Instant::now();
587        let mut ping_interval = tokio::time::interval(PING_INTERVAL);
588        ping_interval.tick().await; // consume the immediate first tick
589        let mut awaiting_pong = false;
590        let mut send_failures: u32 = 0;
591
592        // Track graph version for server-push change detection.
593        let mut last_push_version = state.graph.version();
594
595        // Server-push check interval (faster than ping but slower than a busy loop).
596        let mut push_interval = tokio::time::interval(Duration::from_millis(500));
597        push_interval.tick().await;
598
599        use futures_util::StreamExt as _;
600
601        loop {
602            tokio::select! {
603                // ── Incoming WS messages ──
604                msg = msg_stream.next() => {
605                    let Some(Ok(msg)) = msg else { break };
606                    last_activity = Instant::now();
607                    awaiting_pong = false;
608
609                    match msg {
610                        Message::Text(text) => {
611                            let response_text = match serde_json::from_str::<WsRequest>(&text) {
612                                Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
613                                Err(e) => {
614                                    let err = WsResponse::error(None, format!("invalid request: {e}"));
615                                    serde_json::to_string(&err).unwrap()
616                                }
617                            };
618                            let payload = maybe_compress_response(&response_text);
619                            if tx.try_send(payload).is_err() {
620                                send_failures += 1;
621                                if send_failures >= MAX_SEND_FAILURES {
622                                    log::warn!("closing slow WS client ({})", username);
623                                    break;
624                                }
625                            } else {
626                                send_failures = 0;
627                            }
628                        }
629                        Message::Binary(data) => {
630                            // Compressed request from client.
631                            if let Ok(text) = decompress_deflate(&data) {
632                                let response_text = match serde_json::from_str::<WsRequest>(&text) {
633                                    Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
634                                    Err(e) => {
635                                        let err = WsResponse::error(None, format!("invalid request: {e}"));
636                                        serde_json::to_string(&err).unwrap()
637                                    }
638                                };
639                                let payload = maybe_compress_response(&response_text);
640                                let _ = tx.try_send(payload);
641                            }
642                        }
643                        Message::Ping(bytes) => {
644                            let _ = session.pong(&bytes).await;
645                        }
646                        Message::Pong(_) => {
647                            awaiting_pong = false;
648                        }
649                        Message::Close(_) => {
650                            break;
651                        }
652                        _ => {}
653                    }
654                }
655
656                // ── Server-initiated ping for liveness ──
657                _ = ping_interval.tick() => {
658                    if awaiting_pong && last_activity.elapsed() > PONG_TIMEOUT {
659                        log::info!("closing stale WS connection ({}): no pong", username);
660                        break;
661                    }
662                    let _ = session.ping(b"").await;
663                    awaiting_pong = true;
664                }
665
666                // ── Server-push: check for graph changes ──
667                _ = push_interval.tick() => {
668                    let current_version = state.graph.version();
669                    if current_version > last_push_version {
670                        let changes = state.graph.changes_since(last_push_version);
671                        let changed_refs: HashSet<&str> =
672                            changes.iter().map(|d| d.ref_val.as_str()).collect();
673
674                        let affected = state.watches.watches_affected_by(&changed_refs);
675                        for (watch_id, owner, changed_ids) in &affected {
676                            if owner != &username {
677                                continue;
678                            }
679                            let rows: Vec<Value> = changed_ids
680                                .iter()
681                                .filter_map(|id| {
682                                    let entity = state.graph.get(id)?;
683                                    Some(state.watches.encode_cached(id, current_version, &entity))
684                                })
685                                .collect();
686                            if !rows.is_empty() {
687                                let push_msg = serde_json::json!({
688                                    "type": "push",
689                                    "watchId": watch_id,
690                                    "rows": rows,
691                                });
692                                if let Ok(text) = serde_json::to_string(&push_msg) {
693                                    let payload = maybe_compress_response(&text);
694                                    let _ = tx.try_send(payload);
695                                }
696                            }
697                        }
698                        last_push_version = current_version;
699                    }
700                }
701            }
702        }
703
704        let _ = session.close(None).await;
705    });
706
707    Ok(response)
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713    use haystack_core::graph::{EntityGraph, SharedGraph};
714    use haystack_core::kinds::{HRef, Kind};
715
716    fn make_graph_with_entity(id: &str) -> SharedGraph {
717        let graph = SharedGraph::new(EntityGraph::new());
718        let mut entity = HDict::new();
719        entity.set("id", Kind::Ref(HRef::from_val(id)));
720        entity.set("site", Kind::Marker);
721        entity.set("dis", Kind::Str(format!("Site {id}")));
722        graph.add(entity).unwrap();
723        graph
724    }
725
726    #[test]
727    fn subscribe_returns_watch_id() {
728        let wm = WatchManager::new();
729        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
730        assert!(!watch_id.is_empty());
731    }
732
733    #[test]
734    fn poll_no_changes() {
735        let graph = make_graph_with_entity("site-1");
736        let wm = WatchManager::new();
737        let version = graph.version();
738        let watch_id = wm
739            .subscribe("admin", vec!["site-1".into()], version)
740            .unwrap();
741
742        let changes = wm.poll(&watch_id, "admin", &graph).unwrap();
743        assert!(changes.is_empty());
744    }
745
746    #[test]
747    fn poll_with_changes() {
748        let graph = make_graph_with_entity("site-1");
749        let wm = WatchManager::new();
750        let version = graph.version();
751        let watch_id = wm
752            .subscribe("admin", vec!["site-1".into()], version)
753            .unwrap();
754
755        // Modify the entity
756        let mut changes = HDict::new();
757        changes.set("dis", Kind::Str("Updated".into()));
758        graph.update("site-1", changes).unwrap();
759
760        let result = wm.poll(&watch_id, "admin", &graph).unwrap();
761        assert_eq!(result.len(), 1);
762    }
763
764    #[test]
765    fn poll_unknown_watch() {
766        let graph = make_graph_with_entity("site-1");
767        let wm = WatchManager::new();
768        assert!(wm.poll("unknown", "admin", &graph).is_none());
769    }
770
771    #[test]
772    fn poll_wrong_owner() {
773        let graph = make_graph_with_entity("site-1");
774        let wm = WatchManager::new();
775        let version = graph.version();
776        let watch_id = wm
777            .subscribe("admin", vec!["site-1".into()], version)
778            .unwrap();
779
780        // Different user cannot poll the watch
781        assert!(wm.poll(&watch_id, "other-user", &graph).is_none());
782    }
783
784    #[test]
785    fn unsubscribe_removes_watch() {
786        let wm = WatchManager::new();
787        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
788        assert!(wm.unsubscribe(&watch_id, "admin"));
789        assert!(!wm.unsubscribe(&watch_id, "admin")); // already removed
790    }
791
792    #[test]
793    fn unsubscribe_wrong_owner() {
794        let wm = WatchManager::new();
795        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
796        // Different user cannot unsubscribe
797        assert!(!wm.unsubscribe(&watch_id, "other-user"));
798        // Original owner can still unsubscribe
799        assert!(wm.unsubscribe(&watch_id, "admin"));
800    }
801
802    #[test]
803    fn remove_ids_selective() {
804        let wm = WatchManager::new();
805        let watch_id = wm
806            .subscribe(
807                "admin",
808                vec!["site-1".into(), "site-2".into(), "site-3".into()],
809                0,
810            )
811            .unwrap();
812
813        // Remove only site-2
814        assert!(wm.remove_ids(&watch_id, "admin", &["site-2".into()]));
815
816        let remaining = wm.get_ids(&watch_id).unwrap();
817        assert_eq!(remaining.len(), 2);
818        assert!(remaining.contains(&"site-1".to_string()));
819        assert!(remaining.contains(&"site-3".to_string()));
820        assert!(!remaining.contains(&"site-2".to_string()));
821    }
822
823    #[test]
824    fn remove_ids_nonexistent_watch() {
825        let wm = WatchManager::new();
826        assert!(!wm.remove_ids("no-such-watch", "admin", &["site-1".into()]));
827    }
828
829    #[test]
830    fn remove_ids_wrong_owner() {
831        let wm = WatchManager::new();
832        let watch_id = wm
833            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
834            .unwrap();
835
836        // Different user cannot remove IDs
837        assert!(!wm.remove_ids(&watch_id, "other-user", &["site-1".into()]));
838
839        // IDs remain unchanged
840        let remaining = wm.get_ids(&watch_id).unwrap();
841        assert_eq!(remaining.len(), 2);
842    }
843
844    #[test]
845    fn remove_ids_leaves_watch_alive() {
846        let wm = WatchManager::new();
847        let watch_id = wm
848            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
849            .unwrap();
850
851        // Remove all IDs selectively — watch still exists with empty entity set
852        assert!(wm.remove_ids(&watch_id, "admin", &["site-1".into(), "site-2".into()]));
853
854        let remaining = wm.get_ids(&watch_id).unwrap();
855        assert!(remaining.is_empty());
856
857        // The watch itself still exists (unsubscribe should succeed)
858        assert!(wm.unsubscribe(&watch_id, "admin"));
859    }
860
861    #[test]
862    fn unsubscribe_full_removal() {
863        let wm = WatchManager::new();
864        let watch_id = wm
865            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
866            .unwrap();
867
868        // Full unsubscribe removes the watch entirely
869        assert!(wm.unsubscribe(&watch_id, "admin"));
870
871        // Watch no longer exists — get_ids returns None
872        assert!(wm.get_ids(&watch_id).is_none());
873
874        // Second unsubscribe returns false
875        assert!(!wm.unsubscribe(&watch_id, "admin"));
876    }
877
878    #[test]
879    fn add_ids_ownership_check() {
880        let wm = WatchManager::new();
881        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
882
883        // Different user cannot add IDs
884        assert!(!wm.add_ids(&watch_id, "other-user", vec!["site-2".into()]));
885
886        // Original owner can add IDs
887        assert!(wm.add_ids(&watch_id, "admin", vec!["site-2".into()]));
888
889        let ids = wm.get_ids(&watch_id).unwrap();
890        assert_eq!(ids.len(), 2);
891        assert!(ids.contains(&"site-1".to_string()));
892        assert!(ids.contains(&"site-2".to_string()));
893    }
894
895    #[test]
896    fn get_ids_returns_none_for_unknown_watch() {
897        let wm = WatchManager::new();
898        assert!(wm.get_ids("nonexistent").is_none());
899    }
900
901    // -----------------------------------------------------------------------
902    // WebSocket message format tests
903    // -----------------------------------------------------------------------
904
905    #[test]
906    fn ws_request_deserialization() {
907        let json = r#"{
908            "op": "watchSub",
909            "reqId": "abc-123",
910            "watchDis": "my-watch",
911            "ids": ["@ref1", "@ref2"]
912        }"#;
913        let req: WsRequest = serde_json::from_str(json).unwrap();
914        assert_eq!(req.op, "watchSub");
915        assert_eq!(req.req_id.as_deref(), Some("abc-123"));
916        assert_eq!(req.watch_dis.as_deref(), Some("my-watch"));
917        assert!(req.watch_id.is_none());
918        let ids = req.ids.unwrap();
919        assert_eq!(ids, vec!["@ref1", "@ref2"]);
920    }
921
922    #[test]
923    fn ws_request_deserialization_minimal() {
924        // Only `op` is required; all other fields are optional.
925        let json = r#"{"op": "watchPoll", "watchId": "w-1"}"#;
926        let req: WsRequest = serde_json::from_str(json).unwrap();
927        assert_eq!(req.op, "watchPoll");
928        assert!(req.req_id.is_none());
929        assert!(req.watch_dis.is_none());
930        assert_eq!(req.watch_id.as_deref(), Some("w-1"));
931        assert!(req.ids.is_none());
932    }
933
934    #[test]
935    fn ws_response_serialization() {
936        let resp = WsResponse::ok(
937            Some("r-1".into()),
938            vec![serde_json::json!({"id": "r:site-1"})],
939            Some("w-1".into()),
940        );
941        let json = serde_json::to_value(&resp).unwrap();
942        assert_eq!(json["reqId"], "r-1");
943        assert_eq!(json["watchId"], "w-1");
944        assert!(json["rows"].is_array());
945        assert_eq!(json["rows"][0]["id"], "r:site-1");
946        // `error` should be absent (not null) when None
947        assert!(json.get("error").is_none());
948    }
949
950    #[test]
951    fn ws_response_omits_none_fields() {
952        let resp = WsResponse::ok(None, vec![], None);
953        let json = serde_json::to_value(&resp).unwrap();
954        // reqId, error, and watchId should all be absent
955        assert!(json.get("reqId").is_none());
956        assert!(json.get("error").is_none());
957        assert!(json.get("watchId").is_none());
958        // rows is present (empty array)
959        assert!(json["rows"].is_array());
960    }
961
962    #[test]
963    fn ws_response_includes_req_id() {
964        let resp = WsResponse::error(Some("req-42".into()), "something went wrong");
965        let json = serde_json::to_value(&resp).unwrap();
966        assert_eq!(json["reqId"], "req-42");
967        assert_eq!(json["error"], "something went wrong");
968        // rows and watchId should be absent on error
969        assert!(json.get("rows").is_none());
970        assert!(json.get("watchId").is_none());
971    }
972
973    #[test]
974    fn ws_error_response_format() {
975        let resp = WsResponse::error(None, "bad request");
976        let json = serde_json::to_value(&resp).unwrap();
977        assert_eq!(json["error"], "bad request");
978        // reqId should be absent when not provided
979        assert!(json.get("reqId").is_none());
980        // rows and watchId should be absent
981        assert!(json.get("rows").is_none());
982        assert!(json.get("watchId").is_none());
983    }
984}