rsfu/sfu/
router.rs

1use super::down_track::DownTrack;
2use super::down_track_internal::DownTrackInternal;
3use super::errors::Result;
4use super::receiver::Receiver;
5use super::receiver::WebRTCReceiver;
6use super::session::Session;
7use super::simulcast::SimulcastConfig;
8
9use super::subscriber::Subscriber;
10use crate::buffer::buffer::Options as BufferOptions;
11use crate::buffer::buffer_io::BufferIO;
12use crate::buffer::factory::AtomicFactory;
13use crate::stats::stream::Stream;
14use crate::twcc::twcc::Responder;
15use async_trait::async_trait;
16use rtcp::packet::Packet as RtcpPacket;
17
18use std::collections::HashMap;
19use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22use tokio::sync::Mutex;
23use webrtc::error::Error as RTCError;
24use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
25
26use serde::Deserialize;
27use tokio::sync::mpsc;
28use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
29use webrtc::rtp_transceiver::rtp_codec::RTPCodecType;
30use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver;
31use webrtc::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection;
32use webrtc::rtp_transceiver::RTCPFeedback;
33use webrtc::rtp_transceiver::RTCRtpTransceiverInit;
34use webrtc::track::track_remote::TrackRemote;
35
36pub type RtcpDataSender = mpsc::UnboundedSender<Vec<Box<dyn RtcpPacket + Send + Sync>>>;
37pub type RtcpDataReceiver = mpsc::UnboundedReceiver<Vec<Box<dyn RtcpPacket + Send + Sync>>>;
38
39pub type RtcpWriterFn = Box<
40    dyn (FnMut(
41            Vec<Box<dyn RtcpPacket + Send + Sync>>,
42        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>)
43        + Send
44        + Sync,
45>;
46
47pub type OnAddReciverTrackFn = Box<
48    dyn (FnMut(
49            Arc<dyn Receiver + Send + Sync>,
50        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>)
51        + Send
52        + Sync,
53>;
54pub type OnDelReciverTrackFn = Box<
55    dyn (FnMut(
56            Arc<dyn Receiver + Send + Sync>,
57        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>)
58        + Send
59        + Sync,
60>;
61
62#[async_trait]
63pub trait Router {
64    fn id(&self) -> String;
65    async fn add_receiver(
66        &self,
67        receiver: Arc<RTCRtpReceiver>,
68        track: Arc<TrackRemote>,
69        track_id: String,
70        stream_id: String,
71    ) -> (Arc<dyn Receiver + Send + Sync>, bool);
72    async fn add_down_tracks(
73        &self,
74        s: Arc<Subscriber>,
75        r: Option<Arc<dyn Receiver + Send + Sync>>,
76    ) -> Result<()>;
77    async fn add_down_track(
78        &self,
79        s: Arc<Subscriber>,
80        r: Arc<dyn Receiver + Send + Sync>,
81    ) -> Result<Option<Arc<DownTrack>>>;
82    async fn set_rtcp_writer(&self, writer: RtcpWriterFn);
83    fn get_receiver(&self) -> Arc<Mutex<HashMap<String, Arc<dyn Receiver + Send + Sync>>>>;
84
85    async fn stop(&self);
86    async fn on_add_receiver_track(&self, f: OnAddReciverTrackFn);
87    async fn on_del_receiver_track(&self, f: OnDelReciverTrackFn);
88    async fn send_rtcp(&self);
89}
90
91#[derive(Default, Clone, Deserialize)]
92pub struct RouterConfig {
93    #[serde(rename = "withstats")]
94    pub(super) with_stats: bool,
95    #[serde(rename = "maxbandwidth")]
96    max_bandwidth: u64,
97    #[serde(rename = "maxpackettrack")]
98    pub max_packet_track: i32,
99    #[serde(rename = "audiolevelinterval")]
100    pub audio_level_interval: i32,
101    #[serde(rename = "audiolevelthreshold")]
102    #[allow(dead_code)]
103    audio_level_threshold: u8,
104    #[serde(rename = "audiolevelfilter")]
105    #[allow(dead_code)]
106    audio_level_filter: i32,
107    simulcast: SimulcastConfig,
108}
109
110pub struct RouterLocal {
111    id: String,
112    twcc: Arc<Mutex<Option<Responder>>>,
113    stats: Arc<Mutex<HashMap<u32, Stream>>>,
114    rtcp_sender_channel: Arc<RtcpDataSender>,
115    rtcp_receiver_channel: Arc<Mutex<RtcpDataReceiver>>,
116    stop_sender_channel: Arc<Mutex<mpsc::UnboundedSender<()>>>,
117    stop_receiver_channel: Arc<Mutex<mpsc::UnboundedReceiver<()>>>,
118    config: RouterConfig,
119    session: Arc<dyn Session + Send + Sync>,
120    receivers: Arc<Mutex<HashMap<String, Arc<dyn Receiver + Send + Sync>>>>,
121    buffer_factory: AtomicFactory,
122    rtcp_writer_handler: Arc<Mutex<Option<RtcpWriterFn>>>,
123    on_add_receiver_track_handler: Arc<Mutex<Option<OnAddReciverTrackFn>>>,
124    on_del_receiver_track_handler: Arc<Mutex<Option<OnDelReciverTrackFn>>>,
125}
126impl RouterLocal {
127    pub fn new(id: String, session: Arc<dyn Session + Send + Sync>, config: RouterConfig) -> Self {
128        let (s, r) = mpsc::unbounded_channel();
129        let (sender, receiver) = mpsc::unbounded_channel();
130        Self {
131            id,
132            twcc: Arc::new(Mutex::new(None)),
133            stats: Arc::new(Mutex::new(HashMap::new())),
134            rtcp_sender_channel: Arc::new(s),
135            rtcp_receiver_channel: Arc::new(Mutex::new(r)),
136            stop_sender_channel: Arc::new(Mutex::new(sender)),
137            stop_receiver_channel: Arc::new(Mutex::new(receiver)),
138            config,
139            session,
140            receivers: Arc::new(Mutex::new(HashMap::new())),
141            buffer_factory: AtomicFactory::new(100, 100),
142            rtcp_writer_handler: Arc::new(Mutex::new(None)),
143            on_add_receiver_track_handler: Arc::new(Mutex::new(None)),
144            on_del_receiver_track_handler: Arc::new(Mutex::new(None)),
145        }
146    }
147    #[allow(dead_code)]
148    async fn delete_receiver(&self, track: String, ssrc: u32) {
149        if let Some(f) = &mut *self.on_del_receiver_track_handler.lock().await {
150            if let Some(track) = self.receivers.lock().await.get(&track) {
151                f(track.clone());
152            }
153        }
154        self.receivers.lock().await.remove(&track);
155        self.stats.lock().await.remove(&ssrc);
156    }
157}
158
159#[async_trait]
160impl Router for RouterLocal {
161    fn get_receiver(&self) -> Arc<Mutex<HashMap<String, Arc<dyn Receiver + Send + Sync>>>> {
162        self.receivers.clone()
163    }
164
165    async fn on_add_receiver_track(&self, f: OnAddReciverTrackFn) {
166        let mut handler = self.on_add_receiver_track_handler.lock().await;
167        *handler = Some(f);
168    }
169    async fn on_del_receiver_track(&self, f: OnDelReciverTrackFn) {
170        let mut handler = self.on_del_receiver_track_handler.lock().await;
171        *handler = Some(f);
172    }
173
174    fn id(&self) -> String {
175        self.id.clone()
176    }
177
178    async fn stop(&self) {
179        if let Err(err) = self.stop_sender_channel.lock().await.send(()) {
180            log::error!("stop err: {}", err);
181        }
182        if self.config.with_stats {}
183    }
184
185    async fn add_receiver(
186        &self,
187        receiver: Arc<RTCRtpReceiver>,
188        track: Arc<TrackRemote>,
189        track_id: String,
190        stream_id: String,
191    ) -> (Arc<dyn Receiver + Send + Sync>, bool) {
192        let mut publish = false;
193
194        let buffer = self.buffer_factory.get_or_new_buffer(track.ssrc()).await;
195
196        let sender_for_buffer = Arc::clone(&self.rtcp_sender_channel);
197        buffer
198            .on_feedback_callback(Box::new(
199                move |packets: Vec<Box<dyn RtcpPacket + Send + Sync>>| {
200                    let sender_for_buffer_in = Arc::clone(&sender_for_buffer);
201                    Box::pin(async move {
202                        if let Err(err) = sender_for_buffer_in.send(packets) {
203                            log::error!("send err: {}", err);
204                        }
205                    })
206                },
207            ))
208            .await;
209
210        match track.kind() {
211            RTPCodecType::Audio => {
212                let session_out = Arc::clone(&self.session);
213                let stream_id_out = stream_id.clone();
214                buffer
215                    .on_audio_level(Box::new(move |level: u8| {
216                        let session_in = Arc::clone(&session_out);
217                        let stream_id_in = stream_id_out.clone();
218                        Box::pin(async move {
219                            if let Some(observer) = session_in.audio_obserber() {
220                                observer.lock().await.observe(stream_id_in, level).await;
221                            }
222                        })
223                    }))
224                    .await;
225                if let Some(observer) = self.session.audio_obserber() {
226                    observer.lock().await.add_stream(stream_id).await;
227                }
228            }
229            RTPCodecType::Video => {
230                log::info!("track video");
231                if self.twcc.lock().await.is_none() {
232                    let mut twcc = Responder::new(track.ssrc());
233                    let sender = Arc::clone(&self.rtcp_sender_channel);
234                    twcc.on_feedback(Box::new(
235                        move |rtcp_packet: Box<dyn RtcpPacket + Send + Sync>| {
236                            let sender_in = Arc::clone(&sender);
237                            Box::pin(async move {
238                                let mut data = Vec::new();
239                                data.push(rtcp_packet);
240                                if let Err(err) = sender_in.send(data) {
241                                    log::error!("send err: {}", err);
242                                }
243                            })
244                        },
245                    ))
246                    .await;
247                    let mut t = self.twcc.lock().await;
248                    *t = Some(twcc);
249                    //self.twcc = Arc::new(Mutex::new(Some(twcc)));
250                }
251
252                let twcc_out = Arc::clone(&self.twcc);
253                buffer
254                    .on_transport_wide_cc(Box::new(move |sn: u16, time_ns: i64, marker: bool| {
255                        let twcc_in = Arc::clone(&twcc_out);
256                        Box::pin(async move {
257                            if let Some(twcc) = &mut *twcc_in.lock().await {
258                                twcc.push(sn, time_ns, marker).await;
259                            }
260                        })
261                    }))
262                    .await;
263            }
264            RTPCodecType::Unspecified => {}
265        }
266
267        if self.config.with_stats {
268            let stream = Stream::new(Arc::clone(&buffer));
269            self.stats.lock().await.insert(track.ssrc(), stream);
270        }
271
272        let rtcp_reader = self
273            .buffer_factory
274            .get_or_new_rtcp_buffer(track.ssrc())
275            .await;
276
277        let stats_out = Arc::clone(&self.stats);
278        let buffer_out = Arc::clone(&buffer);
279        let with_status = self.config.with_stats;
280
281        rtcp_reader
282            .lock()
283            .await
284            .on_packet(Box::new(move |packet: Vec<u8>| {
285                let stats_in = Arc::clone(&stats_out);
286                let buffer_in = Arc::clone(&buffer_out);
287                Box::pin(async move {
288                    let mut buf = &packet[..];
289                    let pkts_result = rtcp::packet::unmarshal(&mut buf)?;
290                    for pkt in pkts_result {
291                        if let Some(source_description) =
292                            pkt.as_any()
293                                .downcast_ref::<rtcp::source_description::SourceDescription>()
294                        {
295                            if with_status {
296                                for chunk in &source_description.chunks {
297                                    if let Some(stream) =
298                                        stats_in.lock().await.get_mut(&chunk.source)
299                                    {
300                                        for item in &chunk.items {
301                                            if item.sdes_type
302                                                == rtcp::source_description::SdesType::SdesCname
303                                            {
304                                                stream
305                                                    .set_cname(
306                                                        String::from_utf8(item.text.to_vec())
307                                                            .unwrap(),
308                                                    )
309                                                    .await;
310                                            }
311                                        }
312                                    }
313                                }
314                            }
315                        } else if let Some(sender_report) =
316                            pkt.as_any()
317                                .downcast_ref::<rtcp::sender_report::SenderReport>()
318                        {
319                            buffer_in
320                                .set_sender_report_data(
321                                    sender_report.rtp_time,
322                                    sender_report.ntp_time,
323                                )
324                                .await;
325                            if with_status {
326                                if let Some(_stream) =
327                                    stats_in.lock().await.get_mut(&sender_report.ssrc)
328                                {
329                                    //update stats
330                                }
331                            }
332                        }
333                    }
334                    Ok(())
335                })
336            }))
337            .await;
338
339        let result_receiver;
340
341        let mut receivers = self.receivers.lock().await;
342        if let Some(recv) = receivers.get(&track_id) {
343            result_receiver = recv.clone();
344        } else {
345            let mut rv =
346                WebRTCReceiver::new(receiver.clone(), track.clone(), self.id.clone()).await;
347            rv.set_rtcp_channel(self.rtcp_sender_channel.clone());
348            let recv_kind = rv.kind();
349            let session_out = self.session.clone();
350            let stream_id = track.stream_id().await;
351
352            let receivers_out = self.receivers.clone();
353            let stats_out = self.stats.clone();
354            let del_handler_out = self.on_add_receiver_track_handler.clone();
355            let track_id_out = track_id.clone();
356            let track_ssrc = track.ssrc();
357            rv.on_close_handler(Box::new(move || {
358                //let stats_in = Arc::clone(&stats_out);
359                let session_in = session_out.clone();
360                let stream_id_in = stream_id.clone();
361                let track_id_in = track_id_out.clone();
362
363                let receivers_in = receivers_out.clone();
364                let stats_in = stats_out.clone();
365                let del_handler_in = del_handler_out.clone();
366
367                Box::pin(async move {
368                    if with_status {
369                        // match track.kind() {
370                        //     RTPCodecType::Video => {
371                        //         //todo
372                        //     }
373                        //     _ => {
374                        //         //todo
375                        //     }
376                        // }
377                    }
378                    if recv_kind == RTPCodecType::Audio {
379                        if let Some(audio_observer) = session_in.audio_obserber() {
380                            audio_observer
381                                .lock()
382                                .await
383                                .remove_stream(stream_id_in)
384                                .await;
385                        }
386                    }
387                    delete_receiver(
388                        &track_id_in,
389                        &track_ssrc,
390                        del_handler_in,
391                        receivers_in,
392                        stats_in,
393                    )
394                    .await;
395                })
396            }))
397            .await;
398            result_receiver = Arc::new(rv);
399            receivers.insert(track_id, result_receiver.clone());
400            publish = true;
401
402            if let Some(f) = &mut *self.on_add_receiver_track_handler.lock().await {
403                f(result_receiver.clone());
404            }
405        }
406
407        let layer = result_receiver
408            .add_up_track(
409                track.clone(),
410                buffer.clone(),
411                self.config.simulcast.best_quality_first,
412            )
413            .await;
414
415        if let Some(layer_val) = layer {
416            let receiver_clone = result_receiver.clone();
417            tokio::spawn(async move { receiver_clone.write_rtp(layer_val).await });
418        }
419
420        buffer
421            .bind(
422                receiver.get_parameters().await,
423                BufferOptions {
424                    max_bitrate: self.config.max_bandwidth,
425                },
426            )
427            .await;
428
429        let track_clone = track.clone();
430        let buffer_clone = buffer.clone();
431
432        tokio::spawn(async move {
433            let mut b = vec![0u8; 1500];
434
435            while let Ok((n, _)) = track_clone.read(&mut b).await {
436                if let Err(err) = buffer_clone.write(&b[..n]).await {
437                    log::error!("write error: {}", err);
438                }
439            }
440
441            Result::<()>::Ok(())
442        });
443        (result_receiver, publish)
444    }
445
446    async fn add_down_tracks(
447        &self,
448        s: Arc<Subscriber>,
449        r: Option<Arc<dyn Receiver + Send + Sync>>,
450    ) -> Result<()> {
451        if s.no_auto_subscribe {
452            return Ok(());
453        }
454
455        if let Some(receiver) = r {
456            self.add_down_track(s.clone(), receiver).await?;
457            s.negotiate().await?;
458            return Ok(());
459        }
460
461        let mut recs = Vec::new();
462        {
463            let mut receivers = self.receivers.lock().await;
464            for receiver in (*receivers).values_mut() {
465                recs.push(receiver.clone())
466            }
467        }
468
469        if !recs.is_empty() {
470            for val in recs {
471                self.add_down_track(s.clone(), val.clone()).await?;
472            }
473            s.negotiate().await?;
474        }
475
476        Ok(())
477    }
478
479    async fn add_down_track(
480        &self,
481        s: Arc<Subscriber>,
482        r: Arc<dyn Receiver + Send + Sync>,
483    ) -> Result<Option<Arc<DownTrack>>> {
484        let recv = r.clone();
485        let downtracks = s.get_downtracks(recv.stream_id()).await;
486        if let Some(downtracks_data) = downtracks {
487            for dt in downtracks_data {
488                if dt.id() == recv.track_id() {
489                    return Ok(Some(dt));
490                }
491            }
492        }
493        let codec = recv.codec();
494        s.me.lock()
495            .await
496            .register_codec(codec.clone(), recv.kind())?;
497
498        let codec_capability = RTCRtpCodecCapability {
499            mime_type: codec.capability.mime_type,
500            clock_rate: codec.capability.clock_rate,
501            channels: codec.capability.channels,
502            sdp_fmtp_line: codec.capability.sdp_fmtp_line,
503            rtcp_feedback: vec![
504                RTCPFeedback {
505                    typ: String::from("goog-remb"),
506                    parameter: String::from(""),
507                },
508                RTCPFeedback {
509                    typ: String::from("nack"),
510                    parameter: String::from(""),
511                },
512                RTCPFeedback {
513                    typ: String::from("nack"),
514                    parameter: String::from("pli"),
515                },
516            ],
517        };
518
519        let down_track_local =
520            DownTrackInternal::new(codec_capability, r.clone(), self.config.max_packet_track).await;
521
522        let down_track_arc = Arc::new(down_track_local);
523        let transceiver =
524            s.pc.add_transceiver_from_track(
525                down_track_arc.clone(),
526                &[RTCRtpTransceiverInit {
527                    direction: RTCRtpTransceiverDirection::Sendonly,
528                    send_encodings: Vec::new(),
529                }],
530            )
531            .await?;
532
533        let mut down_track = DownTrack::new_track_local(s.id.clone(), down_track_arc);
534        down_track.set_transceiver(transceiver.clone());
535
536        let down_track_arc = Arc::new(down_track);
537
538        let s_out = s.clone();
539        let r_out = r.clone();
540        let transceiver_out = transceiver.clone();
541        let down_track_arc_out = down_track_arc.clone();
542
543        down_track_arc
544            .on_close_handler(Box::new(move || {
545                let s_in = s_out.clone();
546                let r_in = r_out.clone();
547                let transceiver_in = transceiver_out.clone();
548                let down_track_arc_in = down_track_arc_out.clone();
549                Box::pin(async move {
550                    if s_in.pc.connection_state() != RTCPeerConnectionState::Closed {
551                        let rv = s_in
552                            .pc
553                            .remove_track(&transceiver_in.sender().await.unwrap())
554                            .await;
555                        match rv {
556                            Ok(_) => {
557                                s_in.remove_down_track(r_in.stream_id(), down_track_arc_in)
558                                    .await;
559                                log::info!("RemoveDownTrack Negotiate");
560                                if let Err(err) = s_in.negotiate().await {
561                                    log::error!("negotiate err:{} ", err);
562                                }
563                            }
564                            Err(err) => {
565                                if err == RTCError::ErrConnectionClosed {
566                                    // return;
567                                }
568                            }
569                        }
570                    }
571                })
572            }))
573            .await;
574
575        let s_out_1 = s.clone();
576        let r_out_1 = r.clone();
577        down_track_arc
578            .on_bind(Box::new(move || {
579                let s_in = s_out_1.clone();
580                let r_in = r_out_1.clone();
581
582                Box::pin(async move {
583                    tokio::spawn(async move {
584                        s_in.send_stream_down_track_reports(r_in.stream_id()).await;
585                    });
586                })
587            }))
588            .await;
589
590        s.add_down_track(recv.stream_id(), down_track_arc.clone())
591            .await;
592        recv.add_down_track(down_track_arc, self.config.simulcast.best_quality_first)
593            .await;
594
595        Ok(None)
596    }
597
598    async fn set_rtcp_writer(&self, writer: RtcpWriterFn) {
599        let mut handler = self.rtcp_writer_handler.lock().await;
600        *handler = Some(writer);
601    }
602
603    async fn send_rtcp(&self) {
604        loop {
605            let mut rtcp_receiver = self.rtcp_receiver_channel.lock().await;
606            let mut stop_receiver = self.stop_receiver_channel.lock().await;
607            tokio::select! {
608              data = rtcp_receiver.recv() => {
609                 if let Some(val) = data{
610                    if let Some(f) = &mut *self.rtcp_writer_handler.lock().await {
611                        f(val);
612                    }
613                }
614              }
615              _data = stop_receiver.recv() => {
616                return ;
617              }
618            };
619        }
620    }
621}
622
623async fn delete_receiver(
624    track: &String,
625    ssrc: &u32,
626    del_handler: Arc<Mutex<Option<OnDelReciverTrackFn>>>,
627    receivers: Arc<Mutex<HashMap<String, Arc<dyn Receiver + Send + Sync>>>>,
628    stats: Arc<Mutex<HashMap<u32, Stream>>>,
629) {
630    if let Some(f) = &mut *del_handler.lock().await {
631        if let Some(track) = receivers.lock().await.get(track) {
632            f(track.clone());
633        }
634    }
635    receivers.lock().await.remove(track);
636    stats.lock().await.remove(ssrc);
637}