Skip to main content

pylon_runtime/
pubsub.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5// ---------------------------------------------------------------------------
6// Message type
7// ---------------------------------------------------------------------------
8
9/// A message published to a channel.
10#[derive(Debug, Clone, serde::Serialize)]
11pub struct PubSubMessage {
12    pub channel: String,
13    pub message: String,
14    pub timestamp: String,
15}
16
17// ---------------------------------------------------------------------------
18// Subscriber callback
19// ---------------------------------------------------------------------------
20
21type Callback = Box<dyn Fn(&PubSubMessage) + Send + Sync>;
22
23// ---------------------------------------------------------------------------
24// PubSubBroker
25// ---------------------------------------------------------------------------
26
27/// In-memory pub/sub broker with channel-based messaging, history retention,
28/// and glob-pattern subscriptions.
29pub struct PubSubBroker {
30    /// channel -> list of (subscriber_id, callback)
31    subscriptions: Mutex<HashMap<String, Vec<(u64, Callback)>>>,
32    next_id: Mutex<u64>,
33    /// Recent messages per channel for late joiners.
34    history: Mutex<HashMap<String, Vec<PubSubMessage>>>,
35    max_history: usize,
36}
37
38impl PubSubBroker {
39    /// Create a new broker that retains up to `max_history_per_channel`
40    /// messages per channel.
41    pub fn new(max_history_per_channel: usize) -> Self {
42        Self {
43            subscriptions: Mutex::new(HashMap::new()),
44            next_id: Mutex::new(1),
45            history: Mutex::new(HashMap::new()),
46            max_history: max_history_per_channel,
47        }
48    }
49
50    /// Publish a message to a channel. Returns the number of subscribers
51    /// that were notified.
52    pub fn publish(&self, channel: &str, message: &str) -> usize {
53        let msg = PubSubMessage {
54            channel: channel.to_string(),
55            message: message.to_string(),
56            timestamp: now_iso(),
57        };
58
59        // Save to history.
60        {
61            let mut history = self.history.lock().unwrap();
62            let channel_history = history.entry(channel.to_string()).or_default();
63            channel_history.push(msg.clone());
64            if channel_history.len() > self.max_history {
65                channel_history.remove(0);
66            }
67        }
68
69        // Notify subscribers.
70        let subs = self.subscriptions.lock().unwrap();
71        if let Some(subscribers) = subs.get(channel) {
72            for (_, callback) in subscribers {
73                callback(&msg);
74            }
75            subscribers.len()
76        } else {
77            0
78        }
79    }
80
81    /// Subscribe to a channel. Returns a subscription ID that can be used
82    /// to unsubscribe later.
83    pub fn subscribe(&self, channel: &str, callback: Callback) -> u64 {
84        let id = {
85            let mut next = self.next_id.lock().unwrap();
86            let id = *next;
87            *next += 1;
88            id
89        };
90        let mut subs = self.subscriptions.lock().unwrap();
91        subs.entry(channel.to_string())
92            .or_default()
93            .push((id, callback));
94        id
95    }
96
97    /// Unsubscribe from a channel by subscription ID. Returns true if the
98    /// subscription was found and removed.
99    pub fn unsubscribe(&self, channel: &str, sub_id: u64) -> bool {
100        let mut subs = self.subscriptions.lock().unwrap();
101        if let Some(subscribers) = subs.get_mut(channel) {
102            let before = subscribers.len();
103            subscribers.retain(|(id, _)| *id != sub_id);
104            let removed = subscribers.len() < before;
105            // Clean up empty channel entries.
106            if subscribers.is_empty() {
107                subs.remove(channel);
108            }
109            removed
110        } else {
111            false
112        }
113    }
114
115    /// Get recent message history for a channel, up to `limit` messages.
116    /// Returns messages in chronological order (oldest first).
117    pub fn history(&self, channel: &str, limit: usize) -> Vec<PubSubMessage> {
118        let history = self.history.lock().unwrap();
119        match history.get(channel) {
120            Some(msgs) => {
121                let start = msgs.len().saturating_sub(limit);
122                msgs[start..].to_vec()
123            }
124            None => vec![],
125        }
126    }
127
128    /// List all channels that have at least one subscriber, along with their
129    /// subscriber counts.
130    pub fn channels(&self) -> Vec<(String, usize)> {
131        let subs = self.subscriptions.lock().unwrap();
132        let mut result: Vec<(String, usize)> =
133            subs.iter().map(|(ch, s)| (ch.clone(), s.len())).collect();
134        result.sort_by(|a, b| a.0.cmp(&b.0));
135        result
136    }
137
138    /// Get the number of subscribers for a specific channel.
139    pub fn subscriber_count(&self, channel: &str) -> usize {
140        let subs = self.subscriptions.lock().unwrap();
141        subs.get(channel).map(|s| s.len()).unwrap_or(0)
142    }
143
144    /// Pattern-subscribe: subscribe to all existing channels whose names
145    /// match a glob pattern. Returns the subscription IDs created (one per
146    /// matched channel).
147    ///
148    /// Note: this is a snapshot-based pattern subscribe. Channels created
149    /// after the call will not be matched automatically.
150    pub fn psubscribe(&self, pattern: &str, callback: Callback) -> Vec<u64> {
151        // Collect matching channel names first (to avoid holding both locks).
152        let matching: Vec<String> = {
153            let subs = self.subscriptions.lock().unwrap();
154            subs.keys()
155                .filter(|ch| glob_match(pattern, ch))
156                .cloned()
157                .collect()
158        };
159
160        // Also check history for channels that have messages but no current
161        // subscribers.
162        let history_channels: Vec<String> = {
163            let history = self.history.lock().unwrap();
164            history
165                .keys()
166                .filter(|ch| glob_match(pattern, ch) && !matching.contains(ch))
167                .cloned()
168                .collect()
169        };
170
171        let all_channels: Vec<String> = matching.into_iter().chain(history_channels).collect();
172
173        // We need to create a shared callback that can be used across
174        // multiple subscriptions. We wrap it in an Arc.
175        let shared_cb = std::sync::Arc::new(callback);
176        let mut ids = Vec::new();
177        for ch in &all_channels {
178            let cb = std::sync::Arc::clone(&shared_cb);
179            let id = self.subscribe(ch, Box::new(move |msg| cb(msg)));
180            ids.push(id);
181        }
182        ids
183    }
184
185    /// List all channels that have history entries (regardless of whether
186    /// they have active subscribers).
187    pub fn channels_with_history(&self) -> Vec<String> {
188        let history = self.history.lock().unwrap();
189        let mut channels: Vec<String> = history.keys().cloned().collect();
190        channels.sort();
191        channels
192    }
193}
194
195// ---------------------------------------------------------------------------
196// Helpers
197// ---------------------------------------------------------------------------
198
199/// Return the current UTC time as an ISO 8601 string.
200fn now_iso() -> String {
201    let secs = SystemTime::now()
202        .duration_since(UNIX_EPOCH)
203        .unwrap_or_default()
204        .as_secs();
205    // Simple epoch-to-ISO conversion without the chrono crate.
206    let days = secs / 86400;
207    let time_of_day = secs % 86400;
208    let hours = time_of_day / 3600;
209    let minutes = (time_of_day % 3600) / 60;
210    let seconds = time_of_day % 60;
211
212    let mut y = 1970i64;
213    let mut remaining = days as i64;
214    loop {
215        let days_in_year = if is_leap(y) { 366 } else { 365 };
216        if remaining < days_in_year {
217            break;
218        }
219        remaining -= days_in_year;
220        y += 1;
221    }
222    let leap = is_leap(y);
223    let month_days: [i64; 12] = [
224        31,
225        if leap { 29 } else { 28 },
226        31,
227        30,
228        31,
229        30,
230        31,
231        31,
232        30,
233        31,
234        30,
235        31,
236    ];
237    let mut m = 0usize;
238    for (i, &md) in month_days.iter().enumerate() {
239        if remaining < md {
240            m = i;
241            break;
242        }
243        remaining -= md;
244    }
245    let d = remaining + 1;
246    format!(
247        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
248        y,
249        m + 1,
250        d,
251        hours,
252        minutes,
253        seconds
254    )
255}
256
257fn is_leap(y: i64) -> bool {
258    (y % 4 == 0 && y % 100 != 0) || y % 400 == 0
259}
260
261/// Simple glob matching supporting `*` (any sequence) and `?` (single char).
262fn glob_match(pattern: &str, text: &str) -> bool {
263    let pat: Vec<char> = pattern.chars().collect();
264    let txt: Vec<char> = text.chars().collect();
265    glob_inner(&pat, &txt)
266}
267
268fn glob_inner(pat: &[char], txt: &[char]) -> bool {
269    if pat.is_empty() {
270        return txt.is_empty();
271    }
272    match pat[0] {
273        '*' => {
274            for i in 0..=txt.len() {
275                if glob_inner(&pat[1..], &txt[i..]) {
276                    return true;
277                }
278            }
279            false
280        }
281        '?' => {
282            if txt.is_empty() {
283                false
284            } else {
285                glob_inner(&pat[1..], &txt[1..])
286            }
287        }
288        c => {
289            if txt.is_empty() || txt[0] != c {
290                false
291            } else {
292                glob_inner(&pat[1..], &txt[1..])
293            }
294        }
295    }
296}
297
298// ---------------------------------------------------------------------------
299// Tests
300// ---------------------------------------------------------------------------
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use std::sync::atomic::{AtomicUsize, Ordering};
306    use std::sync::Arc;
307
308    #[test]
309    fn publish_and_subscribe() {
310        let broker = PubSubBroker::new(10);
311        let count = Arc::new(AtomicUsize::new(0));
312        let c = Arc::clone(&count);
313        broker.subscribe(
314            "chat",
315            Box::new(move |_msg| {
316                c.fetch_add(1, Ordering::SeqCst);
317            }),
318        );
319        let notified = broker.publish("chat", "hello");
320        assert_eq!(notified, 1);
321        assert_eq!(count.load(Ordering::SeqCst), 1);
322    }
323
324    #[test]
325    fn publish_to_empty_channel() {
326        let broker = PubSubBroker::new(10);
327        let notified = broker.publish("empty", "no one listening");
328        assert_eq!(notified, 0);
329    }
330
331    #[test]
332    fn multiple_subscribers() {
333        let broker = PubSubBroker::new(10);
334        let count = Arc::new(AtomicUsize::new(0));
335        for _ in 0..5 {
336            let c = Arc::clone(&count);
337            broker.subscribe(
338                "events",
339                Box::new(move |_msg| {
340                    c.fetch_add(1, Ordering::SeqCst);
341                }),
342            );
343        }
344        let notified = broker.publish("events", "boom");
345        assert_eq!(notified, 5);
346        assert_eq!(count.load(Ordering::SeqCst), 5);
347    }
348
349    #[test]
350    fn unsubscribe() {
351        let broker = PubSubBroker::new(10);
352        let count = Arc::new(AtomicUsize::new(0));
353        let c = Arc::clone(&count);
354        let id = broker.subscribe(
355            "ch",
356            Box::new(move |_msg| {
357                c.fetch_add(1, Ordering::SeqCst);
358            }),
359        );
360
361        broker.publish("ch", "first");
362        assert_eq!(count.load(Ordering::SeqCst), 1);
363
364        assert!(broker.unsubscribe("ch", id));
365        broker.publish("ch", "second");
366        // Count should still be 1 since we unsubscribed.
367        assert_eq!(count.load(Ordering::SeqCst), 1);
368    }
369
370    #[test]
371    fn unsubscribe_nonexistent() {
372        let broker = PubSubBroker::new(10);
373        assert!(!broker.unsubscribe("nope", 999));
374    }
375
376    #[test]
377    fn history_basic() {
378        let broker = PubSubBroker::new(10);
379        broker.publish("news", "headline 1");
380        broker.publish("news", "headline 2");
381        broker.publish("news", "headline 3");
382
383        let msgs = broker.history("news", 10);
384        assert_eq!(msgs.len(), 3);
385        assert_eq!(msgs[0].message, "headline 1");
386        assert_eq!(msgs[2].message, "headline 3");
387    }
388
389    #[test]
390    fn history_limit() {
391        let broker = PubSubBroker::new(10);
392        for i in 0..10 {
393            broker.publish("ch", &format!("msg {i}"));
394        }
395        let msgs = broker.history("ch", 3);
396        assert_eq!(msgs.len(), 3);
397        assert_eq!(msgs[0].message, "msg 7");
398        assert_eq!(msgs[2].message, "msg 9");
399    }
400
401    #[test]
402    fn history_eviction() {
403        let broker = PubSubBroker::new(3);
404        broker.publish("ch", "a");
405        broker.publish("ch", "b");
406        broker.publish("ch", "c");
407        broker.publish("ch", "d");
408
409        let msgs = broker.history("ch", 10);
410        assert_eq!(msgs.len(), 3);
411        // "a" should have been evicted.
412        assert_eq!(msgs[0].message, "b");
413    }
414
415    #[test]
416    fn history_empty_channel() {
417        let broker = PubSubBroker::new(10);
418        let msgs = broker.history("nonexistent", 10);
419        assert!(msgs.is_empty());
420    }
421
422    #[test]
423    fn channels_list() {
424        let broker = PubSubBroker::new(10);
425        broker.subscribe("alpha", Box::new(|_| {}));
426        broker.subscribe("alpha", Box::new(|_| {}));
427        broker.subscribe("beta", Box::new(|_| {}));
428
429        let channels = broker.channels();
430        assert_eq!(channels.len(), 2);
431        // Sorted alphabetically.
432        assert_eq!(channels[0].0, "alpha");
433        assert_eq!(channels[0].1, 2);
434        assert_eq!(channels[1].0, "beta");
435        assert_eq!(channels[1].1, 1);
436    }
437
438    #[test]
439    fn subscriber_count() {
440        let broker = PubSubBroker::new(10);
441        assert_eq!(broker.subscriber_count("ch"), 0);
442        broker.subscribe("ch", Box::new(|_| {}));
443        broker.subscribe("ch", Box::new(|_| {}));
444        assert_eq!(broker.subscriber_count("ch"), 2);
445    }
446
447    #[test]
448    fn pattern_subscribe() {
449        let broker = PubSubBroker::new(10);
450        // Create some channels via publish (so they appear in history).
451        broker.publish("user:1", "event");
452        broker.publish("user:2", "event");
453        broker.publish("system:1", "event");
454
455        let count = Arc::new(AtomicUsize::new(0));
456        let c = Arc::clone(&count);
457        let ids = broker.psubscribe(
458            "user:*",
459            Box::new(move |_msg| {
460                c.fetch_add(1, Ordering::SeqCst);
461            }),
462        );
463        assert_eq!(ids.len(), 2); // user:1 and user:2
464
465        broker.publish("user:1", "hello");
466        broker.publish("user:2", "world");
467        assert_eq!(count.load(Ordering::SeqCst), 2);
468    }
469
470    #[test]
471    fn message_contains_metadata() {
472        let broker = PubSubBroker::new(10);
473        let received = Arc::new(Mutex::new(None::<PubSubMessage>));
474        let r = Arc::clone(&received);
475        broker.subscribe(
476            "meta",
477            Box::new(move |msg| {
478                *r.lock().unwrap() = Some(msg.clone());
479            }),
480        );
481        broker.publish("meta", "payload");
482
483        let msg = received.lock().unwrap().clone().unwrap();
484        assert_eq!(msg.channel, "meta");
485        assert_eq!(msg.message, "payload");
486        assert!(!msg.timestamp.is_empty());
487        // Timestamp should look like ISO 8601.
488        assert!(msg.timestamp.contains('T'));
489        assert!(msg.timestamp.ends_with('Z'));
490    }
491
492    #[test]
493    fn glob_match_works() {
494        assert!(glob_match("*", "anything"));
495        assert!(glob_match("user:*", "user:123"));
496        assert!(!glob_match("user:*", "session:1"));
497        assert!(glob_match("u?er:*", "user:1"));
498        assert!(!glob_match("u?er:*", "uuser:1"));
499        assert!(glob_match("*:*", "a:b"));
500    }
501
502    #[test]
503    fn channels_with_history_list() {
504        let broker = PubSubBroker::new(10);
505        broker.publish("alpha", "msg");
506        broker.publish("beta", "msg");
507        let channels = broker.channels_with_history();
508        assert_eq!(channels.len(), 2);
509        assert!(channels.contains(&"alpha".to_string()));
510        assert!(channels.contains(&"beta".to_string()));
511    }
512}