actr_runtime/wire/webrtc/
connection.rs

1//! WebRTC P2P Connection implementation
2
3use crate::transport::DataLane;
4use crate::transport::{NetworkError, NetworkResult};
5use actr_protocol::PayloadType;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU16, Ordering};
9use tokio::sync::{RwLock, mpsc};
10use webrtc::data_channel::RTCDataChannel;
11use webrtc::peer_connection::RTCPeerConnection;
12use webrtc::rtp_transceiver::rtp_sender::RTCRtpSender;
13use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
14
15/// Type alias for media track storage (track_id → (Track, Sender))
16type MediaTracks = Arc<RwLock<HashMap<String, (Arc<TrackLocalStaticRTP>, Arc<RTCRtpSender>)>>>;
17
18/// WebRtcConnection - WebRTC P2P Connect
19#[derive(Clone)]
20pub struct WebRtcConnection {
21    /// underlying RTCPeerConnection
22    peer_connection: Arc<RTCPeerConnection>,
23
24    /// DataChannel Cache:PayloadType → DataChannel(4 types use DataChannel)
25    /// index reference mapping:RpcReliable(0), RpcSignal(1), StreamReliable(2), StreamLatencyFirst(3)
26    data_channels: Arc<RwLock<[Option<Arc<RTCDataChannel>>; 4]>>,
27
28    /// MediaTrack Cache:track_id → (Track, RtpSender)
29    media_tracks: MediaTracks,
30
31    /// RTP sequence numbers per track (track_id → sequence_number)
32    track_sequence_numbers: Arc<RwLock<HashMap<String, Arc<AtomicU16>>>>,
33
34    /// RTP SSRC per track (track_id → ssrc)
35    track_ssrcs: Arc<RwLock<HashMap<String, u32>>>,
36
37    /// Lane Cache:PayloadType → Lane( merely 3 solely proportion Type)
38    /// index reference mapping:RpcReliable(0), RpcSignal(1), StreamReliable(2), StreamLatencyFirst(3)
39    /// MediaTrack not Cachein array in ,using HashMap
40    lane_cache: Arc<RwLock<[Option<DataLane>; 4]>>,
41
42    /// connection status
43    connected: Arc<RwLock<bool>>,
44}
45
46impl std::fmt::Debug for WebRtcConnection {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("WebRtcConnection")
49            .field("peer_connection", &"<RTCPeerConnection>")
50            .field("data_channels", &"<[Option<Arc<RTCDataChannel>>; 4]>")
51            .field("media_tracks", &"<HashMap<String, Arc<Track>>>")
52            .field("connected", &self.connected)
53            .finish()
54    }
55}
56
57impl WebRtcConnection {
58    /// from RTCPeerConnection CreateConnect
59    ///
60    /// # Arguments
61    /// - `peer_connection`: Arc package pack 's RTCPeerConnection
62    pub fn new(peer_connection: Arc<RTCPeerConnection>) -> Self {
63        Self {
64            peer_connection,
65            data_channels: Arc::new(RwLock::new([None, None, None, None])),
66            media_tracks: Arc::new(RwLock::new(HashMap::new())),
67            track_sequence_numbers: Arc::new(RwLock::new(HashMap::new())),
68            track_ssrcs: Arc::new(RwLock::new(HashMap::new())),
69            lane_cache: Arc::new(RwLock::new([None, None, None, None])),
70            connected: Arc::new(RwLock::new(true)),
71        }
72    }
73
74    /// establish Connect(WebRTC Connect already alreadyvia signaling establish , this in only is mark record )
75    pub async fn connect(&self) -> NetworkResult<()> {
76        *self.connected.write().await = true;
77        Ok(())
78    }
79
80    /// Checkwhether already Connect
81    #[inline]
82    pub fn is_connected(&self) -> bool {
83        *self.connected.blocking_read()
84    }
85
86    /// CloseConnect
87    pub async fn close(&self) -> NetworkResult<()> {
88        *self.connected.write().await = false;
89        self.peer_connection.close().await?;
90
91        // clear blank DataChannel Cache
92        let mut channels = self.data_channels.write().await;
93        *channels = [None, None, None, None];
94
95        // clear blank MediaTrack Cache
96        let mut tracks = self.media_tracks.write().await;
97        tracks.clear();
98
99        // clear blank sequence number cache
100        let mut seq_nums = self.track_sequence_numbers.write().await;
101        seq_nums.clear();
102
103        // clear blank SSRC cache
104        let mut ssrcs = self.track_ssrcs.write().await;
105        ssrcs.clear();
106
107        // clear blank Lane Cache
108        let mut cache = self.lane_cache.write().await;
109        *cache = [None, None, None, None];
110
111        tracing::info!("🔌 WebRtcConnection already Close");
112        Ok(())
113    }
114
115    /// based on PayloadType configuration DataChannel
116    fn get_data_channel_config(
117        payload_type: PayloadType,
118    ) -> webrtc::data_channel::data_channel_init::RTCDataChannelInit {
119        use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
120
121        // Use negotiated DataChannel with fixed IDs based on PayloadType
122        // This allows both sides to create the same channel without on_data_channel callback
123        let channel_id = payload_type as u16;
124
125        match payload_type {
126            PayloadType::RpcSignal | PayloadType::RpcReliable => {
127                // reliable ordered transmission
128                RTCDataChannelInit {
129                    ordered: Some(true),
130                    max_retransmits: None,
131                    max_packet_life_time: None,
132                    protocol: Some("".to_string()),
133                    negotiated: Some(channel_id),
134                }
135            }
136            PayloadType::StreamLatencyFirst => {
137                // partial reliable transmission (low latency priority)
138                RTCDataChannelInit {
139                    ordered: Some(false),
140                    max_retransmits: Some(3),
141                    max_packet_life_time: Some(100),
142                    protocol: Some("".to_string()),
143                    negotiated: Some(channel_id),
144                }
145            }
146            _ => {
147                // default reliable transmission
148                RTCDataChannelInit {
149                    ordered: Some(true),
150                    max_retransmits: None,
151                    max_packet_life_time: None,
152                    protocol: Some("".to_string()),
153                    negotiated: Some(channel_id),
154                }
155            }
156        }
157    }
158}
159
160impl WebRtcConnection {
161    /// GetorCreate DataLane( carry Cache)
162    pub async fn get_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
163        // MediaTrack not Supportin this Method in Create(need stream_id)
164        if payload_type == PayloadType::MediaRtp {
165            return Err(NetworkError::NotImplemented(
166                "MediaTrack Lane requires stream_id, use get_media_lane() instead".to_string(),
167            ));
168        }
169
170        let idx = payload_type as usize;
171
172        // 1. CheckCache
173        {
174            let cache = self.lane_cache.read().await;
175            if let Some(lane) = &cache[idx] {
176                tracing::debug!("📦 ReuseCache DataLane: {:?}", payload_type);
177                return Ok(lane.clone());
178            }
179        }
180
181        // 2. Createnew DataLane
182        let lane = self.create_lane_internal(payload_type).await?;
183
184        // 3. Cache
185        {
186            let mut cache = self.lane_cache.write().await;
187            cache[idx] = Some(lane.clone());
188        }
189
190        tracing::info!("✨ WebRtcConnection Createnew DataLane: {:?}", payload_type);
191
192        Ok(lane)
193    }
194
195    /// inner part Method:Create DataChannel Lane( not carry Cache)
196    async fn create_lane_internal(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
197        // Checkwhetheras MediaTrack Type
198        if payload_type == PayloadType::MediaRtp {
199            return Err(NetworkError::NotImplemented(
200                "MediaTrack Lane not implemented in this method".to_string(),
201            ));
202        }
203
204        // Create new DataChannel
205        let mut channels = self.data_channels.write().await;
206
207        let label = format!("{payload_type:?}");
208        let dc_config = Self::get_data_channel_config(payload_type);
209
210        let data_channel = self
211            .peer_connection
212            .create_data_channel(&label, Some(dc_config))
213            .await?;
214
215        // CreateReceive channel (using Bytes)
216        let (tx, rx) = mpsc::channel(100);
217
218        // Set onmessage return adjust
219        let tx_clone = tx.clone();
220        data_channel.on_message(Box::new(
221            move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
222                // zero-copy: directly using msg.data (Bytes)
223                let data = msg.data;
224                let tx = tx_clone.clone();
225                Box::pin(async move {
226                    if let Err(e) = tx.send(data).await {
227                        tracing::warn!("❌ WebRTC DataChannel messageSend to Lane failure: {}", e);
228                    }
229                })
230            },
231        ));
232
233        // Cache DataChannel( index reference directly using PayloadType value )
234        let idx = payload_type as usize;
235        channels[idx] = Some(Arc::clone(&data_channel));
236
237        // Returns Lane
238        Ok(DataLane::webrtc_data_channel(data_channel, rx))
239    }
240
241    /// Add media track to PeerConnection
242    ///
243    /// # Arguments
244    /// - `track_id`: Unique track identifier
245    /// - `codec`: Codec name (e.g., "H264", "VP8", "opus")
246    /// - `media_type`: "video" or "audio"
247    ///
248    /// # Returns
249    /// Reference to the created TrackLocalStaticRTP
250    ///
251    /// # Note
252    /// Must be called BEFORE create_offer/create_answer for track to appear in SDP
253    pub async fn add_media_track(
254        &self,
255        track_id: String,
256        codec: &str,
257        media_type: &str,
258    ) -> NetworkResult<Arc<TrackLocalStaticRTP>> {
259        use webrtc::api::media_engine::MIME_TYPE_H264;
260        use webrtc::api::media_engine::MIME_TYPE_OPUS;
261        use webrtc::api::media_engine::MIME_TYPE_VP8;
262        use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
263
264        // Determine MIME type based on codec and media_type
265        let mime_type = match (media_type, codec.to_uppercase().as_str()) {
266            ("video", "H264") => MIME_TYPE_H264,
267            ("video", "VP8") => MIME_TYPE_VP8,
268            ("audio", "OPUS") => MIME_TYPE_OPUS,
269            _ => {
270                return Err(NetworkError::WebRtcError(format!(
271                    "Unsupported codec: {codec} for {media_type}"
272                )));
273            }
274        };
275
276        // Create TrackLocalStaticRTP
277        let track = Arc::new(TrackLocalStaticRTP::new(
278            RTCRtpCodecCapability {
279                mime_type: mime_type.to_string(),
280                ..Default::default()
281            },
282            track_id.clone(),
283            format!("actr-{media_type}"), // stream_id
284        ));
285
286        // Add track to PeerConnection
287        let rtp_sender =
288            self.peer_connection
289                .add_track(Arc::clone(&track)
290                    as Arc<dyn webrtc::track::track_local::TrackLocal + Send + Sync>)
291                .await?;
292
293        // Cache track and sender
294        let mut tracks = self.media_tracks.write().await;
295        tracks.insert(track_id.clone(), (Arc::clone(&track), rtp_sender));
296
297        // Initialize sequence number for this track
298        let mut seq_nums = self.track_sequence_numbers.write().await;
299        seq_nums.insert(track_id.clone(), Arc::new(AtomicU16::new(0)));
300
301        // Generate unique SSRC for this track (random u32)
302        let ssrc = rand::random::<u32>();
303        let mut ssrcs = self.track_ssrcs.write().await;
304        ssrcs.insert(track_id.clone(), ssrc);
305
306        tracing::info!(
307            "✨ Added media track: id={}, codec={}, type={}, ssrc=0x{:08x}",
308            track_id,
309            codec,
310            media_type,
311            ssrc
312        );
313
314        Ok(track)
315    }
316
317    /// Get existing media track by ID
318    pub async fn get_media_track(&self, track_id: &str) -> Option<Arc<TrackLocalStaticRTP>> {
319        let tracks = self.media_tracks.read().await;
320        tracks
321            .get(track_id)
322            .map(|(track, _sender)| Arc::clone(track))
323    }
324
325    /// Get next RTP sequence number for track (atomically increments)
326    ///
327    /// # Arguments
328    /// - `track_id`: Track identifier
329    ///
330    /// # Returns
331    /// Next sequence number (wraps at 65535)
332    pub async fn next_sequence_number(&self, track_id: &str) -> Option<u16> {
333        let seq_nums = self.track_sequence_numbers.read().await;
334        seq_nums
335            .get(track_id)
336            .map(|atomic_seq| atomic_seq.fetch_add(1, Ordering::SeqCst))
337    }
338
339    /// Get SSRC for track
340    ///
341    /// # Arguments
342    /// - `track_id`: Track identifier
343    ///
344    /// # Returns
345    /// SSRC value for this track
346    pub async fn get_ssrc(&self, track_id: &str) -> Option<u32> {
347        let ssrcs = self.track_ssrcs.read().await;
348        ssrcs.get(track_id).copied()
349    }
350
351    /// GetorCreate MediaTrack Lane( carry Cache)
352    ///
353    /// # Arguments
354    /// - `_stream_id`: Media stream ID
355    ///
356    /// backwardaftercompatible hold Method:create_lane adjust usage get_lane
357    pub async fn create_lane(&self, payload_type: PayloadType) -> NetworkResult<DataLane> {
358        self.get_lane(payload_type).await
359    }
360
361    /// Register received DataChannel (for passive side)
362    ///
363    /// When receiving an Offer, the passive side should register DataChannels
364    /// received via on_data_channel callback instead of creating new ones.
365    pub async fn register_received_data_channel(
366        &self,
367        data_channel: Arc<RTCDataChannel>,
368        payload_type: PayloadType,
369    ) -> NetworkResult<DataLane> {
370        // Check if it's MediaTrack type
371        if payload_type == PayloadType::MediaRtp {
372            return Err(NetworkError::NotImplemented(
373                "MediaTrack Lane not supported in this method".to_string(),
374            ));
375        }
376
377        let idx = payload_type as usize;
378
379        // Create receive channel
380        let (tx, rx) = mpsc::channel(100);
381
382        // Set on_message callback
383        let tx_clone = tx.clone();
384        data_channel.on_message(Box::new(
385            move |msg: webrtc::data_channel::data_channel_message::DataChannelMessage| {
386                let data = msg.data;
387                let tx = tx_clone.clone();
388                Box::pin(async move {
389                    if let Err(e) = tx.send(data).await {
390                        tracing::warn!("❌ WebRTC DataChannel message send to Lane failed: {}", e);
391                    }
392                })
393            },
394        ));
395
396        // Cache DataChannel
397        {
398            let mut channels = self.data_channels.write().await;
399            channels[idx] = Some(Arc::clone(&data_channel));
400        }
401
402        // Create and cache Lane
403        let lane = DataLane::webrtc_data_channel(data_channel, rx);
404        {
405            let mut cache = self.lane_cache.write().await;
406            cache[idx] = Some(lane.clone());
407        }
408
409        tracing::info!(
410            "✨ WebRtcConnection registered received DataChannel: {:?}",
411            payload_type
412        );
413
414        Ok(lane)
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    // Note:WebRTC gather integrate measure try needCompletesignaling stream process , this in only do solely element measure try
423
424    #[test]
425    fn test_data_channel_config() {
426        let config = WebRtcConnection::get_data_channel_config(PayloadType::RpcReliable);
427        assert_eq!(config.ordered, Some(true));
428
429        let config = WebRtcConnection::get_data_channel_config(PayloadType::StreamLatencyFirst);
430        assert_eq!(config.ordered, Some(false));
431        assert_eq!(config.max_retransmits, Some(3));
432    }
433}