use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::events::{AudioFormat, SessionConfig, TurnDetection, Voice};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealtimeConfig {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default = "default_ws_url")]
pub ws_url: String,
#[serde(default)]
pub voice: Voice,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(default)]
pub input_audio_format: AudioFormat,
#[serde(default)]
pub output_audio_format: AudioFormat,
#[serde(skip_serializing_if = "Option::is_none")]
pub turn_detection: Option<TurnDetection>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(default = "default_timeout")]
pub timeout_seconds: u64,
#[serde(default = "default_ping_interval")]
pub ping_interval_seconds: u64,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default)]
pub transcribe_input: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub transcription_model: Option<String>,
}
fn default_ws_url() -> String {
"wss://api.openai.com/v1/realtime".to_string()
}
fn default_timeout() -> u64 {
30
}
fn default_ping_interval() -> u64 {
30
}
impl Default for RealtimeConfig {
fn default() -> Self {
Self {
model: "gpt-4o-realtime-preview".to_string(),
api_key: None,
ws_url: default_ws_url(),
voice: Voice::default(),
instructions: None,
input_audio_format: AudioFormat::default(),
output_audio_format: AudioFormat::default(),
turn_detection: None,
temperature: None,
max_output_tokens: None,
timeout_seconds: default_timeout(),
ping_interval_seconds: default_ping_interval(),
headers: HashMap::new(),
transcribe_input: false,
transcription_model: None,
}
}
}
impl RealtimeConfig {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
pub fn from_env() -> Option<Self> {
let api_key = std::env::var("OPENAI_API_KEY").ok()?;
Some(Self::default().api_key(api_key))
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn ws_url(mut self, url: impl Into<String>) -> Self {
self.ws_url = url.into();
self
}
pub fn voice(mut self, voice: Voice) -> Self {
self.voice = voice;
self
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn input_audio_format(mut self, format: AudioFormat) -> Self {
self.input_audio_format = format;
self
}
pub fn output_audio_format(mut self, format: AudioFormat) -> Self {
self.output_audio_format = format;
self
}
pub fn turn_detection(mut self, detection: TurnDetection) -> Self {
self.turn_detection = Some(detection);
self
}
pub fn no_turn_detection(mut self) -> Self {
self.turn_detection = Some(TurnDetection::None);
self
}
pub fn temperature(mut self, temp: f64) -> Self {
self.temperature = Some(temp);
self
}
pub fn max_output_tokens(mut self, tokens: u32) -> Self {
self.max_output_tokens = Some(tokens);
self
}
pub fn timeout(mut self, seconds: u64) -> Self {
self.timeout_seconds = seconds;
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_transcription(mut self) -> Self {
self.transcribe_input = true;
self
}
pub fn transcription_model(mut self, model: impl Into<String>) -> Self {
self.transcription_model = Some(model.into());
self.transcribe_input = true;
self
}
pub fn to_session_config(&self) -> SessionConfig {
use super::events::InputAudioTranscription;
SessionConfig {
modalities: Some(vec!["text".to_string(), "audio".to_string()]),
instructions: self.instructions.clone(),
voice: Some(self.voice),
input_audio_format: Some(self.input_audio_format),
output_audio_format: Some(self.output_audio_format),
input_audio_transcription: if self.transcribe_input {
Some(InputAudioTranscription {
model: self
.transcription_model
.clone()
.unwrap_or_else(|| "whisper-1".to_string()),
})
} else {
None
},
turn_detection: self.turn_detection.clone(),
tools: None,
tool_choice: None,
temperature: self.temperature,
max_response_output_tokens: self
.max_output_tokens
.map(super::events::MaxTokens::Number),
}
}
pub fn get_ws_url(&self) -> String {
format!("{}?model={}", self.ws_url, self.model)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = RealtimeConfig::default();
assert_eq!(config.model, "gpt-4o-realtime-preview");
assert_eq!(config.voice, Voice::Alloy);
assert!(config.api_key.is_none());
}
#[test]
fn test_config_builder() {
let config = RealtimeConfig::new("gpt-4o-realtime-preview")
.api_key("sk-test")
.voice(Voice::Nova)
.instructions("You are a helpful assistant")
.temperature(0.7)
.max_output_tokens(1000)
.with_transcription();
assert_eq!(config.model, "gpt-4o-realtime-preview");
assert_eq!(config.api_key, Some("sk-test".to_string()));
assert_eq!(config.voice, Voice::Nova);
assert_eq!(
config.instructions,
Some("You are a helpful assistant".to_string())
);
assert_eq!(config.temperature, Some(0.7));
assert_eq!(config.max_output_tokens, Some(1000));
assert!(config.transcribe_input);
}
#[test]
fn test_get_ws_url() {
let config = RealtimeConfig::new("gpt-4o-realtime-preview");
let url = config.get_ws_url();
assert!(url.contains("model=gpt-4o-realtime-preview"));
}
#[test]
fn test_to_session_config() {
let config = RealtimeConfig::new("gpt-4o-realtime-preview")
.voice(Voice::Echo)
.instructions("Test instructions")
.temperature(0.5);
let session = config.to_session_config();
assert_eq!(session.voice, Some(Voice::Echo));
assert_eq!(session.instructions, Some("Test instructions".to_string()));
assert_eq!(session.temperature, Some(0.5));
}
#[test]
fn test_turn_detection() {
let config = RealtimeConfig::default().turn_detection(TurnDetection::ServerVad {
threshold: Some(0.5),
prefix_padding_ms: Some(300),
silence_duration_ms: Some(500),
});
assert!(config.turn_detection.is_some());
let config = RealtimeConfig::default().no_turn_detection();
assert!(matches!(config.turn_detection, Some(TurnDetection::None)));
}
}