Skip to main content

active_call/media/
stream.rs

1use crate::event::{EventSender, SessionEvent};
2use crate::media::dtmf::DtmfDetector;
3use crate::media::volume_control::HoldProcessor;
4use crate::media::{AudioFrame, Samples, TrackId};
5use crate::media::{
6    processor::Processor,
7    recorder::{Recorder, RecorderOption},
8    track::{Track, TrackPacketReceiver, TrackPacketSender},
9};
10use anyhow::Result;
11use std::collections::{HashMap, HashSet};
12use std::path::Path;
13use std::time::Duration;
14use tokio::task::JoinHandle;
15use tokio::{
16    select,
17    sync::{Mutex, mpsc},
18};
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, info, warn};
21use uuid;
22
23pub struct MediaStream {
24    id: String,
25    pub cancel_token: CancellationToken,
26    recorder_option: Mutex<Option<RecorderOption>>,
27    tracks: Mutex<HashMap<TrackId, (Box<dyn Track>, DtmfDetector)>>,
28    suppressed_sources: Mutex<HashSet<TrackId>>,
29    event_sender: EventSender,
30    pub packet_sender: TrackPacketSender,
31    packet_receiver: Mutex<Option<TrackPacketReceiver>>,
32    recorder_sender: mpsc::UnboundedSender<AudioFrame>,
33    recorder_receiver: Mutex<Option<mpsc::UnboundedReceiver<AudioFrame>>>,
34    recorder_handle: Mutex<Option<JoinHandle<()>>>,
35}
36
37const CALLEE_TRACK_ID: &str = "callee-track";
38const QUEUE_HOLD_TRACK_ID: &str = "queue-hold-track";
39
40pub struct MediaStreamBuilder {
41    cancel_token: Option<CancellationToken>,
42    id: Option<String>,
43    event_sender: EventSender,
44    recorder_config: Option<RecorderOption>,
45}
46
47impl MediaStreamBuilder {
48    pub fn new(event_sender: EventSender) -> Self {
49        Self {
50            id: Some(format!("ms:{}", uuid::Uuid::new_v4())),
51            cancel_token: None,
52            event_sender,
53            recorder_config: None,
54        }
55    }
56    pub fn with_id(mut self, id: String) -> Self {
57        self.id = Some(id);
58        self
59    }
60
61    pub fn with_cancel_token(mut self, cancel_token: CancellationToken) -> Self {
62        self.cancel_token = Some(cancel_token);
63        self
64    }
65
66    pub fn with_recorder_config(mut self, recorder_config: RecorderOption) -> Self {
67        self.recorder_config = Some(recorder_config);
68        self
69    }
70
71    pub fn build(self) -> MediaStream {
72        let cancel_token = self
73            .cancel_token
74            .unwrap_or_else(|| CancellationToken::new());
75        let tracks = Mutex::new(HashMap::new());
76        let (track_packet_sender, track_packet_receiver) = mpsc::unbounded_channel();
77        let (recorder_sender, recorder_receiver) = mpsc::unbounded_channel();
78        MediaStream {
79            id: self.id.unwrap_or_default(),
80            cancel_token,
81            recorder_option: Mutex::new(self.recorder_config),
82            tracks,
83            suppressed_sources: Mutex::new(HashSet::new()),
84            event_sender: self.event_sender,
85            packet_sender: track_packet_sender,
86            packet_receiver: Mutex::new(Some(track_packet_receiver)),
87            recorder_sender,
88            recorder_receiver: Mutex::new(Some(recorder_receiver)),
89            recorder_handle: Mutex::new(None),
90        }
91    }
92}
93
94impl MediaStream {
95    pub async fn serve(&self) -> Result<()> {
96        let packet_receiver = match self.packet_receiver.lock().await.take() {
97            Some(receiver) => receiver,
98            None => {
99                warn!(
100                    session_id = self.id,
101                    "MediaStream::serve() called multiple times, stream already serving"
102                );
103                return Ok(());
104            }
105        };
106        self.start_recorder().await.ok();
107        info!(session_id = self.id, "mediastream serving");
108        select! {
109            _ = self.cancel_token.cancelled() => {}
110            r = self.handle_forward_track(packet_receiver) => {
111                info!(session_id = self.id, "track packet receiver stopped {:?}", r);
112            }
113        }
114        Ok(())
115    }
116
117    pub fn stop(&self, _reason: Option<String>, _initiator: Option<String>) {
118        self.cancel_token.cancel()
119    }
120
121    pub async fn cleanup(&self) -> Result<()> {
122        self.cancel_token.cancel();
123        if let Some(recorder_handle) = self.recorder_handle.lock().await.take() {
124            if let Ok(Ok(_)) = tokio::time::timeout(Duration::from_secs(30), recorder_handle).await
125            {
126                info!(session_id = self.id, "recorder stopped");
127            } else {
128                warn!(session_id = self.id, "recorder timeout");
129            }
130        }
131        Ok(())
132    }
133
134    pub async fn update_recorder_option(&self, recorder_config: RecorderOption) {
135        *self.recorder_option.lock().await = Some(recorder_config);
136        self.start_recorder().await.ok();
137    }
138
139    pub async fn remove_track(&self, id: &TrackId, graceful: bool) {
140        let track_entry = { self.tracks.lock().await.remove(id) };
141        if let Some((track, _)) = track_entry {
142            self.suppressed_sources.lock().await.remove(id);
143            let res = if !graceful {
144                track.stop().await
145            } else {
146                track.stop_graceful().await
147            };
148            match res {
149                Ok(_) => {}
150                Err(e) => {
151                    warn!(session_id = self.id, "failed to stop track: {}", e);
152                }
153            }
154        }
155    }
156    pub async fn update_remote_description(
157        &self,
158        track_id: &TrackId,
159        answer: &String,
160    ) -> Result<()> {
161        let track_entry = { self.tracks.lock().await.remove(track_id) };
162        if let Some((mut track, dtmf)) = track_entry {
163            let res = track.update_remote_description(answer).await;
164            self.tracks
165                .lock()
166                .await
167                .insert(track_id.clone(), (track, dtmf));
168            res?;
169        }
170        Ok(())
171    }
172
173    pub async fn update_remote_description_force(
174        &self,
175        track_id: &TrackId,
176        answer: &String,
177    ) -> Result<()> {
178        let track_entry = { self.tracks.lock().await.remove(track_id) };
179        if let Some((mut track, dtmf)) = track_entry {
180            let res = track.update_remote_description_force(answer).await;
181            self.tracks
182                .lock()
183                .await
184                .insert(track_id.clone(), (track, dtmf));
185            res?;
186        }
187        Ok(())
188    }
189
190    pub async fn handshake(
191        &self,
192        track_id: &TrackId,
193        offer: String,
194        timeout: Option<Duration>,
195    ) -> Result<String> {
196        let track_entry = { self.tracks.lock().await.remove(track_id) };
197        if let Some((mut track, dtmf)) = track_entry {
198            let res = track.handshake(offer, timeout).await;
199            self.tracks
200                .lock()
201                .await
202                .insert(track_id.clone(), (track, dtmf));
203            res
204        } else {
205            anyhow::bail!("track not found: {}", track_id)
206        }
207    }
208
209    pub async fn update_track(&self, mut track: Box<dyn Track>, play_id: Option<String>) {
210        self.remove_track(track.id(), false).await;
211        if self.recorder_option.lock().await.is_some() {
212            track.insert_processor(Box::new(RecorderProcessor::new(
213                self.recorder_sender.clone(),
214            )));
215        }
216        match track
217            .start(self.event_sender.clone(), self.packet_sender.clone())
218            .await
219        {
220            Ok(_) => {
221                info!(session_id = self.id, track_id = track.id(), "track started");
222                let track_id = track.id().clone();
223                self.tracks
224                    .lock()
225                    .await
226                    .insert(track_id.clone(), (track, DtmfDetector::new()));
227                self.event_sender
228                    .send(SessionEvent::TrackStart {
229                        track_id,
230                        timestamp: crate::media::get_timestamp(),
231                        play_id,
232                    })
233                    .ok();
234            }
235            Err(e) => {
236                warn!(
237                    session_id = self.id,
238                    track_id = track.id(),
239                    play_id = play_id.as_deref(),
240                    "Failed to start track: {}",
241                    e
242                );
243            }
244        }
245    }
246
247    pub async fn mute_track(&self, id: Option<TrackId>) {
248        if let Some(id) = id {
249            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
250                MuteProcessor::mute_track(track.as_mut());
251            }
252        } else {
253            for (track, _) in self.tracks.lock().await.values_mut() {
254                MuteProcessor::mute_track(track.as_mut());
255            }
256        }
257    }
258
259    pub async fn unmute_track(&self, id: Option<TrackId>) {
260        if let Some(id) = id {
261            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
262                MuteProcessor::unmute_track(track.as_mut());
263            }
264        } else {
265            for (track, _) in self.tracks.lock().await.values_mut() {
266                MuteProcessor::unmute_track(track.as_mut());
267            }
268        }
269    }
270
271    pub async fn hold_track(&self, id: Option<TrackId>) {
272        if let Some(id) = id {
273            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
274                HoldTrack::hold_track(track.as_mut());
275            }
276        } else {
277            for (track, _) in self.tracks.lock().await.values_mut() {
278                HoldTrack::hold_track(track.as_mut());
279            }
280        }
281    }
282
283    pub async fn resume_track(&self, id: Option<TrackId>) {
284        if let Some(id) = id {
285            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
286                HoldTrack::resume_track(track.as_mut());
287            }
288        } else {
289            for (track, _) in self.tracks.lock().await.values_mut() {
290                HoldTrack::resume_track(track.as_mut());
291            }
292        }
293    }
294
295    pub async fn suppress_forwarding(&self, track_id: &TrackId) {
296        self.suppressed_sources
297            .lock()
298            .await
299            .insert(track_id.clone());
300    }
301
302    pub async fn resume_forwarding(&self, track_id: &TrackId) {
303        self.suppressed_sources.lock().await.remove(track_id);
304    }
305
306    pub async fn remove_processor<T: 'static>(&self, track_id: &TrackId) -> Result<()> {
307        if let Some((track, _)) = self.tracks.lock().await.get_mut(track_id) {
308            track.as_mut().processor_chain().remove_processor::<T>();
309            Ok(())
310        } else {
311            Err(anyhow::anyhow!("Track {} not found", track_id))
312        }
313    }
314
315    pub async fn append_processor(
316        &self,
317        track_id: &TrackId,
318        processor: Box<dyn crate::media::processor::Processor>,
319    ) -> Result<()> {
320        if let Some((track, _)) = self.tracks.lock().await.get_mut(track_id) {
321            track.as_mut().processor_chain().append_processor(processor);
322            Ok(())
323        } else {
324            Err(anyhow::anyhow!("Track {} not found", track_id))
325        }
326    }
327}
328
329#[derive(Clone)]
330pub struct RecorderProcessor {
331    sender: mpsc::UnboundedSender<AudioFrame>,
332}
333
334impl RecorderProcessor {
335    pub fn new(sender: mpsc::UnboundedSender<AudioFrame>) -> Self {
336        Self { sender }
337    }
338}
339
340impl Processor for RecorderProcessor {
341    fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
342        let frame_clone = frame.clone();
343        let _ = self.sender.send(frame_clone);
344        Ok(())
345    }
346}
347
348impl MediaStream {
349    pub async fn start_recorder(&self) -> Result<()> {
350        let recorder_option = self.recorder_option.lock().await.clone();
351        if let Some(recorder_option) = recorder_option {
352            if recorder_option.recorder_file.is_empty() {
353                warn!(
354                    session_id = self.id,
355                    "recorder file is empty, skipping recorder start"
356                );
357                return Ok(());
358            }
359            let recorder_receiver = match self.recorder_receiver.lock().await.take() {
360                Some(receiver) => receiver,
361                None => {
362                    return Ok(());
363                }
364            };
365            let cancel_token = self.cancel_token.child_token();
366            let session_id_clone = self.id.clone();
367
368            info!(
369                session_id = session_id_clone,
370                sample_rate = recorder_option.samplerate,
371                ptime = recorder_option.ptime,
372                "start recorder",
373            );
374
375            let recorder_handle = crate::spawn(async move {
376                let recorder_file = recorder_option.recorder_file.clone();
377                let recorder =
378                    Recorder::new(cancel_token, session_id_clone.clone(), recorder_option);
379                match recorder
380                    .process_recording(Path::new(&recorder_file), recorder_receiver)
381                    .await
382                {
383                    Ok(_) => {}
384                    Err(e) => {
385                        warn!(
386                            session_id = session_id_clone,
387                            "Failed to process recorder: {}", e
388                        );
389                    }
390                }
391            });
392            *self.recorder_handle.lock().await = Some(recorder_handle);
393        }
394        Ok(())
395    }
396
397    async fn handle_forward_track(&self, mut packet_receiver: TrackPacketReceiver) {
398        let event_sender = self.event_sender.clone();
399        while let Some(packet) = packet_receiver.recv().await {
400            let suppressed = {
401                self.suppressed_sources
402                    .lock()
403                    .await
404                    .contains(&packet.track_id)
405            };
406            // Process the packet with each track
407            for (track, dtmf_detector) in self.tracks.lock().await.values_mut() {
408                if track.id() == &packet.track_id {
409                    match &packet.samples {
410                        Samples::RTP {
411                            payload_type,
412                            payload,
413                            ..
414                        } => {
415                            if let Some(digit) = dtmf_detector.detect_rtp(*payload_type, payload) {
416                                debug!(track_id = track.id(), digit, "DTMF detected");
417                                event_sender
418                                    .send(SessionEvent::Dtmf {
419                                        track_id: packet.track_id.to_string(),
420                                        timestamp: packet.timestamp,
421                                        digit,
422                                    })
423                                    .ok();
424                            }
425                        }
426                        _ => {}
427                    }
428                    continue;
429                }
430                if suppressed {
431                    continue;
432                }
433                if packet.track_id == QUEUE_HOLD_TRACK_ID && track.id() == CALLEE_TRACK_ID {
434                    continue;
435                }
436                if let Err(e) = track.send_packet(&packet).await {
437                    warn!(
438                        id = track.id(),
439                        "media_stream: Failed to send packet to track: {}", e
440                    );
441                }
442            }
443        }
444    }
445}
446
447pub struct MuteProcessor;
448
449impl MuteProcessor {
450    pub fn mute_track(track: &mut dyn Track) {
451        let chain = track.processor_chain();
452        if !chain.has_processor::<MuteProcessor>() {
453            chain.insert_processor(Box::new(MuteProcessor));
454        }
455    }
456
457    pub fn unmute_track(track: &mut dyn Track) {
458        let chain = track.processor_chain();
459        chain.remove_processor::<MuteProcessor>();
460    }
461}
462
463impl Processor for MuteProcessor {
464    fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
465        match &mut frame.samples {
466            Samples::PCM { samples } => {
467                samples.fill(0);
468            }
469            // discard DTMF frames
470            Samples::RTP { payload_type, .. } if *payload_type >= 96 && *payload_type <= 127 => {
471                frame.samples = Samples::Empty;
472            }
473            _ => {}
474        }
475        Ok(())
476    }
477}
478
479pub struct HoldTrack;
480
481impl HoldTrack {
482    pub fn hold_track(track: &mut dyn Track) {
483        let chain = track.processor_chain();
484        // Remove existing processor if present
485        chain.remove_processor::<HoldProcessor>();
486        // Add a new processor with hold state set to true
487        let processor = HoldProcessor::new();
488        processor.set_hold(true);
489        chain.insert_processor(Box::new(processor));
490    }
491
492    pub fn resume_track(track: &mut dyn Track) {
493        let chain = track.processor_chain();
494        // Simply remove the hold processor to resume normal operation
495        chain.remove_processor::<HoldProcessor>();
496    }
497}