actr_runtime/wire/webrtc/
connection.rs1use crate::transport::DataLane;
4use crate::transport::{NetworkError, NetworkResult};
5use actr_protocol::PayloadType;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU16, Ordering};
9use tokio::sync::{RwLock, mpsc};
10use webrtc::data_channel::RTCDataChannel;
11use webrtc::peer_connection::RTCPeerConnection;
12use webrtc::rtp_transceiver::rtp_sender::RTCRtpSender;
13use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
14
15type MediaTracks = Arc<RwLock<HashMap<String, (Arc<TrackLocalStaticRTP>, Arc<RTCRtpSender>)>>>;
17
18#[derive(Clone)]
20pub struct WebRtcConnection {
21 peer_connection: Arc<RTCPeerConnection>,
23
24 data_channels: Arc<RwLock<[Option<Arc<RTCDataChannel>>; 4]>>,
27
28 media_tracks: MediaTracks,
30
31 track_sequence_numbers: Arc<RwLock<HashMap<String, Arc<AtomicU16>>>>,
33
34 track_ssrcs: Arc<RwLock<HashMap<String, u32>>>,
36
37 lane_cache: Arc<RwLock<[Option<DataLane>; 4]>>,
41
42 connected: Arc<RwLock<bool>>,
44}
45
46impl std::fmt::Debug for WebRtcConnection {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("WebRtcConnection")
49 .field("peer_connection", &"<RTCPeerConnection>")
50 .field("data_channels", &"<[Option<Arc<RTCDataChannel>>; 4]>")
51 .field("media_tracks", &"<HashMap<String, Arc<Track>>>")
52 .field("connected", &self.connected)
53 .finish()
54 }
55}
56
57impl WebRtcConnection {
58 pub fn new(peer_connection: Arc<RTCPeerConnection>) -> Self {
63 Self {
64 peer_connection,
65 data_channels: Arc::new(RwLock::new([None, None, None, None])),
66 media_tracks: Arc::new(RwLock::new(HashMap::new())),
67 track_sequence_numbers: Arc::new(RwLock::new(HashMap::new())),
68 track_ssrcs: Arc::new(RwLock::new(HashMap::new())),
69 lane_cache: Arc::new(RwLock::new([None, None, None, None])),
70 connected: Arc::new(RwLock::new(true)),
71 }
72 }
73
74 pub async fn connect(&self) -> NetworkResult<()> {
76 *self.connected.write().await = true;
77 Ok(())
78 }
79
80 #[inline]
82 pub fn is_connected(&self) -> bool {
83 *self.connected.blocking_read()
84 }
85
86 pub async fn close(&self) -> NetworkResult<()> {
88 *self.connected.write().await = false;
89 self.peer_connection.close().await?;
90
91 let mut channels = self.data_channels.write().await;
93 *channels = [None, None, None, None];
94
95 let mut tracks = self.media_tracks.write().await;
97 tracks.clear();
98
99 let mut seq_nums = self.track_sequence_numbers.write().await;
101 seq_nums.clear();
102
103 let mut ssrcs = self.track_ssrcs.write().await;
105 ssrcs.clear();
106
107 let mut cache = self.lane_cache.write().await;
109 *cache = [None, None, None, None];
110
111 tracing::info!("🔌 WebRtcConnection already Close");
112 Ok(())
113 }
114
115 fn get_data_channel_config(
117 payload_type: PayloadType,
118 ) -> webrtc::data_channel::data_channel_init::RTCDataChannelInit {
119 use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
120
121 let channel_id = payload_type as u16;
124
125 match payload_type {
126 PayloadType::RpcSignal | PayloadType::RpcReliable => {
127 RTCDataChannelInit {
129 ordered: Some(true),
130 max_retransmits: None,
131 max_packet_life_time: None,
132 protocol: Some("".to_string()),
133 negotiated: Some(channel_id),
134 }
135 }
136 PayloadType::StreamLatencyFirst => {
137 RTCDataChannelInit {
139 ordered: Some(false),
140 max_retransmits: Some(3),
141 max_packet_life_time: Some(100),
142 protocol: Some("".to_string()),
143 negotiated: Some(channel_id),
144 }
145 }
146 _ => {
147 RTCDataChannelInit {
149 ordered: Some(true),
150 max_retransmits: None,
151 max_packet_life_time: None,
152 protocol: Some("".to_string()),
153 negotiated: Some(channel_id),
154 }
155 }
156 }
157 }
158}
159
160impl WebRtcConnection {
161 pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
163 if payload_type == PayloadType::MediaRtp {
165 return Err(NetworkError::NotImplemented(
166 "MediaTrack Lane requires stream_id, use get_media_lane() instead".to_string(),
167 ));
168 }
169
170 let idx = payload_type as usize;
171
172 {
174 let cache = self.lane_cache.read().await;
175 if let Some(lane) = &cache[idx] {
176 tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
177 return Ok(lane.clone());
178 }
179 }
180
181 let lane = self.create_lane_internal(payload_type).await?;
183
184 {
186 let mut cache = self.lane_cache.write().await;
187 cache[idx] = Some(lane.clone());
188 }
189
190 tracing::info!("✨ WebRtcConnection Createnew DataLane: {:?}", payload_type);
191
192 Ok(lane)
193 }
194
195 async fn create_lane_internal(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
197 if payload_type == PayloadType::MediaRtp {
199 return Err(NetworkError::NotImplemented(
200 "MediaTrack Lane not implemented in this method".to_string(),
201 ));
202 }
203
204 let mut channels = self.data_channels.write().await;
206
207 let label = format!("{payload_type:?}");
208 let dc_config = Self::get_data_channel_config(payload_type);
209
210 let data_channel = self
211 .peer_connection
212 .create_data_channel(&label, Some(dc_config))
213 .await?;
214
215 let (tx, rx) = mpsc::channel(100);
217
218 let tx_clone = tx.clone();
220 data_channel.on_message(Box::new(
221 move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
222 let data = msg.data;
224 let tx = tx_clone.clone();
225 Box::pin(async move {
226 if let Err(e) = tx.send(data).await {
227 tracing::warn!("❌ WebRTC DataChannel messageSend to Lane failure: {}", e);
228 }
229 })
230 },
231 ));
232
233 let idx = payload_type as usize;
235 channels[idx] = Some(Arc::clone(&data_channel));
236
237 Ok(DataLane::webrtc_data_channel(data_channel, rx))
239 }
240
241 pub async fn add_media_track(
254 &self,
255 track_id: String,
256 codec: &str,
257 media_type: &str,
258 ) -> NetworkResult<Arc<TrackLocalStaticRTP>> {
259 use webrtc::api::media_engine::MIME_TYPE_H264;
260 use webrtc::api::media_engine::MIME_TYPE_OPUS;
261 use webrtc::api::media_engine::MIME_TYPE_VP8;
262 use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
263
264 let mime_type = match (media_type, codec.to_uppercase().as_str()) {
266 ("video", "H264") => MIME_TYPE_H264,
267 ("video", "VP8") => MIME_TYPE_VP8,
268 ("audio", "OPUS") => MIME_TYPE_OPUS,
269 _ => {
270 return Err(NetworkError::WebRtcError(format!(
271 "Unsupported codec: {codec} for {media_type}"
272 )));
273 }
274 };
275
276 let track = Arc::new(TrackLocalStaticRTP::new(
278 RTCRtpCodecCapability {
279 mime_type: mime_type.to_string(),
280 ..Default::default()
281 },
282 track_id.clone(),
283 format!("actr-{media_type}"), ));
285
286 let rtp_sender =
288 self.peer_connection
289 .add_track(Arc::clone(&track)
290 as Arc<dyn webrtc::track::track_local::TrackLocal + Send + Sync>)
291 .await?;
292
293 let mut tracks = self.media_tracks.write().await;
295 tracks.insert(track_id.clone(), (Arc::clone(&track), rtp_sender));
296
297 let mut seq_nums = self.track_sequence_numbers.write().await;
299 seq_nums.insert(track_id.clone(), Arc::new(AtomicU16::new(0)));
300
301 let ssrc = rand::random::<u32>();
303 let mut ssrcs = self.track_ssrcs.write().await;
304 ssrcs.insert(track_id.clone(), ssrc);
305
306 tracing::info!(
307 "✨ Added media track: id={}, codec={}, type={}, ssrc=0x{:08x}",
308 track_id,
309 codec,
310 media_type,
311 ssrc
312 );
313
314 Ok(track)
315 }
316
317 pub async fn get_media_track(&self, track_id: &str) -> Option<Arc<TrackLocalStaticRTP>> {
319 let tracks = self.media_tracks.read().await;
320 tracks
321 .get(track_id)
322 .map(|(track, _sender)| Arc::clone(track))
323 }
324
325 pub async fn next_sequence_number(&self, track_id: &str) -> Option<u16> {
333 let seq_nums = self.track_sequence_numbers.read().await;
334 seq_nums
335 .get(track_id)
336 .map(|atomic_seq| atomic_seq.fetch_add(1, Ordering::SeqCst))
337 }
338
339 pub async fn get_ssrc(&self, track_id: &str) -> Option<u32> {
347 let ssrcs = self.track_ssrcs.read().await;
348 ssrcs.get(track_id).copied()
349 }
350
351 pub async fn create_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
358 self.get_lane(payload_type).await
359 }
360
361 pub async fn register_received_data_channel(
366 &self,
367 data_channel: Arc<RTCDataChannel>,
368 payload_type: PayloadType,
369 ) -> NetworkResult<DataLane> {
370 if payload_type == PayloadType::MediaRtp {
372 return Err(NetworkError::NotImplemented(
373 "MediaTrack Lane not supported in this method".to_string(),
374 ));
375 }
376
377 let idx = payload_type as usize;
378
379 let (tx, rx) = mpsc::channel(100);
381
382 let tx_clone = tx.clone();
384 data_channel.on_message(Box::new(
385 move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
386 let data = msg.data;
387 let tx = tx_clone.clone();
388 Box::pin(async move {
389 if let Err(e) = tx.send(data).await {
390 tracing::warn!("❌ WebRTC DataChannel message send to Lane failed: {}", e);
391 }
392 })
393 },
394 ));
395
396 {
398 let mut channels = self.data_channels.write().await;
399 channels[idx] = Some(Arc::clone(&data_channel));
400 }
401
402 let lane = DataLane::webrtc_data_channel(data_channel, rx);
404 {
405 let mut cache = self.lane_cache.write().await;
406 cache[idx] = Some(lane.clone());
407 }
408
409 tracing::info!(
410 "✨ WebRtcConnection registered received DataChannel: {:?}",
411 payload_type
412 );
413
414 Ok(lane)
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
425 fn test_data_channel_config() {
426 let config = WebRtcConnection::get_data_channel_config(PayloadType::RpcReliable);
427 assert_eq!(config.ordered, Some(true));
428
429 let config = WebRtcConnection::get_data_channel_config(PayloadType::StreamLatencyFirst);
430 assert_eq!(config.ordered, Some(false));
431 assert_eq!(config.max_retransmits, Some(3));
432 }
433}