1use super::processor::Processor;
2use crate::{
3 RealtimeOption, RealtimeType,
4 event::{EventSender, SessionEvent},
5 media::{AudioFrame, INTERNAL_SAMPLERATE, Samples, TrackId},
6};
7use anyhow::{Result, anyhow};
8use base64::{Engine as _, engine::general_purpose};
9use futures::{SinkExt, StreamExt};
10use serde_json::{Value, json};
11use tokio::sync::mpsc;
12use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
13use tokio_util::sync::CancellationToken;
14use tracing::{debug, error, info};
15
16pub struct RealtimeProcessor {
17 audio_tx: mpsc::UnboundedSender<Vec<i16>>,
18}
19
20impl RealtimeProcessor {
21 pub fn new(
22 track_id: TrackId,
23 cancel_token: CancellationToken,
24 event_sender: EventSender,
25 packet_sender: mpsc::UnboundedSender<AudioFrame>,
26 option: RealtimeOption,
27 ) -> Result<Self> {
28 let (audio_tx, audio_rx) = mpsc::unbounded_channel::<Vec<i16>>();
29
30 crate::spawn(async move {
31 if let Err(e) = run_realtime_loop(
32 track_id,
33 cancel_token,
34 event_sender,
35 packet_sender,
36 audio_rx,
37 option,
38 )
39 .await
40 {
41 error!("Realtime loop failed: {:?}", e);
42 }
43 });
44
45 Ok(Self { audio_tx })
46 }
47}
48
49impl Processor for RealtimeProcessor {
50 fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
51 if let Samples::PCM { samples } = &frame.samples {
52 if !samples.is_empty() {
53 self.audio_tx.send(samples.clone()).ok();
54 }
55 }
56 Ok(())
57 }
58}
59
60async fn run_realtime_loop(
61 track_id: TrackId,
62 cancel_token: CancellationToken,
63 event_sender: EventSender,
64 packet_sender: mpsc::UnboundedSender<AudioFrame>,
65 mut audio_rx: mpsc::UnboundedReceiver<Vec<i16>>,
66 option: RealtimeOption,
67) -> Result<()> {
68 let provider = option.provider.unwrap_or(RealtimeType::OpenAI);
69 let model = option
70 .model
71 .unwrap_or_else(|| "gpt-4o-realtime-preview-2024-10-01".to_string());
72
73 let url = match provider {
74 RealtimeType::OpenAI => {
75 format!("wss://api.openai.com/v1/realtime?model={}", model)
76 }
77 RealtimeType::Azure => {
78 let endpoint = option
79 .endpoint
80 .ok_or_else(|| anyhow!("Azure endpoint missing"))?;
81 format!(
82 "{}/openai/realtime?api-version=2024-10-01-preview&deployment={}",
83 endpoint, model
84 )
85 }
86 RealtimeType::Other(ref u) => u.clone(),
87 };
88
89 let api_key = option
90 .secret_key
91 .ok_or_else(|| anyhow!("API key missing"))?;
92
93 let mut request_builder = http::Request::builder().uri(&url);
94
95 match provider {
96 RealtimeType::OpenAI => {
97 request_builder = request_builder
98 .header("Authorization", format!("Bearer {}", api_key))
99 .header("OpenAI-Beta", "realtime=v1");
100 }
101 RealtimeType::Azure => {
102 request_builder = request_builder.header("api-key", api_key);
103 }
104 _ => {}
105 }
106
107 let request = request_builder.body(())?;
108
109 info!(url = %url, "Connecting to Realtime API");
110 let (ws_stream, _) = connect_async(request).await?;
111 let (mut ws_tx, mut ws_rx) = ws_stream.split();
112
113 let session_update = json!({
115 "type": "session.update",
116 "session": {
117 "modalities": ["text", "audio"],
118 "instructions": option.extra.as_ref().and_then(|e| e.get("instructions")).cloned().unwrap_or_default(),
119 "voice": option.extra.as_ref().and_then(|e| e.get("voice")).cloned().unwrap_or_else(|| "alloy".to_string()),
120 "input_audio_format": "pcm16",
121 "output_audio_format": "pcm16",
122 "turn_detection": option.turn_detection.unwrap_or(json!({
123 "type": "server_vad",
124 "threshold": 0.5,
125 "prefix_padding_ms": 300,
126 "silence_duration_ms": 500
127 })),
128 "tools": option.tools.unwrap_or_default(),
129 }
130 });
131 ws_tx
132 .send(Message::Text(session_update.to_string().into()))
133 .await?;
134
135 let mut event_rx = event_sender.subscribe();
136
137 loop {
138 tokio::select! {
139 _ = cancel_token.cancelled() => break,
140 Ok(event) = event_rx.recv() => {
141 match event {
142 SessionEvent::Interrupt { .. } => {
143 debug!("Interruption received, cancelling response");
144 let cancel_event = json!({
145 "type": "response.cancel"
146 });
147 ws_tx.send(Message::Text(cancel_event.to_string().into())).await?;
148 }
149 _ => {}
150 }
151 }
152 Some(samples) = audio_rx.recv() => {
153 let base64_audio = general_purpose::STANDARD.encode(
154 samples.iter().flat_map(|&s| s.to_le_bytes()).collect::<Vec<u8>>()
155 );
156 let append_event = json!({
157 "type": "input_audio_buffer.append",
158 "audio": base64_audio
159 });
160 ws_tx.send(Message::Text(append_event.to_string().into())).await?;
161 }
162 Some(msg) = ws_rx.next() => {
163 let msg = msg?;
164 if let Message::Text(text) = msg {
165 let v: Value = serde_json::from_str(&text)?;
166 match v["type"].as_str() {
167 Some("response.audio.delta") => {
168 if let Some(delta_base64) = v["delta"].as_str() {
169 if let Ok(data) = general_purpose::STANDARD.decode(delta_base64) {
170 let samples: Vec<i16> = data.chunks_exact(2)
171 .map(|c| i16::from_le_bytes([c[0], c[1]]))
172 .collect();
173
174 packet_sender.send(AudioFrame {
175 track_id: "server-side-track".to_string(),
176 samples: Samples::PCM { samples },
177 timestamp: crate::media::get_timestamp(),
178 sample_rate: INTERNAL_SAMPLERATE,
179 channels: 1,
180 })?;
181 }
182 }
183 }
184 Some("input_audio_buffer.speech_started") => {
185 debug!("Speech started detected by server");
186 event_sender.send(SessionEvent::Speaking{
187 track_id: track_id.clone(),
188 timestamp: crate::media::get_timestamp(),
189 start_time: v["audio_start_ms"].as_u64().unwrap_or_default(),
190 is_filler: None,
191 confidence: None,
192 }).ok();
193
194 event_sender.send(SessionEvent::Interrupt {
196 receiver: Some(track_id.clone()),
197 }).ok();
198 }
199 Some("response.audio_transcript.delta") => {
200 if let Some(delta) = v["delta"].as_str() {
201 event_sender.send(SessionEvent::AsrDelta {
202 track_id: track_id.clone(),
203 index: v["content_index"].as_u64().unwrap_or_default() as u32,
204 text: delta.to_string(),
205 timestamp: crate::media::get_timestamp(),
206 task_id: Some(v["item_id"].as_str().unwrap_or_default().to_string()),
207 start_time: None,
208 end_time: None,
209 is_filler: None,
210 confidence: None,
211 }).ok();
212 }
213 }
214 Some("response.function_call_arguments.done") => {
215 if let (Some(name), Some(args), Some(call_id)) = (
216 v["name"].as_str(),
217 v["arguments"].as_str(),
218 v["call_id"].as_str(),
219 ) {
220 debug!(name, args, call_id, "Function call detected");
221 event_sender.send(SessionEvent::FunctionCall {
222 track_id: track_id.clone(),
223 call_id: call_id.to_string(),
224 name: name.to_string(),
225 arguments: args.to_string(),
226 timestamp: crate::media::get_timestamp(),
227 }).ok();
228 }
229 }
230 Some("error") => {
231 error!("Realtime error: {}", v["error"]["message"]);
232 }
233 _ => {
234 debug!(msg_type = ?v["type"], "Other Realtime event");
235 }
236 }
237 }
238 }
239 }
240 }
241
242 Ok(())
243}