Skip to main content

openai_core/websocket/
core.rs

1//! WebSocket transport and connection management internals.
2
3use futures_util::stream::BoxStream;
4use url::Url;
5
6use crate::Client;
7use crate::config::RequestOptions;
8use crate::error::{Error, Result, SerializationError, WebSocketError};
9
10#[cfg(any(feature = "realtime", feature = "responses-ws"))]
11mod enabled {
12    use std::collections::BTreeMap;
13    use std::sync::Arc;
14    use std::sync::atomic::{AtomicU8, Ordering};
15
16    use futures_util::{SinkExt, StreamExt};
17    use serde::Serialize;
18    use tokio::sync::{Mutex, broadcast};
19    use tokio_tungstenite::connect_async;
20    use tokio_tungstenite::tungstenite::Message;
21    use tokio_tungstenite::tungstenite::client::IntoClientRequest;
22    use tokio_tungstenite::tungstenite::protocol::CloseFrame;
23    use tokio_tungstenite::tungstenite::protocol::frame::Utf8Bytes;
24    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
25    use tracing::{debug, error, info, warn};
26
27    use super::{
28        BoxStream, Client, Error, RequestOptions, Result, SerializationError, Url, WebSocketError,
29    };
30    use crate::config::{LogLevel, LogRecord, LoggerHandle};
31    #[cfg(feature = "realtime")]
32    use crate::providers::ProviderKind;
33    use crate::transport::{join_url, prepare_request_context};
34    #[cfg(feature = "realtime")]
35    use crate::websocket::{RealtimeServerEvent, RealtimeStreamMessage};
36    #[cfg(feature = "responses-ws")]
37    use crate::websocket::{ResponsesServerEvent, ResponsesStreamMessage};
38    use crate::websocket::{SocketCloseOptions, SocketStreamMessage, WebSocketServerEvent};
39
40    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
41    enum ConnectionState {
42        Connecting,
43        Open,
44        Closing,
45        Closed,
46    }
47
48    impl ConnectionState {
49        fn as_u8(self) -> u8 {
50            match self {
51                Self::Connecting => 0,
52                Self::Open => 1,
53                Self::Closing => 2,
54                Self::Closed => 3,
55            }
56        }
57
58        fn from_u8(value: u8) -> Self {
59            match value {
60                0 => Self::Connecting,
61                1 => Self::Open,
62                2 => Self::Closing,
63                _ => Self::Closed,
64            }
65        }
66
67        fn into_message<T>(self) -> SocketStreamMessage<T> {
68            match self {
69                Self::Connecting => SocketStreamMessage::Connecting,
70                Self::Open => SocketStreamMessage::Open,
71                Self::Closing => SocketStreamMessage::Closing,
72                Self::Closed => SocketStreamMessage::Close,
73            }
74        }
75    }
76
77    type WsSink = futures_util::stream::SplitSink<
78        tokio_tungstenite::WebSocketStream<
79            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
80        >,
81        Message,
82    >;
83
84    struct SocketCore<T> {
85        url: Url,
86        state: AtomicU8,
87        events: broadcast::Sender<SocketStreamMessage<T>>,
88        sink: Mutex<WsSink>,
89        log_level: LogLevel,
90        logger: Option<LoggerHandle>,
91    }
92
93    impl<T> std::fmt::Debug for SocketCore<T> {
94        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95            f.debug_struct("SocketCore")
96                .field("url", &self.url)
97                .field(
98                    "state",
99                    &ConnectionState::from_u8(self.state.load(Ordering::SeqCst)),
100                )
101                .finish()
102        }
103    }
104
105    impl<T> SocketCore<T>
106    where
107        T: Clone + Send + 'static,
108    {
109        fn stream(&self) -> BoxStream<'static, SocketStreamMessage<T>> {
110            let initial =
111                ConnectionState::from_u8(self.state.load(Ordering::SeqCst)).into_message();
112            let receiver = self.events.subscribe();
113            Box::pin(futures_util::stream::unfold(
114                (Some(initial), receiver, false),
115                |(initial, mut receiver, closed)| async move {
116                    if closed {
117                        return None;
118                    }
119
120                    if let Some(message) = initial {
121                        let closed = matches!(message, SocketStreamMessage::Close);
122                        return Some((message, (None, receiver, closed)));
123                    }
124
125                    loop {
126                        match receiver.recv().await {
127                            Ok(message) => {
128                                let closed = matches!(message, SocketStreamMessage::Close);
129                                return Some((message, (None, receiver, closed)));
130                            }
131                            Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {}
132                            Err(tokio::sync::broadcast::error::RecvError::Closed) => return None,
133                        }
134                    }
135                },
136            ))
137        }
138    }
139
140    /// 表示 Realtime WebSocket 连接句柄。
141    #[cfg(feature = "realtime")]
142    #[derive(Debug, Clone)]
143    pub struct RealtimeSocket {
144        inner: Arc<SocketCore<RealtimeServerEvent>>,
145    }
146
147    /// 表示 Responses WebSocket 连接句柄。
148    #[cfg(feature = "responses-ws")]
149    #[derive(Debug, Clone)]
150    pub struct ResponsesSocket {
151        inner: Arc<SocketCore<ResponsesServerEvent>>,
152    }
153
154    #[cfg(feature = "realtime")]
155    impl RealtimeSocket {
156        /// 建立 Realtime WebSocket 连接。
157        pub(crate) async fn connect(
158            client: &Client,
159            model: Option<String>,
160            mut options: RequestOptions,
161        ) -> Result<Self> {
162            match client.provider().kind() {
163                ProviderKind::Azure => {
164                    if let Some(model) = model {
165                        options.insert_query("deployment", model);
166                    }
167                    let socket =
168                        connect_socket(client, "realtime.ws.connect", "/realtime", options).await?;
169                    if !socket.url.query_pairs().any(|(key, _)| key == "deployment") {
170                        return Err(Error::MissingRequiredField {
171                            field: "deployment",
172                        });
173                    }
174                    Ok(Self { inner: socket })
175                }
176                _ => {
177                    let Some(model) = model else {
178                        return Err(Error::MissingRequiredField { field: "model" });
179                    };
180                    options.insert_query("model", model);
181                    Ok(Self {
182                        inner: connect_socket(client, "realtime.ws.connect", "/realtime", options)
183                            .await?,
184                    })
185                }
186            }
187        }
188
189        /// 返回当前连接的 URL。
190        pub fn url(&self) -> &Url {
191            &self.inner.url
192        }
193
194        /// 返回一个可迭代的事件流。
195        pub fn stream(&self) -> BoxStream<'static, RealtimeStreamMessage> {
196            self.inner.stream()
197        }
198
199        /// 发送一个可序列化事件。
200        pub async fn send_json<T>(&self, event: &T) -> Result<()>
201        where
202            T: Serialize,
203        {
204            send_json(&self.inner, event).await
205        }
206
207        /// 主动关闭连接。
208        pub async fn close(&self, options: SocketCloseOptions) -> Result<()> {
209            close_socket(&self.inner, options).await
210        }
211    }
212
213    #[cfg(feature = "responses-ws")]
214    impl ResponsesSocket {
215        /// 建立 Responses WebSocket 连接。
216        pub(crate) async fn connect(client: &Client, options: RequestOptions) -> Result<Self> {
217            Ok(Self {
218                inner: connect_socket(client, "responses.ws.connect", "/responses", options)
219                    .await?,
220            })
221        }
222
223        /// 返回当前连接的 URL。
224        pub fn url(&self) -> &Url {
225            &self.inner.url
226        }
227
228        /// 返回一个可迭代的事件流。
229        pub fn stream(&self) -> BoxStream<'static, ResponsesStreamMessage> {
230            self.inner.stream()
231        }
232
233        /// 发送一个可序列化事件。
234        pub async fn send_json<T>(&self, event: &T) -> Result<()>
235        where
236            T: Serialize,
237        {
238            send_json(&self.inner, event).await
239        }
240
241        /// 主动关闭连接。
242        pub async fn close(&self, options: SocketCloseOptions) -> Result<()> {
243            close_socket(&self.inner, options).await
244        }
245    }
246
247    async fn connect_socket<T>(
248        client: &Client,
249        endpoint_id: &'static str,
250        path: &str,
251        options: RequestOptions,
252    ) -> Result<Arc<SocketCore<T>>>
253    where
254        T: serde::de::DeserializeOwned + Clone + Send + 'static,
255    {
256        let context =
257            prepare_request_context(&client.inner, endpoint_id, path.into(), None, &options)
258                .await?;
259        let url = build_websocket_url(client.base_url(), &context.path, &context.query)?;
260        emit_socket_log(
261            client.inner.options.log_level,
262            client.inner.options.logger.clone(),
263            LogLevel::Debug,
264            "openai_core::websocket",
265            "建立 WebSocket 连接",
266            BTreeMap::from([
267                ("endpoint_id".into(), endpoint_id.to_string()),
268                ("url".into(), url.to_string()),
269            ]),
270        );
271        let request = build_websocket_request(&url, &context.headers)?;
272        let (stream, _) = connect_async(request)
273            .await
274            .map_err(|error| Error::WebSocket(WebSocketError::transport(error.to_string())))?;
275
276        let (sink, mut source) = stream.split();
277        let (sender, _) = broadcast::channel(128);
278        let inner = Arc::new(SocketCore {
279            url,
280            state: AtomicU8::new(ConnectionState::Open.as_u8()),
281            events: sender,
282            sink: Mutex::new(sink),
283            log_level: client.inner.options.log_level,
284            logger: client.inner.options.logger.clone(),
285        });
286        let reader_inner = inner.clone();
287
288        tokio::spawn(async move {
289            while let Some(message) = source.next().await {
290                match message {
291                    Ok(Message::Text(text)) => {
292                        handle_server_payload::<T>(&reader_inner, text.as_bytes());
293                    }
294                    Ok(Message::Binary(bytes)) => {
295                        handle_server_payload::<T>(&reader_inner, bytes.as_ref());
296                    }
297                    Ok(Message::Close(frame)) => {
298                        handle_close_frame(&reader_inner, frame);
299                        break;
300                    }
301                    Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
302                    Ok(_) => {}
303                    Err(error) => {
304                        push_error(&reader_inner, WebSocketError::transport(error.to_string()));
305                        mark_closed(&reader_inner);
306                        break;
307                    }
308                }
309            }
310
311            if ConnectionState::from_u8(reader_inner.state.load(Ordering::SeqCst))
312                != ConnectionState::Closed
313            {
314                mark_closed(&reader_inner);
315            }
316        });
317
318        Ok(inner)
319    }
320
321    fn handle_server_payload<T>(inner: &Arc<SocketCore<T>>, payload: &[u8])
322    where
323        T: serde::de::DeserializeOwned + Clone + Send + 'static,
324    {
325        let raw = match serde_json::from_slice::<WebSocketServerEvent>(payload) {
326            Ok(raw) => raw,
327            Err(error) => {
328                let error = Error::Serialization(SerializationError::new(format!(
329                    "WebSocket 事件反序列化失败: {error}"
330                )));
331                push_error(inner, WebSocketError::protocol(error.to_string()));
332                return;
333            }
334        };
335
336        if raw.is_error() {
337            let message = raw
338                .error_message()
339                .unwrap_or_else(|| "WebSocket 收到错误事件".into());
340            emit_socket_log(
341                inner.log_level,
342                inner.logger.clone(),
343                LogLevel::Info,
344                "openai_core::websocket",
345                "收到 WebSocket 错误事件",
346                BTreeMap::from([("event_type".into(), raw.event_type.clone())]),
347            );
348            push_error(
349                inner,
350                WebSocketError::server(message, Some(raw.event_type.clone())),
351            );
352            return;
353        }
354
355        match serde_json::from_slice::<T>(payload) {
356            Ok(event) => {
357                emit_socket_log(
358                    inner.log_level,
359                    inner.logger.clone(),
360                    LogLevel::Debug,
361                    "openai_core::websocket",
362                    "收到 WebSocket 事件",
363                    BTreeMap::from([("event_type".into(), raw.event_type.clone())]),
364                );
365                let _ = inner.events.send(SocketStreamMessage::Message(event));
366            }
367            Err(error) => {
368                let error = Error::Serialization(SerializationError::new(format!(
369                    "WebSocket 事件反序列化失败: {error}"
370                )));
371                push_error(inner, WebSocketError::protocol(error.to_string()));
372            }
373        }
374    }
375
376    fn push_error<T>(inner: &Arc<SocketCore<T>>, error: WebSocketError)
377    where
378        T: Clone + Send + 'static,
379    {
380        let _ = inner.events.send(SocketStreamMessage::Error(error));
381    }
382
383    fn handle_close_frame<T>(inner: &Arc<SocketCore<T>>, frame: Option<CloseFrame>)
384    where
385        T: Clone + Send + 'static,
386    {
387        let state = ConnectionState::from_u8(inner.state.load(Ordering::SeqCst));
388        if state != ConnectionState::Closing
389            && let Some(frame) = frame.as_ref()
390            && let Some(error) = map_close_frame_error(frame)
391        {
392            push_error(inner, error);
393        }
394        mark_closed(inner);
395    }
396
397    fn map_close_frame_error(frame: &CloseFrame) -> Option<WebSocketError> {
398        if frame.code == CloseCode::Normal {
399            return None;
400        }
401
402        let code = u16::from(frame.code);
403        let reason = frame.reason.to_string();
404        let message = if reason.is_empty() {
405            format!("WebSocket 连接被关闭: code={code}")
406        } else {
407            format!("WebSocket 连接被关闭: code={code}, reason={reason}")
408        };
409        Some(WebSocketError::protocol(message))
410    }
411
412    fn mark_closed<T>(inner: &Arc<SocketCore<T>>)
413    where
414        T: Clone + Send + 'static,
415    {
416        inner
417            .state
418            .store(ConnectionState::Closed.as_u8(), Ordering::SeqCst);
419        let _ = inner.events.send(SocketStreamMessage::Close);
420    }
421
422    async fn send_json<T, U>(inner: &Arc<SocketCore<T>>, event: &U) -> Result<()>
423    where
424        T: Clone + Send + 'static,
425        U: Serialize,
426    {
427        let payload = serde_json::to_string(event)
428            .map_err(|error| Error::Serialization(SerializationError::new(error.to_string())))?;
429        emit_socket_log(
430            inner.log_level,
431            inner.logger.clone(),
432            LogLevel::Debug,
433            "openai_core::websocket",
434            "发送 WebSocket 消息",
435            BTreeMap::from([("url".into(), inner.url.to_string())]),
436        );
437        let mut sink = inner.sink.lock().await;
438        sink.send(Message::Text(payload.into()))
439            .await
440            .map_err(|error| Error::WebSocket(WebSocketError::transport(error.to_string())))
441    }
442
443    async fn close_socket<T>(inner: &Arc<SocketCore<T>>, options: SocketCloseOptions) -> Result<()>
444    where
445        T: Clone + Send + 'static,
446    {
447        inner
448            .state
449            .store(ConnectionState::Closing.as_u8(), Ordering::SeqCst);
450        let _ = inner.events.send(SocketStreamMessage::Closing);
451        emit_socket_log(
452            inner.log_level,
453            inner.logger.clone(),
454            LogLevel::Info,
455            "openai_core::websocket",
456            "关闭 WebSocket 连接",
457            BTreeMap::from([
458                ("url".into(), inner.url.to_string()),
459                ("code".into(), options.code.to_string()),
460            ]),
461        );
462
463        let mut sink = inner.sink.lock().await;
464        sink.send(Message::Close(Some(CloseFrame {
465            code: CloseCode::from(options.code),
466            reason: Utf8Bytes::from(options.reason),
467        })))
468        .await
469        .map_err(|error| Error::WebSocket(WebSocketError::transport(error.to_string())))?;
470        Ok(())
471    }
472
473    fn build_websocket_url(
474        base_url: &str,
475        path: &str,
476        query: &BTreeMap<String, String>,
477    ) -> Result<Url> {
478        let joined = join_url(base_url, path)?;
479        let mut url = Url::parse(&joined)
480            .map_err(|error| Error::InvalidConfig(format!("WebSocket URL 无效: {error}")))?;
481        match url.scheme() {
482            "http" => {
483                let _ = url.set_scheme("ws");
484            }
485            "https" => {
486                let _ = url.set_scheme("wss");
487            }
488            "ws" | "wss" => {}
489            scheme => {
490                return Err(Error::InvalidConfig(format!(
491                    "不支持的 WebSocket 基础协议: {scheme}"
492                )));
493            }
494        }
495
496        if !query.is_empty() {
497            let mut pairs = url.query_pairs_mut();
498            pairs.clear();
499            for (key, value) in query {
500                pairs.append_pair(key, value);
501            }
502        }
503        Ok(url)
504    }
505
506    fn emit_socket_log(
507        configured_level: LogLevel,
508        logger: Option<LoggerHandle>,
509        level: LogLevel,
510        target: &'static str,
511        message: impl Into<String>,
512        fields: BTreeMap<String, String>,
513    ) {
514        if !configured_level.allows(level) {
515            return;
516        }
517
518        let record = LogRecord {
519            level,
520            target,
521            message: message.into(),
522            fields,
523        };
524        if let Some(logger) = &logger {
525            logger.log(&record);
526        }
527
528        let rendered_fields = if record.fields.is_empty() {
529            String::new()
530        } else {
531            format!(
532                " {}",
533                record
534                    .fields
535                    .iter()
536                    .map(|(key, value)| format!("{key}={value}"))
537                    .collect::<Vec<_>>()
538                    .join(" ")
539            )
540        };
541        let rendered = format!("[{}] {}{}", target, record.message, rendered_fields);
542        match level {
543            LogLevel::Off => {}
544            LogLevel::Error => error!("{rendered}"),
545            LogLevel::Warn => warn!("{rendered}"),
546            LogLevel::Info => info!("{rendered}"),
547            LogLevel::Debug => debug!("{rendered}"),
548        }
549    }
550
551    fn build_websocket_request(
552        url: &Url,
553        headers: &BTreeMap<String, String>,
554    ) -> Result<http::Request<()>> {
555        let mut request = url.as_str().into_client_request().map_err(|error| {
556            Error::InvalidConfig(format!("构建 WebSocket 握手请求失败: {error}"))
557        })?;
558        for (key, value) in headers {
559            request.headers_mut().insert(
560                http::header::HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
561                    Error::InvalidConfig(format!("构建 WebSocket 握手请求失败: {error}"))
562                })?,
563                http::header::HeaderValue::from_str(value).map_err(|error| {
564                    Error::InvalidConfig(format!("构建 WebSocket 握手请求失败: {error}"))
565                })?,
566            );
567        }
568        Ok(request)
569    }
570
571    #[cfg(test)]
572    mod tests {
573        use std::collections::BTreeMap;
574
575        use super::*;
576        use crate::error::WebSocketErrorKind;
577
578        #[test]
579        fn test_should_build_ws_url_from_https_base_url() {
580            let url = build_websocket_url(
581                "https://api.openai.com/v1",
582                "/realtime",
583                &BTreeMap::from([("model".into(), "gpt-4o-realtime-preview".into())]),
584            )
585            .unwrap();
586
587            assert_eq!(
588                url.as_str(),
589                "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview"
590            );
591        }
592
593        #[test]
594        fn test_should_reject_unsupported_websocket_base_scheme() {
595            let error = build_websocket_url("ftp://example.com", "/realtime", &BTreeMap::new())
596                .unwrap_err();
597
598            assert!(matches!(error, Error::InvalidConfig(_)));
599            assert!(error.to_string().contains("ftp"));
600        }
601
602        #[test]
603        fn test_should_reject_invalid_websocket_headers() {
604            let error = build_websocket_request(
605                &Url::parse("ws://example.com/realtime").unwrap(),
606                &BTreeMap::from([("x-test".into(), "bad\nvalue".into())]),
607            )
608            .unwrap_err();
609
610            assert!(matches!(error, Error::InvalidConfig(_)));
611        }
612
613        #[test]
614        fn test_should_parse_error_message_from_event() {
615            let event = WebSocketServerEvent {
616                event_type: "error".into(),
617                data: BTreeMap::from([(
618                    "error".into(),
619                    serde_json::json!({
620                        "message": "bad request"
621                    }),
622                )]),
623            };
624
625            assert_eq!(event.error_message().as_deref(), Some("bad request"));
626        }
627
628        #[test]
629        fn test_should_map_abnormal_close_frame_to_protocol_error() {
630            let error = map_close_frame_error(&CloseFrame {
631                code: CloseCode::from(1008),
632                reason: Utf8Bytes::from("quota exceeded"),
633            })
634            .unwrap();
635
636            assert_eq!(error.kind, WebSocketErrorKind::Protocol);
637            assert!(error.message.contains("1008"));
638            assert!(error.message.contains("quota exceeded"));
639        }
640
641        #[test]
642        fn test_should_ignore_normal_close_frame_for_error_mapping() {
643            let error = map_close_frame_error(&CloseFrame {
644                code: CloseCode::Normal,
645                reason: Utf8Bytes::from("OK"),
646            });
647
648            assert!(error.is_none());
649        }
650    }
651}
652
653#[cfg(not(any(feature = "realtime", feature = "responses-ws")))]
654mod enabled {
655    use futures_util::stream::{self, BoxStream};
656    use serde::Serialize;
657
658    use super::{Client, Error, RequestOptions, Result, Url};
659    use crate::websocket::{RealtimeStreamMessage, ResponsesStreamMessage, SocketCloseOptions};
660
661    /// 表示 Realtime WebSocket 连接句柄。
662    #[derive(Debug, Clone)]
663    pub struct RealtimeSocket {
664        url: Url,
665    }
666
667    /// 表示 Responses WebSocket 连接句柄。
668    #[derive(Debug, Clone)]
669    pub struct ResponsesSocket {
670        url: Url,
671    }
672
673    impl RealtimeSocket {
674        /// 建立 Realtime WebSocket 连接。
675        pub(crate) async fn connect(
676            _client: &Client,
677            _model: Option<String>,
678            _options: RequestOptions,
679        ) -> Result<Self> {
680            Err(Error::InvalidConfig(
681                "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
682            ))
683        }
684
685        /// 返回当前连接的 URL。
686        pub fn url(&self) -> &Url {
687            &self.url
688        }
689
690        /// 返回一个空事件流。
691        pub fn stream(&self) -> BoxStream<'static, RealtimeStreamMessage> {
692            Box::pin(stream::empty())
693        }
694
695        /// 发送一个可序列化事件。
696        pub async fn send_json<T>(&self, _event: &T) -> Result<()>
697        where
698            T: Serialize,
699        {
700            Err(Error::InvalidConfig(
701                "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
702            ))
703        }
704
705        /// 主动关闭连接。
706        pub async fn close(&self, _options: SocketCloseOptions) -> Result<()> {
707            Ok(())
708        }
709    }
710
711    impl ResponsesSocket {
712        /// 建立 Responses WebSocket 连接。
713        pub(crate) async fn connect(_client: &Client, _options: RequestOptions) -> Result<Self> {
714            Err(Error::InvalidConfig(
715                "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
716            ))
717        }
718
719        /// 返回当前连接的 URL。
720        pub fn url(&self) -> &Url {
721            &self.url
722        }
723
724        /// 返回一个空事件流。
725        pub fn stream(&self) -> BoxStream<'static, ResponsesStreamMessage> {
726            Box::pin(stream::empty())
727        }
728
729        /// 发送一个可序列化事件。
730        pub async fn send_json<T>(&self, _event: &T) -> Result<()>
731        where
732            T: Serialize,
733        {
734            Err(Error::InvalidConfig(
735                "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
736            ))
737        }
738
739        /// 主动关闭连接。
740        pub async fn close(&self, _options: SocketCloseOptions) -> Result<()> {
741            Ok(())
742        }
743    }
744}
745
746#[cfg(feature = "realtime")]
747pub use enabled::RealtimeSocket;
748#[cfg(feature = "responses-ws")]
749pub use enabled::ResponsesSocket;