openai_realtime_proxy/
lib.rs

1use axum::extract::ws::WebSocket;
2use futures::{SinkExt, StreamExt};
3use http::{header, HeaderValue};
4use tokio::net::TcpStream;
5use tokio_tungstenite::{
6    tungstenite::{
7        client::IntoClientRequest,
8        handshake::client::Response,
9        protocol::{frame::coding::CloseCode, CloseFrame},
10        Message,
11    },
12    MaybeTlsStream, WebSocketStream,
13};
14use url::Url;
15
16pub struct Proxy {
17    api_token: String,
18}
19
20impl Proxy {
21    pub fn new(api_token: String) -> Self {
22        Self { api_token }
23    }
24
25    pub async fn handle(self, socket: WebSocket) {
26        // connect to server
27        let openai_stream = match self.connect().await {
28            Ok((stream, response)) => {
29                println!("Server response was {response:?}");
30                stream
31            }
32            Err(e) => {
33                println!("WebSocket handshake failed with {e}!");
34                return;
35            }
36        };
37
38        let (mut openai_sender, mut openai_receiver) = openai_stream.split();
39
40        // ...
41        let (mut client_sender, mut client_receiver) = socket.split();
42
43        let mut openai_to_client = tokio::spawn(async move {
44            while let Some(Ok(msg)) = openai_receiver.next().await {
45                let Some(msg) = msg.into_axum() else {
46                    continue;
47                };
48
49                if let Err(e) = client_sender.send(msg).await {
50                    println!("Error sending message to client {e:?}");
51                    break;
52                }
53            }
54        });
55
56        let mut client_to_openai = tokio::spawn(async move {
57            while let Some(Ok(msg)) = client_receiver.next().await {
58                if let Err(e) = openai_sender.send(msg.into_tungstenite()).await {
59                    println!("Error sending message to openai {e:?}");
60                    break;
61                }
62            }
63        });
64
65        tokio::select! {
66            result = (&mut openai_to_client) => {
67                if let Err(error) = result {
68                    println!("Error in openai_to_client {error:?}");
69                }
70                client_to_openai.abort();
71            },
72            result = (&mut client_to_openai) => {
73                if let Err(error) = result {
74                    println!("Error in client_to_openai {error:?}");
75                }
76                openai_to_client.abort();
77            }
78        }
79    }
80
81    async fn connect(
82        &self,
83    ) -> Result<
84        (WebSocketStream<MaybeTlsStream<TcpStream>>, Response),
85        tokio_tungstenite::tungstenite::Error,
86    > {
87        let url =
88            Url::parse("wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01")
89                .unwrap();
90
91        let mut request = url.into_client_request().unwrap();
92        let headers = request.headers_mut();
93
94        headers.insert("OpenAI-Beta", HeaderValue::from_static("realtime=v1"));
95        headers.insert(
96            header::USER_AGENT,
97            HeaderValue::from_static("rust-openai-proxy"),
98        );
99        headers.insert(
100            header::AUTHORIZATION,
101            HeaderValue::from_str(&format!("Bearer {}", self.api_token)).unwrap(),
102        );
103
104        tokio_tungstenite::connect_async(request).await
105    }
106}
107
108trait TungsteniteConverter {
109    fn into_tungstenite(self) -> Message;
110}
111
112impl TungsteniteConverter for axum::extract::ws::Message {
113    fn into_tungstenite(self) -> Message {
114        match self {
115            Self::Text(text) => Message::Text(text),
116            Self::Binary(binary) => Message::Binary(binary),
117            Self::Ping(ping) => Message::Ping(ping),
118            Self::Pong(pong) => Message::Pong(pong),
119            Self::Close(Some(close)) => Message::Close(Some(CloseFrame {
120                code: CloseCode::from(close.code),
121                reason: close.reason,
122            })),
123            Self::Close(None) => Message::Close(None),
124        }
125    }
126}
127
128trait AxumConverter {
129    fn into_axum(self) -> Option<axum::extract::ws::Message>;
130}
131
132impl AxumConverter for Message {
133    fn into_axum(self) -> Option<axum::extract::ws::Message> {
134        match self {
135            Self::Text(text) => Some(axum::extract::ws::Message::Text(text)),
136            Self::Binary(binary) => Some(axum::extract::ws::Message::Binary(binary)),
137            Self::Ping(ping) => Some(axum::extract::ws::Message::Ping(ping)),
138            Self::Pong(pong) => Some(axum::extract::ws::Message::Pong(pong)),
139            Self::Close(Some(close)) => Some(axum::extract::ws::Message::Close(Some(
140                axum::extract::ws::CloseFrame {
141                    code: close.code.into(),
142                    reason: close.reason,
143                },
144            ))),
145            Self::Close(None) => Some(axum::extract::ws::Message::Close(None)),
146            Self::Frame(_) => None,
147        }
148    }
149}