iroh_topic_tracker/
topic_tracker.rs1use std::{collections::HashMap, future::Future, str::FromStr, sync::Arc};
2
3use anyhow::{bail, Result};
4use iroh::{
5 endpoint::{Connection, Endpoint, RecvStream, SendStream},
6 protocol::{AcceptError, ProtocolHandler},
7 NodeId, SecretKey,
8};
9use iroh_gossip::proto::TopicId;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12use tokio::{
13 io::{AsyncReadExt, AsyncWriteExt},
14 sync::Mutex,
15};
16
17use crate::utils::wait_for_relay;
18
19#[derive(Debug, Clone)]
20pub struct TopicTracker {
21 pub node_id: NodeId,
22 endpoint: Endpoint,
23 kv: Arc<Mutex<HashMap<[u8; 32], Vec<NodeId>>>>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27enum Protocol {
28 TopicRequest((Topic, NodeId)),
29 TopicList(Vec<NodeId>),
30 Done,
31}
32
33impl TopicTracker {
34 pub const ALPN: &'static [u8] = b"iroh/topictracker/1";
35 pub const MAX_TOPIC_LIST_SIZE: usize = 10;
36 pub const MAX_NODE_IDS_PER_TOPIC: usize = 100;
37 pub const BOOTSTRAP_NODES: &str =
38 "abcdef4df4d74587095d071406c2a8462bde5079cbbc0c50051b9b2e84d67691";
39 pub const MAX_MSG_SIZE_BYTES: u64 = 1024 * 1024;
40
41 pub fn new(endpoint: &Endpoint) -> Self {
42 let me = Self {
43 endpoint: endpoint.clone(),
44 node_id: endpoint.node_id(),
45 kv: Arc::new(Mutex::new(HashMap::new())),
46 };
47 me
48 }
49
50 pub async fn spawn_optional(self) -> Result<Self> {
51 tokio::spawn({
52 let me2 = self.clone();
53 async move {
54 while let Some(connecting) = me2.clone().endpoint.accept().await {
55 match connecting.accept() {
56 Ok(conn) => {
57 if let Ok(con) = conn.await {
58 tokio::spawn({
59 let me3 = me2.clone();
60 async move {
61 let _ = me3.accept(con).await;
62 }
63 });
64 }
65 }
66 Err(err) => {
67 println!("Failed to connect {err}");
68 }
69 }
70 }
71 }
72 });
73 Ok(self)
74 }
75
76 async fn send_msg(msg: Protocol, send: &mut SendStream) -> Result<()> {
77 let encoded = postcard::to_stdvec(&msg)?;
78 assert!(encoded.len() <= Self::MAX_MSG_SIZE_BYTES as usize);
79
80 send.write_u64_le(encoded.len() as u64).await?;
81 send.write(&encoded).await?;
82 Ok(())
83 }
84
85 async fn recv_msg(recv: &mut RecvStream) -> Result<Protocol> {
86 let len = recv.read_u64_le().await?;
87
88 assert!(len <= Self::MAX_MSG_SIZE_BYTES);
89
90 let mut buffer = vec![0u8; len as usize];
91 recv.read_exact(&mut buffer).await?;
92 let msg: Protocol = postcard::from_bytes(&buffer)?;
93 Ok(msg)
94 }
95
96 pub async fn get_topic_nodes(self: Self, topic: &Topic) -> Result<Vec<NodeId>> {
97 wait_for_relay(&self.endpoint).await?;
98
99 let conn_res = self
100 .endpoint
101 .connect(NodeId::from_str(Self::BOOTSTRAP_NODES)?, Self::ALPN)
102 .await;
103 let conn = conn_res?;
104
105 let (mut send, mut recv) = conn.open_bi().await?;
106
107 let msg = Protocol::TopicRequest((topic.clone(), self.node_id.clone()));
108 Self::send_msg(msg, &mut send).await?;
109
110 let back = match Self::recv_msg(&mut recv).await? {
111 Protocol::TopicList(vec) => {
112 let mut _kv = self.kv.lock().await;
113 match _kv.get_mut(&topic.0) {
114 Some(node_ids) => {
115 for id in vec.clone() {
116 if node_ids.contains(&id) {
117 node_ids.retain(|nid| !nid.eq(&id));
118 }
119 node_ids.push(id);
120 }
121 }
122 None => {
123 let mut node_ids = Vec::with_capacity(Self::MAX_NODE_IDS_PER_TOPIC);
124 for id in vec.clone() {
125 node_ids.push(id);
126 }
127 _kv.insert(topic.0, node_ids);
128 }
129 };
130 drop(_kv);
131 Ok(vec)
132 }
133 _ => bail!("illegal message received"),
134 };
135
136 Self::send_msg(Protocol::Done, &mut send).await?;
137 back
138 }
139
140 async fn accept(&self, conn: Connection) -> Result<()> {
141 let (mut send, mut recv) = conn.accept_bi().await?;
142 let msg = Self::recv_msg(&mut recv).await?;
143
144 match msg {
145 Protocol::TopicRequest((topic, remote_node_id)) => {
146 let mut _kv = self.kv.lock().await;
147 let resp;
148 match _kv.get_mut(&topic.0) {
149 Some(node_ids) => {
150 let latest_list = node_ids
151 .iter()
152 .filter(|&i| !i.eq(&remote_node_id))
153 .rev()
154 .take(Self::MAX_TOPIC_LIST_SIZE)
155 .map(|i| *i)
156 .collect();
157
158 resp = Protocol::TopicList(latest_list);
159
160 if node_ids.contains(&remote_node_id) {
161 node_ids.retain(|nid| !nid.eq(&remote_node_id));
162 }
163 node_ids.push(remote_node_id);
164 }
165 None => {
166 let mut node_ids = Vec::with_capacity(Self::MAX_NODE_IDS_PER_TOPIC);
167 node_ids.push(remote_node_id);
168 _kv.insert(topic.0, node_ids);
169
170 resp = Protocol::TopicList(vec![]);
171 }
172 };
173
174 Self::send_msg(resp, &mut send).await?;
175 Self::send_msg(Protocol::Done, &mut send).await?;
176 Self::recv_msg(&mut recv).await?;
177
178 drop(_kv);
179 }
180 _ => {
181 bail!("Illegal request");
182 }
183 };
184
185 send.finish()?;
186 Ok(())
187 }
188
189 pub async fn memory_footprint(&self) -> usize {
190 let _kv = self.kv.lock().await;
191 let val = &*_kv;
192 size_of_val(val)
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
197pub struct Topic([u8; 32]);
198
199impl Topic {
200 pub fn new(topic: [u8; 32]) -> Self {
201 Self(topic)
202 }
203
204 pub fn from_passphrase(phrase: &str) -> Self {
205 Self(Self::hash(phrase))
206 }
207
208 fn hash(s: &str) -> [u8; 32] {
209 let mut hasher = Sha256::new();
210 hasher.update(s);
211 let mut buf = [0u8; 32];
212 buf.copy_from_slice(&hasher.finalize()[..32]);
213 buf
214 }
215
216 pub fn to_string(&self) -> String {
217 z32::encode(&self.0)
218 }
219
220 pub fn to_secret_key(&self) -> SecretKey {
221 SecretKey::from_bytes(&self.0.clone())
222 }
223}
224
225impl Default for Topic {
226 fn default() -> Self {
227 Self::from_passphrase("password")
228 }
229}
230
231impl ProtocolHandler for TopicTracker {
232 fn accept(
233 &self,
234 conn: Connection,
235 ) -> impl Future<Output = Result<(), AcceptError>> + Send {
236 let topic_tracker = self.clone();
237
238 Box::pin(async move {
239 topic_tracker
240 .accept(conn)
241 .await
242 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
243 Ok(())
244 })
245 }
246}
247
248#[cfg(feature = "iroh-gossip-cast")]
249impl Into<iroh_gossip::proto::TopicId> for Topic {
250 fn into(self) -> iroh_gossip::proto::TopicId {
251 TopicId::from_bytes(self.0)
252 }
253}
254
255#[cfg(feature = "iroh-gossip-cast")]
256impl From<iroh_gossip::proto::TopicId> for Topic {
257 fn from(value: iroh_gossip::proto::TopicId) -> Self {
258 Self {
259 0: *value.as_bytes(),
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_topic_from_passphrase() {
270 let topic = Topic::from_passphrase("test123");
271 assert_eq!(topic.0.len(), 32);
272 }
273
274 #[test]
275 fn test_topic_new() {
276 let bytes = [0u8; 32];
277 let topic = Topic::new(bytes);
278 assert_eq!(topic.0, bytes);
279 }
280
281 #[test]
282 fn test_topic_default() {
283 let topic = Topic::default();
284 assert_eq!(topic, Topic::from_passphrase("password"));
285 }
286
287 #[test]
288 fn test_topic_to_string() {
289 let topic = Topic::from_passphrase("test");
290 assert!(!topic.to_string().is_empty());
291 }
292}