Skip to main content

active_call/media/
stream.rs

1use crate::event::{EventSender, SessionEvent};
2use crate::media::dtmf::DtmfDetector;
3use crate::media::{AudioFrame, Samples, TrackId};
4use crate::media::{
5    processor::Processor,
6    recorder::{Recorder, RecorderOption},
7    track::{Track, TrackPacketReceiver, TrackPacketSender},
8};
9use anyhow::Result;
10use std::collections::{HashMap, HashSet};
11use std::path::Path;
12use std::time::Duration;
13use tokio::task::JoinHandle;
14use tokio::{
15    select,
16    sync::{Mutex, mpsc},
17};
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, info, warn};
20use uuid;
21
22pub struct MediaStream {
23    id: String,
24    pub cancel_token: CancellationToken,
25    recorder_option: Mutex<Option<RecorderOption>>,
26    tracks: Mutex<HashMap<TrackId, (Box<dyn Track>, DtmfDetector)>>,
27    suppressed_sources: Mutex<HashSet<TrackId>>,
28    event_sender: EventSender,
29    pub packet_sender: TrackPacketSender,
30    packet_receiver: Mutex<Option<TrackPacketReceiver>>,
31    recorder_sender: mpsc::UnboundedSender<AudioFrame>,
32    recorder_receiver: Mutex<Option<mpsc::UnboundedReceiver<AudioFrame>>>,
33    recorder_handle: Mutex<Option<JoinHandle<()>>>,
34}
35
36const CALLEE_TRACK_ID: &str = "callee-track";
37const QUEUE_HOLD_TRACK_ID: &str = "queue-hold-track";
38
39pub struct MediaStreamBuilder {
40    cancel_token: Option<CancellationToken>,
41    id: Option<String>,
42    event_sender: EventSender,
43    recorder_config: Option<RecorderOption>,
44}
45
46impl MediaStreamBuilder {
47    pub fn new(event_sender: EventSender) -> Self {
48        Self {
49            id: Some(format!("ms:{}", uuid::Uuid::new_v4())),
50            cancel_token: None,
51            event_sender,
52            recorder_config: None,
53        }
54    }
55    pub fn with_id(mut self, id: String) -> Self {
56        self.id = Some(id);
57        self
58    }
59
60    pub fn with_cancel_token(mut self, cancel_token: CancellationToken) -> Self {
61        self.cancel_token = Some(cancel_token);
62        self
63    }
64
65    pub fn with_recorder_config(mut self, recorder_config: RecorderOption) -> Self {
66        self.recorder_config = Some(recorder_config);
67        self
68    }
69
70    pub fn build(self) -> MediaStream {
71        let cancel_token = self
72            .cancel_token
73            .unwrap_or_else(|| CancellationToken::new());
74        let tracks = Mutex::new(HashMap::new());
75        let (track_packet_sender, track_packet_receiver) = mpsc::unbounded_channel();
76        let (recorder_sender, recorder_receiver) = mpsc::unbounded_channel();
77        MediaStream {
78            id: self.id.unwrap_or_default(),
79            cancel_token,
80            recorder_option: Mutex::new(self.recorder_config),
81            tracks,
82            suppressed_sources: Mutex::new(HashSet::new()),
83            event_sender: self.event_sender,
84            packet_sender: track_packet_sender,
85            packet_receiver: Mutex::new(Some(track_packet_receiver)),
86            recorder_sender,
87            recorder_receiver: Mutex::new(Some(recorder_receiver)),
88            recorder_handle: Mutex::new(None),
89        }
90    }
91}
92
93impl MediaStream {
94    pub async fn serve(&self) -> Result<()> {
95        let packet_receiver = match self.packet_receiver.lock().await.take() {
96            Some(receiver) => receiver,
97            None => {
98                warn!(
99                    session_id = self.id,
100                    "MediaStream::serve() called multiple times, stream already serving"
101                );
102                return Ok(());
103            }
104        };
105        self.start_recorder().await.ok();
106        info!(session_id = self.id, "mediastream serving");
107        select! {
108            _ = self.cancel_token.cancelled() => {}
109            r = self.handle_forward_track(packet_receiver) => {
110                info!(session_id = self.id, "track packet receiver stopped {:?}", r);
111            }
112        }
113        Ok(())
114    }
115
116    pub fn stop(&self, _reason: Option<String>, _initiator: Option<String>) {
117        self.cancel_token.cancel()
118    }
119
120    pub async fn cleanup(&self) -> Result<()> {
121        self.cancel_token.cancel();
122        if let Some(recorder_handle) = self.recorder_handle.lock().await.take() {
123            if let Ok(Ok(_)) = tokio::time::timeout(Duration::from_secs(30), recorder_handle).await
124            {
125                info!(session_id = self.id, "recorder stopped");
126            } else {
127                warn!(session_id = self.id, "recorder timeout");
128            }
129        }
130        Ok(())
131    }
132
133    pub async fn update_recorder_option(&self, recorder_config: RecorderOption) {
134        *self.recorder_option.lock().await = Some(recorder_config);
135        self.start_recorder().await.ok();
136    }
137
138    pub async fn remove_track(&self, id: &TrackId, graceful: bool) {
139        let track_entry = { self.tracks.lock().await.remove(id) };
140        if let Some((track, _)) = track_entry {
141            self.suppressed_sources.lock().await.remove(id);
142            let res = if !graceful {
143                track.stop().await
144            } else {
145                track.stop_graceful().await
146            };
147            match res {
148                Ok(_) => {}
149                Err(e) => {
150                    warn!(session_id = self.id, "failed to stop track: {}", e);
151                }
152            }
153        }
154    }
155    pub async fn update_remote_description(
156        &self,
157        track_id: &TrackId,
158        answer: &String,
159    ) -> Result<()> {
160        let track_entry = { self.tracks.lock().await.remove(track_id) };
161        if let Some((mut track, dtmf)) = track_entry {
162            let res = track.update_remote_description(answer).await;
163            self.tracks
164                .lock()
165                .await
166                .insert(track_id.clone(), (track, dtmf));
167            res?;
168        }
169        Ok(())
170    }
171
172    pub async fn update_remote_description_force(
173        &self,
174        track_id: &TrackId,
175        answer: &String,
176    ) -> Result<()> {
177        let track_entry = { self.tracks.lock().await.remove(track_id) };
178        if let Some((mut track, dtmf)) = track_entry {
179            let res = track.update_remote_description_force(answer).await;
180            self.tracks
181                .lock()
182                .await
183                .insert(track_id.clone(), (track, dtmf));
184            res?;
185        }
186        Ok(())
187    }
188
189    pub async fn handshake(
190        &self,
191        track_id: &TrackId,
192        offer: String,
193        timeout: Option<Duration>,
194    ) -> Result<String> {
195        let track_entry = { self.tracks.lock().await.remove(track_id) };
196        if let Some((mut track, dtmf)) = track_entry {
197            let res = track.handshake(offer, timeout).await;
198            self.tracks
199                .lock()
200                .await
201                .insert(track_id.clone(), (track, dtmf));
202            res
203        } else {
204            anyhow::bail!("track not found: {}", track_id)
205        }
206    }
207
208    pub async fn update_track(&self, mut track: Box<dyn Track>, play_id: Option<String>) {
209        self.remove_track(track.id(), false).await;
210        if self.recorder_option.lock().await.is_some() {
211            track.insert_processor(Box::new(RecorderProcessor::new(
212                self.recorder_sender.clone(),
213            )));
214        }
215        match track
216            .start(self.event_sender.clone(), self.packet_sender.clone())
217            .await
218        {
219            Ok(_) => {
220                info!(session_id = self.id, track_id = track.id(), "track started");
221                let track_id = track.id().clone();
222                self.tracks
223                    .lock()
224                    .await
225                    .insert(track_id.clone(), (track, DtmfDetector::new()));
226                self.event_sender
227                    .send(SessionEvent::TrackStart {
228                        track_id,
229                        timestamp: crate::media::get_timestamp(),
230                        play_id,
231                    })
232                    .ok();
233            }
234            Err(e) => {
235                warn!(
236                    session_id = self.id,
237                    track_id = track.id(),
238                    play_id = play_id.as_deref(),
239                    "Failed to start track: {}",
240                    e
241                );
242            }
243        }
244    }
245
246    pub async fn mute_track(&self, id: Option<TrackId>) {
247        if let Some(id) = id {
248            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
249                MuteProcessor::mute_track(track.as_mut());
250            }
251        } else {
252            for (track, _) in self.tracks.lock().await.values_mut() {
253                MuteProcessor::mute_track(track.as_mut());
254            }
255        }
256    }
257
258    pub async fn unmute_track(&self, id: Option<TrackId>) {
259        if let Some(id) = id {
260            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
261                MuteProcessor::unmute_track(track.as_mut());
262            }
263        } else {
264            for (track, _) in self.tracks.lock().await.values_mut() {
265                MuteProcessor::unmute_track(track.as_mut());
266            }
267        }
268    }
269
270    pub async fn suppress_forwarding(&self, track_id: &TrackId) {
271        self.suppressed_sources
272            .lock()
273            .await
274            .insert(track_id.clone());
275    }
276
277    pub async fn resume_forwarding(&self, track_id: &TrackId) {
278        self.suppressed_sources.lock().await.remove(track_id);
279    }
280}
281
282#[derive(Clone)]
283pub struct RecorderProcessor {
284    sender: mpsc::UnboundedSender<AudioFrame>,
285}
286
287impl RecorderProcessor {
288    pub fn new(sender: mpsc::UnboundedSender<AudioFrame>) -> Self {
289        Self { sender }
290    }
291}
292
293impl Processor for RecorderProcessor {
294    fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
295        let frame_clone = frame.clone();
296        let _ = self.sender.send(frame_clone);
297        Ok(())
298    }
299}
300
301impl MediaStream {
302    pub async fn start_recorder(&self) -> Result<()> {
303        let recorder_option = self.recorder_option.lock().await.clone();
304        if let Some(recorder_option) = recorder_option {
305            if recorder_option.recorder_file.is_empty() {
306                warn!(
307                    session_id = self.id,
308                    "recorder file is empty, skipping recorder start"
309                );
310                return Ok(());
311            }
312            let recorder_receiver = match self.recorder_receiver.lock().await.take() {
313                Some(receiver) => receiver,
314                None => {
315                    return Ok(());
316                }
317            };
318            let cancel_token = self.cancel_token.child_token();
319            let session_id_clone = self.id.clone();
320
321            info!(
322                session_id = session_id_clone,
323                sample_rate = recorder_option.samplerate,
324                ptime = recorder_option.ptime,
325                "start recorder",
326            );
327
328            let recorder_handle = crate::spawn(async move {
329                let recorder_file = recorder_option.recorder_file.clone();
330                let recorder =
331                    Recorder::new(cancel_token, session_id_clone.clone(), recorder_option);
332                match recorder
333                    .process_recording(Path::new(&recorder_file), recorder_receiver)
334                    .await
335                {
336                    Ok(_) => {}
337                    Err(e) => {
338                        warn!(
339                            session_id = session_id_clone,
340                            "Failed to process recorder: {}", e
341                        );
342                    }
343                }
344            });
345            *self.recorder_handle.lock().await = Some(recorder_handle);
346        }
347        Ok(())
348    }
349
350    async fn handle_forward_track(&self, mut packet_receiver: TrackPacketReceiver) {
351        let event_sender = self.event_sender.clone();
352        while let Some(packet) = packet_receiver.recv().await {
353            let suppressed = {
354                self.suppressed_sources
355                    .lock()
356                    .await
357                    .contains(&packet.track_id)
358            };
359            // Process the packet with each track
360            for (track, dtmf_detector) in self.tracks.lock().await.values_mut() {
361                if track.id() == &packet.track_id {
362                    match &packet.samples {
363                        Samples::RTP {
364                            payload_type,
365                            payload,
366                            ..
367                        } => {
368                            if let Some(digit) = dtmf_detector.detect_rtp(*payload_type, payload) {
369                                debug!(track_id = track.id(), digit, "DTMF detected");
370                                event_sender
371                                    .send(SessionEvent::Dtmf {
372                                        track_id: packet.track_id.to_string(),
373                                        timestamp: packet.timestamp,
374                                        digit,
375                                    })
376                                    .ok();
377                            }
378                        }
379                        _ => {}
380                    }
381                    continue;
382                }
383                if suppressed {
384                    continue;
385                }
386                if packet.track_id == QUEUE_HOLD_TRACK_ID && track.id() == CALLEE_TRACK_ID {
387                    continue;
388                }
389                if let Err(e) = track.send_packet(&packet).await {
390                    warn!(
391                        id = track.id(),
392                        "media_stream: Failed to send packet to track: {}", e
393                    );
394                }
395            }
396        }
397    }
398}
399
400pub struct MuteProcessor;
401
402impl MuteProcessor {
403    pub fn mute_track(track: &mut dyn Track) {
404        let chain = track.processor_chain();
405        if !chain.has_processor::<MuteProcessor>() {
406            chain.insert_processor(Box::new(MuteProcessor));
407        }
408    }
409
410    pub fn unmute_track(track: &mut dyn Track) {
411        let chain = track.processor_chain();
412        chain.remove_processor::<MuteProcessor>();
413    }
414}
415
416impl Processor for MuteProcessor {
417    fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
418        match &mut frame.samples {
419            Samples::PCM { samples } => {
420                samples.fill(0);
421            }
422            // discard DTMF frames
423            Samples::RTP { payload_type, .. } if *payload_type >= 96 && *payload_type <= 127 => {
424                frame.samples = Samples::Empty;
425            }
426            _ => {}
427        }
428        Ok(())
429    }
430}