Skip to main content

oxihuman_core/
channel_router.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! Routes messages to named channels based on topic patterns.
6
7use std::collections::HashMap;
8
9/// A message with a topic and payload.
10#[allow(dead_code)]
11#[derive(Debug, Clone)]
12pub struct RoutedMessage {
13    pub topic: String,
14    pub payload: String,
15}
16
17/// Routes messages to named channels.
18#[allow(dead_code)]
19#[derive(Debug, Clone)]
20pub struct ChannelRouter {
21    /// channel_name -> list of topic prefixes it subscribes to
22    routes: HashMap<String, Vec<String>>,
23    /// channel_name -> queued messages
24    queues: HashMap<String, Vec<RoutedMessage>>,
25    total_routed: u64,
26}
27
28#[allow(dead_code)]
29impl ChannelRouter {
30    pub fn new() -> Self {
31        Self {
32            routes: HashMap::new(),
33            queues: HashMap::new(),
34            total_routed: 0,
35        }
36    }
37
38    pub fn add_channel(&mut self, name: &str) {
39        self.routes.entry(name.to_string()).or_default();
40        self.queues.entry(name.to_string()).or_default();
41    }
42
43    pub fn subscribe(&mut self, channel: &str, topic_prefix: &str) {
44        self.routes
45            .entry(channel.to_string())
46            .or_default()
47            .push(topic_prefix.to_string());
48    }
49
50    pub fn route(&mut self, topic: &str, payload: &str) {
51        let msg = RoutedMessage {
52            topic: topic.to_string(),
53            payload: payload.to_string(),
54        };
55        for (ch_name, prefixes) in &self.routes {
56            if prefixes.iter().any(|p| topic.starts_with(p)) {
57                self.queues
58                    .entry(ch_name.clone())
59                    .or_default()
60                    .push(msg.clone());
61                self.total_routed += 1;
62            }
63        }
64    }
65
66    pub fn drain_channel(&mut self, channel: &str) -> Vec<RoutedMessage> {
67        self.queues
68            .get_mut(channel)
69            .map(std::mem::take)
70            .unwrap_or_default()
71    }
72
73    pub fn channel_count(&self) -> usize {
74        self.routes.len()
75    }
76
77    pub fn pending_count(&self, channel: &str) -> usize {
78        self.queues.get(channel).map_or(0, |q| q.len())
79    }
80
81    pub fn total_routed(&self) -> u64 {
82        self.total_routed
83    }
84
85    pub fn has_channel(&self, name: &str) -> bool {
86        self.routes.contains_key(name)
87    }
88
89    pub fn remove_channel(&mut self, name: &str) {
90        self.routes.remove(name);
91        self.queues.remove(name);
92    }
93
94    pub fn clear_all(&mut self) {
95        for q in self.queues.values_mut() {
96            q.clear();
97        }
98    }
99}
100
101impl Default for ChannelRouter {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn new_router_empty() {
113        let r = ChannelRouter::new();
114        assert_eq!(r.channel_count(), 0);
115    }
116
117    #[test]
118    fn add_channel_and_subscribe() {
119        let mut r = ChannelRouter::new();
120        r.add_channel("log");
121        r.subscribe("log", "system.");
122        assert!(r.has_channel("log"));
123    }
124
125    #[test]
126    fn route_to_matching_channel() {
127        let mut r = ChannelRouter::new();
128        r.add_channel("log");
129        r.subscribe("log", "sys.");
130        r.route("sys.info", "hello");
131        assert_eq!(r.pending_count("log"), 1);
132    }
133
134    #[test]
135    fn route_no_match() {
136        let mut r = ChannelRouter::new();
137        r.add_channel("log");
138        r.subscribe("log", "sys.");
139        r.route("net.error", "data");
140        assert_eq!(r.pending_count("log"), 0);
141    }
142
143    #[test]
144    fn drain_channel_empties() {
145        let mut r = ChannelRouter::new();
146        r.add_channel("ch");
147        r.subscribe("ch", "t.");
148        r.route("t.1", "a");
149        let msgs = r.drain_channel("ch");
150        assert_eq!(msgs.len(), 1);
151        assert_eq!(r.pending_count("ch"), 0);
152    }
153
154    #[test]
155    fn total_routed_increments() {
156        let mut r = ChannelRouter::new();
157        r.add_channel("a");
158        r.subscribe("a", "x");
159        r.route("x1", "p");
160        r.route("x2", "q");
161        assert_eq!(r.total_routed(), 2);
162    }
163
164    #[test]
165    fn remove_channel() {
166        let mut r = ChannelRouter::new();
167        r.add_channel("tmp");
168        r.remove_channel("tmp");
169        assert!(!r.has_channel("tmp"));
170    }
171
172    #[test]
173    fn clear_all_queues() {
174        let mut r = ChannelRouter::new();
175        r.add_channel("c");
176        r.subscribe("c", "");
177        r.route("anything", "data");
178        r.clear_all();
179        assert_eq!(r.pending_count("c"), 0);
180    }
181
182    #[test]
183    fn multiple_channels_receive() {
184        let mut r = ChannelRouter::new();
185        r.add_channel("a");
186        r.add_channel("b");
187        r.subscribe("a", "shared.");
188        r.subscribe("b", "shared.");
189        r.route("shared.msg", "x");
190        assert_eq!(r.pending_count("a"), 1);
191        assert_eq!(r.pending_count("b"), 1);
192    }
193
194    #[test]
195    fn default_is_empty() {
196        let r = ChannelRouter::default();
197        assert_eq!(r.channel_count(), 0);
198    }
199}