actr_runtime/wire/webrtc/
connection.rs1use crate::transport::DataLane;
4use crate::transport::{NetworkError, NetworkResult};
5use actr_protocol::PayloadType;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU16, Ordering};
9use tokio::sync::{RwLock, mpsc};
10use webrtc::data_channel::RTCDataChannel;
11use webrtc::peer_connection::RTCPeerConnection;
12use webrtc::rtp_transceiver::rtp_sender::RTCRtpSender;
13use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
14
15type MediaTracks = Arc<RwLock<HashMap<String, (Arc<TrackLocalStaticRTP>, Arc<RTCRtpSender>)>>>;
17
18#[derive(Clone)]
20pub struct WebRtcConnection {
21 peer_connection: Arc<RTCPeerConnection>,
23
24 data_channels: Arc<RwLock<[Option<Arc<RTCDataChannel>>; 4]>>,
28
29 media_tracks: MediaTracks,
31
32 track_sequence_numbers: Arc<RwLock<HashMap<String, Arc<AtomicU16>>>>,
34
35 track_ssrcs: Arc<RwLock<HashMap<String, u32>>>,
37
38 lane_cache: Arc<RwLock<[Option<DataLane>; 4]>>,
42
43 connected: Arc<RwLock<bool>>,
45}
46
47impl std::fmt::Debug for WebRtcConnection {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("WebRtcConnection")
50 .field("peer_connection", &"<RTCPeerConnection>")
51 .field("data_channels", &"<[Option<Arc<RTCDataChannel>>; 4]>")
52 .field("media_tracks", &"<HashMap<String, Arc<Track>>>")
53 .field("connected", &self.connected)
54 .finish()
55 }
56}
57
58impl WebRtcConnection {
59 pub fn new(peer_connection: Arc<RTCPeerConnection>) -> Self {
64 Self {
65 peer_connection,
66 data_channels: Arc::new(RwLock::new([None, None, None, None])),
67 media_tracks: Arc::new(RwLock::new(HashMap::new())),
68 track_sequence_numbers: Arc::new(RwLock::new(HashMap::new())),
69 track_ssrcs: Arc::new(RwLock::new(HashMap::new())),
70 lane_cache: Arc::new(RwLock::new([None, None, None, None])),
71 connected: Arc::new(RwLock::new(true)),
72 }
73 }
74
75 pub fn install_state_change_handler(&self) {
81 use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
82
83 let this = self.clone();
84
85 self.peer_connection
86 .on_peer_connection_state_change(Box::new(move |state: RTCPeerConnectionState| {
87 let this = this.clone();
88
89 Box::pin(async move {
90 let is_connected = matches!(
92 state,
93 RTCPeerConnectionState::New
94 | RTCPeerConnectionState::Connecting
95 | RTCPeerConnectionState::Connected
96 );
97
98 let was_connected = {
100 let mut flag = this.connected.write().await;
101 let prev = *flag;
102 *flag = is_connected;
103 prev
104 };
105
106 tracing::info!(
107 "🔄 WebRtcConnection peer state changed: {:?}, connected={}",
108 state,
109 is_connected
110 );
111
112 if was_connected
117 && matches!(
118 state,
119 RTCPeerConnectionState::Disconnected
120 | RTCPeerConnectionState::Failed
121 | RTCPeerConnectionState::Closed
122 )
123 {
124 tracing::info!(
125 "🔻 WebRtcConnection entering terminal state {:?}, calling close()",
126 state
127 );
128
129 if let Err(e) = this.close().await {
130 tracing::warn!("⚠️ WebRtcConnection::close() failed: {}", e);
131 }
132 }
133 })
134 }));
135 }
136
137 pub async fn connect(&self) -> NetworkResult<()> {
139 *self.connected.write().await = true;
140 Ok(())
141 }
142
143 #[inline]
145 pub fn is_connected(&self) -> bool {
146 *self.connected.blocking_read()
147 }
148
149 pub async fn close(&self) -> NetworkResult<()> {
151 *self.connected.write().await = false;
152 self.peer_connection.close().await?;
153
154 let mut channels = self.data_channels.write().await;
156 *channels = [None, None, None, None];
157
158 let mut tracks = self.media_tracks.write().await;
160 tracks.clear();
161
162 let mut seq_nums = self.track_sequence_numbers.write().await;
164 seq_nums.clear();
165
166 let mut ssrcs = self.track_ssrcs.write().await;
168 ssrcs.clear();
169
170 let mut cache = self.lane_cache.write().await;
172 *cache = [None, None, None, None];
173
174 tracing::info!("🔌 WebRtcConnection already Close");
175 Ok(())
176 }
177
178 fn get_data_channel_config(
180 payload_type: PayloadType,
181 ) -> webrtc::data_channel::data_channel_init::RTCDataChannelInit {
182 use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
183
184 let channel_id = payload_type as u16;
187
188 match payload_type {
190 PayloadType::RpcSignal | PayloadType::RpcReliable => {
191 RTCDataChannelInit {
193 ordered: Some(true),
194 max_retransmits: None,
195 max_packet_life_time: None,
196 protocol: Some("".to_string()),
197 negotiated: Some(channel_id),
198 }
199 }
200 PayloadType::StreamLatencyFirst => {
201 RTCDataChannelInit {
204 ordered: Some(false),
205 max_retransmits: Some(3),
206 max_packet_life_time: None,
207 protocol: Some("".to_string()),
208 negotiated: Some(channel_id),
209 }
210 }
211 _ => {
212 RTCDataChannelInit {
214 ordered: Some(true),
215 max_retransmits: None,
216 max_packet_life_time: None,
217 protocol: Some("".to_string()),
218 negotiated: Some(channel_id),
219 }
220 }
221 }
222 }
223}
224
225impl WebRtcConnection {
226 pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
228 if payload_type == PayloadType::MediaRtp {
230 return Err(NetworkError::NotImplemented(
231 "MediaTrack Lane requires stream_id, use get_media_lane() instead".to_string(),
232 ));
233 }
234
235 let idx = payload_type as usize;
236
237 {
239 let cache = self.lane_cache.read().await;
240 if let Some(lane) = &cache[idx] {
241 tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
242 return Ok(lane.clone());
243 }
244 }
245
246 let lane = self.create_lane_internal(payload_type).await?;
248
249 {
251 let mut cache = self.lane_cache.write().await;
252 cache[idx] = Some(lane.clone());
253 }
254
255 tracing::info!("✨ WebRtcConnection Createnew DataLane: {:?}", payload_type);
256
257 Ok(lane)
258 }
259
260 async fn create_lane_internal(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
262 if payload_type == PayloadType::MediaRtp {
264 return Err(NetworkError::NotImplemented(
265 "MediaTrack Lane not implemented in this method".to_string(),
266 ));
267 }
268
269 let mut channels = self.data_channels.write().await;
271
272 let label = format!("{payload_type:?}");
273 let dc_config = Self::get_data_channel_config(payload_type);
274
275 let data_channel = self
276 .peer_connection
277 .create_data_channel(&label, Some(dc_config))
278 .await?;
279
280 let (tx, rx) = mpsc::channel(100);
282
283 let tx_clone = tx.clone();
285 data_channel.on_message(Box::new(
286 move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
287 let data = msg.data;
289 let tx = tx_clone.clone();
290 Box::pin(async move {
291 if let Err(e) = tx.send(data).await {
292 tracing::warn!("❌ WebRTC DataChannel messageSend to Lane failure: {}", e);
293 }
294 })
295 },
296 ));
297
298 let idx = payload_type as usize;
300 channels[idx] = Some(Arc::clone(&data_channel));
301
302 Ok(DataLane::webrtc_data_channel(data_channel, rx))
304 }
305
306 pub async fn add_media_track(
319 &self,
320 track_id: String,
321 codec: &str,
322 media_type: &str,
323 ) -> NetworkResult<Arc<TrackLocalStaticRTP>> {
324 use webrtc::api::media_engine::MIME_TYPE_H264;
325 use webrtc::api::media_engine::MIME_TYPE_OPUS;
326 use webrtc::api::media_engine::MIME_TYPE_VP8;
327 use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
328
329 let mime_type = match (media_type, codec.to_uppercase().as_str()) {
331 ("video", "H264") => MIME_TYPE_H264,
332 ("video", "VP8") => MIME_TYPE_VP8,
333 ("audio", "OPUS") => MIME_TYPE_OPUS,
334 _ => {
335 return Err(NetworkError::WebRtcError(format!(
336 "Unsupported codec: {codec} for {media_type}"
337 )));
338 }
339 };
340
341 let track = Arc::new(TrackLocalStaticRTP::new(
343 RTCRtpCodecCapability {
344 mime_type: mime_type.to_string(),
345 ..Default::default()
346 },
347 track_id.clone(),
348 format!("actr-{media_type}"), ));
350
351 let rtp_sender =
353 self.peer_connection
354 .add_track(Arc::clone(&track)
355 as Arc<dyn webrtc::track::track_local::TrackLocal + Send + Sync>)
356 .await?;
357
358 let mut tracks = self.media_tracks.write().await;
360 tracks.insert(track_id.clone(), (Arc::clone(&track), rtp_sender));
361
362 let mut seq_nums = self.track_sequence_numbers.write().await;
364 seq_nums.insert(track_id.clone(), Arc::new(AtomicU16::new(0)));
365
366 let ssrc = rand::random::<u32>();
368 let mut ssrcs = self.track_ssrcs.write().await;
369 ssrcs.insert(track_id.clone(), ssrc);
370
371 tracing::info!(
372 "✨ Added media track: id={}, codec={}, type={}, ssrc=0x{:08x}",
373 track_id,
374 codec,
375 media_type,
376 ssrc
377 );
378
379 Ok(track)
380 }
381
382 pub async fn get_media_track(&self, track_id: &str) -> Option<Arc<TrackLocalStaticRTP>> {
384 let tracks = self.media_tracks.read().await;
385 tracks
386 .get(track_id)
387 .map(|(track, _sender)| Arc::clone(track))
388 }
389
390 pub async fn next_sequence_number(&self, track_id: &str) -> Option<u16> {
398 let seq_nums = self.track_sequence_numbers.read().await;
399 seq_nums
400 .get(track_id)
401 .map(|atomic_seq| atomic_seq.fetch_add(1, Ordering::SeqCst))
402 }
403
404 pub async fn get_ssrc(&self, track_id: &str) -> Option<u32> {
412 let ssrcs = self.track_ssrcs.read().await;
413 ssrcs.get(track_id).copied()
414 }
415
416 pub async fn create_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
423 self.get_lane(payload_type).await
424 }
425
426 pub async fn register_received_data_channel(
431 &self,
432 data_channel: Arc<RTCDataChannel>,
433 payload_type: PayloadType,
434 ) -> NetworkResult<DataLane> {
435 if payload_type == PayloadType::MediaRtp {
437 return Err(NetworkError::NotImplemented(
438 "MediaTrack Lane not supported in this method".to_string(),
439 ));
440 }
441
442 let idx = payload_type as usize;
443
444 let (tx, rx) = mpsc::channel(100);
446
447 let tx_clone = tx.clone();
449 data_channel.on_message(Box::new(
450 move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
451 let data = msg.data;
452 let tx = tx_clone.clone();
453 Box::pin(async move {
454 if let Err(e) = tx.send(data).await {
455 tracing::warn!("❌ WebRTC DataChannel message send to Lane failed: {}", e);
456 }
457 })
458 },
459 ));
460
461 {
463 let mut channels = self.data_channels.write().await;
464 channels[idx] = Some(Arc::clone(&data_channel));
465 }
466
467 let lane = DataLane::webrtc_data_channel(data_channel, rx);
469 {
470 let mut cache = self.lane_cache.write().await;
471 cache[idx] = Some(lane.clone());
472 }
473
474 tracing::info!(
475 "✨ WebRtcConnection registered received DataChannel: {:?}",
476 payload_type
477 );
478
479 Ok(lane)
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
490 fn test_data_channel_config() {
491 let config = WebRtcConnection::get_data_channel_config(PayloadType::RpcReliable);
492 assert_eq!(config.ordered, Some(true));
493
494 let config = WebRtcConnection::get_data_channel_config(PayloadType::StreamLatencyFirst);
495 assert_eq!(config.ordered, Some(false));
496 assert_eq!(config.max_retransmits, Some(3));
497 }
498}