hashtree_network/
real_factory.rs1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, Mutex, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, TransportError};
13use crate::types::DATA_CHANNEL_LABEL;
14
15use webrtc::api::interceptor_registry::register_default_interceptors;
16use webrtc::api::media_engine::MediaEngine;
17use webrtc::api::APIBuilder;
18use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
19use webrtc::data_channel::data_channel_message::DataChannelMessage;
20use webrtc::data_channel::RTCDataChannel;
21use webrtc::ice_transport::ice_server::RTCIceServer;
22use webrtc::interceptor::registry::Registry;
23use webrtc::peer_connection::configuration::RTCConfiguration;
24use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
25use webrtc::peer_connection::RTCPeerConnection;
26
27pub struct RealDataChannel {
29 dc: Arc<RTCDataChannel>,
30 msg_rx: Mutex<mpsc::Receiver<Vec<u8>>>,
32}
33
34impl RealDataChannel {
35 pub fn new(dc: Arc<RTCDataChannel>) -> Arc<Self> {
37 let (msg_tx, msg_rx) = mpsc::channel(100);
38
39 let tx = msg_tx.clone();
41 dc.on_message(Box::new(move |msg: DataChannelMessage| {
42 let tx = tx.clone();
43 let data = msg.data.to_vec();
44 Box::pin(async move {
45 let _ = tx.send(data).await;
46 })
47 }));
48
49 Arc::new(Self {
50 dc,
51 msg_rx: Mutex::new(msg_rx),
52 })
53 }
54}
55
56#[async_trait]
57impl PeerLink for RealDataChannel {
58 async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
59 self.dc
60 .send(&bytes::Bytes::from(data))
61 .await
62 .map(|_| ())
63 .map_err(|e| TransportError::SendFailed(e.to_string()))
64 }
65
66 async fn recv(&self) -> Option<Vec<u8>> {
67 self.msg_rx.lock().await.recv().await
68 }
69
70 fn try_recv(&self) -> Option<Vec<u8>> {
71 let Ok(mut rx) = self.msg_rx.try_lock() else {
72 return None;
73 };
74 rx.try_recv().ok()
75 }
76
77 fn is_open(&self) -> bool {
78 self.dc.ready_state() == webrtc::data_channel::data_channel_state::RTCDataChannelState::Open
79 }
80
81 async fn close(&self) {
82 let _ = self.dc.close().await;
83 }
84}
85
86struct PendingConnection {
88 connection: Arc<RTCPeerConnection>,
89 data_channel: Option<Arc<RTCDataChannel>>,
90}
91
92pub struct WebRtcPeerLinkFactory {
94 pending: RwLock<HashMap<String, PendingConnection>>,
96 inbound: RwLock<HashMap<String, PendingConnection>>,
98 stun_servers: Vec<String>,
100}
101
102impl WebRtcPeerLinkFactory {
103 pub fn new() -> Self {
104 Self::with_stun_servers(vec![
105 "stun:stun.iris.to:3478".to_string(),
106 "stun:stun.l.google.com:19302".to_string(),
107 "stun:stun.cloudflare.com:3478".to_string(),
108 ])
109 }
110
111 pub fn with_stun_servers(stun_servers: Vec<String>) -> Self {
112 Self {
113 pending: RwLock::new(HashMap::new()),
114 inbound: RwLock::new(HashMap::new()),
115 stun_servers,
116 }
117 }
118
119 async fn create_connection(&self) -> Result<Arc<RTCPeerConnection>, TransportError> {
120 let mut media_engine = MediaEngine::default();
121 media_engine
122 .register_default_codecs()
123 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
124
125 let mut registry = Registry::new();
126 registry = register_default_interceptors(registry, &mut media_engine)
127 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
128
129 let api = APIBuilder::new()
130 .with_media_engine(media_engine)
131 .with_interceptor_registry(registry)
132 .build();
133
134 let config = RTCConfiguration {
135 ice_servers: vec![RTCIceServer {
136 urls: self.stun_servers.clone(),
137 ..Default::default()
138 }],
139 ..Default::default()
140 };
141
142 api.new_peer_connection(config)
143 .await
144 .map(Arc::new)
145 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))
146 }
147
148 async fn wait_for_ice_gathering(
150 connection: &Arc<RTCPeerConnection>,
151 ) -> Result<String, TransportError> {
152 let mut gathering_complete = connection.gathering_complete_promise().await;
153
154 let _ = tokio::time::timeout(Duration::from_secs(10), gathering_complete.recv()).await;
156
157 let local_desc = connection.local_description().await.ok_or_else(|| {
159 TransportError::ConnectionFailed("No local description after ICE gathering".to_string())
160 })?;
161
162 Ok(local_desc.sdp)
163 }
164}
165
166impl Default for WebRtcPeerLinkFactory {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[async_trait]
173impl PeerLinkFactory for WebRtcPeerLinkFactory {
174 async fn create_offer(
175 &self,
176 target_peer_id: &str,
177 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
178 let connection = self.create_connection().await?;
179
180 let dc_init = RTCDataChannelInit {
182 ordered: Some(false),
183 ..Default::default()
184 };
185 let dc = connection
186 .create_data_channel(DATA_CHANNEL_LABEL, Some(dc_init))
187 .await
188 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
189
190 let offer = connection
192 .create_offer(None)
193 .await
194 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
195
196 connection
197 .set_local_description(offer)
198 .await
199 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
200
201 let sdp = Self::wait_for_ice_gathering(&connection).await?;
203
204 self.pending.write().await.insert(
206 target_peer_id.to_string(),
207 PendingConnection {
208 connection,
209 data_channel: Some(dc.clone()),
210 },
211 );
212
213 let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
215 Ok((channel, sdp))
216 }
217
218 async fn accept_offer(
219 &self,
220 from_peer_id: &str,
221 offer_sdp: &str,
222 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
223 let connection = self.create_connection().await?;
224
225 let (dc_tx, dc_rx) = tokio::sync::oneshot::channel::<Arc<RTCDataChannel>>();
228 let dc_tx = Arc::new(Mutex::new(Some(dc_tx)));
229
230 connection.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
231 let dc_tx = dc_tx.clone();
232 Box::pin(async move {
233 if let Some(tx) = dc_tx.lock().await.take() {
234 let _ = tx.send(dc);
235 }
236 })
237 }));
238
239 let offer = RTCSessionDescription::offer(offer_sdp.to_string())
241 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
242 connection
243 .set_remote_description(offer)
244 .await
245 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
246
247 let answer = connection
249 .create_answer(None)
250 .await
251 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
252 connection
253 .set_local_description(answer)
254 .await
255 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
256
257 let sdp = Self::wait_for_ice_gathering(&connection).await?;
259
260 let dc = tokio::time::timeout(Duration::from_secs(30), dc_rx)
262 .await
263 .map_err(|_| {
264 TransportError::ConnectionFailed("Timeout waiting for data channel".to_string())
265 })?
266 .map_err(|_| {
267 TransportError::ConnectionFailed("Data channel sender dropped".to_string())
268 })?;
269
270 self.inbound.write().await.insert(
272 from_peer_id.to_string(),
273 PendingConnection {
274 connection,
275 data_channel: Some(dc.clone()),
276 },
277 );
278
279 let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
281 Ok((channel, sdp))
282 }
283
284 async fn handle_answer(
285 &self,
286 target_peer_id: &str,
287 answer_sdp: &str,
288 ) -> Result<Arc<dyn PeerLink>, TransportError> {
289 let pending = self
290 .pending
291 .write()
292 .await
293 .remove(target_peer_id)
294 .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
295
296 let answer = RTCSessionDescription::answer(answer_sdp.to_string())
298 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
299 pending
300 .connection
301 .set_remote_description(answer)
302 .await
303 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
304
305 let dc = pending
307 .data_channel
308 .ok_or_else(|| TransportError::ConnectionFailed("No data channel".to_string()))?;
309
310 Ok(RealDataChannel::new(dc))
311 }
312}