1use std::collections::HashMap;
4
5use actix_web::{HttpMessage, HttpRequest, HttpResponse, web};
6use parking_lot::RwLock;
7use serde_json::{Map, Value};
8use uuid::Uuid;
9
10use haystack_core::codecs::json::v3 as json_v3;
11use haystack_core::data::HDict;
12use haystack_core::graph::SharedGraph;
13
14use crate::state::AppState;
15
16const MAX_WATCHES: usize = 100;
21const MAX_ENTITY_IDS_PER_WATCH: usize = 10_000;
22
23#[derive(serde::Deserialize, Debug)]
29struct WsRequest {
30 op: String,
31 #[serde(rename = "reqId")]
32 req_id: Option<String>,
33 #[serde(rename = "watchDis")]
34 #[allow(dead_code)]
35 watch_dis: Option<String>,
36 #[serde(rename = "watchId")]
37 watch_id: Option<String>,
38 ids: Option<Vec<String>>,
39}
40
41#[derive(serde::Serialize, Debug)]
43struct WsResponse {
44 #[serde(rename = "reqId", skip_serializing_if = "Option::is_none")]
45 req_id: Option<String>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 error: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 rows: Option<Vec<Value>>,
50 #[serde(rename = "watchId", skip_serializing_if = "Option::is_none")]
51 watch_id: Option<String>,
52}
53
54impl WsResponse {
55 fn error(req_id: Option<String>, msg: impl Into<String>) -> Self {
57 Self {
58 req_id,
59 error: Some(msg.into()),
60 rows: None,
61 watch_id: None,
62 }
63 }
64
65 fn ok(req_id: Option<String>, rows: Vec<Value>, watch_id: Option<String>) -> Self {
67 Self {
68 req_id,
69 error: None,
70 rows: Some(rows),
71 watch_id,
72 }
73 }
74}
75
76fn encode_entity(entity: &HDict) -> Value {
83 let mut m = Map::new();
84 let mut keys: Vec<&String> = entity.tags().keys().collect();
85 keys.sort();
86 for k in keys {
87 let v = &entity.tags()[k];
88 if let Ok(encoded) = json_v3::encode_kind(v) {
89 m.insert(k.clone(), encoded);
90 }
91 }
92 Value::Object(m)
93}
94
95fn handle_ws_request(req: &WsRequest, username: &str, state: &AppState) -> String {
103 let resp = match req.op.as_str() {
104 "watchSub" => handle_watch_sub(req, username, state),
105 "watchPoll" => handle_watch_poll(req, username, state),
106 "watchUnsub" => handle_watch_unsub(req, username, state),
107 other => WsResponse::error(req.req_id.clone(), format!("unknown op: {other}")),
108 };
109 serde_json::to_string(&resp).unwrap_or_else(|e| {
111 let fallback = WsResponse::error(req.req_id.clone(), format!("serialization error: {e}"));
112 serde_json::to_string(&fallback).unwrap()
113 })
114}
115
116fn handle_watch_sub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
119 let ids = match &req.ids {
120 Some(ids) if !ids.is_empty() => ids.clone(),
121 _ => {
122 return WsResponse::error(
123 req.req_id.clone(),
124 "watchSub requires non-empty 'ids' array",
125 );
126 }
127 };
128
129 let ids: Vec<String> = ids
131 .into_iter()
132 .map(|id| id.strip_prefix('@').unwrap_or(&id).to_string())
133 .collect();
134
135 let graph_version = state.graph.version();
136 let watch_id = match state
137 .watches
138 .subscribe(username, ids.clone(), graph_version)
139 {
140 Ok(wid) => wid,
141 Err(e) => return WsResponse::error(req.req_id.clone(), e),
142 };
143
144 let rows: Vec<Value> = ids
146 .iter()
147 .filter_map(|id| state.graph.get(id).map(|e| encode_entity(&e)))
148 .collect();
149
150 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
151}
152
153fn handle_watch_poll(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
155 let watch_id = match &req.watch_id {
156 Some(wid) => wid.clone(),
157 None => {
158 return WsResponse::error(req.req_id.clone(), "watchPoll requires 'watchId'");
159 }
160 };
161
162 match state.watches.poll(&watch_id, username, &state.graph) {
163 Some(changed) => {
164 let rows: Vec<Value> = changed.iter().map(encode_entity).collect();
165 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
166 }
167 None => WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}")),
168 }
169}
170
171fn handle_watch_unsub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
173 let watch_id = match &req.watch_id {
174 Some(wid) => wid.clone(),
175 None => {
176 return WsResponse::error(req.req_id.clone(), "watchUnsub requires 'watchId'");
177 }
178 };
179
180 if let Some(ids) = &req.ids
183 && !ids.is_empty()
184 {
185 let clean: Vec<String> = ids
186 .iter()
187 .map(|id| id.strip_prefix('@').unwrap_or(id).to_string())
188 .collect();
189 if !state.watches.remove_ids(&watch_id, username, &clean) {
190 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
191 }
192 return WsResponse::ok(req.req_id.clone(), vec![], Some(watch_id));
193 }
194
195 if !state.watches.unsubscribe(&watch_id, username) {
196 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
197 }
198 WsResponse::ok(req.req_id.clone(), vec![], None)
199}
200
201struct Watch {
203 entity_ids: Vec<String>,
205 last_version: u64,
207 owner: String,
209}
210
211pub struct WatchManager {
213 watches: RwLock<HashMap<String, Watch>>,
214}
215
216impl WatchManager {
217 pub fn new() -> Self {
219 Self {
220 watches: RwLock::new(HashMap::new()),
221 }
222 }
223
224 pub fn subscribe(
228 &self,
229 username: &str,
230 ids: Vec<String>,
231 graph_version: u64,
232 ) -> Result<String, String> {
233 let mut watches = self.watches.write();
234 if watches.len() >= MAX_WATCHES {
235 return Err("maximum number of watches reached".to_string());
236 }
237 if ids.len() > MAX_ENTITY_IDS_PER_WATCH {
238 return Err(format!(
239 "too many entity IDs (max {})",
240 MAX_ENTITY_IDS_PER_WATCH
241 ));
242 }
243 let watch_id = Uuid::new_v4().to_string();
244 let watch = Watch {
245 entity_ids: ids,
246 last_version: graph_version,
247 owner: username.to_string(),
248 };
249 watches.insert(watch_id.clone(), watch);
250 Ok(watch_id)
251 }
252
253 pub fn poll(&self, watch_id: &str, username: &str, graph: &SharedGraph) -> Option<Vec<HDict>> {
258 let (entity_ids, last_version) = {
262 let mut watches = self.watches.write();
263 let watch = watches.get_mut(watch_id)?;
264 if watch.owner != username {
265 return None;
266 }
267
268 let current_version = graph.version();
269 if current_version == watch.last_version {
270 return Some(Vec::new());
272 }
273
274 let ids = watch.entity_ids.clone();
275 let last = watch.last_version;
276 watch.last_version = current_version;
277 (ids, last)
278 }; let changes = graph.changes_since(last_version);
282 let changed_refs: std::collections::HashSet<&str> =
283 changes.iter().map(|d| d.ref_val.as_str()).collect();
284
285 Some(
286 entity_ids
287 .iter()
288 .filter(|id| changed_refs.contains(id.as_str()))
289 .filter_map(|id| graph.get(id))
290 .collect(),
291 )
292 }
293
294 pub fn unsubscribe(&self, watch_id: &str, username: &str) -> bool {
298 let mut watches = self.watches.write();
299 match watches.get(watch_id) {
300 Some(watch) if watch.owner == username => {
301 watches.remove(watch_id);
302 true
303 }
304 _ => false,
305 }
306 }
307
308 pub fn add_ids(&self, watch_id: &str, username: &str, ids: Vec<String>) -> bool {
313 let mut watches = self.watches.write();
314 if let Some(watch) = watches.get_mut(watch_id) {
315 if watch.owner != username {
316 return false;
317 }
318 if watch.entity_ids.len() + ids.len() > MAX_ENTITY_IDS_PER_WATCH {
319 return false;
320 }
321 watch.entity_ids.extend(ids);
322 true
323 } else {
324 false
325 }
326 }
327
328 pub fn remove_ids(&self, watch_id: &str, username: &str, ids: &[String]) -> bool {
333 let mut watches = self.watches.write();
334 if let Some(watch) = watches.get_mut(watch_id) {
335 if watch.owner != username {
336 return false;
337 }
338 let to_remove: std::collections::HashSet<&String> = ids.iter().collect();
339 watch.entity_ids.retain(|id| !to_remove.contains(id));
340 true
341 } else {
342 false
343 }
344 }
345
346 pub fn get_ids(&self, watch_id: &str) -> Option<Vec<String>> {
350 let watches = self.watches.read();
351 watches.get(watch_id).map(|w| w.entity_ids.clone())
352 }
353
354 pub fn len(&self) -> usize {
356 self.watches.read().len()
357 }
358
359 pub fn is_empty(&self) -> bool {
361 self.watches.read().is_empty()
362 }
363}
364
365impl Default for WatchManager {
366 fn default() -> Self {
367 Self::new()
368 }
369}
370
371pub async fn ws_handler(
378 req: HttpRequest,
379 stream: web::Payload,
380 state: web::Data<AppState>,
381) -> actix_web::Result<HttpResponse> {
382 let username = req
383 .extensions()
384 .get::<crate::auth::AuthUser>()
385 .map(|u| u.username.clone())
386 .unwrap_or_else(|| "anonymous".to_string());
387
388 let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
389
390 actix_rt::spawn(async move {
391 use actix_ws::Message;
392 use tokio::sync::mpsc;
393
394 let (tx, mut rx) = mpsc::channel::<String>(32);
395
396 let mut session_clone = session.clone();
398 actix_rt::spawn(async move {
399 while let Some(msg) = rx.recv().await {
400 let _ = session_clone.text(msg).await;
401 }
402 });
403
404 use futures_util::StreamExt as _;
406 while let Some(Ok(msg)) = msg_stream.next().await {
407 match msg {
408 Message::Text(text) => {
409 let response_text = match serde_json::from_str::<WsRequest>(&text) {
410 Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
411 Err(e) => {
412 let err = WsResponse::error(None, format!("invalid request: {e}"));
413 serde_json::to_string(&err).unwrap()
414 }
415 };
416 let _ = tx.send(response_text).await;
417 }
418 Message::Ping(bytes) => {
419 let _ = session.pong(&bytes).await;
420 }
421 Message::Close(_) => {
422 break;
423 }
424 _ => {}
425 }
426 }
427
428 let _ = session.close(None).await;
429 });
430
431 Ok(response)
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use haystack_core::graph::{EntityGraph, SharedGraph};
438 use haystack_core::kinds::{HRef, Kind};
439
440 fn make_graph_with_entity(id: &str) -> SharedGraph {
441 let graph = SharedGraph::new(EntityGraph::new());
442 let mut entity = HDict::new();
443 entity.set("id", Kind::Ref(HRef::from_val(id)));
444 entity.set("site", Kind::Marker);
445 entity.set("dis", Kind::Str(format!("Site {id}")));
446 graph.add(entity).unwrap();
447 graph
448 }
449
450 #[test]
451 fn subscribe_returns_watch_id() {
452 let wm = WatchManager::new();
453 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
454 assert!(!watch_id.is_empty());
455 }
456
457 #[test]
458 fn poll_no_changes() {
459 let graph = make_graph_with_entity("site-1");
460 let wm = WatchManager::new();
461 let version = graph.version();
462 let watch_id = wm
463 .subscribe("admin", vec!["site-1".into()], version)
464 .unwrap();
465
466 let changes = wm.poll(&watch_id, "admin", &graph).unwrap();
467 assert!(changes.is_empty());
468 }
469
470 #[test]
471 fn poll_with_changes() {
472 let graph = make_graph_with_entity("site-1");
473 let wm = WatchManager::new();
474 let version = graph.version();
475 let watch_id = wm
476 .subscribe("admin", vec!["site-1".into()], version)
477 .unwrap();
478
479 let mut changes = HDict::new();
481 changes.set("dis", Kind::Str("Updated".into()));
482 graph.update("site-1", changes).unwrap();
483
484 let result = wm.poll(&watch_id, "admin", &graph).unwrap();
485 assert_eq!(result.len(), 1);
486 }
487
488 #[test]
489 fn poll_unknown_watch() {
490 let graph = make_graph_with_entity("site-1");
491 let wm = WatchManager::new();
492 assert!(wm.poll("unknown", "admin", &graph).is_none());
493 }
494
495 #[test]
496 fn poll_wrong_owner() {
497 let graph = make_graph_with_entity("site-1");
498 let wm = WatchManager::new();
499 let version = graph.version();
500 let watch_id = wm
501 .subscribe("admin", vec!["site-1".into()], version)
502 .unwrap();
503
504 assert!(wm.poll(&watch_id, "other-user", &graph).is_none());
506 }
507
508 #[test]
509 fn unsubscribe_removes_watch() {
510 let wm = WatchManager::new();
511 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
512 assert!(wm.unsubscribe(&watch_id, "admin"));
513 assert!(!wm.unsubscribe(&watch_id, "admin")); }
515
516 #[test]
517 fn unsubscribe_wrong_owner() {
518 let wm = WatchManager::new();
519 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
520 assert!(!wm.unsubscribe(&watch_id, "other-user"));
522 assert!(wm.unsubscribe(&watch_id, "admin"));
524 }
525
526 #[test]
527 fn remove_ids_selective() {
528 let wm = WatchManager::new();
529 let watch_id = wm
530 .subscribe(
531 "admin",
532 vec!["site-1".into(), "site-2".into(), "site-3".into()],
533 0,
534 )
535 .unwrap();
536
537 assert!(wm.remove_ids(&watch_id, "admin", &["site-2".into()]));
539
540 let remaining = wm.get_ids(&watch_id).unwrap();
541 assert_eq!(remaining.len(), 2);
542 assert!(remaining.contains(&"site-1".to_string()));
543 assert!(remaining.contains(&"site-3".to_string()));
544 assert!(!remaining.contains(&"site-2".to_string()));
545 }
546
547 #[test]
548 fn remove_ids_nonexistent_watch() {
549 let wm = WatchManager::new();
550 assert!(!wm.remove_ids("no-such-watch", "admin", &["site-1".into()]));
551 }
552
553 #[test]
554 fn remove_ids_wrong_owner() {
555 let wm = WatchManager::new();
556 let watch_id = wm
557 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
558 .unwrap();
559
560 assert!(!wm.remove_ids(&watch_id, "other-user", &["site-1".into()]));
562
563 let remaining = wm.get_ids(&watch_id).unwrap();
565 assert_eq!(remaining.len(), 2);
566 }
567
568 #[test]
569 fn remove_ids_leaves_watch_alive() {
570 let wm = WatchManager::new();
571 let watch_id = wm
572 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
573 .unwrap();
574
575 assert!(wm.remove_ids(&watch_id, "admin", &["site-1".into(), "site-2".into()]));
577
578 let remaining = wm.get_ids(&watch_id).unwrap();
579 assert!(remaining.is_empty());
580
581 assert!(wm.unsubscribe(&watch_id, "admin"));
583 }
584
585 #[test]
586 fn unsubscribe_full_removal() {
587 let wm = WatchManager::new();
588 let watch_id = wm
589 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
590 .unwrap();
591
592 assert!(wm.unsubscribe(&watch_id, "admin"));
594
595 assert!(wm.get_ids(&watch_id).is_none());
597
598 assert!(!wm.unsubscribe(&watch_id, "admin"));
600 }
601
602 #[test]
603 fn add_ids_ownership_check() {
604 let wm = WatchManager::new();
605 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
606
607 assert!(!wm.add_ids(&watch_id, "other-user", vec!["site-2".into()]));
609
610 assert!(wm.add_ids(&watch_id, "admin", vec!["site-2".into()]));
612
613 let ids = wm.get_ids(&watch_id).unwrap();
614 assert_eq!(ids.len(), 2);
615 assert!(ids.contains(&"site-1".to_string()));
616 assert!(ids.contains(&"site-2".to_string()));
617 }
618
619 #[test]
620 fn get_ids_returns_none_for_unknown_watch() {
621 let wm = WatchManager::new();
622 assert!(wm.get_ids("nonexistent").is_none());
623 }
624
625 #[test]
630 fn ws_request_deserialization() {
631 let json = r#"{
632 "op": "watchSub",
633 "reqId": "abc-123",
634 "watchDis": "my-watch",
635 "ids": ["@ref1", "@ref2"]
636 }"#;
637 let req: WsRequest = serde_json::from_str(json).unwrap();
638 assert_eq!(req.op, "watchSub");
639 assert_eq!(req.req_id.as_deref(), Some("abc-123"));
640 assert_eq!(req.watch_dis.as_deref(), Some("my-watch"));
641 assert!(req.watch_id.is_none());
642 let ids = req.ids.unwrap();
643 assert_eq!(ids, vec!["@ref1", "@ref2"]);
644 }
645
646 #[test]
647 fn ws_request_deserialization_minimal() {
648 let json = r#"{"op": "watchPoll", "watchId": "w-1"}"#;
650 let req: WsRequest = serde_json::from_str(json).unwrap();
651 assert_eq!(req.op, "watchPoll");
652 assert!(req.req_id.is_none());
653 assert!(req.watch_dis.is_none());
654 assert_eq!(req.watch_id.as_deref(), Some("w-1"));
655 assert!(req.ids.is_none());
656 }
657
658 #[test]
659 fn ws_response_serialization() {
660 let resp = WsResponse::ok(
661 Some("r-1".into()),
662 vec![serde_json::json!({"id": "r:site-1"})],
663 Some("w-1".into()),
664 );
665 let json = serde_json::to_value(&resp).unwrap();
666 assert_eq!(json["reqId"], "r-1");
667 assert_eq!(json["watchId"], "w-1");
668 assert!(json["rows"].is_array());
669 assert_eq!(json["rows"][0]["id"], "r:site-1");
670 assert!(json.get("error").is_none());
672 }
673
674 #[test]
675 fn ws_response_omits_none_fields() {
676 let resp = WsResponse::ok(None, vec![], None);
677 let json = serde_json::to_value(&resp).unwrap();
678 assert!(json.get("reqId").is_none());
680 assert!(json.get("error").is_none());
681 assert!(json.get("watchId").is_none());
682 assert!(json["rows"].is_array());
684 }
685
686 #[test]
687 fn ws_response_includes_req_id() {
688 let resp = WsResponse::error(Some("req-42".into()), "something went wrong");
689 let json = serde_json::to_value(&resp).unwrap();
690 assert_eq!(json["reqId"], "req-42");
691 assert_eq!(json["error"], "something went wrong");
692 assert!(json.get("rows").is_none());
694 assert!(json.get("watchId").is_none());
695 }
696
697 #[test]
698 fn ws_error_response_format() {
699 let resp = WsResponse::error(None, "bad request");
700 let json = serde_json::to_value(&resp).unwrap();
701 assert_eq!(json["error"], "bad request");
702 assert!(json.get("reqId").is_none());
704 assert!(json.get("rows").is_none());
706 assert!(json.get("watchId").is_none());
707 }
708}