1use futures_util::stream::BoxStream;
4use url::Url;
5
6use crate::Client;
7use crate::config::RequestOptions;
8use crate::error::{Error, Result, SerializationError, WebSocketError};
9
10#[cfg(any(feature = "realtime", feature = "responses-ws"))]
11mod enabled {
12 use std::collections::BTreeMap;
13 use std::sync::Arc;
14 use std::sync::atomic::{AtomicU8, Ordering};
15
16 use futures_util::{SinkExt, StreamExt};
17 use serde::Serialize;
18 use tokio::sync::{Mutex, broadcast};
19 use tokio_tungstenite::connect_async;
20 use tokio_tungstenite::tungstenite::Message;
21 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
22 use tokio_tungstenite::tungstenite::protocol::CloseFrame;
23 use tokio_tungstenite::tungstenite::protocol::frame::Utf8Bytes;
24 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
25 use tracing::{debug, error, info, warn};
26
27 use super::{
28 BoxStream, Client, Error, RequestOptions, Result, SerializationError, Url, WebSocketError,
29 };
30 use crate::config::{LogLevel, LogRecord, LoggerHandle};
31 #[cfg(feature = "realtime")]
32 use crate::providers::ProviderKind;
33 use crate::transport::{join_url, prepare_request_context};
34 #[cfg(feature = "realtime")]
35 use crate::websocket::{RealtimeServerEvent, RealtimeStreamMessage};
36 #[cfg(feature = "responses-ws")]
37 use crate::websocket::{ResponsesServerEvent, ResponsesStreamMessage};
38 use crate::websocket::{SocketCloseOptions, SocketStreamMessage, WebSocketServerEvent};
39
40 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
41 enum ConnectionState {
42 Connecting,
43 Open,
44 Closing,
45 Closed,
46 }
47
48 impl ConnectionState {
49 fn as_u8(self) -> u8 {
50 match self {
51 Self::Connecting => 0,
52 Self::Open => 1,
53 Self::Closing => 2,
54 Self::Closed => 3,
55 }
56 }
57
58 fn from_u8(value: u8) -> Self {
59 match value {
60 0 => Self::Connecting,
61 1 => Self::Open,
62 2 => Self::Closing,
63 _ => Self::Closed,
64 }
65 }
66
67 fn into_message<T>(self) -> SocketStreamMessage<T> {
68 match self {
69 Self::Connecting => SocketStreamMessage::Connecting,
70 Self::Open => SocketStreamMessage::Open,
71 Self::Closing => SocketStreamMessage::Closing,
72 Self::Closed => SocketStreamMessage::Close,
73 }
74 }
75 }
76
77 type WsSink = futures_util::stream::SplitSink<
78 tokio_tungstenite::WebSocketStream<
79 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
80 >,
81 Message,
82 >;
83
84 struct SocketCore<T> {
85 url: Url,
86 state: AtomicU8,
87 events: broadcast::Sender<SocketStreamMessage<T>>,
88 sink: Mutex<WsSink>,
89 log_level: LogLevel,
90 logger: Option<LoggerHandle>,
91 }
92
93 impl<T> std::fmt::Debug for SocketCore<T> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("SocketCore")
96 .field("url", &self.url)
97 .field(
98 "state",
99 &ConnectionState::from_u8(self.state.load(Ordering::SeqCst)),
100 )
101 .finish()
102 }
103 }
104
105 impl<T> SocketCore<T>
106 where
107 T: Clone + Send + 'static,
108 {
109 fn stream(&self) -> BoxStream<'static, SocketStreamMessage<T>> {
110 let initial =
111 ConnectionState::from_u8(self.state.load(Ordering::SeqCst)).into_message();
112 let receiver = self.events.subscribe();
113 Box::pin(futures_util::stream::unfold(
114 (Some(initial), receiver, false),
115 |(initial, mut receiver, closed)| async move {
116 if closed {
117 return None;
118 }
119
120 if let Some(message) = initial {
121 let closed = matches!(message, SocketStreamMessage::Close);
122 return Some((message, (None, receiver, closed)));
123 }
124
125 loop {
126 match receiver.recv().await {
127 Ok(message) => {
128 let closed = matches!(message, SocketStreamMessage::Close);
129 return Some((message, (None, receiver, closed)));
130 }
131 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {}
132 Err(tokio::sync::broadcast::error::RecvError::Closed) => return None,
133 }
134 }
135 },
136 ))
137 }
138 }
139
140 #[cfg(feature = "realtime")]
142 #[derive(Debug, Clone)]
143 pub struct RealtimeSocket {
144 inner: Arc<SocketCore<RealtimeServerEvent>>,
145 }
146
147 #[cfg(feature = "responses-ws")]
149 #[derive(Debug, Clone)]
150 pub struct ResponsesSocket {
151 inner: Arc<SocketCore<ResponsesServerEvent>>,
152 }
153
154 #[cfg(feature = "realtime")]
155 impl RealtimeSocket {
156 pub(crate) async fn connect(
158 client: &Client,
159 model: Option<String>,
160 mut options: RequestOptions,
161 ) -> Result<Self> {
162 match client.provider().kind() {
163 ProviderKind::Azure => {
164 if let Some(model) = model {
165 options.insert_query("deployment", model);
166 }
167 let socket =
168 connect_socket(client, "realtime.ws.connect", "/realtime", options).await?;
169 if !socket.url.query_pairs().any(|(key, _)| key == "deployment") {
170 return Err(Error::MissingRequiredField {
171 field: "deployment",
172 });
173 }
174 Ok(Self { inner: socket })
175 }
176 _ => {
177 let Some(model) = model else {
178 return Err(Error::MissingRequiredField { field: "model" });
179 };
180 options.insert_query("model", model);
181 Ok(Self {
182 inner: connect_socket(client, "realtime.ws.connect", "/realtime", options)
183 .await?,
184 })
185 }
186 }
187 }
188
189 pub fn url(&self) -> &Url {
191 &self.inner.url
192 }
193
194 pub fn stream(&self) -> BoxStream<'static, RealtimeStreamMessage> {
196 self.inner.stream()
197 }
198
199 pub async fn send_json<T>(&self, event: &T) -> Result<()>
201 where
202 T: Serialize,
203 {
204 send_json(&self.inner, event).await
205 }
206
207 pub async fn close(&self, options: SocketCloseOptions) -> Result<()> {
209 close_socket(&self.inner, options).await
210 }
211 }
212
213 #[cfg(feature = "responses-ws")]
214 impl ResponsesSocket {
215 pub(crate) async fn connect(client: &Client, options: RequestOptions) -> Result<Self> {
217 Ok(Self {
218 inner: connect_socket(client, "responses.ws.connect", "/responses", options)
219 .await?,
220 })
221 }
222
223 pub fn url(&self) -> &Url {
225 &self.inner.url
226 }
227
228 pub fn stream(&self) -> BoxStream<'static, ResponsesStreamMessage> {
230 self.inner.stream()
231 }
232
233 pub async fn send_json<T>(&self, event: &T) -> Result<()>
235 where
236 T: Serialize,
237 {
238 send_json(&self.inner, event).await
239 }
240
241 pub async fn close(&self, options: SocketCloseOptions) -> Result<()> {
243 close_socket(&self.inner, options).await
244 }
245 }
246
247 async fn connect_socket<T>(
248 client: &Client,
249 endpoint_id: &'static str,
250 path: &str,
251 options: RequestOptions,
252 ) -> Result<Arc<SocketCore<T>>>
253 where
254 T: serde::de::DeserializeOwned + Clone + Send + 'static,
255 {
256 let context =
257 prepare_request_context(&client.inner, endpoint_id, path.into(), None, &options)
258 .await?;
259 let url = build_websocket_url(client.base_url(), &context.path, &context.query)?;
260 emit_socket_log(
261 client.inner.options.log_level,
262 client.inner.options.logger.clone(),
263 LogLevel::Debug,
264 "openai_core::websocket",
265 "建立 WebSocket 连接",
266 BTreeMap::from([
267 ("endpoint_id".into(), endpoint_id.to_string()),
268 ("url".into(), url.to_string()),
269 ]),
270 );
271 let request = build_websocket_request(&url, &context.headers)?;
272 let (stream, _) = connect_async(request)
273 .await
274 .map_err(|error| Error::WebSocket(WebSocketError::transport(error.to_string())))?;
275
276 let (sink, mut source) = stream.split();
277 let (sender, _) = broadcast::channel(128);
278 let inner = Arc::new(SocketCore {
279 url,
280 state: AtomicU8::new(ConnectionState::Open.as_u8()),
281 events: sender,
282 sink: Mutex::new(sink),
283 log_level: client.inner.options.log_level,
284 logger: client.inner.options.logger.clone(),
285 });
286 let reader_inner = inner.clone();
287
288 tokio::spawn(async move {
289 while let Some(message) = source.next().await {
290 match message {
291 Ok(Message::Text(text)) => {
292 handle_server_payload::<T>(&reader_inner, text.as_bytes());
293 }
294 Ok(Message::Binary(bytes)) => {
295 handle_server_payload::<T>(&reader_inner, bytes.as_ref());
296 }
297 Ok(Message::Close(frame)) => {
298 handle_close_frame(&reader_inner, frame);
299 break;
300 }
301 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
302 Ok(_) => {}
303 Err(error) => {
304 push_error(&reader_inner, WebSocketError::transport(error.to_string()));
305 mark_closed(&reader_inner);
306 break;
307 }
308 }
309 }
310
311 if ConnectionState::from_u8(reader_inner.state.load(Ordering::SeqCst))
312 != ConnectionState::Closed
313 {
314 mark_closed(&reader_inner);
315 }
316 });
317
318 Ok(inner)
319 }
320
321 fn handle_server_payload<T>(inner: &Arc<SocketCore<T>>, payload: &[u8])
322 where
323 T: serde::de::DeserializeOwned + Clone + Send + 'static,
324 {
325 let raw = match serde_json::from_slice::<WebSocketServerEvent>(payload) {
326 Ok(raw) => raw,
327 Err(error) => {
328 let error = Error::Serialization(SerializationError::new(format!(
329 "WebSocket 事件反序列化失败: {error}"
330 )));
331 push_error(inner, WebSocketError::protocol(error.to_string()));
332 return;
333 }
334 };
335
336 if raw.is_error() {
337 let message = raw
338 .error_message()
339 .unwrap_or_else(|| "WebSocket 收到错误事件".into());
340 emit_socket_log(
341 inner.log_level,
342 inner.logger.clone(),
343 LogLevel::Info,
344 "openai_core::websocket",
345 "收到 WebSocket 错误事件",
346 BTreeMap::from([("event_type".into(), raw.event_type.clone())]),
347 );
348 push_error(
349 inner,
350 WebSocketError::server(message, Some(raw.event_type.clone())),
351 );
352 return;
353 }
354
355 match serde_json::from_slice::<T>(payload) {
356 Ok(event) => {
357 emit_socket_log(
358 inner.log_level,
359 inner.logger.clone(),
360 LogLevel::Debug,
361 "openai_core::websocket",
362 "收到 WebSocket 事件",
363 BTreeMap::from([("event_type".into(), raw.event_type.clone())]),
364 );
365 let _ = inner.events.send(SocketStreamMessage::Message(event));
366 }
367 Err(error) => {
368 let error = Error::Serialization(SerializationError::new(format!(
369 "WebSocket 事件反序列化失败: {error}"
370 )));
371 push_error(inner, WebSocketError::protocol(error.to_string()));
372 }
373 }
374 }
375
376 fn push_error<T>(inner: &Arc<SocketCore<T>>, error: WebSocketError)
377 where
378 T: Clone + Send + 'static,
379 {
380 let _ = inner.events.send(SocketStreamMessage::Error(error));
381 }
382
383 fn handle_close_frame<T>(inner: &Arc<SocketCore<T>>, frame: Option<CloseFrame>)
384 where
385 T: Clone + Send + 'static,
386 {
387 let state = ConnectionState::from_u8(inner.state.load(Ordering::SeqCst));
388 if state != ConnectionState::Closing
389 && let Some(frame) = frame.as_ref()
390 && let Some(error) = map_close_frame_error(frame)
391 {
392 push_error(inner, error);
393 }
394 mark_closed(inner);
395 }
396
397 fn map_close_frame_error(frame: &CloseFrame) -> Option<WebSocketError> {
398 if frame.code == CloseCode::Normal {
399 return None;
400 }
401
402 let code = u16::from(frame.code);
403 let reason = frame.reason.to_string();
404 let message = if reason.is_empty() {
405 format!("WebSocket 连接被关闭: code={code}")
406 } else {
407 format!("WebSocket 连接被关闭: code={code}, reason={reason}")
408 };
409 Some(WebSocketError::protocol(message))
410 }
411
412 fn mark_closed<T>(inner: &Arc<SocketCore<T>>)
413 where
414 T: Clone + Send + 'static,
415 {
416 inner
417 .state
418 .store(ConnectionState::Closed.as_u8(), Ordering::SeqCst);
419 let _ = inner.events.send(SocketStreamMessage::Close);
420 }
421
422 async fn send_json<T, U>(inner: &Arc<SocketCore<T>>, event: &U) -> Result<()>
423 where
424 T: Clone + Send + 'static,
425 U: Serialize,
426 {
427 let payload = serde_json::to_string(event)
428 .map_err(|error| Error::Serialization(SerializationError::new(error.to_string())))?;
429 emit_socket_log(
430 inner.log_level,
431 inner.logger.clone(),
432 LogLevel::Debug,
433 "openai_core::websocket",
434 "发送 WebSocket 消息",
435 BTreeMap::from([("url".into(), inner.url.to_string())]),
436 );
437 let mut sink = inner.sink.lock().await;
438 sink.send(Message::Text(payload.into()))
439 .await
440 .map_err(|error| Error::WebSocket(WebSocketError::transport(error.to_string())))
441 }
442
443 async fn close_socket<T>(inner: &Arc<SocketCore<T>>, options: SocketCloseOptions) -> Result<()>
444 where
445 T: Clone + Send + 'static,
446 {
447 inner
448 .state
449 .store(ConnectionState::Closing.as_u8(), Ordering::SeqCst);
450 let _ = inner.events.send(SocketStreamMessage::Closing);
451 emit_socket_log(
452 inner.log_level,
453 inner.logger.clone(),
454 LogLevel::Info,
455 "openai_core::websocket",
456 "关闭 WebSocket 连接",
457 BTreeMap::from([
458 ("url".into(), inner.url.to_string()),
459 ("code".into(), options.code.to_string()),
460 ]),
461 );
462
463 let mut sink = inner.sink.lock().await;
464 sink.send(Message::Close(Some(CloseFrame {
465 code: CloseCode::from(options.code),
466 reason: Utf8Bytes::from(options.reason),
467 })))
468 .await
469 .map_err(|error| Error::WebSocket(WebSocketError::transport(error.to_string())))?;
470 Ok(())
471 }
472
473 fn build_websocket_url(
474 base_url: &str,
475 path: &str,
476 query: &BTreeMap<String, String>,
477 ) -> Result<Url> {
478 let joined = join_url(base_url, path)?;
479 let mut url = Url::parse(&joined)
480 .map_err(|error| Error::InvalidConfig(format!("WebSocket URL 无效: {error}")))?;
481 match url.scheme() {
482 "http" => {
483 let _ = url.set_scheme("ws");
484 }
485 "https" => {
486 let _ = url.set_scheme("wss");
487 }
488 "ws" | "wss" => {}
489 scheme => {
490 return Err(Error::InvalidConfig(format!(
491 "不支持的 WebSocket 基础协议: {scheme}"
492 )));
493 }
494 }
495
496 if !query.is_empty() {
497 let mut pairs = url.query_pairs_mut();
498 pairs.clear();
499 for (key, value) in query {
500 pairs.append_pair(key, value);
501 }
502 }
503 Ok(url)
504 }
505
506 fn emit_socket_log(
507 configured_level: LogLevel,
508 logger: Option<LoggerHandle>,
509 level: LogLevel,
510 target: &'static str,
511 message: impl Into<String>,
512 fields: BTreeMap<String, String>,
513 ) {
514 if !configured_level.allows(level) {
515 return;
516 }
517
518 let record = LogRecord {
519 level,
520 target,
521 message: message.into(),
522 fields,
523 };
524 if let Some(logger) = &logger {
525 logger.log(&record);
526 }
527
528 let rendered_fields = if record.fields.is_empty() {
529 String::new()
530 } else {
531 format!(
532 " {}",
533 record
534 .fields
535 .iter()
536 .map(|(key, value)| format!("{key}={value}"))
537 .collect::<Vec<_>>()
538 .join(" ")
539 )
540 };
541 let rendered = format!("[{}] {}{}", target, record.message, rendered_fields);
542 match level {
543 LogLevel::Off => {}
544 LogLevel::Error => error!("{rendered}"),
545 LogLevel::Warn => warn!("{rendered}"),
546 LogLevel::Info => info!("{rendered}"),
547 LogLevel::Debug => debug!("{rendered}"),
548 }
549 }
550
551 fn build_websocket_request(
552 url: &Url,
553 headers: &BTreeMap<String, String>,
554 ) -> Result<http::Request<()>> {
555 let mut request = url.as_str().into_client_request().map_err(|error| {
556 Error::InvalidConfig(format!("构建 WebSocket 握手请求失败: {error}"))
557 })?;
558 for (key, value) in headers {
559 request.headers_mut().insert(
560 http::header::HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
561 Error::InvalidConfig(format!("构建 WebSocket 握手请求失败: {error}"))
562 })?,
563 http::header::HeaderValue::from_str(value).map_err(|error| {
564 Error::InvalidConfig(format!("构建 WebSocket 握手请求失败: {error}"))
565 })?,
566 );
567 }
568 Ok(request)
569 }
570
571 #[cfg(test)]
572 mod tests {
573 use std::collections::BTreeMap;
574
575 use super::*;
576 use crate::error::WebSocketErrorKind;
577
578 #[test]
579 fn test_should_build_ws_url_from_https_base_url() {
580 let url = build_websocket_url(
581 "https://api.openai.com/v1",
582 "/realtime",
583 &BTreeMap::from([("model".into(), "gpt-4o-realtime-preview".into())]),
584 )
585 .unwrap();
586
587 assert_eq!(
588 url.as_str(),
589 "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview"
590 );
591 }
592
593 #[test]
594 fn test_should_reject_unsupported_websocket_base_scheme() {
595 let error = build_websocket_url("ftp://example.com", "/realtime", &BTreeMap::new())
596 .unwrap_err();
597
598 assert!(matches!(error, Error::InvalidConfig(_)));
599 assert!(error.to_string().contains("ftp"));
600 }
601
602 #[test]
603 fn test_should_reject_invalid_websocket_headers() {
604 let error = build_websocket_request(
605 &Url::parse("ws://example.com/realtime").unwrap(),
606 &BTreeMap::from([("x-test".into(), "bad\nvalue".into())]),
607 )
608 .unwrap_err();
609
610 assert!(matches!(error, Error::InvalidConfig(_)));
611 }
612
613 #[test]
614 fn test_should_parse_error_message_from_event() {
615 let event = WebSocketServerEvent {
616 event_type: "error".into(),
617 data: BTreeMap::from([(
618 "error".into(),
619 serde_json::json!({
620 "message": "bad request"
621 }),
622 )]),
623 };
624
625 assert_eq!(event.error_message().as_deref(), Some("bad request"));
626 }
627
628 #[test]
629 fn test_should_map_abnormal_close_frame_to_protocol_error() {
630 let error = map_close_frame_error(&CloseFrame {
631 code: CloseCode::from(1008),
632 reason: Utf8Bytes::from("quota exceeded"),
633 })
634 .unwrap();
635
636 assert_eq!(error.kind, WebSocketErrorKind::Protocol);
637 assert!(error.message.contains("1008"));
638 assert!(error.message.contains("quota exceeded"));
639 }
640
641 #[test]
642 fn test_should_ignore_normal_close_frame_for_error_mapping() {
643 let error = map_close_frame_error(&CloseFrame {
644 code: CloseCode::Normal,
645 reason: Utf8Bytes::from("OK"),
646 });
647
648 assert!(error.is_none());
649 }
650 }
651}
652
653#[cfg(not(any(feature = "realtime", feature = "responses-ws")))]
654mod enabled {
655 use futures_util::stream::{self, BoxStream};
656 use serde::Serialize;
657
658 use super::{Client, Error, RequestOptions, Result, Url};
659 use crate::websocket::{RealtimeStreamMessage, ResponsesStreamMessage, SocketCloseOptions};
660
661 #[derive(Debug, Clone)]
663 pub struct RealtimeSocket {
664 url: Url,
665 }
666
667 #[derive(Debug, Clone)]
669 pub struct ResponsesSocket {
670 url: Url,
671 }
672
673 impl RealtimeSocket {
674 pub(crate) async fn connect(
676 _client: &Client,
677 _model: Option<String>,
678 _options: RequestOptions,
679 ) -> Result<Self> {
680 Err(Error::InvalidConfig(
681 "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
682 ))
683 }
684
685 pub fn url(&self) -> &Url {
687 &self.url
688 }
689
690 pub fn stream(&self) -> BoxStream<'static, RealtimeStreamMessage> {
692 Box::pin(stream::empty())
693 }
694
695 pub async fn send_json<T>(&self, _event: &T) -> Result<()>
697 where
698 T: Serialize,
699 {
700 Err(Error::InvalidConfig(
701 "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
702 ))
703 }
704
705 pub async fn close(&self, _options: SocketCloseOptions) -> Result<()> {
707 Ok(())
708 }
709 }
710
711 impl ResponsesSocket {
712 pub(crate) async fn connect(_client: &Client, _options: RequestOptions) -> Result<Self> {
714 Err(Error::InvalidConfig(
715 "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
716 ))
717 }
718
719 pub fn url(&self) -> &Url {
721 &self.url
722 }
723
724 pub fn stream(&self) -> BoxStream<'static, ResponsesStreamMessage> {
726 Box::pin(stream::empty())
727 }
728
729 pub async fn send_json<T>(&self, _event: &T) -> Result<()>
731 where
732 T: Serialize,
733 {
734 Err(Error::InvalidConfig(
735 "当前未启用 WebSocket 支持,请开启 `realtime` 或 `responses-ws` feature".into(),
736 ))
737 }
738
739 pub async fn close(&self, _options: SocketCloseOptions) -> Result<()> {
741 Ok(())
742 }
743 }
744}
745
746#[cfg(feature = "realtime")]
747pub use enabled::RealtimeSocket;
748#[cfg(feature = "responses-ws")]
749pub use enabled::ResponsesSocket;