1use crate::transport::DataLane;
4use crate::transport::connection_event::{ConnectionEvent, ConnectionState};
5use crate::transport::{NetworkError, NetworkResult};
6use actr_protocol::prost::Message;
7use actr_protocol::{ActrId, PayloadType};
8use bytes::Bytes;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU16, Ordering};
12use tokio::sync::{RwLock, broadcast, mpsc};
13use webrtc::data_channel::RTCDataChannel;
14use webrtc::peer_connection::{RTCPeerConnection, peer_connection_state::RTCPeerConnectionState};
15use webrtc::rtp_transceiver::rtp_sender::RTCRtpSender;
16use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
17
18type MediaTracks = Arc<RwLock<HashMap<String, (Arc<TrackLocalStaticRTP>, Arc<RTCRtpSender>)>>>;
20
21#[derive(Clone)]
23pub struct WebRtcConnection {
24 peer_id: ActrId,
26
27 peer_connection: Arc<RTCPeerConnection>,
29
30 data_channels: Arc<RwLock<[Option<Arc<RTCDataChannel>>; 4]>>,
34
35 media_tracks: MediaTracks,
37
38 track_sequence_numbers: Arc<RwLock<HashMap<String, Arc<AtomicU16>>>>,
40
41 track_ssrcs: Arc<RwLock<HashMap<String, u32>>>,
43
44 lane_cache: Arc<RwLock<[Option<DataLane>; 4]>>,
48
49 event_tx: broadcast::Sender<ConnectionEvent>,
51
52 connected: Arc<RwLock<bool>>,
54}
55
56impl std::fmt::Debug for WebRtcConnection {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("WebRtcConnection")
59 .field("peer_id", &self.peer_id)
60 .field("peer_connection", &"<RTCPeerConnection>")
61 .field("data_channels", &"<[Option<Arc<RTCDataChannel>>; 4]>")
62 .field("media_tracks", &"<HashMap<String, Arc<Track>>>")
63 .field("connected", &self.connected)
64 .finish()
65 }
66}
67
68impl WebRtcConnection {
69 pub fn new(
76 peer_id: ActrId,
77 peer_connection: Arc<RTCPeerConnection>,
78 event_tx: broadcast::Sender<ConnectionEvent>,
79 ) -> Self {
80 Self {
81 peer_id,
82 peer_connection,
83 data_channels: Arc::new(RwLock::new([None, None, None, None])),
84 media_tracks: Arc::new(RwLock::new(HashMap::new())),
85 track_sequence_numbers: Arc::new(RwLock::new(HashMap::new())),
86 track_ssrcs: Arc::new(RwLock::new(HashMap::new())),
87 lane_cache: Arc::new(RwLock::new([None, None, None, None])),
88 event_tx,
89 connected: Arc::new(RwLock::new(true)),
90 }
91 }
92
93 pub fn peer_id(&self) -> &ActrId {
95 &self.peer_id
96 }
97
98 pub(crate) async fn handle_state_change(&self, state: RTCPeerConnectionState) {
103 let is_connected = matches!(
105 state,
106 RTCPeerConnectionState::New
107 | RTCPeerConnectionState::Connecting
108 | RTCPeerConnectionState::Connected
109 );
110
111 let was_connected = {
113 let mut flag = self.connected.write().await;
114 let prev = *flag;
115 *flag = is_connected;
116 prev
117 };
118
119 let connection_state = match state {
121 RTCPeerConnectionState::New => ConnectionState::New,
122 RTCPeerConnectionState::Connecting => ConnectionState::Connecting,
123 RTCPeerConnectionState::Connected => ConnectionState::Connected,
124 RTCPeerConnectionState::Disconnected => ConnectionState::Disconnected,
125 RTCPeerConnectionState::Failed => ConnectionState::Failed,
126 RTCPeerConnectionState::Closed => ConnectionState::Closed,
127 _ => ConnectionState::Closed, };
129
130 tracing::info!(
131 "🔄 WebRtcConnection peer state changed: {:?}, connected={}",
132 state,
133 is_connected
134 );
135
136 let _ = self.event_tx.send(ConnectionEvent::StateChanged {
138 peer_id: self.peer_id.clone(),
139 state: connection_state.clone(),
140 });
141
142 if was_connected && matches!(state, RTCPeerConnectionState::Closed) {
146 tracing::info!(
147 "🔻 WebRtcConnection entering terminal state {:?}, calling close()",
148 state
149 );
150
151 if let Err(e) = self.close().await {
152 tracing::warn!("⚠️ WebRtcConnection::close() failed: {}", e);
153 }
154 }
155 }
156
157 pub fn install_state_change_handler(&self) {
163 let this = self.clone();
164
165 self.peer_connection
166 .on_peer_connection_state_change(Box::new(move |state: RTCPeerConnectionState| {
167 let this = this.clone();
168
169 Box::pin(async move {
170 this.handle_state_change(state).await;
171 })
172 }));
173 }
174
175 pub async fn connect(&self) -> NetworkResult<()> {
177 *self.connected.write().await = true;
178 Ok(())
179 }
180
181 fn notify_data_channel_closed(&self, payload_type: PayloadType) {
186 let _ = self.event_tx.send(ConnectionEvent::DataChannelClosed {
190 peer_id: self.peer_id.clone(),
191 payload_type,
192 });
193 }
194
195 pub fn subscribe_events(&self) -> broadcast::Receiver<ConnectionEvent> {
197 self.event_tx.subscribe()
198 }
199
200 #[inline]
202 pub fn is_connected(&self) -> bool {
203 *self.connected.blocking_read()
204 }
205
206 pub async fn close(&self) -> NetworkResult<()> {
208 *self.connected.write().await = false;
209 self.peer_connection.close().await?;
210
211 let mut channels = self.data_channels.write().await;
213 *channels = [None, None, None, None];
214
215 let mut tracks = self.media_tracks.write().await;
217 tracks.clear();
218
219 let mut seq_nums = self.track_sequence_numbers.write().await;
221 seq_nums.clear();
222
223 let mut ssrcs = self.track_ssrcs.write().await;
225 ssrcs.clear();
226
227 let mut cache = self.lane_cache.write().await;
229 *cache = [None, None, None, None];
230
231 let _ = self.event_tx.send(ConnectionEvent::ConnectionClosed {
233 peer_id: self.peer_id.clone(),
234 });
235
236 tracing::info!("🔌 WebRtcConnection closed for peer {:?}", self.peer_id);
237 Ok(())
238 }
239
240 fn get_data_channel_config(
242 payload_type: &PayloadType,
243 ) -> webrtc::data_channel::data_channel_init::RTCDataChannelInit {
244 use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
245
246 match payload_type {
247 PayloadType::StreamLatencyFirst => {
248 RTCDataChannelInit {
250 ordered: Some(false),
251 max_retransmits: Some(3),
252 max_packet_life_time: None,
253 protocol: Some("".to_string()),
254 negotiated: None,
255 }
256 }
257 _ => {
258 RTCDataChannelInit {
260 ordered: Some(true),
261 max_retransmits: None,
262 max_packet_life_time: None,
263 protocol: Some("".to_string()),
264 negotiated: None,
265 }
266 }
267 }
268 }
269}
270
271impl WebRtcConnection {
272 pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
274 if payload_type == PayloadType::MediaRtp {
276 return Err(NetworkError::NotImplemented(
277 "MediaTrack Lane requires stream_id, use get_media_lane() instead".to_string(),
278 ));
279 }
280
281 let idx = payload_type as usize;
282
283 let mut need_recreate = false;
285 {
286 let cache = self.lane_cache.read().await;
287 if let Some(lane) = &cache[idx] {
288 if let DataLane::WebRtcDataChannel { data_channel, .. } = lane {
290 use webrtc::data_channel::data_channel_state::RTCDataChannelState;
291 let state = data_channel.ready_state();
292 if matches!(
293 state,
294 RTCDataChannelState::Closed | RTCDataChannelState::Closing
295 ) {
296 tracing::warn!(
297 "♻️ Cached DataChannel for {:?} is {:?}, recreating lane",
298 payload_type,
299 state
300 );
301 need_recreate = true;
302 } else {
303 tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
304 return Ok(lane.clone());
305 }
306 } else {
307 tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
308 return Ok(lane.clone());
309 }
310 }
311 }
312
313 if need_recreate {
314 let mut cache = self.lane_cache.write().await;
316 cache[idx] = None;
317 let mut channels = self.data_channels.write().await;
318 channels[idx] = None;
319 }
320
321 let lane = self.create_lane_internal(payload_type).await?;
323
324 {
326 let mut cache = self.lane_cache.write().await;
327 cache[idx] = Some(lane.clone());
328 }
329
330 tracing::info!("✨ WebRtcConnection Createnew DataLane: {:?}", payload_type);
331
332 Ok(lane)
333 }
334
335 pub async fn invalidate_lane(&self, payload_type: PayloadType) {
340 let idx = payload_type as usize;
341 let mut cache = self.lane_cache.write().await;
342 cache[idx] = None;
343 let mut channels = self.data_channels.write().await;
344 channels[idx] = None;
345 }
346
347 async fn create_lane_internal(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
349 if payload_type == PayloadType::MediaRtp {
351 return Err(NetworkError::NotImplemented(
352 "MediaTrack Lane not implemented in this method".to_string(),
353 ));
354 }
355
356 let mut channels = self.data_channels.write().await;
358
359 let label = payload_type.as_str_name();
360
361 let dc_config = Self::get_data_channel_config(&payload_type);
362 let data_channel = self
363 .peer_connection
364 .create_data_channel(&label, Some(dc_config))
365 .await?;
366
367 data_channel.on_open(Box::new(move || {
368 tracing::info!("🔄 WebRTC DataChannel opened: {:?}", payload_type);
369 Box::pin(async move {})
370 }));
371
372 let channel_id = data_channel.id();
373 let payload_type_for_error = payload_type;
374 let label_for_error = label;
375 data_channel.on_error(Box::new(move |error| {
376 let payload_type = payload_type_for_error;
377 let label = label_for_error;
378 let channel_id = channel_id;
379 tracing::warn!(
380 "⚠️ WebRTC DataChannel error [{}] (payload_type={:?}, channel_id={}): {:?}",
381 label,
382 payload_type,
383 channel_id,
384 error
385 );
386 Box::pin(async move {})
387 }));
388
389 let this_for_close = self.clone();
390 let payload_type_for_close = payload_type;
391 let label_for_close = label;
392 let channel_id_for_close = channel_id;
393 data_channel.on_close(Box::new(move || {
394 let this = this_for_close.clone();
395 let payload_type = payload_type_for_close;
396 let label = label_for_close;
397 let channel_id = channel_id_for_close;
398 Box::pin(async move {
399 tracing::warn!(
400 "⚠️ WebRTC DataChannel closed [{}] (payload_type={:?}, channel_id={})",
401 label,
402 payload_type,
403 channel_id
404 );
405 this.invalidate_lane(payload_type).await;
407 this.notify_data_channel_closed(payload_type);
409 })
410 }));
411
412 let (tx, rx) = mpsc::channel(100);
414
415 let tx_clone = tx.clone();
417 data_channel.on_message(Box::new(
418 move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
419 let data = msg.data;
421 tracing::debug!("🔄 WebRTC DataChannel message received1111: {:?}", data);
422 let tx = tx_clone.clone();
423 Box::pin(async move {
424 if let Err(e) = tx.send(data).await {
425 tracing::warn!("❌ WebRTC DataChannel messageSend to Lane failure: {}", e);
426 }
427 })
428 },
429 ));
430
431 let idx = payload_type as usize;
433 channels[idx] = Some(Arc::clone(&data_channel));
434
435 Ok(DataLane::webrtc_data_channel(data_channel, rx))
437 }
438
439 pub async fn add_media_track(
452 &self,
453 track_id: String,
454 codec: &str,
455 media_type: &str,
456 ) -> NetworkResult<Arc<TrackLocalStaticRTP>> {
457 use webrtc::api::media_engine::MIME_TYPE_H264;
458 use webrtc::api::media_engine::MIME_TYPE_OPUS;
459 use webrtc::api::media_engine::MIME_TYPE_VP8;
460 use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
461
462 let mime_type = match (media_type, codec.to_uppercase().as_str()) {
464 ("video", "H264") => MIME_TYPE_H264,
465 ("video", "VP8") => MIME_TYPE_VP8,
466 ("audio", "OPUS") => MIME_TYPE_OPUS,
467 _ => {
468 return Err(NetworkError::WebRtcError(format!(
469 "Unsupported codec: {codec} for {media_type}"
470 )));
471 }
472 };
473
474 let track = Arc::new(TrackLocalStaticRTP::new(
476 RTCRtpCodecCapability {
477 mime_type: mime_type.to_string(),
478 ..Default::default()
479 },
480 track_id.clone(),
481 format!("actr-{media_type}"), ));
483
484 let rtp_sender =
486 self.peer_connection
487 .add_track(Arc::clone(&track)
488 as Arc<dyn webrtc::track::track_local::TrackLocal + Send + Sync>)
489 .await?;
490
491 let mut tracks = self.media_tracks.write().await;
493 tracks.insert(track_id.clone(), (Arc::clone(&track), rtp_sender));
494
495 let mut seq_nums = self.track_sequence_numbers.write().await;
497 seq_nums.insert(track_id.clone(), Arc::new(AtomicU16::new(0)));
498
499 let ssrc = rand::random::<u32>();
501 let mut ssrcs = self.track_ssrcs.write().await;
502 ssrcs.insert(track_id.clone(), ssrc);
503
504 tracing::info!(
505 "✨ Added media track: id={}, codec={}, type={}, ssrc=0x{:08x}",
506 track_id,
507 codec,
508 media_type,
509 ssrc
510 );
511
512 Ok(track)
513 }
514
515 pub async fn get_media_track(&self, track_id: &str) -> Option<Arc<TrackLocalStaticRTP>> {
517 let tracks = self.media_tracks.read().await;
518 tracks
519 .get(track_id)
520 .map(|(track, _sender)| Arc::clone(track))
521 }
522
523 pub async fn next_sequence_number(&self, track_id: &str) -> Option<u16> {
531 let seq_nums = self.track_sequence_numbers.read().await;
532 seq_nums
533 .get(track_id)
534 .map(|atomic_seq| atomic_seq.fetch_add(1, Ordering::SeqCst))
535 }
536
537 pub async fn get_ssrc(&self, track_id: &str) -> Option<u32> {
545 let ssrcs = self.track_ssrcs.read().await;
546 ssrcs.get(track_id).copied()
547 }
548
549 pub async fn create_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
556 self.get_lane(payload_type).await
557 }
558
559 pub async fn register_received_data_channel(
564 &self,
565 data_channel: Arc<RTCDataChannel>,
566 payload_type: PayloadType,
567 message_tx: mpsc::UnboundedSender<(Vec<u8>, Bytes, PayloadType)>,
568 ) -> NetworkResult<DataLane> {
569 if payload_type == PayloadType::MediaRtp {
571 return Err(NetworkError::NotImplemented(
572 "MediaTrack Lane not supported in this method".to_string(),
573 ));
574 }
575
576 let idx = payload_type as usize;
577 tracing::debug!(
578 "🔄 WebRTC DataChannel registered received: {:?}, idx={}",
579 payload_type,
580 idx
581 );
582 let label = format!("{payload_type:?}");
583
584 let payload_type_for_error = payload_type;
586 let label_for_error = label.clone();
587 data_channel.on_error(Box::new(move |error| {
588 let payload_type = payload_type_for_error;
589 let label = label_for_error.clone();
590 tracing::warn!(
591 "⚠️ WebRTC DataChannel error [{}] (payload_type={:?} ): {:?}",
592 label,
593 payload_type,
594 error
595 );
596 Box::pin(async move {})
597 }));
598
599 let this_for_close = self.clone();
601 let payload_type_for_close = payload_type;
602 let label_for_close = label.clone();
603
604 data_channel.on_close(Box::new(move || {
605 let this = this_for_close.clone();
606 let payload_type = payload_type_for_close;
607 let label = label_for_close.clone();
608
609 Box::pin(async move {
610 tracing::warn!(
611 "⚠️ WebRTC DataChannel closed [{}] (payload_type={:?})",
612 label,
613 payload_type,
614 );
615 this.invalidate_lane(payload_type).await;
617 this.notify_data_channel_closed(payload_type);
619 })
620 }));
621
622 let (tx, rx) = mpsc::channel(100);
624
625 let tx_clone = tx.clone();
627 data_channel.on_message(Box::new(
628 move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
629 let data = msg.data;
630 let tx = tx_clone.clone();
631 Box::pin(async move {
632 if let Err(e) = tx.send(data).await {
633 tracing::warn!("❌ WebRTC DataChannel message send to Lane failed: {}", e);
634 }
635 })
636 },
637 ));
638
639 {
641 let mut channels = self.data_channels.write().await;
642 channels[idx] = Some(Arc::clone(&data_channel));
643 }
644
645 let lane = DataLane::webrtc_data_channel(data_channel, rx);
647 {
648 let mut cache = self.lane_cache.write().await;
649 cache[idx] = Some(lane.clone());
650 }
651
652 tracing::info!(
653 "✨ WebRtcConnection registered received DataChannel: {:?}",
654 payload_type
655 );
656 let peer_id_clone = self.peer_id.clone();
657 let lane_clone = lane.clone();
658 tokio::spawn(async move {
659 loop {
661 match lane_clone.recv().await {
662 Ok(data) => {
663 tracing::debug!(
664 "📨 Received message from {:?} (PayloadType: {:?}): {} bytes",
665 peer_id_clone,
666 payload_type,
667 data.len()
668 );
669
670 let peer_id_bytes = peer_id_clone.encode_to_vec();
672
673 if let Err(e) = message_tx.send((peer_id_bytes, data, payload_type)) {
675 tracing::error!("❌ Message aggregation failed: {:?}", e);
676 break;
677 }
678 }
679 Err(e) => {
680 tracing::warn!(
681 "❌ Peer {:?} message receive failed (PayloadType: {:?}): {}",
682 peer_id_clone,
683 payload_type,
684 e
685 );
686 break;
687 }
688 }
689 }
690 });
691
692 Ok(lane)
693 }
694}