1use std::collections::{HashMap, HashSet};
12use std::time::{Duration, Instant};
13
14use axum::Extension;
15use axum::extract::State;
16use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
17use axum::response::Response;
18use parking_lot::RwLock;
19use serde_json::{Map, Value};
20use uuid::Uuid;
21
22use haystack_core::codecs::json::v3 as json_v3;
23use haystack_core::data::HDict;
24use haystack_core::graph::SharedGraph;
25
26use crate::auth::AuthUser;
27use crate::state::SharedState;
28
29const MAX_WATCHES: usize = 100;
34const MAX_ENTITY_IDS_PER_WATCH: usize = 1_000;
35const MAX_TOTAL_WATCHED_IDS: usize = 5_000;
37const MAX_WATCHES_PER_USER: usize = 20;
39
40const MAX_ENCODE_CACHE_ENTRIES: usize = 50_000;
42
43const PING_INTERVAL: Duration = Duration::from_secs(30);
45
46const PONG_TIMEOUT: Duration = Duration::from_secs(10);
49
50const CHANNEL_CAPACITY: usize = 64;
52
53const MAX_SEND_FAILURES: u32 = 3;
55
56#[derive(serde::Deserialize, Debug)]
62struct WsRequest {
63 op: String,
64 #[serde(rename = "reqId")]
65 req_id: Option<String>,
66 #[serde(rename = "watchId")]
67 watch_id: Option<String>,
68 ids: Option<Vec<String>>,
69}
70
71#[derive(serde::Serialize, Debug)]
73struct WsResponse {
74 #[serde(rename = "reqId", skip_serializing_if = "Option::is_none")]
75 req_id: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 error: Option<String>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 rows: Option<Vec<Value>>,
80 #[serde(rename = "watchId", skip_serializing_if = "Option::is_none")]
81 watch_id: Option<String>,
82}
83
84impl WsResponse {
85 fn error(req_id: Option<String>, msg: impl Into<String>) -> Self {
87 Self {
88 req_id,
89 error: Some(msg.into()),
90 rows: None,
91 watch_id: None,
92 }
93 }
94
95 fn ok(req_id: Option<String>, rows: Vec<Value>, watch_id: Option<String>) -> Self {
97 Self {
98 req_id,
99 error: None,
100 rows: Some(rows),
101 watch_id,
102 }
103 }
104}
105
106fn encode_entity(entity: &HDict) -> Value {
113 let mut m = Map::new();
114 let mut keys: Vec<&String> = entity.tags().keys().collect();
115 keys.sort();
116 for k in keys {
117 let v = &entity.tags()[k];
118 if let Ok(encoded) = json_v3::encode_kind(v) {
119 m.insert(k.clone(), encoded);
120 }
121 }
122 Value::Object(m)
123}
124
125fn handle_ws_request(req: &WsRequest, username: &str, state: &SharedState) -> String {
131 let resp = match req.op.as_str() {
132 "watchSub" => handle_watch_sub(req, username, state),
133 "watchPoll" => handle_watch_poll(req, username, state),
134 "watchUnsub" => handle_watch_unsub(req, username, state),
135 other => WsResponse::error(req.req_id.clone(), format!("unknown op: {other}")),
136 };
137 serde_json::to_string(&resp).unwrap_or_else(|e| {
138 let fallback = WsResponse::error(req.req_id.clone(), format!("serialization error: {e}"));
139 serde_json::to_string(&fallback).unwrap()
140 })
141}
142
143fn handle_watch_sub(req: &WsRequest, username: &str, state: &SharedState) -> WsResponse {
144 let ids = match &req.ids {
145 Some(ids) if !ids.is_empty() => ids.clone(),
146 _ => {
147 return WsResponse::error(
148 req.req_id.clone(),
149 "watchSub requires non-empty 'ids' array",
150 );
151 }
152 };
153
154 let ids: Vec<String> = ids
156 .into_iter()
157 .map(|id| id.strip_prefix('@').unwrap_or(&id).to_string())
158 .collect();
159
160 let graph_version = state.graph.version();
161 let watch_id = match state
162 .watches
163 .subscribe(username, ids.clone(), graph_version)
164 {
165 Ok(wid) => wid,
166 Err(e) => return WsResponse::error(req.req_id.clone(), e),
167 };
168
169 let rows: Vec<Value> = ids
170 .iter()
171 .filter_map(|id| state.graph.get(id).map(|e| encode_entity(&e)))
172 .collect();
173
174 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
175}
176
177fn handle_watch_poll(req: &WsRequest, username: &str, state: &SharedState) -> WsResponse {
178 let watch_id = match &req.watch_id {
179 Some(wid) => wid.clone(),
180 None => {
181 return WsResponse::error(req.req_id.clone(), "watchPoll requires 'watchId'");
182 }
183 };
184
185 match state.watches.poll(&watch_id, username, &state.graph) {
186 Some(changed) => {
187 let rows: Vec<Value> = changed.iter().map(encode_entity).collect();
188 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
189 }
190 None => WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}")),
191 }
192}
193
194fn handle_watch_unsub(req: &WsRequest, username: &str, state: &SharedState) -> WsResponse {
195 let watch_id = match &req.watch_id {
196 Some(wid) => wid.clone(),
197 None => {
198 return WsResponse::error(req.req_id.clone(), "watchUnsub requires 'watchId'");
199 }
200 };
201
202 if let Some(ids) = &req.ids
203 && !ids.is_empty()
204 {
205 let clean: Vec<String> = ids
206 .iter()
207 .map(|id| id.strip_prefix('@').unwrap_or(id).to_string())
208 .collect();
209 if !state.watches.remove_ids(&watch_id, username, &clean) {
210 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
211 }
212 return WsResponse::ok(req.req_id.clone(), vec![], Some(watch_id));
213 }
214
215 if !state.watches.unsubscribe(&watch_id, username) {
216 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
217 }
218 WsResponse::ok(req.req_id.clone(), vec![], None)
219}
220
221struct Watch {
223 entity_ids: HashSet<String>,
225 last_version: u64,
227 owner: String,
229}
230
231pub struct WatchManager {
233 watches: RwLock<HashMap<String, Watch>>,
234 encode_cache: RwLock<HashMap<(String, u64), Value>>,
236 cache_version: RwLock<u64>,
238}
239
240impl WatchManager {
241 pub fn new() -> Self {
243 Self {
244 watches: RwLock::new(HashMap::new()),
245 encode_cache: RwLock::new(HashMap::new()),
246 cache_version: RwLock::new(0),
247 }
248 }
249
250 pub fn subscribe(
252 &self,
253 username: &str,
254 ids: Vec<String>,
255 graph_version: u64,
256 ) -> Result<String, String> {
257 let mut watches = self.watches.write();
258 if watches.len() >= MAX_WATCHES {
259 return Err("maximum number of watches reached".to_string());
260 }
261 let user_count = watches.values().filter(|w| w.owner == username).count();
262 if user_count >= MAX_WATCHES_PER_USER {
263 return Err(format!(
264 "user '{}' has reached the maximum of {} watches",
265 username, MAX_WATCHES_PER_USER
266 ));
267 }
268 if ids.len() > MAX_ENTITY_IDS_PER_WATCH {
269 return Err(format!(
270 "too many entity IDs (max {})",
271 MAX_ENTITY_IDS_PER_WATCH
272 ));
273 }
274 let user_total: usize = watches
275 .values()
276 .filter(|w| w.owner == username)
277 .map(|w| w.entity_ids.len())
278 .sum();
279 if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
280 return Err(format!(
281 "user '{}' would exceed the maximum of {} total watched IDs",
282 username, MAX_TOTAL_WATCHED_IDS
283 ));
284 }
285 let watch_id = Uuid::new_v4().to_string();
286 let watch = Watch {
287 entity_ids: ids.into_iter().collect(),
288 last_version: graph_version,
289 owner: username.to_string(),
290 };
291 watches.insert(watch_id.clone(), watch);
292 Ok(watch_id)
293 }
294
295 pub fn poll(&self, watch_id: &str, username: &str, graph: &SharedGraph) -> Option<Vec<HDict>> {
297 let (entity_ids, last_version) = {
298 let mut watches = self.watches.write();
299 let watch = watches.get_mut(watch_id)?;
300 if watch.owner != username {
301 return None;
302 }
303
304 let current_version = graph.version();
305 if current_version == watch.last_version {
306 return Some(Vec::new());
307 }
308
309 let ids = watch.entity_ids.clone();
310 let last = watch.last_version;
311 watch.last_version = current_version;
312 (ids, last)
313 };
314
315 let changes = match graph.changes_since(last_version) {
316 Ok(c) => c,
317 Err(_gap) => {
318 return Some(entity_ids.iter().filter_map(|id| graph.get(id)).collect());
319 }
320 };
321 let changed_refs: HashSet<&str> = changes.iter().map(|d| d.ref_val.as_str()).collect();
322
323 Some(
324 entity_ids
325 .iter()
326 .filter(|id| changed_refs.contains(id.as_str()))
327 .filter_map(|id| graph.get(id))
328 .collect(),
329 )
330 }
331
332 pub fn unsubscribe(&self, watch_id: &str, username: &str) -> bool {
334 let mut watches = self.watches.write();
335 match watches.get(watch_id) {
336 Some(watch) if watch.owner == username => {
337 watches.remove(watch_id);
338 true
339 }
340 _ => false,
341 }
342 }
343
344 pub fn add_ids(&self, watch_id: &str, username: &str, ids: Vec<String>) -> bool {
346 let mut watches = self.watches.write();
347
348 let (owner_ok, per_watch_ok, user_total) = match watches.get(watch_id) {
349 Some(watch) => (
350 watch.owner == username,
351 watch.entity_ids.len() + ids.len() <= MAX_ENTITY_IDS_PER_WATCH,
352 watches
353 .values()
354 .filter(|w| w.owner == username)
355 .map(|w| w.entity_ids.len())
356 .sum::<usize>(),
357 ),
358 None => return false,
359 };
360
361 if !owner_ok || !per_watch_ok {
362 return false;
363 }
364 if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
365 return false;
366 }
367
368 if let Some(watch) = watches.get_mut(watch_id) {
369 watch.entity_ids.extend(ids);
370 }
371 true
372 }
373
374 pub fn remove_ids(&self, watch_id: &str, username: &str, ids: &[String]) -> bool {
376 let mut watches = self.watches.write();
377 if let Some(watch) = watches.get_mut(watch_id) {
378 if watch.owner != username {
379 return false;
380 }
381 for id in ids {
382 watch.entity_ids.remove(id);
383 }
384 true
385 } else {
386 false
387 }
388 }
389
390 pub fn remove_by_owner(&self, owner: &str) {
392 let mut watches = self.watches.write();
393 watches.retain(|_, w| w.owner != owner);
394 }
395
396 pub fn get_ids(&self, watch_id: &str) -> Option<Vec<String>> {
398 let watches = self.watches.read();
399 watches
400 .get(watch_id)
401 .map(|w| w.entity_ids.iter().cloned().collect())
402 }
403
404 pub fn len(&self) -> usize {
406 self.watches.read().len()
407 }
408
409 pub fn is_empty(&self) -> bool {
411 self.watches.read().is_empty()
412 }
413
414 pub fn encode_cached(&self, ref_val: &str, graph_version: u64, entity: &HDict) -> Value {
416 {
417 let mut cv = self.cache_version.write();
418 if graph_version > *cv {
419 self.encode_cache.write().clear();
420 *cv = graph_version;
421 }
422 }
423
424 let key = (ref_val.to_string(), graph_version);
425 if let Some(cached) = self.encode_cache.read().get(&key) {
426 return cached.clone();
427 }
428
429 let encoded = encode_entity(entity);
430 let mut cache = self.encode_cache.write();
431 cache.insert(key, encoded.clone());
432 if cache.len() > MAX_ENCODE_CACHE_ENTRIES {
433 let to_remove = cache.len() / 4;
434 let keys: Vec<_> = cache.keys().take(to_remove).cloned().collect();
435 for k in keys {
436 cache.remove(&k);
437 }
438 }
439 encoded
440 }
441
442 pub fn all_watched_ids(&self) -> HashSet<String> {
444 let watches = self.watches.read();
445 watches
446 .values()
447 .flat_map(|w| w.entity_ids.iter().cloned())
448 .collect()
449 }
450
451 pub fn watches_affected_by(
453 &self,
454 changed_refs: &HashSet<&str>,
455 ) -> Vec<(String, String, Vec<String>)> {
456 let watches = self.watches.read();
457 let mut affected = Vec::new();
458 for (wid, watch) in watches.iter() {
459 let matched: Vec<String> = watch
460 .entity_ids
461 .iter()
462 .filter(|id| changed_refs.contains(id.as_str()))
463 .cloned()
464 .collect();
465 if !matched.is_empty() {
466 affected.push((wid.clone(), watch.owner.clone(), matched));
467 }
468 }
469 affected
470 }
471}
472
473impl Default for WatchManager {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479pub async fn ws_handler(
485 ws: WebSocketUpgrade,
486 State(state): State<SharedState>,
487 auth: Option<Extension<AuthUser>>,
488) -> Response {
489 let username = auth
490 .map(|Extension(u)| u.username)
491 .unwrap_or_else(|| "anonymous".into());
492 ws.on_upgrade(move |socket| handle_socket(socket, username, state))
493}
494
495async fn handle_socket(socket: WebSocket, username: String, state: SharedState) {
497 use tokio::sync::mpsc;
498
499 let (tx, mut rx) = mpsc::channel::<Message>(CHANNEL_CAPACITY);
500
501 use futures_util::{SinkExt, StreamExt};
503
504 let (mut ws_sender, mut ws_receiver) = socket.split();
505
506 tokio::spawn(async move {
507 while let Some(msg) = rx.recv().await {
508 if ws_sender.send(msg).await.is_err() {
509 break;
510 }
511 }
512 });
513
514 let mut last_activity = Instant::now();
516 let mut ping_interval = tokio::time::interval(PING_INTERVAL);
517 ping_interval.tick().await; let mut awaiting_pong = false;
519 let mut send_failures: u32 = 0;
520
521 let mut last_push_version = state.graph.version();
523
524 let mut push_interval = tokio::time::interval(Duration::from_millis(500));
526 push_interval.tick().await;
527
528 loop {
529 tokio::select! {
530 msg = ws_receiver.next() => {
532 let Some(Ok(msg)) = msg else { break };
533 last_activity = Instant::now();
534 awaiting_pong = false;
535
536 match msg {
537 Message::Text(text) => {
538 let response_text = match serde_json::from_str::<WsRequest>(&text) {
539 Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
540 Err(e) => {
541 let err = WsResponse::error(None, format!("invalid request: {e}"));
542 serde_json::to_string(&err).unwrap()
543 }
544 };
545 if tx.try_send(Message::Text(response_text.into())).is_err() {
546 send_failures += 1;
547 if send_failures >= MAX_SEND_FAILURES {
548 log::warn!("closing slow WS client ({})", username);
549 break;
550 }
551 } else {
552 send_failures = 0;
553 }
554 }
555 Message::Ping(_) | Message::Pong(_) => {
556 awaiting_pong = false;
557 }
558 Message::Close(_) => {
559 break;
560 }
561 _ => {}
562 }
563 }
564
565 _ = ping_interval.tick() => {
567 if awaiting_pong && last_activity.elapsed() > PONG_TIMEOUT {
568 log::info!("closing stale WS connection ({}): no pong", username);
569 break;
570 }
571 if tx.try_send(Message::Ping(vec![].into())).is_err() {
572 break;
573 }
574 awaiting_pong = true;
575 }
576
577 _ = push_interval.tick() => {
579 let current_version = state.graph.version();
580 if current_version > last_push_version {
581 let changes = match state.graph.changes_since(last_push_version) {
582 Ok(c) => c,
583 Err(_gap) => {
584 last_push_version = current_version;
585 continue;
586 }
587 };
588 let changed_refs: HashSet<&str> =
589 changes.iter().map(|d| d.ref_val.as_str()).collect();
590
591 let affected = state.watches.watches_affected_by(&changed_refs);
592 for (watch_id, owner, changed_ids) in &affected {
593 if owner != &username {
594 continue;
595 }
596 let rows: Vec<Value> = changed_ids
597 .iter()
598 .filter_map(|id| {
599 let entity = state.graph.get(id)?;
600 Some(state.watches.encode_cached(id, current_version, &entity))
601 })
602 .collect();
603 if !rows.is_empty() {
604 let push_msg = serde_json::json!({
605 "type": "push",
606 "watchId": watch_id,
607 "rows": rows,
608 });
609 if let Ok(text) = serde_json::to_string(&push_msg) {
610 let _ = tx.try_send(Message::Text(text.into()));
611 }
612 }
613 }
614 last_push_version = current_version;
615 }
616 }
617 }
618 }
619
620 state.watches.remove_by_owner(&username);
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627 use haystack_core::graph::{EntityGraph, SharedGraph};
628 use haystack_core::kinds::{HRef, Kind};
629
630 fn make_graph_with_entity(id: &str) -> SharedGraph {
631 let graph = SharedGraph::new(EntityGraph::new());
632 let mut entity = HDict::new();
633 entity.set("id", Kind::Ref(HRef::from_val(id)));
634 entity.set("site", Kind::Marker);
635 entity.set("dis", Kind::Str(format!("Site {id}")));
636 graph.add(entity).unwrap();
637 graph
638 }
639
640 #[test]
641 fn subscribe_returns_watch_id() {
642 let wm = WatchManager::new();
643 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
644 assert!(!watch_id.is_empty());
645 }
646
647 #[test]
648 fn poll_no_changes() {
649 let graph = make_graph_with_entity("site-1");
650 let wm = WatchManager::new();
651 let version = graph.version();
652 let watch_id = wm
653 .subscribe("admin", vec!["site-1".into()], version)
654 .unwrap();
655
656 let changes = wm.poll(&watch_id, "admin", &graph).unwrap();
657 assert!(changes.is_empty());
658 }
659
660 #[test]
661 fn poll_with_changes() {
662 let graph = make_graph_with_entity("site-1");
663 let wm = WatchManager::new();
664 let version = graph.version();
665 let watch_id = wm
666 .subscribe("admin", vec!["site-1".into()], version)
667 .unwrap();
668
669 let mut changes = HDict::new();
670 changes.set("dis", Kind::Str("Updated".into()));
671 graph.update("site-1", changes).unwrap();
672
673 let result = wm.poll(&watch_id, "admin", &graph).unwrap();
674 assert_eq!(result.len(), 1);
675 }
676
677 #[test]
678 fn poll_unknown_watch() {
679 let graph = make_graph_with_entity("site-1");
680 let wm = WatchManager::new();
681 assert!(wm.poll("unknown", "admin", &graph).is_none());
682 }
683
684 #[test]
685 fn poll_wrong_owner() {
686 let graph = make_graph_with_entity("site-1");
687 let wm = WatchManager::new();
688 let version = graph.version();
689 let watch_id = wm
690 .subscribe("admin", vec!["site-1".into()], version)
691 .unwrap();
692
693 assert!(wm.poll(&watch_id, "other-user", &graph).is_none());
694 }
695
696 #[test]
697 fn unsubscribe_removes_watch() {
698 let wm = WatchManager::new();
699 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
700 assert!(wm.unsubscribe(&watch_id, "admin"));
701 assert!(!wm.unsubscribe(&watch_id, "admin"));
702 }
703
704 #[test]
705 fn unsubscribe_wrong_owner() {
706 let wm = WatchManager::new();
707 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
708 assert!(!wm.unsubscribe(&watch_id, "other-user"));
709 assert!(wm.unsubscribe(&watch_id, "admin"));
710 }
711
712 #[test]
713 fn remove_ids_selective() {
714 let wm = WatchManager::new();
715 let watch_id = wm
716 .subscribe(
717 "admin",
718 vec!["site-1".into(), "site-2".into(), "site-3".into()],
719 0,
720 )
721 .unwrap();
722
723 assert!(wm.remove_ids(&watch_id, "admin", &["site-2".into()]));
724
725 let remaining = wm.get_ids(&watch_id).unwrap();
726 assert_eq!(remaining.len(), 2);
727 assert!(remaining.contains(&"site-1".to_string()));
728 assert!(remaining.contains(&"site-3".to_string()));
729 assert!(!remaining.contains(&"site-2".to_string()));
730 }
731
732 #[test]
733 fn remove_ids_nonexistent_watch() {
734 let wm = WatchManager::new();
735 assert!(!wm.remove_ids("no-such-watch", "admin", &["site-1".into()]));
736 }
737
738 #[test]
739 fn remove_ids_wrong_owner() {
740 let wm = WatchManager::new();
741 let watch_id = wm
742 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
743 .unwrap();
744
745 assert!(!wm.remove_ids(&watch_id, "other-user", &["site-1".into()]));
746
747 let remaining = wm.get_ids(&watch_id).unwrap();
748 assert_eq!(remaining.len(), 2);
749 }
750
751 #[test]
752 fn remove_ids_leaves_watch_alive() {
753 let wm = WatchManager::new();
754 let watch_id = wm
755 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
756 .unwrap();
757
758 assert!(wm.remove_ids(&watch_id, "admin", &["site-1".into(), "site-2".into()]));
759
760 let remaining = wm.get_ids(&watch_id).unwrap();
761 assert!(remaining.is_empty());
762
763 assert!(wm.unsubscribe(&watch_id, "admin"));
764 }
765
766 #[test]
767 fn unsubscribe_full_removal() {
768 let wm = WatchManager::new();
769 let watch_id = wm
770 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
771 .unwrap();
772
773 assert!(wm.unsubscribe(&watch_id, "admin"));
774 assert!(wm.get_ids(&watch_id).is_none());
775 assert!(!wm.unsubscribe(&watch_id, "admin"));
776 }
777
778 #[test]
779 fn add_ids_ownership_check() {
780 let wm = WatchManager::new();
781 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
782
783 assert!(!wm.add_ids(&watch_id, "other-user", vec!["site-2".into()]));
784 assert!(wm.add_ids(&watch_id, "admin", vec!["site-2".into()]));
785
786 let ids = wm.get_ids(&watch_id).unwrap();
787 assert_eq!(ids.len(), 2);
788 assert!(ids.contains(&"site-1".to_string()));
789 assert!(ids.contains(&"site-2".to_string()));
790 }
791
792 #[test]
793 fn get_ids_returns_none_for_unknown_watch() {
794 let wm = WatchManager::new();
795 assert!(wm.get_ids("nonexistent").is_none());
796 }
797
798 #[test]
799 fn ws_request_deserialization() {
800 let json = r#"{
801 "op": "watchSub",
802 "reqId": "abc-123",
803 "ids": ["@ref1", "@ref2"]
804 }"#;
805 let req: WsRequest = serde_json::from_str(json).unwrap();
806 assert_eq!(req.op, "watchSub");
807 assert_eq!(req.req_id.as_deref(), Some("abc-123"));
808 assert!(req.watch_id.is_none());
809 let ids = req.ids.unwrap();
810 assert_eq!(ids, vec!["@ref1", "@ref2"]);
811 }
812
813 #[test]
814 fn ws_request_deserialization_minimal() {
815 let json = r#"{"op": "watchPoll", "watchId": "w-1"}"#;
816 let req: WsRequest = serde_json::from_str(json).unwrap();
817 assert_eq!(req.op, "watchPoll");
818 assert!(req.req_id.is_none());
819 assert_eq!(req.watch_id.as_deref(), Some("w-1"));
820 assert!(req.ids.is_none());
821 }
822
823 #[test]
824 fn ws_response_serialization() {
825 let resp = WsResponse::ok(
826 Some("r-1".into()),
827 vec![serde_json::json!({"id": "r:site-1"})],
828 Some("w-1".into()),
829 );
830 let json = serde_json::to_value(&resp).unwrap();
831 assert_eq!(json["reqId"], "r-1");
832 assert_eq!(json["watchId"], "w-1");
833 assert!(json["rows"].is_array());
834 assert_eq!(json["rows"][0]["id"], "r:site-1");
835 assert!(json.get("error").is_none());
836 }
837
838 #[test]
839 fn ws_response_omits_none_fields() {
840 let resp = WsResponse::ok(None, vec![], None);
841 let json = serde_json::to_value(&resp).unwrap();
842 assert!(json.get("reqId").is_none());
843 assert!(json.get("error").is_none());
844 assert!(json.get("watchId").is_none());
845 assert!(json["rows"].is_array());
846 }
847
848 #[test]
849 fn ws_response_includes_req_id() {
850 let resp = WsResponse::error(Some("req-42".into()), "something went wrong");
851 let json = serde_json::to_value(&resp).unwrap();
852 assert_eq!(json["reqId"], "req-42");
853 assert_eq!(json["error"], "something went wrong");
854 assert!(json.get("rows").is_none());
855 assert!(json.get("watchId").is_none());
856 }
857
858 #[test]
859 fn ws_error_response_format() {
860 let resp = WsResponse::error(None, "bad request");
861 let json = serde_json::to_value(&resp).unwrap();
862 assert_eq!(json["error"], "bad request");
863 assert!(json.get("reqId").is_none());
864 assert!(json.get("rows").is_none());
865 assert!(json.get("watchId").is_none());
866 }
867}