use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::connect_async;
use crate::error::{Error, Result};
pub struct RealtimeProvider {
api_key: String,
model: String,
}
impl RealtimeProvider {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
}
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
Error::Configuration("OPENAI_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(&api_key, "gpt-4o-realtime-preview"))
}
pub async fn create_session(&self, config: SessionConfig) -> Result<RealtimeSession> {
let ws_url = format!(
"wss://api.openai.com/v1/realtime?model={}&api_key={}",
self.model, self.api_key
);
let (_ws_stream, _) = connect_async(&ws_url)
.await
.map_err(|e| Error::Configuration(format!("WebSocket connection failed: {}", e)))?;
let (tx, _rx) = mpsc::unbounded_channel();
Ok(RealtimeSession {
config: Arc::new(Mutex::new(config)),
tx,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default)]
pub modalities: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(default)]
pub voice: String,
#[serde(default)]
pub input_audio_format: String,
#[serde(default)]
pub output_audio_format: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_activity_detection: Option<VadConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_response_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
model: None,
modalities: vec!["text-and-audio".to_string()],
instructions: None,
voice: "alloy".to_string(),
input_audio_format: "pcm16".to_string(),
output_audio_format: "pcm16".to_string(),
voice_activity_detection: Some(VadConfig::default()),
max_response_output_tokens: Some(4096),
tools: None,
tool_choice: None,
temperature: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VadConfig {
#[serde(default = "default_silence_duration")]
pub silence_duration_ms: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub threshold: Option<f32>,
}
fn default_silence_duration() -> u32 {
500
}
impl Default for VadConfig {
fn default() -> Self {
Self {
silence_duration_ms: 500,
threshold: None,
}
}
}
pub struct RealtimeSession {
config: Arc<Mutex<SessionConfig>>,
tx: mpsc::UnboundedSender<ClientEvent>,
}
impl RealtimeSession {
pub async fn send_text(&self, text: &str) -> Result<()> {
let event = ClientEvent::InputUserMessageText {
user_message_text: text.to_string(),
};
self.tx
.send(event)
.map_err(|e| Error::InvalidRequest(format!("Failed to send message: {}", e)))?;
Ok(())
}
pub async fn send_audio(&self, audio_data: Vec<u8>) -> Result<()> {
use base64::Engine;
let base64_audio = base64::engine::general_purpose::STANDARD.encode(&audio_data);
let event = ClientEvent::InputAudioBufferAppend {
audio: base64_audio,
};
self.tx
.send(event)
.map_err(|e| Error::InvalidRequest(format!("Failed to send audio: {}", e)))?;
Ok(())
}
pub async fn commit_audio(&self) -> Result<()> {
let event = ClientEvent::InputAudioBufferCommit {};
self.tx
.send(event)
.map_err(|e| Error::InvalidRequest(format!("Failed to commit audio: {}", e)))?;
Ok(())
}
pub async fn get_config(&self) -> SessionConfig {
self.config.lock().await.clone()
}
pub async fn update_config(&self, config: SessionConfig) -> Result<()> {
let mut current = self.config.lock().await;
*current = config;
Ok(())
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
#[allow(clippy::enum_variant_names)]
enum ClientEvent {
InputUserMessageText { user_message_text: String },
InputAudioBufferAppend { audio: String },
InputAudioBufferCommit {},
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ServerEvent {
SessionCreated { session: SessionData },
SessionUpdated { session: SessionData },
ResponseCreated { response: ResponseData },
ResponseContentPartAdded {
response_id: String,
item_index: u32,
content_part: ContentPart,
},
ResponseTextDelta {
response_id: String,
item_index: u32,
index: u32,
text: String,
},
ResponseAudioTranscriptDelta {
response_id: String,
item_index: u32,
index: u32,
transcript: String,
},
ResponseAudioDelta {
response_id: String,
item_index: u32,
index: u32,
#[serde(rename = "delta")]
audio: String,
},
ResponseDone { response: ResponseData },
RateLimitUpdated { rate_limit_info: RateLimitInfo },
Error { error: ErrorData },
}
#[derive(Debug, Clone, Deserialize)]
pub struct SessionData {
pub id: String,
pub object: String,
pub created_at: String,
pub model: String,
pub modalities: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ResponseData {
pub id: String,
pub object: String,
pub created_at: String,
pub status: String,
pub status_details: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ContentPart {
InputText { text: String },
InputAudio { audio: String },
Text { text: String },
Audio { audio: String },
}
#[derive(Debug, Clone, Deserialize)]
pub struct RateLimitInfo {
pub request_limit_tokens_per_min: u32,
pub request_limit_tokens_reset_seconds: u32,
pub tokens_used_current_request: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ErrorData {
pub code: String,
pub message: String,
pub param: Option<String>,
pub event_id: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let provider = RealtimeProvider::new("test-key", "gpt-4o-realtime-preview");
assert_eq!(provider.api_key, "test-key");
assert_eq!(provider.model, "gpt-4o-realtime-preview");
}
#[test]
fn test_session_config_default() {
let config = SessionConfig::default();
assert_eq!(config.modalities, vec!["text-and-audio"]);
assert_eq!(config.voice, "alloy");
assert_eq!(config.input_audio_format, "pcm16");
assert_eq!(config.output_audio_format, "pcm16");
}
#[test]
fn test_vad_config_default() {
let vad = VadConfig::default();
assert_eq!(vad.silence_duration_ms, 500);
}
#[test]
fn test_session_config_serialization() {
let config = SessionConfig {
model: Some("gpt-4o-realtime-preview".to_string()),
modalities: vec!["text-and-audio".to_string()],
voice: "shimmer".to_string(),
..Default::default()
};
let json = serde_json::to_string(&config).expect("serialization failed");
assert!(json.contains("gpt-4o-realtime-preview"));
assert!(json.contains("shimmer"));
}
#[test]
fn test_server_event_deserialization() {
let json = r#"{
"type": "session_created",
"session": {
"id": "sess_123",
"object": "realtime.session",
"created_at": "2025-01-02T12:00:00Z",
"model": "gpt-4o-realtime-preview",
"modalities": ["text-and-audio"]
}
}"#;
let event: ServerEvent = serde_json::from_str(json).expect("deserialization failed");
match event {
ServerEvent::SessionCreated { session } => {
assert_eq!(session.id, "sess_123");
assert_eq!(session.model, "gpt-4o-realtime-preview");
}
other => {
panic!("expected SessionCreated, got {:?}", other);
}
}
}
#[test]
fn test_error_deserialization() {
let json = r#"{
"type": "error",
"error": {
"code": "invalid_api_key",
"message": "Invalid API key",
"param": null,
"event_id": "evt_123"
}
}"#;
let event: ServerEvent = serde_json::from_str(json).expect("deserialization failed");
match event {
ServerEvent::Error { error } => {
assert_eq!(error.code, "invalid_api_key");
assert_eq!(error.message, "Invalid API key");
}
other => {
panic!("expected Error, got {:?}", other);
}
}
}
#[test]
fn test_rate_limit_deserialization() {
let json = r#"{
"type": "rate_limit_updated",
"rate_limit_info": {
"request_limit_tokens_per_min": 100000,
"request_limit_tokens_reset_seconds": 60,
"tokens_used_current_request": 150
}
}"#;
let event: ServerEvent = serde_json::from_str(json).expect("deserialization failed");
match event {
ServerEvent::RateLimitUpdated { rate_limit_info } => {
assert_eq!(rate_limit_info.request_limit_tokens_per_min, 100000);
assert_eq!(rate_limit_info.tokens_used_current_request, 150);
}
other => {
panic!("expected RateLimitUpdated, got {:?}", other);
}
}
}
}