active_call/media/track/
websocket.rs

1use super::{Track, TrackConfig, TrackPacketSender, track_codec::TrackCodec};
2use crate::{
3    event::{EventSender, SessionEvent},
4    media::AudioFrame,
5    media::Samples,
6    media::TrackId,
7    media::processor::ProcessorChain,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use audio_codec::bytes_to_samples;
12use bytes::Bytes;
13use std::{sync::Mutex, time::Duration};
14use tokio::select;
15use tokio_util::sync::CancellationToken;
16use tracing::{info, warn};
17
18pub type WebsocketBytesSender = tokio::sync::mpsc::UnboundedSender<Bytes>;
19pub type WebsocketBytesReceiver = tokio::sync::mpsc::UnboundedReceiver<Bytes>;
20
21pub struct WebsocketTrack {
22    track_id: TrackId,
23    config: TrackConfig,
24    cancel_token: CancellationToken,
25    processor_chain: ProcessorChain,
26    rx: Mutex<Option<WebsocketBytesReceiver>>,
27    encoder: TrackCodec,
28    payload_type: u8,
29    event_sender: EventSender,
30    ssrc: u32,
31}
32
33impl WebsocketTrack {
34    pub fn new(
35        cancel_token: CancellationToken,
36        track_id: TrackId,
37        track_config: TrackConfig,
38        event_sender: EventSender,
39        audio_receiver: WebsocketBytesReceiver,
40        codec: Option<String>,
41        ssrc: u32,
42    ) -> Self {
43        let processor_chain = ProcessorChain::new(track_config.samplerate);
44        let payload_type = match codec.unwrap_or("pcm".to_string()).to_lowercase().as_str() {
45            "pcmu" => 0,
46            "pcma" => 8,
47            "g722" => 9,
48            _ => u8::MAX, // PCM
49        };
50        Self {
51            track_id,
52            config: track_config,
53            cancel_token,
54            processor_chain,
55            rx: Mutex::new(Some(audio_receiver)),
56            encoder: TrackCodec::new(),
57            payload_type,
58            event_sender,
59            ssrc,
60        }
61    }
62}
63
64#[async_trait]
65impl Track for WebsocketTrack {
66    fn ssrc(&self) -> u32 {
67        self.ssrc
68    }
69    fn id(&self) -> &TrackId {
70        &self.track_id
71    }
72    fn config(&self) -> &TrackConfig {
73        &self.config
74    }
75    fn processor_chain(&mut self) -> &mut ProcessorChain {
76        &mut self.processor_chain
77    }
78
79    async fn handshake(&mut self, _offer: String, _timeout: Option<Duration>) -> Result<String> {
80        Ok("".to_string())
81    }
82    async fn update_remote_description(&mut self, _answer: &String) -> Result<()> {
83        Ok(())
84    }
85
86    async fn start(
87        &mut self,
88        event_sender: EventSender,
89        packet_sender: TrackPacketSender,
90    ) -> Result<()> {
91        let track_id = self.track_id.clone();
92        let token = self.cancel_token.clone();
93        let mut audio_from_ws = match self.rx.lock().unwrap().take() {
94            Some(rx) => rx,
95            None => {
96                warn!(track_id, "no audio from ws");
97                return Ok(());
98            }
99        };
100        let sample_rate = self.config.samplerate;
101        let channels = self.config.channels;
102        let payload_type = self.payload_type;
103        let start_time = crate::media::get_timestamp();
104        let ssrc = self.ssrc;
105        let mut processor_chain = self.processor_chain.clone();
106        crate::spawn(async move {
107            let track_id_clone = track_id.clone();
108            let audio_from_ws_loop = async move {
109                let mut sequence_number = 0;
110                while let Some(bytes) = audio_from_ws.recv().await {
111                    sequence_number += 1;
112
113                    let samples = match payload_type {
114                        u8::MAX => Samples::PCM {
115                            samples: bytes_to_samples(&bytes.to_vec()),
116                        },
117                        _ => Samples::RTP {
118                            sequence_number,
119                            payload_type,
120                            payload: bytes.to_vec(),
121                        },
122                    };
123
124                    let mut packet = AudioFrame {
125                        track_id: track_id_clone.clone(),
126                        samples,
127                        timestamp: crate::media::get_timestamp(),
128                        sample_rate,
129                        channels,
130                    };
131
132                    if let Err(e) = processor_chain.process_frame(&mut packet) {
133                        warn!("error processing frame: {}", e);
134                    }
135
136                    match packet_sender.send(packet) {
137                        Ok(_) => (),
138                        Err(e) => {
139                            warn!("error sending packet: {}", e);
140                            break;
141                        }
142                    }
143                }
144            };
145
146            select! {
147                _ = token.cancelled() => {
148                    info!("RTC process cancelled");
149                },
150                _ = audio_from_ws_loop => {
151                    info!("audio_from_ws_loop");
152                }
153            };
154
155            event_sender
156                .send(SessionEvent::TrackEnd {
157                    track_id,
158                    timestamp: crate::media::get_timestamp(),
159                    duration: crate::media::get_timestamp() - start_time,
160                    ssrc,
161                    play_id: None,
162                })
163                .ok();
164        });
165        Ok(())
166    }
167
168    async fn stop(&self) -> Result<()> {
169        self.cancel_token.cancel();
170        Ok(())
171    }
172
173    async fn send_packet(&mut self, packet: &AudioFrame) -> Result<()> {
174        let packet = packet.clone();
175        // Do not run the processor chain for outgoing packets to the user.
176        // The processor chain (VAD, ASR, etc.) is intended for audio coming FROM the user.
177
178        let (_, payload) = self.encoder.encode(self.payload_type, packet);
179        if payload.is_empty() {
180            return Ok(());
181        }
182        self.event_sender
183            .send(SessionEvent::Binary {
184                track_id: self.track_id.clone(),
185                timestamp: crate::media::get_timestamp(),
186                data: payload,
187            })
188            .map(|_| ())
189            .map_err(|_| anyhow::anyhow!("error sending binary event"))
190    }
191}