Skip to main content

rustic_ai/realtime/
grok.rs

1//! Grok Realtime WebSocket client and event types.
2//!
3//! Based on the OpenAI Realtime API specification which Grok is compatible with.
4
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use base64::Engine as _;
8use futures::{SinkExt, StreamExt};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use thiserror::Error;
12use tokio::sync::mpsc;
13use tokio_tungstenite::{
14    connect_async,
15    tungstenite::{Message, http::Request},
16};
17use tracing::{debug, error, info, trace, warn};
18
19use crate::messages::ToolCallPart;
20
21#[derive(Debug, Error)]
22pub enum Error {
23    #[error("connection closed")]
24    ConnectionClosed,
25    #[error("serialization error: {0}")]
26    Serialization(String),
27    #[error("websocket error: {0}")]
28    WebSocket(String),
29    #[error("provider error: {0}")]
30    Provider(String),
31}
32
33impl From<serde_json::Error> for Error {
34    fn from(err: serde_json::Error) -> Self {
35        Self::Serialization(err.to_string())
36    }
37}
38
39impl From<tokio_tungstenite::tungstenite::Error> for Error {
40    fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
41        Self::WebSocket(err.to_string())
42    }
43}
44
45pub type Result<T> = std::result::Result<T, Error>;
46
47/// Events sent from client to Grok
48#[derive(Debug, Clone, Serialize)]
49#[serde(tag = "type", rename_all = "snake_case")]
50pub enum ClientEvent {
51    /// Update session configuration
52    #[serde(rename = "session.update")]
53    SessionUpdate { session: SessionUpdatePayload },
54
55    /// Append audio to input buffer
56    #[serde(rename = "input_audio_buffer.append")]
57    InputAudioBufferAppend {
58        #[serde(skip_serializing_if = "Option::is_none")]
59        event_id: Option<String>,
60        audio: String, // base64 encoded
61    },
62
63    /// Commit the audio buffer (create user message)
64    #[serde(rename = "conversation.item.commit")]
65    ConversationItemCommit {
66        #[serde(skip_serializing_if = "Option::is_none")]
67        event_id: Option<String>,
68    },
69
70    /// Clear the audio buffer
71    #[serde(rename = "input_audio_buffer.clear")]
72    InputAudioBufferClear {
73        #[serde(skip_serializing_if = "Option::is_none")]
74        event_id: Option<String>,
75    },
76
77    /// Create a conversation item (e.g., tool result)
78    #[serde(rename = "conversation.item.create")]
79    ConversationItemCreate {
80        #[serde(skip_serializing_if = "Option::is_none")]
81        event_id: Option<String>,
82        item: ConversationItem,
83    },
84
85    /// Trigger a response from the model
86    #[serde(rename = "response.create")]
87    ResponseCreate {
88        #[serde(skip_serializing_if = "Option::is_none")]
89        event_id: Option<String>,
90        #[serde(skip_serializing_if = "Option::is_none")]
91        response: Option<ResponseCreatePayload>,
92    },
93
94    /// Cancel an in-progress response
95    #[serde(rename = "response.cancel")]
96    ResponseCancel {
97        #[serde(skip_serializing_if = "Option::is_none")]
98        event_id: Option<String>,
99    },
100}
101
102#[derive(Debug, Clone, Serialize)]
103pub struct SessionUpdatePayload {
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub instructions: Option<String>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub voice: Option<String>,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub turn_detection: Option<TurnDetection>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub tools: Option<Vec<GrokToolDefinition>>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub temperature: Option<f32>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub audio: Option<AudioConfig>,
116}
117
118#[derive(Debug, Clone, Serialize)]
119pub struct TurnDetection {
120    #[serde(rename = "type")]
121    pub detection_type: String, // "server_vad"
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub threshold: Option<f32>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub prefix_padding_ms: Option<u32>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub silence_duration_ms: Option<u32>,
128}
129
130impl Default for TurnDetection {
131    fn default() -> Self {
132        Self {
133            detection_type: "server_vad".to_string(),
134            threshold: Some(0.5),
135            prefix_padding_ms: Some(300),
136            silence_duration_ms: Some(200),
137        }
138    }
139}
140
141#[derive(Debug, Clone, Serialize)]
142pub struct AudioConfig {
143    pub input: AudioChannelConfig,
144    pub output: AudioChannelConfig,
145}
146
147#[derive(Debug, Clone, Serialize)]
148pub struct AudioChannelConfig {
149    pub format: AudioFormat,
150}
151
152#[derive(Debug, Clone, Serialize)]
153pub struct AudioFormat {
154    #[serde(rename = "type")]
155    pub format_type: String, // "audio/pcm", "audio/pcmu", "audio/pcma"
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub rate: Option<u32>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct GrokToolDefinition {
162    #[serde(rename = "type")]
163    pub tool_type: String, // "function"
164    pub name: String,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub description: Option<String>,
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub parameters: Option<Value>, // JSON Schema
169}
170
171impl GrokToolDefinition {
172    pub fn function(
173        name: impl Into<String>,
174        description: impl Into<String>,
175        parameters: Value,
176    ) -> Self {
177        Self {
178            tool_type: "function".to_string(),
179            name: name.into(),
180            description: Some(description.into()),
181            parameters: Some(parameters),
182        }
183    }
184}
185
186impl From<&crate::tools::ToolDefinition> for GrokToolDefinition {
187    fn from(tool: &crate::tools::ToolDefinition) -> Self {
188        Self {
189            tool_type: "function".to_string(),
190            name: tool.name.clone(),
191            description: tool.description.clone(),
192            parameters: Some(tool.parameters_json_schema.clone()),
193        }
194    }
195}
196
197#[derive(Debug, Clone, Serialize)]
198pub struct ConversationItem {
199    #[serde(rename = "type")]
200    pub item_type: String,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub id: Option<String>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub call_id: Option<String>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub output: Option<String>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub role: Option<String>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub content: Option<Vec<ContentPart>>,
211}
212
213impl ConversationItem {
214    /// Create a function call output item
215    pub fn function_call_output(call_id: String, output: String) -> Self {
216        Self {
217            item_type: "function_call_output".to_string(),
218            id: None,
219            call_id: Some(call_id),
220            output: Some(output),
221            role: None,
222            content: None,
223        }
224    }
225
226    /// Create a user text message item
227    pub fn user_text(text: impl Into<String>) -> Self {
228        Self {
229            item_type: "message".to_string(),
230            id: None,
231            call_id: None,
232            output: None,
233            role: Some("user".to_string()),
234            content: Some(vec![ContentPart {
235                content_type: "input_text".to_string(),
236                text: Some(text.into()),
237                audio: None,
238            }]),
239        }
240    }
241}
242
243#[derive(Debug, Clone, Serialize)]
244pub struct ContentPart {
245    #[serde(rename = "type")]
246    pub content_type: String,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub text: Option<String>,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub audio: Option<String>,
251}
252
253#[derive(Debug, Clone, Serialize)]
254pub struct ResponseCreatePayload {
255    #[serde(skip_serializing_if = "Option::is_none")]
256    pub modalities: Option<Vec<String>>,
257}
258
259/// Events received from Grok server
260#[derive(Debug, Clone, Deserialize)]
261#[serde(tag = "type", rename_all = "snake_case")]
262pub enum ServerEvent {
263    /// Session created
264    #[serde(rename = "session.created")]
265    SessionCreated { session: SessionInfo },
266
267    /// Session updated
268    #[serde(rename = "session.updated")]
269    SessionUpdated { session: SessionInfo },
270
271    /// Conversation created
272    #[serde(rename = "conversation.created")]
273    ConversationCreated {
274        event_id: String,
275        conversation: ConversationInfo,
276        #[serde(default)]
277        previous_item_id: Option<String>,
278    },
279
280    /// Audio delta from model response
281    #[serde(rename = "response.audio.delta")]
282    ResponseAudioDelta {
283        event_id: String,
284        response_id: String,
285        item_id: String,
286        output_index: u32,
287        content_index: u32,
288        delta: String, // base64 encoded audio
289    },
290
291    /// Alternative event name for audio delta
292    #[serde(rename = "response.output_audio.delta")]
293    ResponseOutputAudioDelta {
294        event_id: String,
295        response_id: String,
296        item_id: String,
297        output_index: u32,
298        content_index: u32,
299        delta: String,
300    },
301
302    /// Function call arguments streaming
303    #[serde(rename = "response.function_call_arguments.delta")]
304    ResponseFunctionCallArgumentsDelta {
305        event_id: String,
306        response_id: String,
307        item_id: String,
308        output_index: u32,
309        call_id: String,
310        delta: String,
311    },
312
313    /// Function call arguments complete
314    #[serde(rename = "response.function_call_arguments.done")]
315    ResponseFunctionCallArgumentsDone {
316        event_id: String,
317        response_id: String,
318        item_id: String,
319        output_index: u32,
320        call_id: String,
321        name: String,
322        arguments: String,
323    },
324
325    /// Response completed
326    #[serde(rename = "response.done")]
327    ResponseDone {
328        event_id: String,
329        response_id: String,
330        #[serde(default)]
331        response: Option<ResponseInfo>,
332    },
333
334    /// Speech started in input buffer
335    #[serde(rename = "input_audio_buffer.speech_started")]
336    InputAudioBufferSpeechStarted {
337        event_id: String,
338        audio_start_ms: u64,
339        item_id: String,
340    },
341
342    /// Speech stopped in input buffer
343    #[serde(rename = "input_audio_buffer.speech_stopped")]
344    InputAudioBufferSpeechStopped {
345        event_id: String,
346        audio_end_ms: u64,
347        item_id: String,
348    },
349
350    /// Input audio buffer committed
351    #[serde(rename = "input_audio_buffer.committed")]
352    InputAudioBufferCommitted {
353        event_id: String,
354        item_id: String,
355        previous_item_id: Option<String>,
356    },
357
358    /// Input audio transcription completed
359    #[serde(rename = "conversation.item.input_audio_transcription.completed")]
360    InputAudioTranscriptionCompleted {
361        event_id: String,
362        item_id: String,
363        transcript: String,
364        content_index: u32,
365        status: String,
366        #[serde(default)]
367        previous_item_id: Option<String>,
368    },
369
370    /// Output audio transcript delta
371    #[serde(rename = "response.output_audio_transcript.delta")]
372    ResponseOutputAudioTranscriptDelta {
373        event_id: String,
374        item_id: String,
375        response_id: String,
376        delta: String,
377        content_index: u32,
378        output_index: u32,
379        #[serde(default)]
380        start_time: Option<f32>,
381        #[serde(default)]
382        previous_item_id: Option<String>,
383    },
384
385    /// Output audio transcript completed
386    #[serde(rename = "response.output_audio_transcript.done")]
387    ResponseOutputAudioTranscriptDone {
388        event_id: String,
389        item_id: String,
390        response_id: String,
391        transcript: String,
392        content_index: u32,
393        output_index: u32,
394        #[serde(default)]
395        previous_item_id: Option<String>,
396    },
397
398    /// Rate limits updated
399    #[serde(rename = "rate_limits.updated")]
400    RateLimitsUpdated {
401        event_id: String,
402        rate_limits: Vec<RateLimit>,
403    },
404
405    /// Error from server
406    #[serde(rename = "error")]
407    Error { event_id: String, error: ErrorInfo },
408
409    /// Catch-all for unknown events
410    #[serde(other)]
411    Unknown,
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use serde_json::{Value, json};
418
419    #[test]
420    fn session_update_serializes() {
421        let event = ClientEvent::SessionUpdate {
422            session: SessionUpdatePayload {
423                instructions: Some("be concise".to_string()),
424                voice: Some("alloy".to_string()),
425                turn_detection: Some(TurnDetection::default()),
426                tools: Some(vec![GrokToolDefinition::function(
427                    "echo",
428                    "echo back",
429                    json!({"type": "object", "properties": {}}),
430                )]),
431                temperature: Some(0.3),
432                audio: Some(AudioConfig {
433                    input: AudioChannelConfig {
434                        format: AudioFormat {
435                            format_type: "audio/pcm".to_string(),
436                            rate: Some(16_000),
437                        },
438                    },
439                    output: AudioChannelConfig {
440                        format: AudioFormat {
441                            format_type: "audio/pcm".to_string(),
442                            rate: Some(16_000),
443                        },
444                    },
445                }),
446            },
447        };
448
449        let value = serde_json::to_value(event).expect("serialize");
450        assert_eq!(
451            value.get("type"),
452            Some(&Value::String("session.update".to_string()))
453        );
454        assert_eq!(
455            value
456                .get("session")
457                .and_then(|v| v.get("instructions"))
458                .and_then(|v| v.as_str()),
459            Some("be concise")
460        );
461        assert_eq!(
462            value
463                .get("session")
464                .and_then(|v| v.get("voice"))
465                .and_then(|v| v.as_str()),
466            Some("alloy")
467        );
468    }
469
470    #[test]
471    fn conversation_item_helpers_build_expected_shapes() {
472        let output = ConversationItem::function_call_output("call-1".to_string(), "ok".to_string());
473        let output_value = serde_json::to_value(output).expect("serialize output");
474        assert_eq!(
475            output_value.get("type"),
476            Some(&Value::String("function_call_output".to_string()))
477        );
478        assert_eq!(
479            output_value.get("call_id"),
480            Some(&Value::String("call-1".to_string()))
481        );
482        assert_eq!(
483            output_value.get("output"),
484            Some(&Value::String("ok".to_string()))
485        );
486
487        let user = ConversationItem::user_text("hello");
488        let user_value = serde_json::to_value(user).expect("serialize user");
489        assert_eq!(
490            user_value.get("type"),
491            Some(&Value::String("message".to_string()))
492        );
493        assert_eq!(
494            user_value.get("role"),
495            Some(&Value::String("user".to_string()))
496        );
497        let content = user_value
498            .get("content")
499            .and_then(|v| v.as_array())
500            .expect("content array");
501        assert_eq!(
502            content[0].get("type"),
503            Some(&Value::String("input_text".to_string()))
504        );
505    }
506
507    #[test]
508    fn tool_definition_from_tool() {
509        let tool = crate::tools::ToolDefinition::new(
510            "tool",
511            Some("desc".to_string()),
512            json!({"type": "object", "properties": {}}),
513        );
514        let def: GrokToolDefinition = (&tool).into();
515        assert_eq!(def.tool_type, "function");
516        assert_eq!(def.name, "tool");
517        assert_eq!(def.description.as_deref(), Some("desc"));
518        assert!(def.parameters.is_some());
519    }
520
521    #[test]
522    fn server_event_helpers_extract_audio_and_tool_calls() {
523        let audio_event = ServerEvent::ResponseAudioDelta {
524            event_id: "evt".to_string(),
525            response_id: "resp".to_string(),
526            item_id: "item".to_string(),
527            output_index: 0,
528            content_index: 0,
529            delta: "audio".to_string(),
530        };
531        assert_eq!(audio_event.audio_delta(), Some("audio"));
532        assert!(audio_event.function_call().is_none());
533
534        let output_audio_event = ServerEvent::ResponseOutputAudioDelta {
535            event_id: "evt".to_string(),
536            response_id: "resp".to_string(),
537            item_id: "item".to_string(),
538            output_index: 0,
539            content_index: 0,
540            delta: "audio2".to_string(),
541        };
542        assert_eq!(output_audio_event.audio_delta(), Some("audio2"));
543
544        let call_event = ServerEvent::ResponseFunctionCallArgumentsDone {
545            event_id: "evt".to_string(),
546            response_id: "resp".to_string(),
547            item_id: "item".to_string(),
548            output_index: 0,
549            call_id: "call".to_string(),
550            name: "tool".to_string(),
551            arguments: "{\"a\":1}".to_string(),
552        };
553        let call = call_event.function_call().expect("function call");
554        assert_eq!(call.call_id, "call");
555        assert_eq!(call.name, "tool");
556    }
557
558    #[test]
559    fn function_call_to_tool_call_part_parses_json_or_string() {
560        let call = FunctionCall {
561            call_id: "call-1".to_string(),
562            name: "tool".to_string(),
563            arguments: "{\"a\":1}".to_string(),
564        };
565        let part = call.to_tool_call_part();
566        assert_eq!(part.name, "tool");
567        assert_eq!(part.arguments, json!({"a": 1}));
568
569        let call = FunctionCall {
570            call_id: "call-2".to_string(),
571            name: "tool".to_string(),
572            arguments: "not-json".to_string(),
573        };
574        let part = call.to_tool_call_part();
575        assert_eq!(part.arguments, Value::String("not-json".to_string()));
576    }
577
578    #[test]
579    fn session_config_builders_populate_payload() {
580        let config = SessionConfig::new("hello")
581            .with_voice("Nova")
582            .with_temperature(0.4)
583            .with_audio_format("audio/pcm", Some(16_000))
584            .with_turn_detection(TurnDetection::default());
585        let payload = config.to_update_payload();
586        assert_eq!(payload.instructions.as_deref(), Some("hello"));
587        assert_eq!(payload.voice.as_deref(), Some("Nova"));
588        assert!(payload.tools.is_none());
589        assert_eq!(payload.temperature, Some(0.4));
590        let audio = payload.audio.expect("audio");
591        assert_eq!(audio.input.format.format_type, "audio/pcm");
592        assert_eq!(audio.input.format.rate, Some(16_000));
593
594        let tools = vec![GrokToolDefinition::function(
595            "echo",
596            "Echo back",
597            json!({"type": "object"}),
598        )];
599        let config = SessionConfig::default().with_tools(tools.clone());
600        let payload = config.to_update_payload();
601        assert!(payload.tools.is_some());
602        assert_eq!(payload.tools.unwrap().len(), tools.len());
603    }
604
605    #[tokio::test]
606    async fn grok_sender_emits_events() {
607        let (tx, mut rx) = mpsc::channel(10);
608        let sender = GrokSender { tx };
609
610        sender
611            .send_audio("audio".to_string())
612            .await
613            .expect("send audio");
614        match rx.recv().await.expect("audio event") {
615            ClientEvent::InputAudioBufferAppend { audio, .. } => {
616                assert_eq!(audio, "audio");
617            }
618            other => panic!("unexpected event: {other:?}"),
619        }
620
621        sender
622            .send_user_text("hello".to_string())
623            .await
624            .expect("send text");
625        match rx.recv().await.expect("user event") {
626            ClientEvent::ConversationItemCreate { item, .. } => {
627                assert_eq!(item.item_type, "message");
628                assert_eq!(item.role.as_deref(), Some("user"));
629            }
630            other => panic!("unexpected event: {other:?}"),
631        }
632
633        sender
634            .send_tool_result("call-1".to_string(), "ok".to_string())
635            .await
636            .expect("send tool result");
637        match rx.recv().await.expect("tool result") {
638            ClientEvent::ConversationItemCreate { item, .. } => {
639                assert_eq!(item.item_type, "function_call_output");
640                assert_eq!(item.call_id.as_deref(), Some("call-1"));
641            }
642            other => panic!("unexpected event: {other:?}"),
643        }
644        match rx.recv().await.expect("response create") {
645            ClientEvent::ResponseCreate { response, .. } => {
646                assert!(response.is_none());
647            }
648            other => panic!("unexpected event: {other:?}"),
649        }
650
651        sender
652            .request_response(Some(vec!["text".to_string()]))
653            .await
654            .expect("request response");
655        match rx.recv().await.expect("response create") {
656            ClientEvent::ResponseCreate { response, .. } => {
657                let response = response.expect("response payload");
658                assert_eq!(response.modalities, Some(vec!["text".to_string()]));
659            }
660            other => panic!("unexpected event: {other:?}"),
661        }
662
663        sender.cancel_response().await.expect("cancel response");
664        match rx.recv().await.expect("cancel event") {
665            ClientEvent::ResponseCancel { .. } => {}
666            other => panic!("unexpected event: {other:?}"),
667        }
668
669        sender.commit_audio().await.expect("commit audio");
670        match rx.recv().await.expect("commit event") {
671            ClientEvent::ConversationItemCommit { .. } => {}
672            other => panic!("unexpected event: {other:?}"),
673        }
674    }
675
676    #[test]
677    fn misc_helpers_cover_key_generation_and_host_extraction() {
678        let key = generate_ws_key();
679        let decoded = base64::engine::general_purpose::STANDARD
680            .decode(key.as_bytes())
681            .expect("decode");
682        assert_eq!(decoded.len(), 16);
683
684        assert_eq!(
685            extract_host("wss://api.x.ai/v1/realtime"),
686            "api.x.ai".to_string()
687        );
688        assert_eq!(
689            extract_host("ws://localhost:8080/socket"),
690            "localhost:8080".to_string()
691        );
692
693        let detection = TurnDetection::default();
694        assert_eq!(detection.detection_type, "server_vad");
695        assert_eq!(detection.threshold, Some(0.5));
696    }
697
698    #[test]
699    fn tool_definition_constructor_sets_fields() {
700        let def = GrokToolDefinition::function(
701            "tool",
702            "desc",
703            json!({"type": "object", "properties": {}}),
704        );
705        assert_eq!(def.tool_type, "function");
706        assert_eq!(def.name, "tool");
707        assert_eq!(def.description.as_deref(), Some("desc"));
708        assert!(def.parameters.is_some());
709    }
710}
711
712impl ServerEvent {
713    /// Extract audio delta if this is an audio event
714    pub fn audio_delta(&self) -> Option<&str> {
715        match self {
716            Self::ResponseAudioDelta { delta, .. } => Some(delta),
717            Self::ResponseOutputAudioDelta { delta, .. } => Some(delta),
718            _ => None,
719        }
720    }
721
722    /// Check if this is a function call completion
723    pub fn function_call(&self) -> Option<FunctionCall> {
724        match self {
725            Self::ResponseFunctionCallArgumentsDone {
726                call_id,
727                name,
728                arguments,
729                ..
730            } => Some(FunctionCall {
731                call_id: call_id.clone(),
732                name: name.clone(),
733                arguments: arguments.clone(),
734            }),
735            _ => None,
736        }
737    }
738}
739
740#[derive(Debug, Clone)]
741pub struct FunctionCall {
742    pub call_id: String,
743    pub name: String,
744    pub arguments: String,
745}
746
747impl FunctionCall {
748    pub fn to_tool_call_part(&self) -> ToolCallPart {
749        let args = serde_json::from_str::<Value>(&self.arguments)
750            .unwrap_or_else(|_| Value::String(self.arguments.clone()));
751        ToolCallPart {
752            id: self.call_id.clone(),
753            name: self.name.clone(),
754            arguments: args,
755        }
756    }
757}
758
759#[derive(Debug, Clone, Deserialize)]
760pub struct ConversationInfo {
761    pub id: String,
762    #[serde(default)]
763    pub object: Option<String>,
764}
765
766#[derive(Debug, Clone, Deserialize)]
767pub struct SessionInfo {
768    #[serde(default)]
769    pub id: Option<String>,
770    #[serde(default)]
771    pub model: Option<String>,
772    #[serde(default)]
773    pub voice: Option<String>,
774}
775
776#[derive(Debug, Clone, Deserialize)]
777pub struct ResponseInfo {
778    #[serde(default)]
779    pub id: Option<String>,
780    #[serde(default)]
781    pub status: Option<String>,
782}
783
784#[derive(Debug, Clone, Deserialize)]
785pub struct RateLimit {
786    pub name: String,
787    pub limit: u32,
788    pub remaining: u32,
789    pub reset_seconds: f32,
790}
791
792#[derive(Debug, Clone, Deserialize)]
793pub struct ErrorInfo {
794    #[serde(rename = "type")]
795    pub error_type: String,
796    pub code: Option<String>,
797    pub message: String,
798}
799
800/// Configuration for a Grok Realtime session
801#[derive(Debug, Clone)]
802pub struct SessionConfig {
803    pub instructions: String,
804    pub voice: String,
805    pub tools: Vec<GrokToolDefinition>,
806    pub temperature: f32,
807    pub audio_format: AudioFormat,
808    pub turn_detection: TurnDetection,
809}
810
811impl Default for SessionConfig {
812    fn default() -> Self {
813        Self {
814            instructions: "You are a helpful voice assistant.".to_string(),
815            voice: "Ara".to_string(),
816            tools: Vec::new(),
817            temperature: 0.8,
818            audio_format: AudioFormat {
819                format_type: "audio/pcmu".to_string(),
820                rate: None,
821            },
822            turn_detection: TurnDetection::default(),
823        }
824    }
825}
826
827impl SessionConfig {
828    pub fn new(instructions: impl Into<String>) -> Self {
829        Self {
830            instructions: instructions.into(),
831            ..Default::default()
832        }
833    }
834
835    pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
836        self.voice = voice.into();
837        self
838    }
839
840    pub fn with_tools(mut self, tools: Vec<GrokToolDefinition>) -> Self {
841        self.tools = tools;
842        self
843    }
844
845    pub fn with_rustic_tools(mut self, tools: &[crate::tools::ToolDefinition]) -> Self {
846        self.tools = tools.iter().map(GrokToolDefinition::from).collect();
847        self
848    }
849
850    pub fn with_temperature(mut self, temperature: f32) -> Self {
851        self.temperature = temperature;
852        self
853    }
854
855    pub fn with_audio_format(mut self, format_type: impl Into<String>, rate: Option<u32>) -> Self {
856        self.audio_format = AudioFormat {
857            format_type: format_type.into(),
858            rate,
859        };
860        self
861    }
862
863    pub fn with_turn_detection(mut self, detection: TurnDetection) -> Self {
864        self.turn_detection = detection;
865        self
866    }
867
868    /// Convert to session update payload for the API
869    pub fn to_update_payload(&self) -> SessionUpdatePayload {
870        SessionUpdatePayload {
871            instructions: Some(self.instructions.clone()),
872            voice: Some(self.voice.clone()),
873            turn_detection: Some(self.turn_detection.clone()),
874            tools: if self.tools.is_empty() {
875                None
876            } else {
877                Some(self.tools.clone())
878            },
879            temperature: Some(self.temperature),
880            audio: Some(AudioConfig {
881                input: AudioChannelConfig {
882                    format: self.audio_format.clone(),
883                },
884                output: AudioChannelConfig {
885                    format: self.audio_format.clone(),
886                },
887            }),
888        }
889    }
890}
891
892/// Handle for sending events to Grok
893#[derive(Clone)]
894pub struct GrokSender {
895    tx: mpsc::Sender<ClientEvent>,
896}
897
898impl GrokSender {
899    /// Send audio data to Grok
900    pub async fn send_audio(&self, audio_base64: String) -> Result<()> {
901        self.tx
902            .send(ClientEvent::InputAudioBufferAppend {
903                event_id: None,
904                audio: audio_base64,
905            })
906            .await
907            .map_err(|_| Error::ConnectionClosed)
908    }
909
910    /// Send a tool result back to Grok
911    pub async fn send_tool_result(&self, call_id: String, result: String) -> Result<()> {
912        self.tx
913            .send(ClientEvent::ConversationItemCreate {
914                event_id: None,
915                item: ConversationItem::function_call_output(call_id, result),
916            })
917            .await
918            .map_err(|_| Error::ConnectionClosed)?;
919
920        self.tx
921            .send(ClientEvent::ResponseCreate {
922                event_id: None,
923                response: None,
924            })
925            .await
926            .map_err(|_| Error::ConnectionClosed)
927    }
928
929    /// Send a user text message to Grok
930    pub async fn send_user_text(&self, text: String) -> Result<()> {
931        self.tx
932            .send(ClientEvent::ConversationItemCreate {
933                event_id: None,
934                item: ConversationItem::user_text(text),
935            })
936            .await
937            .map_err(|_| Error::ConnectionClosed)
938    }
939
940    /// Request a model response
941    pub async fn request_response(&self, modalities: Option<Vec<String>>) -> Result<()> {
942        self.tx
943            .send(ClientEvent::ResponseCreate {
944                event_id: None,
945                response: Some(ResponseCreatePayload { modalities }),
946            })
947            .await
948            .map_err(|_| Error::ConnectionClosed)
949    }
950
951    /// Cancel the current response (e.g., on interruption)
952    pub async fn cancel_response(&self) -> Result<()> {
953        self.tx
954            .send(ClientEvent::ResponseCancel { event_id: None })
955            .await
956            .map_err(|_| Error::ConnectionClosed)
957    }
958
959    /// Commit the current input audio buffer
960    pub async fn commit_audio(&self) -> Result<()> {
961        self.tx
962            .send(ClientEvent::ConversationItemCommit { event_id: None })
963            .await
964            .map_err(|_| Error::ConnectionClosed)
965    }
966}
967
968/// Grok Realtime API client
969pub struct GrokClient {
970    ws_url: String,
971    api_key: String,
972}
973
974impl GrokClient {
975    pub fn new(ws_url: String, api_key: String) -> Self {
976        Self { ws_url, api_key }
977    }
978
979    /// Connect to Grok and return sender/receiver handles
980    ///
981    /// Returns:
982    /// - `GrokSender`: For sending audio and tool results
983    /// - `mpsc::Receiver<ServerEvent>`: For receiving events from Grok
984    pub async fn connect(
985        &self,
986        session_config: SessionConfig,
987    ) -> Result<(GrokSender, mpsc::Receiver<ServerEvent>)> {
988        let request = Request::builder()
989            .uri(&self.ws_url)
990            .header("Authorization", format!("Bearer {}", self.api_key))
991            .header("Sec-WebSocket-Key", generate_ws_key())
992            .header("Sec-WebSocket-Version", "13")
993            .header("Connection", "Upgrade")
994            .header("Upgrade", "websocket")
995            .header("Host", extract_host(&self.ws_url))
996            .body(())
997            .map_err(|e| Error::Provider(format!("failed to build request: {e}")))?;
998
999        info!(url = %self.ws_url, "Connecting to Grok Realtime API");
1000
1001        let (ws_stream, _response) = connect_async(request)
1002            .await
1003            .map_err(|e| Error::Provider(format!("websocket connection failed: {e}")))?;
1004
1005        info!("Connected to Grok Realtime API");
1006
1007        let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
1008
1009        let (client_tx, mut client_rx) = mpsc::channel::<ClientEvent>(256);
1010        let (server_tx, server_rx) = mpsc::channel::<ServerEvent>(256);
1011
1012        let session_update = ClientEvent::SessionUpdate {
1013            session: session_config.to_update_payload(),
1014        };
1015        let msg = serde_json::to_string(&session_update)?;
1016        ws_sink
1017            .send(Message::Text(msg))
1018            .await
1019            .map_err(|e| Error::Provider(format!("failed to send session update: {e}")))?;
1020        debug!("Sent session.update");
1021
1022        tokio::spawn(async move {
1023            while let Some(event) = client_rx.recv().await {
1024                match serde_json::to_string(&event) {
1025                    Ok(msg) => {
1026                        if let Err(e) = ws_sink.send(Message::Text(msg)).await {
1027                            error!(error = %e, "Failed to send to Grok WebSocket");
1028                            break;
1029                        }
1030                    }
1031                    Err(e) => {
1032                        error!(error = %e, "Failed to serialize client event");
1033                    }
1034                }
1035            }
1036            debug!("Grok sender task ended");
1037        });
1038
1039        tokio::spawn(async move {
1040            while let Some(msg_result) = ws_stream_rx.next().await {
1041                match msg_result {
1042                    Ok(Message::Text(text)) => match serde_json::from_str::<Value>(&text) {
1043                        Ok(value) => {
1044                            let event_type = value
1045                                .get("type")
1046                                .and_then(|val| val.as_str())
1047                                .unwrap_or("unknown");
1048                            match serde_json::from_value::<ServerEvent>(value.clone()) {
1049                                Ok(event) => {
1050                                    if matches!(event, ServerEvent::Unknown) {
1051                                        trace!(event_type = %event_type, raw = %text, "Unhandled Grok event");
1052                                    } else if event.audio_delta().is_none() {
1053                                        debug!(?event, "Received Grok event");
1054                                    }
1055                                    if server_tx.send(event).await.is_err() {
1056                                        debug!("Server event receiver dropped");
1057                                        break;
1058                                    }
1059                                }
1060                                Err(e) => {
1061                                    warn!(
1062                                        error = %e,
1063                                        event_type = %event_type,
1064                                        "Failed to parse Grok event"
1065                                    );
1066                                    trace!(raw = %text, "Grok event parse failure payload");
1067                                }
1068                            }
1069                        }
1070                        Err(e) => {
1071                            warn!(error = %e, "Failed to parse Grok event");
1072                            trace!(raw = %text, "Grok event parse failure payload");
1073                        }
1074                    },
1075                    Ok(Message::Close(_)) => {
1076                        info!("Grok WebSocket closed");
1077                        break;
1078                    }
1079                    Ok(Message::Ping(data)) => {
1080                        debug!("Received ping from Grok");
1081                        let _ = data;
1082                    }
1083                    Ok(_) => {}
1084                    Err(e) => {
1085                        error!(error = %e, "Grok WebSocket error");
1086                        break;
1087                    }
1088                }
1089            }
1090            debug!("Grok receiver task ended");
1091        });
1092
1093        Ok((GrokSender { tx: client_tx }, server_rx))
1094    }
1095}
1096
1097fn generate_ws_key() -> String {
1098    let mut key = [0u8; 16];
1099    for (i, byte) in key.iter_mut().enumerate() {
1100        let now = SystemTime::now()
1101            .duration_since(UNIX_EPOCH)
1102            .unwrap_or(Duration::from_secs(0));
1103        *byte = (now.as_nanos() as u8).wrapping_add(i as u8);
1104    }
1105    base64::engine::general_purpose::STANDARD.encode(key)
1106}
1107
1108fn extract_host(url: &str) -> String {
1109    url.replace("wss://", "")
1110        .replace("ws://", "")
1111        .split('/')
1112        .next()
1113        .unwrap_or("api.x.ai")
1114        .to_string()
1115}