use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::{HeaderName, HeaderValue};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RealtimeError {
#[error("realtime connection error: {0}")]
Connect(String),
#[error("realtime protocol error: {0}")]
Protocol(String),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
}
#[derive(Debug, Clone)]
pub struct RealtimeConnectOptions {
pub api_key: String,
pub base_url: String,
pub model: String,
pub organization: Option<String>,
pub project: Option<String>,
pub extra_headers: Vec<(String, String)>,
}
pub fn build_realtime_url(base_url: &str, model: &str) -> Result<String, RealtimeError> {
let base = base_url.trim();
let base = if base.is_empty() {
crate::DEFAULT_BASE_URL
} else {
base
};
let base = base.trim_end_matches('/');
if base.contains('?') || base.contains('#') {
return Err(RealtimeError::Connect(format!(
"base_url must not contain a query or fragment: {base}"
)));
}
let ws_base = if let Some(rest) = base.strip_prefix("https://") {
format!("wss://{rest}")
} else if let Some(rest) = base.strip_prefix("http://") {
format!("ws://{rest}")
} else if base.starts_with("wss://") || base.starts_with("ws://") {
base.to_string()
} else {
return Err(RealtimeError::Connect(format!(
"unsupported base_url scheme: {base}"
)));
};
Ok(format!(
"{ws_base}/realtime?model={}",
encode_query_component(model)
))
}
fn encode_query_component(value: &str) -> String {
let mut out = String::with_capacity(value.len());
for byte in value.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(byte as char);
}
_ => out.push_str(&format!("%{byte:02X}")),
}
}
out
}
pub async fn connect(options: RealtimeConnectOptions) -> Result<RealtimeSession, RealtimeError> {
let url = build_realtime_url(&options.base_url, &options.model)?;
let mut request = url
.into_client_request()
.map_err(|e| RealtimeError::Connect(e.to_string()))?;
let headers = request.headers_mut();
let auth = HeaderValue::from_str(&format!("Bearer {}", options.api_key))
.map_err(|e| RealtimeError::Connect(format!("invalid api key header: {e}")))?;
headers.insert("Authorization", auth);
headers.insert("OpenAI-Beta", HeaderValue::from_static("realtime=v1"));
if let Some(org) = &options.organization {
let value = HeaderValue::from_str(org)
.map_err(|e| RealtimeError::Connect(format!("invalid organization header: {e}")))?;
headers.insert("OpenAI-Organization", value);
}
if let Some(project) = &options.project {
let value = HeaderValue::from_str(project)
.map_err(|e| RealtimeError::Connect(format!("invalid project header: {e}")))?;
headers.insert("OpenAI-Project", value);
}
for (name, value) in &options.extra_headers {
let header_name = HeaderName::from_bytes(name.as_bytes())
.map_err(|e| RealtimeError::Connect(format!("invalid header name {name}: {e}")))?;
let header_value = HeaderValue::from_str(value)
.map_err(|e| RealtimeError::Connect(format!("invalid value for {name}: {e}")))?;
headers.insert(header_name, header_value);
}
let (ws, _response) = connect_async(request)
.await
.map_err(|e| RealtimeError::Connect(e.to_string()))?;
Ok(RealtimeSession { ws })
}
pub struct RealtimeSession {
ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
}
impl RealtimeSession {
pub async fn send(&mut self, event: serde_json::Value) -> Result<(), RealtimeError> {
let text = serde_json::to_string(&event)?;
self.ws
.send(Message::Text(text))
.await
.map_err(|e| RealtimeError::Protocol(e.to_string()))?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Option<serde_json::Value>, RealtimeError> {
while let Some(message) = self.ws.next().await {
let message = message.map_err(|e| RealtimeError::Protocol(e.to_string()))?;
match message {
Message::Text(text) => {
let value = serde_json::from_str(text.as_str())?;
return Ok(Some(value));
}
Message::Binary(bytes) => {
let value = serde_json::from_slice(bytes.as_ref())?;
return Ok(Some(value));
}
Message::Close(_) => return Ok(None),
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
}
}
Ok(None)
}
pub async fn close(mut self) -> Result<(), RealtimeError> {
self.ws
.close(None)
.await
.map_err(|e| RealtimeError::Protocol(e.to_string()))?;
Ok(())
}
}
pub mod events {
use serde_json::{json, Value};
pub fn session_update(session: Value) -> Value {
json!({ "type": "session.update", "session": session })
}
pub fn input_audio_buffer_append(base64_audio: &str) -> Value {
json!({ "type": "input_audio_buffer.append", "audio": base64_audio })
}
pub fn input_audio_buffer_commit() -> Value {
json!({ "type": "input_audio_buffer.commit" })
}
pub fn conversation_item_create(item: Value) -> Value {
json!({ "type": "conversation.item.create", "item": item })
}
pub fn response_create() -> Value {
json!({ "type": "response.create" })
}
pub fn response_create_with(response: Value) -> Value {
json!({ "type": "response.create", "response": response })
}
}