Skip to main content

hashtree_network/
runtime_state.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use tokio::sync::RwLock;
7
8use crate::local_bus::SharedLocalNostrBus;
9use crate::mesh_session::{resolve_root_from_local_buses_with_source, MeshSession};
10use crate::root_events::PeerRootEvent;
11use crate::runtime_peer::MeshPeerEntry;
12use crate::types::{KnownPeerRecord, KnownPeerSnapshot};
13
14/// Shared runtime state for transport-backed mesh peers.
15pub struct MeshRuntimeState<P> {
16    pub peers: Arc<RwLock<HashMap<String, MeshPeerEntry<P>>>>,
17    pub connected_count: Arc<AtomicUsize>,
18    peer_hash_get: Arc<RwLock<HashMap<String, bool>>>,
19    pub bytes_sent: AtomicU64,
20    pub bytes_received: AtomicU64,
21    pub mesh_received: AtomicU64,
22    pub mesh_forwarded: AtomicU64,
23    pub mesh_dropped_duplicate: AtomicU64,
24    local_buses: RwLock<Vec<SharedLocalNostrBus>>,
25    known_peers: RwLock<HashMap<String, KnownPeerRecord>>,
26}
27
28impl<P> Default for MeshRuntimeState<P>
29where
30    P: MeshSession + Send + Sync + 'static,
31{
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<P> MeshRuntimeState<P>
38where
39    P: MeshSession + Send + Sync + 'static,
40{
41    pub fn new() -> Self {
42        Self {
43            peers: Arc::new(RwLock::new(HashMap::new())),
44            connected_count: Arc::new(AtomicUsize::new(0)),
45            peer_hash_get: Arc::new(RwLock::new(HashMap::new())),
46            bytes_sent: AtomicU64::new(0),
47            bytes_received: AtomicU64::new(0),
48            mesh_received: AtomicU64::new(0),
49            mesh_forwarded: AtomicU64::new(0),
50            mesh_dropped_duplicate: AtomicU64::new(0),
51            local_buses: RwLock::new(Vec::new()),
52            known_peers: RwLock::new(HashMap::new()),
53        }
54    }
55
56    pub async fn set_local_buses(&self, buses: Vec<SharedLocalNostrBus>) {
57        *self.local_buses.write().await = buses;
58    }
59
60    pub async fn add_local_bus(&self, bus: SharedLocalNostrBus) {
61        self.local_buses.write().await.push(bus);
62    }
63
64    pub async fn set_peer_hash_get(&self, peer_id: &str, enabled: bool) {
65        self.peer_hash_get
66            .write()
67            .await
68            .insert(peer_id.to_string(), enabled);
69    }
70
71    pub async fn clear_peer_hash_get(&self, peer_id: &str) {
72        self.peer_hash_get.write().await.remove(peer_id);
73    }
74
75    pub async fn peer_hash_get_enabled(&self, peer_id: &str) -> bool {
76        self.peer_hash_get
77            .read()
78            .await
79            .get(peer_id)
80            .copied()
81            .unwrap_or(true)
82    }
83
84    pub async fn peer_hash_get_snapshot(&self) -> HashMap<String, bool> {
85        self.peer_hash_get.read().await.clone()
86    }
87
88    pub async fn record_known_peer_signal_urls(
89        &self,
90        peer_id: &str,
91        signal_urls: &[String],
92        source: &str,
93    ) {
94        let clean_signal_urls = normalize_peer_signal_urls(signal_urls);
95        if clean_signal_urls.is_empty() {
96            return;
97        }
98
99        let mut known = self.known_peers.write().await;
100        let entry = known
101            .entry(peer_id.to_string())
102            .or_insert_with(|| KnownPeerRecord {
103                peer_id: peer_id.to_string(),
104                signal_urls: Vec::new(),
105                last_seen_unix_ms: 0,
106                last_source: None,
107            });
108        for signal_url in clean_signal_urls {
109            if !entry.signal_urls.contains(&signal_url) {
110                entry.signal_urls.push(signal_url);
111            }
112        }
113        entry.signal_urls.sort();
114        entry.last_seen_unix_ms = now_unix_ms();
115        entry.last_source = Some(source.to_string());
116    }
117
118    pub async fn known_peer_snapshot(&self) -> KnownPeerSnapshot {
119        let mut peers: Vec<KnownPeerRecord> =
120            self.known_peers.read().await.values().cloned().collect();
121        peers.sort_by(|a, b| a.peer_id.cmp(&b.peer_id));
122        KnownPeerSnapshot { version: 1, peers }
123    }
124
125    pub async fn import_known_peer_snapshot(&self, snapshot: &KnownPeerSnapshot) {
126        if snapshot.version != 1 {
127            return;
128        }
129        let mut known = self.known_peers.write().await;
130        known.clear();
131        for peer in &snapshot.peers {
132            let signal_urls = normalize_peer_signal_urls(&peer.signal_urls);
133            if peer.peer_id.trim().is_empty() || signal_urls.is_empty() {
134                continue;
135            }
136            known.insert(
137                peer.peer_id.clone(),
138                KnownPeerRecord {
139                    peer_id: peer.peer_id.clone(),
140                    signal_urls,
141                    last_seen_unix_ms: peer.last_seen_unix_ms,
142                    last_source: peer.last_source.clone(),
143                },
144            );
145        }
146    }
147
148    pub async fn local_buses(&self) -> Vec<SharedLocalNostrBus> {
149        self.local_buses.read().await.clone()
150    }
151
152    pub async fn reset(&self) {
153        self.set_local_buses(Vec::new()).await;
154        let peers = {
155            let mut peers = self.peers.write().await;
156            std::mem::take(&mut *peers)
157        };
158        self.peer_hash_get.write().await.clear();
159        self.connected_count.store(0, Ordering::Relaxed);
160        for entry in peers.into_values() {
161            if let Some(peer) = entry.peer {
162                let _ = peer.close().await;
163            }
164        }
165    }
166
167    pub fn get_bandwidth(&self) -> (u64, u64) {
168        (
169            self.bytes_sent.load(Ordering::Relaxed),
170            self.bytes_received.load(Ordering::Relaxed),
171        )
172    }
173
174    pub fn get_mesh_stats(&self) -> (u64, u64, u64) {
175        (
176            self.mesh_received.load(Ordering::Relaxed),
177            self.mesh_forwarded.load(Ordering::Relaxed),
178            self.mesh_dropped_duplicate.load(Ordering::Relaxed),
179        )
180    }
181
182    pub fn record_mesh_received(&self) {
183        self.mesh_received.fetch_add(1, Ordering::Relaxed);
184    }
185
186    pub fn record_mesh_forwarded(&self, count: u64) {
187        self.mesh_forwarded.fetch_add(count, Ordering::Relaxed);
188    }
189
190    pub fn record_mesh_duplicate_drop(&self) {
191        self.mesh_dropped_duplicate.fetch_add(1, Ordering::Relaxed);
192    }
193
194    pub async fn record_sent(&self, peer_id: &str, bytes: u64) {
195        self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
196        if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
197            entry.bytes_sent += bytes;
198        }
199    }
200
201    pub async fn record_received(&self, peer_id: &str, bytes: u64) {
202        self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
203        if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
204            entry.bytes_received += bytes;
205        }
206    }
207
208    pub async fn resolve_root_from_local_buses_with_source(
209        &self,
210        owner_pubkey: &str,
211        tree_name: &str,
212        timeout: Duration,
213    ) -> Option<(&'static str, PeerRootEvent)> {
214        resolve_root_from_local_buses_with_source(
215            self.local_buses().await,
216            owner_pubkey,
217            tree_name,
218            timeout,
219        )
220        .await
221    }
222
223    pub async fn resolve_root_from_local_buses(
224        &self,
225        owner_pubkey: &str,
226        tree_name: &str,
227        timeout: Duration,
228    ) -> Option<PeerRootEvent> {
229        self.resolve_root_from_local_buses_with_source(owner_pubkey, tree_name, timeout)
230            .await
231            .map(|(_, root)| root)
232    }
233}
234
235fn now_unix_ms() -> u64 {
236    SystemTime::now()
237        .duration_since(UNIX_EPOCH)
238        .map(|duration| duration.as_millis().min(u128::from(u64::MAX)) as u64)
239        .unwrap_or(0)
240}
241
242fn normalize_peer_signal_urls(signal_urls: &[String]) -> Vec<String> {
243    let mut output = Vec::new();
244    for signal_url in signal_urls {
245        let trimmed = signal_url.trim().trim_end_matches('/').to_string();
246        if trimmed.is_empty() || !trimmed.starts_with("http://") || output.contains(&trimmed) {
247            continue;
248        }
249        output.push(trimmed);
250    }
251    output.sort();
252    output
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use anyhow::Result;
259    use async_trait::async_trait;
260    use nostr_sdk::nostr::{Event, Filter};
261    use std::collections::BTreeSet;
262    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
263    use std::time::Instant;
264
265    use crate::local_bus::LocalNostrBus;
266    use crate::runtime_peer::{
267        ConnectionState, MeshPeerEntry, PeerDirection, PeerSignalPath, PeerTransport,
268    };
269    use crate::types::{MeshNostrFrame, PeerHTLConfig, PeerId, PeerPool};
270
271    struct TestSession {
272        closed: AtomicBool,
273    }
274
275    #[async_trait]
276    impl MeshSession for TestSession {
277        fn is_ready(&self) -> bool {
278            true
279        }
280
281        fn is_connected(&self) -> bool {
282            true
283        }
284
285        fn htl_config(&self) -> PeerHTLConfig {
286            PeerHTLConfig::from_flags(false, false)
287        }
288
289        async fn request(&self, _hash_hex: &str, _timeout: Duration) -> Result<Option<Vec<u8>>> {
290            Ok(None)
291        }
292
293        async fn query_nostr_events(
294            &self,
295            _filters: Vec<Filter>,
296            _timeout: Duration,
297        ) -> Result<Vec<Event>> {
298            Ok(Vec::new())
299        }
300
301        async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> Result<()> {
302            Ok(())
303        }
304
305        async fn close(&self) -> Result<()> {
306            self.closed.store(true, AtomicOrdering::Relaxed);
307            Ok(())
308        }
309    }
310
311    struct TestLocalBus {
312        source: &'static str,
313        root: Option<PeerRootEvent>,
314    }
315
316    #[async_trait]
317    impl LocalNostrBus for TestLocalBus {
318        fn source_name(&self) -> &'static str {
319            self.source
320        }
321
322        async fn broadcast_event(&self, _event: &Event) -> Result<()> {
323            Ok(())
324        }
325
326        async fn query_root(
327            &self,
328            _owner_pubkey: &str,
329            _tree_name: &str,
330            _timeout: Duration,
331        ) -> Option<PeerRootEvent> {
332            self.root.clone()
333        }
334    }
335
336    #[tokio::test]
337    async fn record_updates_global_and_per_peer_counters() {
338        let runtime = MeshRuntimeState::<TestSession>::new();
339        let peer_id = PeerId::new("peer-a".to_string());
340        let peer_key = peer_id.to_string();
341        runtime.peers.write().await.insert(
342            peer_key.clone(),
343            MeshPeerEntry {
344                peer_id,
345                direction: PeerDirection::Outbound,
346                state: ConnectionState::Connected,
347                last_seen: Instant::now(),
348                peer: None,
349                pool: PeerPool::Other,
350                transport: PeerTransport::WebRtc,
351                signal_paths: BTreeSet::from([PeerSignalPath::Relay]),
352                bytes_sent: 0,
353                bytes_received: 0,
354            },
355        );
356
357        runtime.record_sent(&peer_key, 16).await;
358        runtime.record_received(&peer_key, 32).await;
359
360        assert_eq!(runtime.get_bandwidth(), (16, 32));
361        let peers = runtime.peers.read().await;
362        let entry = peers.get(&peer_key).expect("peer");
363        assert_eq!(entry.bytes_sent, 16);
364        assert_eq!(entry.bytes_received, 32);
365    }
366
367    #[tokio::test]
368    async fn reset_closes_peers_and_clears_local_buses() {
369        let runtime = MeshRuntimeState::<TestSession>::new();
370        let session = TestSession {
371            closed: AtomicBool::new(false),
372        };
373        let peer_id = PeerId::new("peer-a".to_string());
374        runtime.peers.write().await.insert(
375            peer_id.to_string(),
376            MeshPeerEntry {
377                peer_id,
378                direction: PeerDirection::Outbound,
379                state: ConnectionState::Connected,
380                last_seen: Instant::now(),
381                peer: Some(session),
382                pool: PeerPool::Other,
383                transport: PeerTransport::Bluetooth,
384                signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
385                bytes_sent: 0,
386                bytes_received: 0,
387            },
388        );
389        runtime.connected_count.store(1, Ordering::Relaxed);
390        runtime
391            .set_local_buses(vec![Arc::new(TestLocalBus {
392                source: "mock",
393                root: None,
394            }) as SharedLocalNostrBus])
395            .await;
396
397        runtime.reset().await;
398
399        assert_eq!(runtime.connected_count.load(Ordering::Relaxed), 0);
400        assert!(runtime.peers.read().await.is_empty());
401        assert!(runtime.local_buses().await.is_empty());
402    }
403
404    #[tokio::test]
405    async fn resolve_root_from_local_buses_returns_first_match() {
406        let runtime = MeshRuntimeState::<TestSession>::new();
407        let root = PeerRootEvent {
408            hash: "ab".repeat(32),
409            key: None,
410            encrypted_key: None,
411            self_encrypted_key: None,
412            event_id: "event-1".to_string(),
413            created_at: 1,
414            peer_id: "bus-peer".to_string(),
415        };
416        runtime
417            .set_local_buses(vec![
418                Arc::new(TestLocalBus {
419                    source: "empty",
420                    root: None,
421                }) as SharedLocalNostrBus,
422                Arc::new(TestLocalBus {
423                    source: "mock-bus",
424                    root: Some(root.clone()),
425                }) as SharedLocalNostrBus,
426            ])
427            .await;
428
429        let resolved = runtime
430            .resolve_root_from_local_buses_with_source("owner", "tree", Duration::from_millis(10))
431            .await
432            .expect("root");
433
434        assert_eq!(resolved.0, "mock-bus");
435        assert_eq!(resolved.1, root);
436    }
437}