use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use futures_core::Stream;
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
use crate::config::Config;
use crate::error::OpenAIError;
use crate::types::responses::{Response, ResponseCreateRequest, ResponseStreamEvent, ResponseTool};
const DEFAULT_WS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(300);
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = SplitSink<WsStream, Message>;
type WsReader = futures_util::stream::SplitStream<WsStream>;
pub struct WsSession {
sink: WsSink,
reader: WsReader,
response_timeout: Duration,
}
impl WsSession {
pub async fn connect(config: &dyn Config) -> Result<Self, OpenAIError> {
#[cfg(not(target_arch = "wasm32"))]
crate::ensure_tls_provider();
let ws_url = build_ws_url(config);
tracing::debug!(url = %ws_url, "connecting to WebSocket");
let request = tokio_tungstenite::tungstenite::http::Request::builder()
.uri(&ws_url)
.header("Authorization", format!("Bearer {}", config.api_key()))
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tokio_tungstenite::tungstenite::handshake::client::generate_key(),
)
.header(
"Host",
reqwest::Url::parse(&ws_url)
.map(|u| u.host_str().unwrap_or("api.openai.com").to_string())
.unwrap_or_else(|_| "api.openai.com".to_string()),
)
.header("OpenAI-Beta", "responses=v1")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.body(())
.map_err(|e| OpenAIError::WebSocketError(format!("build request: {e}")))?;
let (stream, _response) = connect_async_with_config(request, None, false)
.await
.map_err(|e| OpenAIError::WebSocketError(format!("connection failed: {e}")))?;
let (sink, reader) = stream.split();
tracing::info!("WebSocket session connected");
let response_timeout = DEFAULT_WS_RESPONSE_TIMEOUT;
Ok(Self {
sink,
reader,
response_timeout,
})
}
pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
self.response_timeout = timeout;
self
}
pub async fn send(&mut self, request: ResponseCreateRequest) -> Result<Response, OpenAIError> {
self.send_request(&request).await?;
self.read_until_completed().await
}
pub async fn send_stream(
&mut self,
request: ResponseCreateRequest,
) -> Result<WsEventStream<'_>, OpenAIError> {
self.send_request(&request).await?;
Ok(WsEventStream {
reader: &mut self.reader,
done: false,
})
}
pub async fn warmup(
&mut self,
model: impl Into<String>,
tools: Option<Vec<ResponseTool>>,
instructions: Option<String>,
) -> Result<(), OpenAIError> {
let mut warmup_body = serde_json::json!({
"model": model.into(),
});
if let Some(tools) = tools {
warmup_body["tools"] = serde_json::to_value(&tools)
.map_err(|e| OpenAIError::WebSocketError(format!("serialize tools: {e}")))?;
}
if let Some(instructions) = instructions {
warmup_body["instructions"] = serde_json::Value::String(instructions);
}
let text = serde_json::to_string(&warmup_body)?;
self.sink
.send(Message::Text(text.into()))
.await
.map_err(|e| OpenAIError::WebSocketError(format!("send warmup: {e}")))?;
let _response = self.read_until_completed().await?;
Ok(())
}
pub async fn close(mut self) -> Result<(), OpenAIError> {
self.sink
.send(Message::Close(None))
.await
.map_err(|e| OpenAIError::WebSocketError(format!("close: {e}")))?;
Ok(())
}
async fn send_request(&mut self, request: &ResponseCreateRequest) -> Result<(), OpenAIError> {
let mut value = serde_json::to_value(request)?;
if let serde_json::Value::Object(ref mut map) = value {
map.insert(
"type".to_string(),
serde_json::Value::String("response.create".to_string()),
);
if let Some(serde_json::Value::Number(n)) = map.get("temperature") {
if let Some(f) = n.as_f64() {
if f.fract() != 0.0 {
tracing::debug!(
temperature = f,
"stripping decimal temperature (OpenAI WS bug)"
);
map.remove("temperature");
}
}
}
}
let text = serde_json::to_string(&value)?;
tracing::debug!(len = text.len(), "sending WS request");
tracing::trace!(body = %text, "WS request body");
self.sink
.send(Message::Text(text.into()))
.await
.map_err(|e| OpenAIError::WebSocketError(format!("send: {e}")))?;
Ok(())
}
async fn read_until_completed(&mut self) -> Result<Response, OpenAIError> {
tokio::time::timeout(self.response_timeout, self.read_until_completed_inner())
.await
.map_err(|_| {
OpenAIError::WebSocketError(format!(
"timed out waiting for response.completed after {:?}",
self.response_timeout
))
})?
}
async fn read_until_completed_inner(&mut self) -> Result<Response, OpenAIError> {
loop {
let msg = self
.reader
.next()
.await
.ok_or_else(|| {
OpenAIError::WebSocketError(
"connection closed before response.completed".into(),
)
})?
.map_err(|e| OpenAIError::WebSocketError(format!("read: {e}")))?;
match msg {
Message::Text(text) => {
let event: ResponseStreamEvent = serde_json::from_str(&text)?;
match event {
ResponseStreamEvent::ResponseCompleted(evt) => {
return Ok(evt.response);
}
ResponseStreamEvent::ResponseFailed(evt) => {
let message = evt
.response
.error
.as_ref()
.map(|e| e.message.clone())
.unwrap_or_else(|| "unknown error".into());
let code = evt.response.error.as_ref().map(|e| e.code.clone());
return Err(OpenAIError::ApiError {
status: 0,
message,
type_: Some("response_failed".into()),
code,
request_id: None,
});
}
other => {
tracing::trace!(event_type = %other.event_type(), "ws event (ignored in send)");
}
}
}
Message::Ping(data) => {
self.sink
.send(Message::Pong(data))
.await
.map_err(|e| OpenAIError::WebSocketError(format!("pong: {e}")))?;
}
Message::Close(frame) => {
let reason = frame
.as_ref()
.map(|f| format!("code={}, reason={}", f.code, f.reason))
.unwrap_or_else(|| "no close frame".into());
tracing::warn!(reason = %reason, "WS server closed connection");
return Err(OpenAIError::WebSocketError(format!(
"server closed connection: {reason}"
)));
}
_ => {}
}
}
}
}
pub struct WsEventStream<'a> {
reader: &'a mut WsReader,
done: bool,
}
impl<'a> Stream for WsEventStream<'a> {
type Item = Result<ResponseStreamEvent, OpenAIError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.done {
return Poll::Ready(None);
}
match this.reader.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(msg))) => match msg {
Message::Text(text) => match serde_json::from_str::<ResponseStreamEvent>(&text) {
Ok(event) => {
if matches!(
event,
ResponseStreamEvent::ResponseCompleted(_)
| ResponseStreamEvent::ResponseFailed(_)
) {
this.done = true;
}
Poll::Ready(Some(Ok(event)))
}
Err(e) => Poll::Ready(Some(Err(OpenAIError::JsonError(e)))),
},
Message::Close(_) => {
this.done = true;
Poll::Ready(None)
}
Message::Ping(_) => {
cx.waker().wake_by_ref();
Poll::Pending
}
_ => {
cx.waker().wake_by_ref();
Poll::Pending
}
},
Poll::Ready(Some(Err(e))) => {
this.done = true;
Poll::Ready(Some(Err(OpenAIError::WebSocketError(format!("read: {e}")))))
}
Poll::Ready(None) => {
this.done = true;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
fn build_ws_url(config: &dyn Config) -> String {
let base = config.base_url();
let ws_base = if base.starts_with("https://") {
format!("wss://{}", &base["https://".len()..])
} else if base.starts_with("http://") {
format!("ws://{}", &base["http://".len()..])
} else {
base.to_string()
};
format!("{ws_base}/responses")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ClientConfig;
#[test]
fn test_build_ws_url_default() {
let config = ClientConfig::new("sk-test-key-123");
let url = build_ws_url(&config);
assert_eq!(url, "wss://api.openai.com/v1/responses");
}
#[test]
fn test_build_ws_url_custom_base() {
let config = ClientConfig::new("sk-abc").base_url("https://custom.api.com/v2");
let url = build_ws_url(&config);
assert_eq!(url, "wss://custom.api.com/v2/responses");
}
#[test]
fn test_build_ws_url_http() {
let config = ClientConfig::new("sk-local").base_url("http://localhost:8080/v1");
let url = build_ws_url(&config);
assert_eq!(url, "ws://localhost:8080/v1/responses");
}
#[test]
fn test_build_ws_url_no_scheme() {
let config = ClientConfig::new("sk-x").base_url("wss://already-wss.com/v1");
let url = build_ws_url(&config);
assert_eq!(url, "wss://already-wss.com/v1/responses");
}
#[tokio::test]
#[ignore]
async fn ws_live() {
let client = crate::OpenAI::from_env().expect("OPENAI_API_KEY");
eprintln!("Connecting WS...");
let mut session = WsSession::connect(&*client.config)
.await
.expect("ws connect failed");
eprintln!("Connected. Sending request...");
let req = ResponseCreateRequest::new("gpt-5.4-mini").input("Say hello in exactly 3 words");
let resp = session.send(req).await.expect("ws send failed");
let text = resp.output_text();
eprintln!("Response: {text}");
assert!(!text.is_empty(), "Expected non-empty response");
session.close().await.ok();
}
#[tokio::test]
#[ignore]
async fn ws_live_large() {
let client = crate::OpenAI::from_env().expect("OPENAI_API_KEY");
let mut session = WsSession::connect(&*client.config)
.await
.expect("ws connect");
let big_system = "X".repeat(60_000);
let req = ResponseCreateRequest::new("gpt-5.4-mini")
.instructions(&big_system)
.input("Say hi in 3 words")
.max_output_tokens(50);
eprintln!("Sending ~70KB request via WS...");
match session.send(req).await {
Ok(resp) => eprintln!("OK with large payload: {}", resp.output_text()),
Err(e) => panic!("FAILED with large payload: {e}"),
}
}
#[tokio::test]
#[ignore]
async fn ws_live_tools() {
let client = crate::OpenAI::from_env().expect("OPENAI_API_KEY");
let mut session = WsSession::connect(&*client.config)
.await
.expect("ws connect");
let req = ResponseCreateRequest::new("gpt-5.4-mini")
.input("What is 2+2?")
.tools(vec![ResponseTool::Function {
name: "calculate".into(),
description: Some("Math calculation".into()),
parameters: Some(serde_json::json!({"type":"object","properties":{"expr":{"type":"string"}},"required":["expr"]})),
strict: None,
}])
.store(true);
eprintln!("Sending WS request with tools...");
match session.send(req).await {
Ok(resp) => {
let fcs = resp.function_calls();
eprintln!(
"OK tools: {} function calls, text={}",
fcs.len(),
resp.output_text()
);
}
Err(e) => panic!("FAILED with tools: {e}"),
}
}
#[tokio::test]
#[ignore]
async fn ws_live_server_sim() {
let client = crate::OpenAI::from_env().expect("OPENAI_API_KEY");
let mut session = WsSession::connect(&*client.config)
.await
.expect("ws connect");
let big_system = "You are a sales coach. ".repeat(2000); let input = vec![
serde_json::json!({"type": "message", "role": "system", "content": big_system}),
serde_json::json!({"type": "message", "role": "user", "content": "Rep said hello to customer"}),
];
let tools = vec![ResponseTool::Function {
name: "whisper".into(),
description: Some("Coach the rep".into()),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"]
})),
strict: None,
}];
let mut req = ResponseCreateRequest::new("gpt-5.4-mini");
req.input = Some(crate::types::responses::ResponseInput::Items(input));
req = req.tools(tools).store(true).max_output_tokens(100);
let payload = serde_json::to_string(&req).unwrap();
eprintln!(
"Sending server-sim payload via WS ({} bytes)...",
payload.len()
);
match session.send(req).await {
Ok(resp) => {
let fcs = resp.function_calls();
eprintln!(
"OK: {} function_calls, text={}",
fcs.len(),
resp.output_text()
);
}
Err(e) => panic!("FAILED server-sim: {e}"),
}
}
#[tokio::test]
#[ignore]
async fn ws_live_delay() {
let client = crate::OpenAI::from_env().expect("OPENAI_API_KEY");
let mut session = WsSession::connect(&*client.config)
.await
.expect("ws connect");
eprintln!("Connected. Waiting 5 seconds...");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
eprintln!("Sending after 5s delay...");
let req = ResponseCreateRequest::new("gpt-5.4-mini").input("Say hi");
match session.send(req).await {
Ok(resp) => eprintln!("OK after delay: {}", resp.output_text()),
Err(e) => panic!("FAILED after 5s delay: {e}"),
}
}
#[test]
fn test_request_serialization_for_ws() {
let request = ResponseCreateRequest::new("gpt-4o-mini")
.input("Hello")
.instructions("Be concise")
.temperature(0.5);
let json = serde_json::to_string(&request).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "gpt-4o-mini");
assert_eq!(parsed["input"], "Hello");
assert_eq!(parsed["instructions"], "Be concise");
assert_eq!(parsed["temperature"], 0.5);
assert!(parsed.get("stream").is_none());
}
}