1use futures_util::{SinkExt, StreamExt};
34use serde::Serialize;
35use tokio::net::TcpStream;
36use tokio_tungstenite::tungstenite::http::Request;
37use tokio_tungstenite::tungstenite::Message;
38use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
39
40use crate::client::Client;
41use crate::error::{ApiError, Error, Result};
42
43type WsSink = futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
44type WsStream = futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
45
46#[derive(Debug, Clone, Serialize)]
50pub struct RealtimeConfig {
51 pub voice: String,
53
54 pub instructions: String,
56
57 pub sample_rate: u32,
59
60 #[serde(skip_serializing_if = "Vec::is_empty")]
62 pub tools: Vec<serde_json::Value>,
63
64 #[serde(default, skip_serializing_if = "String::is_empty")]
67 pub model: String,
68}
69
70impl Default for RealtimeConfig {
71 fn default() -> Self {
72 Self {
73 voice: "Sal".into(),
74 instructions: String::new(),
75 sample_rate: 24000,
76 tools: Vec::new(),
77 model: String::new(),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
84pub enum RealtimeEvent {
85 SessionReady,
87
88 AudioDelta { delta: String },
90
91 TranscriptDelta {
93 delta: String,
94 source: String,
96 },
97
98 TranscriptDone {
100 transcript: String,
101 source: String,
103 },
104
105 SpeechStarted,
107
108 SpeechStopped,
110
111 FunctionCall {
113 name: String,
114 call_id: String,
115 arguments: String,
116 },
117
118 ResponseDone,
120
121 Error { message: String },
123
124 Unknown(serde_json::Value),
126}
127
128pub struct RealtimeSender {
130 sink: tokio::sync::Mutex<WsSink>,
131}
132
133pub struct RealtimeReceiver {
135 stream: WsStream,
136}
137
138impl Client {
141 pub async fn realtime_connect(
147 &self,
148 config: &RealtimeConfig,
149 ) -> Result<(RealtimeSender, RealtimeReceiver)> {
150 let base = self.base_url();
152 let ws_base = if base.starts_with("https://") {
153 format!("wss://{}", &base[8..])
154 } else if base.starts_with("http://") {
155 format!("ws://{}", &base[7..])
156 } else {
157 return Err(Error::Api(ApiError {
158 status_code: 0,
159 code: "invalid_base_url".into(),
160 message: format!("Cannot convert base URL to WebSocket: {base}"),
161 request_id: String::new(),
162 }));
163 };
164
165 let url = format!("{ws_base}/qai/v1/realtime");
166
167 let host = base
169 .trim_start_matches("https://")
170 .trim_start_matches("http://")
171 .trim_end_matches('/')
172 .to_string();
173
174 let auth = self
175 .auth_header()
176 .to_str()
177 .unwrap_or("")
178 .to_string();
179
180 let raw_token = auth.strip_prefix("Bearer ").unwrap_or(&auth);
182
183 let request = Request::builder()
184 .uri(&url)
185 .header("Host", &host)
186 .header("Authorization", &auth)
187 .header("X-API-Key", raw_token)
188 .header("Connection", "Upgrade")
189 .header("Upgrade", "websocket")
190 .header("Sec-WebSocket-Version", "13")
191 .header(
192 "Sec-WebSocket-Key",
193 tokio_tungstenite::tungstenite::handshake::client::generate_key(),
194 )
195 .body(())
196 .map_err(|e| Error::Api(ApiError {
197 status_code: 0,
198 code: "websocket_request".into(),
199 message: format!("Failed to build WebSocket request: {e}"),
200 request_id: String::new(),
201 }))?;
202
203 let (ws_stream, _response) = tokio::time::timeout(
205 std::time::Duration::from_secs(15),
206 tokio_tungstenite::connect_async(request),
207 )
208 .await
209 .map_err(|_| Error::Api(ApiError {
210 status_code: 0,
211 code: "timeout".into(),
212 message: "WebSocket connection timed out (15s)".into(),
213 request_id: String::new(),
214 }))?
215 .map_err(Error::WebSocket)?;
216
217 let (sink, stream) = ws_stream.split();
218 let sender = RealtimeSender {
219 sink: tokio::sync::Mutex::new(sink),
220 };
221 let receiver = RealtimeReceiver { stream };
222
223 let session_update = build_session_update(config);
225 sender.send_raw(&serde_json::to_string(&session_update)?).await?;
226
227 Ok((sender, receiver))
228 }
229}
230
231#[derive(Debug, Clone, serde::Deserialize)]
233pub struct RealtimeSession {
234 #[serde(default)]
236 pub ephemeral_token: String,
237 #[serde(default)]
241 pub url: String,
242 #[serde(default)]
244 pub signed_url: String,
245 #[serde(default)]
247 pub session_id: String,
248 #[serde(default)]
250 pub provider: String,
251}
252
253pub type RealtimeSessionResponse = RealtimeSession;
255
256impl RealtimeSession {
257 pub fn ws_url(&self) -> &str {
259 if !self.signed_url.is_empty() { &self.signed_url }
260 else { &self.url }
261 }
262}
263
264impl Client {
265 pub async fn realtime_session(&self) -> Result<RealtimeSession> {
269 self.realtime_session_for(None).await
270 }
271
272 pub async fn realtime_session_for(&self, provider: Option<&str>) -> Result<RealtimeSession> {
274 self.realtime_session_with(provider, serde_json::json!({})).await
275 }
276
277 pub async fn realtime_session_with(
281 &self,
282 provider: Option<&str>,
283 mut body: serde_json::Value,
284 ) -> Result<RealtimeSession> {
285 if let Some(p) = provider {
286 body["provider"] = serde_json::Value::String(p.to_string());
287 }
288 let (session, _meta): (RealtimeSession, _) = self
289 .post_json("/qai/v1/realtime/session", &body)
290 .await?;
291 Ok(session)
292 }
293
294 pub async fn realtime_end(&self, session_id: &str, duration_seconds: u64) -> Result<()> {
296 let _: (serde_json::Value, _) = self
297 .post_json(
298 "/qai/v1/realtime/end",
299 &serde_json::json!({
300 "session_id": session_id,
301 "duration_seconds": duration_seconds,
302 }),
303 )
304 .await?;
305 Ok(())
306 }
307
308 pub async fn realtime_refresh(&self, session_id: &str) -> Result<String> {
310 let (resp, _): (serde_json::Value, _) = self
311 .post_json(
312 "/qai/v1/realtime/refresh",
313 &serde_json::json!({ "session_id": session_id }),
314 )
315 .await?;
316 Ok(resp["ephemeral_token"]
317 .as_str()
318 .unwrap_or("")
319 .to_string())
320 }
321}
322
323pub async fn realtime_connect_direct(
328 ephemeral_token: &str,
329 config: &RealtimeConfig,
330) -> Result<(RealtimeSender, RealtimeReceiver)> {
331 realtime_connect_direct_to("wss://api.x.ai/v1/realtime", ephemeral_token, config).await
332}
333
334pub async fn realtime_connect_direct_to(
336 url: &str,
337 token: &str,
338 config: &RealtimeConfig,
339) -> Result<(RealtimeSender, RealtimeReceiver)> {
340 let host = url
342 .trim_start_matches("wss://")
343 .trim_start_matches("ws://")
344 .split('/')
345 .next()
346 .unwrap_or("api.x.ai");
347
348 let request = Request::builder()
349 .uri(url)
350 .header("Host", host)
351 .header("Authorization", format!("Bearer {token}"))
352 .header("Connection", "Upgrade")
353 .header("Upgrade", "websocket")
354 .header("Sec-WebSocket-Version", "13")
355 .header(
356 "Sec-WebSocket-Key",
357 tokio_tungstenite::tungstenite::handshake::client::generate_key(),
358 )
359 .body(())
360 .map_err(|e| Error::Api(ApiError {
361 status_code: 0,
362 code: "websocket_request".into(),
363 message: format!("Failed to build WebSocket request: {e}"),
364 request_id: String::new(),
365 }))?;
366
367 let (ws_stream, _response) = tokio::time::timeout(
368 std::time::Duration::from_secs(10),
369 tokio_tungstenite::connect_async(request),
370 )
371 .await
372 .map_err(|_| Error::Api(ApiError {
373 status_code: 0,
374 code: "timeout".into(),
375 message: "Direct xAI WebSocket connection timed out (10s)".into(),
376 request_id: String::new(),
377 }))?
378 .map_err(Error::WebSocket)?;
379
380 let (sink, stream) = ws_stream.split();
381 let sender = RealtimeSender {
382 sink: tokio::sync::Mutex::new(sink),
383 };
384 let receiver = RealtimeReceiver { stream };
385
386 let session_update = build_session_update(config);
388 sender.send_raw(&serde_json::to_string(&session_update)?).await?;
389
390 Ok((sender, receiver))
391}
392
393fn build_session_update(config: &RealtimeConfig) -> serde_json::Value {
399 let is_openai = config.model.contains("gpt-") || config.model.contains("realtime");
400
401 let mut session = serde_json::json!({
402 "voice": config.voice,
403 "instructions": config.instructions,
404 "turn_detection": { "type": "server_vad" },
405 "tools": config.tools,
406 });
407
408 if !config.model.is_empty() {
409 session["model"] = serde_json::Value::String(config.model.clone());
410 }
411
412 if is_openai {
413 session["modalities"] = serde_json::json!(["text", "audio"]);
415 session["input_audio_format"] = serde_json::json!("pcm16");
416 session["output_audio_format"] = serde_json::json!("pcm16");
417 session["input_audio_transcription"] = serde_json::json!({ "model": "gpt-4o-mini-transcribe" });
418 } else {
419 session["input_audio_transcription"] = serde_json::json!({ "model": "grok-2-audio" });
421 session["audio"] = serde_json::json!({
422 "input": { "format": { "type": "audio/pcm", "rate": config.sample_rate } },
423 "output": { "format": { "type": "audio/pcm", "rate": config.sample_rate } },
424 });
425 }
426
427 serde_json::json!({
428 "type": "session.update",
429 "session": session,
430 })
431}
432
433unsafe impl Send for RealtimeSender {}
437unsafe impl Sync for RealtimeSender {}
438
439impl RealtimeSender {
440 pub async fn send_audio(&self, base64_pcm: &str) -> Result<()> {
442 let msg = serde_json::json!({
443 "type": "input_audio_buffer.append",
444 "audio": base64_pcm,
445 });
446 self.send_raw(&serde_json::to_string(&msg)?).await
447 }
448
449 pub async fn send_text(&self, text: &str) -> Result<()> {
451 let item = serde_json::json!({
452 "type": "conversation.item.create",
453 "item": {
454 "type": "message",
455 "role": "user",
456 "content": [{
457 "type": "input_text",
458 "text": text,
459 }]
460 }
461 });
462 self.send_raw(&serde_json::to_string(&item)?).await?;
463
464 let response = serde_json::json!({
465 "type": "response.create",
466 "response": {
467 "modalities": ["text", "audio"],
468 }
469 });
470 self.send_raw(&serde_json::to_string(&response)?).await
471 }
472
473 pub async fn send_function_result(&self, call_id: &str, output: &str) -> Result<()> {
475 let item = serde_json::json!({
476 "type": "conversation.item.create",
477 "item": {
478 "type": "function_call_output",
479 "call_id": call_id,
480 "output": output,
481 }
482 });
483 self.send_raw(&serde_json::to_string(&item)?).await?;
484
485 let response = serde_json::json!({
486 "type": "response.create",
487 });
488 self.send_raw(&serde_json::to_string(&response)?).await
489 }
490
491 pub async fn cancel_response(&self) -> Result<()> {
493 let msg = serde_json::json!({ "type": "response.cancel" });
494 self.send_raw(&serde_json::to_string(&msg)?).await
495 }
496
497 pub async fn close(self) -> Result<()> {
499 let mut sink = self.sink.into_inner();
500 sink.close().await.map_err(Error::WebSocket)
501 }
502
503 async fn send_raw(&self, text: &str) -> Result<()> {
505 let mut sink = self.sink.lock().await;
506 sink.send(Message::Text(text.into()))
507 .await
508 .map_err(Error::WebSocket)
509 }
510}
511
512impl RealtimeReceiver {
515 pub async fn recv(&mut self) -> Option<RealtimeEvent> {
517 loop {
518 let msg = self.stream.next().await?;
519 match msg {
520 Ok(Message::Text(text)) => {
521 return Some(parse_event(&text));
522 }
523 Ok(Message::Close(_)) => return None,
524 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Frame(_)) => continue,
525 Ok(Message::Binary(_)) => continue,
526 Err(_) => return None,
527 }
528 }
529 }
530}
531
532fn parse_event(text: &str) -> RealtimeEvent {
535 let Ok(v) = serde_json::from_str::<serde_json::Value>(text) else {
536 return RealtimeEvent::Unknown(serde_json::Value::String(text.to_string()));
537 };
538
539 let event_type = v["type"].as_str().unwrap_or("");
540
541 match event_type {
542 "session.updated" => RealtimeEvent::SessionReady,
543
544 "response.audio.delta" => RealtimeEvent::AudioDelta {
545 delta: v["delta"].as_str().unwrap_or("").to_string(),
546 },
547
548 "response.output_audio.delta" => RealtimeEvent::AudioDelta {
550 delta: v["delta"].as_str().unwrap_or("").to_string(),
551 },
552
553 "response.audio_transcript.delta" | "response.output_audio_transcript.delta" => {
554 RealtimeEvent::TranscriptDelta {
555 delta: v["delta"].as_str().unwrap_or("").to_string(),
556 source: "output".into(),
557 }
558 }
559
560 "response.audio_transcript.done" | "response.output_audio_transcript.done" => {
561 RealtimeEvent::TranscriptDone {
562 transcript: v["transcript"].as_str().unwrap_or("").to_string(),
563 source: "output".into(),
564 }
565 }
566
567 "conversation.item.input_audio_transcription.completed" => {
568 RealtimeEvent::TranscriptDone {
569 transcript: v["transcript"].as_str().unwrap_or("").to_string(),
570 source: "input".into(),
571 }
572 }
573
574 "input_audio_buffer.speech_started" => RealtimeEvent::SpeechStarted,
575 "input_audio_buffer.speech_stopped" => RealtimeEvent::SpeechStopped,
576
577 "response.function_call_arguments.done" => RealtimeEvent::FunctionCall {
578 name: v["name"].as_str().unwrap_or("").to_string(),
579 call_id: v["call_id"].as_str().unwrap_or("").to_string(),
580 arguments: v["arguments"].as_str().unwrap_or("").to_string(),
581 },
582
583 "response.done" => RealtimeEvent::ResponseDone,
584
585 "error" => RealtimeEvent::Error {
586 message: v["error"]["message"]
587 .as_str()
588 .or_else(|| v["message"].as_str())
589 .unwrap_or("unknown error")
590 .to_string(),
591 },
592
593 _ => RealtimeEvent::Unknown(v),
594 }
595}
596
597#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[test]
604 fn default_config() {
605 let config = RealtimeConfig::default();
606 assert_eq!(config.voice, "Sal");
607 assert_eq!(config.sample_rate, 24000);
608 assert!(config.instructions.is_empty());
609 assert!(config.tools.is_empty());
610 assert!(config.model.is_empty());
611 }
612
613 #[test]
614 fn config_serialization() {
615 let config = RealtimeConfig {
616 voice: "Eve".into(),
617 instructions: "You are a helpful assistant.".into(),
618 sample_rate: 16000,
619 tools: vec![serde_json::json!({
620 "type": "function",
621 "name": "get_weather",
622 "description": "Get weather for a location",
623 "parameters": {
624 "type": "object",
625 "properties": {
626 "location": { "type": "string" }
627 },
628 "required": ["location"]
629 }
630 })],
631 model: String::new(),
632 };
633
634 let json = serde_json::to_value(&config).unwrap();
635 assert_eq!(json["voice"], "Eve");
636 assert_eq!(json["sample_rate"], 16000);
637 assert_eq!(json["tools"].as_array().unwrap().len(), 1);
638 }
639
640 #[test]
641 fn parse_session_ready() {
642 let event = parse_event(r#"{"type":"session.updated","session":{}}"#);
643 assert!(matches!(event, RealtimeEvent::SessionReady));
644 }
645
646 #[test]
647 fn parse_audio_delta() {
648 let event = parse_event(r#"{"type":"response.audio.delta","delta":"AQID"}"#);
649 match event {
650 RealtimeEvent::AudioDelta { delta } => assert_eq!(delta, "AQID"),
651 _ => panic!("expected AudioDelta"),
652 }
653 }
654
655 #[test]
656 fn parse_transcript_done() {
657 let event = parse_event(
658 r#"{"type":"conversation.item.input_audio_transcription.completed","transcript":"hello"}"#,
659 );
660 match event {
661 RealtimeEvent::TranscriptDone { transcript, source } => {
662 assert_eq!(transcript, "hello");
663 assert_eq!(source, "input");
664 }
665 _ => panic!("expected TranscriptDone"),
666 }
667 }
668
669 #[test]
670 fn parse_function_call() {
671 let event = parse_event(
672 r#"{"type":"response.function_call_arguments.done","name":"get_weather","call_id":"call_123","arguments":"{\"location\":\"London\"}"}"#,
673 );
674 match event {
675 RealtimeEvent::FunctionCall { name, call_id, arguments } => {
676 assert_eq!(name, "get_weather");
677 assert_eq!(call_id, "call_123");
678 assert!(arguments.contains("London"));
679 }
680 _ => panic!("expected FunctionCall"),
681 }
682 }
683
684 #[test]
685 fn parse_error() {
686 let event = parse_event(r#"{"type":"error","error":{"message":"rate limited"}}"#);
687 match event {
688 RealtimeEvent::Error { message } => assert_eq!(message, "rate limited"),
689 _ => panic!("expected Error"),
690 }
691 }
692
693 #[test]
694 fn parse_unknown() {
695 let event = parse_event(r#"{"type":"some.future.event","data":42}"#);
696 assert!(matches!(event, RealtimeEvent::Unknown(_)));
697 }
698
699 #[test]
700 fn parse_speech_events() {
701 assert!(matches!(
702 parse_event(r#"{"type":"input_audio_buffer.speech_started"}"#),
703 RealtimeEvent::SpeechStarted
704 ));
705 assert!(matches!(
706 parse_event(r#"{"type":"input_audio_buffer.speech_stopped"}"#),
707 RealtimeEvent::SpeechStopped
708 ));
709 assert!(matches!(
710 parse_event(r#"{"type":"response.done"}"#),
711 RealtimeEvent::ResponseDone
712 ));
713 }
714
715 #[ignore]
716 #[tokio::test]
717 async fn live_connect() {
718 let key = std::env::var("QAI_API_KEY").expect("QAI_API_KEY required");
720 let client = crate::Client::new(key);
721 let config = RealtimeConfig::default();
722
723 let (sender, mut receiver) = client.realtime_connect(&config).await.unwrap();
724
725 let event = receiver.recv().await.unwrap();
727 assert!(matches!(event, RealtimeEvent::SessionReady));
728
729 sender.close().await.unwrap();
730 }
731}