Skip to main content

brainwires_network/routing/
content.rs

1use anyhow::{Result, bail};
2use async_trait::async_trait;
3
4use super::peer_table::PeerTable;
5use super::traits::{Router, RoutingStrategy};
6use crate::network::{MessageEnvelope, MessageTarget};
7use crate::transport::TransportAddress;
8
9/// Content-based (topic) router.
10///
11/// Routes messages addressed to a [`MessageTarget::Topic`] to all peers
12/// that are subscribed to that topic in the [`PeerTable`].
13#[derive(Debug, Default)]
14pub struct ContentRouter;
15
16impl ContentRouter {
17    /// Create a new content router.
18    pub fn new() -> Self {
19        Self
20    }
21}
22
23#[async_trait]
24impl Router for ContentRouter {
25    async fn route(
26        &self,
27        envelope: &MessageEnvelope,
28        peers: &PeerTable,
29    ) -> Result<Vec<TransportAddress>> {
30        match &envelope.recipient {
31            MessageTarget::Topic(topic) => {
32                let subscribers = peers.subscribers(topic);
33                let mut addrs = Vec::new();
34
35                for sub_id in &subscribers {
36                    // Don't send back to sender
37                    if *sub_id == envelope.sender {
38                        continue;
39                    }
40                    if let Some(peer_addrs) = peers.addresses(sub_id) {
41                        addrs.extend_from_slice(peer_addrs);
42                    }
43                }
44
45                Ok(addrs)
46            }
47            MessageTarget::Direct(_) => {
48                bail!("ContentRouter does not handle direct messages");
49            }
50            MessageTarget::Broadcast => {
51                bail!("ContentRouter does not handle broadcast messages");
52            }
53        }
54    }
55
56    fn strategy(&self) -> RoutingStrategy {
57        RoutingStrategy::ContentBased
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use crate::identity::AgentIdentity;
65    use crate::network::Payload;
66    use uuid::Uuid;
67
68    #[tokio::test]
69    async fn content_routes_to_subscribers() {
70        let router = ContentRouter::new();
71        let mut peers = PeerTable::new();
72
73        let sender = AgentIdentity::new("sender");
74        let sender_id = sender.id;
75        let sub_a = AgentIdentity::new("sub-a");
76        let sub_a_id = sub_a.id;
77        let sub_b = AgentIdentity::new("sub-b");
78        let sub_b_id = sub_b.id;
79        let non_sub = AgentIdentity::new("non-sub");
80
81        let addr_a = TransportAddress::Tcp("127.0.0.1:1000".parse().unwrap());
82        let addr_b = TransportAddress::Tcp("127.0.0.1:2000".parse().unwrap());
83        let addr_ns = TransportAddress::Tcp("127.0.0.1:3000".parse().unwrap());
84
85        peers.upsert(sender, vec![]);
86        peers.upsert(sub_a, vec![addr_a.clone()]);
87        peers.upsert(sub_b, vec![addr_b.clone()]);
88        peers.upsert(non_sub, vec![addr_ns.clone()]);
89
90        peers.subscribe(sub_a_id, "events");
91        peers.subscribe(sub_b_id, "events");
92
93        let env = MessageEnvelope::topic(sender_id, "events", Payload::Text("update".into()));
94        let addrs = router.route(&env, &peers).await.unwrap();
95
96        assert_eq!(addrs.len(), 2);
97        assert!(addrs.contains(&addr_a));
98        assert!(addrs.contains(&addr_b));
99        assert!(!addrs.contains(&addr_ns));
100    }
101
102    #[tokio::test]
103    async fn content_empty_topic() {
104        let router = ContentRouter::new();
105        let peers = PeerTable::new();
106
107        let env = MessageEnvelope::topic(
108            Uuid::new_v4(),
109            "no-subscribers",
110            Payload::Text("hello".into()),
111        );
112        let addrs = router.route(&env, &peers).await.unwrap();
113        assert!(addrs.is_empty());
114    }
115
116    #[tokio::test]
117    async fn content_rejects_direct() {
118        let router = ContentRouter::new();
119        let peers = PeerTable::new();
120
121        let env = MessageEnvelope::direct(
122            Uuid::new_v4(),
123            Uuid::new_v4(),
124            Payload::Text("hello".into()),
125        );
126        assert!(router.route(&env, &peers).await.is_err());
127    }
128}