use std::convert::TryFrom;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use audiopus::coder::{Decoder, Encoder};
use audiopus::{Application, Channels, MutSignals, SampleRate};
use parking_lot::Mutex as ParkingMutex;
use str0m::Rtc;
use str0m::change::SdpAnswer;
use str0m::channel::ChannelId;
use str0m::media::{Direction, Frequency, MediaKind, MediaTime, Mid, Pt};
use tokio::sync::{Mutex, mpsc};
use std::sync::atomic::AtomicU64;
use crate::config::RealtimeConfig;
use crate::error::{RealtimeError, Result};
use crate::events::ServerEvent;
const MAX_OPUS_FRAME_BYTES: usize = 4000;
const MAX_DECODED_SAMPLES_PER_CHANNEL: usize = 5760;
pub struct OpusCodec {
encoder: Encoder,
decoder: Decoder,
sample_rate: SampleRate,
channels: Channels,
}
impl OpusCodec {
pub fn new(sample_rate: u32, channels: u8) -> Result<Self> {
let sample_rate = SampleRate::try_from(sample_rate as i32)
.map_err(|e| RealtimeError::opus(format!("Invalid sample rate {sample_rate}: {e}")))?;
let channels = match channels {
1 => Channels::Mono,
2 => Channels::Stereo,
other => {
return Err(RealtimeError::opus(format!(
"Invalid channel count {other}: must be 1 (mono) or 2 (stereo)"
)));
}
};
let encoder = Encoder::new(sample_rate, channels, Application::Voip)
.map_err(|e| RealtimeError::opus(format!("Failed to create Opus encoder: {e}")))?;
let decoder = Decoder::new(sample_rate, channels)
.map_err(|e| RealtimeError::opus(format!("Failed to create Opus decoder: {e}")))?;
Ok(Self { encoder, decoder, sample_rate, channels })
}
pub fn encode(&mut self, pcm: &[i16]) -> Result<Vec<u8>> {
let mut output = vec![0u8; MAX_OPUS_FRAME_BYTES];
let encoded_len = self
.encoder
.encode(pcm, &mut output)
.map_err(|e| RealtimeError::opus(format!("Opus encode failed: {e}")))?;
output.truncate(encoded_len);
Ok(output)
}
pub fn decode(&mut self, opus_data: &[u8]) -> Result<Vec<i16>> {
let channel_count = match self.channels {
Channels::Mono => 1,
Channels::Stereo => 2,
_ => 1,
};
let max_samples = MAX_DECODED_SAMPLES_PER_CHANNEL * channel_count;
let mut output = vec![0i16; max_samples];
let packet = audiopus::packet::Packet::try_from(opus_data)
.map_err(|e| RealtimeError::opus(format!("Invalid Opus packet: {e}")))?;
let mut_signals = MutSignals::try_from(output.as_mut_slice())
.map_err(|e| RealtimeError::opus(format!("Failed to create output buffer: {e}")))?;
let decoded_samples = self
.decoder
.decode(Some(packet), mut_signals, false)
.map_err(|e| RealtimeError::opus(format!("Opus decode failed: {e}")))?;
let total_samples = decoded_samples * channel_count;
output.truncate(total_samples);
Ok(output)
}
pub fn sample_rate(&self) -> SampleRate {
self.sample_rate
}
pub fn channels(&self) -> Channels {
self.channels
}
}
const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
const DATA_CHANNEL_LABEL: &str = "oai-events";
pub struct OpenAIWebRTCSession {
session_id: String,
connected: Arc<AtomicBool>,
rtc: Arc<Mutex<Rtc>>,
audio_track_id: Mid,
data_channel_id: ChannelId,
opus_encoder: Arc<ParkingMutex<OpusCodec>>,
opus_pt: Pt,
clock_rate: Frequency,
rtp_sample_offset: AtomicU64,
event_rx: Arc<Mutex<mpsc::UnboundedReceiver<Result<ServerEvent>>>>,
event_tx: mpsc::UnboundedSender<Result<ServerEvent>>,
pending_dc_messages: Arc<Mutex<Vec<Vec<u8>>>>,
dc_open: Arc<AtomicBool>,
}
#[derive(Debug, serde::Deserialize)]
struct EphemeralTokenResponse {
client_secret: ClientSecret,
}
#[derive(Debug, serde::Deserialize)]
struct ClientSecret {
value: String,
}
#[derive(Debug, serde::Deserialize)]
struct SdpExchangeResponse {
sdp: String,
#[serde(rename = "type")]
_sdp_type: String,
}
impl OpenAIWebRTCSession {
pub async fn connect(api_key: &str, model_id: &str, _config: RealtimeConfig) -> Result<Self> {
let mut rtc = Rtc::new(Instant::now());
let mut changes = rtc.sdp_api();
let audio_track_id = changes.add_media(
MediaKind::Audio,
Direction::SendRecv,
None, None, None, );
let data_channel_id = changes.add_channel(DATA_CHANNEL_LABEL.to_string());
let (offer, pending) = changes.apply().ok_or_else(|| {
RealtimeError::webrtc("Failed to generate SDP offer: no changes to apply")
})?;
let offer_sdp = offer.to_sdp_string();
tracing::debug!(
audio_mid = %audio_track_id,
channel_id = ?data_channel_id,
"Generated local SDP offer for OpenAI WebRTC"
);
let http_client = reqwest::Client::new();
let ephemeral_token =
Self::request_ephemeral_token(&http_client, api_key, model_id).await?;
tracing::debug!("Obtained ephemeral token for WebRTC signaling");
let answer_sdp =
Self::exchange_sdp(&http_client, &ephemeral_token, model_id, &offer_sdp).await?;
let answer = SdpAnswer::from_sdp_string(&answer_sdp)
.map_err(|e| RealtimeError::webrtc(format!("Failed to parse SDP answer: {e}")))?;
rtc.sdp_api()
.accept_answer(pending, answer)
.map_err(|e| RealtimeError::webrtc(format!("Failed to apply SDP answer: {e}")))?;
tracing::info!(
audio_mid = %audio_track_id,
"OpenAI WebRTC SDP handshake complete"
);
let opus_codec = OpusCodec::new(24000, 1)?;
let (opus_pt, clock_rate) = {
let writer = rtc.writer(audio_track_id).ok_or_else(|| {
RealtimeError::webrtc("Audio track writer not available after SDP answer")
})?;
let params = writer.payload_params().next().ok_or_else(|| {
RealtimeError::webrtc(
"No payload type negotiated for audio track — SDP answer may be invalid",
)
})?;
(params.pt(), params.spec().clock_rate)
};
let (event_tx, event_rx) = mpsc::unbounded_channel();
let session_id = uuid::Uuid::new_v4().to_string();
Ok(Self {
session_id,
connected: Arc::new(AtomicBool::new(true)),
rtc: Arc::new(Mutex::new(rtc)),
audio_track_id,
data_channel_id,
opus_encoder: Arc::new(ParkingMutex::new(opus_codec)),
opus_pt,
clock_rate,
rtp_sample_offset: AtomicU64::new(0),
event_rx: Arc::new(Mutex::new(event_rx)),
event_tx,
pending_dc_messages: Arc::new(Mutex::new(Vec::new())),
dc_open: Arc::new(AtomicBool::new(false)),
})
}
async fn request_ephemeral_token(
client: &reqwest::Client,
api_key: &str,
model_id: &str,
) -> Result<String> {
let url = format!("{}/realtime/sessions", OPENAI_API_BASE);
let response = client
.post(&url)
.header("Authorization", format!("Bearer {api_key}"))
.header("Content-Type", "application/json")
.json(&serde_json::json!({
"model": model_id,
"voice": "alloy",
}))
.send()
.await
.map_err(|e| {
RealtimeError::AuthError(format!("Failed to request ephemeral token: {e}"))
})?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(RealtimeError::AuthError(format!(
"Ephemeral token request failed with status {status}: {body}"
)));
}
let token_response: EphemeralTokenResponse = response.json().await.map_err(|e| {
RealtimeError::AuthError(format!("Failed to parse ephemeral token response: {e}"))
})?;
Ok(token_response.client_secret.value)
}
async fn exchange_sdp(
client: &reqwest::Client,
ephemeral_token: &str,
model_id: &str,
offer_sdp: &str,
) -> Result<String> {
let url = format!("{}/realtime?model={}", OPENAI_API_BASE, model_id);
let response = client
.post(&url)
.header("Authorization", format!("Bearer {ephemeral_token}"))
.header("Content-Type", "application/sdp")
.body(offer_sdp.to_string())
.send()
.await
.map_err(|e| RealtimeError::connection(format!("SDP exchange request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(RealtimeError::connection(format!(
"SDP exchange failed with status {status}: {body}"
)));
}
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let body = response.text().await.map_err(|e| {
RealtimeError::connection(format!("Failed to read SDP answer body: {e}"))
})?;
if content_type.contains("application/sdp") {
Ok(body)
} else {
let parsed: SdpExchangeResponse = serde_json::from_str(&body).map_err(|e| {
RealtimeError::connection(format!(
"Failed to parse SDP exchange response as JSON: {e}"
))
})?;
Ok(parsed.sdp)
}
}
pub fn audio_track_id(&self) -> Mid {
self.audio_track_id
}
pub fn data_channel_id(&self) -> ChannelId {
self.data_channel_id
}
pub fn rtc(&self) -> &Arc<Mutex<Rtc>> {
&self.rtc
}
pub fn event_sender(&self) -> mpsc::UnboundedSender<Result<ServerEvent>> {
self.event_tx.clone()
}
}
impl std::fmt::Debug for OpenAIWebRTCSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIWebRTCSession")
.field("session_id", &self.session_id)
.field("connected", &self.connected.load(Ordering::Relaxed))
.field("audio_track_id", &self.audio_track_id)
.field("data_channel_id", &self.data_channel_id)
.finish()
}
}
use async_trait::async_trait;
use base64::Engine;
use serde_json::Value;
impl OpenAIWebRTCSession {
async fn send_data_channel_message(&self, value: &Value) -> Result<()> {
let json_bytes = serde_json::to_vec(value)
.map_err(|e| RealtimeError::protocol(format!("JSON serialize error: {e}")))?;
if !self.dc_open.load(Ordering::Acquire) {
let mut pending = self.pending_dc_messages.lock().await;
if pending.len() >= 50 {
return Err(RealtimeError::webrtc(
"Data channel message queue full (50 messages). Channel may not be opening.",
));
}
pending.push(json_bytes);
tracing::debug!(
"Data channel not open yet, queued message ({} pending)",
pending.len()
);
return Ok(());
}
let mut rtc = self.rtc.lock().await;
let mut channel = rtc
.channel(self.data_channel_id)
.ok_or_else(|| RealtimeError::webrtc("Data channel 'oai-events' not available"))?;
channel
.write(true, json_bytes.as_slice())
.map_err(|e| RealtimeError::webrtc(format!("Data channel write failed: {e}")))?;
Ok(())
}
pub async fn flush_pending_dc_messages(&self) -> Result<()> {
self.dc_open.store(true, Ordering::Release);
let mut pending = self.pending_dc_messages.lock().await;
if pending.is_empty() {
return Ok(());
}
let count = pending.len();
tracing::info!("Data channel opened — flushing {count} queued messages");
let mut rtc = self.rtc.lock().await;
let mut channel = rtc.channel(self.data_channel_id).ok_or_else(|| {
RealtimeError::webrtc("Data channel 'oai-events' not available during flush")
})?;
for msg in pending.drain(..) {
channel
.write(true, msg.as_slice())
.map_err(|e| RealtimeError::webrtc(format!("Data channel flush failed: {e}")))?;
}
Ok(())
}
async fn write_audio_to_track(&self, pcm_samples: &[i16]) -> Result<()> {
let opus_data = {
let mut encoder = self.opus_encoder.lock();
encoder.encode(pcm_samples)?
};
let clock_hz = self.clock_rate.get() as u64;
let samples_at_clock = (pcm_samples.len() as u64) * clock_hz / 24000;
let rtp_offset = self.rtp_sample_offset.fetch_add(samples_at_clock, Ordering::Relaxed);
let mut rtc = self.rtc.lock().await;
let writer = rtc
.writer(self.audio_track_id)
.ok_or_else(|| RealtimeError::webrtc("Audio track writer not available"))?;
let now = Instant::now();
let rtp_time = MediaTime::new(rtp_offset, self.clock_rate);
writer
.write(self.opus_pt, now, rtp_time, opus_data)
.map_err(|e| RealtimeError::webrtc(format!("Audio track write failed: {e}")))?;
Ok(())
}
}
use crate::openai::protocol::OpenAITransportLink;
#[async_trait]
impl OpenAITransportLink for OpenAIWebRTCSession {
fn session_id(&self) -> &str {
&self.session_id
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Relaxed)
}
async fn send_raw(&self, payload: &Value) -> Result<()> {
self.send_data_channel_message(payload).await
}
async fn send_audio(&self, audio: &crate::audio::AudioChunk) -> Result<()> {
if !self.is_connected() {
return Err(RealtimeError::NotConnected);
}
let pcm_samples = audio
.to_i16_samples()
.map_err(|e| RealtimeError::opus(format!("Invalid PCM16 audio data: {e}")))?;
self.write_audio_to_track(&pcm_samples).await
}
async fn send_audio_base64(&self, audio_base64: &str) -> Result<()> {
if !self.is_connected() {
return Err(RealtimeError::NotConnected);
}
let raw_bytes = base64::engine::general_purpose::STANDARD
.decode(audio_base64)
.map_err(|e| RealtimeError::audio(format!("Invalid base64 audio: {e}")))?;
if raw_bytes.len() % 2 != 0 {
return Err(RealtimeError::audio(format!(
"Invalid PCM16 data length: {} (must be even)",
raw_bytes.len()
)));
}
let pcm_samples: Vec<i16> =
raw_bytes.chunks_exact(2).map(|c| i16::from_le_bytes([c[0], c[1]])).collect();
self.write_audio_to_track(&pcm_samples).await
}
async fn receive_raw(&self) -> Option<Result<ServerEvent>> {
let mut rx = self.event_rx.lock().await;
rx.recv().await
}
async fn close(&self) -> Result<()> {
self.connected.store(false, Ordering::Relaxed);
let mut rtc = self.rtc.lock().await;
rtc.disconnect();
Ok(())
}
}