use base64::prelude::*;
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async_with_config,
tungstenite::{client::IntoClientRequest, Message as WsMessage},
MaybeTlsStream, WebSocketStream,
};
use crate::common::auth::AuthProvider;
use crate::common::errors::{OpenAIToolError, Result};
use crate::common::models::RealtimeModel;
use crate::common::tool::Tool;
use super::audio::{AudioFormat, InputAudioTranscription, TranscriptionModel, Voice};
use super::conversation::{ConversationItem, FunctionCallOutputItem, MessageItem};
use super::events::client::ClientEvent;
use super::events::server::ServerEvent;
use super::session::{Modality, RealtimeTool, ResponseCreateConfig, SessionConfig};
use super::vad::{SemanticVadConfig, ServerVadConfig, TurnDetection};
const REALTIME_PATH: &str = "realtime";
#[derive(Debug, Clone)]
pub struct RealtimeClient {
auth: AuthProvider,
model: RealtimeModel,
session_config: SessionConfig,
}
impl RealtimeClient {
pub fn new() -> Self {
let auth = AuthProvider::openai_from_env().expect("OPENAI_API_KEY must be set");
Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
}
#[deprecated(since = "0.3.0", note = "Use `with_auth(AuthProvider::OpenAI(...))` instead")]
pub fn with_api_key(api_key: impl Into<String>) -> Self {
let auth = AuthProvider::OpenAI(crate::common::auth::OpenAIAuth::new(api_key));
Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn model(&mut self, model: RealtimeModel) -> &mut Self {
self.model = model;
self
}
#[deprecated(since = "0.2.0", note = "Use `model(RealtimeModel)` instead for type safety")]
pub fn model_id(&mut self, model_id: impl Into<String>) -> &mut Self {
self.model = RealtimeModel::from(model_id.into().as_str());
self
}
pub fn modalities(&mut self, modalities: Vec<Modality>) -> &mut Self {
self.session_config.modalities = Some(modalities);
self
}
pub fn instructions(&mut self, instructions: impl Into<String>) -> &mut Self {
self.session_config.instructions = Some(instructions.into());
self
}
pub fn voice(&mut self, voice: Voice) -> &mut Self {
self.session_config.voice = Some(voice);
self
}
pub fn input_audio_format(&mut self, format: AudioFormat) -> &mut Self {
self.session_config.input_audio_format = Some(format);
self
}
pub fn output_audio_format(&mut self, format: AudioFormat) -> &mut Self {
self.session_config.output_audio_format = Some(format);
self
}
pub fn enable_transcription(&mut self, model: TranscriptionModel) -> &mut Self {
self.session_config.input_audio_transcription = Some(InputAudioTranscription::new(model));
self
}
pub fn transcription(&mut self, config: InputAudioTranscription) -> &mut Self {
self.session_config.input_audio_transcription = Some(config);
self
}
pub fn server_vad(&mut self, config: ServerVadConfig) -> &mut Self {
self.session_config.turn_detection = Some(TurnDetection::ServerVad(config));
self
}
pub fn semantic_vad(&mut self, config: SemanticVadConfig) -> &mut Self {
self.session_config.turn_detection = Some(TurnDetection::SemanticVad(config));
self
}
pub fn disable_turn_detection(&mut self) -> &mut Self {
self.session_config.turn_detection = None;
self
}
pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
self.session_config.tools = Some(tools.into_iter().map(RealtimeTool::from).collect());
self
}
pub fn realtime_tools(&mut self, tools: Vec<RealtimeTool>) -> &mut Self {
self.session_config.tools = Some(tools);
self
}
pub fn temperature(&mut self, temp: f32) -> &mut Self {
self.session_config.temperature = Some(temp);
self
}
pub async fn connect(&self) -> Result<RealtimeSession> {
let url = self.ws_endpoint();
let mut request = url.into_client_request().map_err(|e| OpenAIToolError::Error(format!("Failed to build request: {}", e)))?;
let headers = request.headers_mut();
match &self.auth {
AuthProvider::OpenAI(auth) => {
headers.insert(
"Authorization",
format!("Bearer {}", auth.api_key()).parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?,
);
}
AuthProvider::Azure(auth) => {
headers.insert("api-key", auth.api_key().parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?);
}
}
headers.insert("OpenAI-Beta", "realtime=v1".parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?);
let (ws_stream, _response) = connect_async_with_config(request, None, false)
.await
.map_err(|e| OpenAIToolError::Error(format!("WebSocket connection failed: {}", e)))?;
let mut session = RealtimeSession::new(ws_stream);
session.wait_for_session_created().await?;
if self.session_config.modalities.is_some()
|| self.session_config.instructions.is_some()
|| self.session_config.voice.is_some()
|| self.session_config.tools.is_some()
|| self.session_config.turn_detection.is_some()
{
session.update_session(self.session_config.clone()).await?;
}
Ok(session)
}
fn ws_endpoint(&self) -> String {
match &self.auth {
AuthProvider::OpenAI(_) => {
format!("wss://api.openai.com/v1/{}?model={}", REALTIME_PATH, self.model.as_str())
}
AuthProvider::Azure(auth) => {
let base = auth.base_url();
if base.starts_with("https://") {
base.replacen("https://", "wss://", 1)
} else if base.starts_with("http://") {
base.replacen("http://", "ws://", 1)
} else {
base.to_string()
}
}
}
}
}
impl Default for RealtimeClient {
fn default() -> Self {
Self::new()
}
}
pub struct RealtimeSession {
ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
}
impl RealtimeSession {
fn new(ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
Self { ws_stream }
}
pub async fn send(&mut self, event: ClientEvent) -> Result<()> {
let json = serde_json::to_string(&event)?;
self.ws_stream.send(WsMessage::Text(json.into())).await.map_err(|e| OpenAIToolError::Error(format!("Failed to send event: {}", e)))?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Option<ServerEvent>> {
loop {
match self.ws_stream.next().await {
Some(Ok(WsMessage::Text(text))) => {
let event: ServerEvent = serde_json::from_str(&text)?;
return Ok(Some(event));
}
Some(Ok(WsMessage::Close(_))) => {
return Ok(None);
}
Some(Ok(WsMessage::Ping(data))) => {
self.ws_stream.send(WsMessage::Pong(data)).await.map_err(|e| OpenAIToolError::Error(format!("Failed to send pong: {}", e)))?;
continue;
}
Some(Ok(_)) => continue, Some(Err(e)) => {
return Err(OpenAIToolError::Error(format!("WebSocket error: {}", e)));
}
None => {
return Ok(None);
}
}
}
}
async fn wait_for_session_created(&mut self) -> Result<()> {
match self.recv().await? {
Some(ServerEvent::SessionCreated(_)) => Ok(()),
Some(ServerEvent::Error(e)) => Err(OpenAIToolError::Error(format!("Session creation failed: {}", e.error.message))),
Some(event) => {
Err(OpenAIToolError::Error(format!("Unexpected event while waiting for session.created: {:?}", std::mem::discriminant(&event))))
}
None => Err(OpenAIToolError::Error("Connection closed before session.created".to_string())),
}
}
pub async fn update_session(&mut self, config: SessionConfig) -> Result<()> {
self.send(ClientEvent::SessionUpdate { event_id: None, session: config }).await
}
pub async fn append_audio(&mut self, audio_base64: &str) -> Result<()> {
self.send(ClientEvent::InputAudioBufferAppend { event_id: None, audio: audio_base64.to_string() }).await
}
pub async fn append_audio_bytes(&mut self, audio: &[u8]) -> Result<()> {
let encoded = BASE64_STANDARD.encode(audio);
self.append_audio(&encoded).await
}
pub async fn commit_audio(&mut self) -> Result<()> {
self.send(ClientEvent::InputAudioBufferCommit { event_id: None }).await
}
pub async fn clear_audio(&mut self) -> Result<()> {
self.send(ClientEvent::InputAudioBufferClear { event_id: None }).await
}
pub async fn create_item(&mut self, item: ConversationItem) -> Result<()> {
self.send(ClientEvent::ConversationItemCreate { event_id: None, previous_item_id: None, item }).await
}
pub async fn send_text(&mut self, text: &str) -> Result<()> {
let item = ConversationItem::Message(MessageItem::user_text(text));
self.create_item(item).await
}
pub async fn create_response(&mut self, config: Option<ResponseCreateConfig>) -> Result<()> {
self.send(ClientEvent::ResponseCreate { event_id: None, response: config }).await
}
pub async fn cancel_response(&mut self) -> Result<()> {
self.send(ClientEvent::ResponseCancel { event_id: None }).await
}
pub async fn submit_function_output(&mut self, call_id: &str, output: &str) -> Result<()> {
let item = ConversationItem::FunctionCallOutput(FunctionCallOutputItem::new(call_id, output));
self.create_item(item).await
}
pub async fn delete_item(&mut self, item_id: &str) -> Result<()> {
self.send(ClientEvent::ConversationItemDelete { event_id: None, item_id: item_id.to_string() }).await
}
pub async fn truncate_item(&mut self, item_id: &str, content_index: u32, audio_end_ms: u32) -> Result<()> {
self.send(ClientEvent::ConversationItemTruncate { event_id: None, item_id: item_id.to_string(), content_index, audio_end_ms }).await
}
pub async fn close(&mut self) -> Result<()> {
self.ws_stream.close(None).await.map_err(|e| OpenAIToolError::Error(format!("Failed to close session: {}", e)))
}
}