this/events/sinks/
websocket.rs1use crate::config::sinks::SinkType;
19use crate::events::sinks::Sink;
20use anyhow::Result;
21use async_trait::async_trait;
22use serde_json::Value;
23use std::collections::HashMap;
24use std::sync::Arc;
25
26#[async_trait]
34pub trait WebSocketDispatcher: Send + Sync + std::fmt::Debug {
35 async fn dispatch_to_recipient(&self, recipient_id: &str, payload: Value) -> Result<usize>;
42
43 async fn broadcast(&self, payload: Value) -> Result<usize>;
47}
48
49#[derive(Debug)]
54pub struct WebSocketSink {
55 dispatcher: Arc<dyn WebSocketDispatcher>,
57}
58
59impl WebSocketSink {
60 pub fn new(dispatcher: Arc<dyn WebSocketDispatcher>) -> Self {
62 Self { dispatcher }
63 }
64}
65
66#[async_trait]
67impl Sink for WebSocketSink {
68 async fn deliver(
69 &self,
70 payload: Value,
71 recipient_id: Option<&str>,
72 context_vars: &HashMap<String, Value>,
73 ) -> Result<()> {
74 let recipient = super::resolve_recipient(recipient_id, &payload, context_vars);
76
77 let count = match &recipient {
78 Some(rid) => {
79 tracing::debug!(
80 recipient = %rid,
81 "websocket sink: dispatching to recipient connections"
82 );
83 self.dispatcher.dispatch_to_recipient(rid, payload).await?
84 }
85 None => {
86 tracing::debug!("websocket sink: broadcasting to all connections");
87 self.dispatcher.broadcast(payload).await?
88 }
89 };
90
91 tracing::debug!(
92 connections = count,
93 "websocket sink: dispatched to connections"
94 );
95
96 Ok(())
97 }
98
99 fn name(&self) -> &str {
100 "websocket"
101 }
102
103 fn sink_type(&self) -> SinkType {
104 SinkType::WebSocket
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use serde_json::json;
112 use std::sync::atomic::{AtomicUsize, Ordering};
113 use tokio::sync::Mutex;
114
115 #[derive(Debug)]
117 struct MockDispatcher {
118 dispatched: Mutex<Vec<(Option<String>, Value)>>,
119 dispatch_count: AtomicUsize,
120 }
121
122 impl MockDispatcher {
123 fn new() -> Self {
124 Self {
125 dispatched: Mutex::new(Vec::new()),
126 dispatch_count: AtomicUsize::new(0),
127 }
128 }
129 }
130
131 #[async_trait]
132 impl WebSocketDispatcher for MockDispatcher {
133 async fn dispatch_to_recipient(&self, recipient_id: &str, payload: Value) -> Result<usize> {
134 self.dispatched
135 .lock()
136 .await
137 .push((Some(recipient_id.to_string()), payload));
138 let count = self.dispatch_count.load(Ordering::SeqCst);
139 Ok(if count > 0 { count } else { 1 })
140 }
141
142 async fn broadcast(&self, payload: Value) -> Result<usize> {
143 self.dispatched.lock().await.push((None, payload));
144 let count = self.dispatch_count.load(Ordering::SeqCst);
145 Ok(if count > 0 { count } else { 1 })
146 }
147 }
148
149 #[tokio::test]
150 async fn test_ws_deliver_to_recipient() {
151 let dispatcher = Arc::new(MockDispatcher::new());
152 let sink = WebSocketSink::new(dispatcher.clone());
153
154 let payload = json!({
155 "title": "New follower",
156 "body": "Alice followed you",
157 "recipient_id": "user-A"
158 });
159
160 sink.deliver(payload.clone(), None, &HashMap::new())
161 .await
162 .unwrap();
163
164 let dispatched = dispatcher.dispatched.lock().await;
165 assert_eq!(dispatched.len(), 1);
166 assert_eq!(dispatched[0].0.as_deref(), Some("user-A"));
167 assert_eq!(dispatched[0].1["title"], "New follower");
168 }
169
170 #[tokio::test]
171 async fn test_ws_deliver_explicit_recipient() {
172 let dispatcher = Arc::new(MockDispatcher::new());
173 let sink = WebSocketSink::new(dispatcher.clone());
174
175 let payload = json!({"title": "Test"});
176
177 sink.deliver(payload, Some("user-B"), &HashMap::new())
178 .await
179 .unwrap();
180
181 let dispatched = dispatcher.dispatched.lock().await;
182 assert_eq!(dispatched[0].0.as_deref(), Some("user-B"));
183 }
184
185 #[tokio::test]
186 async fn test_ws_broadcast_when_no_recipient() {
187 let dispatcher = Arc::new(MockDispatcher::new());
188 let sink = WebSocketSink::new(dispatcher.clone());
189
190 let payload = json!({"title": "System announcement"});
192
193 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
194
195 let dispatched = dispatcher.dispatched.lock().await;
196 assert_eq!(dispatched.len(), 1);
197 assert!(dispatched[0].0.is_none()); }
199
200 #[tokio::test]
201 async fn test_ws_recipient_from_context() {
202 let dispatcher = Arc::new(MockDispatcher::new());
203 let sink = WebSocketSink::new(dispatcher.clone());
204
205 let payload = json!({"title": "Test"});
206 let mut vars = HashMap::new();
207 vars.insert(
208 "recipient_id".to_string(),
209 Value::String("user-C".to_string()),
210 );
211
212 sink.deliver(payload, None, &vars).await.unwrap();
213
214 let dispatched = dispatcher.dispatched.lock().await;
215 assert_eq!(dispatched[0].0.as_deref(), Some("user-C"));
216 }
217
218 #[test]
219 fn test_ws_sink_name_and_type() {
220 let dispatcher = Arc::new(MockDispatcher::new());
221 let sink = WebSocketSink::new(dispatcher);
222 assert_eq!(sink.name(), "websocket");
223 assert_eq!(sink.sink_type(), SinkType::WebSocket);
224 }
225}