1use futures_util::{SinkExt, StreamExt};
34use serde::Serialize;
35use tokio::net::TcpStream;
36use tokio_tungstenite::tungstenite::http::Request;
37use tokio_tungstenite::tungstenite::Message;
38use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
39
40use crate::client::Client;
41use crate::error::{ApiError, Error, Result};
42
43type WsSink = futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
44type WsStream = futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
45
46#[derive(Debug, Clone, Serialize)]
50pub struct RealtimeConfig {
51 pub voice: String,
53
54 pub instructions: String,
56
57 pub sample_rate: u32,
59
60 #[serde(skip_serializing_if = "Vec::is_empty")]
62 pub tools: Vec<serde_json::Value>,
63
64 #[serde(default, skip_serializing_if = "String::is_empty")]
67 pub model: String,
68}
69
70impl Default for RealtimeConfig {
71 fn default() -> Self {
72 Self {
73 voice: "Sal".into(),
74 instructions: String::new(),
75 sample_rate: 24000,
76 tools: Vec::new(),
77 model: String::new(),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
84pub enum RealtimeEvent {
85 SessionReady,
87
88 AudioDelta { delta: String },
90
91 TranscriptDelta {
93 delta: String,
94 source: String,
96 },
97
98 TranscriptDone {
100 transcript: String,
101 source: String,
103 },
104
105 SpeechStarted,
107
108 SpeechStopped,
110
111 FunctionCall {
113 name: String,
114 call_id: String,
115 arguments: String,
116 },
117
118 ResponseDone,
120
121 Error { message: String },
123
124 Unknown(serde_json::Value),
126}
127
128pub struct RealtimeSender {
130 sink: tokio::sync::Mutex<WsSink>,
131}
132
133pub struct RealtimeReceiver {
135 stream: WsStream,
136}
137
138impl Client {
141 pub async fn realtime_connect(
147 &self,
148 config: &RealtimeConfig,
149 ) -> Result<(RealtimeSender, RealtimeReceiver)> {
150 let base = self.base_url();
152 let ws_base = if base.starts_with("https://") {
153 format!("wss://{}", &base[8..])
154 } else if base.starts_with("http://") {
155 format!("ws://{}", &base[7..])
156 } else {
157 return Err(Error::Api(ApiError {
158 status_code: 0,
159 code: "invalid_base_url".into(),
160 message: format!("Cannot convert base URL to WebSocket: {base}"),
161 request_id: String::new(),
162 }));
163 };
164
165 let url = format!("{ws_base}/qai/v1/realtime");
166
167 let host = base
169 .trim_start_matches("https://")
170 .trim_start_matches("http://")
171 .trim_end_matches('/')
172 .to_string();
173
174 let auth = self
175 .auth_header()
176 .to_str()
177 .unwrap_or("")
178 .to_string();
179
180 let raw_token = auth.strip_prefix("Bearer ").unwrap_or(&auth);
182
183 let request = Request::builder()
184 .uri(&url)
185 .header("Host", &host)
186 .header("Authorization", &auth)
187 .header("X-API-Key", raw_token)
188 .header("Connection", "Upgrade")
189 .header("Upgrade", "websocket")
190 .header("Sec-WebSocket-Version", "13")
191 .header(
192 "Sec-WebSocket-Key",
193 tokio_tungstenite::tungstenite::handshake::client::generate_key(),
194 )
195 .body(())
196 .map_err(|e| Error::Api(ApiError {
197 status_code: 0,
198 code: "websocket_request".into(),
199 message: format!("Failed to build WebSocket request: {e}"),
200 request_id: String::new(),
201 }))?;
202
203 let (ws_stream, _response) = tokio::time::timeout(
205 std::time::Duration::from_secs(15),
206 tokio_tungstenite::connect_async(request),
207 )
208 .await
209 .map_err(|_| Error::Api(ApiError {
210 status_code: 0,
211 code: "timeout".into(),
212 message: "WebSocket connection timed out (15s)".into(),
213 request_id: String::new(),
214 }))?
215 .map_err(Error::WebSocket)?;
216
217 let (sink, stream) = ws_stream.split();
218 let sender = RealtimeSender {
219 sink: tokio::sync::Mutex::new(sink),
220 };
221 let receiver = RealtimeReceiver { stream };
222
223 let session_update = build_session_update(config);
225 sender.send_raw(&serde_json::to_string(&session_update)?).await?;
226
227 Ok((sender, receiver))
228 }
229}
230
231#[derive(Debug, Clone, serde::Deserialize)]
233pub struct RealtimeSession {
234 #[serde(default)]
236 pub ephemeral_token: String,
237 #[serde(default)]
241 pub url: String,
242 #[serde(default)]
244 pub signed_url: String,
245 #[serde(default)]
247 pub session_id: String,
248 #[serde(default)]
250 pub provider: String,
251}
252
253impl RealtimeSession {
254 pub fn ws_url(&self) -> &str {
256 if !self.signed_url.is_empty() { &self.signed_url }
257 else { &self.url }
258 }
259}
260
261impl Client {
262 pub async fn realtime_session(&self) -> Result<RealtimeSession> {
266 self.realtime_session_for(None).await
267 }
268
269 pub async fn realtime_session_for(&self, provider: Option<&str>) -> Result<RealtimeSession> {
271 let mut body = serde_json::json!({});
272 if let Some(p) = provider {
273 body["provider"] = serde_json::Value::String(p.to_string());
274 }
275 let (session, _meta): (RealtimeSession, _) = self
276 .post_json("/qai/v1/realtime/session", &body)
277 .await?;
278 Ok(session)
279 }
280
281 pub async fn realtime_end(&self, session_id: &str, duration_seconds: u64) -> Result<()> {
283 let _: (serde_json::Value, _) = self
284 .post_json(
285 "/qai/v1/realtime/end",
286 &serde_json::json!({
287 "session_id": session_id,
288 "duration_seconds": duration_seconds,
289 }),
290 )
291 .await?;
292 Ok(())
293 }
294
295 pub async fn realtime_refresh(&self, session_id: &str) -> Result<String> {
297 let (resp, _): (serde_json::Value, _) = self
298 .post_json(
299 "/qai/v1/realtime/refresh",
300 &serde_json::json!({ "session_id": session_id }),
301 )
302 .await?;
303 Ok(resp["ephemeral_token"]
304 .as_str()
305 .unwrap_or("")
306 .to_string())
307 }
308}
309
310pub async fn realtime_connect_direct(
315 ephemeral_token: &str,
316 config: &RealtimeConfig,
317) -> Result<(RealtimeSender, RealtimeReceiver)> {
318 realtime_connect_direct_to("wss://api.x.ai/v1/realtime", ephemeral_token, config).await
319}
320
321pub async fn realtime_connect_direct_to(
323 url: &str,
324 token: &str,
325 config: &RealtimeConfig,
326) -> Result<(RealtimeSender, RealtimeReceiver)> {
327 let host = url
329 .trim_start_matches("wss://")
330 .trim_start_matches("ws://")
331 .split('/')
332 .next()
333 .unwrap_or("api.x.ai");
334
335 let request = Request::builder()
336 .uri(url)
337 .header("Host", host)
338 .header("Authorization", format!("Bearer {token}"))
339 .header("Connection", "Upgrade")
340 .header("Upgrade", "websocket")
341 .header("Sec-WebSocket-Version", "13")
342 .header(
343 "Sec-WebSocket-Key",
344 tokio_tungstenite::tungstenite::handshake::client::generate_key(),
345 )
346 .body(())
347 .map_err(|e| Error::Api(ApiError {
348 status_code: 0,
349 code: "websocket_request".into(),
350 message: format!("Failed to build WebSocket request: {e}"),
351 request_id: String::new(),
352 }))?;
353
354 let (ws_stream, _response) = tokio::time::timeout(
355 std::time::Duration::from_secs(10),
356 tokio_tungstenite::connect_async(request),
357 )
358 .await
359 .map_err(|_| Error::Api(ApiError {
360 status_code: 0,
361 code: "timeout".into(),
362 message: "Direct xAI WebSocket connection timed out (10s)".into(),
363 request_id: String::new(),
364 }))?
365 .map_err(Error::WebSocket)?;
366
367 let (sink, stream) = ws_stream.split();
368 let sender = RealtimeSender {
369 sink: tokio::sync::Mutex::new(sink),
370 };
371 let receiver = RealtimeReceiver { stream };
372
373 let session_update = build_session_update(config);
375 sender.send_raw(&serde_json::to_string(&session_update)?).await?;
376
377 Ok((sender, receiver))
378}
379
380fn build_session_update(config: &RealtimeConfig) -> serde_json::Value {
386 let is_openai = config.model.contains("gpt-") || config.model.contains("realtime");
387
388 let mut session = serde_json::json!({
389 "voice": config.voice,
390 "instructions": config.instructions,
391 "turn_detection": { "type": "server_vad" },
392 "tools": config.tools,
393 });
394
395 if !config.model.is_empty() {
396 session["model"] = serde_json::Value::String(config.model.clone());
397 }
398
399 if is_openai {
400 session["modalities"] = serde_json::json!(["text", "audio"]);
402 session["input_audio_format"] = serde_json::json!("pcm16");
403 session["output_audio_format"] = serde_json::json!("pcm16");
404 session["input_audio_transcription"] = serde_json::json!({ "model": "gpt-4o-mini-transcribe" });
405 } else {
406 session["input_audio_transcription"] = serde_json::json!({ "model": "grok-2-audio" });
408 session["audio"] = serde_json::json!({
409 "input": { "format": { "type": "audio/pcm", "rate": config.sample_rate } },
410 "output": { "format": { "type": "audio/pcm", "rate": config.sample_rate } },
411 });
412 }
413
414 serde_json::json!({
415 "type": "session.update",
416 "session": session,
417 })
418}
419
420unsafe impl Send for RealtimeSender {}
424unsafe impl Sync for RealtimeSender {}
425
426impl RealtimeSender {
427 pub async fn send_audio(&self, base64_pcm: &str) -> Result<()> {
429 let msg = serde_json::json!({
430 "type": "input_audio_buffer.append",
431 "audio": base64_pcm,
432 });
433 self.send_raw(&serde_json::to_string(&msg)?).await
434 }
435
436 pub async fn send_text(&self, text: &str) -> Result<()> {
438 let item = serde_json::json!({
439 "type": "conversation.item.create",
440 "item": {
441 "type": "message",
442 "role": "user",
443 "content": [{
444 "type": "input_text",
445 "text": text,
446 }]
447 }
448 });
449 self.send_raw(&serde_json::to_string(&item)?).await?;
450
451 let response = serde_json::json!({
452 "type": "response.create",
453 "response": {
454 "modalities": ["text", "audio"],
455 }
456 });
457 self.send_raw(&serde_json::to_string(&response)?).await
458 }
459
460 pub async fn send_function_result(&self, call_id: &str, output: &str) -> Result<()> {
462 let item = serde_json::json!({
463 "type": "conversation.item.create",
464 "item": {
465 "type": "function_call_output",
466 "call_id": call_id,
467 "output": output,
468 }
469 });
470 self.send_raw(&serde_json::to_string(&item)?).await?;
471
472 let response = serde_json::json!({
473 "type": "response.create",
474 });
475 self.send_raw(&serde_json::to_string(&response)?).await
476 }
477
478 pub async fn cancel_response(&self) -> Result<()> {
480 let msg = serde_json::json!({ "type": "response.cancel" });
481 self.send_raw(&serde_json::to_string(&msg)?).await
482 }
483
484 pub async fn close(self) -> Result<()> {
486 let mut sink = self.sink.into_inner();
487 sink.close().await.map_err(Error::WebSocket)
488 }
489
490 async fn send_raw(&self, text: &str) -> Result<()> {
492 let mut sink = self.sink.lock().await;
493 sink.send(Message::Text(text.into()))
494 .await
495 .map_err(Error::WebSocket)
496 }
497}
498
499impl RealtimeReceiver {
502 pub async fn recv(&mut self) -> Option<RealtimeEvent> {
504 loop {
505 let msg = self.stream.next().await?;
506 match msg {
507 Ok(Message::Text(text)) => {
508 return Some(parse_event(&text));
509 }
510 Ok(Message::Close(_)) => return None,
511 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Frame(_)) => continue,
512 Ok(Message::Binary(_)) => continue,
513 Err(_) => return None,
514 }
515 }
516 }
517}
518
519fn parse_event(text: &str) -> RealtimeEvent {
522 let Ok(v) = serde_json::from_str::<serde_json::Value>(text) else {
523 return RealtimeEvent::Unknown(serde_json::Value::String(text.to_string()));
524 };
525
526 let event_type = v["type"].as_str().unwrap_or("");
527
528 match event_type {
529 "session.updated" => RealtimeEvent::SessionReady,
530
531 "response.audio.delta" => RealtimeEvent::AudioDelta {
532 delta: v["delta"].as_str().unwrap_or("").to_string(),
533 },
534
535 "response.output_audio.delta" => RealtimeEvent::AudioDelta {
537 delta: v["delta"].as_str().unwrap_or("").to_string(),
538 },
539
540 "response.audio_transcript.delta" | "response.output_audio_transcript.delta" => {
541 RealtimeEvent::TranscriptDelta {
542 delta: v["delta"].as_str().unwrap_or("").to_string(),
543 source: "output".into(),
544 }
545 }
546
547 "response.audio_transcript.done" | "response.output_audio_transcript.done" => {
548 RealtimeEvent::TranscriptDone {
549 transcript: v["transcript"].as_str().unwrap_or("").to_string(),
550 source: "output".into(),
551 }
552 }
553
554 "conversation.item.input_audio_transcription.completed" => {
555 RealtimeEvent::TranscriptDone {
556 transcript: v["transcript"].as_str().unwrap_or("").to_string(),
557 source: "input".into(),
558 }
559 }
560
561 "input_audio_buffer.speech_started" => RealtimeEvent::SpeechStarted,
562 "input_audio_buffer.speech_stopped" => RealtimeEvent::SpeechStopped,
563
564 "response.function_call_arguments.done" => RealtimeEvent::FunctionCall {
565 name: v["name"].as_str().unwrap_or("").to_string(),
566 call_id: v["call_id"].as_str().unwrap_or("").to_string(),
567 arguments: v["arguments"].as_str().unwrap_or("").to_string(),
568 },
569
570 "response.done" => RealtimeEvent::ResponseDone,
571
572 "error" => RealtimeEvent::Error {
573 message: v["error"]["message"]
574 .as_str()
575 .or_else(|| v["message"].as_str())
576 .unwrap_or("unknown error")
577 .to_string(),
578 },
579
580 _ => RealtimeEvent::Unknown(v),
581 }
582}
583
584#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn default_config() {
592 let config = RealtimeConfig::default();
593 assert_eq!(config.voice, "Sal");
594 assert_eq!(config.sample_rate, 24000);
595 assert!(config.instructions.is_empty());
596 assert!(config.tools.is_empty());
597 assert!(config.model.is_empty());
598 }
599
600 #[test]
601 fn config_serialization() {
602 let config = RealtimeConfig {
603 voice: "Eve".into(),
604 instructions: "You are a helpful assistant.".into(),
605 sample_rate: 16000,
606 tools: vec![serde_json::json!({
607 "type": "function",
608 "name": "get_weather",
609 "description": "Get weather for a location",
610 "parameters": {
611 "type": "object",
612 "properties": {
613 "location": { "type": "string" }
614 },
615 "required": ["location"]
616 }
617 })],
618 model: String::new(),
619 };
620
621 let json = serde_json::to_value(&config).unwrap();
622 assert_eq!(json["voice"], "Eve");
623 assert_eq!(json["sample_rate"], 16000);
624 assert_eq!(json["tools"].as_array().unwrap().len(), 1);
625 }
626
627 #[test]
628 fn parse_session_ready() {
629 let event = parse_event(r#"{"type":"session.updated","session":{}}"#);
630 assert!(matches!(event, RealtimeEvent::SessionReady));
631 }
632
633 #[test]
634 fn parse_audio_delta() {
635 let event = parse_event(r#"{"type":"response.audio.delta","delta":"AQID"}"#);
636 match event {
637 RealtimeEvent::AudioDelta { delta } => assert_eq!(delta, "AQID"),
638 _ => panic!("expected AudioDelta"),
639 }
640 }
641
642 #[test]
643 fn parse_transcript_done() {
644 let event = parse_event(
645 r#"{"type":"conversation.item.input_audio_transcription.completed","transcript":"hello"}"#,
646 );
647 match event {
648 RealtimeEvent::TranscriptDone { transcript, source } => {
649 assert_eq!(transcript, "hello");
650 assert_eq!(source, "input");
651 }
652 _ => panic!("expected TranscriptDone"),
653 }
654 }
655
656 #[test]
657 fn parse_function_call() {
658 let event = parse_event(
659 r#"{"type":"response.function_call_arguments.done","name":"get_weather","call_id":"call_123","arguments":"{\"location\":\"London\"}"}"#,
660 );
661 match event {
662 RealtimeEvent::FunctionCall { name, call_id, arguments } => {
663 assert_eq!(name, "get_weather");
664 assert_eq!(call_id, "call_123");
665 assert!(arguments.contains("London"));
666 }
667 _ => panic!("expected FunctionCall"),
668 }
669 }
670
671 #[test]
672 fn parse_error() {
673 let event = parse_event(r#"{"type":"error","error":{"message":"rate limited"}}"#);
674 match event {
675 RealtimeEvent::Error { message } => assert_eq!(message, "rate limited"),
676 _ => panic!("expected Error"),
677 }
678 }
679
680 #[test]
681 fn parse_unknown() {
682 let event = parse_event(r#"{"type":"some.future.event","data":42}"#);
683 assert!(matches!(event, RealtimeEvent::Unknown(_)));
684 }
685
686 #[test]
687 fn parse_speech_events() {
688 assert!(matches!(
689 parse_event(r#"{"type":"input_audio_buffer.speech_started"}"#),
690 RealtimeEvent::SpeechStarted
691 ));
692 assert!(matches!(
693 parse_event(r#"{"type":"input_audio_buffer.speech_stopped"}"#),
694 RealtimeEvent::SpeechStopped
695 ));
696 assert!(matches!(
697 parse_event(r#"{"type":"response.done"}"#),
698 RealtimeEvent::ResponseDone
699 ));
700 }
701
702 #[ignore]
703 #[tokio::test]
704 async fn live_connect() {
705 let key = std::env::var("QAI_API_KEY").expect("QAI_API_KEY required");
707 let client = crate::Client::new(key);
708 let config = RealtimeConfig::default();
709
710 let (sender, mut receiver) = client.realtime_connect(&config).await.unwrap();
711
712 let event = receiver.recv().await.unwrap();
714 assert!(matches!(event, RealtimeEvent::SessionReady));
715
716 sender.close().await.unwrap();
717 }
718}