1use base64::prelude::*;
4use futures_util::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_tungstenite::{
7 connect_async_with_config,
8 tungstenite::{client::IntoClientRequest, Message as WsMessage},
9 MaybeTlsStream, WebSocketStream,
10};
11
12use crate::common::auth::AuthProvider;
13use crate::common::errors::{OpenAIToolError, Result};
14use crate::common::models::RealtimeModel;
15use crate::common::tool::Tool;
16
17use super::audio::{AudioFormat, InputAudioTranscription, TranscriptionModel, Voice};
18use super::conversation::{ConversationItem, FunctionCallOutputItem, MessageItem};
19use super::events::client::ClientEvent;
20use super::events::server::ServerEvent;
21use super::session::{Modality, RealtimeTool, ResponseCreateConfig, SessionConfig};
22use super::vad::{SemanticVadConfig, ServerVadConfig, TurnDetection};
23
24const REALTIME_PATH: &str = "realtime";
26
27#[derive(Debug, Clone)]
51pub struct RealtimeClient {
52 auth: AuthProvider,
54 model: RealtimeModel,
55 session_config: SessionConfig,
56}
57
58impl RealtimeClient {
59 pub fn new() -> Self {
63 let auth = AuthProvider::openai_from_env().expect("OPENAI_API_KEY must be set");
64 Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
65 }
66
67 pub fn with_auth(auth: AuthProvider) -> Self {
69 Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
70 }
71
72 pub fn azure() -> Result<Self> {
74 let auth = AuthProvider::azure_from_env()?;
75 Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
76 }
77
78 pub fn detect_provider() -> Result<Self> {
80 let auth = AuthProvider::from_env()?;
81 Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
82 }
83
84 pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
86 let auth = AuthProvider::from_url_with_key(base_url, api_key);
87 Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
88 }
89
90 pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
92 let auth = AuthProvider::from_url(url)?;
93 Ok(Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() })
94 }
95
96 #[deprecated(since = "0.3.0", note = "Use `with_auth(AuthProvider::OpenAI(...))` instead")]
98 pub fn with_api_key(api_key: impl Into<String>) -> Self {
99 let auth = AuthProvider::OpenAI(crate::common::auth::OpenAIAuth::new(api_key));
100 Self { auth, model: RealtimeModel::default(), session_config: SessionConfig::default() }
101 }
102
103 pub fn auth(&self) -> &AuthProvider {
105 &self.auth
106 }
107
108 pub fn model(&mut self, model: RealtimeModel) -> &mut Self {
120 self.model = model;
121 self
122 }
123
124 #[deprecated(since = "0.2.0", note = "Use `model(RealtimeModel)` instead for type safety")]
128 pub fn model_id(&mut self, model_id: impl Into<String>) -> &mut Self {
129 self.model = RealtimeModel::from(model_id.into().as_str());
130 self
131 }
132
133 pub fn modalities(&mut self, modalities: Vec<Modality>) -> &mut Self {
135 self.session_config.modalities = Some(modalities);
136 self
137 }
138
139 pub fn instructions(&mut self, instructions: impl Into<String>) -> &mut Self {
141 self.session_config.instructions = Some(instructions.into());
142 self
143 }
144
145 pub fn voice(&mut self, voice: Voice) -> &mut Self {
147 self.session_config.voice = Some(voice);
148 self
149 }
150
151 pub fn input_audio_format(&mut self, format: AudioFormat) -> &mut Self {
153 self.session_config.input_audio_format = Some(format);
154 self
155 }
156
157 pub fn output_audio_format(&mut self, format: AudioFormat) -> &mut Self {
159 self.session_config.output_audio_format = Some(format);
160 self
161 }
162
163 pub fn enable_transcription(&mut self, model: TranscriptionModel) -> &mut Self {
165 self.session_config.input_audio_transcription = Some(InputAudioTranscription::new(model));
166 self
167 }
168
169 pub fn transcription(&mut self, config: InputAudioTranscription) -> &mut Self {
171 self.session_config.input_audio_transcription = Some(config);
172 self
173 }
174
175 pub fn server_vad(&mut self, config: ServerVadConfig) -> &mut Self {
177 self.session_config.turn_detection = Some(TurnDetection::ServerVad(config));
178 self
179 }
180
181 pub fn semantic_vad(&mut self, config: SemanticVadConfig) -> &mut Self {
183 self.session_config.turn_detection = Some(TurnDetection::SemanticVad(config));
184 self
185 }
186
187 pub fn disable_turn_detection(&mut self) -> &mut Self {
189 self.session_config.turn_detection = None;
190 self
191 }
192
193 pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
197 self.session_config.tools = Some(tools.into_iter().map(RealtimeTool::from).collect());
198 self
199 }
200
201 pub fn realtime_tools(&mut self, tools: Vec<RealtimeTool>) -> &mut Self {
203 self.session_config.tools = Some(tools);
204 self
205 }
206
207 pub fn temperature(&mut self, temp: f32) -> &mut Self {
209 self.session_config.temperature = Some(temp);
210 self
211 }
212
213 pub async fn connect(&self) -> Result<RealtimeSession> {
217 let url = self.ws_endpoint();
219
220 let mut request = url.into_client_request().map_err(|e| OpenAIToolError::Error(format!("Failed to build request: {}", e)))?;
222
223 let headers = request.headers_mut();
224
225 match &self.auth {
227 AuthProvider::OpenAI(auth) => {
228 headers.insert(
229 "Authorization",
230 format!("Bearer {}", auth.api_key()).parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?,
231 );
232 }
233 AuthProvider::Azure(auth) => {
234 headers.insert("api-key", auth.api_key().parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?);
235 }
236 }
237 headers.insert("OpenAI-Beta", "realtime=v1".parse().map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?);
238
239 let (ws_stream, _response) = connect_async_with_config(request, None, false)
240 .await
241 .map_err(|e| OpenAIToolError::Error(format!("WebSocket connection failed: {}", e)))?;
242
243 let mut session = RealtimeSession::new(ws_stream);
244
245 session.wait_for_session_created().await?;
247
248 if self.session_config.modalities.is_some()
250 || self.session_config.instructions.is_some()
251 || self.session_config.voice.is_some()
252 || self.session_config.tools.is_some()
253 || self.session_config.turn_detection.is_some()
254 {
255 session.update_session(self.session_config.clone()).await?;
256 }
257
258 Ok(session)
259 }
260
261 fn ws_endpoint(&self) -> String {
263 match &self.auth {
264 AuthProvider::OpenAI(_) => {
265 format!("wss://api.openai.com/v1/{}?model={}", REALTIME_PATH, self.model.as_str())
266 }
267 AuthProvider::Azure(auth) => {
268 let base = auth.base_url();
273 if base.starts_with("https://") {
274 base.replacen("https://", "wss://", 1)
275 } else if base.starts_with("http://") {
276 base.replacen("http://", "ws://", 1)
277 } else {
278 base.to_string()
279 }
280 }
281 }
282 }
283}
284
285impl Default for RealtimeClient {
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291pub struct RealtimeSession {
295 ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
296}
297
298impl RealtimeSession {
299 fn new(ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
301 Self { ws_stream }
302 }
303
304 pub async fn send(&mut self, event: ClientEvent) -> Result<()> {
306 let json = serde_json::to_string(&event)?;
307 self.ws_stream.send(WsMessage::Text(json.into())).await.map_err(|e| OpenAIToolError::Error(format!("Failed to send event: {}", e)))?;
308 Ok(())
309 }
310
311 pub async fn recv(&mut self) -> Result<Option<ServerEvent>> {
315 loop {
316 match self.ws_stream.next().await {
317 Some(Ok(WsMessage::Text(text))) => {
318 let event: ServerEvent = serde_json::from_str(&text)?;
319 return Ok(Some(event));
320 }
321 Some(Ok(WsMessage::Close(_))) => {
322 return Ok(None);
323 }
324 Some(Ok(WsMessage::Ping(data))) => {
325 self.ws_stream.send(WsMessage::Pong(data)).await.map_err(|e| OpenAIToolError::Error(format!("Failed to send pong: {}", e)))?;
326 continue;
327 }
328 Some(Ok(_)) => continue, Some(Err(e)) => {
330 return Err(OpenAIToolError::Error(format!("WebSocket error: {}", e)));
331 }
332 None => {
333 return Ok(None);
334 }
335 }
336 }
337 }
338
339 async fn wait_for_session_created(&mut self) -> Result<()> {
341 match self.recv().await? {
342 Some(ServerEvent::SessionCreated(_)) => Ok(()),
343 Some(ServerEvent::Error(e)) => Err(OpenAIToolError::Error(format!("Session creation failed: {}", e.error.message))),
344 Some(event) => {
345 Err(OpenAIToolError::Error(format!("Unexpected event while waiting for session.created: {:?}", std::mem::discriminant(&event))))
346 }
347 None => Err(OpenAIToolError::Error("Connection closed before session.created".to_string())),
348 }
349 }
350
351 pub async fn update_session(&mut self, config: SessionConfig) -> Result<()> {
353 self.send(ClientEvent::SessionUpdate { event_id: None, session: config }).await
354 }
355
356 pub async fn append_audio(&mut self, audio_base64: &str) -> Result<()> {
358 self.send(ClientEvent::InputAudioBufferAppend { event_id: None, audio: audio_base64.to_string() }).await
359 }
360
361 pub async fn append_audio_bytes(&mut self, audio: &[u8]) -> Result<()> {
365 let encoded = BASE64_STANDARD.encode(audio);
366 self.append_audio(&encoded).await
367 }
368
369 pub async fn commit_audio(&mut self) -> Result<()> {
373 self.send(ClientEvent::InputAudioBufferCommit { event_id: None }).await
374 }
375
376 pub async fn clear_audio(&mut self) -> Result<()> {
378 self.send(ClientEvent::InputAudioBufferClear { event_id: None }).await
379 }
380
381 pub async fn create_item(&mut self, item: ConversationItem) -> Result<()> {
383 self.send(ClientEvent::ConversationItemCreate { event_id: None, previous_item_id: None, item }).await
384 }
385
386 pub async fn send_text(&mut self, text: &str) -> Result<()> {
388 let item = ConversationItem::Message(MessageItem::user_text(text));
389 self.create_item(item).await
390 }
391
392 pub async fn create_response(&mut self, config: Option<ResponseCreateConfig>) -> Result<()> {
394 self.send(ClientEvent::ResponseCreate { event_id: None, response: config }).await
395 }
396
397 pub async fn cancel_response(&mut self) -> Result<()> {
399 self.send(ClientEvent::ResponseCancel { event_id: None }).await
400 }
401
402 pub async fn submit_function_output(&mut self, call_id: &str, output: &str) -> Result<()> {
404 let item = ConversationItem::FunctionCallOutput(FunctionCallOutputItem::new(call_id, output));
405 self.create_item(item).await
406 }
407
408 pub async fn delete_item(&mut self, item_id: &str) -> Result<()> {
410 self.send(ClientEvent::ConversationItemDelete { event_id: None, item_id: item_id.to_string() }).await
411 }
412
413 pub async fn truncate_item(&mut self, item_id: &str, content_index: u32, audio_end_ms: u32) -> Result<()> {
415 self.send(ClientEvent::ConversationItemTruncate { event_id: None, item_id: item_id.to_string(), content_index, audio_end_ms }).await
416 }
417
418 pub async fn close(&mut self) -> Result<()> {
420 self.ws_stream.close(None).await.map_err(|e| OpenAIToolError::Error(format!("Failed to close session: {}", e)))
421 }
422}