1use crate::api::response::ResponseCreateEvent;
2use crate::api::session::{Session, SessionUpdateEvent};
3use crate::error::RealtimeError;
4use crate::event::{Event, EventMessage};
5use crate::websocket::config::WebsocketConfig;
6use async_trait::async_trait;
7use ezsockets::{Error, Utf8Bytes};
8use nanoid::nanoid;
9use serde::Serialize;
10use serde_json::{Value, json};
11use std::sync::Arc;
12use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
13use tokio::sync::{Mutex, oneshot};
14use tracing::{debug, error, info};
15
16pub mod config {
17 use crate::ApiKeyRef;
18 use crate::api::model::Model;
19 use url::Url;
20
21 #[derive(Debug)]
22 pub struct WebsocketConfig {
23 pub model: Model,
24 pub api_key_ref: ApiKeyRef,
25 }
26
27 impl Default for WebsocketConfig {
28 fn default() -> Self {
29 Self {
30 model: Model::default(),
31 api_key_ref: ApiKeyRef::default(),
32 }
33 }
34 }
35
36 impl WebsocketConfig {
37 pub fn url(&self) -> Url {
38 Url::parse(format!("wss://api.openai.com/v1/realtime?model={}", self.model).as_str())
39 .unwrap()
40 }
41 }
42}
43
44pub async fn connect(
45 config: WebsocketConfig,
46) -> Result<(Arc<RealtimeSession>, UnboundedReceiver<Vec<u8>>), RealtimeError> {
47 let ws_config = ezsockets::ClientConfig::new(config.url())
48 .bearer(config.api_key_ref.api_key())
49 .header("openai-beta", "realtime=v1");
50
51 let (tx_events, mut rx_events) = unbounded_channel();
52 let (tx_connected, rx_connected) = oneshot::channel();
53
54 let session_id = nanoid!(6);
55
56 let (handle, _) = ezsockets::connect(
57 |handle| WebsocketHandle {
58 _handle: handle,
59 session_id: session_id.clone(),
60 tx_events,
61 connected: Some(tx_connected),
62 },
63 ws_config,
64 )
65 .await;
66
67 rx_connected.await.unwrap();
68
69 info!("connected");
70
71 let (realtime_session, rx_audio) = RealtimeSession::new(session_id, Arc::new(handle));
73
74 let realtime_session_for_events = realtime_session.clone();
76 tokio::spawn(async move {
77 while let Some(evt) = rx_events.recv().await {
78 realtime_session_for_events.handle_event(evt).await;
79 }
80 });
81
82 Ok((realtime_session, rx_audio))
83}
84
85pub struct WebsocketHandle {
86 _handle: ezsockets::Client<Self>,
87 session_id: String,
88 tx_events: UnboundedSender<Event>,
89 connected: Option<oneshot::Sender<()>>,
90}
91
92#[async_trait]
93impl ezsockets::ClientExt for WebsocketHandle {
94 type Call = ();
95
96 async fn on_text(&mut self, text: Utf8Bytes) -> Result<(), ezsockets::Error> {
97 let j: Value = serde_json::from_str(text.as_str()).unwrap();
98
99 let m = j.as_object().unwrap();
100 let event_type = m.get("type").unwrap().as_str().unwrap();
101
102 if event_type.to_string() != "response.audio.delta" {
103 debug!(
104 "openai: received event: {event_type}\n{}",
105 serde_json::to_string_pretty(&j.clone()).unwrap()
106 );
107 }
108
109 debug!("session({})> event: {}", self.session_id, event_type);
110
111 match event_type {
112 "session.created" => {
113 self.tx_events
114 .send(Event::SessionCreated(
115 serde_json::from_value(m.get("session").unwrap().clone()).unwrap(),
116 ))
117 .unwrap();
118 }
119 "response.audio.delta" => {
120 let decoded = base64::decode(m.get("delta").unwrap().as_str().unwrap()).unwrap();
121 self.tx_events.send(Event::Audio(decoded)).unwrap();
122 }
123 "response.audio_transcript.delta" => {
124 self.tx_events
125 .send(Event::TranscriptDelta(
126 serde_json::from_value(m.get("delta").unwrap().clone()).unwrap(),
127 ))
128 .unwrap();
129 }
130 "response.audio_transcript.done" => {
131 self.tx_events
132 .send(Event::TranscriptDone(
133 serde_json::from_value(m.get("transcript").unwrap().clone()).unwrap(),
134 ))
135 .unwrap();
136 }
137 "input_audio_buffer.speech_started" => {
138 self.tx_events
139 .send(Event::InputAudioBufferSpeechStarted)
140 .unwrap();
141 }
142 "response.audio.done" => {
143 println!("response.audio.done {:?}", m);
144
145 self.tx_events.send(Event::AudioDone).unwrap();
148
149 let silence: Vec<u8> = vec![0; 48_000 * 2];
151 self.tx_events.send(Event::Audio(silence)).unwrap();
152 }
153 _ => debug!(
157 "Unhandled event:\n{}",
158 serde_json::to_string_pretty(&j.clone()).unwrap()
159 ),
160 }
161
162 Ok(())
164 }
165
166 async fn on_binary(&mut self, _bytes: ezsockets::Bytes) -> Result<(), ezsockets::Error> {
167 unimplemented!()
168 }
169
170 async fn on_call(&mut self, call: Self::Call) -> Result<(), ezsockets::Error> {
171 Ok(())
172 }
173
174 async fn on_connect(&mut self) -> Result<(), Error> {
175 if let Some(connected) = self.connected.take() {
176 connected.send(()).unwrap();
177 }
178 Ok(())
179 }
180}
181
182pub struct RealtimeSession {
183 id: String,
184 session: Mutex<Option<Session>>,
185 tx_audio: UnboundedSender<Vec<u8>>,
186 tx_msg_out: UnboundedSender<Utf8Bytes>,
187}
188
189impl RealtimeSession {
190 pub fn new(
191 id: String,
192 ws: Arc<ezsockets::Client<WebsocketHandle>>,
193 ) -> (Arc<Self>, UnboundedReceiver<Vec<u8>>) {
194 let (tx_audio_out, rx_audio_out) = unbounded_channel();
195
196 let (tx_msg_out, mut rx_msg_out) = unbounded_channel::<Utf8Bytes>();
197
198 let ws_2 = ws.clone();
199 tokio::spawn(async move {
200 while let Some(data) = rx_msg_out.recv().await {
201 match ws_2.text(data) {
202 Ok(_) => {}
203 Err(e) => {
204 error!("error sending: {}", e);
205 }
206 }
207 }
208 panic!("websocket closed");
209 });
210
211 let session = Arc::new(Self {
212 id,
213 session: Mutex::new(None),
214 tx_audio: tx_audio_out,
215 tx_msg_out: tx_msg_out.clone(),
216 });
217
218 (session, rx_audio_out)
221 }
222
223 fn send(&self, evt: &str, body: impl Serialize) -> anyhow::Result<()> {
224 let body_str = serde_json::to_string_pretty(&EventMessage::wrap(evt, body))?;
225 if evt != "input_audio_buffer.append" {
226 debug!("session({})> send: {} {}", self.id, evt, body_str);
227 }
228 self.tx_msg_out.send(Utf8Bytes::from(body_str))?;
229 Ok(())
230 }
231
232 pub fn session_update(&self, session: SessionUpdateEvent) -> anyhow::Result<()> {
235 self.send(
236 "session.update",
237 json!({
238 "session": session
239 }),
240 )
241 }
242
243 pub fn response_create(&self, response: ResponseCreateEvent) -> anyhow::Result<()> {
246 self.send(
247 "response.create",
248 json!({
249 "response": response
250 }),
251 )
252 }
253
254 pub fn audio_append(&self, buffer: Vec<u8>) -> anyhow::Result<()> {
255 debug!("session({})> audio --> {} bytes", self.id, buffer.len());
256 self.send(
257 "input_audio_buffer.append",
258 json!({
259 "audio": base64::encode(buffer)
260 }),
261 )
262 }
263
264 async fn handle_event(&self, evt: Event) {
265 match evt.clone() {
267 Event::Audio(audio) => {
268 debug!("session({})> audio <-- {} bytes", self.id, audio.len());
269 }
270 _ => debug!("{:?}", evt),
271 }
272
273 match evt {
274 Event::Audio(audio) => match self.tx_audio.send(audio) {
275 Ok(_) => {}
276 Err(e) => {
277 error!("error handling audio event: {}", e);
278 }
279 },
280 Event::SessionCreated(session) => {
281 info!("Session created: {}", session.id);
282 {
283 self.session.lock().await.replace(session);
284 }
285 }
286 Event::TranscriptDone(transcript) => {
287 info!("transcript done: {transcript}");
288 }
289 Event::InputAudioBufferSpeechStarted => {
290 }
294 _ => {}
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use crate::WebsocketConfig;
302 use crate::websocket::connect;
303
304 #[tokio::test]
305 async fn it_works() {
306 let client = connect(WebsocketConfig::default()).await.unwrap();
307 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
308 }
309}