use std::time::{Duration, SystemTime, UNIX_EPOCH};
use base64::Engine as _;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::mpsc;
use tokio_tungstenite::{
connect_async,
tungstenite::{Message, http::Request},
};
use tracing::{debug, error, info, trace, warn};
use crate::messages::ToolCallPart;
#[derive(Debug, Error)]
pub enum Error {
#[error("connection closed")]
ConnectionClosed,
#[error("serialization error: {0}")]
Serialization(String),
#[error("websocket error: {0}")]
WebSocket(String),
#[error("provider error: {0}")]
Provider(String),
}
impl From<serde_json::Error> for Error {
fn from(err: serde_json::Error) -> Self {
Self::Serialization(err.to_string())
}
}
impl From<tokio_tungstenite::tungstenite::Error> for Error {
fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
Self::WebSocket(err.to_string())
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientEvent {
#[serde(rename = "session.update")]
SessionUpdate { session: SessionUpdatePayload },
#[serde(rename = "input_audio_buffer.append")]
InputAudioBufferAppend {
#[serde(skip_serializing_if = "Option::is_none")]
event_id: Option<String>,
audio: String, },
#[serde(rename = "conversation.item.commit")]
ConversationItemCommit {
#[serde(skip_serializing_if = "Option::is_none")]
event_id: Option<String>,
},
#[serde(rename = "input_audio_buffer.clear")]
InputAudioBufferClear {
#[serde(skip_serializing_if = "Option::is_none")]
event_id: Option<String>,
},
#[serde(rename = "conversation.item.create")]
ConversationItemCreate {
#[serde(skip_serializing_if = "Option::is_none")]
event_id: Option<String>,
item: ConversationItem,
},
#[serde(rename = "response.create")]
ResponseCreate {
#[serde(skip_serializing_if = "Option::is_none")]
event_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
response: Option<ResponseCreatePayload>,
},
#[serde(rename = "response.cancel")]
ResponseCancel {
#[serde(skip_serializing_if = "Option::is_none")]
event_id: Option<String>,
},
}
#[derive(Debug, Clone, Serialize)]
pub struct SessionUpdatePayload {
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub turn_detection: Option<TurnDetection>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<GrokToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio: Option<AudioConfig>,
}
#[derive(Debug, Clone, Serialize)]
pub struct TurnDetection {
#[serde(rename = "type")]
pub detection_type: String, #[serde(skip_serializing_if = "Option::is_none")]
pub threshold: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefix_padding_ms: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub silence_duration_ms: Option<u32>,
}
impl Default for TurnDetection {
fn default() -> Self {
Self {
detection_type: "server_vad".to_string(),
threshold: Some(0.5),
prefix_padding_ms: Some(300),
silence_duration_ms: Some(200),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct AudioConfig {
pub input: AudioChannelConfig,
pub output: AudioChannelConfig,
}
#[derive(Debug, Clone, Serialize)]
pub struct AudioChannelConfig {
pub format: AudioFormat,
}
#[derive(Debug, Clone, Serialize)]
pub struct AudioFormat {
#[serde(rename = "type")]
pub format_type: String, #[serde(skip_serializing_if = "Option::is_none")]
pub rate: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrokToolDefinition {
#[serde(rename = "type")]
pub tool_type: String, pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>, }
impl GrokToolDefinition {
pub fn function(
name: impl Into<String>,
description: impl Into<String>,
parameters: Value,
) -> Self {
Self {
tool_type: "function".to_string(),
name: name.into(),
description: Some(description.into()),
parameters: Some(parameters),
}
}
}
impl From<&crate::tools::ToolDefinition> for GrokToolDefinition {
fn from(tool: &crate::tools::ToolDefinition) -> Self {
Self {
tool_type: "function".to_string(),
name: tool.name.clone(),
description: tool.description.clone(),
parameters: Some(tool.parameters_json_schema.clone()),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ConversationItem {
#[serde(rename = "type")]
pub item_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Vec<ContentPart>>,
}
impl ConversationItem {
pub fn function_call_output(call_id: String, output: String) -> Self {
Self {
item_type: "function_call_output".to_string(),
id: None,
call_id: Some(call_id),
output: Some(output),
role: None,
content: None,
}
}
pub fn user_text(text: impl Into<String>) -> Self {
Self {
item_type: "message".to_string(),
id: None,
call_id: None,
output: None,
role: Some("user".to_string()),
content: Some(vec![ContentPart {
content_type: "input_text".to_string(),
text: Some(text.into()),
audio: None,
}]),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ContentPart {
#[serde(rename = "type")]
pub content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ResponseCreatePayload {
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerEvent {
#[serde(rename = "session.created")]
SessionCreated { session: SessionInfo },
#[serde(rename = "session.updated")]
SessionUpdated { session: SessionInfo },
#[serde(rename = "conversation.created")]
ConversationCreated {
event_id: String,
conversation: ConversationInfo,
#[serde(default)]
previous_item_id: Option<String>,
},
#[serde(rename = "response.audio.delta")]
ResponseAudioDelta {
event_id: String,
response_id: String,
item_id: String,
output_index: u32,
content_index: u32,
delta: String, },
#[serde(rename = "response.output_audio.delta")]
ResponseOutputAudioDelta {
event_id: String,
response_id: String,
item_id: String,
output_index: u32,
content_index: u32,
delta: String,
},
#[serde(rename = "response.function_call_arguments.delta")]
ResponseFunctionCallArgumentsDelta {
event_id: String,
response_id: String,
item_id: String,
output_index: u32,
call_id: String,
delta: String,
},
#[serde(rename = "response.function_call_arguments.done")]
ResponseFunctionCallArgumentsDone {
event_id: String,
response_id: String,
item_id: String,
output_index: u32,
call_id: String,
name: String,
arguments: String,
},
#[serde(rename = "response.done")]
ResponseDone {
event_id: String,
response_id: String,
#[serde(default)]
response: Option<ResponseInfo>,
},
#[serde(rename = "input_audio_buffer.speech_started")]
InputAudioBufferSpeechStarted {
event_id: String,
audio_start_ms: u64,
item_id: String,
},
#[serde(rename = "input_audio_buffer.speech_stopped")]
InputAudioBufferSpeechStopped {
event_id: String,
audio_end_ms: u64,
item_id: String,
},
#[serde(rename = "input_audio_buffer.committed")]
InputAudioBufferCommitted {
event_id: String,
item_id: String,
previous_item_id: Option<String>,
},
#[serde(rename = "conversation.item.input_audio_transcription.completed")]
InputAudioTranscriptionCompleted {
event_id: String,
item_id: String,
transcript: String,
content_index: u32,
status: String,
#[serde(default)]
previous_item_id: Option<String>,
},
#[serde(rename = "response.output_audio_transcript.delta")]
ResponseOutputAudioTranscriptDelta {
event_id: String,
item_id: String,
response_id: String,
delta: String,
content_index: u32,
output_index: u32,
#[serde(default)]
start_time: Option<f32>,
#[serde(default)]
previous_item_id: Option<String>,
},
#[serde(rename = "response.output_audio_transcript.done")]
ResponseOutputAudioTranscriptDone {
event_id: String,
item_id: String,
response_id: String,
transcript: String,
content_index: u32,
output_index: u32,
#[serde(default)]
previous_item_id: Option<String>,
},
#[serde(rename = "rate_limits.updated")]
RateLimitsUpdated {
event_id: String,
rate_limits: Vec<RateLimit>,
},
#[serde(rename = "error")]
Error { event_id: String, error: ErrorInfo },
#[serde(other)]
Unknown,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{Value, json};
#[test]
fn session_update_serializes() {
let event = ClientEvent::SessionUpdate {
session: SessionUpdatePayload {
instructions: Some("be concise".to_string()),
voice: Some("alloy".to_string()),
turn_detection: Some(TurnDetection::default()),
tools: Some(vec![GrokToolDefinition::function(
"echo",
"echo back",
json!({"type": "object", "properties": {}}),
)]),
temperature: Some(0.3),
audio: Some(AudioConfig {
input: AudioChannelConfig {
format: AudioFormat {
format_type: "audio/pcm".to_string(),
rate: Some(16_000),
},
},
output: AudioChannelConfig {
format: AudioFormat {
format_type: "audio/pcm".to_string(),
rate: Some(16_000),
},
},
}),
},
};
let value = serde_json::to_value(event).expect("serialize");
assert_eq!(
value.get("type"),
Some(&Value::String("session.update".to_string()))
);
assert_eq!(
value
.get("session")
.and_then(|v| v.get("instructions"))
.and_then(|v| v.as_str()),
Some("be concise")
);
assert_eq!(
value
.get("session")
.and_then(|v| v.get("voice"))
.and_then(|v| v.as_str()),
Some("alloy")
);
}
#[test]
fn conversation_item_helpers_build_expected_shapes() {
let output = ConversationItem::function_call_output("call-1".to_string(), "ok".to_string());
let output_value = serde_json::to_value(output).expect("serialize output");
assert_eq!(
output_value.get("type"),
Some(&Value::String("function_call_output".to_string()))
);
assert_eq!(
output_value.get("call_id"),
Some(&Value::String("call-1".to_string()))
);
assert_eq!(
output_value.get("output"),
Some(&Value::String("ok".to_string()))
);
let user = ConversationItem::user_text("hello");
let user_value = serde_json::to_value(user).expect("serialize user");
assert_eq!(
user_value.get("type"),
Some(&Value::String("message".to_string()))
);
assert_eq!(
user_value.get("role"),
Some(&Value::String("user".to_string()))
);
let content = user_value
.get("content")
.and_then(|v| v.as_array())
.expect("content array");
assert_eq!(
content[0].get("type"),
Some(&Value::String("input_text".to_string()))
);
}
#[test]
fn tool_definition_from_tool() {
let tool = crate::tools::ToolDefinition::new(
"tool",
Some("desc".to_string()),
json!({"type": "object", "properties": {}}),
);
let def: GrokToolDefinition = (&tool).into();
assert_eq!(def.tool_type, "function");
assert_eq!(def.name, "tool");
assert_eq!(def.description.as_deref(), Some("desc"));
assert!(def.parameters.is_some());
}
#[test]
fn server_event_helpers_extract_audio_and_tool_calls() {
let audio_event = ServerEvent::ResponseAudioDelta {
event_id: "evt".to_string(),
response_id: "resp".to_string(),
item_id: "item".to_string(),
output_index: 0,
content_index: 0,
delta: "audio".to_string(),
};
assert_eq!(audio_event.audio_delta(), Some("audio"));
assert!(audio_event.function_call().is_none());
let output_audio_event = ServerEvent::ResponseOutputAudioDelta {
event_id: "evt".to_string(),
response_id: "resp".to_string(),
item_id: "item".to_string(),
output_index: 0,
content_index: 0,
delta: "audio2".to_string(),
};
assert_eq!(output_audio_event.audio_delta(), Some("audio2"));
let call_event = ServerEvent::ResponseFunctionCallArgumentsDone {
event_id: "evt".to_string(),
response_id: "resp".to_string(),
item_id: "item".to_string(),
output_index: 0,
call_id: "call".to_string(),
name: "tool".to_string(),
arguments: "{\"a\":1}".to_string(),
};
let call = call_event.function_call().expect("function call");
assert_eq!(call.call_id, "call");
assert_eq!(call.name, "tool");
}
#[test]
fn function_call_to_tool_call_part_parses_json_or_string() {
let call = FunctionCall {
call_id: "call-1".to_string(),
name: "tool".to_string(),
arguments: "{\"a\":1}".to_string(),
};
let part = call.to_tool_call_part();
assert_eq!(part.name, "tool");
assert_eq!(part.arguments, json!({"a": 1}));
let call = FunctionCall {
call_id: "call-2".to_string(),
name: "tool".to_string(),
arguments: "not-json".to_string(),
};
let part = call.to_tool_call_part();
assert_eq!(part.arguments, Value::String("not-json".to_string()));
}
#[test]
fn session_config_builders_populate_payload() {
let config = SessionConfig::new("hello")
.with_voice("Nova")
.with_temperature(0.4)
.with_audio_format("audio/pcm", Some(16_000))
.with_turn_detection(TurnDetection::default());
let payload = config.to_update_payload();
assert_eq!(payload.instructions.as_deref(), Some("hello"));
assert_eq!(payload.voice.as_deref(), Some("Nova"));
assert!(payload.tools.is_none());
assert_eq!(payload.temperature, Some(0.4));
let audio = payload.audio.expect("audio");
assert_eq!(audio.input.format.format_type, "audio/pcm");
assert_eq!(audio.input.format.rate, Some(16_000));
let tools = vec![GrokToolDefinition::function(
"echo",
"Echo back",
json!({"type": "object"}),
)];
let config = SessionConfig::default().with_tools(tools.clone());
let payload = config.to_update_payload();
assert!(payload.tools.is_some());
assert_eq!(payload.tools.unwrap().len(), tools.len());
}
#[tokio::test]
async fn grok_sender_emits_events() {
let (tx, mut rx) = mpsc::channel(10);
let sender = GrokSender { tx };
sender
.send_audio("audio".to_string())
.await
.expect("send audio");
match rx.recv().await.expect("audio event") {
ClientEvent::InputAudioBufferAppend { audio, .. } => {
assert_eq!(audio, "audio");
}
other => panic!("unexpected event: {other:?}"),
}
sender
.send_user_text("hello".to_string())
.await
.expect("send text");
match rx.recv().await.expect("user event") {
ClientEvent::ConversationItemCreate { item, .. } => {
assert_eq!(item.item_type, "message");
assert_eq!(item.role.as_deref(), Some("user"));
}
other => panic!("unexpected event: {other:?}"),
}
sender
.send_tool_result("call-1".to_string(), "ok".to_string())
.await
.expect("send tool result");
match rx.recv().await.expect("tool result") {
ClientEvent::ConversationItemCreate { item, .. } => {
assert_eq!(item.item_type, "function_call_output");
assert_eq!(item.call_id.as_deref(), Some("call-1"));
}
other => panic!("unexpected event: {other:?}"),
}
match rx.recv().await.expect("response create") {
ClientEvent::ResponseCreate { response, .. } => {
assert!(response.is_none());
}
other => panic!("unexpected event: {other:?}"),
}
sender
.request_response(Some(vec!["text".to_string()]))
.await
.expect("request response");
match rx.recv().await.expect("response create") {
ClientEvent::ResponseCreate { response, .. } => {
let response = response.expect("response payload");
assert_eq!(response.modalities, Some(vec!["text".to_string()]));
}
other => panic!("unexpected event: {other:?}"),
}
sender.cancel_response().await.expect("cancel response");
match rx.recv().await.expect("cancel event") {
ClientEvent::ResponseCancel { .. } => {}
other => panic!("unexpected event: {other:?}"),
}
sender.commit_audio().await.expect("commit audio");
match rx.recv().await.expect("commit event") {
ClientEvent::ConversationItemCommit { .. } => {}
other => panic!("unexpected event: {other:?}"),
}
}
#[test]
fn misc_helpers_cover_key_generation_and_host_extraction() {
let key = generate_ws_key();
let decoded = base64::engine::general_purpose::STANDARD
.decode(key.as_bytes())
.expect("decode");
assert_eq!(decoded.len(), 16);
assert_eq!(
extract_host("wss://api.x.ai/v1/realtime"),
"api.x.ai".to_string()
);
assert_eq!(
extract_host("ws://localhost:8080/socket"),
"localhost:8080".to_string()
);
let detection = TurnDetection::default();
assert_eq!(detection.detection_type, "server_vad");
assert_eq!(detection.threshold, Some(0.5));
}
#[test]
fn tool_definition_constructor_sets_fields() {
let def = GrokToolDefinition::function(
"tool",
"desc",
json!({"type": "object", "properties": {}}),
);
assert_eq!(def.tool_type, "function");
assert_eq!(def.name, "tool");
assert_eq!(def.description.as_deref(), Some("desc"));
assert!(def.parameters.is_some());
}
}
impl ServerEvent {
pub fn audio_delta(&self) -> Option<&str> {
match self {
Self::ResponseAudioDelta { delta, .. } => Some(delta),
Self::ResponseOutputAudioDelta { delta, .. } => Some(delta),
_ => None,
}
}
pub fn function_call(&self) -> Option<FunctionCall> {
match self {
Self::ResponseFunctionCallArgumentsDone {
call_id,
name,
arguments,
..
} => Some(FunctionCall {
call_id: call_id.clone(),
name: name.clone(),
arguments: arguments.clone(),
}),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct FunctionCall {
pub call_id: String,
pub name: String,
pub arguments: String,
}
impl FunctionCall {
pub fn to_tool_call_part(&self) -> ToolCallPart {
let args = serde_json::from_str::<Value>(&self.arguments)
.unwrap_or_else(|_| Value::String(self.arguments.clone()));
ToolCallPart {
id: self.call_id.clone(),
name: self.name.clone(),
arguments: args,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ConversationInfo {
pub id: String,
#[serde(default)]
pub object: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SessionInfo {
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub voice: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ResponseInfo {
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub status: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RateLimit {
pub name: String,
pub limit: u32,
pub remaining: u32,
pub reset_seconds: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ErrorInfo {
#[serde(rename = "type")]
pub error_type: String,
pub code: Option<String>,
pub message: String,
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub instructions: String,
pub voice: String,
pub tools: Vec<GrokToolDefinition>,
pub temperature: f32,
pub audio_format: AudioFormat,
pub turn_detection: TurnDetection,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
instructions: "You are a helpful voice assistant.".to_string(),
voice: "Ara".to_string(),
tools: Vec::new(),
temperature: 0.8,
audio_format: AudioFormat {
format_type: "audio/pcmu".to_string(),
rate: None,
},
turn_detection: TurnDetection::default(),
}
}
}
impl SessionConfig {
pub fn new(instructions: impl Into<String>) -> Self {
Self {
instructions: instructions.into(),
..Default::default()
}
}
pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
self.voice = voice.into();
self
}
pub fn with_tools(mut self, tools: Vec<GrokToolDefinition>) -> Self {
self.tools = tools;
self
}
pub fn with_rustic_tools(mut self, tools: &[crate::tools::ToolDefinition]) -> Self {
self.tools = tools.iter().map(GrokToolDefinition::from).collect();
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_audio_format(mut self, format_type: impl Into<String>, rate: Option<u32>) -> Self {
self.audio_format = AudioFormat {
format_type: format_type.into(),
rate,
};
self
}
pub fn with_turn_detection(mut self, detection: TurnDetection) -> Self {
self.turn_detection = detection;
self
}
pub fn to_update_payload(&self) -> SessionUpdatePayload {
SessionUpdatePayload {
instructions: Some(self.instructions.clone()),
voice: Some(self.voice.clone()),
turn_detection: Some(self.turn_detection.clone()),
tools: if self.tools.is_empty() {
None
} else {
Some(self.tools.clone())
},
temperature: Some(self.temperature),
audio: Some(AudioConfig {
input: AudioChannelConfig {
format: self.audio_format.clone(),
},
output: AudioChannelConfig {
format: self.audio_format.clone(),
},
}),
}
}
}
#[derive(Clone)]
pub struct GrokSender {
tx: mpsc::Sender<ClientEvent>,
}
impl GrokSender {
pub async fn send_audio(&self, audio_base64: String) -> Result<()> {
self.tx
.send(ClientEvent::InputAudioBufferAppend {
event_id: None,
audio: audio_base64,
})
.await
.map_err(|_| Error::ConnectionClosed)
}
pub async fn send_tool_result(&self, call_id: String, result: String) -> Result<()> {
self.tx
.send(ClientEvent::ConversationItemCreate {
event_id: None,
item: ConversationItem::function_call_output(call_id, result),
})
.await
.map_err(|_| Error::ConnectionClosed)?;
self.tx
.send(ClientEvent::ResponseCreate {
event_id: None,
response: None,
})
.await
.map_err(|_| Error::ConnectionClosed)
}
pub async fn send_user_text(&self, text: String) -> Result<()> {
self.tx
.send(ClientEvent::ConversationItemCreate {
event_id: None,
item: ConversationItem::user_text(text),
})
.await
.map_err(|_| Error::ConnectionClosed)
}
pub async fn request_response(&self, modalities: Option<Vec<String>>) -> Result<()> {
self.tx
.send(ClientEvent::ResponseCreate {
event_id: None,
response: Some(ResponseCreatePayload { modalities }),
})
.await
.map_err(|_| Error::ConnectionClosed)
}
pub async fn cancel_response(&self) -> Result<()> {
self.tx
.send(ClientEvent::ResponseCancel { event_id: None })
.await
.map_err(|_| Error::ConnectionClosed)
}
pub async fn commit_audio(&self) -> Result<()> {
self.tx
.send(ClientEvent::ConversationItemCommit { event_id: None })
.await
.map_err(|_| Error::ConnectionClosed)
}
}
pub struct GrokClient {
ws_url: String,
api_key: String,
}
impl GrokClient {
pub fn new(ws_url: String, api_key: String) -> Self {
Self { ws_url, api_key }
}
pub async fn connect(
&self,
session_config: SessionConfig,
) -> Result<(GrokSender, mpsc::Receiver<ServerEvent>)> {
let request = Request::builder()
.uri(&self.ws_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Sec-WebSocket-Key", generate_ws_key())
.header("Sec-WebSocket-Version", "13")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Host", extract_host(&self.ws_url))
.body(())
.map_err(|e| Error::Provider(format!("failed to build request: {e}")))?;
info!(url = %self.ws_url, "Connecting to Grok Realtime API");
let (ws_stream, _response) = connect_async(request)
.await
.map_err(|e| Error::Provider(format!("websocket connection failed: {e}")))?;
info!("Connected to Grok Realtime API");
let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
let (client_tx, mut client_rx) = mpsc::channel::<ClientEvent>(256);
let (server_tx, server_rx) = mpsc::channel::<ServerEvent>(256);
let session_update = ClientEvent::SessionUpdate {
session: session_config.to_update_payload(),
};
let msg = serde_json::to_string(&session_update)?;
ws_sink
.send(Message::Text(msg))
.await
.map_err(|e| Error::Provider(format!("failed to send session update: {e}")))?;
debug!("Sent session.update");
tokio::spawn(async move {
while let Some(event) = client_rx.recv().await {
match serde_json::to_string(&event) {
Ok(msg) => {
if let Err(e) = ws_sink.send(Message::Text(msg)).await {
error!(error = %e, "Failed to send to Grok WebSocket");
break;
}
}
Err(e) => {
error!(error = %e, "Failed to serialize client event");
}
}
}
debug!("Grok sender task ended");
});
tokio::spawn(async move {
while let Some(msg_result) = ws_stream_rx.next().await {
match msg_result {
Ok(Message::Text(text)) => match serde_json::from_str::<Value>(&text) {
Ok(value) => {
let event_type = value
.get("type")
.and_then(|val| val.as_str())
.unwrap_or("unknown");
match serde_json::from_value::<ServerEvent>(value.clone()) {
Ok(event) => {
if matches!(event, ServerEvent::Unknown) {
trace!(event_type = %event_type, raw = %text, "Unhandled Grok event");
} else if event.audio_delta().is_none() {
debug!(?event, "Received Grok event");
}
if server_tx.send(event).await.is_err() {
debug!("Server event receiver dropped");
break;
}
}
Err(e) => {
warn!(
error = %e,
event_type = %event_type,
"Failed to parse Grok event"
);
trace!(raw = %text, "Grok event parse failure payload");
}
}
}
Err(e) => {
warn!(error = %e, "Failed to parse Grok event");
trace!(raw = %text, "Grok event parse failure payload");
}
},
Ok(Message::Close(_)) => {
info!("Grok WebSocket closed");
break;
}
Ok(Message::Ping(data)) => {
debug!("Received ping from Grok");
let _ = data;
}
Ok(_) => {}
Err(e) => {
error!(error = %e, "Grok WebSocket error");
break;
}
}
}
debug!("Grok receiver task ended");
});
Ok((GrokSender { tx: client_tx }, server_rx))
}
}
fn generate_ws_key() -> String {
let mut key = [0u8; 16];
for (i, byte) in key.iter_mut().enumerate() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0));
*byte = (now.as_nanos() as u8).wrapping_add(i as u8);
}
base64::engine::general_purpose::STANDARD.encode(key)
}
fn extract_host(url: &str) -> String {
url.replace("wss://", "")
.replace("ws://", "")
.split('/')
.next()
.unwrap_or("api.x.ai")
.to_string()
}