actr_runtime/wire/webrtc/
connection.rs

1//! WebRTC P2P Connection implementation
2
3use crate::transport::DataLane;
4use crate::transport::{NetworkError, NetworkResult};
5use actr_protocol::PayloadType;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU16, Ordering};
9use tokio::sync::{RwLock, mpsc};
10use webrtc::data_channel::RTCDataChannel;
11use webrtc::peer_connection::RTCPeerConnection;
12use webrtc::rtp_transceiver::rtp_sender::RTCRtpSender;
13use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
14
15/// Type alias for media track storage (track_id → (Track, Sender))
16type MediaTracks = Arc<RwLock<HashMap<String, (Arc<TrackLocalStaticRTP>, Arc<RTCRtpSender>)>>>;
17
18/// WebRtcConnection - WebRTC P2P Connect
19#[derive(Clone)]
20pub struct WebRtcConnection {
21    /// underlying RTCPeerConnection
22    peer_connection: Arc<RTCPeerConnection>,
23
24    // TODO: useless property, remove this
25    /// DataChannel Cache:PayloadType → DataChannel(4 types use DataChannel)
26    /// index reference mapping:RpcReliable(0), RpcSignal(1), StreamReliable(2), StreamLatencyFirst(3)
27    data_channels: Arc<RwLock<[Option<Arc<RTCDataChannel>>; 4]>>,
28
29    /// MediaTrack Cache:track_id → (Track, RtpSender)
30    media_tracks: MediaTracks,
31
32    /// RTP sequence numbers per track (track_id → sequence_number)
33    track_sequence_numbers: Arc<RwLock<HashMap<String, Arc<AtomicU16>>>>,
34
35    /// RTP SSRC per track (track_id → ssrc)
36    track_ssrcs: Arc<RwLock<HashMap<String, u32>>>,
37
38    /// Lane Cache:PayloadType → Lane( merely 3 solely proportion Type)
39    /// index reference mapping:RpcReliable(0), RpcSignal(1), StreamReliable(2), StreamLatencyFirst(3)
40    /// MediaTrack not Cachein array in ,using HashMap
41    lane_cache: Arc<RwLock<[Option<DataLane>; 4]>>,
42
43    /// connection status
44    connected: Arc<RwLock<bool>>,
45}
46
47impl std::fmt::Debug for WebRtcConnection {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("WebRtcConnection")
50            .field("peer_connection", &"<RTCPeerConnection>")
51            .field("data_channels", &"<[Option<Arc<RTCDataChannel>>; 4]>")
52            .field("media_tracks", &"<HashMap<String, Arc<Track>>>")
53            .field("connected", &self.connected)
54            .finish()
55    }
56}
57
58impl WebRtcConnection {
59    /// from RTCPeerConnection CreateConnect
60    ///
61    /// # Arguments
62    /// - `peer_connection`: Arc package pack 's RTCPeerConnection
63    pub fn new(peer_connection: Arc<RTCPeerConnection>) -> Self {
64        Self {
65            peer_connection,
66            data_channels: Arc::new(RwLock::new([None, None, None, None])),
67            media_tracks: Arc::new(RwLock::new(HashMap::new())),
68            track_sequence_numbers: Arc::new(RwLock::new(HashMap::new())),
69            track_ssrcs: Arc::new(RwLock::new(HashMap::new())),
70            lane_cache: Arc::new(RwLock::new([None, None, None, None])),
71            connected: Arc::new(RwLock::new(true)),
72        }
73    }
74
75    /// Install a state-change handler on the underlying RTCPeerConnection.
76    ///
77    /// This keeps `connected` in sync with the WebRTC connection state and
78    /// proactively closes the PeerConnection and clears internal caches when
79    /// entering a terminal state (Disconnected/Failed/Closed).
80    pub fn install_state_change_handler(&self) {
81        use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
82
83        let this = self.clone();
84
85        self.peer_connection
86            .on_peer_connection_state_change(Box::new(move |state: RTCPeerConnectionState| {
87                let this = this.clone();
88
89                Box::pin(async move {
90                    // Treat New/Connecting/Connected as "connected"; others as disconnected.
91                    let is_connected = matches!(
92                        state,
93                        RTCPeerConnectionState::New
94                            | RTCPeerConnectionState::Connecting
95                            | RTCPeerConnectionState::Connected
96                    );
97
98                    // Update flag and detect transitions from connected -> disconnected.
99                    let was_connected = {
100                        let mut flag = this.connected.write().await;
101                        let prev = *flag;
102                        *flag = is_connected;
103                        prev
104                    };
105
106                    tracing::info!(
107                        "🔄 WebRtcConnection peer state changed: {:?}, connected={}",
108                        state,
109                        is_connected
110                    );
111
112                    // For terminal states, proactively close the connection and let
113                    // `close()` perform all resource cleanup. Only trigger when we
114                    // transition from connected -> disconnected to avoid loops.
115                    // not support ice restart, FIXME: ice restart
116                    if was_connected
117                        && matches!(
118                            state,
119                            RTCPeerConnectionState::Disconnected
120                                | RTCPeerConnectionState::Failed
121                                | RTCPeerConnectionState::Closed
122                        )
123                    {
124                        tracing::info!(
125                            "🔻 WebRtcConnection entering terminal state {:?}, calling close()",
126                            state
127                        );
128
129                        if let Err(e) = this.close().await {
130                            tracing::warn!("⚠️ WebRtcConnection::close() failed: {}", e);
131                        }
132                    }
133                })
134            }));
135    }
136
137    /// establish Connect(WebRTC Connect already alreadyvia signaling establish , this in only is mark record )
138    pub async fn connect(&self) -> NetworkResult<()> {
139        *self.connected.write().await = true;
140        Ok(())
141    }
142
143    /// Checkwhether already Connect
144    #[inline]
145    pub fn is_connected(&self) -> bool {
146        *self.connected.blocking_read()
147    }
148
149    /// CloseConnect
150    pub async fn close(&self) -> NetworkResult<()> {
151        *self.connected.write().await = false;
152        self.peer_connection.close().await?;
153
154        // clear blank DataChannel Cache
155        let mut channels = self.data_channels.write().await;
156        *channels = [None, None, None, None];
157
158        // clear blank MediaTrack Cache
159        let mut tracks = self.media_tracks.write().await;
160        tracks.clear();
161
162        // clear blank sequence number cache
163        let mut seq_nums = self.track_sequence_numbers.write().await;
164        seq_nums.clear();
165
166        // clear blank SSRC cache
167        let mut ssrcs = self.track_ssrcs.write().await;
168        ssrcs.clear();
169
170        // clear blank Lane Cache
171        let mut cache = self.lane_cache.write().await;
172        *cache = [None, None, None, None];
173
174        tracing::info!("🔌 WebRtcConnection already Close");
175        Ok(())
176    }
177
178    /// based on PayloadType configuration DataChannel
179    fn get_data_channel_config(
180        payload_type: PayloadType,
181    ) -> webrtc::data_channel::data_channel_init::RTCDataChannelInit {
182        use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
183
184        // Use negotiated DataChannel with fixed IDs based on PayloadType
185        // This allows both sides to create the same channel without on_data_channel callback
186        let channel_id = payload_type as u16;
187
188        // TODO: remove negotiated flag to use auto-negotiation
189        match payload_type {
190            PayloadType::RpcSignal | PayloadType::RpcReliable => {
191                // reliable ordered transmission
192                RTCDataChannelInit {
193                    ordered: Some(true),
194                    max_retransmits: None,
195                    max_packet_life_time: None,
196                    protocol: Some("".to_string()),
197                    negotiated: Some(channel_id),
198                }
199            }
200            PayloadType::StreamLatencyFirst => {
201                // partial reliable transmission (low latency priority)
202                // NOTE: WebRTC spec forbids setting both max_retransmits and max_packet_life_time.
203                RTCDataChannelInit {
204                    ordered: Some(false),
205                    max_retransmits: Some(3),
206                    max_packet_life_time: None,
207                    protocol: Some("".to_string()),
208                    negotiated: Some(channel_id),
209                }
210            }
211            _ => {
212                // default reliable transmission
213                RTCDataChannelInit {
214                    ordered: Some(true),
215                    max_retransmits: None,
216                    max_packet_life_time: None,
217                    protocol: Some("".to_string()),
218                    negotiated: Some(channel_id),
219                }
220            }
221        }
222    }
223}
224
225impl WebRtcConnection {
226    /// GetorCreate DataLane( carry Cache)
227    pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
228        // MediaTrack not Supportin this Method in Create(need stream_id)
229        if payload_type == PayloadType::MediaRtp {
230            return Err(NetworkError::NotImplemented(
231                "MediaTrack Lane requires stream_id, use get_media_lane() instead".to_string(),
232            ));
233        }
234
235        let idx = payload_type as usize;
236
237        // 1. CheckCache
238        {
239            let cache = self.lane_cache.read().await;
240            if let Some(lane) = &cache[idx] {
241                tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
242                return Ok(lane.clone());
243            }
244        }
245
246        // 2. Createnew DataLane
247        let lane = self.create_lane_internal(payload_type).await?;
248
249        // 3. Cache
250        {
251            let mut cache = self.lane_cache.write().await;
252            cache[idx] = Some(lane.clone());
253        }
254
255        tracing::info!("✨ WebRtcConnection Createnew DataLane: {:?}", payload_type);
256
257        Ok(lane)
258    }
259
260    /// inner part Method:Create DataChannel Lane( not carry Cache)
261    async fn create_lane_internal(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
262        // Checkwhetheras MediaTrack Type
263        if payload_type == PayloadType::MediaRtp {
264            return Err(NetworkError::NotImplemented(
265                "MediaTrack Lane not implemented in this method".to_string(),
266            ));
267        }
268
269        // Create new DataChannel
270        let mut channels = self.data_channels.write().await;
271
272        let label = format!("{payload_type:?}");
273        let dc_config = Self::get_data_channel_config(payload_type);
274
275        let data_channel = self
276            .peer_connection
277            .create_data_channel(&label, Some(dc_config))
278            .await?;
279
280        // CreateReceive channel (using Bytes)
281        let (tx, rx) = mpsc::channel(100);
282
283        // Set onmessage return adjust
284        let tx_clone = tx.clone();
285        data_channel.on_message(Box::new(
286            move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
287                // zero-copy: directly using msg.data (Bytes)
288                let data = msg.data;
289                let tx = tx_clone.clone();
290                Box::pin(async move {
291                    if let Err(e) = tx.send(data).await {
292                        tracing::warn!("❌ WebRTC DataChannel messageSend to Lane failure: {}", e);
293                    }
294                })
295            },
296        ));
297
298        // Cache DataChannel( index reference directly using PayloadType value )
299        let idx = payload_type as usize;
300        channels[idx] = Some(Arc::clone(&data_channel));
301
302        // Returns Lane
303        Ok(DataLane::webrtc_data_channel(data_channel, rx))
304    }
305
306    /// Add media track to PeerConnection
307    ///
308    /// # Arguments
309    /// - `track_id`: Unique track identifier
310    /// - `codec`: Codec name (e.g., "H264", "VP8", "opus")
311    /// - `media_type`: "video" or "audio"
312    ///
313    /// # Returns
314    /// Reference to the created TrackLocalStaticRTP
315    ///
316    /// # Note
317    /// Must be called BEFORE create_offer/create_answer for track to appear in SDP
318    pub async fn add_media_track(
319        &self,
320        track_id: String,
321        codec: &str,
322        media_type: &str,
323    ) -> NetworkResult<Arc<TrackLocalStaticRTP>> {
324        use webrtc::api::media_engine::MIME_TYPE_H264;
325        use webrtc::api::media_engine::MIME_TYPE_OPUS;
326        use webrtc::api::media_engine::MIME_TYPE_VP8;
327        use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
328
329        // Determine MIME type based on codec and media_type
330        let mime_type = match (media_type, codec.to_uppercase().as_str()) {
331            ("video", "H264") => MIME_TYPE_H264,
332            ("video", "VP8") => MIME_TYPE_VP8,
333            ("audio", "OPUS") => MIME_TYPE_OPUS,
334            _ => {
335                return Err(NetworkError::WebRtcError(format!(
336                    "Unsupported codec: {codec} for {media_type}"
337                )));
338            }
339        };
340
341        // Create TrackLocalStaticRTP
342        let track = Arc::new(TrackLocalStaticRTP::new(
343            RTCRtpCodecCapability {
344                mime_type: mime_type.to_string(),
345                ..Default::default()
346            },
347            track_id.clone(),
348            format!("actr-{media_type}"), // stream_id
349        ));
350
351        // Add track to PeerConnection
352        let rtp_sender =
353            self.peer_connection
354                .add_track(Arc::clone(&track)
355                    as Arc<dyn webrtc::track::track_local::TrackLocal + Send + Sync>)
356                .await?;
357
358        // Cache track and sender
359        let mut tracks = self.media_tracks.write().await;
360        tracks.insert(track_id.clone(), (Arc::clone(&track), rtp_sender));
361
362        // Initialize sequence number for this track
363        let mut seq_nums = self.track_sequence_numbers.write().await;
364        seq_nums.insert(track_id.clone(), Arc::new(AtomicU16::new(0)));
365
366        // Generate unique SSRC for this track (random u32)
367        let ssrc = rand::random::<u32>();
368        let mut ssrcs = self.track_ssrcs.write().await;
369        ssrcs.insert(track_id.clone(), ssrc);
370
371        tracing::info!(
372            "✨ Added media track: id={}, codec={}, type={}, ssrc=0x{:08x}",
373            track_id,
374            codec,
375            media_type,
376            ssrc
377        );
378
379        Ok(track)
380    }
381
382    /// Get existing media track by ID
383    pub async fn get_media_track(&self, track_id: &str) -> Option<Arc<TrackLocalStaticRTP>> {
384        let tracks = self.media_tracks.read().await;
385        tracks
386            .get(track_id)
387            .map(|(track, _sender)| Arc::clone(track))
388    }
389
390    /// Get next RTP sequence number for track (atomically increments)
391    ///
392    /// # Arguments
393    /// - `track_id`: Track identifier
394    ///
395    /// # Returns
396    /// Next sequence number (wraps at 65535)
397    pub async fn next_sequence_number(&self, track_id: &str) -> Option<u16> {
398        let seq_nums = self.track_sequence_numbers.read().await;
399        seq_nums
400            .get(track_id)
401            .map(|atomic_seq| atomic_seq.fetch_add(1, Ordering::SeqCst))
402    }
403
404    /// Get SSRC for track
405    ///
406    /// # Arguments
407    /// - `track_id`: Track identifier
408    ///
409    /// # Returns
410    /// SSRC value for this track
411    pub async fn get_ssrc(&self, track_id: &str) -> Option<u32> {
412        let ssrcs = self.track_ssrcs.read().await;
413        ssrcs.get(track_id).copied()
414    }
415
416    /// GetorCreate MediaTrack Lane( carry Cache)
417    ///
418    /// # Arguments
419    /// - `_stream_id`: Media stream ID
420    ///
421    /// backwardaftercompatible hold Method:create_lane adjust usage get_lane
422    pub async fn create_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
423        self.get_lane(payload_type).await
424    }
425
426    /// Register received DataChannel (for passive side)
427    ///
428    /// When receiving an Offer, the passive side should register DataChannels
429    /// received via on_data_channel callback instead of creating new ones.
430    pub async fn register_received_data_channel(
431        &self,
432        data_channel: Arc<RTCDataChannel>,
433        payload_type: PayloadType,
434    ) -> NetworkResult<DataLane> {
435        // Check if it's MediaTrack type
436        if payload_type == PayloadType::MediaRtp {
437            return Err(NetworkError::NotImplemented(
438                "MediaTrack Lane not supported in this method".to_string(),
439            ));
440        }
441
442        let idx = payload_type as usize;
443
444        // Create receive channel
445        let (tx, rx) = mpsc::channel(100);
446
447        // Set on_message callback
448        let tx_clone = tx.clone();
449        data_channel.on_message(Box::new(
450            move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
451                let data = msg.data;
452                let tx = tx_clone.clone();
453                Box::pin(async move {
454                    if let Err(e) = tx.send(data).await {
455                        tracing::warn!("❌ WebRTC DataChannel message send to Lane failed: {}", e);
456                    }
457                })
458            },
459        ));
460
461        // Cache DataChannel
462        {
463            let mut channels = self.data_channels.write().await;
464            channels[idx] = Some(Arc::clone(&data_channel));
465        }
466
467        // Create and cache Lane
468        let lane = DataLane::webrtc_data_channel(data_channel, rx);
469        {
470            let mut cache = self.lane_cache.write().await;
471            cache[idx] = Some(lane.clone());
472        }
473
474        tracing::info!(
475            "✨ WebRtcConnection registered received DataChannel: {:?}",
476            payload_type
477        );
478
479        Ok(lane)
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    // Note:WebRTC gather integrate measure try needCompletesignaling stream process , this in only do solely element measure try
488
489    #[test]
490    fn test_data_channel_config() {
491        let config = WebRtcConnection::get_data_channel_config(PayloadType::RpcReliable);
492        assert_eq!(config.ordered, Some(true));
493
494        let config = WebRtcConnection::get_data_channel_config(PayloadType::StreamLatencyFirst);
495        assert_eq!(config.ordered, Some(false));
496        assert_eq!(config.max_retransmits, Some(3));
497    }
498}