Skip to main content

hashtree_network/
real_factory.rs

1//! Real WebRTC peer-link factory
2//!
3//! Wraps the `webrtc` crate to implement the generic peer-link factory for
4//! production WebRTC use.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, Mutex, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, TransportError};
13use crate::types::DATA_CHANNEL_LABEL;
14
15use webrtc::api::interceptor_registry::register_default_interceptors;
16use webrtc::api::media_engine::MediaEngine;
17use webrtc::api::APIBuilder;
18use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
19use webrtc::data_channel::data_channel_message::DataChannelMessage;
20use webrtc::data_channel::RTCDataChannel;
21use webrtc::ice_transport::ice_server::RTCIceServer;
22use webrtc::interceptor::registry::Registry;
23use webrtc::peer_connection::configuration::RTCConfiguration;
24use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
25use webrtc::peer_connection::RTCPeerConnection;
26
27/// Wrapper around `RTCDataChannel` that implements the generic peer-link trait.
28pub struct RealDataChannel {
29    dc: Arc<RTCDataChannel>,
30    /// Receiver for incoming messages (populated by on_message callback)
31    msg_rx: Mutex<mpsc::Receiver<Vec<u8>>>,
32}
33
34impl RealDataChannel {
35    /// Create a new RealDataChannel with message handling
36    pub fn new(dc: Arc<RTCDataChannel>) -> Arc<Self> {
37        let (msg_tx, msg_rx) = mpsc::channel(100);
38
39        // Set up on_message handler to forward messages to channel
40        let tx = msg_tx.clone();
41        dc.on_message(Box::new(move |msg: DataChannelMessage| {
42            let tx = tx.clone();
43            let data = msg.data.to_vec();
44            Box::pin(async move {
45                let _ = tx.send(data).await;
46            })
47        }));
48
49        Arc::new(Self {
50            dc,
51            msg_rx: Mutex::new(msg_rx),
52        })
53    }
54}
55
56#[async_trait]
57impl PeerLink for RealDataChannel {
58    async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
59        self.dc
60            .send(&bytes::Bytes::from(data))
61            .await
62            .map(|_| ())
63            .map_err(|e| TransportError::SendFailed(e.to_string()))
64    }
65
66    async fn recv(&self) -> Option<Vec<u8>> {
67        self.msg_rx.lock().await.recv().await
68    }
69
70    fn try_recv(&self) -> Option<Vec<u8>> {
71        let Ok(mut rx) = self.msg_rx.try_lock() else {
72            return None;
73        };
74        rx.try_recv().ok()
75    }
76
77    fn is_open(&self) -> bool {
78        self.dc.ready_state() == webrtc::data_channel::data_channel_state::RTCDataChannelState::Open
79    }
80
81    async fn close(&self) {
82        let _ = self.dc.close().await;
83    }
84}
85
86/// Pending connection state
87struct PendingConnection {
88    connection: Arc<RTCPeerConnection>,
89    data_channel: Option<Arc<RTCDataChannel>>,
90}
91
92/// WebRTC peer-link factory for the default production transport stack.
93pub struct WebRtcPeerLinkFactory {
94    /// Pending outbound connections (we sent offer, waiting for answer)
95    pending: RwLock<HashMap<String, PendingConnection>>,
96    /// Pending inbound connections (we received offer, sent answer)
97    inbound: RwLock<HashMap<String, PendingConnection>>,
98    /// STUN servers for ICE
99    stun_servers: Vec<String>,
100}
101
102impl WebRtcPeerLinkFactory {
103    pub fn new() -> Self {
104        Self::with_stun_servers(vec![
105            "stun:stun.iris.to:3478".to_string(),
106            "stun:stun.l.google.com:19302".to_string(),
107            "stun:stun.cloudflare.com:3478".to_string(),
108        ])
109    }
110
111    pub fn with_stun_servers(stun_servers: Vec<String>) -> Self {
112        Self {
113            pending: RwLock::new(HashMap::new()),
114            inbound: RwLock::new(HashMap::new()),
115            stun_servers,
116        }
117    }
118
119    async fn create_connection(&self) -> Result<Arc<RTCPeerConnection>, TransportError> {
120        let mut media_engine = MediaEngine::default();
121        media_engine
122            .register_default_codecs()
123            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
124
125        let mut registry = Registry::new();
126        registry = register_default_interceptors(registry, &mut media_engine)
127            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
128
129        let api = APIBuilder::new()
130            .with_media_engine(media_engine)
131            .with_interceptor_registry(registry)
132            .build();
133
134        let config = RTCConfiguration {
135            ice_servers: vec![RTCIceServer {
136                urls: self.stun_servers.clone(),
137                ..Default::default()
138            }],
139            ..Default::default()
140        };
141
142        api.new_peer_connection(config)
143            .await
144            .map(Arc::new)
145            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))
146    }
147
148    /// Wait for ICE gathering to complete and return the SDP with embedded candidates
149    async fn wait_for_ice_gathering(
150        connection: &Arc<RTCPeerConnection>,
151    ) -> Result<String, TransportError> {
152        let mut gathering_complete = connection.gathering_complete_promise().await;
153
154        // Wait for ICE gathering to complete (with timeout)
155        let _ = tokio::time::timeout(Duration::from_secs(10), gathering_complete.recv()).await;
156
157        // Get the local description with ICE candidates embedded
158        let local_desc = connection.local_description().await.ok_or_else(|| {
159            TransportError::ConnectionFailed("No local description after ICE gathering".to_string())
160        })?;
161
162        Ok(local_desc.sdp)
163    }
164}
165
166impl Default for WebRtcPeerLinkFactory {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[async_trait]
173impl PeerLinkFactory for WebRtcPeerLinkFactory {
174    async fn create_offer(
175        &self,
176        target_peer_id: &str,
177    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
178        let connection = self.create_connection().await?;
179
180        // Create data channel (unordered for better performance - protocol is stateless)
181        let dc_init = RTCDataChannelInit {
182            ordered: Some(false),
183            ..Default::default()
184        };
185        let dc = connection
186            .create_data_channel(DATA_CHANNEL_LABEL, Some(dc_init))
187            .await
188            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
189
190        // Create offer and set local description to start ICE gathering
191        let offer = connection
192            .create_offer(None)
193            .await
194            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
195
196        connection
197            .set_local_description(offer)
198            .await
199            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
200
201        // Wait for ICE gathering to complete - this embeds ICE candidates in the SDP
202        let sdp = Self::wait_for_ice_gathering(&connection).await?;
203
204        // Store pending connection (we'll need it when answer arrives)
205        self.pending.write().await.insert(
206            target_peer_id.to_string(),
207            PendingConnection {
208                connection,
209                data_channel: Some(dc.clone()),
210            },
211        );
212
213        // Create channel wrapper with message handling
214        let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
215        Ok((channel, sdp))
216    }
217
218    async fn accept_offer(
219        &self,
220        from_peer_id: &str,
221        offer_sdp: &str,
222    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
223        let connection = self.create_connection().await?;
224
225        // Set up data channel callback BEFORE setting remote description
226        // This ensures we catch the data channel when it arrives
227        let (dc_tx, dc_rx) = tokio::sync::oneshot::channel::<Arc<RTCDataChannel>>();
228        let dc_tx = Arc::new(Mutex::new(Some(dc_tx)));
229
230        connection.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
231            let dc_tx = dc_tx.clone();
232            Box::pin(async move {
233                if let Some(tx) = dc_tx.lock().await.take() {
234                    let _ = tx.send(dc);
235                }
236            })
237        }));
238
239        // Set remote description (the offer)
240        let offer = RTCSessionDescription::offer(offer_sdp.to_string())
241            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
242        connection
243            .set_remote_description(offer)
244            .await
245            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
246
247        // Create answer and set local description to start ICE gathering
248        let answer = connection
249            .create_answer(None)
250            .await
251            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
252        connection
253            .set_local_description(answer)
254            .await
255            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
256
257        // Wait for ICE gathering to complete - this embeds ICE candidates in the SDP
258        let sdp = Self::wait_for_ice_gathering(&connection).await?;
259
260        // Wait for data channel from remote peer (with timeout)
261        let dc = tokio::time::timeout(Duration::from_secs(30), dc_rx)
262            .await
263            .map_err(|_| {
264                TransportError::ConnectionFailed("Timeout waiting for data channel".to_string())
265            })?
266            .map_err(|_| {
267                TransportError::ConnectionFailed("Data channel sender dropped".to_string())
268            })?;
269
270        // Store connection for potential future use
271        self.inbound.write().await.insert(
272            from_peer_id.to_string(),
273            PendingConnection {
274                connection,
275                data_channel: Some(dc.clone()),
276            },
277        );
278
279        // Create channel wrapper with message handling
280        let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
281        Ok((channel, sdp))
282    }
283
284    async fn handle_answer(
285        &self,
286        target_peer_id: &str,
287        answer_sdp: &str,
288    ) -> Result<Arc<dyn PeerLink>, TransportError> {
289        let pending = self
290            .pending
291            .write()
292            .await
293            .remove(target_peer_id)
294            .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
295
296        // Set remote description (the answer)
297        let answer = RTCSessionDescription::answer(answer_sdp.to_string())
298            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
299        pending
300            .connection
301            .set_remote_description(answer)
302            .await
303            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
304
305        // Return the data channel we created earlier with message handling
306        let dc = pending
307            .data_channel
308            .ok_or_else(|| TransportError::ConnectionFailed("No data channel".to_string()))?;
309
310        Ok(RealDataChannel::new(dc))
311    }
312}