Skip to main content

hashtree_network/
runtime_state.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::RwLock;
7
8use crate::local_bus::SharedLocalNostrBus;
9use crate::mesh_session::{resolve_root_from_local_buses_with_source, MeshSession};
10use crate::root_events::PeerRootEvent;
11use crate::runtime_peer::MeshPeerEntry;
12
13/// Shared runtime state for transport-backed mesh peers.
14pub struct MeshRuntimeState<P> {
15    pub peers: Arc<RwLock<HashMap<String, MeshPeerEntry<P>>>>,
16    pub connected_count: Arc<AtomicUsize>,
17    peer_hash_get: Arc<RwLock<HashMap<String, bool>>>,
18    pub bytes_sent: AtomicU64,
19    pub bytes_received: AtomicU64,
20    pub mesh_received: AtomicU64,
21    pub mesh_forwarded: AtomicU64,
22    pub mesh_dropped_duplicate: AtomicU64,
23    local_buses: RwLock<Vec<SharedLocalNostrBus>>,
24}
25
26impl<P> Default for MeshRuntimeState<P>
27where
28    P: MeshSession + Send + Sync + 'static,
29{
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl<P> MeshRuntimeState<P>
36where
37    P: MeshSession + Send + Sync + 'static,
38{
39    pub fn new() -> Self {
40        Self {
41            peers: Arc::new(RwLock::new(HashMap::new())),
42            connected_count: Arc::new(AtomicUsize::new(0)),
43            peer_hash_get: Arc::new(RwLock::new(HashMap::new())),
44            bytes_sent: AtomicU64::new(0),
45            bytes_received: AtomicU64::new(0),
46            mesh_received: AtomicU64::new(0),
47            mesh_forwarded: AtomicU64::new(0),
48            mesh_dropped_duplicate: AtomicU64::new(0),
49            local_buses: RwLock::new(Vec::new()),
50        }
51    }
52
53    pub async fn set_local_buses(&self, buses: Vec<SharedLocalNostrBus>) {
54        *self.local_buses.write().await = buses;
55    }
56
57    pub async fn add_local_bus(&self, bus: SharedLocalNostrBus) {
58        self.local_buses.write().await.push(bus);
59    }
60
61    pub async fn set_peer_hash_get(&self, peer_id: &str, enabled: bool) {
62        self.peer_hash_get
63            .write()
64            .await
65            .insert(peer_id.to_string(), enabled);
66    }
67
68    pub async fn clear_peer_hash_get(&self, peer_id: &str) {
69        self.peer_hash_get.write().await.remove(peer_id);
70    }
71
72    pub async fn peer_hash_get_enabled(&self, peer_id: &str) -> bool {
73        self.peer_hash_get
74            .read()
75            .await
76            .get(peer_id)
77            .copied()
78            .unwrap_or(true)
79    }
80
81    pub async fn peer_hash_get_snapshot(&self) -> HashMap<String, bool> {
82        self.peer_hash_get.read().await.clone()
83    }
84
85    pub async fn local_buses(&self) -> Vec<SharedLocalNostrBus> {
86        self.local_buses.read().await.clone()
87    }
88
89    pub async fn reset(&self) {
90        self.set_local_buses(Vec::new()).await;
91        let peers = {
92            let mut peers = self.peers.write().await;
93            std::mem::take(&mut *peers)
94        };
95        self.peer_hash_get.write().await.clear();
96        self.connected_count.store(0, Ordering::Relaxed);
97        for entry in peers.into_values() {
98            if let Some(peer) = entry.peer {
99                let _ = peer.close().await;
100            }
101        }
102    }
103
104    pub fn get_bandwidth(&self) -> (u64, u64) {
105        (
106            self.bytes_sent.load(Ordering::Relaxed),
107            self.bytes_received.load(Ordering::Relaxed),
108        )
109    }
110
111    pub fn get_mesh_stats(&self) -> (u64, u64, u64) {
112        (
113            self.mesh_received.load(Ordering::Relaxed),
114            self.mesh_forwarded.load(Ordering::Relaxed),
115            self.mesh_dropped_duplicate.load(Ordering::Relaxed),
116        )
117    }
118
119    pub fn record_mesh_received(&self) {
120        self.mesh_received.fetch_add(1, Ordering::Relaxed);
121    }
122
123    pub fn record_mesh_forwarded(&self, count: u64) {
124        self.mesh_forwarded.fetch_add(count, Ordering::Relaxed);
125    }
126
127    pub fn record_mesh_duplicate_drop(&self) {
128        self.mesh_dropped_duplicate.fetch_add(1, Ordering::Relaxed);
129    }
130
131    pub async fn record_sent(&self, peer_id: &str, bytes: u64) {
132        self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
133        if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
134            entry.bytes_sent += bytes;
135        }
136    }
137
138    pub async fn record_received(&self, peer_id: &str, bytes: u64) {
139        self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
140        if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
141            entry.bytes_received += bytes;
142        }
143    }
144
145    pub async fn resolve_root_from_local_buses_with_source(
146        &self,
147        owner_pubkey: &str,
148        tree_name: &str,
149        timeout: Duration,
150    ) -> Option<(&'static str, PeerRootEvent)> {
151        resolve_root_from_local_buses_with_source(
152            self.local_buses().await,
153            owner_pubkey,
154            tree_name,
155            timeout,
156        )
157        .await
158    }
159
160    pub async fn resolve_root_from_local_buses(
161        &self,
162        owner_pubkey: &str,
163        tree_name: &str,
164        timeout: Duration,
165    ) -> Option<PeerRootEvent> {
166        self.resolve_root_from_local_buses_with_source(owner_pubkey, tree_name, timeout)
167            .await
168            .map(|(_, root)| root)
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use anyhow::Result;
176    use async_trait::async_trait;
177    use nostr_sdk::nostr::{Event, Filter};
178    use std::collections::BTreeSet;
179    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
180    use std::time::Instant;
181
182    use crate::local_bus::LocalNostrBus;
183    use crate::runtime_peer::{
184        ConnectionState, MeshPeerEntry, PeerDirection, PeerSignalPath, PeerTransport,
185    };
186    use crate::types::{MeshNostrFrame, PeerHTLConfig, PeerId, PeerPool};
187
188    struct TestSession {
189        closed: AtomicBool,
190    }
191
192    #[async_trait]
193    impl MeshSession for TestSession {
194        fn is_ready(&self) -> bool {
195            true
196        }
197
198        fn is_connected(&self) -> bool {
199            true
200        }
201
202        fn htl_config(&self) -> PeerHTLConfig {
203            PeerHTLConfig::from_flags(false, false)
204        }
205
206        async fn request(&self, _hash_hex: &str, _timeout: Duration) -> Result<Option<Vec<u8>>> {
207            Ok(None)
208        }
209
210        async fn query_nostr_events(
211            &self,
212            _filters: Vec<Filter>,
213            _timeout: Duration,
214        ) -> Result<Vec<Event>> {
215            Ok(Vec::new())
216        }
217
218        async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> Result<()> {
219            Ok(())
220        }
221
222        async fn close(&self) -> Result<()> {
223            self.closed.store(true, AtomicOrdering::Relaxed);
224            Ok(())
225        }
226    }
227
228    struct TestLocalBus {
229        source: &'static str,
230        root: Option<PeerRootEvent>,
231    }
232
233    #[async_trait]
234    impl LocalNostrBus for TestLocalBus {
235        fn source_name(&self) -> &'static str {
236            self.source
237        }
238
239        async fn broadcast_event(&self, _event: &Event) -> Result<()> {
240            Ok(())
241        }
242
243        async fn query_root(
244            &self,
245            _owner_pubkey: &str,
246            _tree_name: &str,
247            _timeout: Duration,
248        ) -> Option<PeerRootEvent> {
249            self.root.clone()
250        }
251    }
252
253    #[tokio::test]
254    async fn record_updates_global_and_per_peer_counters() {
255        let runtime = MeshRuntimeState::<TestSession>::new();
256        let peer_id = PeerId::new("peer-a".to_string());
257        let peer_key = peer_id.to_string();
258        runtime.peers.write().await.insert(
259            peer_key.clone(),
260            MeshPeerEntry {
261                peer_id,
262                direction: PeerDirection::Outbound,
263                state: ConnectionState::Connected,
264                last_seen: Instant::now(),
265                peer: None,
266                pool: PeerPool::Other,
267                transport: PeerTransport::WebRtc,
268                signal_paths: BTreeSet::from([PeerSignalPath::Relay]),
269                bytes_sent: 0,
270                bytes_received: 0,
271            },
272        );
273
274        runtime.record_sent(&peer_key, 16).await;
275        runtime.record_received(&peer_key, 32).await;
276
277        assert_eq!(runtime.get_bandwidth(), (16, 32));
278        let peers = runtime.peers.read().await;
279        let entry = peers.get(&peer_key).expect("peer");
280        assert_eq!(entry.bytes_sent, 16);
281        assert_eq!(entry.bytes_received, 32);
282    }
283
284    #[tokio::test]
285    async fn reset_closes_peers_and_clears_local_buses() {
286        let runtime = MeshRuntimeState::<TestSession>::new();
287        let session = TestSession {
288            closed: AtomicBool::new(false),
289        };
290        let peer_id = PeerId::new("peer-a".to_string());
291        runtime.peers.write().await.insert(
292            peer_id.to_string(),
293            MeshPeerEntry {
294                peer_id,
295                direction: PeerDirection::Outbound,
296                state: ConnectionState::Connected,
297                last_seen: Instant::now(),
298                peer: Some(session),
299                pool: PeerPool::Other,
300                transport: PeerTransport::Bluetooth,
301                signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
302                bytes_sent: 0,
303                bytes_received: 0,
304            },
305        );
306        runtime.connected_count.store(1, Ordering::Relaxed);
307        runtime
308            .set_local_buses(vec![Arc::new(TestLocalBus {
309                source: "mock",
310                root: None,
311            }) as SharedLocalNostrBus])
312            .await;
313
314        runtime.reset().await;
315
316        assert_eq!(runtime.connected_count.load(Ordering::Relaxed), 0);
317        assert!(runtime.peers.read().await.is_empty());
318        assert!(runtime.local_buses().await.is_empty());
319    }
320
321    #[tokio::test]
322    async fn resolve_root_from_local_buses_returns_first_match() {
323        let runtime = MeshRuntimeState::<TestSession>::new();
324        let root = PeerRootEvent {
325            hash: "ab".repeat(32),
326            key: None,
327            encrypted_key: None,
328            self_encrypted_key: None,
329            event_id: "event-1".to_string(),
330            created_at: 1,
331            peer_id: "bus-peer".to_string(),
332        };
333        runtime
334            .set_local_buses(vec![
335                Arc::new(TestLocalBus {
336                    source: "empty",
337                    root: None,
338                }) as SharedLocalNostrBus,
339                Arc::new(TestLocalBus {
340                    source: "mock-bus",
341                    root: Some(root.clone()),
342                }) as SharedLocalNostrBus,
343            ])
344            .await;
345
346        let resolved = runtime
347            .resolve_root_from_local_buses_with_source("owner", "tree", Duration::from_millis(10))
348            .await
349            .expect("root");
350
351        assert_eq!(resolved.0, "mock-bus");
352        assert_eq!(resolved.1, root);
353    }
354}