active_call/media/
realtime_processor.rs

1use super::processor::Processor;
2use crate::{
3    RealtimeOption, RealtimeType,
4    event::{EventSender, SessionEvent},
5    media::{AudioFrame, INTERNAL_SAMPLERATE, Samples, TrackId},
6};
7use anyhow::{Result, anyhow};
8use base64::{Engine as _, engine::general_purpose};
9use futures::{SinkExt, StreamExt};
10use serde_json::{Value, json};
11use tokio::sync::mpsc;
12use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
13use tokio_util::sync::CancellationToken;
14use tracing::{debug, error, info};
15
16pub struct RealtimeProcessor {
17    audio_tx: mpsc::UnboundedSender<Vec<i16>>,
18}
19
20impl RealtimeProcessor {
21    pub fn new(
22        track_id: TrackId,
23        cancel_token: CancellationToken,
24        event_sender: EventSender,
25        packet_sender: mpsc::UnboundedSender<AudioFrame>,
26        option: RealtimeOption,
27    ) -> Result<Self> {
28        let (audio_tx, audio_rx) = mpsc::unbounded_channel::<Vec<i16>>();
29
30        crate::spawn(async move {
31            if let Err(e) = run_realtime_loop(
32                track_id,
33                cancel_token,
34                event_sender,
35                packet_sender,
36                audio_rx,
37                option,
38            )
39            .await
40            {
41                error!("Realtime loop failed: {:?}", e);
42            }
43        });
44
45        Ok(Self { audio_tx })
46    }
47}
48
49impl Processor for RealtimeProcessor {
50    fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
51        if let Samples::PCM { samples } = &frame.samples {
52            if !samples.is_empty() {
53                self.audio_tx.send(samples.clone()).ok();
54            }
55        }
56        Ok(())
57    }
58}
59
60async fn run_realtime_loop(
61    track_id: TrackId,
62    cancel_token: CancellationToken,
63    event_sender: EventSender,
64    packet_sender: mpsc::UnboundedSender<AudioFrame>,
65    mut audio_rx: mpsc::UnboundedReceiver<Vec<i16>>,
66    option: RealtimeOption,
67) -> Result<()> {
68    let provider = option.provider.unwrap_or(RealtimeType::OpenAI);
69    let model = option
70        .model
71        .unwrap_or_else(|| "gpt-4o-realtime-preview-2024-10-01".to_string());
72
73    let url = match provider {
74        RealtimeType::OpenAI => {
75            format!("wss://api.openai.com/v1/realtime?model={}", model)
76        }
77        RealtimeType::Azure => {
78            let endpoint = option
79                .endpoint
80                .ok_or_else(|| anyhow!("Azure endpoint missing"))?;
81            format!(
82                "{}/openai/realtime?api-version=2024-10-01-preview&deployment={}",
83                endpoint, model
84            )
85        }
86        RealtimeType::Other(ref u) => u.clone(),
87    };
88
89    let api_key = option
90        .secret_key
91        .ok_or_else(|| anyhow!("API key missing"))?;
92
93    let mut request_builder = http::Request::builder().uri(&url);
94
95    match provider {
96        RealtimeType::OpenAI => {
97            request_builder = request_builder
98                .header("Authorization", format!("Bearer {}", api_key))
99                .header("OpenAI-Beta", "realtime=v1");
100        }
101        RealtimeType::Azure => {
102            request_builder = request_builder.header("api-key", api_key);
103        }
104        _ => {}
105    }
106
107    let request = request_builder.body(())?;
108
109    info!(url = %url, "Connecting to Realtime API");
110    let (ws_stream, _) = connect_async(request).await?;
111    let (mut ws_tx, mut ws_rx) = ws_stream.split();
112
113    // Initialize session
114    let session_update = json!({
115        "type": "session.update",
116        "session": {
117            "modalities": ["text", "audio"],
118            "instructions": option.extra.as_ref().and_then(|e| e.get("instructions")).cloned().unwrap_or_default(),
119            "voice": option.extra.as_ref().and_then(|e| e.get("voice")).cloned().unwrap_or_else(|| "alloy".to_string()),
120            "input_audio_format": "pcm16",
121            "output_audio_format": "pcm16",
122            "turn_detection": option.turn_detection.unwrap_or(json!({
123                "type": "server_vad",
124                "threshold": 0.5,
125                "prefix_padding_ms": 300,
126                "silence_duration_ms": 500
127            })),
128            "tools": option.tools.unwrap_or_default(),
129        }
130    });
131    ws_tx
132        .send(Message::Text(session_update.to_string().into()))
133        .await?;
134
135    let mut event_rx = event_sender.subscribe();
136
137    loop {
138        tokio::select! {
139            _ = cancel_token.cancelled() => break,
140            Ok(event) = event_rx.recv() => {
141                match event {
142                    SessionEvent::Interrupt { .. } => {
143                        debug!("Interruption received, cancelling response");
144                        let cancel_event = json!({
145                            "type": "response.cancel"
146                        });
147                        ws_tx.send(Message::Text(cancel_event.to_string().into())).await?;
148                    }
149                    _ => {}
150                }
151            }
152            Some(samples) = audio_rx.recv() => {
153                let base64_audio = general_purpose::STANDARD.encode(
154                    samples.iter().flat_map(|&s| s.to_le_bytes()).collect::<Vec<u8>>()
155                );
156                let append_event = json!({
157                    "type": "input_audio_buffer.append",
158                    "audio": base64_audio
159                });
160                ws_tx.send(Message::Text(append_event.to_string().into())).await?;
161            }
162            Some(msg) = ws_rx.next() => {
163                let msg = msg?;
164                if let Message::Text(text) = msg {
165                    let v: Value = serde_json::from_str(&text)?;
166                    match v["type"].as_str() {
167                        Some("response.audio.delta") => {
168                            if let Some(delta_base64) = v["delta"].as_str() {
169                                if let Ok(data) = general_purpose::STANDARD.decode(delta_base64) {
170                                    let samples: Vec<i16> = data.chunks_exact(2)
171                                        .map(|c| i16::from_le_bytes([c[0], c[1]]))
172                                        .collect();
173
174                                    packet_sender.send(AudioFrame {
175                                        track_id: "server-side-track".to_string(),
176                                        samples: Samples::PCM { samples },
177                                        timestamp: crate::media::get_timestamp(),
178                                        sample_rate: INTERNAL_SAMPLERATE,
179                                        channels: 1,
180                                    })?;
181                                }
182                            }
183                        }
184                        Some("input_audio_buffer.speech_started") => {
185                            debug!("Speech started detected by server");
186                            event_sender.send(SessionEvent::Transcription {
187                                track_id: track_id.clone(),
188                                text: "".to_string(),
189                                is_final: false,
190                                timestamp: crate::media::get_timestamp(),
191                                extra: Some(json!({ "event": "speech_started" })),
192                            }).ok();
193
194                            // Immediately signal interruption to stop current local playback
195                            event_sender.send(SessionEvent::Interrupt {
196                                receiver: Some(track_id.clone()),
197                            }).ok();
198                        }
199                        Some("response.audio_transcript.delta") => {
200                            if let Some(delta) = v["delta"].as_str() {
201                                event_sender.send(SessionEvent::Transcription {
202                                    track_id: track_id.clone(),
203                                    text: delta.to_string(),
204                                    is_final: false,
205                                    timestamp: crate::media::get_timestamp(),
206                                    extra: None,
207                                }).ok();
208                            }
209                        }
210                        Some("response.function_call_arguments.done") => {
211                            if let (Some(name), Some(args), Some(call_id)) = (
212                                v["name"].as_str(),
213                                v["arguments"].as_str(),
214                                v["call_id"].as_str(),
215                            ) {
216                                debug!(name, args, call_id, "Function call detected");
217                                event_sender.send(SessionEvent::FunctionCall {
218                                    track_id: track_id.clone(),
219                                    call_id: call_id.to_string(),
220                                    name: name.to_string(),
221                                    arguments: args.to_string(),
222                                    timestamp: crate::media::get_timestamp(),
223                                }).ok();
224                            }
225                        }
226                        Some("error") => {
227                            error!("Realtime error: {}", v["error"]["message"]);
228                        }
229                        _ => {
230                            debug!(msg_type = ?v["type"], "Other Realtime event");
231                        }
232                    }
233                }
234            }
235        }
236    }
237
238    Ok(())
239}