Skip to main content

hashtree_network/
real_factory.rs

1//! Real WebRTC peer-link factory
2//!
3//! Wraps the `webrtc` crate to implement the generic peer-link factory for
4//! production WebRTC use.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, Mutex, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, TransportError};
13use crate::types::DATA_CHANNEL_LABEL;
14
15use webrtc::api::interceptor_registry::register_default_interceptors;
16use webrtc::api::media_engine::MediaEngine;
17use webrtc::api::APIBuilder;
18use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
19use webrtc::data_channel::data_channel_message::DataChannelMessage;
20use webrtc::data_channel::RTCDataChannel;
21use webrtc::ice_transport::ice_server::RTCIceServer;
22use webrtc::interceptor::registry::Registry;
23use webrtc::peer_connection::configuration::RTCConfiguration;
24use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
25use webrtc::peer_connection::RTCPeerConnection;
26
27/// Wrapper around `RTCDataChannel` that implements the generic peer-link trait.
28pub struct RealDataChannel {
29    dc: Arc<RTCDataChannel>,
30    /// Receiver for incoming messages (populated by on_message callback)
31    msg_rx: Mutex<mpsc::Receiver<Vec<u8>>>,
32}
33
34impl RealDataChannel {
35    /// Create a new RealDataChannel with message handling
36    pub fn new(dc: Arc<RTCDataChannel>) -> Arc<Self> {
37        let (msg_tx, msg_rx) = mpsc::channel(100);
38
39        // Set up on_message handler to forward messages to channel
40        let tx = msg_tx.clone();
41        dc.on_message(Box::new(move |msg: DataChannelMessage| {
42            let tx = tx.clone();
43            let data = msg.data.to_vec();
44            Box::pin(async move {
45                let _ = tx.send(data).await;
46            })
47        }));
48
49        Arc::new(Self {
50            dc,
51            msg_rx: Mutex::new(msg_rx),
52        })
53    }
54}
55
56#[async_trait]
57impl PeerLink for RealDataChannel {
58    async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
59        self.dc
60            .send(&bytes::Bytes::from(data))
61            .await
62            .map(|_| ())
63            .map_err(|e| TransportError::SendFailed(e.to_string()))
64    }
65
66    async fn recv(&self) -> Option<Vec<u8>> {
67        self.msg_rx.lock().await.recv().await
68    }
69
70    fn try_recv(&self) -> Option<Vec<u8>> {
71        let Ok(mut rx) = self.msg_rx.try_lock() else {
72            return None;
73        };
74        rx.try_recv().ok()
75    }
76
77    fn is_open(&self) -> bool {
78        self.dc.ready_state() == webrtc::data_channel::data_channel_state::RTCDataChannelState::Open
79    }
80
81    async fn close(&self) {
82        let _ = self.dc.close().await;
83    }
84}
85
86/// Pending connection state
87struct PendingConnection {
88    connection: Arc<RTCPeerConnection>,
89    data_channel: Option<Arc<RTCDataChannel>>,
90}
91
92/// Real WebRTC peer-link factory
93///
94/// Creates actual WebRTC connections using the webrtc crate.
95pub struct RealPeerConnectionFactory {
96    /// Pending outbound connections (we sent offer, waiting for answer)
97    pending: RwLock<HashMap<String, PendingConnection>>,
98    /// Pending inbound connections (we received offer, sent answer)
99    inbound: RwLock<HashMap<String, PendingConnection>>,
100    /// STUN servers for ICE
101    stun_servers: Vec<String>,
102}
103
104impl RealPeerConnectionFactory {
105    pub fn new() -> Self {
106        Self::with_stun_servers(vec![
107            "stun:stun.iris.to:3478".to_string(),
108            "stun:stun.l.google.com:19302".to_string(),
109            "stun:stun.cloudflare.com:3478".to_string(),
110        ])
111    }
112
113    pub fn with_stun_servers(stun_servers: Vec<String>) -> Self {
114        Self {
115            pending: RwLock::new(HashMap::new()),
116            inbound: RwLock::new(HashMap::new()),
117            stun_servers,
118        }
119    }
120
121    async fn create_connection(&self) -> Result<Arc<RTCPeerConnection>, TransportError> {
122        let mut media_engine = MediaEngine::default();
123        media_engine
124            .register_default_codecs()
125            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
126
127        let mut registry = Registry::new();
128        registry = register_default_interceptors(registry, &mut media_engine)
129            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
130
131        let api = APIBuilder::new()
132            .with_media_engine(media_engine)
133            .with_interceptor_registry(registry)
134            .build();
135
136        let config = RTCConfiguration {
137            ice_servers: vec![RTCIceServer {
138                urls: self.stun_servers.clone(),
139                ..Default::default()
140            }],
141            ..Default::default()
142        };
143
144        api.new_peer_connection(config)
145            .await
146            .map(Arc::new)
147            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))
148    }
149
150    /// Wait for ICE gathering to complete and return the SDP with embedded candidates
151    async fn wait_for_ice_gathering(
152        connection: &Arc<RTCPeerConnection>,
153    ) -> Result<String, TransportError> {
154        let mut gathering_complete = connection.gathering_complete_promise().await;
155
156        // Wait for ICE gathering to complete (with timeout)
157        let _ = tokio::time::timeout(Duration::from_secs(10), gathering_complete.recv()).await;
158
159        // Get the local description with ICE candidates embedded
160        let local_desc = connection.local_description().await.ok_or_else(|| {
161            TransportError::ConnectionFailed("No local description after ICE gathering".to_string())
162        })?;
163
164        Ok(local_desc.sdp)
165    }
166}
167
168impl Default for RealPeerConnectionFactory {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174#[async_trait]
175impl PeerLinkFactory for RealPeerConnectionFactory {
176    async fn create_offer(
177        &self,
178        target_peer_id: &str,
179    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
180        let connection = self.create_connection().await?;
181
182        // Create data channel (unordered for better performance - protocol is stateless)
183        let dc_init = RTCDataChannelInit {
184            ordered: Some(false),
185            ..Default::default()
186        };
187        let dc = connection
188            .create_data_channel(DATA_CHANNEL_LABEL, Some(dc_init))
189            .await
190            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
191
192        // Create offer and set local description to start ICE gathering
193        let offer = connection
194            .create_offer(None)
195            .await
196            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
197
198        connection
199            .set_local_description(offer)
200            .await
201            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
202
203        // Wait for ICE gathering to complete - this embeds ICE candidates in the SDP
204        let sdp = Self::wait_for_ice_gathering(&connection).await?;
205
206        // Store pending connection (we'll need it when answer arrives)
207        self.pending.write().await.insert(
208            target_peer_id.to_string(),
209            PendingConnection {
210                connection,
211                data_channel: Some(dc.clone()),
212            },
213        );
214
215        // Create channel wrapper with message handling
216        let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
217        Ok((channel, sdp))
218    }
219
220    async fn accept_offer(
221        &self,
222        from_peer_id: &str,
223        offer_sdp: &str,
224    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
225        let connection = self.create_connection().await?;
226
227        // Set up data channel callback BEFORE setting remote description
228        // This ensures we catch the data channel when it arrives
229        let (dc_tx, dc_rx) = tokio::sync::oneshot::channel::<Arc<RTCDataChannel>>();
230        let dc_tx = Arc::new(Mutex::new(Some(dc_tx)));
231
232        connection.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
233            let dc_tx = dc_tx.clone();
234            Box::pin(async move {
235                if let Some(tx) = dc_tx.lock().await.take() {
236                    let _ = tx.send(dc);
237                }
238            })
239        }));
240
241        // Set remote description (the offer)
242        let offer = RTCSessionDescription::offer(offer_sdp.to_string())
243            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
244        connection
245            .set_remote_description(offer)
246            .await
247            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
248
249        // Create answer and set local description to start ICE gathering
250        let answer = connection
251            .create_answer(None)
252            .await
253            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
254        connection
255            .set_local_description(answer)
256            .await
257            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
258
259        // Wait for ICE gathering to complete - this embeds ICE candidates in the SDP
260        let sdp = Self::wait_for_ice_gathering(&connection).await?;
261
262        // Wait for data channel from remote peer (with timeout)
263        let dc = tokio::time::timeout(Duration::from_secs(30), dc_rx)
264            .await
265            .map_err(|_| {
266                TransportError::ConnectionFailed("Timeout waiting for data channel".to_string())
267            })?
268            .map_err(|_| {
269                TransportError::ConnectionFailed("Data channel sender dropped".to_string())
270            })?;
271
272        // Store connection for potential future use
273        self.inbound.write().await.insert(
274            from_peer_id.to_string(),
275            PendingConnection {
276                connection,
277                data_channel: Some(dc.clone()),
278            },
279        );
280
281        // Create channel wrapper with message handling
282        let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
283        Ok((channel, sdp))
284    }
285
286    async fn handle_answer(
287        &self,
288        target_peer_id: &str,
289        answer_sdp: &str,
290    ) -> Result<Arc<dyn PeerLink>, TransportError> {
291        let pending = self
292            .pending
293            .write()
294            .await
295            .remove(target_peer_id)
296            .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
297
298        // Set remote description (the answer)
299        let answer = RTCSessionDescription::answer(answer_sdp.to_string())
300            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
301        pending
302            .connection
303            .set_remote_description(answer)
304            .await
305            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
306
307        // Return the data channel we created earlier with message handling
308        let dc = pending
309            .data_channel
310            .ok_or_else(|| TransportError::ConnectionFailed("No data channel".to_string()))?;
311
312        Ok(RealDataChannel::new(dc))
313    }
314}
315
316pub type WebRtcPeerLinkFactory = RealPeerConnectionFactory;