Skip to main content

xai_rust/api/
realtime.rs

1//! Realtime API for voice interactions via WebSocket.
2
3use futures_util::{SinkExt, StreamExt};
4use tokio_tungstenite::{
5    connect_async,
6    tungstenite::{client::IntoClientRequest, Message as WsMessage},
7};
8
9use crate::client::XaiClient;
10use crate::models::tool::Tool;
11use crate::models::voice::{
12    AudioFormat, ConversationItem, RealtimeClientMessage, RealtimeServerMessage, SessionConfig,
13    Voice,
14};
15use crate::{Error, Result};
16
17/// Realtime API for voice interactions.
18#[derive(Debug, Clone)]
19pub struct RealtimeApi {
20    client: XaiClient,
21}
22
23impl RealtimeApi {
24    pub(crate) fn new(client: XaiClient) -> Self {
25        Self { client }
26    }
27
28    /// Start building a realtime session.
29    ///
30    /// # Example
31    ///
32    /// ```rust,no_run
33    /// use xai_rust::{XaiClient, Voice, AudioFormat};
34    ///
35    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
36    /// let client = XaiClient::from_env()?;
37    ///
38    /// let session = client.realtime()
39    ///     .connect("grok-4")
40    ///     .voice(Voice::Ara)
41    ///     .audio_format(AudioFormat::Pcm16)
42    ///     .instructions("You are a helpful voice assistant.")
43    ///     .start()
44    ///     .await?;
45    /// # Ok(())
46    /// # }
47    /// ```
48    pub fn connect(&self, model: impl Into<String>) -> RealtimeSessionBuilder {
49        RealtimeSessionBuilder::new(self.client.clone(), model.into())
50    }
51
52    /// Resume a realtime session from an existing session configuration.
53    ///
54    /// This is useful when reconnecting after a dropped socket and you want
55    /// to restore the same voice/tool/instruction settings in one call.
56    ///
57    /// # Example
58    ///
59    /// ```rust,no_run
60    /// use xai_rust::{SessionConfig, Voice, XaiClient};
61    ///
62    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
63    /// let client = XaiClient::from_env()?;
64    ///
65    /// let config = SessionConfig::new("grok-4")
66    ///     .voice(Voice::Rex)
67    ///     .instructions("Resume with prior settings");
68    ///
69    /// let session = client.realtime().resume(config).start().await?;
70    /// # Ok(())
71    /// # }
72    /// ```
73    pub fn resume(&self, config: SessionConfig) -> RealtimeSessionBuilder {
74        RealtimeSessionBuilder::from_config(self.client.clone(), config)
75    }
76}
77
78/// Builder for realtime sessions.
79#[derive(Debug)]
80pub struct RealtimeSessionBuilder {
81    client: XaiClient,
82    config: SessionConfig,
83}
84
85impl RealtimeSessionBuilder {
86    fn new(client: XaiClient, model: String) -> Self {
87        Self::from_config(client, SessionConfig::new(model))
88    }
89
90    fn from_config(client: XaiClient, config: SessionConfig) -> Self {
91        Self { client, config }
92    }
93
94    /// Set the voice.
95    pub fn voice(mut self, voice: Voice) -> Self {
96        self.config.voice = voice;
97        self
98    }
99
100    /// Set the audio format for both input and output.
101    pub fn audio_format(mut self, format: AudioFormat) -> Self {
102        self.config.input_audio_format = format;
103        self.config.output_audio_format = format;
104        self
105    }
106
107    /// Set the input audio format.
108    pub fn input_format(mut self, format: AudioFormat) -> Self {
109        self.config.input_audio_format = format;
110        self
111    }
112
113    /// Set the output audio format.
114    pub fn output_format(mut self, format: AudioFormat) -> Self {
115        self.config.output_audio_format = format;
116        self
117    }
118
119    /// Set system instructions.
120    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
121        self.config.instructions = Some(instructions.into());
122        self
123    }
124
125    /// Add tools.
126    pub fn tools(mut self, tools: Vec<Tool>) -> Self {
127        self.config.tools = Some(tools);
128        self
129    }
130
131    /// Start the realtime session.
132    ///
133    /// This sends the configured session settings to the server immediately
134    /// after connecting.
135    pub async fn start(self) -> Result<RealtimeSession> {
136        // Build WebSocket URL using proper URL parsing
137        let mut parsed = url::Url::parse(self.client.base_url())?;
138        let scheme = match parsed.scheme() {
139            "https" => "wss",
140            "http" => "ws",
141            s => {
142                return Err(Error::InvalidRequest(format!(
143                    "Unsupported URL scheme: {}",
144                    s
145                )))
146            }
147        };
148        parsed
149            .set_scheme(scheme)
150            .map_err(|_| Error::InvalidRequest("Failed to set WebSocket scheme".to_string()))?;
151        parsed
152            .path_segments_mut()
153            .map_err(|_| Error::InvalidRequest("Cannot-be-a-base URL".to_string()))?
154            .push("realtime");
155        parsed
156            .query_pairs_mut()
157            .append_pair("model", &self.config.model);
158        let ws_url = parsed.to_string();
159
160        // Create WebSocket request with required handshake headers plus auth.
161        let mut request = ws_url
162            .into_client_request()
163            .map_err(|e| Error::InvalidRequest(e.to_string()))?;
164        request.headers_mut().insert(
165            "Authorization",
166            http::HeaderValue::from_str(&format!("Bearer {}", self.client.api_key()))
167                .map_err(|e| Error::InvalidRequest(e.to_string()))?,
168        );
169        request.headers_mut().insert(
170            "Sec-WebSocket-Protocol",
171            http::HeaderValue::from_static("realtime"),
172        );
173
174        let (ws_stream, _) = connect_async(request).await?;
175        let (write, read) = ws_stream.split();
176
177        let mut session = RealtimeSession {
178            client: self.client.clone(),
179            config: self.config.clone(),
180            write: Box::new(write),
181            read: Box::new(read),
182        };
183
184        // Apply builder configuration on connect without consuming
185        // any incoming server messages.
186        session.update_session(self.config).await?;
187
188        Ok(session)
189    }
190}
191
192/// An active realtime session.
193pub struct RealtimeSession {
194    client: XaiClient,
195    config: SessionConfig,
196    write: Box<
197        dyn futures_util::Sink<WsMessage, Error = tokio_tungstenite::tungstenite::Error>
198            + Send
199            + Unpin,
200    >,
201    read: Box<
202        dyn futures_util::Stream<
203                Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>,
204            > + Send
205            + Unpin,
206    >,
207}
208
209impl RealtimeSession {
210    /// Get the session configuration.
211    pub fn config(&self) -> &SessionConfig {
212        &self.config
213    }
214
215    /// Create a reconnect builder seeded with the current session configuration.
216    ///
217    /// This is helpful when you want to tweak a setting before reconnecting.
218    pub fn reconnect_builder(&self) -> RealtimeSessionBuilder {
219        RealtimeSessionBuilder::from_config(self.client.clone(), self.config.clone())
220    }
221
222    /// Reconnect using the current session configuration.
223    ///
224    /// This creates a new socket and reapplies the current session config by
225    /// sending a `session_update` during startup.
226    pub async fn reconnect(&self) -> Result<RealtimeSession> {
227        self.reconnect_builder().start().await
228    }
229
230    /// Reconnect and replay conversation items into the new session.
231    ///
232    /// Use this to rehydrate local conversation context after reconnecting.
233    pub async fn reconnect_and_replay(
234        &self,
235        items: impl IntoIterator<Item = ConversationItem>,
236    ) -> Result<RealtimeSession> {
237        let mut session = self.reconnect().await?;
238        for item in items {
239            session.create_item(item).await?;
240        }
241        Ok(session)
242    }
243
244    /// Send a message to the server.
245    pub async fn send(&mut self, message: RealtimeClientMessage) -> Result<()> {
246        let json = serde_json::to_string(&message)?;
247        self.write.send(WsMessage::Text(json.into())).await?;
248        Ok(())
249    }
250
251    /// Receive the next message from the server.
252    pub async fn receive(&mut self) -> Result<Option<RealtimeServerMessage>> {
253        while let Some(result) = self.read.next().await {
254            match result? {
255                WsMessage::Text(text) => {
256                    let message: RealtimeServerMessage = serde_json::from_str(&text)?;
257                    return Ok(Some(message));
258                }
259                WsMessage::Binary(_) => {
260                    // Binary frames contain raw audio data, not JSON messages.
261                    // Skip them here; use `receive_raw()` if you need audio frames.
262                    continue;
263                }
264                WsMessage::Close(_) => return Ok(None),
265                WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_) => continue,
266            }
267        }
268        Ok(None)
269    }
270
271    /// Receive the next raw WebSocket message (including binary audio frames).
272    pub async fn receive_raw(&mut self) -> Result<Option<RawRealtimeMessage>> {
273        while let Some(result) = self.read.next().await {
274            match result? {
275                WsMessage::Text(text) => {
276                    let message: RealtimeServerMessage = serde_json::from_str(&text)?;
277                    return Ok(Some(RawRealtimeMessage::Event(message)));
278                }
279                WsMessage::Binary(data) => {
280                    return Ok(Some(RawRealtimeMessage::Audio(data.to_vec())));
281                }
282                WsMessage::Close(_) => return Ok(None),
283                WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_) => continue,
284            }
285        }
286        Ok(None)
287    }
288
289    /// Update the session configuration.
290    pub async fn update_session(&mut self, config: SessionConfig) -> Result<()> {
291        self.config = config.clone();
292        self.send(RealtimeClientMessage::SessionUpdate { session: config })
293            .await
294    }
295
296    /// Append audio data to the input buffer.
297    pub async fn append_audio(&mut self, audio_base64: impl Into<String>) -> Result<()> {
298        self.send(RealtimeClientMessage::InputAudioBufferAppend {
299            audio: audio_base64.into(),
300        })
301        .await
302    }
303
304    /// Commit the audio buffer.
305    pub async fn commit_audio(&mut self) -> Result<()> {
306        self.send(RealtimeClientMessage::InputAudioBufferCommit {})
307            .await
308    }
309
310    /// Clear the audio buffer.
311    pub async fn clear_audio(&mut self) -> Result<()> {
312        self.send(RealtimeClientMessage::InputAudioBufferClear {})
313            .await
314    }
315
316    /// Create a conversation item.
317    pub async fn create_item(&mut self, item: ConversationItem) -> Result<()> {
318        self.send(RealtimeClientMessage::ConversationItemCreate { item })
319            .await
320    }
321
322    /// Request a response from the model.
323    pub async fn create_response(&mut self) -> Result<()> {
324        self.send(RealtimeClientMessage::ResponseCreate { response: None })
325            .await
326    }
327
328    /// Cancel the current response.
329    pub async fn cancel_response(&mut self) -> Result<()> {
330        self.send(RealtimeClientMessage::ResponseCancel {}).await
331    }
332
333    /// Close the session.
334    pub async fn close(mut self) -> Result<()> {
335        self.write.close().await?;
336        Ok(())
337    }
338}
339
340impl std::fmt::Debug for RealtimeSession {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        f.debug_struct("RealtimeSession")
343            .field("config", &self.config)
344            .finish_non_exhaustive()
345    }
346}
347
348/// A raw message received from the WebSocket, which may be either
349/// a parsed JSON event or raw binary audio data.
350#[derive(Debug)]
351pub enum RawRealtimeMessage {
352    /// A parsed server event (JSON text frame).
353    Event(RealtimeServerMessage),
354    /// Raw binary audio data.
355    Audio(Vec<u8>),
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::models::voice::{
362        ConversationContent, ConversationContentType, ConversationItemType,
363    };
364    use futures_util::{SinkExt, StreamExt};
365    use std::time::Duration;
366    use tokio::net::TcpListener;
367    use tokio::time::timeout;
368    use tokio_tungstenite::{
369        accept_hdr_async,
370        tungstenite::{
371            handshake::server::{Request, Response},
372            Message as WsMessage,
373        },
374    };
375
376    async fn spawn_capture_server(
377        expected_messages: usize,
378    ) -> (String, tokio::task::JoinHandle<Vec<String>>) {
379        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
380        let addr = listener.local_addr().unwrap();
381
382        let handle = tokio::spawn(async move {
383            let (stream, _) = listener.accept().await.unwrap();
384            let ws_stream =
385                accept_hdr_async(stream, |request: &Request, mut response: Response| {
386                    let has_realtime = request
387                        .headers()
388                        .get("Sec-WebSocket-Protocol")
389                        .and_then(|v| v.to_str().ok())
390                        .map(|raw| raw.split(',').any(|v| v.trim() == "realtime"))
391                        .unwrap_or(false);
392
393                    if has_realtime {
394                        response.headers_mut().insert(
395                            "Sec-WebSocket-Protocol",
396                            http::HeaderValue::from_static("realtime"),
397                        );
398                    }
399
400                    Ok(response)
401                })
402                .await
403                .unwrap();
404            let (_write, mut read) = ws_stream.split();
405
406            let mut messages = Vec::new();
407            while let Some(frame) = read.next().await {
408                match frame.unwrap() {
409                    WsMessage::Text(text) => {
410                        messages.push(text.to_string());
411                        if messages.len() >= expected_messages {
412                            break;
413                        }
414                    }
415                    WsMessage::Close(_) => break,
416                    WsMessage::Ping(_)
417                    | WsMessage::Pong(_)
418                    | WsMessage::Binary(_)
419                    | WsMessage::Frame(_) => {}
420                }
421            }
422            messages
423        });
424
425        (format!("http://{}", addr), handle)
426    }
427
428    async fn spawn_multi_capture_server(
429        expected_messages_per_connection: Vec<usize>,
430    ) -> (String, tokio::task::JoinHandle<Vec<Vec<String>>>) {
431        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
432        let addr = listener.local_addr().unwrap();
433
434        let handle = tokio::spawn(async move {
435            let mut all_connections = Vec::new();
436
437            for expected_messages in expected_messages_per_connection {
438                let (stream, _) = listener.accept().await.unwrap();
439                let ws_stream =
440                    accept_hdr_async(stream, |request: &Request, mut response: Response| {
441                        let has_realtime = request
442                            .headers()
443                            .get("Sec-WebSocket-Protocol")
444                            .and_then(|v| v.to_str().ok())
445                            .map(|raw| raw.split(',').any(|v| v.trim() == "realtime"))
446                            .unwrap_or(false);
447
448                        if has_realtime {
449                            response.headers_mut().insert(
450                                "Sec-WebSocket-Protocol",
451                                http::HeaderValue::from_static("realtime"),
452                            );
453                        }
454
455                        Ok(response)
456                    })
457                    .await
458                    .unwrap();
459                let (_write, mut read) = ws_stream.split();
460
461                let mut messages = Vec::new();
462                while let Some(frame) = read.next().await {
463                    match frame.unwrap() {
464                        WsMessage::Text(text) => {
465                            messages.push(text.to_string());
466                            if messages.len() >= expected_messages {
467                                break;
468                            }
469                        }
470                        WsMessage::Close(_) => break,
471                        WsMessage::Ping(_)
472                        | WsMessage::Pong(_)
473                        | WsMessage::Binary(_)
474                        | WsMessage::Frame(_) => {}
475                    }
476                }
477
478                all_connections.push(messages);
479            }
480
481            all_connections
482        });
483
484        (format!("http://{}", addr), handle)
485    }
486
487    async fn spawn_response_server(
488        frames: Vec<WsMessage>,
489    ) -> (String, tokio::task::JoinHandle<()>) {
490        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
491        let addr = listener.local_addr().unwrap();
492
493        let handle = tokio::spawn(async move {
494            let (stream, _) = listener.accept().await.unwrap();
495            let ws_stream =
496                accept_hdr_async(stream, |request: &Request, mut response: Response| {
497                    let has_realtime = request
498                        .headers()
499                        .get("Sec-WebSocket-Protocol")
500                        .and_then(|v| v.to_str().ok())
501                        .map(|raw| raw.split(',').any(|v| v.trim() == "realtime"))
502                        .unwrap_or(false);
503
504                    if has_realtime {
505                        response.headers_mut().insert(
506                            "Sec-WebSocket-Protocol",
507                            http::HeaderValue::from_static("realtime"),
508                        );
509                    }
510
511                    Ok(response)
512                })
513                .await
514                .unwrap();
515            let (mut write, mut read) = ws_stream.split();
516            for frame in frames {
517                write.send(frame).await.unwrap();
518            }
519
520            while let Some(frame) = read.next().await {
521                if matches!(frame, Ok(WsMessage::Close(_))) {
522                    break;
523                }
524            }
525        });
526
527        (format!("http://{}", addr), handle)
528    }
529
530    #[tokio::test]
531    async fn start_sends_session_update_with_builder_config() {
532        let (base_url, server_handle) = spawn_capture_server(1).await;
533
534        let client = XaiClient::builder()
535            .api_key("test-key")
536            .base_url(base_url)
537            .build()
538            .unwrap();
539
540        let session = timeout(
541            Duration::from_secs(2),
542            client
543                .realtime()
544                .connect("grok-4")
545                .voice(Voice::Rex)
546                .audio_format(AudioFormat::G711Ulaw)
547                .instructions("Be brief")
548                .start(),
549        )
550        .await
551        .expect("start() should not wait on incoming messages")
552        .unwrap();
553
554        assert_eq!(session.config().voice, Voice::Rex);
555        assert_eq!(session.config().instructions.as_deref(), Some("Be brief"));
556
557        let frames = timeout(Duration::from_secs(2), server_handle)
558            .await
559            .unwrap()
560            .unwrap();
561        assert_eq!(frames.len(), 1);
562
563        let msg: serde_json::Value = serde_json::from_str(&frames[0]).unwrap();
564        assert_eq!(msg["type"], "session_update");
565        assert_eq!(msg["session"]["model"], "grok-4");
566        assert_eq!(msg["session"]["voice"], "rex");
567        assert_eq!(msg["session"]["input_audio_format"], "g711_ulaw");
568        assert_eq!(msg["session"]["output_audio_format"], "g711_ulaw");
569        assert_eq!(msg["session"]["instructions"], "Be brief");
570
571        session.close().await.unwrap();
572    }
573
574    #[tokio::test]
575    async fn start_sends_session_update_with_format_and_tools_config() {
576        let (base_url, server_handle) = spawn_capture_server(1).await;
577
578        let client = XaiClient::builder()
579            .api_key("test-key")
580            .base_url(base_url)
581            .build()
582            .unwrap();
583
584        let session = timeout(
585            Duration::from_secs(2),
586            client
587                .realtime()
588                .connect("grok-4")
589                .voice(Voice::Ara)
590                .input_format(AudioFormat::G711Alaw)
591                .output_format(AudioFormat::G711Ulaw)
592                .tools(vec![Tool::web_search(), Tool::x_search()])
593                .instructions("Use web and x search tools")
594                .start(),
595        )
596        .await
597        .expect("start() should not wait on incoming messages")
598        .unwrap();
599
600        assert_eq!(session.config().voice, Voice::Ara);
601        assert_eq!(
602            session.config().instructions.as_deref(),
603            Some("Use web and x search tools")
604        );
605
606        let frames = timeout(Duration::from_secs(2), server_handle)
607            .await
608            .unwrap()
609            .unwrap();
610        assert_eq!(frames.len(), 1);
611
612        let msg: serde_json::Value = serde_json::from_str(&frames[0]).unwrap();
613        assert_eq!(msg["type"], "session_update");
614        assert_eq!(msg["session"]["model"], "grok-4");
615        assert_eq!(msg["session"]["voice"], "ara");
616        assert_eq!(msg["session"]["input_audio_format"], "g711_alaw");
617        assert_eq!(msg["session"]["output_audio_format"], "g711_ulaw");
618
619        let tools = msg["session"]["tools"].as_array().unwrap();
620        assert_eq!(tools.len(), 2);
621        assert_eq!(tools[0]["type"], "web_search");
622        assert_eq!(tools[1]["type"], "x_search");
623
624        session.close().await.unwrap();
625    }
626
627    #[tokio::test]
628    async fn update_session_updates_local_config_and_sends_update() {
629        let (base_url, server_handle) = spawn_capture_server(2).await;
630
631        let client = XaiClient::builder()
632            .api_key("test-key")
633            .base_url(base_url)
634            .build()
635            .unwrap();
636
637        let mut session = client
638            .realtime()
639            .connect("grok-4")
640            .voice(Voice::Ara)
641            .start()
642            .await
643            .unwrap();
644
645        let updated = SessionConfig::new("grok-4")
646            .voice(Voice::Leo)
647            .input_format(AudioFormat::G711Alaw)
648            .output_format(AudioFormat::G711Alaw)
649            .instructions("Updated instructions");
650
651        session.update_session(updated.clone()).await.unwrap();
652
653        assert_eq!(session.config().voice, Voice::Leo);
654        assert_eq!(session.config().input_audio_format, AudioFormat::G711Alaw);
655        assert_eq!(session.config().output_audio_format, AudioFormat::G711Alaw);
656        assert_eq!(
657            session.config().instructions.as_deref(),
658            Some("Updated instructions")
659        );
660
661        session.close().await.unwrap();
662
663        let frames = timeout(Duration::from_secs(2), server_handle)
664            .await
665            .unwrap()
666            .unwrap();
667        assert_eq!(frames.len(), 2);
668
669        let second: serde_json::Value = serde_json::from_str(&frames[1]).unwrap();
670        assert_eq!(second["type"], "session_update");
671        assert_eq!(second["session"]["voice"], "leo");
672        assert_eq!(second["session"]["input_audio_format"], "g711_alaw");
673        assert_eq!(second["session"]["output_audio_format"], "g711_alaw");
674        assert_eq!(second["session"]["instructions"], "Updated instructions");
675    }
676
677    #[tokio::test]
678    async fn resume_starts_session_with_existing_config() {
679        let (base_url, server_handle) = spawn_capture_server(1).await;
680
681        let client = XaiClient::builder()
682            .api_key("test-key")
683            .base_url(base_url)
684            .build()
685            .unwrap();
686
687        let config = SessionConfig::new("grok-4")
688            .voice(Voice::Leo)
689            .input_format(AudioFormat::G711Alaw)
690            .output_format(AudioFormat::G711Ulaw)
691            .instructions("Resume config");
692
693        let session = client
694            .realtime()
695            .resume(config.clone())
696            .start()
697            .await
698            .unwrap();
699
700        assert_eq!(session.config().model, "grok-4");
701        assert_eq!(session.config().voice, Voice::Leo);
702        assert_eq!(session.config().input_audio_format, AudioFormat::G711Alaw);
703        assert_eq!(session.config().output_audio_format, AudioFormat::G711Ulaw);
704        assert_eq!(
705            session.config().instructions.as_deref(),
706            Some("Resume config")
707        );
708
709        session.close().await.unwrap();
710
711        let frames = timeout(Duration::from_secs(2), server_handle)
712            .await
713            .unwrap()
714            .unwrap();
715        assert_eq!(frames.len(), 1);
716
717        let msg: serde_json::Value = serde_json::from_str(&frames[0]).unwrap();
718        assert_eq!(msg["type"], "session_update");
719        assert_eq!(msg["session"]["model"], "grok-4");
720        assert_eq!(msg["session"]["voice"], "leo");
721        assert_eq!(msg["session"]["input_audio_format"], "g711_alaw");
722        assert_eq!(msg["session"]["output_audio_format"], "g711_ulaw");
723        assert_eq!(msg["session"]["instructions"], "Resume config");
724    }
725
726    #[tokio::test]
727    async fn reconnect_reuses_current_session_config() {
728        let (base_url, server_handle) = spawn_multi_capture_server(vec![1, 1]).await;
729
730        let client = XaiClient::builder()
731            .api_key("test-key")
732            .base_url(base_url)
733            .build()
734            .unwrap();
735
736        let session = client
737            .realtime()
738            .connect("grok-4")
739            .voice(Voice::Rex)
740            .instructions("Reconnect me")
741            .start()
742            .await
743            .unwrap();
744
745        let reconnected = session.reconnect().await.unwrap();
746
747        assert_eq!(reconnected.config().model, "grok-4");
748        assert_eq!(reconnected.config().voice, Voice::Rex);
749        assert_eq!(
750            reconnected.config().instructions.as_deref(),
751            Some("Reconnect me")
752        );
753
754        session.close().await.unwrap();
755        reconnected.close().await.unwrap();
756
757        let frames_by_connection = timeout(Duration::from_secs(2), server_handle)
758            .await
759            .unwrap()
760            .unwrap();
761        assert_eq!(frames_by_connection.len(), 2);
762        assert_eq!(frames_by_connection[0].len(), 1);
763        assert_eq!(frames_by_connection[1].len(), 1);
764
765        let first: serde_json::Value = serde_json::from_str(&frames_by_connection[0][0]).unwrap();
766        let second: serde_json::Value = serde_json::from_str(&frames_by_connection[1][0]).unwrap();
767        assert_eq!(first["type"], "session_update");
768        assert_eq!(second["type"], "session_update");
769        assert_eq!(second["session"]["model"], "grok-4");
770        assert_eq!(second["session"]["voice"], "rex");
771        assert_eq!(second["session"]["instructions"], "Reconnect me");
772    }
773
774    #[tokio::test]
775    async fn reconnect_and_replay_sends_conversation_items() {
776        let (base_url, server_handle) = spawn_multi_capture_server(vec![1, 3]).await;
777
778        let client = XaiClient::builder()
779            .api_key("test-key")
780            .base_url(base_url)
781            .build()
782            .unwrap();
783
784        let session = client.realtime().connect("grok-4").start().await.unwrap();
785
786        let items = vec![
787            ConversationItem {
788                id: Some("item-1".to_string()),
789                item_type: ConversationItemType::Message,
790                role: Some("user".to_string()),
791                content: Some(vec![ConversationContent {
792                    content_type: ConversationContentType::InputText,
793                    text: Some("hello".to_string()),
794                    audio: None,
795                    transcript: None,
796                }]),
797            },
798            ConversationItem {
799                id: Some("item-2".to_string()),
800                item_type: ConversationItemType::Message,
801                role: Some("assistant".to_string()),
802                content: Some(vec![ConversationContent {
803                    content_type: ConversationContentType::Text,
804                    text: Some("world".to_string()),
805                    audio: None,
806                    transcript: None,
807                }]),
808            },
809        ];
810
811        let resumed = session.reconnect_and_replay(items).await.unwrap();
812
813        session.close().await.unwrap();
814        resumed.close().await.unwrap();
815
816        let frames_by_connection = timeout(Duration::from_secs(2), server_handle)
817            .await
818            .unwrap()
819            .unwrap();
820        assert_eq!(frames_by_connection.len(), 2);
821        assert_eq!(frames_by_connection[1].len(), 3);
822
823        let second_connection = &frames_by_connection[1];
824        let first: serde_json::Value = serde_json::from_str(&second_connection[0]).unwrap();
825        let second: serde_json::Value = serde_json::from_str(&second_connection[1]).unwrap();
826        let third: serde_json::Value = serde_json::from_str(&second_connection[2]).unwrap();
827
828        assert_eq!(first["type"], "session_update");
829        assert_eq!(second["type"], "conversation_item_create");
830        assert_eq!(second["item"]["id"], "item-1");
831        assert_eq!(second["item"]["content"][0]["text"], "hello");
832        assert_eq!(third["type"], "conversation_item_create");
833        assert_eq!(third["item"]["id"], "item-2");
834        assert_eq!(third["item"]["content"][0]["text"], "world");
835    }
836
837    #[tokio::test]
838    async fn receive_skips_binary_frames_and_returns_event() {
839        let binary = WsMessage::Binary(vec![0x10, 0x20].into());
840        let event = WsMessage::Text(
841            r#"{"type":"session_updated","session":{"model":"grok-4","voice":"rex","input_audio_format":"pcm16","output_audio_format":"pcm16"}}"#
842                .to_string()
843                .into(),
844        );
845
846        let (base_url, server_handle) = spawn_response_server(vec![binary, event]).await;
847
848        let client = XaiClient::builder()
849            .api_key("test-key")
850            .base_url(base_url)
851            .build()
852            .unwrap();
853
854        let mut session = client
855            .realtime()
856            .connect("grok-4")
857            .voice(Voice::Rex)
858            .start()
859            .await
860            .unwrap();
861
862        let event = session
863            .receive()
864            .await
865            .expect("event should arrive after binary frame");
866        assert!(matches!(
867            event,
868            Some(RealtimeServerMessage::SessionUpdated { .. })
869        ));
870
871        session.close().await.unwrap();
872        server_handle.await.unwrap();
873    }
874
875    #[tokio::test]
876    async fn receive_returns_none_on_close() {
877        let close = WsMessage::Close(None);
878
879        let (base_url, server_handle) = spawn_response_server(vec![close]).await;
880
881        let client = XaiClient::builder()
882            .api_key("test-key")
883            .base_url(base_url)
884            .build()
885            .unwrap();
886
887        let mut session = client.realtime().connect("grok-4").start().await.unwrap();
888
889        let event = session.receive().await.expect("close should be observed");
890        assert!(event.is_none());
891
892        session.close().await.unwrap();
893        server_handle.await.unwrap();
894    }
895
896    #[tokio::test]
897    async fn receive_skips_control_frames_until_event() {
898        let ping = WsMessage::Ping(vec![0x01].into());
899        let pong = WsMessage::Pong(vec![0x02].into());
900        let event = WsMessage::Text(
901            r#"{"type":"session_updated","session":{"model":"grok-4","voice":"rex","input_audio_format":"pcm16","output_audio_format":"pcm16"}}"#
902                .to_string()
903                .into(),
904        );
905
906        let (base_url, server_handle) = spawn_response_server(vec![ping, pong, event]).await;
907
908        let client = XaiClient::builder()
909            .api_key("test-key")
910            .base_url(base_url)
911            .build()
912            .unwrap();
913
914        let mut session = client
915            .realtime()
916            .connect("grok-4")
917            .voice(Voice::Rex)
918            .start()
919            .await
920            .unwrap();
921
922        let event = session
923            .receive()
924            .await
925            .expect("event should arrive after control frames");
926        assert!(matches!(
927            event,
928            Some(RealtimeServerMessage::SessionUpdated { .. })
929        ));
930
931        session.close().await.unwrap();
932        server_handle.await.unwrap();
933    }
934
935    #[tokio::test]
936    async fn receive_raw_returns_none_on_close() {
937        let close = WsMessage::Close(None);
938
939        let (base_url, server_handle) = spawn_response_server(vec![close]).await;
940
941        let client = XaiClient::builder()
942            .api_key("test-key")
943            .base_url(base_url)
944            .build()
945            .unwrap();
946
947        let mut session = client.realtime().connect("grok-4").start().await.unwrap();
948
949        let event = session
950            .receive_raw()
951            .await
952            .expect("close should be observed");
953        assert!(event.is_none());
954
955        session.close().await.unwrap();
956        server_handle.await.unwrap();
957    }
958
959    #[tokio::test]
960    async fn receive_raw_skips_control_frames_until_audio() {
961        let ping = WsMessage::Ping(vec![0x01].into());
962        let audio = WsMessage::Binary(vec![9, 8, 7].into());
963
964        let (base_url, server_handle) = spawn_response_server(vec![ping, audio]).await;
965
966        let client = XaiClient::builder()
967            .api_key("test-key")
968            .base_url(base_url)
969            .build()
970            .unwrap();
971
972        let mut session = client.realtime().connect("grok-4").start().await.unwrap();
973
974        let first = session
975            .receive_raw()
976            .await
977            .expect("audio after control frame")
978            .expect("audio is present");
979        assert!(matches!(first, RawRealtimeMessage::Audio(_)));
980        match first {
981            RawRealtimeMessage::Audio(bytes) => assert_eq!(bytes, vec![9, 8, 7]),
982            RawRealtimeMessage::Event(_) => unreachable!("expected audio"),
983        }
984
985        session.close().await.unwrap();
986        server_handle.await.unwrap();
987    }
988
989    #[tokio::test]
990    async fn start_rejects_unsupported_base_url_scheme() {
991        let client = XaiClient::builder()
992            .api_key("test-key")
993            .base_url("ftp://localhost")
994            .build()
995            .unwrap();
996
997        let err = client
998            .realtime()
999            .connect("grok-4")
1000            .start()
1001            .await
1002            .unwrap_err();
1003
1004        match err {
1005            Error::InvalidRequest(message) => {
1006                assert_eq!(message, "Unsupported URL scheme: ftp")
1007            }
1008            _ => panic!("expected unsupported scheme error"),
1009        }
1010    }
1011
1012    #[tokio::test]
1013    async fn receive_raw_supports_event_and_audio() {
1014        let event = WsMessage::Text(
1015            r#"{"type":"response_audio_delta","response_id":"resp","item_id":"item","delta":"AQID"}"#
1016                .to_string()
1017                .into(),
1018        );
1019        let audio = WsMessage::Binary(vec![9, 8, 7].into());
1020
1021        let (base_url, server_handle) = spawn_response_server(vec![event, audio]).await;
1022
1023        let client = XaiClient::builder()
1024            .api_key("test-key")
1025            .base_url(base_url)
1026            .build()
1027            .unwrap();
1028
1029        let mut session = client.realtime().connect("grok-4").start().await.unwrap();
1030
1031        let first = session
1032            .receive_raw()
1033            .await
1034            .expect("received event")
1035            .expect("event is present");
1036        match first {
1037            RawRealtimeMessage::Event(message) => {
1038                assert!(matches!(
1039                    message,
1040                    RealtimeServerMessage::ResponseAudioDelta { .. }
1041                ))
1042            }
1043            RawRealtimeMessage::Audio(_) => panic!("expected event first"),
1044        }
1045
1046        let second = session
1047            .receive_raw()
1048            .await
1049            .expect("received audio")
1050            .expect("audio is present");
1051        match second {
1052            RawRealtimeMessage::Audio(bytes) => assert_eq!(bytes, vec![9, 8, 7]),
1053            RawRealtimeMessage::Event(_) => panic!("expected audio second"),
1054        }
1055
1056        session.close().await.unwrap();
1057        server_handle.await.unwrap();
1058    }
1059}