Skip to main content

quantum_sdk/
realtime.rs

1//! Realtime voice sessions via WebSocket.
2//!
3//! Connects to the QAI Realtime API (proxied xAI Realtime) for bidirectional
4//! audio streaming with voice activity detection, transcription, and tool calling.
5//!
6//! # Example
7//!
8//! ```no_run
9//! # async fn example() -> quantum_sdk::Result<()> {
10//! let client = quantum_sdk::Client::new("qai_key_xxx");
11//! let config = quantum_sdk::RealtimeConfig::default();
12//!
13//! let (mut sender, mut receiver) = client.realtime_connect(&config).await?;
14//!
15//! // Send audio in a task, receive events in another
16//! tokio::spawn(async move {
17//!     while let Some(event) = receiver.recv().await {
18//!         match event {
19//!             quantum_sdk::RealtimeEvent::AudioDelta { delta } => { /* play PCM */ }
20//!             quantum_sdk::RealtimeEvent::TranscriptDone { transcript, .. } => {
21//!                 println!("Transcript: {transcript}");
22//!             }
23//!             _ => {}
24//!         }
25//!     }
26//! });
27//!
28//! // sender.send_audio(base64_pcm).await?;
29//! # Ok(())
30//! # }
31//! ```
32
33use futures_util::{SinkExt, StreamExt};
34use serde::Serialize;
35use tokio::net::TcpStream;
36use tokio_tungstenite::tungstenite::http::Request;
37use tokio_tungstenite::tungstenite::Message;
38use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
39
40use crate::client::Client;
41use crate::error::{ApiError, Error, Result};
42
43type WsSink = futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
44type WsStream = futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
45
46// ── Public types ──
47
48/// Configuration for a realtime voice session.
49#[derive(Debug, Clone, Serialize)]
50pub struct RealtimeConfig {
51    /// Voice to use (e.g. "Sal", "Eve", "Vesper" for xAI; "alloy", "echo" for OpenAI).
52    pub voice: String,
53
54    /// System instructions for the AI.
55    pub instructions: String,
56
57    /// PCM sample rate in Hz.
58    pub sample_rate: u32,
59
60    /// Tool definitions (xAI Realtime API format).
61    #[serde(skip_serializing_if = "Vec::is_empty")]
62    pub tools: Vec<serde_json::Value>,
63
64    /// Model to use for the realtime session (e.g. "gpt-4o-realtime-preview").
65    /// When empty, the server picks the default for the provider.
66    #[serde(default, skip_serializing_if = "String::is_empty")]
67    pub model: String,
68}
69
70impl Default for RealtimeConfig {
71    fn default() -> Self {
72        Self {
73            voice: "Sal".into(),
74            instructions: String::new(),
75            sample_rate: 24000,
76            tools: Vec::new(),
77            model: String::new(),
78        }
79    }
80}
81
82/// Parsed incoming event from the realtime API.
83#[derive(Debug, Clone)]
84pub enum RealtimeEvent {
85    /// Session configuration acknowledged.
86    SessionReady,
87
88    /// Base64-encoded PCM audio chunk from the assistant.
89    AudioDelta { delta: String },
90
91    /// Partial transcript text.
92    TranscriptDelta {
93        delta: String,
94        /// "input" for user speech, "output" for assistant speech.
95        source: String,
96    },
97
98    /// Final transcript for a completed utterance.
99    TranscriptDone {
100        transcript: String,
101        /// "input" for user speech, "output" for assistant speech.
102        source: String,
103    },
104
105    /// Voice activity detected — user started speaking.
106    SpeechStarted,
107
108    /// Voice activity ended — user stopped speaking.
109    SpeechStopped,
110
111    /// The model is requesting a function/tool call.
112    FunctionCall {
113        name: String,
114        call_id: String,
115        arguments: String,
116    },
117
118    /// The model finished its response turn.
119    ResponseDone,
120
121    /// An error from the realtime API.
122    Error { message: String },
123
124    /// An event type we don't explicitly handle.
125    Unknown(serde_json::Value),
126}
127
128/// Write half of a realtime session — send audio and control messages.
129pub struct RealtimeSender {
130    sink: tokio::sync::Mutex<WsSink>,
131}
132
133/// Read half of a realtime session — receive audio, transcripts, and tool calls.
134pub struct RealtimeReceiver {
135    stream: WsStream,
136}
137
138// ── Client method ──
139
140impl Client {
141    /// Opens a realtime voice session via WebSocket.
142    ///
143    /// Returns `(sender, receiver)` for bidirectional communication.
144    /// The connection is made to `{base_url}/qai/v1/realtime` with the
145    /// client's auth token.
146    pub async fn realtime_connect(
147        &self,
148        config: &RealtimeConfig,
149    ) -> Result<(RealtimeSender, RealtimeReceiver)> {
150        // Convert https:// → wss://, http:// → ws://
151        let base = self.base_url();
152        let ws_base = if base.starts_with("https://") {
153            format!("wss://{}", &base[8..])
154        } else if base.starts_with("http://") {
155            format!("ws://{}", &base[7..])
156        } else {
157            return Err(Error::Api(ApiError {
158                status_code: 0,
159                code: "invalid_base_url".into(),
160                message: format!("Cannot convert base URL to WebSocket: {base}"),
161                request_id: String::new(),
162            }));
163        };
164
165        let url = format!("{ws_base}/qai/v1/realtime");
166
167        // Extract host from the base URL for the Host header
168        let host = base
169            .trim_start_matches("https://")
170            .trim_start_matches("http://")
171            .trim_end_matches('/')
172            .to_string();
173
174        let auth = self
175            .auth_header()
176            .to_str()
177            .unwrap_or("")
178            .to_string();
179
180        // Extract raw token (strip "Bearer " prefix) for X-API-Key
181        let raw_token = auth.strip_prefix("Bearer ").unwrap_or(&auth);
182
183        let request = Request::builder()
184            .uri(&url)
185            .header("Host", &host)
186            .header("Authorization", &auth)
187            .header("X-API-Key", raw_token)
188            .header("Connection", "Upgrade")
189            .header("Upgrade", "websocket")
190            .header("Sec-WebSocket-Version", "13")
191            .header(
192                "Sec-WebSocket-Key",
193                tokio_tungstenite::tungstenite::handshake::client::generate_key(),
194            )
195            .body(())
196            .map_err(|e| Error::Api(ApiError {
197                status_code: 0,
198                code: "websocket_request".into(),
199                message: format!("Failed to build WebSocket request: {e}"),
200                request_id: String::new(),
201            }))?;
202
203        // Connect with timeout
204        let (ws_stream, _response) = tokio::time::timeout(
205            std::time::Duration::from_secs(15),
206            tokio_tungstenite::connect_async(request),
207        )
208        .await
209        .map_err(|_| Error::Api(ApiError {
210            status_code: 0,
211            code: "timeout".into(),
212            message: "WebSocket connection timed out (15s)".into(),
213            request_id: String::new(),
214        }))?
215        .map_err(Error::WebSocket)?;
216
217        let (sink, stream) = ws_stream.split();
218        let sender = RealtimeSender {
219            sink: tokio::sync::Mutex::new(sink),
220        };
221        let receiver = RealtimeReceiver { stream };
222
223        // Send session.update with config
224        let session_update = build_session_update(config);
225        sender.send_raw(&serde_json::to_string(&session_update)?).await?;
226
227        Ok((sender, receiver))
228    }
229}
230
231/// Response from the QAI realtime session endpoint.
232#[derive(Debug, Clone, serde::Deserialize)]
233pub struct RealtimeSession {
234    /// Ephemeral token for direct WebSocket connection (xAI/OpenAI).
235    #[serde(default)]
236    pub ephemeral_token: String,
237    /// WebSocket URL to connect to.
238    /// For xAI: "wss://api.x.ai/v1/realtime"
239    /// For ElevenLabs: the signed WebSocket URL (includes auth in URL).
240    #[serde(default)]
241    pub url: String,
242    /// Signed URL (alias for url — ElevenLabs returns this field name).
243    #[serde(default)]
244    pub signed_url: String,
245    /// Session ID for billing (pass to realtime/end on disconnect).
246    #[serde(default)]
247    pub session_id: String,
248    /// Provider name (e.g. "elevenlabs", "xai").
249    #[serde(default)]
250    pub provider: String,
251}
252
253impl RealtimeSession {
254    /// Get the WebSocket URL — checks both `url` and `signed_url` fields.
255    pub fn ws_url(&self) -> &str {
256        if !self.signed_url.is_empty() { &self.signed_url }
257        else { &self.url }
258    }
259}
260
261impl Client {
262    /// Request an ephemeral token from the QAI proxy for direct voice connection.
263    /// Call this before `realtime_connect_direct` to get a scoped token.
264    /// Pass an optional `provider` to route to a specific backend (e.g. "openai").
265    pub async fn realtime_session(&self) -> Result<RealtimeSession> {
266        self.realtime_session_for(None).await
267    }
268
269    /// Request an ephemeral token for a specific provider.
270    pub async fn realtime_session_for(&self, provider: Option<&str>) -> Result<RealtimeSession> {
271        let mut body = serde_json::json!({});
272        if let Some(p) = provider {
273            body["provider"] = serde_json::Value::String(p.to_string());
274        }
275        let (session, _meta): (RealtimeSession, _) = self
276            .post_json("/qai/v1/realtime/session", &body)
277            .await?;
278        Ok(session)
279    }
280
281    /// End a realtime session and finalize billing.
282    pub async fn realtime_end(&self, session_id: &str, duration_seconds: u64) -> Result<()> {
283        let _: (serde_json::Value, _) = self
284            .post_json(
285                "/qai/v1/realtime/end",
286                &serde_json::json!({
287                    "session_id": session_id,
288                    "duration_seconds": duration_seconds,
289                }),
290            )
291            .await?;
292        Ok(())
293    }
294
295    /// Refresh an ephemeral token for long sessions (>4 min).
296    pub async fn realtime_refresh(&self, session_id: &str) -> Result<String> {
297        let (resp, _): (serde_json::Value, _) = self
298            .post_json(
299                "/qai/v1/realtime/refresh",
300                &serde_json::json!({ "session_id": session_id }),
301            )
302            .await?;
303        Ok(resp["ephemeral_token"]
304            .as_str()
305            .unwrap_or("")
306            .to_string())
307    }
308}
309
310/// Opens a realtime voice session directly to xAI (bypassing the proxy).
311///
312/// Use with an ephemeral token from `client.realtime_session()`.
313/// Much lower latency than the proxy path — no extra hop.
314pub async fn realtime_connect_direct(
315    ephemeral_token: &str,
316    config: &RealtimeConfig,
317) -> Result<(RealtimeSender, RealtimeReceiver)> {
318    realtime_connect_direct_to("wss://api.x.ai/v1/realtime", ephemeral_token, config).await
319}
320
321/// Opens a realtime voice session to a specific WebSocket URL.
322pub async fn realtime_connect_direct_to(
323    url: &str,
324    token: &str,
325    config: &RealtimeConfig,
326) -> Result<(RealtimeSender, RealtimeReceiver)> {
327    // Extract host from URL
328    let host = url
329        .trim_start_matches("wss://")
330        .trim_start_matches("ws://")
331        .split('/')
332        .next()
333        .unwrap_or("api.x.ai");
334
335    let request = Request::builder()
336        .uri(url)
337        .header("Host", host)
338        .header("Authorization", format!("Bearer {token}"))
339        .header("Connection", "Upgrade")
340        .header("Upgrade", "websocket")
341        .header("Sec-WebSocket-Version", "13")
342        .header(
343            "Sec-WebSocket-Key",
344            tokio_tungstenite::tungstenite::handshake::client::generate_key(),
345        )
346        .body(())
347        .map_err(|e| Error::Api(ApiError {
348            status_code: 0,
349            code: "websocket_request".into(),
350            message: format!("Failed to build WebSocket request: {e}"),
351            request_id: String::new(),
352        }))?;
353
354    let (ws_stream, _response) = tokio::time::timeout(
355        std::time::Duration::from_secs(10),
356        tokio_tungstenite::connect_async(request),
357    )
358    .await
359    .map_err(|_| Error::Api(ApiError {
360        status_code: 0,
361        code: "timeout".into(),
362        message: "Direct xAI WebSocket connection timed out (10s)".into(),
363        request_id: String::new(),
364    }))?
365    .map_err(Error::WebSocket)?;
366
367    let (sink, stream) = ws_stream.split();
368    let sender = RealtimeSender {
369        sink: tokio::sync::Mutex::new(sink),
370    };
371    let receiver = RealtimeReceiver { stream };
372
373    // Send session.update
374    let session_update = build_session_update(config);
375    sender.send_raw(&serde_json::to_string(&session_update)?).await?;
376
377    Ok((sender, receiver))
378}
379
380// ── Session update builder ──
381
382/// Build the `session.update` JSON payload from config.
383/// Adapts the format based on whether a model is specified (OpenAI uses `model`
384/// at the session level; xAI uses `input_audio_transcription.model`).
385fn build_session_update(config: &RealtimeConfig) -> serde_json::Value {
386    let is_openai = config.model.contains("gpt-") || config.model.contains("realtime");
387
388    let mut session = serde_json::json!({
389        "voice": config.voice,
390        "instructions": config.instructions,
391        "turn_detection": { "type": "server_vad" },
392        "tools": config.tools,
393    });
394
395    if !config.model.is_empty() {
396        session["model"] = serde_json::Value::String(config.model.clone());
397    }
398
399    if is_openai {
400        // OpenAI Realtime API format: modalities + input_audio_format/output_audio_format
401        session["modalities"] = serde_json::json!(["text", "audio"]);
402        session["input_audio_format"] = serde_json::json!("pcm16");
403        session["output_audio_format"] = serde_json::json!("pcm16");
404        session["input_audio_transcription"] = serde_json::json!({ "model": "gpt-4o-mini-transcribe" });
405    } else {
406        // xAI Realtime API format
407        session["input_audio_transcription"] = serde_json::json!({ "model": "grok-2-audio" });
408        session["audio"] = serde_json::json!({
409            "input": { "format": { "type": "audio/pcm", "rate": config.sample_rate } },
410            "output": { "format": { "type": "audio/pcm", "rate": config.sample_rate } },
411        });
412    }
413
414    serde_json::json!({
415        "type": "session.update",
416        "session": session,
417    })
418}
419
420// ── RealtimeSender ──
421
422// SAFETY: WsSink contains a TcpStream which is Send, and we wrap in tokio::sync::Mutex.
423unsafe impl Send for RealtimeSender {}
424unsafe impl Sync for RealtimeSender {}
425
426impl RealtimeSender {
427    /// Send a base64-encoded PCM audio chunk.
428    pub async fn send_audio(&self, base64_pcm: &str) -> Result<()> {
429        let msg = serde_json::json!({
430            "type": "input_audio_buffer.append",
431            "audio": base64_pcm,
432        });
433        self.send_raw(&serde_json::to_string(&msg)?).await
434    }
435
436    /// Send a text message (creates a conversation item and requests a response).
437    pub async fn send_text(&self, text: &str) -> Result<()> {
438        let item = serde_json::json!({
439            "type": "conversation.item.create",
440            "item": {
441                "type": "message",
442                "role": "user",
443                "content": [{
444                    "type": "input_text",
445                    "text": text,
446                }]
447            }
448        });
449        self.send_raw(&serde_json::to_string(&item)?).await?;
450
451        let response = serde_json::json!({
452            "type": "response.create",
453            "response": {
454                "modalities": ["text", "audio"],
455            }
456        });
457        self.send_raw(&serde_json::to_string(&response)?).await
458    }
459
460    /// Send a function/tool call result back to the model.
461    pub async fn send_function_result(&self, call_id: &str, output: &str) -> Result<()> {
462        let item = serde_json::json!({
463            "type": "conversation.item.create",
464            "item": {
465                "type": "function_call_output",
466                "call_id": call_id,
467                "output": output,
468            }
469        });
470        self.send_raw(&serde_json::to_string(&item)?).await?;
471
472        let response = serde_json::json!({
473            "type": "response.create",
474        });
475        self.send_raw(&serde_json::to_string(&response)?).await
476    }
477
478    /// Cancel the current response (interrupt the model).
479    pub async fn cancel_response(&self) -> Result<()> {
480        let msg = serde_json::json!({ "type": "response.cancel" });
481        self.send_raw(&serde_json::to_string(&msg)?).await
482    }
483
484    /// Close the WebSocket connection gracefully.
485    pub async fn close(self) -> Result<()> {
486        let mut sink = self.sink.into_inner();
487        sink.close().await.map_err(Error::WebSocket)
488    }
489
490    /// Send a raw text frame.
491    async fn send_raw(&self, text: &str) -> Result<()> {
492        let mut sink = self.sink.lock().await;
493        sink.send(Message::Text(text.into()))
494            .await
495            .map_err(Error::WebSocket)
496    }
497}
498
499// ── RealtimeReceiver ──
500
501impl RealtimeReceiver {
502    /// Receive the next event. Returns `None` when the connection closes.
503    pub async fn recv(&mut self) -> Option<RealtimeEvent> {
504        loop {
505            let msg = self.stream.next().await?;
506            match msg {
507                Ok(Message::Text(text)) => {
508                    return Some(parse_event(&text));
509                }
510                Ok(Message::Close(_)) => return None,
511                Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Frame(_)) => continue,
512                Ok(Message::Binary(_)) => continue,
513                Err(_) => return None,
514            }
515        }
516    }
517}
518
519// ── Event parsing ──
520
521fn parse_event(text: &str) -> RealtimeEvent {
522    let Ok(v) = serde_json::from_str::<serde_json::Value>(text) else {
523        return RealtimeEvent::Unknown(serde_json::Value::String(text.to_string()));
524    };
525
526    let event_type = v["type"].as_str().unwrap_or("");
527
528    match event_type {
529        "session.updated" => RealtimeEvent::SessionReady,
530
531        "response.audio.delta" => RealtimeEvent::AudioDelta {
532            delta: v["delta"].as_str().unwrap_or("").to_string(),
533        },
534
535        // Some API versions use "response.output_audio.delta"
536        "response.output_audio.delta" => RealtimeEvent::AudioDelta {
537            delta: v["delta"].as_str().unwrap_or("").to_string(),
538        },
539
540        "response.audio_transcript.delta" | "response.output_audio_transcript.delta" => {
541            RealtimeEvent::TranscriptDelta {
542                delta: v["delta"].as_str().unwrap_or("").to_string(),
543                source: "output".into(),
544            }
545        }
546
547        "response.audio_transcript.done" | "response.output_audio_transcript.done" => {
548            RealtimeEvent::TranscriptDone {
549                transcript: v["transcript"].as_str().unwrap_or("").to_string(),
550                source: "output".into(),
551            }
552        }
553
554        "conversation.item.input_audio_transcription.completed" => {
555            RealtimeEvent::TranscriptDone {
556                transcript: v["transcript"].as_str().unwrap_or("").to_string(),
557                source: "input".into(),
558            }
559        }
560
561        "input_audio_buffer.speech_started" => RealtimeEvent::SpeechStarted,
562        "input_audio_buffer.speech_stopped" => RealtimeEvent::SpeechStopped,
563
564        "response.function_call_arguments.done" => RealtimeEvent::FunctionCall {
565            name: v["name"].as_str().unwrap_or("").to_string(),
566            call_id: v["call_id"].as_str().unwrap_or("").to_string(),
567            arguments: v["arguments"].as_str().unwrap_or("").to_string(),
568        },
569
570        "response.done" => RealtimeEvent::ResponseDone,
571
572        "error" => RealtimeEvent::Error {
573            message: v["error"]["message"]
574                .as_str()
575                .or_else(|| v["message"].as_str())
576                .unwrap_or("unknown error")
577                .to_string(),
578        },
579
580        _ => RealtimeEvent::Unknown(v),
581    }
582}
583
584// ── Tests ──
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn default_config() {
592        let config = RealtimeConfig::default();
593        assert_eq!(config.voice, "Sal");
594        assert_eq!(config.sample_rate, 24000);
595        assert!(config.instructions.is_empty());
596        assert!(config.tools.is_empty());
597        assert!(config.model.is_empty());
598    }
599
600    #[test]
601    fn config_serialization() {
602        let config = RealtimeConfig {
603            voice: "Eve".into(),
604            instructions: "You are a helpful assistant.".into(),
605            sample_rate: 16000,
606            tools: vec![serde_json::json!({
607                "type": "function",
608                "name": "get_weather",
609                "description": "Get weather for a location",
610                "parameters": {
611                    "type": "object",
612                    "properties": {
613                        "location": { "type": "string" }
614                    },
615                    "required": ["location"]
616                }
617            })],
618            model: String::new(),
619        };
620
621        let json = serde_json::to_value(&config).unwrap();
622        assert_eq!(json["voice"], "Eve");
623        assert_eq!(json["sample_rate"], 16000);
624        assert_eq!(json["tools"].as_array().unwrap().len(), 1);
625    }
626
627    #[test]
628    fn parse_session_ready() {
629        let event = parse_event(r#"{"type":"session.updated","session":{}}"#);
630        assert!(matches!(event, RealtimeEvent::SessionReady));
631    }
632
633    #[test]
634    fn parse_audio_delta() {
635        let event = parse_event(r#"{"type":"response.audio.delta","delta":"AQID"}"#);
636        match event {
637            RealtimeEvent::AudioDelta { delta } => assert_eq!(delta, "AQID"),
638            _ => panic!("expected AudioDelta"),
639        }
640    }
641
642    #[test]
643    fn parse_transcript_done() {
644        let event = parse_event(
645            r#"{"type":"conversation.item.input_audio_transcription.completed","transcript":"hello"}"#,
646        );
647        match event {
648            RealtimeEvent::TranscriptDone { transcript, source } => {
649                assert_eq!(transcript, "hello");
650                assert_eq!(source, "input");
651            }
652            _ => panic!("expected TranscriptDone"),
653        }
654    }
655
656    #[test]
657    fn parse_function_call() {
658        let event = parse_event(
659            r#"{"type":"response.function_call_arguments.done","name":"get_weather","call_id":"call_123","arguments":"{\"location\":\"London\"}"}"#,
660        );
661        match event {
662            RealtimeEvent::FunctionCall { name, call_id, arguments } => {
663                assert_eq!(name, "get_weather");
664                assert_eq!(call_id, "call_123");
665                assert!(arguments.contains("London"));
666            }
667            _ => panic!("expected FunctionCall"),
668        }
669    }
670
671    #[test]
672    fn parse_error() {
673        let event = parse_event(r#"{"type":"error","error":{"message":"rate limited"}}"#);
674        match event {
675            RealtimeEvent::Error { message } => assert_eq!(message, "rate limited"),
676            _ => panic!("expected Error"),
677        }
678    }
679
680    #[test]
681    fn parse_unknown() {
682        let event = parse_event(r#"{"type":"some.future.event","data":42}"#);
683        assert!(matches!(event, RealtimeEvent::Unknown(_)));
684    }
685
686    #[test]
687    fn parse_speech_events() {
688        assert!(matches!(
689            parse_event(r#"{"type":"input_audio_buffer.speech_started"}"#),
690            RealtimeEvent::SpeechStarted
691        ));
692        assert!(matches!(
693            parse_event(r#"{"type":"input_audio_buffer.speech_stopped"}"#),
694            RealtimeEvent::SpeechStopped
695        ));
696        assert!(matches!(
697            parse_event(r#"{"type":"response.done"}"#),
698            RealtimeEvent::ResponseDone
699        ));
700    }
701
702    #[ignore]
703    #[tokio::test]
704    async fn live_connect() {
705        // Requires a running QAI server and valid API key.
706        let key = std::env::var("QAI_API_KEY").expect("QAI_API_KEY required");
707        let client = crate::Client::new(key);
708        let config = RealtimeConfig::default();
709
710        let (sender, mut receiver) = client.realtime_connect(&config).await.unwrap();
711
712        // Should receive SessionReady
713        let event = receiver.recv().await.unwrap();
714        assert!(matches!(event, RealtimeEvent::SessionReady));
715
716        sender.close().await.unwrap();
717    }
718}