1use crate::peer::{Peer, PeerError};
9use crate::peer_selector::PeerSelector;
10use crate::types::{
11 ClassifyRequest, ForwardRx, ForwardTx, PeerId, PeerPool, PeerState, SignalingMessage,
12 WebRTCStats, WebRTCStoreConfig, NOSTR_KIND_HASHTREE,
13};
14use crate::{build_hedged_wave_plan, normalize_dispatch_config, sync_selector_peers};
15use async_trait::async_trait;
16use hashtree_core::{to_hex, Hash, Store, StoreError};
17use nostr_sdk::prelude::*;
18use nostr_sdk::ClientBuilder;
19use std::collections::HashMap;
20use std::sync::Arc;
21use thiserror::Error;
22use tokio::sync::{mpsc, oneshot, RwLock};
23use uuid::Uuid;
24
25#[derive(Debug, Error)]
26pub enum WebRTCStoreError {
27 #[error("Peer error: {0}")]
28 Peer(#[from] PeerError),
29 #[error("Nostr error: {0}")]
30 Nostr(String),
31 #[error("No peers available")]
32 NoPeers,
33 #[error("Data not found")]
34 NotFound,
35 #[error("Store error: {0}")]
36 Store(#[from] StoreError),
37}
38
39pub type MeshStoreError = WebRTCStoreError;
40
41struct PeerEntry<S: Store> {
43 peer: Arc<Peer<S>>,
44 pool: PeerPool,
45}
46
47pub struct WebRTCStore<S: Store> {
49 local_store: Arc<S>,
51 config: WebRTCStoreConfig,
53 client: Option<Client>,
55 peer_id: PeerId,
57 peers: Arc<RwLock<HashMap<String, PeerEntry<S>>>>,
59 peer_roots: Arc<RwLock<HashMap<String, Vec<String>>>>,
61 signaling_tx: mpsc::Sender<SignalingMessage>,
63 signaling_rx: Arc<RwLock<Option<mpsc::Receiver<SignalingMessage>>>>,
65 forward_tx: ForwardTx,
67 forward_rx: Arc<RwLock<Option<ForwardRx>>>,
69 running: Arc<RwLock<bool>>,
71 stats: Arc<RwLock<WebRTCStats>>,
73 peer_selector: Arc<RwLock<PeerSelector>>,
75}
76
77pub type MeshStore<S> = WebRTCStore<S>;
78
79impl<S: Store + 'static> WebRTCStore<S> {
80 pub fn new(local_store: Arc<S>, config: WebRTCStoreConfig) -> Self {
82 let (signaling_tx, signaling_rx) = mpsc::channel(100);
83 let (forward_tx, forward_rx) = mpsc::channel(100);
84 let mut selector = PeerSelector::with_strategy(config.request_selection_strategy);
85 selector.set_fairness(config.request_fairness_enabled);
86
87 let peer_id = PeerId::new(String::new(), Uuid::new_v4().to_string());
88
89 Self {
90 local_store,
91 config,
92 client: None,
93 peer_id,
94 peers: Arc::new(RwLock::new(HashMap::new())),
95 peer_roots: Arc::new(RwLock::new(HashMap::new())),
96 signaling_tx,
97 signaling_rx: Arc::new(RwLock::new(Some(signaling_rx))),
98 forward_tx,
99 forward_rx: Arc::new(RwLock::new(Some(forward_rx))),
100 running: Arc::new(RwLock::new(false)),
101 stats: Arc::new(RwLock::new(WebRTCStats::default())),
102 peer_selector: Arc::new(RwLock::new(selector)),
103 }
104 }
105
106 async fn ordered_ready_peers_by_pool(
107 peers: &RwLock<HashMap<String, PeerEntry<S>>>,
108 peer_selector: &RwLock<PeerSelector>,
109 exclude_peer_id: Option<&str>,
110 ) -> (Vec<(String, Arc<Peer<S>>)>, Vec<(String, Arc<Peer<S>>)>) {
111 let current_peer_ids: Vec<String> = {
112 let peers_read = peers.read().await;
113 peers_read.keys().cloned().collect()
114 };
115 sync_selector_peers(peer_selector, ¤t_peer_ids).await;
116
117 let mut ordered_peer_ids = peer_selector.write().await.select_peers();
118 if ordered_peer_ids.is_empty() {
119 ordered_peer_ids = current_peer_ids;
120 ordered_peer_ids.sort();
121 }
122
123 let peers_read = peers.read().await;
124 let mut follows_peers = Vec::new();
125 let mut other_peers = Vec::new();
126 for peer_id in ordered_peer_ids {
127 if exclude_peer_id
128 .map(|excluded| excluded == peer_id)
129 .unwrap_or(false)
130 {
131 continue;
132 }
133
134 let Some((peer, pool)) = peers_read
135 .get(&peer_id)
136 .map(|entry| (entry.peer.clone(), entry.pool))
137 else {
138 continue;
139 };
140
141 if peer.state().await != PeerState::Ready {
142 continue;
143 }
144
145 match pool {
146 PeerPool::Follows => follows_peers.push((peer_id, peer)),
147 PeerPool::Other => other_peers.push((peer_id, peer)),
148 }
149 }
150
151 (follows_peers, other_peers)
152 }
153
154 pub fn forward_tx(&self) -> ForwardTx {
156 self.forward_tx.clone()
157 }
158
159 pub async fn start(&mut self, keys: Keys) -> Result<(), WebRTCStoreError> {
161 self.peer_id.pubkey = keys.public_key().to_hex();
163
164 let client = ClientBuilder::new()
167 .signer(keys.clone())
168 .database(nostr_sdk::database::MemoryDatabase::new())
169 .build();
170
171 for relay in &self.config.relays {
173 client
174 .add_relay(relay)
175 .await
176 .map_err(|e| WebRTCStoreError::Nostr(e.to_string()))?;
177 }
178
179 client.connect().await;
181
182 self.client = Some(client.clone());
183 *self.running.write().await = true;
184
185 let filter = Filter::new()
188 .kind(Kind::Custom(NOSTR_KIND_HASHTREE))
189 .since(Timestamp::now());
190
191 client
192 .subscribe(vec![filter], None)
193 .await
194 .map_err(|e| WebRTCStoreError::Nostr(e.to_string()))?;
195
196 self.send_hello().await?;
198
199 self.start_event_handler(client.clone()).await;
201 self.start_signaling_sender(client).await;
202 self.start_hello_timer().await;
203 self.start_forward_handler().await;
204
205 Ok(())
206 }
207
208 async fn start_forward_handler(&self) {
210 let mut rx = self.forward_rx.write().await.take().unwrap();
211 let peers = self.peers.clone();
212 let peer_selector = self.peer_selector.clone();
213 let local_store = self.local_store.clone();
214 let running = self.running.clone();
215 let debug = self.config.debug;
216
217 tokio::spawn(async move {
218 while let Some(req) = rx.recv().await {
219 if !*running.read().await {
220 break;
221 }
222
223 if debug {
224 println!(
225 "[Store] Forward request: hash={}..., htl={}, exclude={}",
226 &to_hex(&req.hash)[..16],
227 req.htl,
228 &req.exclude_peer_id[..req.exclude_peer_id.len().min(16)]
229 );
230 }
231
232 let (follows_peers, other_peers) = Self::ordered_ready_peers_by_pool(
233 peers.as_ref(),
234 peer_selector.as_ref(),
235 Some(&req.exclude_peer_id),
236 )
237 .await;
238
239 let request_bytes = 40u64;
241
242 let mut result = None;
244 for (peer_id, peer) in follows_peers.into_iter().chain(other_peers.into_iter()) {
245 peer_selector
247 .write()
248 .await
249 .record_request(&peer_id, request_bytes);
250 let start_time = std::time::Instant::now();
251
252 match tokio::time::timeout(
254 std::time::Duration::from_millis(500), peer.request_with_htl(&req.hash, req.htl),
256 )
257 .await
258 {
259 Ok(Ok(Some(data))) => {
260 if hashtree_core::sha256(&data) == req.hash {
262 let rtt_ms = start_time.elapsed().as_millis() as u64;
264 peer_selector.write().await.record_success(
265 &peer_id,
266 rtt_ms,
267 data.len() as u64,
268 );
269
270 let _ = local_store.put(req.hash, data.clone()).await;
272 result = Some(data);
273 break;
274 } else {
275 peer_selector.write().await.record_failure(&peer_id);
277 }
278 }
279 Ok(Ok(None)) => {
280 continue;
282 }
283 Ok(Err(_)) => {
284 peer_selector.write().await.record_failure(&peer_id);
286 continue;
287 }
288 Err(_) => {
289 peer_selector.write().await.record_timeout(&peer_id);
291 continue;
292 }
293 }
294 }
295
296 if debug {
297 println!(
298 "[Store] Forward result: hash={}..., found={}",
299 &to_hex(&req.hash)[..16],
300 result.is_some()
301 );
302 }
303
304 let _ = req.response.send(result);
305 }
306 });
307 }
308
309 async fn send_hello(&self) -> Result<(), WebRTCStoreError> {
311 let roots: Vec<String> = self.config.roots.iter().map(to_hex).collect();
312
313 let msg = SignalingMessage::Hello {
314 peer_id: self.peer_id.to_peer_string(),
315 roots,
316 };
317
318 self.signaling_tx
319 .send(msg)
320 .await
321 .map_err(|_| WebRTCStoreError::Nostr("Channel closed".to_string()))?;
322
323 Ok(())
324 }
325
326 async fn start_event_handler(&self, client: Client) {
328 let peers = self.peers.clone();
329 let peer_roots = self.peer_roots.clone();
330 let local_peer_id = self.peer_id.to_peer_string();
331 let signaling_tx = self.signaling_tx.clone();
332 let forward_tx = self.forward_tx.clone();
333 let local_store = self.local_store.clone();
334 let running = self.running.clone();
335 let config = self.config.clone();
336 let stats = self.stats.clone();
337 let peer_selector = self.peer_selector.clone();
338
339 let mut notifications = client.notifications();
342
343 tokio::spawn(async move {
344 loop {
345 if !*running.read().await {
346 break;
347 }
348
349 match tokio::time::timeout(
351 std::time::Duration::from_millis(100),
352 notifications.recv(),
353 )
354 .await
355 {
356 Ok(Ok(notification)) => {
357 if let RelayPoolNotification::Event { event, .. } = notification {
358 if event.kind == Kind::Custom(NOSTR_KIND_HASHTREE) {
360 if config.debug {
361 let content_preview = if event.content.len() > 80 {
362 format!("{}...", &event.content[..80])
363 } else {
364 event.content.clone()
365 };
366 println!("[Store] Received event: {}", content_preview);
367 }
368 if let Ok(msg) =
369 serde_json::from_str::<SignalingMessage>(&event.content)
370 {
371 Self::handle_signaling_message(
372 msg,
373 &local_peer_id,
374 peers.clone(),
375 peer_roots.clone(),
376 signaling_tx.clone(),
377 forward_tx.clone(),
378 local_store.clone(),
379 &config,
380 stats.clone(),
381 peer_selector.clone(),
382 )
383 .await;
384 } else if config.debug {
385 println!(
386 "[Store] Failed to parse signaling message from event"
387 );
388 }
389 }
390 }
391 }
392 Ok(Err(e)) => {
393 if config.debug {
395 println!("[Store] Notification channel error: {:?}", e);
396 }
397 if matches!(e, tokio::sync::broadcast::error::RecvError::Closed) {
400 break;
401 }
402 }
403 Err(_) => {
404 }
406 }
407 }
408 });
409 }
410
411 async fn classify_peer(pubkey: &str, config: &WebRTCStoreConfig) -> PeerPool {
413 if let Some(ref classifier_tx) = config.classifier_tx {
414 let (response_tx, response_rx) = oneshot::channel();
415 let request = ClassifyRequest {
416 pubkey: pubkey.to_string(),
417 response: response_tx,
418 };
419 if classifier_tx.send(request).await.is_ok() {
420 if let Ok(pool) = response_rx.await {
421 return pool;
422 }
423 }
424 }
425 PeerPool::Other
426 }
427
428 async fn count_pools(peers: &HashMap<String, PeerEntry<S>>) -> (usize, usize) {
430 let mut follows = 0;
431 let mut other = 0;
432 for entry in peers.values() {
433 match entry.pool {
434 PeerPool::Follows => follows += 1,
435 PeerPool::Other => other += 1,
436 }
437 }
438 (follows, other)
439 }
440
441 fn can_accept_peer(
443 pool: PeerPool,
444 follows_count: usize,
445 other_count: usize,
446 config: &WebRTCStoreConfig,
447 ) -> bool {
448 match pool {
449 PeerPool::Follows => follows_count < config.pools.follows.max_connections,
450 PeerPool::Other => other_count < config.pools.other.max_connections,
451 }
452 }
453
454 fn pool_needs_peers(
456 pool: PeerPool,
457 follows_count: usize,
458 other_count: usize,
459 config: &WebRTCStoreConfig,
460 ) -> bool {
461 match pool {
462 PeerPool::Follows => follows_count < config.pools.follows.satisfied_connections,
463 PeerPool::Other => other_count < config.pools.other.satisfied_connections,
464 }
465 }
466
467 #[allow(clippy::too_many_arguments)]
469 async fn handle_signaling_message(
470 msg: SignalingMessage,
471 local_peer_id: &str,
472 peers: Arc<RwLock<HashMap<String, PeerEntry<S>>>>,
473 peer_roots: Arc<RwLock<HashMap<String, Vec<String>>>>,
474 signaling_tx: mpsc::Sender<SignalingMessage>,
475 forward_tx: ForwardTx,
476 local_store: Arc<S>,
477 config: &WebRTCStoreConfig,
478 stats: Arc<RwLock<WebRTCStats>>,
479 peer_selector: Arc<RwLock<PeerSelector>>,
480 ) {
481 match &msg {
482 SignalingMessage::Hello { peer_id, roots } => {
483 if peer_id == local_peer_id {
484 return; }
486
487 let peer_pubkey = peer_id.split(':').next().unwrap_or("");
489
490 let pool = Self::classify_peer(peer_pubkey, config).await;
492
493 let peers_read = peers.read().await;
495 let (follows_count, other_count) = Self::count_pools(&peers_read).await;
496 drop(peers_read);
497
498 if !Self::can_accept_peer(pool, follows_count, other_count, config) {
499 if config.debug {
500 println!(
501 "[Store] Ignoring hello from {} - {:?} pool full",
502 peer_id, pool
503 );
504 }
505 return;
506 }
507
508 if config.debug {
509 println!("[Store] Received hello from {} (pool: {:?})", peer_id, pool);
510 }
511
512 peer_roots
514 .write()
515 .await
516 .insert(peer_id.clone(), roots.clone());
517
518 if Self::pool_needs_peers(pool, follows_count, other_count, config) {
521 if let Some(remote_id) = PeerId::from_peer_string(peer_id) {
522 if !peers.read().await.contains_key(peer_id) {
523 if config.debug {
524 println!(
525 "[Store] Initiating connection to {} (pool: {:?})",
526 peer_id, pool
527 );
528 }
529 if let Ok(peer) = Peer::with_forward_channel(
531 remote_id,
532 local_peer_id.to_string(),
533 signaling_tx.clone(),
534 local_store.clone(),
535 config.debug,
536 Some(forward_tx.clone()),
537 )
538 .await
539 {
540 let peer = Arc::new(peer);
541 peers.write().await.insert(
542 peer_id.clone(),
543 PeerEntry {
544 peer: peer.clone(),
545 pool,
546 },
547 );
548 stats.write().await.connected_peers += 1;
549
550 peer_selector.write().await.add_peer(peer_id.clone());
552
553 tokio::spawn(async move {
555 let _ = peer.connect().await;
556 });
557 }
558 }
559 }
560 }
561 }
562 SignalingMessage::Offer {
563 peer_id,
564 target_peer_id,
565 ..
566 }
567 | SignalingMessage::Answer {
568 peer_id,
569 target_peer_id,
570 ..
571 }
572 | SignalingMessage::Candidate {
573 peer_id,
574 target_peer_id,
575 ..
576 }
577 | SignalingMessage::Candidates {
578 peer_id,
579 target_peer_id,
580 ..
581 } => {
582 if target_peer_id != local_peer_id {
583 return; }
585
586 let peer_pubkey = peer_id.split(':').next().unwrap_or("");
588
589 let pool = Self::classify_peer(peer_pubkey, config).await;
591
592 let peers_read = peers.read().await;
594 let (follows_count, other_count) = Self::count_pools(&peers_read).await;
595 drop(peers_read);
596
597 if !Self::can_accept_peer(pool, follows_count, other_count, config) {
598 if config.debug {
599 println!(
600 "[Store] Ignoring signaling from {} - {:?} pool full",
601 peer_id, pool
602 );
603 }
604 return;
605 }
606
607 let peer = {
609 let peers_read = peers.read().await;
610 peers_read.get(peer_id).map(|e| e.peer.clone())
611 };
612
613 let peer = match peer {
614 Some(p) => p,
615 None => {
616 if let Some(remote_id) = PeerId::from_peer_string(peer_id) {
617 if let Ok(p) = Peer::with_forward_channel(
618 remote_id,
619 local_peer_id.to_string(),
620 signaling_tx.clone(),
621 local_store.clone(),
622 config.debug,
623 Some(forward_tx.clone()),
624 )
625 .await
626 {
627 let p = Arc::new(p);
628 peers.write().await.insert(
629 peer_id.clone(),
630 PeerEntry {
631 peer: p.clone(),
632 pool,
633 },
634 );
635 stats.write().await.connected_peers += 1;
636
637 peer_selector.write().await.add_peer(peer_id.clone());
639 p
640 } else {
641 return;
642 }
643 } else {
644 return;
645 }
646 }
647 };
648
649 let _ = peer.handle_signaling(msg).await;
650 }
651 }
652 }
653
654 async fn start_signaling_sender(&self, client: Client) {
656 let mut rx = self.signaling_rx.write().await.take().unwrap();
657 let running = self.running.clone();
658
659 tokio::spawn(async move {
660 while let Some(msg) = rx.recv().await {
661 if !*running.read().await {
662 break;
663 }
664
665 let json = serde_json::to_string(&msg).unwrap();
666 println!(
667 "[Store] Sending signaling: {}",
668 &json[..json.len().min(100)]
669 );
670 let builder = EventBuilder::new(Kind::Custom(NOSTR_KIND_HASHTREE), json, []);
671
672 match client.send_event_builder(builder).await {
673 Ok(output) => {
674 if output.success.is_empty() {
676 eprintln!("[Store] Warning: Event not sent to any relay");
677 }
678 }
679 Err(e) => {
680 eprintln!("[Store] Error sending event: {:?}", e);
681 }
682 }
683 }
684 });
685 }
686
687 async fn start_hello_timer(&self) {
689 let signaling_tx = self.signaling_tx.clone();
690 let peer_id = self.peer_id.to_peer_string();
691 let roots: Vec<String> = self.config.roots.iter().map(to_hex).collect();
692 let interval_ms = self.config.hello_interval_ms;
693 let running = self.running.clone();
694
695 tokio::spawn(async move {
696 let mut interval = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
697
698 loop {
699 interval.tick().await;
700
701 if !*running.read().await {
702 break;
703 }
704
705 let msg = SignalingMessage::Hello {
706 peer_id: peer_id.clone(),
707 roots: roots.clone(),
708 };
709
710 let _ = signaling_tx.send(msg).await;
711 }
712 });
713 }
714
715 pub async fn stop(&self) {
717 *self.running.write().await = false;
718
719 let peers = self.peers.read().await;
721 for entry in peers.values() {
722 let _ = entry.peer.close().await;
723 }
724
725 if let Some(ref client) = self.client {
727 let _ = client.disconnect().await;
728 }
729 }
730
731 pub async fn stats(&self) -> WebRTCStats {
733 self.stats.read().await.clone()
734 }
735
736 pub async fn peer_count(&self) -> usize {
738 let peers = self.peers.read().await;
739 let mut count = 0;
740 for entry in peers.values() {
741 if entry.peer.state().await == PeerState::Ready {
742 count += 1;
743 }
744 }
745 count
746 }
747
748 async fn request_from_peers(&self, hash: &Hash) -> Result<Option<Vec<u8>>, WebRTCStoreError> {
753 let (follows_peers, other_peers) = Self::ordered_ready_peers_by_pool(
754 self.peers.as_ref(),
755 self.peer_selector.as_ref(),
756 None,
757 )
758 .await;
759 let ordered_peers: Vec<(String, Arc<Peer<S>>)> = follows_peers
760 .into_iter()
761 .chain(other_peers.into_iter())
762 .collect();
763 if ordered_peers.is_empty() {
764 return Ok(None);
765 }
766
767 let dispatch = normalize_dispatch_config(self.config.request_dispatch, ordered_peers.len());
768 let wave_plan = build_hedged_wave_plan(ordered_peers.len(), dispatch);
769 if wave_plan.is_empty() {
770 return Ok(None);
771 }
772
773 let request_bytes = 40u64;
774 let mut next_peer_idx = 0usize;
775 for (wave_idx, wave_size) in wave_plan.iter().copied().enumerate() {
776 let from = next_peer_idx;
777 let to = (next_peer_idx + wave_size).min(ordered_peers.len());
778 next_peer_idx = to;
779
780 for (peer_id, peer) in &ordered_peers[from..to] {
781 let peer_id = peer_id.clone();
782 self.peer_selector
784 .write()
785 .await
786 .record_request(&peer_id, request_bytes);
787 let start_time = std::time::Instant::now();
788
789 match peer.request(hash).await {
790 Ok(Some(data)) => {
791 if hashtree_core::sha256(&data) == *hash {
793 let rtt_ms = start_time.elapsed().as_millis() as u64;
795 self.peer_selector.write().await.record_success(
796 &peer_id,
797 rtt_ms,
798 data.len() as u64,
799 );
800
801 let _ = self.local_store.put(*hash, data.clone()).await;
803 let mut stats = self.stats.write().await;
804 stats.requests_fulfilled += 1;
805 stats.bytes_received += data.len() as u64;
806 return Ok(Some(data));
807 } else {
808 self.peer_selector.write().await.record_failure(&peer_id);
810 }
811 }
812 Ok(None) => {
813 continue;
815 }
816 Err(PeerError::Timeout) => {
817 self.peer_selector.write().await.record_timeout(&peer_id);
819 continue;
820 }
821 Err(_) => {
822 self.peer_selector.write().await.record_failure(&peer_id);
824 continue;
825 }
826 }
827 }
828
829 let is_last_wave =
830 wave_idx + 1 == wave_plan.len() || next_peer_idx >= ordered_peers.len();
831 if !is_last_wave && dispatch.hedge_interval_ms > 0 {
832 tokio::time::sleep(std::time::Duration::from_millis(dispatch.hedge_interval_ms))
833 .await;
834 }
835 }
836
837 Ok(None)
838 }
839
840 pub async fn selector_summary(&self) -> crate::peer_selector::SelectorSummary {
842 self.peer_selector.read().await.summary()
843 }
844}
845
846#[async_trait]
847impl<S: Store + 'static> Store for WebRTCStore<S> {
848 async fn put(&self, hash: Hash, data: Vec<u8>) -> Result<bool, StoreError> {
849 self.local_store.put(hash, data).await
850 }
851
852 async fn get(&self, hash: &Hash) -> Result<Option<Vec<u8>>, StoreError> {
853 if let Some(data) = self.local_store.get(hash).await? {
855 return Ok(Some(data));
856 }
857
858 self.stats.write().await.requests_made += 1;
860
861 match self.request_from_peers(hash).await {
863 Ok(data) => Ok(data),
864 Err(_) => Ok(None),
865 }
866 }
867
868 async fn has(&self, hash: &Hash) -> Result<bool, StoreError> {
869 self.local_store.has(hash).await
870 }
871
872 async fn delete(&self, hash: &Hash) -> Result<bool, StoreError> {
873 self.local_store.delete(hash).await
874 }
875}