1use std::sync::Arc;
12
13use dashmap::DashMap;
14use serde::{Deserialize, Serialize};
15use tracing::{debug, warn};
16
17use crate::process::{Pid, ProcessState, ProcessTable};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Subscription {
22 pub topic: String,
24
25 pub subscriber_pid: Pid,
27
28 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub filter: Option<String>,
31}
32
33pub struct TopicRouter {
44 subscriptions: DashMap<String, Vec<Pid>>,
46
47 process_table: Arc<ProcessTable>,
49}
50
51impl TopicRouter {
52 pub fn new(process_table: Arc<ProcessTable>) -> Self {
54 Self {
55 subscriptions: DashMap::new(),
56 process_table,
57 }
58 }
59
60 pub fn subscribe(&self, pid: Pid, topic: &str) {
64 debug!(pid, topic, "subscribing to topic");
65 self.subscriptions
66 .entry(topic.to_owned())
67 .or_default()
68 .push(pid);
69
70 if let Some(mut subs) = self.subscriptions.get_mut(topic) {
72 subs.dedup();
73 }
74 }
75
76 pub fn unsubscribe(&self, pid: Pid, topic: &str) {
81 debug!(pid, topic, "unsubscribing from topic");
82 if let Some(mut subs) = self.subscriptions.get_mut(topic) {
83 subs.retain(|&p| p != pid);
84 }
85
86 self.subscriptions.retain(|_, subs| !subs.is_empty());
88 }
89
90 pub fn live_subscribers(&self, topic: &str) -> Vec<Pid> {
95 let mut live = Vec::new();
96 let mut dead = Vec::new();
97
98 if let Some(subs) = self.subscriptions.get(topic) {
99 for &pid in subs.iter() {
100 if self.is_alive(pid) {
101 live.push(pid);
102 } else {
103 dead.push(pid);
104 }
105 }
106 }
107
108 if !dead.is_empty() {
110 if let Some(mut subs) = self.subscriptions.get_mut(topic) {
111 subs.retain(|p| !dead.contains(p));
112 }
113 warn!(
114 topic,
115 dead_count = dead.len(),
116 "cleaned up dead subscribers"
117 );
118 }
119
120 live
121 }
122
123 pub fn subscribers(&self, topic: &str) -> Vec<Pid> {
127 self.subscriptions
128 .get(topic)
129 .map(|subs| subs.clone())
130 .unwrap_or_default()
131 }
132
133 pub fn list_topics(&self) -> Vec<(String, usize)> {
135 self.subscriptions
136 .iter()
137 .map(|entry| (entry.key().clone(), entry.value().len()))
138 .collect()
139 }
140
141 pub fn topics_for_pid(&self, pid: Pid) -> Vec<String> {
143 self.subscriptions
144 .iter()
145 .filter(|entry| entry.value().contains(&pid))
146 .map(|entry| entry.key().clone())
147 .collect()
148 }
149
150 pub fn topic_count(&self) -> usize {
152 self.subscriptions.len()
153 }
154
155 pub fn has_subscribers(&self, topic: &str) -> bool {
157 self.subscriptions
158 .get(topic)
159 .is_some_and(|subs| !subs.is_empty())
160 }
161
162 pub fn unsubscribe_all(&self, pid: Pid) {
164 debug!(pid, "unsubscribing from all topics");
165 for mut entry in self.subscriptions.iter_mut() {
166 entry.value_mut().retain(|&p| p != pid);
167 }
168 self.subscriptions.retain(|_, subs| !subs.is_empty());
170 }
171
172 fn is_alive(&self, pid: Pid) -> bool {
174 self.process_table
175 .get(pid)
176 .is_some_and(|entry| !matches!(entry.state, ProcessState::Exited(_)))
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::capability::AgentCapabilities;
184 use crate::process::{ProcessEntry, ResourceUsage};
185 use tokio_util::sync::CancellationToken;
186
187 fn make_router_with_processes(count: usize) -> (TopicRouter, Vec<Pid>) {
188 let table = Arc::new(ProcessTable::new(64));
189 let mut pids = Vec::new();
190 for i in 0..count {
191 let entry = ProcessEntry {
192 pid: 0,
193 agent_id: format!("agent-{i}"),
194 state: ProcessState::Running,
195 capabilities: AgentCapabilities::default(),
196 resource_usage: ResourceUsage::default(),
197 cancel_token: CancellationToken::new(),
198 parent_pid: None,
199 };
200 let pid = table.insert(entry).unwrap();
201 pids.push(pid);
202 }
203 (TopicRouter::new(table), pids)
204 }
205
206 #[test]
207 fn subscribe_and_list() {
208 let (router, pids) = make_router_with_processes(2);
209 router.subscribe(pids[0], "build");
210 router.subscribe(pids[1], "build");
211
212 let subs = router.subscribers("build");
213 assert_eq!(subs.len(), 2);
214 assert!(subs.contains(&pids[0]));
215 assert!(subs.contains(&pids[1]));
216 }
217
218 #[test]
219 fn subscribe_idempotent() {
220 let (router, pids) = make_router_with_processes(1);
221 router.subscribe(pids[0], "build");
222 router.subscribe(pids[0], "build");
223
224 let subs = router.subscribers("build");
225 assert_eq!(subs.len(), 1);
226 }
227
228 #[test]
229 fn unsubscribe() {
230 let (router, pids) = make_router_with_processes(2);
231 router.subscribe(pids[0], "build");
232 router.subscribe(pids[1], "build");
233
234 router.unsubscribe(pids[0], "build");
235
236 let subs = router.subscribers("build");
237 assert_eq!(subs.len(), 1);
238 assert_eq!(subs[0], pids[1]);
239 }
240
241 #[test]
242 fn unsubscribe_nonexistent_is_noop() {
243 let (router, _pids) = make_router_with_processes(0);
244 router.unsubscribe(999, "build"); assert!(router.subscribers("build").is_empty());
246 }
247
248 #[test]
249 fn unsubscribe_removes_empty_topic() {
250 let (router, pids) = make_router_with_processes(1);
251 router.subscribe(pids[0], "build");
252 router.unsubscribe(pids[0], "build");
253
254 assert_eq!(router.topic_count(), 0);
255 }
256
257 #[test]
258 fn list_topics() {
259 let (router, pids) = make_router_with_processes(2);
260 router.subscribe(pids[0], "build");
261 router.subscribe(pids[0], "test");
262 router.subscribe(pids[1], "build");
263
264 let topics = router.list_topics();
265 assert_eq!(topics.len(), 2);
266
267 let build_count = topics
268 .iter()
269 .find(|(t, _)| t == "build")
270 .map(|(_, c)| *c)
271 .unwrap();
272 assert_eq!(build_count, 2);
273 }
274
275 #[test]
276 fn topics_for_pid() {
277 let (router, pids) = make_router_with_processes(1);
278 router.subscribe(pids[0], "build");
279 router.subscribe(pids[0], "test");
280 router.subscribe(pids[0], "deploy");
281
282 let topics = router.topics_for_pid(pids[0]);
283 assert_eq!(topics.len(), 3);
284 }
285
286 #[test]
287 fn has_subscribers() {
288 let (router, pids) = make_router_with_processes(1);
289 assert!(!router.has_subscribers("build"));
290
291 router.subscribe(pids[0], "build");
292 assert!(router.has_subscribers("build"));
293 }
294
295 #[test]
296 fn live_subscribers_filters_dead() {
297 let table = Arc::new(ProcessTable::new(64));
298
299 let entry1 = ProcessEntry {
301 pid: 0,
302 agent_id: "alive".to_owned(),
303 state: ProcessState::Running,
304 capabilities: AgentCapabilities::default(),
305 resource_usage: ResourceUsage::default(),
306 cancel_token: CancellationToken::new(),
307 parent_pid: None,
308 };
309 let pid1 = table.insert(entry1).unwrap();
310
311 let entry2 = ProcessEntry {
313 pid: 0,
314 agent_id: "dead".to_owned(),
315 state: ProcessState::Running,
316 capabilities: AgentCapabilities::default(),
317 resource_usage: ResourceUsage::default(),
318 cancel_token: CancellationToken::new(),
319 parent_pid: None,
320 };
321 let pid2 = table.insert(entry2).unwrap();
322 table.update_state(pid2, ProcessState::Exited(0)).unwrap();
323
324 let router = TopicRouter::new(table);
325 router.subscribe(pid1, "build");
326 router.subscribe(pid2, "build");
327
328 assert_eq!(router.subscribers("build").len(), 2);
330
331 let live = router.live_subscribers("build");
333 assert_eq!(live.len(), 1);
334 assert_eq!(live[0], pid1);
335
336 assert_eq!(router.subscribers("build").len(), 1);
338 }
339
340 #[test]
341 fn unsubscribe_all() {
342 let (router, pids) = make_router_with_processes(2);
343 router.subscribe(pids[0], "build");
344 router.subscribe(pids[0], "test");
345 router.subscribe(pids[1], "build");
346
347 router.unsubscribe_all(pids[0]);
348
349 assert!(router.topics_for_pid(pids[0]).is_empty());
350 assert_eq!(router.subscribers("build").len(), 1);
351 assert_eq!(router.topic_count(), 1); }
353
354 #[test]
355 fn subscription_serde_roundtrip() {
356 let sub = Subscription {
357 topic: "build".into(),
358 subscriber_pid: 42,
359 filter: Some("status:*".into()),
360 };
361 let json = serde_json::to_string(&sub).unwrap();
362 let restored: Subscription = serde_json::from_str(&json).unwrap();
363 assert_eq!(restored.topic, "build");
364 assert_eq!(restored.subscriber_pid, 42);
365 assert_eq!(restored.filter, Some("status:*".into()));
366 }
367}