hashtree_cli/webrtc/
peer.rs

1//! WebRTC peer connection for hashtree data exchange
2
3use anyhow::Result;
4use bytes::Bytes;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tracing::{debug, error, info, warn};
9use webrtc::api::interceptor_registry::register_default_interceptors;
10use webrtc::api::media_engine::MediaEngine;
11use webrtc::api::setting_engine::SettingEngine;
12use webrtc::api::APIBuilder;
13use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
14use webrtc::data_channel::data_channel_message::DataChannelMessage;
15use webrtc::data_channel::RTCDataChannel;
16use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
17use webrtc::ice_transport::ice_server::RTCIceServer;
18use webrtc::interceptor::registry::Registry;
19use webrtc::peer_connection::configuration::RTCConfiguration;
20use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
21use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
22use webrtc::peer_connection::RTCPeerConnection;
23
24use super::types::{DataMessage, DataRequest, DataResponse, PeerDirection, PeerId, PeerStateEvent, SignalingMessage, encode_message, encode_request, encode_response, parse_message, hash_to_hex};
25
26/// Trait for content storage that can be used by WebRTC peers
27pub trait ContentStore: Send + Sync + 'static {
28    /// Get content by hex hash
29    fn get(&self, hash_hex: &str) -> Result<Option<Vec<u8>>>;
30}
31
32/// Pending request tracking (keyed by hash hex)
33pub struct PendingRequest {
34    pub hash: Vec<u8>,
35    pub response_tx: oneshot::Sender<Option<Vec<u8>>>,
36}
37
38/// WebRTC peer connection with data channel protocol
39pub struct Peer {
40    pub peer_id: PeerId,
41    pub direction: PeerDirection,
42    pub created_at: std::time::Instant,
43    pub connected_at: Option<std::time::Instant>,
44
45    pc: Arc<RTCPeerConnection>,
46    /// Data channel - can be set from callback when receiving channel from peer
47    pub data_channel: Arc<Mutex<Option<Arc<RTCDataChannel>>>>,
48    signaling_tx: mpsc::Sender<SignalingMessage>,
49    my_peer_id: PeerId,
50
51    // Content store for serving requests
52    store: Option<Arc<dyn ContentStore>>,
53
54    // Track pending outgoing requests (keyed by hash hex)
55    pub pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
56
57    // Channel for incoming data messages
58    #[allow(dead_code)]
59    message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
60    #[allow(dead_code)]
61    message_rx: Option<mpsc::Receiver<(DataMessage, Option<Vec<u8>>)>>,
62
63    // Optional channel to notify signaling layer of state changes
64    state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
65}
66
67impl Peer {
68    /// Create a new peer connection
69    pub async fn new(
70        peer_id: PeerId,
71        direction: PeerDirection,
72        my_peer_id: PeerId,
73        signaling_tx: mpsc::Sender<SignalingMessage>,
74        stun_servers: Vec<String>,
75    ) -> Result<Self> {
76        Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, None, None).await
77    }
78
79    /// Create a new peer connection with content store
80    pub async fn new_with_store(
81        peer_id: PeerId,
82        direction: PeerDirection,
83        my_peer_id: PeerId,
84        signaling_tx: mpsc::Sender<SignalingMessage>,
85        stun_servers: Vec<String>,
86        store: Option<Arc<dyn ContentStore>>,
87    ) -> Result<Self> {
88        Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, store, None).await
89    }
90
91    /// Create a new peer connection with content store and state event channel
92    pub async fn new_with_store_and_events(
93        peer_id: PeerId,
94        direction: PeerDirection,
95        my_peer_id: PeerId,
96        signaling_tx: mpsc::Sender<SignalingMessage>,
97        stun_servers: Vec<String>,
98        store: Option<Arc<dyn ContentStore>>,
99        state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
100    ) -> Result<Self> {
101        // Create WebRTC API
102        let mut m = MediaEngine::default();
103        m.register_default_codecs()?;
104
105        let mut registry = Registry::new();
106        registry = register_default_interceptors(registry, &mut m)?;
107
108        // Enable mDNS temporarily for debugging
109        // Previously disabled due to https://github.com/webrtc-rs/webrtc/issues/616
110        let setting_engine = SettingEngine::default();
111        // Note: mDNS enabled by default
112
113        let api = APIBuilder::new()
114            .with_media_engine(m)
115            .with_interceptor_registry(registry)
116            .with_setting_engine(setting_engine)
117            .build();
118
119        // Configure ICE servers
120        let ice_servers: Vec<RTCIceServer> = stun_servers
121            .iter()
122            .map(|url| RTCIceServer {
123                urls: vec![url.clone()],
124                ..Default::default()
125            })
126            .collect();
127
128        let config = RTCConfiguration {
129            ice_servers,
130            ..Default::default()
131        };
132
133        let pc = Arc::new(api.new_peer_connection(config).await?);
134        let (message_tx, message_rx) = mpsc::channel(100);
135        Ok(Self {
136            peer_id,
137            direction,
138            created_at: std::time::Instant::now(),
139            connected_at: None,
140            pc,
141            data_channel: Arc::new(Mutex::new(None)),
142            signaling_tx,
143            my_peer_id,
144            store,
145            pending_requests: Arc::new(Mutex::new(HashMap::new())),
146            message_tx,
147            message_rx: Some(message_rx),
148            state_event_tx,
149        })
150    }
151
152    /// Set content store
153    pub fn set_store(&mut self, store: Arc<dyn ContentStore>) {
154        self.store = Some(store);
155    }
156
157    /// Get connection state
158    pub fn state(&self) -> RTCPeerConnectionState {
159        self.pc.connection_state()
160    }
161
162    /// Get signaling state
163    pub fn signaling_state(&self) -> webrtc::peer_connection::signaling_state::RTCSignalingState {
164        self.pc.signaling_state()
165    }
166
167    /// Check if connected
168    pub fn is_connected(&self) -> bool {
169        self.pc.connection_state() == RTCPeerConnectionState::Connected
170    }
171
172    /// Setup event handlers for the peer connection
173    pub async fn setup_handlers(&mut self) -> Result<()> {
174        let peer_id = self.peer_id.clone();
175        let signaling_tx = self.signaling_tx.clone();
176        let my_peer_id_str = self.my_peer_id.to_string();
177        let recipient = self.peer_id.to_string();
178
179        // Handle ICE candidates - work MUST be inside the returned future
180        self.pc
181            .on_ice_candidate(Box::new(move |candidate: Option<RTCIceCandidate>| {
182                let signaling_tx = signaling_tx.clone();
183                let my_peer_id_str = my_peer_id_str.clone();
184                let recipient = recipient.clone();
185
186                Box::pin(async move {
187                    if let Some(c) = candidate {
188                        if let Some(init) = c.to_json().ok() {
189                            info!("ICE candidate generated: {}", &init.candidate[..init.candidate.len().min(60)]);
190                            let msg = SignalingMessage::candidate(
191                                serde_json::to_value(&init).unwrap_or_default(),
192                                &recipient,
193                                &my_peer_id_str,
194                            );
195                            if let Err(e) = signaling_tx.send(msg).await {
196                                error!("Failed to send ICE candidate: {}", e);
197                            }
198                        }
199                    }
200                })
201            }));
202
203        // Handle connection state changes - work MUST be inside the returned future
204        let peer_id_log = peer_id.clone();
205        let state_event_tx = self.state_event_tx.clone();
206        self.pc
207            .on_peer_connection_state_change(Box::new(move |state: RTCPeerConnectionState| {
208                let peer_id = peer_id_log.clone();
209                let state_event_tx = state_event_tx.clone();
210                Box::pin(async move {
211                    info!("Peer {} connection state: {:?}", peer_id.short(), state);
212
213                    // Notify signaling layer of state changes
214                    if let Some(tx) = state_event_tx {
215                        let event = match state {
216                            RTCPeerConnectionState::Connected => Some(PeerStateEvent::Connected(peer_id)),
217                            RTCPeerConnectionState::Failed => Some(PeerStateEvent::Failed(peer_id)),
218                            RTCPeerConnectionState::Disconnected | RTCPeerConnectionState::Closed => {
219                                Some(PeerStateEvent::Disconnected(peer_id))
220                            }
221                            _ => None,
222                        };
223                        if let Some(event) = event {
224                            if let Err(e) = tx.send(event).await {
225                                error!("Failed to send peer state event: {}", e);
226                            }
227                        }
228                    }
229                })
230            }));
231
232        Ok(())
233    }
234
235    /// Initiate connection (create offer) - for outbound connections
236    pub async fn connect(&mut self) -> Result<serde_json::Value> {
237        println!("[Peer {}] Creating data channel...", self.peer_id.short());
238        // Create data channel first
239        // Use unordered for better performance - protocol is stateless (each message self-describes)
240        let dc_init = RTCDataChannelInit {
241            ordered: Some(false),
242            ..Default::default()
243        };
244        let dc = self.pc.create_data_channel("hashtree", Some(dc_init)).await?;
245        println!("[Peer {}] Data channel created, setting up handlers...", self.peer_id.short());
246        self.setup_data_channel(dc.clone()).await?;
247        println!("[Peer {}] Handlers set up, storing data channel...", self.peer_id.short());
248        {
249            let mut dc_guard = self.data_channel.lock().await;
250            *dc_guard = Some(dc);
251        }
252        println!("[Peer {}] Data channel stored", self.peer_id.short());
253
254        // Create offer and wait for ICE gathering to complete
255        // This ensures all ICE candidates are embedded in the SDP
256        let offer = self.pc.create_offer(None).await?;
257        let mut gathering_complete = self.pc.gathering_complete_promise().await;
258        self.pc.set_local_description(offer).await?;
259
260        // Wait for ICE gathering to complete (with timeout)
261        let _ = tokio::time::timeout(
262            std::time::Duration::from_secs(10),
263            gathering_complete.recv()
264        ).await;
265
266        // Get the local description with ICE candidates embedded
267        let local_desc = self.pc.local_description().await
268            .ok_or_else(|| anyhow::anyhow!("No local description after gathering"))?;
269
270        debug!("Offer created, SDP len: {}, ice_gathering: {:?}",
271            local_desc.sdp.len(), self.pc.ice_gathering_state());
272
273        // Return offer as JSON
274        let offer_json = serde_json::json!({
275            "type": local_desc.sdp_type.to_string().to_lowercase(),
276            "sdp": local_desc.sdp
277        });
278
279        Ok(offer_json)
280    }
281
282    /// Handle incoming offer and create answer
283    pub async fn handle_offer(&mut self, offer: serde_json::Value) -> Result<serde_json::Value> {
284        let sdp = offer
285            .get("sdp")
286            .and_then(|s| s.as_str())
287            .ok_or_else(|| anyhow::anyhow!("Missing SDP in offer"))?;
288
289        // Setup data channel handler BEFORE set_remote_description
290        // This ensures the handler is registered before any data channel events fire
291        let peer_id = self.peer_id.clone();
292        let message_tx = self.message_tx.clone();
293        let pending_requests = self.pending_requests.clone();
294        let store = self.store.clone();
295        let data_channel_holder = self.data_channel.clone();
296
297        self.pc
298            .on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
299                let peer_id = peer_id.clone();
300                let message_tx = message_tx.clone();
301                let pending_requests = pending_requests.clone();
302                let store = store.clone();
303                let data_channel_holder = data_channel_holder.clone();
304
305                // Work MUST be inside the returned future
306                Box::pin(async move {
307                    info!("Peer {} received data channel: {}", peer_id.short(), dc.label());
308
309                    // Store the received data channel
310                    {
311                        let mut dc_guard = data_channel_holder.lock().await;
312                        *dc_guard = Some(dc.clone());
313                    }
314
315                    // Set up message handlers
316                    Self::setup_dc_handlers(
317                        dc.clone(),
318                        peer_id,
319                        message_tx,
320                        pending_requests,
321                        store,
322                    )
323                    .await;
324                })
325            }));
326
327        // Set remote description after handler is registered
328        let offer_desc = RTCSessionDescription::offer(sdp.to_string())?;
329        self.pc.set_remote_description(offer_desc).await?;
330
331        // Create answer and wait for ICE gathering to complete
332        // This ensures all ICE candidates are embedded in the SDP
333        let answer = self.pc.create_answer(None).await?;
334        let mut gathering_complete = self.pc.gathering_complete_promise().await;
335        self.pc.set_local_description(answer).await?;
336
337        // Wait for ICE gathering to complete (with timeout)
338        let _ = tokio::time::timeout(
339            std::time::Duration::from_secs(10),
340            gathering_complete.recv()
341        ).await;
342
343        // Get the local description with ICE candidates embedded
344        let local_desc = self.pc.local_description().await
345            .ok_or_else(|| anyhow::anyhow!("No local description after gathering"))?;
346
347        debug!("Answer created, SDP len: {}, ice_gathering: {:?}",
348            local_desc.sdp.len(), self.pc.ice_gathering_state());
349
350        let answer_json = serde_json::json!({
351            "type": local_desc.sdp_type.to_string().to_lowercase(),
352            "sdp": local_desc.sdp
353        });
354
355        Ok(answer_json)
356    }
357
358    /// Handle incoming answer
359    pub async fn handle_answer(&mut self, answer: serde_json::Value) -> Result<()> {
360        let sdp = answer
361            .get("sdp")
362            .and_then(|s| s.as_str())
363            .ok_or_else(|| anyhow::anyhow!("Missing SDP in answer"))?;
364
365        let answer_desc = RTCSessionDescription::answer(sdp.to_string())?;
366        self.pc.set_remote_description(answer_desc).await?;
367
368        Ok(())
369    }
370
371    /// Handle incoming ICE candidate
372    pub async fn handle_candidate(&mut self, candidate: serde_json::Value) -> Result<()> {
373        let candidate_str = candidate
374            .get("candidate")
375            .and_then(|c| c.as_str())
376            .unwrap_or("");
377
378        let sdp_mid = candidate
379            .get("sdpMid")
380            .and_then(|m| m.as_str())
381            .map(|s| s.to_string());
382
383        let sdp_mline_index = candidate
384            .get("sdpMLineIndex")
385            .and_then(|i| i.as_u64())
386            .map(|i| i as u16);
387
388        if !candidate_str.is_empty() {
389            use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
390            let init = RTCIceCandidateInit {
391                candidate: candidate_str.to_string(),
392                sdp_mid,
393                sdp_mline_index,
394                username_fragment: candidate
395                    .get("usernameFragment")
396                    .and_then(|u| u.as_str())
397                    .map(|s| s.to_string()),
398            };
399            self.pc.add_ice_candidate(init).await?;
400        }
401
402        Ok(())
403    }
404
405    /// Setup data channel handlers
406    async fn setup_data_channel(&mut self, dc: Arc<RTCDataChannel>) -> Result<()> {
407        let peer_id = self.peer_id.clone();
408        let message_tx = self.message_tx.clone();
409        let pending_requests = self.pending_requests.clone();
410        let store = self.store.clone();
411
412        Self::setup_dc_handlers(dc, peer_id, message_tx, pending_requests, store).await;
413        Ok(())
414    }
415
416    /// Setup handlers for a data channel (shared between outbound and inbound)
417    async fn setup_dc_handlers(
418        dc: Arc<RTCDataChannel>,
419        peer_id: PeerId,
420        message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
421        pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
422        store: Option<Arc<dyn ContentStore>>,
423    ) {
424        let label = dc.label().to_string();
425        let peer_short = peer_id.short();
426
427        // Track pending binary data (request_id -> expected after response)
428        let _pending_binary: Arc<Mutex<Option<u32>>> = Arc::new(Mutex::new(None));
429
430        let _dc_for_open = dc.clone();
431        let peer_short_open = peer_short.clone();
432        let label_clone = label.clone();
433        dc.on_open(Box::new(move || {
434            let peer_short_open = peer_short_open.clone();
435            let label_clone = label_clone.clone();
436            // Work MUST be inside the returned future
437            Box::pin(async move {
438                info!("[Peer {}] Data channel '{}' open", peer_short_open, label_clone);
439            })
440        }));
441
442        let dc_for_msg = dc.clone();
443        let peer_short_msg = peer_short.clone();
444        let _pending_binary_clone = _pending_binary.clone();
445        let store_clone = store.clone();
446
447        dc.on_message(Box::new(move |msg: DataChannelMessage| {
448            let dc = dc_for_msg.clone();
449            let peer_short = peer_short_msg.clone();
450            let pending_requests = pending_requests.clone();
451            let _pending_binary = _pending_binary_clone.clone();
452            let _message_tx = message_tx.clone();
453            let store = store_clone.clone();
454            let msg_data = msg.data.clone();
455
456            // Work MUST be inside the returned future
457            Box::pin(async move {
458                // All messages are binary with type prefix + MessagePack body
459                debug!("[Peer {}] Received {} bytes on data channel", peer_short, msg_data.len());
460                match parse_message(&msg_data) {
461                    Ok(data_msg) => match data_msg {
462                        DataMessage::Request(req) => {
463                            let hash_hex = hash_to_hex(&req.h);
464                            let hash_short = &hash_hex[..8.min(hash_hex.len())];
465                            info!(
466                                "[Peer {}] Received request for {}",
467                                peer_short, hash_short
468                            );
469
470                            // Handle request - look up in store
471                            let data = if let Some(ref store) = store {
472                                match store.get(&hash_hex) {
473                                    Ok(Some(data)) => {
474                                        info!("[Peer {}] Found {} in store ({} bytes)", peer_short, hash_short, data.len());
475                                        Some(data)
476                                    },
477                                    Ok(None) => {
478                                        info!("[Peer {}] Hash {} not in store", peer_short, hash_short);
479                                        None
480                                    },
481                                    Err(e) => {
482                                        warn!("[Peer {}] Store error: {}", peer_short, e);
483                                        None
484                                    }
485                                }
486                            } else {
487                                warn!("[Peer {}] No store configured - cannot serve requests", peer_short);
488                                None
489                            };
490
491                            // Send response only if we have data
492                            if let Some(data) = data {
493                                let data_len = data.len();
494                                let response = DataResponse {
495                                    h: req.h,
496                                    d: data,
497                                };
498                                if let Ok(wire) = encode_response(&response) {
499                                    if let Err(e) = dc.send(&Bytes::from(wire)).await {
500                                        error!(
501                                            "[Peer {}] Failed to send response: {}",
502                                            peer_short, e
503                                        );
504                                    } else {
505                                        info!(
506                                            "[Peer {}] Sent response for {} ({} bytes)",
507                                            peer_short, hash_short, data_len
508                                        );
509                                    }
510                                }
511                            } else {
512                                info!("[Peer {}] Content not found for {}", peer_short, hash_short);
513                            }
514                        }
515                        DataMessage::Response(res) => {
516                            let hash_hex = hash_to_hex(&res.h);
517                            let hash_short = &hash_hex[..8.min(hash_hex.len())];
518                            debug!(
519                                "[Peer {}] Received response for {} ({} bytes)",
520                                peer_short, hash_short, res.d.len()
521                            );
522
523                            // Resolve the pending request by hash
524                            let mut pending = pending_requests.lock().await;
525                            if let Some(req) = pending.remove(&hash_hex) {
526                                let _ = req.response_tx.send(Some(res.d));
527                            }
528                        }
529                    },
530                    Err(e) => {
531                        warn!("[Peer {}] Failed to parse message: {:?}", peer_short, e);
532                        // Log hex dump of first 50 bytes for debugging
533                        let hex_dump: String = msg_data.iter().take(50).map(|b| format!("{:02x}", b)).collect();
534                        warn!("[Peer {}] Message hex: {}", peer_short, hex_dump);
535                    }
536                }
537            })
538        }));
539    }
540
541    /// Check if data channel is ready
542    pub fn has_data_channel(&self) -> bool {
543        // Use try_lock for non-async context
544        self.data_channel
545            .try_lock()
546            .map(|guard| guard.is_some())
547            .unwrap_or(false)
548    }
549
550    /// Request content by hash from this peer
551    pub async fn request(&self, hash_hex: &str) -> Result<Option<Vec<u8>>> {
552        let dc_guard = self.data_channel.lock().await;
553        let dc = dc_guard
554            .as_ref()
555            .ok_or_else(|| anyhow::anyhow!("No data channel"))?
556            .clone();
557        drop(dc_guard);  // Release lock before async operations
558
559        // Convert hex to binary hash
560        let hash = hex::decode(hash_hex)
561            .map_err(|e| anyhow::anyhow!("Invalid hex hash: {}", e))?;
562
563        // Create response channel
564        let (tx, rx) = oneshot::channel();
565
566        // Store pending request (keyed by hash hex)
567        {
568            let mut pending = self.pending_requests.lock().await;
569            pending.insert(
570                hash_hex.to_string(),
571                PendingRequest {
572                    hash: hash.clone(),
573                    response_tx: tx,
574                },
575            );
576        }
577
578        // Send request with MAX_HTL (fresh request from us)
579        let req = DataRequest {
580            h: hash,
581            htl: crate::webrtc::types::MAX_HTL,
582        };
583        let wire = encode_request(&req)?;
584        dc.send(&Bytes::from(wire)).await?;
585
586        debug!(
587            "[Peer {}] Sent request for {}",
588            self.peer_id.short(),
589            &hash_hex[..8.min(hash_hex.len())]
590        );
591
592        // Wait for response with timeout
593        match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
594            Ok(Ok(data)) => Ok(data),
595            Ok(Err(_)) => {
596                // Channel closed
597                Ok(None)
598            }
599            Err(_) => {
600                // Timeout - clean up pending request
601                let mut pending = self.pending_requests.lock().await;
602                pending.remove(hash_hex);
603                Ok(None)
604            }
605        }
606    }
607
608    /// Send a message over the data channel
609    pub async fn send_message(&self, msg: &DataMessage) -> Result<()> {
610        let dc_guard = self.data_channel.lock().await;
611        if let Some(ref dc) = *dc_guard {
612            let wire = encode_message(msg)?;
613            dc.send(&Bytes::from(wire)).await?;
614        }
615        Ok(())
616    }
617
618    /// Close the connection
619    pub async fn close(&self) -> Result<()> {
620        {
621            let dc_guard = self.data_channel.lock().await;
622            if let Some(ref dc) = *dc_guard {
623                dc.close().await?;
624            }
625        }
626        self.pc.close().await?;
627        Ok(())
628    }
629}