openai_realtime/
websocket.rs

1use crate::api::response::ResponseCreateEvent;
2use crate::api::session::{Session, SessionUpdateEvent};
3use crate::error::RealtimeError;
4use crate::event::{Event, EventMessage};
5use crate::websocket::config::WebsocketConfig;
6use async_trait::async_trait;
7use ezsockets::{Error, Utf8Bytes};
8use nanoid::nanoid;
9use serde::Serialize;
10use serde_json::{Value, json};
11use std::sync::Arc;
12use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
13use tokio::sync::{Mutex, oneshot};
14use tracing::{debug, error, info};
15
16pub mod config {
17    use crate::ApiKeyRef;
18    use crate::api::model::Model;
19    use url::Url;
20
21    #[derive(Debug)]
22    pub struct WebsocketConfig {
23        pub model: Model,
24        pub api_key_ref: ApiKeyRef,
25    }
26
27    impl Default for WebsocketConfig {
28        fn default() -> Self {
29            Self {
30                model: Model::default(),
31                api_key_ref: ApiKeyRef::default(),
32            }
33        }
34    }
35
36    impl WebsocketConfig {
37        pub fn url(&self) -> Url {
38            Url::parse(format!("wss://api.openai.com/v1/realtime?model={}", self.model).as_str())
39                .unwrap()
40        }
41    }
42}
43
44pub async fn connect(
45    config: WebsocketConfig,
46) -> Result<(Arc<RealtimeSession>, UnboundedReceiver<Vec<u8>>), RealtimeError> {
47    let ws_config = ezsockets::ClientConfig::new(config.url())
48        .bearer(config.api_key_ref.api_key())
49        .header("openai-beta", "realtime=v1");
50
51    let (tx_events, mut rx_events) = unbounded_channel();
52    let (tx_connected, rx_connected) = oneshot::channel();
53
54    let session_id = nanoid!(6);
55
56    let (handle, _) = ezsockets::connect(
57        |handle| WebsocketHandle {
58            _handle: handle,
59            session_id: session_id.clone(),
60            tx_events,
61            connected: Some(tx_connected),
62        },
63        ws_config,
64    )
65    .await;
66
67    rx_connected.await.unwrap();
68
69    info!("connected");
70
71    // create new realtime session
72    let (realtime_session, rx_audio) = RealtimeSession::new(session_id, Arc::new(handle));
73
74    // process events
75    let realtime_session_for_events = realtime_session.clone();
76    tokio::spawn(async move {
77        while let Some(evt) = rx_events.recv().await {
78            realtime_session_for_events.handle_event(evt).await;
79        }
80    });
81
82    Ok((realtime_session, rx_audio))
83}
84
85pub struct WebsocketHandle {
86    _handle: ezsockets::Client<Self>,
87    session_id: String,
88    tx_events: UnboundedSender<Event>,
89    connected: Option<oneshot::Sender<()>>,
90}
91
92#[async_trait]
93impl ezsockets::ClientExt for WebsocketHandle {
94    type Call = ();
95
96    async fn on_text(&mut self, text: Utf8Bytes) -> Result<(), ezsockets::Error> {
97        let j: Value = serde_json::from_str(text.as_str()).unwrap();
98
99        let m = j.as_object().unwrap();
100        let event_type = m.get("type").unwrap().as_str().unwrap();
101
102        if event_type.to_string() != "response.audio.delta" {
103            debug!(
104                "openai: received event: {event_type}\n{}",
105                serde_json::to_string_pretty(&j.clone()).unwrap()
106            );
107        }
108
109        debug!("session({})> event: {}", self.session_id, event_type);
110
111        match event_type {
112            "session.created" => {
113                self.tx_events
114                    .send(Event::SessionCreated(
115                        serde_json::from_value(m.get("session").unwrap().clone()).unwrap(),
116                    ))
117                    .unwrap();
118            }
119            "response.audio.delta" => {
120                let decoded = base64::decode(m.get("delta").unwrap().as_str().unwrap()).unwrap();
121                self.tx_events.send(Event::Audio(decoded)).unwrap();
122            }
123            "response.audio_transcript.delta" => {
124                self.tx_events
125                    .send(Event::TranscriptDelta(
126                        serde_json::from_value(m.get("delta").unwrap().clone()).unwrap(),
127                    ))
128                    .unwrap();
129            }
130            "response.audio_transcript.done" => {
131                self.tx_events
132                    .send(Event::TranscriptDone(
133                        serde_json::from_value(m.get("transcript").unwrap().clone()).unwrap(),
134                    ))
135                    .unwrap();
136            }
137            "input_audio_buffer.speech_started" => {
138                self.tx_events
139                    .send(Event::InputAudioBufferSpeechStarted)
140                    .unwrap();
141            }
142            "response.audio.done" => {
143                println!("response.audio.done {:?}", m);
144
145                // TODO: when done, we should generate a bit of silence at then end
146
147                self.tx_events.send(Event::AudioDone).unwrap();
148
149                // TODO: figure out how much silence we actually need
150                let silence: Vec<u8> = vec![0; 48_000 * 2];
151                self.tx_events.send(Event::Audio(silence)).unwrap();
152            }
153            // TODO: response.audio.done
154            // TODO: response.audio_transcript.done
155            // TODO: response.done
156            _ => debug!(
157                "Unhandled event:\n{}",
158                serde_json::to_string_pretty(&j.clone()).unwrap()
159            ),
160        }
161
162        //tracing::debug!("received message: {text}");
163        Ok(())
164    }
165
166    async fn on_binary(&mut self, _bytes: ezsockets::Bytes) -> Result<(), ezsockets::Error> {
167        unimplemented!()
168    }
169
170    async fn on_call(&mut self, call: Self::Call) -> Result<(), ezsockets::Error> {
171        Ok(())
172    }
173
174    async fn on_connect(&mut self) -> Result<(), Error> {
175        if let Some(connected) = self.connected.take() {
176            connected.send(()).unwrap();
177        }
178        Ok(())
179    }
180}
181
182pub struct RealtimeSession {
183    id: String,
184    session: Mutex<Option<Session>>,
185    tx_audio: UnboundedSender<Vec<u8>>,
186    tx_msg_out: UnboundedSender<Utf8Bytes>,
187}
188
189impl RealtimeSession {
190    pub fn new(
191        id: String,
192        ws: Arc<ezsockets::Client<WebsocketHandle>>,
193    ) -> (Arc<Self>, UnboundedReceiver<Vec<u8>>) {
194        let (tx_audio_out, rx_audio_out) = unbounded_channel();
195
196        let (tx_msg_out, mut rx_msg_out) = unbounded_channel::<Utf8Bytes>();
197
198        let ws_2 = ws.clone();
199        tokio::spawn(async move {
200            while let Some(data) = rx_msg_out.recv().await {
201                match ws_2.text(data) {
202                    Ok(_) => {}
203                    Err(e) => {
204                        error!("error sending: {}", e);
205                    }
206                }
207            }
208            panic!("websocket closed");
209        });
210
211        let session = Arc::new(Self {
212            id,
213            session: Mutex::new(None),
214            tx_audio: tx_audio_out,
215            tx_msg_out: tx_msg_out.clone(),
216        });
217
218        // TODO: send from websocket to tx_audio
219
220        (session, rx_audio_out)
221    }
222
223    fn send(&self, evt: &str, body: impl Serialize) -> anyhow::Result<()> {
224        let body_str = serde_json::to_string_pretty(&EventMessage::wrap(evt, body))?;
225        if evt != "input_audio_buffer.append" {
226            debug!("session({})> send: {} {}", self.id, evt, body_str);
227        }
228        self.tx_msg_out.send(Utf8Bytes::from(body_str))?;
229        Ok(())
230    }
231
232    /// Updates the session
233    /// See: https://platform.openai.com/docs/api-reference/realtime-client-events/session/update
234    pub fn session_update(&self, session: SessionUpdateEvent) -> anyhow::Result<()> {
235        self.send(
236            "session.update",
237            json!({
238                "session": session
239            }),
240        )
241    }
242
243    /// This event instructs the server to create a Response, which means triggering model inference. When in Server VAD mode, the server will create Responses automatically.
244    /// See: https://platform.openai.com/docs/api-reference/realtime-client-events/response/create
245    pub fn response_create(&self, response: ResponseCreateEvent) -> anyhow::Result<()> {
246        self.send(
247            "response.create",
248            json!({
249                "response": response
250            }),
251        )
252    }
253
254    pub fn audio_append(&self, buffer: Vec<u8>) -> anyhow::Result<()> {
255        debug!("session({})> audio --> {} bytes", self.id, buffer.len());
256        self.send(
257            "input_audio_buffer.append",
258            json!({
259                "audio": base64::encode(buffer)
260            }),
261        )
262    }
263
264    async fn handle_event(&self, evt: Event) {
265        // debug
266        match evt.clone() {
267            Event::Audio(audio) => {
268                debug!("session({})> audio <-- {} bytes", self.id, audio.len());
269            }
270            _ => debug!("{:?}", evt),
271        }
272
273        match evt {
274            Event::Audio(audio) => match self.tx_audio.send(audio) {
275                Ok(_) => {}
276                Err(e) => {
277                    error!("error handling audio event: {}", e);
278                }
279            },
280            Event::SessionCreated(session) => {
281                info!("Session created: {}", session.id);
282                {
283                    self.session.lock().await.replace(session);
284                }
285            }
286            Event::TranscriptDone(transcript) => {
287                info!("transcript done: {transcript}");
288            }
289            Event::InputAudioBufferSpeechStarted => {
290                //println!("STARTED");
291                //rt_client_events.send("response.cancel", json!({}))
292                //rt_client_events.send("response.cancel", json!({}))
293            }
294            _ => {}
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use crate::WebsocketConfig;
302    use crate::websocket::connect;
303
304    #[tokio::test]
305    async fn it_works() {
306        let client = connect(WebsocketConfig::default()).await.unwrap();
307        tokio::time::sleep(std::time::Duration::from_secs(10)).await;
308    }
309}