Skip to main content

oai_rt_rs/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
2#![allow(clippy::module_name_repetitions)]
3#![allow(clippy::multiple_crate_versions)]
4
5pub mod error;
6pub mod protocol;
7pub mod sdk;
8pub mod transport;
9
10pub use error::{Error, Result};
11pub use protocol::client_events::ClientEvent;
12pub use protocol::models::{
13    ApprovalFilter, ApprovalMode, AudioConfig, AudioFormat, CachedTokenDetails, ContentPart,
14    ConversationMode, Eagerness, Infinite, InputAudioConfig, InputAudioTranscription, InputItem,
15    InputTokenDetails, Item, ItemStatus, MaxTokens, McpError, McpToolConfig, McpToolInfo, Modality,
16    NoiseReduction, NoiseReductionType, OutputAudioConfig, OutputModalities, OutputTokenDetails,
17    PromptRef, RequireApproval, Response, ResponseConfig, ResponseStatus, RetentionRatioTruncation,
18    Role, Session, SessionConfig, SessionKind, SessionUpdate, SessionUpdateConfig, Temperature,
19    TokenLimits, Tool, ToolChoice, ToolChoiceMode, Tracing, TracingAuto, TracingConfig, Truncation,
20    TruncationStrategy, TruncationType, Usage, Voice,
21};
22pub use protocol::server_events::ServerEvent;
23pub use sdk::{
24    AudioChunk, AudioIn, EventStream, Realtime, RealtimeBuilder, ResponseBuilder, SdkEvent,
25    Session as RealtimeSession, SessionHandle, ToolCall, ToolFuture, ToolRegistry, ToolResult,
26    ToolSpec, TranscriptChunk, VoiceEvent, VoiceEventStream, VoiceSessionBuilder,
27};
28
29use crate::protocol::models;
30use futures::stream::BoxStream;
31use futures::{SinkExt, StreamExt};
32use serde_json::from_str;
33use tokio_tungstenite::tungstenite::protocol::Message;
34use transport::ws::WsStream;
35
36const TRACE_LOG_MAX_BYTES: usize = 1024;
37const MAX_INPUT_AUDIO_CHUNK_BYTES: usize = 15 * 1024 * 1024;
38const TRACE_TRUNCATE_SUFFIX: &str = "... (truncated)";
39
40/// The main client for interacting with the `OpenAI` Realtime API.
41///
42/// Thread safety: `RealtimeClient` is `Send` but not `Sync` because the underlying
43/// WebSocket stream is not `Sync`.
44#[must_use]
45pub struct RealtimeClient {
46    stream: WsStream,
47}
48
49impl RealtimeClient {
50    /// Connect to the `OpenAI` Realtime API.
51    ///
52    /// # Errors
53    /// Returns an error if the connection fails or if the URL is invalid.
54    pub async fn connect(
55        api_key: &str,
56        model: Option<&str>,
57        call_id: Option<&str>,
58    ) -> Result<Self> {
59        let stream = transport::ws::connect(api_key, model, call_id).await?;
60        Ok(Self { stream })
61    }
62
63    /// Send a client event to the server.
64    ///
65    /// # Errors
66    /// Returns an error if serialization fails or if the WebSocket send fails.
67    pub async fn send(&mut self, event: ClientEvent) -> Result<()> {
68        validate_client_event(&event)?;
69        let json = serde_json::to_string(&event)?;
70        tracing::trace!(
71            "Sending event: {}",
72            safe_truncate(&json, TRACE_LOG_MAX_BYTES)
73        );
74        self.stream.send(Message::Text(json.into())).await?;
75        Ok(())
76    }
77
78    /// Receive the next server event.
79    ///
80    /// # Errors
81    /// Returns an error if deserialization fails or if the WebSocket fails.
82    pub async fn next_event(&mut self) -> Result<Option<ServerEvent>> {
83        while let Some(msg) = self.stream.next().await {
84            match msg? {
85                Message::Text(text) => {
86                    tracing::trace!(
87                        "Received event: {}",
88                        safe_truncate(&text, TRACE_LOG_MAX_BYTES)
89                    );
90                    return Ok(Some(from_str::<ServerEvent>(&text)?));
91                }
92                Message::Close(_) => {
93                    tracing::info!("WebSocket connection closed by server");
94                    return Ok(None);
95                }
96                Message::Ping(payload) => {
97                    tracing::debug!("Received Ping, sending Pong");
98                    self.stream.send(Message::Pong(payload)).await?;
99                }
100                _ => (),
101            }
102        }
103        Ok(None)
104    }
105
106    /// Split the client into a sender and a receiver for concurrent usage.
107    pub fn split(self) -> (RealtimeSender, RealtimeReceiver) {
108        let (write, read) = self.stream.split();
109        (RealtimeSender { write }, RealtimeReceiver { read })
110    }
111
112    /// Re-unify a split client.
113    ///
114    /// # Errors
115    /// Returns an error if the split halves don't match or cannot be reunited.
116    #[allow(clippy::result_large_err)]
117    pub fn unsplit(sender: RealtimeSender, receiver: RealtimeReceiver) -> Result<Self> {
118        let stream = receiver.read.reunite(sender.write)?;
119        Ok(Self { stream })
120    }
121}
122
123fn safe_truncate(s: &str, max_bytes: usize) -> std::borrow::Cow<'_, str> {
124    if s.len() <= max_bytes {
125        return std::borrow::Cow::Borrowed(s);
126    }
127
128    let mut end = max_bytes;
129    while end > 0 && !s.is_char_boundary(end) {
130        end -= 1;
131    }
132    std::borrow::Cow::Owned(format!(
133        "{} {} {} bytes",
134        &s[..end],
135        TRACE_TRUNCATE_SUFFIX,
136        s.len() - end
137    ))
138}
139
140/// The sending half of a split `RealtimeClient`.
141pub struct RealtimeSender {
142    write: futures::stream::SplitSink<WsStream, Message>,
143}
144
145impl RealtimeSender {
146    /// Send a client event.
147    ///
148    /// # Errors
149    /// Returns an error if serialization or sending fails.
150    pub async fn send(&mut self, event: ClientEvent) -> Result<()> {
151        validate_client_event(&event)?;
152        let json = serde_json::to_string(&event)?;
153        tracing::trace!(
154            "Sending event (split): {}",
155            safe_truncate(&json, TRACE_LOG_MAX_BYTES)
156        );
157        self.write.send(Message::Text(json.into())).await?;
158        Ok(())
159    }
160}
161
162#[allow(clippy::result_large_err)]
163fn validate_client_event(event: &ClientEvent) -> Result<()> {
164    match event {
165        ClientEvent::InputAudioBufferAppend { audio, .. } => {
166            let size = estimate_base64_decoded_len(audio)?;
167            if size > MAX_INPUT_AUDIO_CHUNK_BYTES {
168                return Err(Error::InvalidClientEvent(format!(
169                    "input_audio_buffer.append exceeds 15MB ({size} bytes)",
170                )));
171            }
172        }
173        ClientEvent::SessionUpdate { session, .. } => {
174            validate_session_update(session.as_ref())?;
175        }
176        ClientEvent::ResponseCreate {
177            response: Some(config),
178            ..
179        } => {
180            validate_response_config(config.as_ref())?;
181        }
182        _ => {}
183    }
184    Ok(())
185}
186
187#[allow(clippy::result_large_err)]
188fn validate_session_update(session: &models::SessionUpdate) -> Result<()> {
189    let config = &session.config;
190    if let Some(format) = &config.input_audio_format {
191        validate_audio_format(format)?;
192    }
193    if let Some(format) = &config.output_audio_format {
194        validate_audio_format(format)?;
195    }
196    if let Some(audio) = &config.audio {
197        validate_audio_config(audio)?;
198    }
199    if let Some(tools) = &config.tools {
200        validate_tools(tools)?;
201    }
202    Ok(())
203}
204
205#[allow(clippy::result_large_err)]
206fn validate_response_config(config: &models::ResponseConfig) -> Result<()> {
207    if let Some(audio) = &config.audio {
208        validate_audio_config(audio)?;
209    }
210    if let Some(format) = &config.input_audio_format {
211        validate_audio_format(format)?;
212        if let Some(audio) = &config.audio {
213            if let Some(input) = &audio.input {
214                if let Some(nested) = &input.format {
215                    if nested != format {
216                        return Err(Error::InvalidClientEvent(
217                            "response.input_audio_format conflicts with response.audio.input.format"
218                                .to_string(),
219                        ));
220                    }
221                }
222            }
223        }
224    }
225    if let Some(tools) = &config.tools {
226        validate_tools(tools)?;
227    }
228    Ok(())
229}
230
231#[allow(clippy::result_large_err)]
232fn validate_audio_config(audio: &models::AudioConfig) -> Result<()> {
233    if let Some(input) = &audio.input {
234        validate_input_audio_config(input)?;
235    }
236    if let Some(output) = &audio.output {
237        validate_output_audio_config(output)?;
238    }
239    Ok(())
240}
241
242#[allow(clippy::result_large_err)]
243fn validate_input_audio_config(audio: &models::InputAudioConfig) -> Result<()> {
244    if let Some(format) = &audio.format {
245        validate_audio_format(format)?;
246    }
247    Ok(())
248}
249
250#[allow(clippy::result_large_err)]
251fn validate_output_audio_config(audio: &models::OutputAudioConfig) -> Result<()> {
252    if let Some(format) = &audio.format {
253        validate_audio_format(format)?;
254    }
255    Ok(())
256}
257
258#[allow(clippy::result_large_err)]
259fn validate_audio_format(format: &models::AudioFormat) -> Result<()> {
260    format.validate()?;
261    Ok(())
262}
263
264#[allow(clippy::result_large_err)]
265fn validate_tools(tools: &[models::Tool]) -> Result<()> {
266    for tool in tools {
267        if let models::Tool::Mcp(config) = tool {
268            config.validate()?;
269        }
270    }
271    Ok(())
272}
273
274#[allow(clippy::result_large_err)]
275fn estimate_base64_decoded_len(s: &str) -> Result<usize> {
276    let bytes = s.as_bytes();
277    if bytes.len() % 4 != 0 {
278        return Err(Error::InvalidClientEvent(
279            "input_audio_buffer.append invalid base64 length".to_string(),
280        ));
281    }
282
283    let mut padding = 0;
284    let mut seen_padding = false;
285    for &b in bytes {
286        if b == b'=' {
287            seen_padding = true;
288            padding += 1;
289            continue;
290        }
291        if seen_padding {
292            return Err(Error::InvalidClientEvent(
293                "input_audio_buffer.append invalid base64 padding".to_string(),
294            ));
295        }
296        let is_valid = matches!(b,
297            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'+' | b'/'
298        );
299        if !is_valid {
300            return Err(Error::InvalidClientEvent(
301                "input_audio_buffer.append invalid base64 character".to_string(),
302            ));
303        }
304    }
305
306    if padding > 2 {
307        return Err(Error::InvalidClientEvent(
308            "input_audio_buffer.append invalid base64 padding length".to_string(),
309        ));
310    }
311
312    Ok(bytes.len() / 4 * 3 - padding)
313}
314
315/// The receiving half of a split `RealtimeClient`.
316pub struct RealtimeReceiver {
317    read: futures::stream::SplitStream<WsStream>,
318}
319
320impl RealtimeReceiver {
321    /// Exposes an asynchronous stream of `Result<ServerEvent>` that preserves Errors.
322    #[must_use]
323    pub fn try_into_stream(self) -> BoxStream<'static, Result<ServerEvent>> {
324        self.read
325            .map(|res| res.map_err(Error::from))
326            .filter_map(|res| async move {
327                match res {
328                    Ok(Message::Text(text)) => {
329                        tracing::trace!(
330                            "Received event (stream): {}",
331                            safe_truncate(&text, TRACE_LOG_MAX_BYTES)
332                        );
333                        Some(from_str::<ServerEvent>(&text).map_err(Error::from))
334                    }
335                    Ok(_) => None,
336                    Err(e) => Some(Err(e)),
337                }
338            })
339            .boxed()
340    }
341}