iroh_topic_tracker/
topic_tracker.rs

1use std::{collections::HashMap, future::Future, str::FromStr, sync::Arc};
2
3use anyhow::{bail, Result};
4use iroh::{
5    endpoint::{Connection, Endpoint, RecvStream, SendStream},
6    protocol::{AcceptError, ProtocolHandler},
7    NodeId, SecretKey,
8};
9use iroh_gossip::proto::TopicId;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12use tokio::{
13    io::{AsyncReadExt, AsyncWriteExt},
14    sync::Mutex,
15};
16
17use crate::utils::wait_for_relay;
18
19#[derive(Debug, Clone)]
20pub struct TopicTracker {
21    pub node_id: NodeId,
22    endpoint: Endpoint,
23    kv: Arc<Mutex<HashMap<[u8; 32], Vec<NodeId>>>>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27enum Protocol {
28    TopicRequest((Topic, NodeId)),
29    TopicList(Vec<NodeId>),
30    Done,
31}
32
33impl TopicTracker {
34    pub const ALPN: &'static [u8] = b"iroh/topictracker/1";
35    pub const MAX_TOPIC_LIST_SIZE: usize = 10;
36    pub const MAX_NODE_IDS_PER_TOPIC: usize = 100;
37    pub const BOOTSTRAP_NODES: &str =
38        "abcdef4df4d74587095d071406c2a8462bde5079cbbc0c50051b9b2e84d67691";
39    pub const MAX_MSG_SIZE_BYTES: u64 = 1024 * 1024;
40
41    pub fn new(endpoint: &Endpoint) -> Self {
42        let me = Self {
43            endpoint: endpoint.clone(),
44            node_id: endpoint.node_id(),
45            kv: Arc::new(Mutex::new(HashMap::new())),
46        };
47        me
48    }
49
50    pub async fn spawn_optional(self) -> Result<Self> {
51        tokio::spawn({
52            let me2 = self.clone();
53            async move {
54                while let Some(connecting) = me2.clone().endpoint.accept().await {
55                    match connecting.accept() {
56                        Ok(conn) => {
57                            if let Ok(con) = conn.await {
58                                tokio::spawn({
59                                    let me3 = me2.clone();
60                                    async move {
61                                        let _ = me3.accept(con).await;
62                                    }
63                                });
64                            }
65                        }
66                        Err(err) => {
67                            println!("Failed to connect {err}");
68                        }
69                    }
70                }
71            }
72        });
73        Ok(self)
74    }
75
76    async fn send_msg(msg: Protocol, send: &mut SendStream) -> Result<()> {
77        let encoded = postcard::to_stdvec(&msg)?;
78        assert!(encoded.len() <= Self::MAX_MSG_SIZE_BYTES as usize);
79
80        send.write_u64_le(encoded.len() as u64).await?;
81        send.write(&encoded).await?;
82        Ok(())
83    }
84
85    async fn recv_msg(recv: &mut RecvStream) -> Result<Protocol> {
86        let len = recv.read_u64_le().await?;
87
88        assert!(len <= Self::MAX_MSG_SIZE_BYTES);
89
90        let mut buffer = vec![0u8; len as usize];
91        recv.read_exact(&mut buffer).await?;
92        let msg: Protocol = postcard::from_bytes(&buffer)?;
93        Ok(msg)
94    }
95
96    pub async fn get_topic_nodes(self: Self, topic: &Topic) -> Result<Vec<NodeId>> {
97        wait_for_relay(&self.endpoint).await?;
98
99        let conn_res = self
100            .endpoint
101            .connect(NodeId::from_str(Self::BOOTSTRAP_NODES)?, Self::ALPN)
102            .await;
103        let conn = conn_res?;
104
105        let (mut send, mut recv) = conn.open_bi().await?;
106
107        let msg = Protocol::TopicRequest((topic.clone(), self.node_id.clone()));
108        Self::send_msg(msg, &mut send).await?;
109
110        let back = match Self::recv_msg(&mut recv).await? {
111            Protocol::TopicList(vec) => {
112                let mut _kv = self.kv.lock().await;
113                match _kv.get_mut(&topic.0) {
114                    Some(node_ids) => {
115                        for id in vec.clone() {
116                            if node_ids.contains(&id) {
117                                node_ids.retain(|nid| !nid.eq(&id));
118                            }
119                            node_ids.push(id);
120                        }
121                    }
122                    None => {
123                        let mut node_ids = Vec::with_capacity(Self::MAX_NODE_IDS_PER_TOPIC);
124                        for id in vec.clone() {
125                            node_ids.push(id);
126                        }
127                        _kv.insert(topic.0, node_ids);
128                    }
129                };
130                drop(_kv);
131                Ok(vec)
132            }
133            _ => bail!("illegal message received"),
134        };
135
136        Self::send_msg(Protocol::Done, &mut send).await?;
137        back
138    }
139
140    async fn accept(&self, conn: Connection) -> Result<()> {
141        let (mut send, mut recv) = conn.accept_bi().await?;
142        let msg = Self::recv_msg(&mut recv).await?;
143
144        match msg {
145            Protocol::TopicRequest((topic, remote_node_id)) => {
146                let mut _kv = self.kv.lock().await;
147                let resp;
148                match _kv.get_mut(&topic.0) {
149                    Some(node_ids) => {
150                        let latest_list = node_ids
151                            .iter()
152                            .filter(|&i| !i.eq(&remote_node_id))
153                            .rev()
154                            .take(Self::MAX_TOPIC_LIST_SIZE)
155                            .map(|i| *i)
156                            .collect();
157
158                        resp = Protocol::TopicList(latest_list);
159
160                        if node_ids.contains(&remote_node_id) {
161                            node_ids.retain(|nid| !nid.eq(&remote_node_id));
162                        }
163                        node_ids.push(remote_node_id);
164                    }
165                    None => {
166                        let mut node_ids = Vec::with_capacity(Self::MAX_NODE_IDS_PER_TOPIC);
167                        node_ids.push(remote_node_id);
168                        _kv.insert(topic.0, node_ids);
169
170                        resp = Protocol::TopicList(vec![]);
171                    }
172                };
173
174                Self::send_msg(resp, &mut send).await?;
175                Self::send_msg(Protocol::Done, &mut send).await?;
176                Self::recv_msg(&mut recv).await?;
177
178                drop(_kv);
179            }
180            _ => {
181                bail!("Illegal request");
182            }
183        };
184
185        send.finish()?;
186        Ok(())
187    }
188
189    pub async fn memory_footprint(&self) -> usize {
190        let _kv = self.kv.lock().await;
191        let val = &*_kv;
192        size_of_val(val)
193    }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
197pub struct Topic([u8; 32]);
198
199impl Topic {
200    pub fn new(topic: [u8; 32]) -> Self {
201        Self(topic)
202    }
203
204    pub fn from_passphrase(phrase: &str) -> Self {
205        Self(Self::hash(phrase))
206    }
207
208    fn hash(s: &str) -> [u8; 32] {
209        let mut hasher = Sha256::new();
210        hasher.update(s);
211        let mut buf = [0u8; 32];
212        buf.copy_from_slice(&hasher.finalize()[..32]);
213        buf
214    }
215
216    pub fn to_string(&self) -> String {
217        z32::encode(&self.0)
218    }
219
220    pub fn to_secret_key(&self) -> SecretKey {
221        SecretKey::from_bytes(&self.0.clone())
222    }
223}
224
225impl Default for Topic {
226    fn default() -> Self {
227        Self::from_passphrase("password")
228    }
229}
230
231impl ProtocolHandler for TopicTracker {
232    fn accept(
233        &self,
234        conn: Connection,
235    ) -> impl Future<Output = Result<(), AcceptError>> + Send {
236        let topic_tracker = self.clone();
237
238        Box::pin(async move {
239            topic_tracker
240                .accept(conn)
241                .await
242                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
243            Ok(())
244        })
245    }
246}
247
248#[cfg(feature = "iroh-gossip-cast")]
249impl Into<iroh_gossip::proto::TopicId> for Topic {
250    fn into(self) -> iroh_gossip::proto::TopicId {
251        TopicId::from_bytes(self.0)
252    }
253}
254
255#[cfg(feature = "iroh-gossip-cast")]
256impl From<iroh_gossip::proto::TopicId> for Topic {
257    fn from(value: iroh_gossip::proto::TopicId) -> Self {
258        Self {
259            0: *value.as_bytes(),
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_topic_from_passphrase() {
270        let topic = Topic::from_passphrase("test123");
271        assert_eq!(topic.0.len(), 32);
272    }
273
274    #[test]
275    fn test_topic_new() {
276        let bytes = [0u8; 32];
277        let topic = Topic::new(bytes);
278        assert_eq!(topic.0, bytes);
279    }
280
281    #[test]
282    fn test_topic_default() {
283        let topic = Topic::default();
284        assert_eq!(topic, Topic::from_passphrase("password"));
285    }
286
287    #[test]
288    fn test_topic_to_string() {
289        let topic = Topic::from_passphrase("test");
290        assert!(!topic.to_string().is_empty());
291    }
292}