hashtree_cli/webrtc/
peer.rs

1//! WebRTC peer connection for hashtree data exchange
2
3use anyhow::Result;
4use bytes::Bytes;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tracing::{debug, error, info, warn};
9use webrtc::api::interceptor_registry::register_default_interceptors;
10use webrtc::api::media_engine::MediaEngine;
11use webrtc::api::setting_engine::SettingEngine;
12use webrtc::api::APIBuilder;
13use webrtc::ice::mdns::MulticastDnsMode;
14use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
15use webrtc::data_channel::data_channel_message::DataChannelMessage;
16use webrtc::data_channel::RTCDataChannel;
17use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
18use webrtc::ice_transport::ice_server::RTCIceServer;
19use webrtc::interceptor::registry::Registry;
20use webrtc::peer_connection::configuration::RTCConfiguration;
21use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
22use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
23use webrtc::peer_connection::RTCPeerConnection;
24
25use super::types::{DataMessage, DataRequest, DataResponse, PeerDirection, PeerId, PeerStateEvent, SignalingMessage, encode_message, encode_request, encode_response, parse_message, hash_to_hex};
26
27/// Trait for content storage that can be used by WebRTC peers
28pub trait ContentStore: Send + Sync + 'static {
29    /// Get content by hex hash
30    fn get(&self, hash_hex: &str) -> Result<Option<Vec<u8>>>;
31}
32
33/// Pending request tracking (keyed by hash hex)
34pub struct PendingRequest {
35    pub hash: Vec<u8>,
36    pub response_tx: oneshot::Sender<Option<Vec<u8>>>,
37}
38
39/// WebRTC peer connection with data channel protocol
40pub struct Peer {
41    pub peer_id: PeerId,
42    pub direction: PeerDirection,
43    pub created_at: std::time::Instant,
44    pub connected_at: Option<std::time::Instant>,
45
46    pc: Arc<RTCPeerConnection>,
47    /// Data channel - can be set from callback when receiving channel from peer
48    pub data_channel: Arc<Mutex<Option<Arc<RTCDataChannel>>>>,
49    signaling_tx: mpsc::Sender<SignalingMessage>,
50    my_peer_id: PeerId,
51
52    // Content store for serving requests
53    store: Option<Arc<dyn ContentStore>>,
54
55    // Track pending outgoing requests (keyed by hash hex)
56    pub pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
57
58    // Channel for incoming data messages
59    #[allow(dead_code)]
60    message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
61    #[allow(dead_code)]
62    message_rx: Option<mpsc::Receiver<(DataMessage, Option<Vec<u8>>)>>,
63
64    // Optional channel to notify signaling layer of state changes
65    state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
66}
67
68impl Peer {
69    /// Create a new peer connection
70    pub async fn new(
71        peer_id: PeerId,
72        direction: PeerDirection,
73        my_peer_id: PeerId,
74        signaling_tx: mpsc::Sender<SignalingMessage>,
75        stun_servers: Vec<String>,
76    ) -> Result<Self> {
77        Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, None, None).await
78    }
79
80    /// Create a new peer connection with content store
81    pub async fn new_with_store(
82        peer_id: PeerId,
83        direction: PeerDirection,
84        my_peer_id: PeerId,
85        signaling_tx: mpsc::Sender<SignalingMessage>,
86        stun_servers: Vec<String>,
87        store: Option<Arc<dyn ContentStore>>,
88    ) -> Result<Self> {
89        Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, store, None).await
90    }
91
92    /// Create a new peer connection with content store and state event channel
93    pub async fn new_with_store_and_events(
94        peer_id: PeerId,
95        direction: PeerDirection,
96        my_peer_id: PeerId,
97        signaling_tx: mpsc::Sender<SignalingMessage>,
98        stun_servers: Vec<String>,
99        store: Option<Arc<dyn ContentStore>>,
100        state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
101    ) -> Result<Self> {
102        // Create WebRTC API
103        let mut m = MediaEngine::default();
104        m.register_default_codecs()?;
105
106        let mut registry = Registry::new();
107        registry = register_default_interceptors(registry, &mut m)?;
108
109        // Disable mDNS to prevent CPU spin from orphaned mDNS agents.
110        // See: https://github.com/webrtc-rs/webrtc/issues/616
111        // mDNS is only useful for LAN peer discovery which we don't need.
112        let mut setting_engine = SettingEngine::default();
113        setting_engine.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled);
114
115        let api = APIBuilder::new()
116            .with_media_engine(m)
117            .with_interceptor_registry(registry)
118            .with_setting_engine(setting_engine)
119            .build();
120
121        // Configure ICE servers
122        let ice_servers: Vec<RTCIceServer> = stun_servers
123            .iter()
124            .map(|url| RTCIceServer {
125                urls: vec![url.clone()],
126                ..Default::default()
127            })
128            .collect();
129
130        let config = RTCConfiguration {
131            ice_servers,
132            ..Default::default()
133        };
134
135        let pc = Arc::new(api.new_peer_connection(config).await?);
136        let (message_tx, message_rx) = mpsc::channel(100);
137
138        Ok(Self {
139            peer_id,
140            direction,
141            created_at: std::time::Instant::now(),
142            connected_at: None,
143            pc,
144            data_channel: Arc::new(Mutex::new(None)),
145            signaling_tx,
146            my_peer_id,
147            store,
148            pending_requests: Arc::new(Mutex::new(HashMap::new())),
149            message_tx,
150            message_rx: Some(message_rx),
151            state_event_tx,
152        })
153    }
154
155    /// Set content store
156    pub fn set_store(&mut self, store: Arc<dyn ContentStore>) {
157        self.store = Some(store);
158    }
159
160    /// Get connection state
161    pub fn state(&self) -> RTCPeerConnectionState {
162        self.pc.connection_state()
163    }
164
165    /// Get signaling state
166    pub fn signaling_state(&self) -> webrtc::peer_connection::signaling_state::RTCSignalingState {
167        self.pc.signaling_state()
168    }
169
170    /// Check if connected
171    pub fn is_connected(&self) -> bool {
172        self.pc.connection_state() == RTCPeerConnectionState::Connected
173    }
174
175    /// Setup event handlers for the peer connection
176    pub async fn setup_handlers(&mut self) -> Result<()> {
177        let peer_id = self.peer_id.clone();
178        let signaling_tx = self.signaling_tx.clone();
179        let my_peer_id_str = self.my_peer_id.to_string();
180        let recipient = self.peer_id.to_string();
181
182        // Handle ICE candidates - work MUST be inside the returned future
183        self.pc
184            .on_ice_candidate(Box::new(move |candidate: Option<RTCIceCandidate>| {
185                let signaling_tx = signaling_tx.clone();
186                let my_peer_id_str = my_peer_id_str.clone();
187                let recipient = recipient.clone();
188
189                Box::pin(async move {
190                    if let Some(c) = candidate {
191                        if let Some(init) = c.to_json().ok() {
192                            info!("ICE candidate generated: {}", &init.candidate[..init.candidate.len().min(60)]);
193                            let msg = SignalingMessage::candidate(
194                                serde_json::to_value(&init).unwrap_or_default(),
195                                &recipient,
196                                &my_peer_id_str,
197                            );
198                            if let Err(e) = signaling_tx.send(msg).await {
199                                error!("Failed to send ICE candidate: {}", e);
200                            }
201                        }
202                    }
203                })
204            }));
205
206        // Handle connection state changes - work MUST be inside the returned future
207        let peer_id_log = peer_id.clone();
208        let state_event_tx = self.state_event_tx.clone();
209        self.pc
210            .on_peer_connection_state_change(Box::new(move |state: RTCPeerConnectionState| {
211                let peer_id = peer_id_log.clone();
212                let state_event_tx = state_event_tx.clone();
213                Box::pin(async move {
214                    info!("Peer {} connection state: {:?}", peer_id.short(), state);
215
216                    // Notify signaling layer of state changes
217                    if let Some(tx) = state_event_tx {
218                        let event = match state {
219                            RTCPeerConnectionState::Connected => Some(PeerStateEvent::Connected(peer_id)),
220                            RTCPeerConnectionState::Failed => Some(PeerStateEvent::Failed(peer_id)),
221                            RTCPeerConnectionState::Disconnected | RTCPeerConnectionState::Closed => {
222                                Some(PeerStateEvent::Disconnected(peer_id))
223                            }
224                            _ => None,
225                        };
226                        if let Some(event) = event {
227                            if let Err(e) = tx.send(event).await {
228                                error!("Failed to send peer state event: {}", e);
229                            }
230                        }
231                    }
232                })
233            }));
234
235        Ok(())
236    }
237
238    /// Initiate connection (create offer) - for outbound connections
239    pub async fn connect(&mut self) -> Result<serde_json::Value> {
240        println!("[Peer {}] Creating data channel...", self.peer_id.short());
241        // Create data channel first
242        // Use unordered for better performance - protocol is stateless (each message self-describes)
243        let dc_init = RTCDataChannelInit {
244            ordered: Some(false),
245            ..Default::default()
246        };
247        let dc = self.pc.create_data_channel("hashtree", Some(dc_init)).await?;
248        println!("[Peer {}] Data channel created, setting up handlers...", self.peer_id.short());
249        self.setup_data_channel(dc.clone()).await?;
250        println!("[Peer {}] Handlers set up, storing data channel...", self.peer_id.short());
251        {
252            let mut dc_guard = self.data_channel.lock().await;
253            *dc_guard = Some(dc);
254        }
255        println!("[Peer {}] Data channel stored", self.peer_id.short());
256
257        // Create offer
258        let offer = self.pc.create_offer(None).await?;
259        self.pc.set_local_description(offer.clone()).await?;
260
261        // Return offer as JSON
262        let offer_json = serde_json::json!({
263            "type": offer.sdp_type.to_string().to_lowercase(),
264            "sdp": offer.sdp
265        });
266
267        Ok(offer_json)
268    }
269
270    /// Handle incoming offer and create answer
271    pub async fn handle_offer(&mut self, offer: serde_json::Value) -> Result<serde_json::Value> {
272        let sdp = offer
273            .get("sdp")
274            .and_then(|s| s.as_str())
275            .ok_or_else(|| anyhow::anyhow!("Missing SDP in offer"))?;
276
277        // Setup data channel handler BEFORE set_remote_description
278        // This ensures the handler is registered before any data channel events fire
279        let peer_id = self.peer_id.clone();
280        let message_tx = self.message_tx.clone();
281        let pending_requests = self.pending_requests.clone();
282        let store = self.store.clone();
283        let data_channel_holder = self.data_channel.clone();
284
285        self.pc
286            .on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
287                let peer_id = peer_id.clone();
288                let message_tx = message_tx.clone();
289                let pending_requests = pending_requests.clone();
290                let store = store.clone();
291                let data_channel_holder = data_channel_holder.clone();
292
293                // Work MUST be inside the returned future
294                Box::pin(async move {
295                    info!("Peer {} received data channel: {}", peer_id.short(), dc.label());
296
297                    // Store the received data channel
298                    {
299                        let mut dc_guard = data_channel_holder.lock().await;
300                        *dc_guard = Some(dc.clone());
301                    }
302
303                    // Set up message handlers
304                    Self::setup_dc_handlers(
305                        dc.clone(),
306                        peer_id,
307                        message_tx,
308                        pending_requests,
309                        store,
310                    )
311                    .await;
312                })
313            }));
314
315        // Now set remote description after handler is registered
316        let offer_desc = RTCSessionDescription::offer(sdp.to_string())?;
317        self.pc.set_remote_description(offer_desc).await?;
318
319        // Create answer
320        let answer = self.pc.create_answer(None).await?;
321        self.pc.set_local_description(answer.clone()).await?;
322
323        let answer_json = serde_json::json!({
324            "type": answer.sdp_type.to_string().to_lowercase(),
325            "sdp": answer.sdp
326        });
327
328        Ok(answer_json)
329    }
330
331    /// Handle incoming answer
332    pub async fn handle_answer(&mut self, answer: serde_json::Value) -> Result<()> {
333        let sdp = answer
334            .get("sdp")
335            .and_then(|s| s.as_str())
336            .ok_or_else(|| anyhow::anyhow!("Missing SDP in answer"))?;
337
338        let answer_desc = RTCSessionDescription::answer(sdp.to_string())?;
339        self.pc.set_remote_description(answer_desc).await?;
340
341        Ok(())
342    }
343
344    /// Handle incoming ICE candidate
345    pub async fn handle_candidate(&mut self, candidate: serde_json::Value) -> Result<()> {
346        let candidate_str = candidate
347            .get("candidate")
348            .and_then(|c| c.as_str())
349            .unwrap_or("");
350
351        let sdp_mid = candidate
352            .get("sdpMid")
353            .and_then(|m| m.as_str())
354            .map(|s| s.to_string());
355
356        let sdp_mline_index = candidate
357            .get("sdpMLineIndex")
358            .and_then(|i| i.as_u64())
359            .map(|i| i as u16);
360
361        if !candidate_str.is_empty() {
362            use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
363            let init = RTCIceCandidateInit {
364                candidate: candidate_str.to_string(),
365                sdp_mid,
366                sdp_mline_index,
367                username_fragment: candidate
368                    .get("usernameFragment")
369                    .and_then(|u| u.as_str())
370                    .map(|s| s.to_string()),
371            };
372            self.pc.add_ice_candidate(init).await?;
373        }
374
375        Ok(())
376    }
377
378    /// Setup data channel handlers
379    async fn setup_data_channel(&mut self, dc: Arc<RTCDataChannel>) -> Result<()> {
380        let peer_id = self.peer_id.clone();
381        let message_tx = self.message_tx.clone();
382        let pending_requests = self.pending_requests.clone();
383        let store = self.store.clone();
384
385        Self::setup_dc_handlers(dc, peer_id, message_tx, pending_requests, store).await;
386        Ok(())
387    }
388
389    /// Setup handlers for a data channel (shared between outbound and inbound)
390    async fn setup_dc_handlers(
391        dc: Arc<RTCDataChannel>,
392        peer_id: PeerId,
393        message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
394        pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
395        store: Option<Arc<dyn ContentStore>>,
396    ) {
397        let label = dc.label().to_string();
398        let peer_short = peer_id.short();
399
400        // Track pending binary data (request_id -> expected after response)
401        let _pending_binary: Arc<Mutex<Option<u32>>> = Arc::new(Mutex::new(None));
402
403        let _dc_for_open = dc.clone();
404        let peer_short_open = peer_short.clone();
405        let label_clone = label.clone();
406        dc.on_open(Box::new(move || {
407            let peer_short_open = peer_short_open.clone();
408            let label_clone = label_clone.clone();
409            // Work MUST be inside the returned future
410            Box::pin(async move {
411                info!("[Peer {}] Data channel '{}' open", peer_short_open, label_clone);
412            })
413        }));
414
415        let dc_for_msg = dc.clone();
416        let peer_short_msg = peer_short.clone();
417        let _pending_binary_clone = _pending_binary.clone();
418        let store_clone = store.clone();
419
420        dc.on_message(Box::new(move |msg: DataChannelMessage| {
421            let dc = dc_for_msg.clone();
422            let peer_short = peer_short_msg.clone();
423            let pending_requests = pending_requests.clone();
424            let _pending_binary = _pending_binary_clone.clone();
425            let _message_tx = message_tx.clone();
426            let store = store_clone.clone();
427            let msg_data = msg.data.clone();
428
429            // Work MUST be inside the returned future
430            Box::pin(async move {
431                // All messages are binary with type prefix + MessagePack body
432                debug!("[Peer {}] Received {} bytes on data channel", peer_short, msg_data.len());
433                match parse_message(&msg_data) {
434                    Ok(data_msg) => match data_msg {
435                        DataMessage::Request(req) => {
436                            let hash_hex = hash_to_hex(&req.h);
437                            let hash_short = &hash_hex[..8.min(hash_hex.len())];
438                            info!(
439                                "[Peer {}] Received request for {}",
440                                peer_short, hash_short
441                            );
442
443                            // Handle request - look up in store
444                            let data = if let Some(ref store) = store {
445                                match store.get(&hash_hex) {
446                                    Ok(Some(data)) => {
447                                        info!("[Peer {}] Found {} in store ({} bytes)", peer_short, hash_short, data.len());
448                                        Some(data)
449                                    },
450                                    Ok(None) => {
451                                        info!("[Peer {}] Hash {} not in store", peer_short, hash_short);
452                                        None
453                                    },
454                                    Err(e) => {
455                                        warn!("[Peer {}] Store error: {}", peer_short, e);
456                                        None
457                                    }
458                                }
459                            } else {
460                                warn!("[Peer {}] No store configured - cannot serve requests", peer_short);
461                                None
462                            };
463
464                            // Send response only if we have data
465                            if let Some(data) = data {
466                                let data_len = data.len();
467                                let response = DataResponse {
468                                    h: req.h,
469                                    d: data,
470                                };
471                                if let Ok(wire) = encode_response(&response) {
472                                    if let Err(e) = dc.send(&Bytes::from(wire)).await {
473                                        error!(
474                                            "[Peer {}] Failed to send response: {}",
475                                            peer_short, e
476                                        );
477                                    } else {
478                                        info!(
479                                            "[Peer {}] Sent response for {} ({} bytes)",
480                                            peer_short, hash_short, data_len
481                                        );
482                                    }
483                                }
484                            } else {
485                                info!("[Peer {}] Content not found for {}", peer_short, hash_short);
486                            }
487                        }
488                        DataMessage::Response(res) => {
489                            let hash_hex = hash_to_hex(&res.h);
490                            let hash_short = &hash_hex[..8.min(hash_hex.len())];
491                            debug!(
492                                "[Peer {}] Received response for {} ({} bytes)",
493                                peer_short, hash_short, res.d.len()
494                            );
495
496                            // Resolve the pending request by hash
497                            let mut pending = pending_requests.lock().await;
498                            if let Some(req) = pending.remove(&hash_hex) {
499                                let _ = req.response_tx.send(Some(res.d));
500                            }
501                        }
502                    },
503                    Err(e) => {
504                        warn!("[Peer {}] Failed to parse message: {:?}", peer_short, e);
505                        // Log hex dump of first 50 bytes for debugging
506                        let hex_dump: String = msg_data.iter().take(50).map(|b| format!("{:02x}", b)).collect();
507                        warn!("[Peer {}] Message hex: {}", peer_short, hex_dump);
508                    }
509                }
510            })
511        }));
512    }
513
514    /// Check if data channel is ready
515    pub fn has_data_channel(&self) -> bool {
516        // Use try_lock for non-async context
517        self.data_channel
518            .try_lock()
519            .map(|guard| guard.is_some())
520            .unwrap_or(false)
521    }
522
523    /// Request content by hash from this peer
524    pub async fn request(&self, hash_hex: &str) -> Result<Option<Vec<u8>>> {
525        let dc_guard = self.data_channel.lock().await;
526        let dc = dc_guard
527            .as_ref()
528            .ok_or_else(|| anyhow::anyhow!("No data channel"))?
529            .clone();
530        drop(dc_guard);  // Release lock before async operations
531
532        // Convert hex to binary hash
533        let hash = hex::decode(hash_hex)
534            .map_err(|e| anyhow::anyhow!("Invalid hex hash: {}", e))?;
535
536        // Create response channel
537        let (tx, rx) = oneshot::channel();
538
539        // Store pending request (keyed by hash hex)
540        {
541            let mut pending = self.pending_requests.lock().await;
542            pending.insert(
543                hash_hex.to_string(),
544                PendingRequest {
545                    hash: hash.clone(),
546                    response_tx: tx,
547                },
548            );
549        }
550
551        // Send request with MAX_HTL (fresh request from us)
552        let req = DataRequest {
553            h: hash,
554            htl: crate::webrtc::types::MAX_HTL,
555        };
556        let wire = encode_request(&req)?;
557        dc.send(&Bytes::from(wire)).await?;
558
559        debug!(
560            "[Peer {}] Sent request for {}",
561            self.peer_id.short(),
562            &hash_hex[..8.min(hash_hex.len())]
563        );
564
565        // Wait for response with timeout
566        match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
567            Ok(Ok(data)) => Ok(data),
568            Ok(Err(_)) => {
569                // Channel closed
570                Ok(None)
571            }
572            Err(_) => {
573                // Timeout - clean up pending request
574                let mut pending = self.pending_requests.lock().await;
575                pending.remove(hash_hex);
576                Ok(None)
577            }
578        }
579    }
580
581    /// Send a message over the data channel
582    pub async fn send_message(&self, msg: &DataMessage) -> Result<()> {
583        let dc_guard = self.data_channel.lock().await;
584        if let Some(ref dc) = *dc_guard {
585            let wire = encode_message(msg)?;
586            dc.send(&Bytes::from(wire)).await?;
587        }
588        Ok(())
589    }
590
591    /// Close the connection
592    pub async fn close(&self) -> Result<()> {
593        {
594            let dc_guard = self.data_channel.lock().await;
595            if let Some(ref dc) = *dc_guard {
596                dc.close().await?;
597            }
598        }
599        self.pc.close().await?;
600        Ok(())
601    }
602}