hashtree_network/
real_factory.rs1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, Mutex, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, TransportError};
13use crate::types::DATA_CHANNEL_LABEL;
14
15use webrtc::api::interceptor_registry::register_default_interceptors;
16use webrtc::api::media_engine::MediaEngine;
17use webrtc::api::APIBuilder;
18use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
19use webrtc::data_channel::data_channel_message::DataChannelMessage;
20use webrtc::data_channel::RTCDataChannel;
21use webrtc::ice_transport::ice_server::RTCIceServer;
22use webrtc::interceptor::registry::Registry;
23use webrtc::peer_connection::configuration::RTCConfiguration;
24use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
25use webrtc::peer_connection::RTCPeerConnection;
26
27pub struct RealDataChannel {
29 dc: Arc<RTCDataChannel>,
30 msg_rx: Mutex<mpsc::Receiver<Vec<u8>>>,
32}
33
34impl RealDataChannel {
35 pub fn new(dc: Arc<RTCDataChannel>) -> Arc<Self> {
37 let (msg_tx, msg_rx) = mpsc::channel(100);
38
39 let tx = msg_tx.clone();
41 dc.on_message(Box::new(move |msg: DataChannelMessage| {
42 let tx = tx.clone();
43 let data = msg.data.to_vec();
44 Box::pin(async move {
45 let _ = tx.send(data).await;
46 })
47 }));
48
49 Arc::new(Self {
50 dc,
51 msg_rx: Mutex::new(msg_rx),
52 })
53 }
54}
55
56#[async_trait]
57impl PeerLink for RealDataChannel {
58 async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
59 self.dc
60 .send(&bytes::Bytes::from(data))
61 .await
62 .map(|_| ())
63 .map_err(|e| TransportError::SendFailed(e.to_string()))
64 }
65
66 async fn recv(&self) -> Option<Vec<u8>> {
67 self.msg_rx.lock().await.recv().await
68 }
69
70 fn try_recv(&self) -> Option<Vec<u8>> {
71 let Ok(mut rx) = self.msg_rx.try_lock() else {
72 return None;
73 };
74 rx.try_recv().ok()
75 }
76
77 fn is_open(&self) -> bool {
78 self.dc.ready_state() == webrtc::data_channel::data_channel_state::RTCDataChannelState::Open
79 }
80
81 async fn close(&self) {
82 let _ = self.dc.close().await;
83 }
84}
85
86struct PendingConnection {
88 connection: Arc<RTCPeerConnection>,
89 data_channel: Option<Arc<RTCDataChannel>>,
90}
91
92pub struct RealPeerConnectionFactory {
96 pending: RwLock<HashMap<String, PendingConnection>>,
98 inbound: RwLock<HashMap<String, PendingConnection>>,
100 stun_servers: Vec<String>,
102}
103
104impl RealPeerConnectionFactory {
105 pub fn new() -> Self {
106 Self::with_stun_servers(vec![
107 "stun:stun.iris.to:3478".to_string(),
108 "stun:stun.l.google.com:19302".to_string(),
109 "stun:stun.cloudflare.com:3478".to_string(),
110 ])
111 }
112
113 pub fn with_stun_servers(stun_servers: Vec<String>) -> Self {
114 Self {
115 pending: RwLock::new(HashMap::new()),
116 inbound: RwLock::new(HashMap::new()),
117 stun_servers,
118 }
119 }
120
121 async fn create_connection(&self) -> Result<Arc<RTCPeerConnection>, TransportError> {
122 let mut media_engine = MediaEngine::default();
123 media_engine
124 .register_default_codecs()
125 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
126
127 let mut registry = Registry::new();
128 registry = register_default_interceptors(registry, &mut media_engine)
129 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
130
131 let api = APIBuilder::new()
132 .with_media_engine(media_engine)
133 .with_interceptor_registry(registry)
134 .build();
135
136 let config = RTCConfiguration {
137 ice_servers: vec![RTCIceServer {
138 urls: self.stun_servers.clone(),
139 ..Default::default()
140 }],
141 ..Default::default()
142 };
143
144 api.new_peer_connection(config)
145 .await
146 .map(Arc::new)
147 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))
148 }
149
150 async fn wait_for_ice_gathering(
152 connection: &Arc<RTCPeerConnection>,
153 ) -> Result<String, TransportError> {
154 let mut gathering_complete = connection.gathering_complete_promise().await;
155
156 let _ = tokio::time::timeout(Duration::from_secs(10), gathering_complete.recv()).await;
158
159 let local_desc = connection.local_description().await.ok_or_else(|| {
161 TransportError::ConnectionFailed("No local description after ICE gathering".to_string())
162 })?;
163
164 Ok(local_desc.sdp)
165 }
166}
167
168impl Default for RealPeerConnectionFactory {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174#[async_trait]
175impl PeerLinkFactory for RealPeerConnectionFactory {
176 async fn create_offer(
177 &self,
178 target_peer_id: &str,
179 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
180 let connection = self.create_connection().await?;
181
182 let dc_init = RTCDataChannelInit {
184 ordered: Some(false),
185 ..Default::default()
186 };
187 let dc = connection
188 .create_data_channel(DATA_CHANNEL_LABEL, Some(dc_init))
189 .await
190 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
191
192 let offer = connection
194 .create_offer(None)
195 .await
196 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
197
198 connection
199 .set_local_description(offer)
200 .await
201 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
202
203 let sdp = Self::wait_for_ice_gathering(&connection).await?;
205
206 self.pending.write().await.insert(
208 target_peer_id.to_string(),
209 PendingConnection {
210 connection,
211 data_channel: Some(dc.clone()),
212 },
213 );
214
215 let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
217 Ok((channel, sdp))
218 }
219
220 async fn accept_offer(
221 &self,
222 from_peer_id: &str,
223 offer_sdp: &str,
224 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
225 let connection = self.create_connection().await?;
226
227 let (dc_tx, dc_rx) = tokio::sync::oneshot::channel::<Arc<RTCDataChannel>>();
230 let dc_tx = Arc::new(Mutex::new(Some(dc_tx)));
231
232 connection.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
233 let dc_tx = dc_tx.clone();
234 Box::pin(async move {
235 if let Some(tx) = dc_tx.lock().await.take() {
236 let _ = tx.send(dc);
237 }
238 })
239 }));
240
241 let offer = RTCSessionDescription::offer(offer_sdp.to_string())
243 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
244 connection
245 .set_remote_description(offer)
246 .await
247 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
248
249 let answer = connection
251 .create_answer(None)
252 .await
253 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
254 connection
255 .set_local_description(answer)
256 .await
257 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
258
259 let sdp = Self::wait_for_ice_gathering(&connection).await?;
261
262 let dc = tokio::time::timeout(Duration::from_secs(30), dc_rx)
264 .await
265 .map_err(|_| {
266 TransportError::ConnectionFailed("Timeout waiting for data channel".to_string())
267 })?
268 .map_err(|_| {
269 TransportError::ConnectionFailed("Data channel sender dropped".to_string())
270 })?;
271
272 self.inbound.write().await.insert(
274 from_peer_id.to_string(),
275 PendingConnection {
276 connection,
277 data_channel: Some(dc.clone()),
278 },
279 );
280
281 let channel: Arc<dyn PeerLink> = RealDataChannel::new(dc);
283 Ok((channel, sdp))
284 }
285
286 async fn handle_answer(
287 &self,
288 target_peer_id: &str,
289 answer_sdp: &str,
290 ) -> Result<Arc<dyn PeerLink>, TransportError> {
291 let pending = self
292 .pending
293 .write()
294 .await
295 .remove(target_peer_id)
296 .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
297
298 let answer = RTCSessionDescription::answer(answer_sdp.to_string())
300 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
301 pending
302 .connection
303 .set_remote_description(answer)
304 .await
305 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
306
307 let dc = pending
309 .data_channel
310 .ok_or_else(|| TransportError::ConnectionFailed("No data channel".to_string()))?;
311
312 Ok(RealDataChannel::new(dc))
313 }
314}
315
316pub type WebRtcPeerLinkFactory = RealPeerConnectionFactory;