hashtree_cli/webrtc/
peer.rs

1//! WebRTC peer connection for hashtree data exchange
2
3use anyhow::Result;
4use bytes::Bytes;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tracing::{debug, error, info, warn};
9use webrtc::api::interceptor_registry::register_default_interceptors;
10use webrtc::api::media_engine::MediaEngine;
11use webrtc::api::APIBuilder;
12use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
13use webrtc::data_channel::data_channel_message::DataChannelMessage;
14use webrtc::data_channel::RTCDataChannel;
15use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
16use webrtc::ice_transport::ice_server::RTCIceServer;
17use webrtc::interceptor::registry::Registry;
18use webrtc::peer_connection::configuration::RTCConfiguration;
19use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
20use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
21use webrtc::peer_connection::RTCPeerConnection;
22
23use super::types::{DataMessage, DataRequest, DataResponse, PeerDirection, PeerId, PeerStateEvent, SignalingMessage, encode_message, encode_request, encode_response, parse_message, hash_to_hex};
24
25/// Trait for content storage that can be used by WebRTC peers
26pub trait ContentStore: Send + Sync + 'static {
27    /// Get content by hex hash
28    fn get(&self, hash_hex: &str) -> Result<Option<Vec<u8>>>;
29}
30
31/// Pending request tracking (keyed by hash hex)
32pub struct PendingRequest {
33    pub hash: Vec<u8>,
34    pub response_tx: oneshot::Sender<Option<Vec<u8>>>,
35}
36
37/// WebRTC peer connection with data channel protocol
38pub struct Peer {
39    pub peer_id: PeerId,
40    pub direction: PeerDirection,
41    pub created_at: std::time::Instant,
42    pub connected_at: Option<std::time::Instant>,
43
44    pc: Arc<RTCPeerConnection>,
45    /// Data channel - can be set from callback when receiving channel from peer
46    pub data_channel: Arc<Mutex<Option<Arc<RTCDataChannel>>>>,
47    signaling_tx: mpsc::Sender<SignalingMessage>,
48    my_peer_id: PeerId,
49
50    // Content store for serving requests
51    store: Option<Arc<dyn ContentStore>>,
52
53    // Track pending outgoing requests (keyed by hash hex)
54    pub pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
55
56    // Channel for incoming data messages
57    #[allow(dead_code)]
58    message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
59    #[allow(dead_code)]
60    message_rx: Option<mpsc::Receiver<(DataMessage, Option<Vec<u8>>)>>,
61
62    // Optional channel to notify signaling layer of state changes
63    state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
64}
65
66impl Peer {
67    /// Create a new peer connection
68    pub async fn new(
69        peer_id: PeerId,
70        direction: PeerDirection,
71        my_peer_id: PeerId,
72        signaling_tx: mpsc::Sender<SignalingMessage>,
73        stun_servers: Vec<String>,
74    ) -> Result<Self> {
75        Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, None, None).await
76    }
77
78    /// Create a new peer connection with content store
79    pub async fn new_with_store(
80        peer_id: PeerId,
81        direction: PeerDirection,
82        my_peer_id: PeerId,
83        signaling_tx: mpsc::Sender<SignalingMessage>,
84        stun_servers: Vec<String>,
85        store: Option<Arc<dyn ContentStore>>,
86    ) -> Result<Self> {
87        Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, store, None).await
88    }
89
90    /// Create a new peer connection with content store and state event channel
91    pub async fn new_with_store_and_events(
92        peer_id: PeerId,
93        direction: PeerDirection,
94        my_peer_id: PeerId,
95        signaling_tx: mpsc::Sender<SignalingMessage>,
96        stun_servers: Vec<String>,
97        store: Option<Arc<dyn ContentStore>>,
98        state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
99    ) -> Result<Self> {
100        // Create WebRTC API
101        let mut m = MediaEngine::default();
102        m.register_default_codecs()?;
103
104        let mut registry = Registry::new();
105        registry = register_default_interceptors(registry, &mut m)?;
106
107        let api = APIBuilder::new()
108            .with_media_engine(m)
109            .with_interceptor_registry(registry)
110            .build();
111
112        // Configure ICE servers
113        let ice_servers: Vec<RTCIceServer> = stun_servers
114            .iter()
115            .map(|url| RTCIceServer {
116                urls: vec![url.clone()],
117                ..Default::default()
118            })
119            .collect();
120
121        let config = RTCConfiguration {
122            ice_servers,
123            ..Default::default()
124        };
125
126        let pc = Arc::new(api.new_peer_connection(config).await?);
127        let (message_tx, message_rx) = mpsc::channel(100);
128
129        Ok(Self {
130            peer_id,
131            direction,
132            created_at: std::time::Instant::now(),
133            connected_at: None,
134            pc,
135            data_channel: Arc::new(Mutex::new(None)),
136            signaling_tx,
137            my_peer_id,
138            store,
139            pending_requests: Arc::new(Mutex::new(HashMap::new())),
140            message_tx,
141            message_rx: Some(message_rx),
142            state_event_tx,
143        })
144    }
145
146    /// Set content store
147    pub fn set_store(&mut self, store: Arc<dyn ContentStore>) {
148        self.store = Some(store);
149    }
150
151    /// Get connection state
152    pub fn state(&self) -> RTCPeerConnectionState {
153        self.pc.connection_state()
154    }
155
156    /// Get signaling state
157    pub fn signaling_state(&self) -> webrtc::peer_connection::signaling_state::RTCSignalingState {
158        self.pc.signaling_state()
159    }
160
161    /// Check if connected
162    pub fn is_connected(&self) -> bool {
163        self.pc.connection_state() == RTCPeerConnectionState::Connected
164    }
165
166    /// Setup event handlers for the peer connection
167    pub async fn setup_handlers(&mut self) -> Result<()> {
168        let peer_id = self.peer_id.clone();
169        let signaling_tx = self.signaling_tx.clone();
170        let my_peer_id_str = self.my_peer_id.to_string();
171        let recipient = self.peer_id.to_string();
172
173        // Handle ICE candidates - work MUST be inside the returned future
174        self.pc
175            .on_ice_candidate(Box::new(move |candidate: Option<RTCIceCandidate>| {
176                let signaling_tx = signaling_tx.clone();
177                let my_peer_id_str = my_peer_id_str.clone();
178                let recipient = recipient.clone();
179
180                Box::pin(async move {
181                    if let Some(c) = candidate {
182                        if let Some(init) = c.to_json().ok() {
183                            info!("ICE candidate generated: {}", &init.candidate[..init.candidate.len().min(60)]);
184                            let msg = SignalingMessage::candidate(
185                                serde_json::to_value(&init).unwrap_or_default(),
186                                &recipient,
187                                &my_peer_id_str,
188                            );
189                            if let Err(e) = signaling_tx.send(msg).await {
190                                error!("Failed to send ICE candidate: {}", e);
191                            }
192                        }
193                    }
194                })
195            }));
196
197        // Handle connection state changes - work MUST be inside the returned future
198        let peer_id_log = peer_id.clone();
199        let state_event_tx = self.state_event_tx.clone();
200        self.pc
201            .on_peer_connection_state_change(Box::new(move |state: RTCPeerConnectionState| {
202                let peer_id = peer_id_log.clone();
203                let state_event_tx = state_event_tx.clone();
204                Box::pin(async move {
205                    info!("Peer {} connection state: {:?}", peer_id.short(), state);
206
207                    // Notify signaling layer of state changes
208                    if let Some(tx) = state_event_tx {
209                        let event = match state {
210                            RTCPeerConnectionState::Connected => Some(PeerStateEvent::Connected(peer_id)),
211                            RTCPeerConnectionState::Failed => Some(PeerStateEvent::Failed(peer_id)),
212                            RTCPeerConnectionState::Disconnected | RTCPeerConnectionState::Closed => {
213                                Some(PeerStateEvent::Disconnected(peer_id))
214                            }
215                            _ => None,
216                        };
217                        if let Some(event) = event {
218                            if let Err(e) = tx.send(event).await {
219                                error!("Failed to send peer state event: {}", e);
220                            }
221                        }
222                    }
223                })
224            }));
225
226        Ok(())
227    }
228
229    /// Initiate connection (create offer) - for outbound connections
230    pub async fn connect(&mut self) -> Result<serde_json::Value> {
231        println!("[Peer {}] Creating data channel...", self.peer_id.short());
232        // Create data channel first
233        // Use unordered for better performance - protocol is stateless (each message self-describes)
234        let dc_init = RTCDataChannelInit {
235            ordered: Some(false),
236            ..Default::default()
237        };
238        let dc = self.pc.create_data_channel("hashtree", Some(dc_init)).await?;
239        println!("[Peer {}] Data channel created, setting up handlers...", self.peer_id.short());
240        self.setup_data_channel(dc.clone()).await?;
241        println!("[Peer {}] Handlers set up, storing data channel...", self.peer_id.short());
242        {
243            let mut dc_guard = self.data_channel.lock().await;
244            *dc_guard = Some(dc);
245        }
246        println!("[Peer {}] Data channel stored", self.peer_id.short());
247
248        // Create offer
249        let offer = self.pc.create_offer(None).await?;
250        self.pc.set_local_description(offer.clone()).await?;
251
252        // Return offer as JSON
253        let offer_json = serde_json::json!({
254            "type": offer.sdp_type.to_string().to_lowercase(),
255            "sdp": offer.sdp
256        });
257
258        Ok(offer_json)
259    }
260
261    /// Handle incoming offer and create answer
262    pub async fn handle_offer(&mut self, offer: serde_json::Value) -> Result<serde_json::Value> {
263        let sdp = offer
264            .get("sdp")
265            .and_then(|s| s.as_str())
266            .ok_or_else(|| anyhow::anyhow!("Missing SDP in offer"))?;
267
268        // Setup data channel handler BEFORE set_remote_description
269        // This ensures the handler is registered before any data channel events fire
270        let peer_id = self.peer_id.clone();
271        let message_tx = self.message_tx.clone();
272        let pending_requests = self.pending_requests.clone();
273        let store = self.store.clone();
274        let data_channel_holder = self.data_channel.clone();
275
276        self.pc
277            .on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
278                let peer_id = peer_id.clone();
279                let message_tx = message_tx.clone();
280                let pending_requests = pending_requests.clone();
281                let store = store.clone();
282                let data_channel_holder = data_channel_holder.clone();
283
284                // Work MUST be inside the returned future
285                Box::pin(async move {
286                    info!("Peer {} received data channel: {}", peer_id.short(), dc.label());
287
288                    // Store the received data channel
289                    {
290                        let mut dc_guard = data_channel_holder.lock().await;
291                        *dc_guard = Some(dc.clone());
292                    }
293
294                    // Set up message handlers
295                    Self::setup_dc_handlers(
296                        dc.clone(),
297                        peer_id,
298                        message_tx,
299                        pending_requests,
300                        store,
301                    )
302                    .await;
303                })
304            }));
305
306        // Now set remote description after handler is registered
307        let offer_desc = RTCSessionDescription::offer(sdp.to_string())?;
308        self.pc.set_remote_description(offer_desc).await?;
309
310        // Create answer
311        let answer = self.pc.create_answer(None).await?;
312        self.pc.set_local_description(answer.clone()).await?;
313
314        let answer_json = serde_json::json!({
315            "type": answer.sdp_type.to_string().to_lowercase(),
316            "sdp": answer.sdp
317        });
318
319        Ok(answer_json)
320    }
321
322    /// Handle incoming answer
323    pub async fn handle_answer(&mut self, answer: serde_json::Value) -> Result<()> {
324        let sdp = answer
325            .get("sdp")
326            .and_then(|s| s.as_str())
327            .ok_or_else(|| anyhow::anyhow!("Missing SDP in answer"))?;
328
329        let answer_desc = RTCSessionDescription::answer(sdp.to_string())?;
330        self.pc.set_remote_description(answer_desc).await?;
331
332        Ok(())
333    }
334
335    /// Handle incoming ICE candidate
336    pub async fn handle_candidate(&mut self, candidate: serde_json::Value) -> Result<()> {
337        let candidate_str = candidate
338            .get("candidate")
339            .and_then(|c| c.as_str())
340            .unwrap_or("");
341
342        let sdp_mid = candidate
343            .get("sdpMid")
344            .and_then(|m| m.as_str())
345            .map(|s| s.to_string());
346
347        let sdp_mline_index = candidate
348            .get("sdpMLineIndex")
349            .and_then(|i| i.as_u64())
350            .map(|i| i as u16);
351
352        if !candidate_str.is_empty() {
353            use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
354            let init = RTCIceCandidateInit {
355                candidate: candidate_str.to_string(),
356                sdp_mid,
357                sdp_mline_index,
358                username_fragment: candidate
359                    .get("usernameFragment")
360                    .and_then(|u| u.as_str())
361                    .map(|s| s.to_string()),
362            };
363            self.pc.add_ice_candidate(init).await?;
364        }
365
366        Ok(())
367    }
368
369    /// Setup data channel handlers
370    async fn setup_data_channel(&mut self, dc: Arc<RTCDataChannel>) -> Result<()> {
371        let peer_id = self.peer_id.clone();
372        let message_tx = self.message_tx.clone();
373        let pending_requests = self.pending_requests.clone();
374        let store = self.store.clone();
375
376        Self::setup_dc_handlers(dc, peer_id, message_tx, pending_requests, store).await;
377        Ok(())
378    }
379
380    /// Setup handlers for a data channel (shared between outbound and inbound)
381    async fn setup_dc_handlers(
382        dc: Arc<RTCDataChannel>,
383        peer_id: PeerId,
384        message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
385        pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
386        store: Option<Arc<dyn ContentStore>>,
387    ) {
388        let label = dc.label().to_string();
389        let peer_short = peer_id.short();
390
391        // Track pending binary data (request_id -> expected after response)
392        let _pending_binary: Arc<Mutex<Option<u32>>> = Arc::new(Mutex::new(None));
393
394        let _dc_for_open = dc.clone();
395        let peer_short_open = peer_short.clone();
396        let label_clone = label.clone();
397        dc.on_open(Box::new(move || {
398            let peer_short_open = peer_short_open.clone();
399            let label_clone = label_clone.clone();
400            // Work MUST be inside the returned future
401            Box::pin(async move {
402                info!("[Peer {}] Data channel '{}' open", peer_short_open, label_clone);
403            })
404        }));
405
406        let dc_for_msg = dc.clone();
407        let peer_short_msg = peer_short.clone();
408        let _pending_binary_clone = _pending_binary.clone();
409        let store_clone = store.clone();
410
411        dc.on_message(Box::new(move |msg: DataChannelMessage| {
412            let dc = dc_for_msg.clone();
413            let peer_short = peer_short_msg.clone();
414            let pending_requests = pending_requests.clone();
415            let _pending_binary = _pending_binary_clone.clone();
416            let _message_tx = message_tx.clone();
417            let store = store_clone.clone();
418            let msg_data = msg.data.clone();
419
420            // Work MUST be inside the returned future
421            Box::pin(async move {
422                // All messages are binary with type prefix + MessagePack body
423                debug!("[Peer {}] Received {} bytes on data channel", peer_short, msg_data.len());
424                match parse_message(&msg_data) {
425                    Ok(data_msg) => match data_msg {
426                        DataMessage::Request(req) => {
427                            let hash_hex = hash_to_hex(&req.h);
428                            let hash_short = &hash_hex[..8.min(hash_hex.len())];
429                            info!(
430                                "[Peer {}] Received request for {}",
431                                peer_short, hash_short
432                            );
433
434                            // Handle request - look up in store
435                            let data = if let Some(ref store) = store {
436                                match store.get(&hash_hex) {
437                                    Ok(Some(data)) => {
438                                        info!("[Peer {}] Found {} in store ({} bytes)", peer_short, hash_short, data.len());
439                                        Some(data)
440                                    },
441                                    Ok(None) => {
442                                        info!("[Peer {}] Hash {} not in store", peer_short, hash_short);
443                                        None
444                                    },
445                                    Err(e) => {
446                                        warn!("[Peer {}] Store error: {}", peer_short, e);
447                                        None
448                                    }
449                                }
450                            } else {
451                                warn!("[Peer {}] No store configured - cannot serve requests", peer_short);
452                                None
453                            };
454
455                            // Send response only if we have data
456                            if let Some(data) = data {
457                                let data_len = data.len();
458                                let response = DataResponse {
459                                    h: req.h,
460                                    d: data,
461                                };
462                                if let Ok(wire) = encode_response(&response) {
463                                    if let Err(e) = dc.send(&Bytes::from(wire)).await {
464                                        error!(
465                                            "[Peer {}] Failed to send response: {}",
466                                            peer_short, e
467                                        );
468                                    } else {
469                                        info!(
470                                            "[Peer {}] Sent response for {} ({} bytes)",
471                                            peer_short, hash_short, data_len
472                                        );
473                                    }
474                                }
475                            } else {
476                                info!("[Peer {}] Content not found for {}", peer_short, hash_short);
477                            }
478                        }
479                        DataMessage::Response(res) => {
480                            let hash_hex = hash_to_hex(&res.h);
481                            let hash_short = &hash_hex[..8.min(hash_hex.len())];
482                            debug!(
483                                "[Peer {}] Received response for {} ({} bytes)",
484                                peer_short, hash_short, res.d.len()
485                            );
486
487                            // Resolve the pending request by hash
488                            let mut pending = pending_requests.lock().await;
489                            if let Some(req) = pending.remove(&hash_hex) {
490                                let _ = req.response_tx.send(Some(res.d));
491                            }
492                        }
493                    },
494                    Err(e) => {
495                        warn!("[Peer {}] Failed to parse message: {:?}", peer_short, e);
496                        // Log hex dump of first 50 bytes for debugging
497                        let hex_dump: String = msg_data.iter().take(50).map(|b| format!("{:02x}", b)).collect();
498                        warn!("[Peer {}] Message hex: {}", peer_short, hex_dump);
499                    }
500                }
501            })
502        }));
503    }
504
505    /// Check if data channel is ready
506    pub fn has_data_channel(&self) -> bool {
507        // Use try_lock for non-async context
508        self.data_channel
509            .try_lock()
510            .map(|guard| guard.is_some())
511            .unwrap_or(false)
512    }
513
514    /// Request content by hash from this peer
515    pub async fn request(&self, hash_hex: &str) -> Result<Option<Vec<u8>>> {
516        let dc_guard = self.data_channel.lock().await;
517        let dc = dc_guard
518            .as_ref()
519            .ok_or_else(|| anyhow::anyhow!("No data channel"))?
520            .clone();
521        drop(dc_guard);  // Release lock before async operations
522
523        // Convert hex to binary hash
524        let hash = hex::decode(hash_hex)
525            .map_err(|e| anyhow::anyhow!("Invalid hex hash: {}", e))?;
526
527        // Create response channel
528        let (tx, rx) = oneshot::channel();
529
530        // Store pending request (keyed by hash hex)
531        {
532            let mut pending = self.pending_requests.lock().await;
533            pending.insert(
534                hash_hex.to_string(),
535                PendingRequest {
536                    hash: hash.clone(),
537                    response_tx: tx,
538                },
539            );
540        }
541
542        // Send request with MAX_HTL (fresh request from us)
543        let req = DataRequest {
544            h: hash,
545            htl: crate::webrtc::types::MAX_HTL,
546        };
547        let wire = encode_request(&req)?;
548        dc.send(&Bytes::from(wire)).await?;
549
550        debug!(
551            "[Peer {}] Sent request for {}",
552            self.peer_id.short(),
553            &hash_hex[..8.min(hash_hex.len())]
554        );
555
556        // Wait for response with timeout
557        match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
558            Ok(Ok(data)) => Ok(data),
559            Ok(Err(_)) => {
560                // Channel closed
561                Ok(None)
562            }
563            Err(_) => {
564                // Timeout - clean up pending request
565                let mut pending = self.pending_requests.lock().await;
566                pending.remove(hash_hex);
567                Ok(None)
568            }
569        }
570    }
571
572    /// Send a message over the data channel
573    pub async fn send_message(&self, msg: &DataMessage) -> Result<()> {
574        let dc_guard = self.data_channel.lock().await;
575        if let Some(ref dc) = *dc_guard {
576            let wire = encode_message(msg)?;
577            dc.send(&Bytes::from(wire)).await?;
578        }
579        Ok(())
580    }
581
582    /// Close the connection
583    pub async fn close(&self) -> Result<()> {
584        {
585            let dc_guard = self.data_channel.lock().await;
586            if let Some(ref dc) = *dc_guard {
587                dc.close().await?;
588            }
589        }
590        self.pc.close().await?;
591        Ok(())
592    }
593}