Skip to main content

active_call/media/
stream.rs

1use crate::event::{EventSender, SessionEvent};
2use crate::media::dtmf::DtmfDetector;
3use crate::media::volume_control::HoldProcessor;
4use crate::media::{AudioFrame, Samples, TrackId};
5use crate::media::{
6    processor::Processor,
7    recorder::{Recorder, RecorderOption},
8    track::{Track, TrackPacketReceiver, TrackPacketSender},
9};
10use anyhow::Result;
11use std::collections::{HashMap, HashSet};
12use std::path::Path;
13use std::time::Duration;
14use tokio::task::JoinHandle;
15use tokio::{
16    select,
17    sync::{Mutex, mpsc},
18};
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, info, warn};
21use uuid;
22
23pub struct MediaStream {
24    id: String,
25    pub cancel_token: CancellationToken,
26    recorder_option: Mutex<Option<RecorderOption>>,
27    tracks: Mutex<HashMap<TrackId, (Box<dyn Track>, DtmfDetector)>>,
28    suppressed_sources: Mutex<HashSet<TrackId>>,
29    event_sender: EventSender,
30    pub packet_sender: TrackPacketSender,
31    packet_receiver: Mutex<Option<TrackPacketReceiver>>,
32    recorder_sender: mpsc::UnboundedSender<AudioFrame>,
33    recorder_receiver: Mutex<Option<mpsc::UnboundedReceiver<AudioFrame>>>,
34    recorder_handle: Mutex<Option<JoinHandle<()>>>,
35}
36
37const CALLEE_TRACK_ID: &str = "callee-track";
38const QUEUE_HOLD_TRACK_ID: &str = "queue-hold-track";
39
40pub struct MediaStreamBuilder {
41    cancel_token: Option<CancellationToken>,
42    id: Option<String>,
43    event_sender: EventSender,
44    recorder_config: Option<RecorderOption>,
45}
46
47impl MediaStreamBuilder {
48    pub fn new(event_sender: EventSender) -> Self {
49        Self {
50            id: Some(format!("ms:{}", uuid::Uuid::new_v4())),
51            cancel_token: None,
52            event_sender,
53            recorder_config: None,
54        }
55    }
56    pub fn with_id(mut self, id: String) -> Self {
57        self.id = Some(id);
58        self
59    }
60
61    pub fn with_cancel_token(mut self, cancel_token: CancellationToken) -> Self {
62        self.cancel_token = Some(cancel_token);
63        self
64    }
65
66    pub fn with_recorder_config(mut self, recorder_config: RecorderOption) -> Self {
67        self.recorder_config = Some(recorder_config);
68        self
69    }
70
71    pub fn build(self) -> MediaStream {
72        let cancel_token = self
73            .cancel_token
74            .unwrap_or_else(|| CancellationToken::new());
75        let tracks = Mutex::new(HashMap::new());
76        let (track_packet_sender, track_packet_receiver) = mpsc::unbounded_channel();
77        let (recorder_sender, recorder_receiver) = mpsc::unbounded_channel();
78        MediaStream {
79            id: self.id.unwrap_or_default(),
80            cancel_token,
81            recorder_option: Mutex::new(self.recorder_config),
82            tracks,
83            suppressed_sources: Mutex::new(HashSet::new()),
84            event_sender: self.event_sender,
85            packet_sender: track_packet_sender,
86            packet_receiver: Mutex::new(Some(track_packet_receiver)),
87            recorder_sender,
88            recorder_receiver: Mutex::new(Some(recorder_receiver)),
89            recorder_handle: Mutex::new(None),
90        }
91    }
92}
93
94impl MediaStream {
95    pub async fn serve(&self) -> Result<()> {
96        let packet_receiver = match self.packet_receiver.lock().await.take() {
97            Some(receiver) => receiver,
98            None => {
99                warn!(
100                    session_id = self.id,
101                    "MediaStream::serve() called multiple times, stream already serving"
102                );
103                return Ok(());
104            }
105        };
106        self.start_recorder().await.ok();
107        info!(session_id = self.id, "mediastream serving");
108        select! {
109            _ = self.cancel_token.cancelled() => {}
110            r = self.handle_forward_track(packet_receiver) => {
111                info!(session_id = self.id, "track packet receiver stopped {:?}", r);
112            }
113        }
114        Ok(())
115    }
116
117    pub fn stop(&self, _reason: Option<String>, _initiator: Option<String>) {
118        self.cancel_token.cancel()
119    }
120
121    pub async fn cleanup(&self) -> Result<()> {
122        self.cancel_token.cancel();
123        {
124            let mut tracks = self.tracks.lock().await;
125            for (id, (track, _)) in tracks.drain() {
126                if let Err(e) = track.stop().await {
127                    warn!(session_id = self.id, track_id = %id, "failed to stop track during cleanup: {}", e);
128                }
129            }
130        }
131        self.suppressed_sources.lock().await.clear();
132
133        if let Some(recorder_handle) = self.recorder_handle.lock().await.take() {
134            if let Ok(Ok(_)) = tokio::time::timeout(Duration::from_secs(30), recorder_handle).await
135            {
136                info!(session_id = self.id, "recorder stopped");
137            } else {
138                warn!(session_id = self.id, "recorder timeout");
139            }
140        }
141        Ok(())
142    }
143    pub async fn track_count(&self) -> usize {
144        self.tracks.lock().await.len()
145    }
146
147    pub async fn update_recorder_option(&self, recorder_config: RecorderOption) {
148        *self.recorder_option.lock().await = Some(recorder_config);
149        self.start_recorder().await.ok();
150    }
151
152    pub async fn remove_track(&self, id: &TrackId, graceful: bool) {
153        let track_entry = { self.tracks.lock().await.remove(id) };
154        if let Some((track, _)) = track_entry {
155            self.suppressed_sources.lock().await.remove(id);
156            let res = if !graceful {
157                track.stop().await
158            } else {
159                track.stop_graceful().await
160            };
161            match res {
162                Ok(_) => {}
163                Err(e) => {
164                    warn!(session_id = self.id, "failed to stop track: {}", e);
165                }
166            }
167        }
168    }
169    pub async fn update_remote_description(
170        &self,
171        track_id: &TrackId,
172        answer: &String,
173    ) -> Result<()> {
174        let track_entry = { self.tracks.lock().await.remove(track_id) };
175        if let Some((mut track, dtmf)) = track_entry {
176            let res = track.update_remote_description(answer).await;
177            self.tracks
178                .lock()
179                .await
180                .insert(track_id.clone(), (track, dtmf));
181            res?;
182        }
183        Ok(())
184    }
185
186    pub async fn update_remote_description_force(
187        &self,
188        track_id: &TrackId,
189        answer: &String,
190    ) -> Result<()> {
191        let track_entry = { self.tracks.lock().await.remove(track_id) };
192        if let Some((mut track, dtmf)) = track_entry {
193            let res = track.update_remote_description_force(answer).await;
194            self.tracks
195                .lock()
196                .await
197                .insert(track_id.clone(), (track, dtmf));
198            res?;
199        }
200        Ok(())
201    }
202
203    pub async fn handshake(
204        &self,
205        track_id: &TrackId,
206        offer: String,
207        timeout: Option<Duration>,
208    ) -> Result<String> {
209        let track_entry = { self.tracks.lock().await.remove(track_id) };
210        if let Some((mut track, dtmf)) = track_entry {
211            let res = track.handshake(offer, timeout).await;
212            self.tracks
213                .lock()
214                .await
215                .insert(track_id.clone(), (track, dtmf));
216            res
217        } else {
218            anyhow::bail!("track not found: {}", track_id)
219        }
220    }
221
222    pub async fn update_track(&self, mut track: Box<dyn Track>, play_id: Option<String>) {
223        self.remove_track(track.id(), false).await;
224        if self.recorder_option.lock().await.is_some() {
225            track.insert_processor(Box::new(RecorderProcessor::new(
226                self.recorder_sender.clone(),
227            )));
228        }
229        match track
230            .start(self.event_sender.clone(), self.packet_sender.clone())
231            .await
232        {
233            Ok(_) => {
234                info!(session_id = self.id, track_id = track.id(), "track started");
235                let track_id = track.id().clone();
236                self.tracks
237                    .lock()
238                    .await
239                    .insert(track_id.clone(), (track, DtmfDetector::new()));
240                self.event_sender
241                    .send(SessionEvent::TrackStart {
242                        track_id,
243                        timestamp: crate::media::get_timestamp(),
244                        play_id,
245                    })
246                    .ok();
247            }
248            Err(e) => {
249                warn!(
250                    session_id = self.id,
251                    track_id = track.id(),
252                    play_id = play_id.as_deref(),
253                    "Failed to start track: {}",
254                    e
255                );
256            }
257        }
258    }
259
260    pub async fn mute_track(&self, id: Option<TrackId>) {
261        if let Some(id) = id {
262            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
263                MuteProcessor::mute_track(track.as_mut());
264            }
265        } else {
266            for (track, _) in self.tracks.lock().await.values_mut() {
267                MuteProcessor::mute_track(track.as_mut());
268            }
269        }
270    }
271
272    pub async fn unmute_track(&self, id: Option<TrackId>) {
273        if let Some(id) = id {
274            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
275                MuteProcessor::unmute_track(track.as_mut());
276            }
277        } else {
278            for (track, _) in self.tracks.lock().await.values_mut() {
279                MuteProcessor::unmute_track(track.as_mut());
280            }
281        }
282    }
283
284    pub async fn hold_track(&self, id: Option<TrackId>) {
285        if let Some(id) = id {
286            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
287                HoldTrack::hold_track(track.as_mut());
288            }
289        } else {
290            for (track, _) in self.tracks.lock().await.values_mut() {
291                HoldTrack::hold_track(track.as_mut());
292            }
293        }
294    }
295
296    pub async fn resume_track(&self, id: Option<TrackId>) {
297        if let Some(id) = id {
298            if let Some((track, _)) = self.tracks.lock().await.get_mut(&id) {
299                HoldTrack::resume_track(track.as_mut());
300            }
301        } else {
302            for (track, _) in self.tracks.lock().await.values_mut() {
303                HoldTrack::resume_track(track.as_mut());
304            }
305        }
306    }
307
308    pub async fn suppress_forwarding(&self, track_id: &TrackId) {
309        self.suppressed_sources
310            .lock()
311            .await
312            .insert(track_id.clone());
313    }
314
315    pub async fn resume_forwarding(&self, track_id: &TrackId) {
316        self.suppressed_sources.lock().await.remove(track_id);
317    }
318
319    pub async fn remove_processor<T: 'static>(&self, track_id: &TrackId) -> Result<()> {
320        if let Some((track, _)) = self.tracks.lock().await.get_mut(track_id) {
321            track.as_mut().processor_chain().remove_processor::<T>();
322            Ok(())
323        } else {
324            Err(anyhow::anyhow!("Track {} not found", track_id))
325        }
326    }
327
328    pub async fn append_processor(
329        &self,
330        track_id: &TrackId,
331        processor: Box<dyn crate::media::processor::Processor>,
332    ) -> Result<()> {
333        if let Some((track, _)) = self.tracks.lock().await.get_mut(track_id) {
334            track.as_mut().processor_chain().append_processor(processor);
335            Ok(())
336        } else {
337            Err(anyhow::anyhow!("Track {} not found", track_id))
338        }
339    }
340}
341
342#[derive(Clone)]
343pub struct RecorderProcessor {
344    sender: mpsc::UnboundedSender<AudioFrame>,
345}
346
347impl RecorderProcessor {
348    pub fn new(sender: mpsc::UnboundedSender<AudioFrame>) -> Self {
349        Self { sender }
350    }
351}
352
353impl Processor for RecorderProcessor {
354    fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
355        let frame_clone = frame.clone();
356        let _ = self.sender.send(frame_clone);
357        Ok(())
358    }
359}
360
361impl MediaStream {
362    pub async fn start_recorder(&self) -> Result<()> {
363        let recorder_option = self.recorder_option.lock().await.clone();
364        if let Some(recorder_option) = recorder_option {
365            if recorder_option.recorder_file.is_empty() {
366                warn!(
367                    session_id = self.id,
368                    "recorder file is empty, skipping recorder start"
369                );
370                return Ok(());
371            }
372            let recorder_receiver = match self.recorder_receiver.lock().await.take() {
373                Some(receiver) => receiver,
374                None => {
375                    return Ok(());
376                }
377            };
378            let cancel_token = self.cancel_token.child_token();
379            let session_id_clone = self.id.clone();
380
381            info!(
382                session_id = session_id_clone,
383                sample_rate = recorder_option.samplerate,
384                ptime = recorder_option.ptime,
385                "start recorder",
386            );
387
388            let recorder_handle = crate::spawn(async move {
389                let recorder_file = recorder_option.recorder_file.clone();
390                let recorder =
391                    Recorder::new(cancel_token, session_id_clone.clone(), recorder_option);
392                match recorder
393                    .process_recording(Path::new(&recorder_file), recorder_receiver)
394                    .await
395                {
396                    Ok(_) => {}
397                    Err(e) => {
398                        warn!(
399                            session_id = session_id_clone,
400                            "Failed to process recorder: {}", e
401                        );
402                    }
403                }
404            });
405            *self.recorder_handle.lock().await = Some(recorder_handle);
406        }
407        Ok(())
408    }
409
410    async fn handle_forward_track(&self, mut packet_receiver: TrackPacketReceiver) {
411        let event_sender = self.event_sender.clone();
412        while let Some(packet) = packet_receiver.recv().await {
413            let suppressed = {
414                self.suppressed_sources
415                    .lock()
416                    .await
417                    .contains(&packet.track_id)
418            };
419            // Process the packet with each track
420            for (track, dtmf_detector) in self.tracks.lock().await.values_mut() {
421                if track.id() == &packet.track_id {
422                    match &packet.samples {
423                        Samples::RTP {
424                            payload_type,
425                            payload,
426                            ..
427                        } => {
428                            if let Some(digit) = dtmf_detector.detect_rtp(*payload_type, payload) {
429                                debug!(track_id = track.id(), digit, "DTMF detected");
430                                event_sender
431                                    .send(SessionEvent::Dtmf {
432                                        track_id: packet.track_id.to_string(),
433                                        timestamp: packet.timestamp,
434                                        digit,
435                                    })
436                                    .ok();
437                            }
438                        }
439                        _ => {}
440                    }
441                    continue;
442                }
443                if suppressed {
444                    continue;
445                }
446                if packet.track_id == QUEUE_HOLD_TRACK_ID && track.id() == CALLEE_TRACK_ID {
447                    continue;
448                }
449                if let Err(e) = track.send_packet(&packet).await {
450                    warn!(
451                        id = track.id(),
452                        "media_stream: Failed to send packet to track: {}", e
453                    );
454                }
455            }
456        }
457    }
458}
459
460pub struct MuteProcessor;
461
462impl MuteProcessor {
463    pub fn mute_track(track: &mut dyn Track) {
464        let chain = track.processor_chain();
465        if !chain.has_processor::<MuteProcessor>() {
466            chain.insert_processor(Box::new(MuteProcessor));
467        }
468    }
469
470    pub fn unmute_track(track: &mut dyn Track) {
471        let chain = track.processor_chain();
472        chain.remove_processor::<MuteProcessor>();
473    }
474}
475
476impl Processor for MuteProcessor {
477    fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
478        match &mut frame.samples {
479            Samples::PCM { samples } => {
480                samples.fill(0);
481            }
482            // discard DTMF frames
483            Samples::RTP { payload_type, .. } if *payload_type >= 96 && *payload_type <= 127 => {
484                frame.samples = Samples::Empty;
485            }
486            _ => {}
487        }
488        Ok(())
489    }
490}
491
492pub struct HoldTrack;
493
494impl HoldTrack {
495    pub fn hold_track(track: &mut dyn Track) {
496        let chain = track.processor_chain();
497        // Remove existing processor if present
498        chain.remove_processor::<HoldProcessor>();
499        // Add a new processor with hold state set to true
500        let processor = HoldProcessor::new();
501        processor.set_hold(true);
502        chain.insert_processor(Box::new(processor));
503    }
504
505    pub fn resume_track(track: &mut dyn Track) {
506        let chain = track.processor_chain();
507        // Simply remove the hold processor to resume normal operation
508        chain.remove_processor::<HoldProcessor>();
509    }
510}