Skip to main content

openai_tools/realtime/
client.rs

1//! Realtime API client implementation.
2
3use base64::prelude::*;
4use futures_util::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_tungstenite::{
7    connect_async_with_config,
8    tungstenite::{client::IntoClientRequest, Message as WsMessage},
9    MaybeTlsStream, WebSocketStream,
10};
11
12use crate::common::auth::AuthProvider;
13use crate::common::errors::{OpenAIToolError, Result};
14use crate::common::models::RealtimeModel;
15use crate::common::tool::Tool;
16
17use super::audio::{AudioFormat, InputAudioTranscription, TranscriptionModel, Voice};
18use super::conversation::{ConversationItem, FunctionCallOutputItem, MessageItem};
19use super::events::client::ClientEvent;
20use super::events::server::ServerEvent;
21use super::session::{Modality, RealtimeTool, ResponseCreateConfig, SessionConfig};
22use super::vad::{SemanticVadConfig, ServerVadConfig, TurnDetection};
23
24/// The Realtime API WebSocket endpoint path.
25const REALTIME_PATH: &str = "realtime";
26
27/// Builder for creating Realtime API connections.
28///
29/// # Example
30///
31/// ```rust,no_run
32/// use openai_tools::realtime::{RealtimeClient, Modality, Voice};
33/// use openai_tools::common::models::RealtimeModel;
34///
35/// #[tokio::main]
36/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
37///     let mut client = RealtimeClient::new();
38///     client
39///         .model(RealtimeModel::Gpt4oRealtimePreview)
40///         .modalities(vec![Modality::Text, Modality::Audio])
41///         .voice(Voice::Alloy)
42///         .instructions("You are a helpful assistant.");
43///
44///     let mut session = client.connect().await?;
45///     // Use session...
46///     session.close().await?;
47///     Ok(())
48/// }
49/// ```
50#[derive(Debug, Clone)]
51pub struct RealtimeClient {
52    /// Authentication provider (OpenAI or Azure)
53    auth: AuthProvider,
54    model: RealtimeModel,
55    session_config: SessionConfig,
56}
57
58impl RealtimeClient {
59    /// Create a new RealtimeClient for OpenAI API.
60    ///
61    /// Loads the API key from the `OPENAI_API_KEY` environment variable.
62    pub fn new() -> Self {
63        let auth = AuthProvider::openai_from_env().expect("OPENAI_API_KEY must be set");
64        Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
65    }
66
67    /// Create a new RealtimeClient with a custom authentication provider.
68    pub fn with_auth(auth: AuthProvider) -> Self {
69        Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
70    }
71
72    /// Create a new RealtimeClient for Azure OpenAI API.
73    pub fn azure() -> Result<Self> {
74        let auth = AuthProvider::azure_from_env()?;
75        Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
76    }
77
78    /// Create a new RealtimeClient by auto-detecting the provider.
79    pub fn detect_provider() -> Result<Self> {
80        let auth = AuthProvider::from_env()?;
81        Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
82    }
83
84    /// Creates a new RealtimeClient with URL-based provider detection.
85    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
86        let auth = AuthProvider::from_url_with_key(base_url, api_key);
87        Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
88    }
89
90    /// Creates a new RealtimeClient from URL using environment variables.
91    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
92        let auth = AuthProvider::from_url(url)?;
93        Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
94    }
95
96    /// Create a new RealtimeClient with an explicit API key.
97    #[deprecated(since = "0.3.0", note = "Use `with_auth(AuthProvider::OpenAI(...))` instead")]
98    pub fn with_api_key(api_key: impl Into<String>) -> Self {
99        let auth = AuthProvider::OpenAI(crate::common::auth::OpenAIAuth::new(api_key));
100        Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
101    }
102
103    /// Returns the authentication provider.
104    pub fn auth(&self) -> &AuthProvider {
105        &self.auth
106    }
107
108    /// Set the model for the Realtime API.
109    ///
110    /// # Example
111    ///
112    /// ```rust,no_run
113    /// use openai_tools::realtime::RealtimeClient;
114    /// use openai_tools::common::models::RealtimeModel;
115    ///
116    /// let mut client = RealtimeClient::new();
117    /// client.model(RealtimeModel::Gpt4oRealtimePreview);
118    /// ```
119    pub fn model(&mut self, model: RealtimeModel) -> &mut Self {
120        self.model = model;
121        self
122    }
123
124    /// Set the model using a string ID (for backward compatibility).
125    ///
126    /// Prefer using [`model`] with `RealtimeModel` enum for type safety.
127    #[deprecated(since = "0.2.0", note = "Use `model(RealtimeModel)` instead for type safety")]
128    pub fn model_id(&mut self, model_id: impl Into<String>) -> &mut Self {
129        self.model = RealtimeModel::from(model_id.into().as_str());
130        self
131    }
132
133    /// Set the supported modalities.
134    pub fn modalities(&mut self, modalities: Vec<Modality>) -> &mut Self {
135        self.session_config.modalities = Some(modalities);
136        self
137    }
138
139    /// Set the system instructions.
140    pub fn instructions(&mut self, instructions: impl Into<String>) -> &mut Self {
141        self.session_config.instructions = Some(instructions.into());
142        self
143    }
144
145    /// Set the voice for audio output.
146    pub fn voice(&mut self, voice: Voice) -> &mut Self {
147        self.session_config.voice = Some(voice);
148        self
149    }
150
151    /// Set the input audio format.
152    pub fn input_audio_format(&mut self, format: AudioFormat) -> &mut Self {
153        self.session_config.input_audio_format = Some(format);
154        self
155    }
156
157    /// Set the output audio format.
158    pub fn output_audio_format(&mut self, format: AudioFormat) -> &mut Self {
159        self.session_config.output_audio_format = Some(format);
160        self
161    }
162
163    /// Enable input audio transcription.
164    pub fn enable_transcription(&mut self, model: TranscriptionModel) -> &mut Self {
165        self.session_config.input_audio_transcription = Some(InputAudioTranscription::new(model));
166        self
167    }
168
169    /// Set input audio transcription configuration.
170    pub fn transcription(&mut self, config: InputAudioTranscription) -> &mut Self {
171        self.session_config.input_audio_transcription = Some(config);
172        self
173    }
174
175    /// Set Server VAD turn detection.
176    pub fn server_vad(&mut self, config: ServerVadConfig) -> &mut Self {
177        self.session_config.turn_detection = Some(TurnDetection::ServerVad(config));
178        self
179    }
180
181    /// Set Semantic VAD turn detection.
182    pub fn semantic_vad(&mut self, config: SemanticVadConfig) -> &mut Self {
183        self.session_config.turn_detection = Some(TurnDetection::SemanticVad(config));
184        self
185    }
186
187    /// Disable turn detection (manual mode).
188    pub fn disable_turn_detection(&mut self) -> &mut Self {
189        self.session_config.turn_detection = None;
190        self
191    }
192
193    /// Set available tools for function calling.
194    ///
195    /// Accepts `Tool` from the common module and converts to `RealtimeTool`.
196    pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
197        self.session_config.tools = Some(tools.into_iter().map(RealtimeTool::from).collect());
198        self
199    }
200
201    /// Set available realtime tools directly.
202    pub fn realtime_tools(&mut self, tools: Vec<RealtimeTool>) -> &mut Self {
203        self.session_config.tools = Some(tools);
204        self
205    }
206
207    /// Set the sampling temperature.
208    pub fn temperature(&mut self, temp: f32) -> &mut Self {
209        self.session_config.temperature = Some(temp);
210        self
211    }
212
213    /// Connect to the Realtime API.
214    ///
215    /// Returns a `RealtimeSession` for sending and receiving events.
216    pub async fn connect(&self) -> Result<RealtimeSession> {
217        // Get the WebSocket URL based on auth provider
218        let url = self.ws_endpoint();
219
220        // Build WebSocket request with headers
221        let mut request = url.into_client_request().map_err(|e| OpenAIToolError::Error(format!("Failed to build request: {}", e)))?;
222
223        let headers = request.headers_mut();
224
225        // Apply auth headers based on provider
226        match &self.auth {
227            AuthProvider::OpenAI(auth) => {
228                headers.insert(
229                    "Authorization",
230                    format!("Bearer {}", auth.api_key()).parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?,
231                );
232            }
233            AuthProvider::Azure(auth) => {
234                headers.insert("api-key", auth.api_key().parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?);
235            }
236        }
237        headers.insert("OpenAI-Beta", "realtime=v1".parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?);
238
239        let (ws_stream, _response) = connect_async_with_config(request, None, false)
240            .await
241            .map_err(|e| OpenAIToolError::Error(format!("WebSocket connection failed: {}", e)))?;
242
243        let mut session = RealtimeSession::new(ws_stream);
244
245        // Wait for session.created event
246        session.wait_for_session_created().await?;
247
248        // Send initial session.update if we have configuration
249        if self.session_config.modalities.is_some()
250            || self.session_config.instructions.is_some()
251            || self.session_config.voice.is_some()
252            || self.session_config.tools.is_some()
253            || self.session_config.turn_detection.is_some()
254        {
255            session.update_session(self.session_config.clone()).await?;
256        }
257
258        Ok(session)
259    }
260
261    /// Get the WebSocket endpoint URL based on auth provider.
262    fn ws_endpoint(&self) -> String {
263        match &self.auth {
264            AuthProvider::OpenAI(_) => {
265                format!("wss://api.openai.com/v1/{}?model={}", REALTIME_PATH, self.model.as_str())
266            }
267            AuthProvider::Azure(auth) => {
268                // Azure WebSocket endpoint: convert https to wss and use the base_url directly
269                // User should provide a WebSocket-compatible URL like:
270                // "wss://my-resource.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=my-deployment"
271                // or if they provide https, we convert it to wss
272                let base = auth.base_url();
273                if base.starts_with("https://") {
274                    base.replacen("https://", "wss://", 1)
275                } else if base.starts_with("http://") {
276                    base.replacen("http://", "ws://", 1)
277                } else {
278                    base.to_string()
279                }
280            }
281        }
282    }
283}
284
285impl Default for RealtimeClient {
286    fn default() -> Self {
287        Self::new()
288    }
289}
290
291/// An active Realtime API session.
292///
293/// Provides methods for sending events and receiving responses.
294pub struct RealtimeSession {
295    ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
296}
297
298impl RealtimeSession {
299    /// Create a new session from a WebSocket stream.
300    fn new(ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
301        Self { ws_stream }
302    }
303
304    /// Send a client event to the server.
305    pub async fn send(&mut self, event: ClientEvent) -> Result<()> {
306        let json = serde_json::to_string(&event)?;
307        self.ws_stream.send(WsMessage::Text(json.into())).await.map_err(|e| OpenAIToolError::Error(format!("Failed to send event: {}", e)))?;
308        Ok(())
309    }
310
311    /// Receive the next server event.
312    ///
313    /// Returns `None` if the connection is closed.
314    pub async fn recv(&mut self) -> Result<Option<ServerEvent>> {
315        loop {
316            match self.ws_stream.next().await {
317                Some(Ok(WsMessage::Text(text))) => {
318                    let event: ServerEvent = serde_json::from_str(&text)?;
319                    return Ok(Some(event));
320                }
321                Some(Ok(WsMessage::Close(_))) => {
322                    return Ok(None);
323                }
324                Some(Ok(WsMessage::Ping(data))) => {
325                    self.ws_stream.send(WsMessage::Pong(data)).await.map_err(|e| OpenAIToolError::Error(format!("Failed to send pong: {}", e)))?;
326                    continue;
327                }
328                Some(Ok(_)) => continue, // Ignore other message types
329                Some(Err(e)) => {
330                    return Err(OpenAIToolError::Error(format!("WebSocket error: {}", e)));
331                }
332                None => {
333                    return Ok(None);
334                }
335            }
336        }
337    }
338
339    /// Wait for the session.created event.
340    async fn wait_for_session_created(&mut self) -> Result<()> {
341        match self.recv().await? {
342            Some(ServerEvent::SessionCreated(_)) => Ok(()),
343            Some(ServerEvent::Error(e)) => Err(OpenAIToolError::Error(format!("Session creation failed: {}", e.error.message))),
344            Some(event) => {
345                Err(OpenAIToolError::Error(format!("Unexpected event while waiting for session.created: {:?}", std::mem::discriminant(&event))))
346            }
347            None => Err(OpenAIToolError::Error("Connection closed before session.created".to_string())),
348        }
349    }
350
351    /// Update the session configuration.
352    pub async fn update_session(&mut self, config: SessionConfig) -> Result<()> {
353        self.send(ClientEvent::SessionUpdate { event_id: None, session: config }).await
354    }
355
356    /// Append base64-encoded audio to the input buffer.
357    pub async fn append_audio(&mut self, audio_base64: &str) -> Result<()> {
358        self.send(ClientEvent::InputAudioBufferAppend { event_id: None, audio: audio_base64.to_string() }).await
359    }
360
361    /// Append raw audio bytes to the input buffer.
362    ///
363    /// The bytes will be base64 encoded automatically.
364    pub async fn append_audio_bytes(&mut self, audio: &[u8]) -> Result<()> {
365        let encoded = BASE64_STANDARD.encode(audio);
366        self.append_audio(&encoded).await
367    }
368
369    /// Commit the input audio buffer.
370    ///
371    /// Creates a user message item from the buffered audio.
372    pub async fn commit_audio(&mut self) -> Result<()> {
373        self.send(ClientEvent::InputAudioBufferCommit { event_id: None }).await
374    }
375
376    /// Clear the input audio buffer.
377    pub async fn clear_audio(&mut self) -> Result<()> {
378        self.send(ClientEvent::InputAudioBufferClear { event_id: None }).await
379    }
380
381    /// Create a conversation item.
382    pub async fn create_item(&mut self, item: ConversationItem) -> Result<()> {
383        self.send(ClientEvent::ConversationItemCreate { event_id: None, previous_item_id: None, item }).await
384    }
385
386    /// Send a text message.
387    pub async fn send_text(&mut self, text: &str) -> Result<()> {
388        let item = ConversationItem::Message(MessageItem::user_text(text));
389        self.create_item(item).await
390    }
391
392    /// Create a response from the model.
393    pub async fn create_response(&mut self, config: Option<ResponseCreateConfig>) -> Result<()> {
394        self.send(ClientEvent::ResponseCreate { event_id: None, response: config }).await
395    }
396
397    /// Cancel the current response.
398    pub async fn cancel_response(&mut self) -> Result<()> {
399        self.send(ClientEvent::ResponseCancel { event_id: None }).await
400    }
401
402    /// Submit the output of a function call.
403    pub async fn submit_function_output(&mut self, call_id: &str, output: &str) -> Result<()> {
404        let item = ConversationItem::FunctionCallOutput(FunctionCallOutputItem::new(call_id, output));
405        self.create_item(item).await
406    }
407
408    /// Delete a conversation item.
409    pub async fn delete_item(&mut self, item_id: &str) -> Result<()> {
410        self.send(ClientEvent::ConversationItemDelete { event_id: None, item_id: item_id.to_string() }).await
411    }
412
413    /// Truncate a conversation item's audio.
414    pub async fn truncate_item(&mut self, item_id: &str, content_index: u32, audio_end_ms: u32) -> Result<()> {
415        self.send(ClientEvent::ConversationItemTruncate { event_id: None, item_id: item_id.to_string(), content_index, audio_end_ms }).await
416    }
417
418    /// Close the session.
419    pub async fn close(&mut self) -> Result<()> {
420        self.ws_stream.close(None).await.map_err(|e| OpenAIToolError::Error(format!("Failed to close session: {}", e)))
421    }
422}