use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use futures::SinkExt;
use futures::stream::{BoxStream, StreamExt};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use super::types::websocket::WsResponseCreate;
use super::types::{CreateResponseRequest, ResponseStreamEvent};
const DEFAULT_WS_URL: &str = "wss://api.openai.com/v1/responses";
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(59 * 60);
type WsStream =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
struct WsConnection {
stream: WsStream,
connected_at: Instant,
}
pub struct ResponsesWebSocket {
api_key: String,
ws_url: String,
organization: Option<String>,
connection: Arc<Mutex<Option<WsConnection>>>,
}
impl ResponsesWebSocket {
pub fn new(api_key: String) -> Self {
Self {
api_key,
ws_url: DEFAULT_WS_URL.to_string(),
organization: None,
connection: Arc::new(Mutex::new(None)),
}
}
pub fn with_ws_url(mut self, url: String) -> Self {
self.ws_url = url;
self
}
pub fn with_organization(mut self, org: String) -> Self {
self.organization = Some(org);
self
}
async fn ensure_connected(&self) -> Result<()> {
let mut conn = self.connection.lock().await;
if let Some(ref c) = *conn {
if c.connected_at.elapsed() < CONNECTION_TIMEOUT {
return Ok(());
}
tracing::info!("WebSocket connection approaching 60-min limit, reconnecting");
}
let mut builder = tokio_tungstenite::tungstenite::http::Request::builder()
.uri(&self.ws_url)
.header("Authorization", format!("Bearer {}", self.api_key));
if let Some(ref org) = self.organization {
builder = builder.header("OpenAI-Organization", org.as_str());
}
let request = builder
.body(())
.context("Failed to build WebSocket request")?;
let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
.await
.context("Failed to connect to OpenAI Responses WebSocket")?;
tracing::info!(url = %self.ws_url, "WebSocket connection established");
*conn = Some(WsConnection {
stream: ws_stream,
connected_at: Instant::now(),
});
Ok(())
}
pub fn create_stream<'a>(
&'a self,
req: &'a CreateResponseRequest,
) -> BoxStream<'a, Result<ResponseStreamEvent>> {
Box::pin(async_stream::stream! {
if let Err(e) = self.ensure_connected().await {
yield Err(e);
return;
}
let ws_msg = WsResponseCreate::new(req.clone());
let json = match serde_json::to_string(&ws_msg) {
Ok(j) => j,
Err(e) => {
yield Err(anyhow::anyhow!("Failed to serialize WebSocket request: {}", e));
return;
}
};
let mut conn = self.connection.lock().await;
let ws = match conn.as_mut() {
Some(c) => &mut c.stream,
None => {
yield Err(anyhow::anyhow!("WebSocket connection lost"));
return;
}
};
if let Err(e) = ws.send(WsMessage::Text(json.into())).await {
*conn = None;
yield Err(anyhow::anyhow!("Failed to send WebSocket message: {}", e));
return;
}
loop {
match ws.next().await {
Some(Ok(WsMessage::Text(text))) => {
match serde_json::from_str::<ResponseStreamEvent>(&text) {
Ok(event) => {
let is_terminal = matches!(
&event,
ResponseStreamEvent::ResponseCompleted { .. }
| ResponseStreamEvent::ResponseFailed { .. }
| ResponseStreamEvent::ResponseIncomplete { .. }
);
yield Ok(event);
if is_terminal {
return;
}
}
Err(e) => {
tracing::warn!(
"Failed to parse WebSocket event: {} — data: {}",
e,
&text[..text.len().min(200)]
);
}
}
}
Some(Ok(WsMessage::Close(frame))) => {
let reason = frame
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "unknown".to_string());
tracing::info!(reason = %reason, "WebSocket closed by server");
*conn = None;
yield Err(anyhow::anyhow!("WebSocket closed by server: {}", reason));
return;
}
Some(Ok(WsMessage::Ping(data))) => {
if let Err(e) = ws.send(WsMessage::Pong(data)).await {
tracing::warn!("Failed to send pong: {}", e);
}
}
Some(Ok(_)) => {
}
Some(Err(e)) => {
*conn = None;
yield Err(anyhow::anyhow!("WebSocket error: {}", e));
return;
}
None => {
*conn = None;
yield Err(anyhow::anyhow!("WebSocket stream ended unexpectedly"));
return;
}
}
}
})
}
pub async fn disconnect(&self) {
let mut conn = self.connection.lock().await;
if let Some(mut c) = conn.take() {
let _ = c.stream.close(None).await;
}
}
pub async fn is_connected(&self) -> bool {
let conn = self.connection.lock().await;
conn.as_ref()
.is_some_and(|c| c.connected_at.elapsed() < CONNECTION_TIMEOUT)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_url() {
let ws = ResponsesWebSocket::new("test-key".to_string());
assert_eq!(ws.ws_url, "wss://api.openai.com/v1/responses");
}
#[test]
fn test_custom_url() {
let ws = ResponsesWebSocket::new("test-key".to_string())
.with_ws_url("wss://custom.api.com/v1/responses".to_string());
assert_eq!(ws.ws_url, "wss://custom.api.com/v1/responses");
}
#[test]
fn test_organization() {
let ws = ResponsesWebSocket::new("test-key".to_string())
.with_organization("org-123".to_string());
assert_eq!(ws.organization, Some("org-123".to_string()));
}
#[tokio::test]
async fn test_not_connected_initially() {
let ws = ResponsesWebSocket::new("test-key".to_string());
assert!(!ws.is_connected().await);
}
}