Skip to main content

rig/providers/openai/responses_api/
websocket.rs

1//! WebSocket session support for the OpenAI Responses API.
2//!
3//! This module implements OpenAI's `/v1/responses` WebSocket mode as a stateful,
4//! sequential session. Each connection supports a single in-flight response at a
5//! time, which matches OpenAI's current protocol constraints.
6
7use crate::completion::{self, CompletionError};
8use crate::http_client::HttpClientExt;
9use crate::providers::openai::responses_api::streaming::{
10    ItemChunk, ResponseChunk, ResponseChunkKind, StreamingCompletionChunk,
11};
12use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
13use futures::{SinkExt, StreamExt};
14use serde::{Deserialize, Serialize};
15use serde_json::{Map, Value};
16use std::time::Duration;
17use tokio::net::TcpStream;
18use tokio_tungstenite::{
19    MaybeTlsStream, WebSocketStream, connect_async,
20    tungstenite::{self, Message, client::IntoClientRequest},
21};
22use tracing::Level;
23use url::Url;
24
25use super::{CompletionResponse, ResponseError, ResponseStatus, ResponsesCompletionModel};
26
27type OpenAIWebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
28const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
29
30/// Options for a `response.create` message sent over OpenAI WebSocket mode.
31#[derive(Debug, Clone, Default, Serialize, Deserialize)]
32pub struct ResponsesWebSocketCreateOptions {
33    /// When set to `false`, OpenAI prepares request state without generating a model output.
34    ///
35    /// This is the "warmup" mode described in the OpenAI WebSocket mode guide.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub generate: Option<bool>,
38}
39
40impl ResponsesWebSocketCreateOptions {
41    /// Creates warmup options equivalent to `generate: false`.
42    #[must_use]
43    pub fn warmup() -> Self {
44        Self {
45            generate: Some(false),
46        }
47    }
48}
49
50#[derive(Debug, Clone, Serialize)]
51struct ResponsesWebSocketClientEvent {
52    #[serde(rename = "type")]
53    kind: ResponsesWebSocketClientEventKind,
54    #[serde(flatten)]
55    request: super::CompletionRequest,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    generate: Option<bool>,
58}
59
60#[derive(Debug, Clone, Serialize)]
61enum ResponsesWebSocketClientEventKind {
62    #[serde(rename = "response.create")]
63    ResponseCreate,
64}
65
66/// A protocol error event emitted by OpenAI WebSocket mode.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ResponsesWebSocketErrorEvent {
69    /// The event type.
70    #[serde(rename = "type")]
71    pub kind: ResponsesWebSocketErrorEventKind,
72    /// The provider error payload.
73    pub error: ResponsesWebSocketErrorPayload,
74}
75
76impl std::fmt::Display for ResponsesWebSocketErrorEvent {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        self.error.fmt(f)
79    }
80}
81
82/// The event kind for an OpenAI WebSocket protocol error.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum ResponsesWebSocketErrorEventKind {
85    #[serde(rename = "error")]
86    Error,
87}
88
89/// The payload carried by an OpenAI WebSocket protocol error event.
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct ResponsesWebSocketErrorPayload {
92    /// Provider-specific error code when supplied.
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub code: Option<String>,
95    /// Human-readable error message.
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub message: Option<String>,
98    /// Any extra fields supplied by the provider.
99    #[serde(flatten, default)]
100    pub extra: Map<String, Value>,
101}
102
103impl std::fmt::Display for ResponsesWebSocketErrorPayload {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match (&self.code, &self.message) {
106            (Some(code), Some(message)) => write!(f, "{code}: {message}"),
107            (None, Some(message)) => f.write_str(message),
108            (Some(code), None) => f.write_str(code),
109            (None, None) => f.write_str("OpenAI websocket error"),
110        }
111    }
112}
113
114/// The optional `response.done` event emitted by OpenAI WebSocket mode.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ResponsesWebSocketDoneEvent {
117    /// The event type.
118    #[serde(rename = "type")]
119    pub kind: ResponsesWebSocketDoneEventKind,
120    /// The provider payload for the finished response.
121    pub response: Value,
122}
123
124impl ResponsesWebSocketDoneEvent {
125    /// Returns the response ID if the payload includes one.
126    #[must_use]
127    pub fn response_id(&self) -> Option<&str> {
128        self.response.get("id").and_then(Value::as_str)
129    }
130
131    fn status(&self) -> Option<ResponseStatus> {
132        self.response
133            .get("status")
134            .cloned()
135            .and_then(|status| serde_json::from_value(status).ok())
136    }
137
138    fn as_completion_response(&self) -> Option<CompletionResponse> {
139        serde_json::from_value(self.response.clone()).ok()
140    }
141}
142
143/// The event kind for the terminal websocket event.
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum ResponsesWebSocketDoneEventKind {
146    #[serde(rename = "response.done")]
147    ResponseDone,
148}
149
150/// A server event emitted by OpenAI WebSocket mode.
151#[derive(Debug, Clone)]
152pub enum ResponsesWebSocketEvent {
153    /// A response lifecycle event such as `response.created` or `response.completed`.
154    Response(Box<ResponseChunk>),
155    /// A streaming item/delta event such as `response.output_text.delta`.
156    Item(ItemChunk),
157    /// A protocol-level websocket error event.
158    Error(ResponsesWebSocketErrorEvent),
159    /// An optional `response.done` event emitted by OpenAI over WebSockets.
160    Done(ResponsesWebSocketDoneEvent),
161}
162
163impl ResponsesWebSocketEvent {
164    /// Returns the response ID when the event includes one.
165    #[must_use]
166    pub fn response_id(&self) -> Option<&str> {
167        match self {
168            Self::Response(chunk) => Some(&chunk.response.id),
169            Self::Done(done) => done.response_id(),
170            Self::Item(_) | Self::Error(_) => None,
171        }
172    }
173
174    /// Returns `true` when this event ends the current in-flight websocket turn.
175    #[must_use]
176    pub fn is_terminal(&self) -> bool {
177        match self {
178            Self::Response(chunk) => matches!(
179                chunk.kind,
180                ResponseChunkKind::ResponseCompleted
181                    | ResponseChunkKind::ResponseFailed
182                    | ResponseChunkKind::ResponseIncomplete
183            ),
184            Self::Error(_) | Self::Done(_) => true,
185            Self::Item(_) => false,
186        }
187    }
188}
189
190/// A builder for an OpenAI Responses WebSocket session.
191///
192/// The default builder applies a 30 second connection timeout and leaves the
193/// per-event timeout disabled.
194pub struct ResponsesWebSocketSessionBuilder<T = reqwest::Client> {
195    model: ResponsesCompletionModel<T>,
196    connect_timeout: Option<Duration>,
197    event_timeout: Option<Duration>,
198}
199
200impl<T> ResponsesWebSocketSessionBuilder<T> {
201    pub(crate) fn new(model: ResponsesCompletionModel<T>) -> Self {
202        Self {
203            model,
204            connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
205            event_timeout: None,
206        }
207    }
208
209    /// Sets the timeout for establishing the websocket connection.
210    #[must_use]
211    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
212        self.connect_timeout = Some(timeout);
213        self
214    }
215
216    /// Disables the websocket connection timeout.
217    #[must_use]
218    pub fn without_connect_timeout(mut self) -> Self {
219        self.connect_timeout = None;
220        self
221    }
222
223    /// Sets the timeout for waiting on the next websocket event.
224    #[must_use]
225    pub fn event_timeout(mut self, timeout: Duration) -> Self {
226        self.event_timeout = Some(timeout);
227        self
228    }
229
230    /// Disables the websocket event timeout.
231    #[must_use]
232    pub fn without_event_timeout(mut self) -> Self {
233        self.event_timeout = None;
234        self
235    }
236}
237
238impl<T> ResponsesWebSocketSessionBuilder<T>
239where
240    T: HttpClientExt
241        + Clone
242        + std::fmt::Debug
243        + Default
244        + WasmCompatSend
245        + WasmCompatSync
246        + 'static,
247{
248    /// Opens the websocket session using the configured builder options.
249    pub async fn connect(self) -> Result<ResponsesWebSocketSession<T>, CompletionError> {
250        ResponsesWebSocketSession::connect_with_timeouts(
251            self.model,
252            self.connect_timeout,
253            self.event_timeout,
254        )
255        .await
256    }
257}
258
259/// A stateful OpenAI Responses WebSocket session.
260///
261/// This session keeps track of the most recent successful `response.id` so later
262/// turns can automatically chain via `previous_response_id` unless the request
263/// explicitly sets a different one.
264///
265/// Call [`ResponsesWebSocketSession::close`] when you are finished with the
266/// session so the websocket can complete a close handshake cleanly.
267pub struct ResponsesWebSocketSession<T = reqwest::Client> {
268    model: ResponsesCompletionModel<T>,
269    previous_response_id: Option<String>,
270    pending_done_response_id: Option<String>,
271    socket: OpenAIWebSocket,
272    in_flight: bool,
273    event_timeout: Option<Duration>,
274    closed: bool,
275    failed: bool,
276}
277
278impl<T> ResponsesWebSocketSession<T>
279where
280    T: HttpClientExt
281        + Clone
282        + std::fmt::Debug
283        + Default
284        + WasmCompatSend
285        + WasmCompatSync
286        + 'static,
287{
288    async fn connect_with_timeouts(
289        model: ResponsesCompletionModel<T>,
290        connect_timeout: Option<Duration>,
291        event_timeout: Option<Duration>,
292    ) -> Result<Self, CompletionError> {
293        let url = websocket_url(model.client.base_url())?;
294        let request = websocket_request(&url, model.client.headers())?;
295        let socket = connect_websocket(request, connect_timeout).await?;
296
297        Ok(Self {
298            model,
299            previous_response_id: None,
300            pending_done_response_id: None,
301            socket,
302            in_flight: false,
303            event_timeout,
304            closed: false,
305            failed: false,
306        })
307    }
308
309    /// Returns the most recent successful `response.id` tracked by this session.
310    #[must_use]
311    pub fn previous_response_id(&self) -> Option<&str> {
312        self.previous_response_id.as_deref()
313    }
314
315    /// Clears the cached `previous_response_id` so the next turn starts a fresh chain.
316    pub fn clear_previous_response_id(&mut self) {
317        self.previous_response_id = None;
318    }
319
320    /// Sends a `response.create` event for a Rig completion request.
321    pub async fn send(
322        &mut self,
323        completion_request: crate::completion::CompletionRequest,
324    ) -> Result<(), CompletionError> {
325        self.send_with_options(
326            completion_request,
327            ResponsesWebSocketCreateOptions::default(),
328        )
329        .await
330    }
331
332    /// Sends a `response.create` event with explicit websocket-mode options.
333    pub async fn send_with_options(
334        &mut self,
335        completion_request: crate::completion::CompletionRequest,
336        options: ResponsesWebSocketCreateOptions,
337    ) -> Result<(), CompletionError> {
338        self.ensure_open()?;
339
340        if self.in_flight {
341            return Err(CompletionError::ProviderError(
342                "An OpenAI websocket response is already in flight on this session".to_string(),
343            ));
344        }
345
346        let payload = ResponsesWebSocketClientEvent {
347            kind: ResponsesWebSocketClientEventKind::ResponseCreate,
348            request: self.prepare_request(completion_request)?,
349            generate: options.generate,
350        };
351
352        if tracing::enabled!(Level::TRACE) {
353            tracing::trace!(
354                target: "rig::completions",
355                "OpenAI websocket request: {}",
356                serde_json::to_string_pretty(&payload)?
357            );
358        }
359
360        let payload = serde_json::to_string(&payload)?;
361
362        if let Err(error) = self.socket.send(Message::text(payload)).await {
363            return Err(self.fail_session(websocket_provider_error(error)));
364        }
365        self.in_flight = true;
366
367        Ok(())
368    }
369
370    /// Reads the next server event for the current in-flight turn.
371    pub async fn next_event(&mut self) -> Result<ResponsesWebSocketEvent, CompletionError> {
372        self.ensure_open()?;
373
374        if !self.in_flight {
375            return Err(CompletionError::ProviderError(
376                "No OpenAI websocket response is currently in flight on this session".to_string(),
377            ));
378        }
379
380        loop {
381            let message = match self.read_next_message().await {
382                Ok(message) => message,
383                Err(error) => return Err(error),
384            };
385
386            let Some(message) = message else {
387                self.mark_closed();
388                return Err(CompletionError::ProviderError(
389                    "The OpenAI websocket connection closed before the turn finished".to_string(),
390                ));
391            };
392
393            let message = match message {
394                Ok(message) => message,
395                Err(error) => return Err(self.fail_session(websocket_provider_error(error))),
396            };
397            let payload = match websocket_message_to_text(message) {
398                Ok(Some(payload)) => payload,
399                Ok(None) => continue,
400                Err(error) => return Err(self.fail_session(error)),
401            };
402            let event = match parse_server_event(&payload) {
403                Ok(Some(event)) => event,
404                Ok(None) => continue,
405                Err(error) => return Err(self.fail_session(error)),
406            };
407            if let ResponsesWebSocketEvent::Done(done) = &event {
408                // OpenAI may emit `response.done` after the turn has already ended at
409                // `response.completed`. Ignore that trailing event on the next turn.
410                if self.pending_done_response_id.as_deref() == done.response_id() {
411                    self.pending_done_response_id = None;
412                    continue;
413                }
414            }
415            self.update_state_for_event(&event);
416            return Ok(event);
417        }
418    }
419
420    /// Sends a warmup turn (`generate: false`) and returns the resulting response ID.
421    pub async fn warmup(
422        &mut self,
423        completion_request: crate::completion::CompletionRequest,
424    ) -> Result<String, CompletionError> {
425        self.send_with_options(
426            completion_request,
427            ResponsesWebSocketCreateOptions::warmup(),
428        )
429        .await?;
430        let response = self.wait_for_completed_response().await?;
431        Ok(response.id)
432    }
433
434    /// Sends a completion turn and collects the final OpenAI response.
435    pub async fn completion(
436        &mut self,
437        completion_request: crate::completion::CompletionRequest,
438    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
439        self.send(completion_request).await?;
440        let response = self.wait_for_completed_response().await?;
441        response.try_into()
442    }
443
444    /// Closes the websocket connection.
445    ///
446    /// Call this when you are finished with the session so the websocket can
447    /// terminate with a clean close handshake.
448    pub async fn close(&mut self) -> Result<(), CompletionError> {
449        if self.closed {
450            return Ok(());
451        }
452
453        let result = self
454            .socket
455            .close(None)
456            .await
457            .map_err(websocket_provider_error);
458        self.mark_closed();
459        result
460    }
461
462    fn prepare_request(
463        &self,
464        completion_request: crate::completion::CompletionRequest,
465    ) -> Result<super::CompletionRequest, CompletionError> {
466        let mut request = self.model.create_completion_request(completion_request)?;
467
468        // WebSocket mode is always event-driven, so these HTTP/SSE-specific flags
469        // are ignored by the provider and only add noise to the payload.
470        request.stream = None;
471        request.additional_parameters.background = None;
472
473        if request.additional_parameters.previous_response_id.is_none() {
474            request.additional_parameters.previous_response_id = self.previous_response_id.clone();
475        }
476
477        Ok(request)
478    }
479
480    async fn wait_for_completed_response(&mut self) -> Result<CompletionResponse, CompletionError> {
481        loop {
482            match self.next_event().await? {
483                ResponsesWebSocketEvent::Response(chunk) => {
484                    if matches!(
485                        chunk.kind,
486                        ResponseChunkKind::ResponseCompleted
487                            | ResponseChunkKind::ResponseFailed
488                            | ResponseChunkKind::ResponseIncomplete
489                    ) {
490                        return terminal_response_result(chunk.response);
491                    }
492                }
493                ResponsesWebSocketEvent::Done(done) => {
494                    if let Some(response) = done.as_completion_response() {
495                        return terminal_response_result(response);
496                    }
497
498                    let message = if let Some(response_id) = done.response_id() {
499                        format!(
500                            "OpenAI websocket turn ended with response.done before a terminal response body was available (response_id={response_id})"
501                        )
502                    } else {
503                        "OpenAI websocket turn ended with response.done before a terminal response body was available"
504                            .to_string()
505                    };
506
507                    return Err(CompletionError::ProviderError(message));
508                }
509                ResponsesWebSocketEvent::Error(error) => {
510                    return Err(CompletionError::ProviderError(error.to_string()));
511                }
512                ResponsesWebSocketEvent::Item(_) => {}
513            }
514        }
515    }
516
517    fn update_state_for_event(&mut self, event: &ResponsesWebSocketEvent) {
518        match event {
519            ResponsesWebSocketEvent::Response(chunk) => match chunk.kind {
520                ResponseChunkKind::ResponseCompleted => {
521                    let response_id = chunk.response.id.clone();
522                    self.previous_response_id = Some(response_id.clone());
523                    self.pending_done_response_id = Some(response_id);
524                    self.in_flight = false;
525                }
526                ResponseChunkKind::ResponseFailed | ResponseChunkKind::ResponseIncomplete => {
527                    self.pending_done_response_id = Some(chunk.response.id.clone());
528                    self.previous_response_id = None;
529                    self.in_flight = false;
530                }
531                ResponseChunkKind::ResponseCreated | ResponseChunkKind::ResponseInProgress => {}
532            },
533            ResponsesWebSocketEvent::Done(done) => {
534                match done.status() {
535                    Some(ResponseStatus::Completed) => {
536                        if let Some(response_id) = done.response_id() {
537                            self.previous_response_id = Some(response_id.to_string());
538                        }
539                    }
540                    Some(ResponseStatus::Failed)
541                    | Some(ResponseStatus::Incomplete)
542                    | Some(ResponseStatus::Cancelled) => {
543                        self.previous_response_id = None;
544                    }
545                    Some(ResponseStatus::InProgress | ResponseStatus::Queued) | None => {}
546                }
547                self.pending_done_response_id = None;
548                self.in_flight = false;
549            }
550            ResponsesWebSocketEvent::Error(_) => {
551                self.previous_response_id = None;
552                self.pending_done_response_id = None;
553                self.in_flight = false;
554            }
555            ResponsesWebSocketEvent::Item(_) => {}
556        }
557    }
558
559    fn abort_turn(&mut self) {
560        self.previous_response_id = None;
561        self.pending_done_response_id = None;
562        self.in_flight = false;
563    }
564
565    fn mark_closed(&mut self) {
566        self.abort_turn();
567        self.closed = true;
568        self.failed = false;
569    }
570
571    fn mark_failed(&mut self) {
572        self.abort_turn();
573        self.failed = true;
574    }
575
576    fn ensure_open(&self) -> Result<(), CompletionError> {
577        if self.closed || self.failed {
578            return Err(CompletionError::ProviderError(
579                "The OpenAI websocket session is closed".to_string(),
580            ));
581        }
582
583        Ok(())
584    }
585
586    fn fail_session(&mut self, error: CompletionError) -> CompletionError {
587        self.mark_failed();
588        error
589    }
590
591    async fn read_next_message(
592        &mut self,
593    ) -> Result<Option<Result<Message, tungstenite::Error>>, CompletionError> {
594        if let Some(timeout_duration) = self.event_timeout {
595            match tokio::time::timeout(timeout_duration, self.socket.next()).await {
596                Ok(message) => Ok(message),
597                Err(_) => Err(self.fail_session(event_timeout_error(timeout_duration))),
598            }
599        } else {
600            Ok(self.socket.next().await)
601        }
602    }
603}
604
605impl<T> Drop for ResponsesWebSocketSession<T> {
606    fn drop(&mut self) {
607        if !self.closed {
608            tracing::warn!(
609                target: "rig::completions",
610                in_flight = self.in_flight,
611                "Dropping an OpenAI websocket session without calling close(); the connection will end without a close handshake"
612            );
613        }
614    }
615}
616
617fn terminal_response_result(
618    response: CompletionResponse,
619) -> Result<CompletionResponse, CompletionError> {
620    match response.status {
621        ResponseStatus::Completed => Ok(response),
622        ResponseStatus::Failed => Err(CompletionError::ProviderError(response_error_message(
623            response.error.as_ref(),
624            "failed response",
625        ))),
626        ResponseStatus::Incomplete => {
627            let reason = response
628                .incomplete_details
629                .as_ref()
630                .map(|details| details.reason.as_str())
631                .unwrap_or("unknown reason");
632            Err(CompletionError::ProviderError(format!(
633                "OpenAI websocket response was incomplete: {reason}"
634            )))
635        }
636        status => Err(CompletionError::ProviderError(format!(
637            "OpenAI websocket response ended with status {:?}",
638            status
639        ))),
640    }
641}
642
643fn response_error_message(error: Option<&ResponseError>, fallback: &str) -> String {
644    if let Some(error) = error {
645        if error.code.is_empty() {
646            error.message.clone()
647        } else {
648            format!("{}: {}", error.code, error.message)
649        }
650    } else {
651        format!("OpenAI websocket returned a {fallback}")
652    }
653}
654
655fn is_known_streaming_event(kind: &str) -> bool {
656    matches!(
657        kind,
658        "response.created"
659            | "response.in_progress"
660            | "response.completed"
661            | "response.failed"
662            | "response.incomplete"
663            | "response.output_item.added"
664            | "response.output_item.done"
665            | "response.content_part.added"
666            | "response.content_part.done"
667            | "response.output_text.delta"
668            | "response.output_text.done"
669            | "response.refusal.delta"
670            | "response.refusal.done"
671            | "response.function_call_arguments.delta"
672            | "response.function_call_arguments.done"
673            | "response.reasoning_summary_part.added"
674            | "response.reasoning_summary_part.done"
675            | "response.reasoning_summary_text.delta"
676            | "response.reasoning_summary_text.done"
677    )
678}
679
680fn parse_server_event(payload: &str) -> Result<Option<ResponsesWebSocketEvent>, CompletionError> {
681    #[derive(Deserialize)]
682    struct EventType {
683        #[serde(rename = "type")]
684        kind: String,
685    }
686
687    let event_type = serde_json::from_str::<EventType>(payload)?;
688    match event_type.kind.as_str() {
689        "error" => serde_json::from_str(payload)
690            .map(|e| Some(ResponsesWebSocketEvent::Error(e)))
691            .map_err(CompletionError::from),
692        "response.done" => serde_json::from_str(payload)
693            .map(|d| Some(ResponsesWebSocketEvent::Done(d)))
694            .map_err(CompletionError::from),
695        kind if is_known_streaming_event(kind) => match serde_json::from_str(payload)? {
696            StreamingCompletionChunk::Response(response) => {
697                Ok(Some(ResponsesWebSocketEvent::Response(response)))
698            }
699            StreamingCompletionChunk::Delta(item) => Ok(Some(ResponsesWebSocketEvent::Item(item))),
700        },
701        _ => {
702            tracing::debug!(
703                target: "rig::completions",
704                event_type = event_type.kind.as_str(),
705                "Skipping unrecognised OpenAI websocket event"
706            );
707            Ok(None)
708        }
709    }
710}
711
712fn websocket_message_to_text(message: Message) -> Result<Option<String>, CompletionError> {
713    match message {
714        Message::Text(text) => Ok(Some(text.to_string())),
715        Message::Binary(bytes) => String::from_utf8(bytes.to_vec())
716            .map(Some)
717            .map_err(|error| CompletionError::ResponseError(error.to_string())),
718        Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(None),
719        Message::Close(frame) => {
720            let reason = frame
721                .map(|frame| frame.reason.to_string())
722                .filter(|reason| !reason.is_empty())
723                .unwrap_or_else(|| "without a close reason".to_string());
724            Err(CompletionError::ProviderError(format!(
725                "The OpenAI websocket connection closed {reason}"
726            )))
727        }
728    }
729}
730
731fn websocket_url(base_url: &str) -> Result<String, CompletionError> {
732    let mut url = Url::parse(base_url)?;
733    match url.scheme() {
734        "https" => {
735            url.set_scheme("wss").map_err(|_| {
736                CompletionError::ProviderError("Failed to convert https URL to wss".to_string())
737            })?;
738        }
739        "http" => {
740            url.set_scheme("ws").map_err(|_| {
741                CompletionError::ProviderError("Failed to convert http URL to ws".to_string())
742            })?;
743        }
744        scheme => {
745            return Err(CompletionError::ProviderError(format!(
746                "Unsupported base URL scheme for OpenAI websocket mode: {scheme}"
747            )));
748        }
749    }
750
751    let path = format!("{}/responses", url.path().trim_end_matches('/'));
752    url.set_path(&path);
753    Ok(url.to_string())
754}
755
756fn websocket_request(
757    url: &str,
758    headers: &http::HeaderMap,
759) -> Result<http::Request<()>, CompletionError> {
760    let mut request = url.into_client_request().map_err(|error| {
761        CompletionError::ProviderError(format!("Failed to build OpenAI websocket request: {error}"))
762    })?;
763
764    for (name, value) in headers {
765        request.headers_mut().insert(name, value.clone());
766    }
767
768    Ok(request)
769}
770
771async fn connect_websocket(
772    request: http::Request<()>,
773    connect_timeout: Option<Duration>,
774) -> Result<OpenAIWebSocket, CompletionError> {
775    if let Some(timeout_duration) = connect_timeout {
776        match tokio::time::timeout(timeout_duration, connect_async(request)).await {
777            Ok(result) => result
778                .map(|(socket, _)| socket)
779                .map_err(websocket_provider_error),
780            Err(_) => Err(connect_timeout_error(timeout_duration)),
781        }
782    } else {
783        connect_async(request)
784            .await
785            .map(|(socket, _)| socket)
786            .map_err(websocket_provider_error)
787    }
788}
789
790fn connect_timeout_error(timeout: Duration) -> CompletionError {
791    CompletionError::ProviderError(format!(
792        "Timed out connecting to the OpenAI websocket after {timeout:?}"
793    ))
794}
795
796fn event_timeout_error(timeout: Duration) -> CompletionError {
797    CompletionError::ProviderError(format!(
798        "Timed out waiting for the next OpenAI websocket event after {timeout:?}"
799    ))
800}
801
802fn websocket_provider_error(error: tungstenite::Error) -> CompletionError {
803    CompletionError::ProviderError(error.to_string())
804}
805
806#[cfg(test)]
807mod tests {
808    use super::{
809        ResponsesWebSocketCreateOptions, ResponsesWebSocketDoneEvent, ResponsesWebSocketEvent,
810        parse_server_event, terminal_response_result, websocket_url,
811    };
812    use crate::client::CompletionClient;
813    use crate::completion::CompletionModel;
814    use crate::providers::openai::responses_api::{
815        CompletionResponse, ResponseObject, ResponseStatus, ResponsesUsage,
816    };
817    use futures::{SinkExt, StreamExt};
818    use serde_json::json;
819    use std::time::Duration;
820    use tokio::net::TcpListener;
821    use tokio::time::sleep;
822    use tokio_tungstenite::{accept_async, tungstenite::Message};
823
824    fn sample_response(status: ResponseStatus) -> CompletionResponse {
825        CompletionResponse {
826            id: "resp_123".to_string(),
827            object: ResponseObject::Response,
828            created_at: 0,
829            status,
830            error: None,
831            incomplete_details: None,
832            instructions: None,
833            max_output_tokens: None,
834            model: "gpt-5.4".to_string(),
835            usage: Some(ResponsesUsage {
836                input_tokens: 1,
837                input_tokens_details: None,
838                output_tokens: 2,
839                output_tokens_details:
840                    crate::providers::openai::responses_api::OutputTokensDetails {
841                        reasoning_tokens: 0,
842                    },
843                total_tokens: 3,
844            }),
845            output: Vec::new(),
846            tools: Vec::new(),
847            additional_parameters: Default::default(),
848        }
849    }
850
851    #[test]
852    fn warmup_options_serialize_generate_false() {
853        let options = ResponsesWebSocketCreateOptions::warmup();
854        let json = serde_json::to_value(options).expect("options should serialize");
855
856        assert_eq!(json, json!({ "generate": false }));
857    }
858
859    #[test]
860    fn websocket_url_converts_https_to_wss() {
861        let url = websocket_url("https://api.openai.com/v1").expect("url should convert");
862        assert_eq!(url, "wss://api.openai.com/v1/responses");
863    }
864
865    #[test]
866    fn parse_done_event_exposes_response_id() {
867        let payload = json!({
868            "type": "response.done",
869            "response": {
870                "id": "resp_done_1",
871                "status": "completed"
872            }
873        });
874
875        let event = parse_server_event(&payload.to_string())
876            .expect("done event should deserialize")
877            .expect("done event should not be skipped");
878
879        assert!(matches!(
880            event,
881            ResponsesWebSocketEvent::Done(ResponsesWebSocketDoneEvent { .. })
882        ));
883        assert_eq!(event.response_id(), Some("resp_done_1"));
884        assert!(event.is_terminal());
885    }
886
887    #[test]
888    fn parse_response_completed_event_is_terminal() {
889        let payload = json!({
890            "type": "response.completed",
891            "sequence_number": 12,
892            "response": {
893                "id": "resp_completed_1",
894                "object": "response",
895                "created_at": 0,
896                "status": "completed",
897                "error": null,
898                "incomplete_details": null,
899                "instructions": null,
900                "max_output_tokens": null,
901                "model": "gpt-5.4",
902                "usage": null,
903                "output": [],
904                "tools": []
905            }
906        });
907
908        let event = parse_server_event(&payload.to_string())
909            .expect("response event should deserialize")
910            .expect("response event should not be skipped");
911
912        assert!(matches!(event, ResponsesWebSocketEvent::Response(_)));
913        assert!(event.is_terminal());
914        assert_eq!(event.response_id(), Some("resp_completed_1"));
915    }
916
917    #[test]
918    fn parse_live_output_item_added_event() {
919        let payload = json!({
920            "type": "response.output_item.added",
921            "item": {
922                "id": "msg_036471c3a72c147b0069ae7848d68881959773fd2d99e3d98a",
923                "type": "message",
924                "status": "in_progress",
925                "content": [],
926                "role": "assistant"
927            },
928            "output_index": 0,
929            "sequence_number": 2
930        });
931
932        let event = parse_server_event(&payload.to_string())
933            .expect("output item event should parse")
934            .expect("output item event should not be skipped");
935
936        assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
937    }
938
939    #[test]
940    fn parse_live_content_part_added_event() {
941        let payload = json!({
942            "type": "response.content_part.added",
943            "content_index": 0,
944            "item_id": "msg_036471c3a72c147b0069ae7848d68881959773fd2d99e3d98a",
945            "output_index": 0,
946            "part": {
947                "type": "output_text",
948                "annotations": [],
949                "logprobs": [],
950                "text": ""
951            },
952            "sequence_number": 3
953        });
954
955        let event = parse_server_event(&payload.to_string())
956            .expect("content part event should parse")
957            .expect("content part event should not be skipped");
958
959        assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
960    }
961
962    #[test]
963    fn parse_live_output_text_delta_event() {
964        let payload = json!({
965            "type": "response.output_text.delta",
966            "content_index": 0,
967            "delta": "Web",
968            "item_id": "msg_023af0f0a91bc2a90069ae788612e881958345bb156915ba29",
969            "logprobs": [],
970            "obfuscation": "2YYErYq7jkqqM",
971            "output_index": 0,
972            "sequence_number": 4
973        });
974
975        let event = parse_server_event(&payload.to_string())
976            .expect("output text delta event should parse")
977            .expect("output text delta event should not be skipped");
978
979        assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
980    }
981
982    #[test]
983    fn terminal_response_requires_completed_status() {
984        let completed = terminal_response_result(sample_response(ResponseStatus::Completed))
985            .expect("completed response should succeed");
986        assert_eq!(completed.id, "resp_123");
987
988        let failed = terminal_response_result(sample_response(ResponseStatus::Failed))
989            .expect_err("failed response should error");
990        assert!(failed.to_string().contains("failed response"));
991    }
992
993    #[tokio::test]
994    async fn malformed_known_event_rejects_reuse_and_allows_close() {
995        let listener = TcpListener::bind("127.0.0.1:0")
996            .await
997            .expect("listener should bind");
998        let address = listener.local_addr().expect("listener should have address");
999
1000        let server = tokio::spawn(async move {
1001            let (stream, _) = listener.accept().await.expect("server should accept");
1002            let mut socket = accept_async(stream)
1003                .await
1004                .expect("server should upgrade websocket");
1005
1006            let request = socket
1007                .next()
1008                .await
1009                .expect("request should exist")
1010                .expect("request should be valid");
1011            let payload = request.into_text().expect("request should be text");
1012            assert!(
1013                payload.contains("\"type\":\"response.create\""),
1014                "expected response.create payload, got {payload}"
1015            );
1016
1017            socket
1018                .send(Message::text(
1019                    json!({
1020                        "type": "response.completed"
1021                    })
1022                    .to_string(),
1023                ))
1024                .await
1025                .expect("malformed known event should send");
1026
1027            let message = socket
1028                .next()
1029                .await
1030                .expect("close frame should arrive")
1031                .expect("close frame should be valid");
1032            assert!(
1033                matches!(message, Message::Close(_)),
1034                "expected close frame, got {message:?}"
1035            );
1036        });
1037
1038        let base_url = format!("http://{address}/v1");
1039        let client = crate::providers::openai::Client::builder()
1040            .api_key("test-key")
1041            .base_url(&base_url)
1042            .build()
1043            .expect("client should build");
1044        let model = client.completion_model("gpt-4o");
1045        let mut session = client
1046            .responses_websocket("gpt-4o")
1047            .await
1048            .expect("session should connect");
1049
1050        session
1051            .send(model.completion_request("hello").build())
1052            .await
1053            .expect("request should send");
1054
1055        let error = session
1056            .next_event()
1057            .await
1058            .expect_err("malformed known event should fail");
1059        assert!(
1060            error.to_string().contains("StreamingCompletionChunk"),
1061            "expected strict decode failure, got {error}"
1062        );
1063
1064        let closed = session
1065            .send(model.completion_request("retry").build())
1066            .await
1067            .expect_err("session should close after fatal parse error");
1068        assert!(
1069            closed.to_string().contains("session is closed"),
1070            "expected closed-session error, got {closed}"
1071        );
1072
1073        session
1074            .close()
1075            .await
1076            .expect("explicit close after fatal parse error should succeed");
1077
1078        server.await.expect("server task should finish");
1079    }
1080
1081    #[tokio::test]
1082    async fn event_timeout_rejects_reuse_and_allows_close() {
1083        let listener = TcpListener::bind("127.0.0.1:0")
1084            .await
1085            .expect("listener should bind");
1086        let address = listener.local_addr().expect("listener should have address");
1087
1088        let server = tokio::spawn(async move {
1089            let (stream, _) = listener.accept().await.expect("server should accept");
1090            let mut socket = accept_async(stream)
1091                .await
1092                .expect("server should upgrade websocket");
1093
1094            let request = socket
1095                .next()
1096                .await
1097                .expect("request should exist")
1098                .expect("request should be valid");
1099            let payload = request.into_text().expect("request should be text");
1100            assert!(
1101                payload.contains("\"type\":\"response.create\""),
1102                "expected response.create payload, got {payload}"
1103            );
1104
1105            sleep(Duration::from_millis(60)).await;
1106            let message = socket
1107                .next()
1108                .await
1109                .expect("close frame should arrive")
1110                .expect("close frame should be valid");
1111            assert!(
1112                matches!(message, Message::Close(_)),
1113                "expected close frame, got {message:?}"
1114            );
1115        });
1116
1117        let base_url = format!("http://{address}/v1");
1118        let client = crate::providers::openai::Client::builder()
1119            .api_key("test-key")
1120            .base_url(&base_url)
1121            .build()
1122            .expect("client should build");
1123        let model = client.completion_model("gpt-4o");
1124        let mut session = client
1125            .responses_websocket_builder("gpt-4o")
1126            .event_timeout(Duration::from_millis(20))
1127            .connect()
1128            .await
1129            .expect("session should connect");
1130
1131        session
1132            .send(model.completion_request("hello").build())
1133            .await
1134            .expect("request should send");
1135
1136        let error = session
1137            .next_event()
1138            .await
1139            .expect_err("next_event should time out");
1140        assert!(
1141            error
1142                .to_string()
1143                .contains("Timed out waiting for the next OpenAI websocket event"),
1144            "expected timeout error, got {error}"
1145        );
1146
1147        let closed = session
1148            .send(model.completion_request("retry").build())
1149            .await
1150            .expect_err("timed-out session should close");
1151        assert!(
1152            closed.to_string().contains("session is closed"),
1153            "expected closed-session error, got {closed}"
1154        );
1155
1156        session
1157            .close()
1158            .await
1159            .expect("explicit close after timeout should succeed");
1160
1161        server.await.expect("server task should finish");
1162    }
1163
1164    #[tokio::test]
1165    async fn late_response_done_is_ignored_on_next_turn() {
1166        let listener = TcpListener::bind("127.0.0.1:0")
1167            .await
1168            .expect("listener should bind");
1169        let address = listener.local_addr().expect("listener should have address");
1170
1171        let server = tokio::spawn(async move {
1172            let (stream, _) = listener.accept().await.expect("server should accept");
1173            let mut socket = accept_async(stream)
1174                .await
1175                .expect("server should upgrade websocket");
1176
1177            for (index, response_id) in ["resp_1", "resp_2"].iter().enumerate() {
1178                let request = socket
1179                    .next()
1180                    .await
1181                    .expect("request should exist")
1182                    .expect("request should be valid");
1183                let payload = request.into_text().expect("request should be text");
1184                assert!(
1185                    payload.contains("\"type\":\"response.create\""),
1186                    "expected response.create payload, got {payload}"
1187                );
1188
1189                let response = sample_response(ResponseStatus::Completed);
1190                let response = serde_json::to_value(CompletionResponse {
1191                    id: (*response_id).to_string(),
1192                    ..response
1193                })
1194                .expect("response should serialize");
1195
1196                socket
1197                    .send(Message::text(
1198                        json!({
1199                            "type": "response.completed",
1200                            "sequence_number": (index * 2) + 1,
1201                            "response": response,
1202                        })
1203                        .to_string(),
1204                    ))
1205                    .await
1206                    .expect("completed event should send");
1207                socket
1208                    .send(Message::text(
1209                        json!({
1210                            "type": "response.done",
1211                            "response": {
1212                                "id": response_id,
1213                                "status": "completed",
1214                            },
1215                        })
1216                        .to_string(),
1217                    ))
1218                    .await
1219                    .expect("done event should send");
1220            }
1221        });
1222
1223        let base_url = format!("http://{address}/v1");
1224        let client = crate::providers::openai::Client::builder()
1225            .api_key("test-key")
1226            .base_url(&base_url)
1227            .build()
1228            .expect("client should build");
1229        let model = client.completion_model("gpt-4o");
1230        let mut session = client
1231            .responses_websocket("gpt-4o")
1232            .await
1233            .expect("session should connect");
1234
1235        session
1236            .send(model.completion_request("first").build())
1237            .await
1238            .expect("first request should send");
1239        let first = session
1240            .wait_for_completed_response()
1241            .await
1242            .expect("first response should complete");
1243        assert_eq!(first.id, "resp_1");
1244        assert_eq!(session.previous_response_id(), Some("resp_1"));
1245
1246        session
1247            .send(model.completion_request("second").build())
1248            .await
1249            .expect("second request should send");
1250        let second = session
1251            .wait_for_completed_response()
1252            .await
1253            .expect("second response should complete");
1254        assert_eq!(second.id, "resp_2");
1255        assert_eq!(session.previous_response_id(), Some("resp_2"));
1256
1257        server.await.expect("server task should finish");
1258    }
1259
1260    #[tokio::test]
1261    async fn clearing_previous_response_id_does_not_disable_late_done_filter() {
1262        let listener = TcpListener::bind("127.0.0.1:0")
1263            .await
1264            .expect("listener should bind");
1265        let address = listener.local_addr().expect("listener should have address");
1266
1267        let server = tokio::spawn(async move {
1268            let (stream, _) = listener.accept().await.expect("server should accept");
1269            let mut socket = accept_async(stream)
1270                .await
1271                .expect("server should upgrade websocket");
1272
1273            for response_id in ["resp_1", "resp_2"] {
1274                let request = socket
1275                    .next()
1276                    .await
1277                    .expect("request should exist")
1278                    .expect("request should be valid");
1279                let payload = request.into_text().expect("request should be text");
1280                assert!(
1281                    payload.contains("\"type\":\"response.create\""),
1282                    "expected response.create payload, got {payload}"
1283                );
1284
1285                let response = sample_response(ResponseStatus::Completed);
1286                let response = serde_json::to_value(CompletionResponse {
1287                    id: response_id.to_string(),
1288                    ..response
1289                })
1290                .expect("response should serialize");
1291
1292                socket
1293                    .send(Message::text(
1294                        json!({
1295                            "type": "response.completed",
1296                            "sequence_number": 1,
1297                            "response": response,
1298                        })
1299                        .to_string(),
1300                    ))
1301                    .await
1302                    .expect("completed event should send");
1303                socket
1304                    .send(Message::text(
1305                        json!({
1306                            "type": "response.done",
1307                            "response": {
1308                                "id": response_id,
1309                                "status": "completed",
1310                            },
1311                        })
1312                        .to_string(),
1313                    ))
1314                    .await
1315                    .expect("done event should send");
1316            }
1317        });
1318
1319        let base_url = format!("http://{address}/v1");
1320        let client = crate::providers::openai::Client::builder()
1321            .api_key("test-key")
1322            .base_url(&base_url)
1323            .build()
1324            .expect("client should build");
1325        let model = client.completion_model("gpt-4o");
1326        let mut session = client
1327            .responses_websocket("gpt-4o")
1328            .await
1329            .expect("session should connect");
1330
1331        session
1332            .send(model.completion_request("first").build())
1333            .await
1334            .expect("first request should send");
1335        let first = session
1336            .wait_for_completed_response()
1337            .await
1338            .expect("first response should complete");
1339        assert_eq!(first.id, "resp_1");
1340
1341        session.clear_previous_response_id();
1342        assert_eq!(session.previous_response_id(), None);
1343
1344        session
1345            .send(model.completion_request("second").build())
1346            .await
1347            .expect("second request should send");
1348        let second = session
1349            .wait_for_completed_response()
1350            .await
1351            .expect("second response should complete");
1352        assert_eq!(second.id, "resp_2");
1353
1354        server.await.expect("server task should finish");
1355    }
1356
1357    #[tokio::test]
1358    async fn failed_turn_keeps_late_done_out_of_next_request() {
1359        let listener = TcpListener::bind("127.0.0.1:0")
1360            .await
1361            .expect("listener should bind");
1362        let address = listener.local_addr().expect("listener should have address");
1363
1364        let server = tokio::spawn(async move {
1365            let (stream, _) = listener.accept().await.expect("server should accept");
1366            let mut socket = accept_async(stream)
1367                .await
1368                .expect("server should upgrade websocket");
1369
1370            let first_request = socket
1371                .next()
1372                .await
1373                .expect("request should exist")
1374                .expect("request should be valid");
1375            let payload = first_request
1376                .into_text()
1377                .expect("failed request should be text");
1378            assert!(
1379                payload.contains("\"type\":\"response.create\""),
1380                "expected response.create payload, got {payload}"
1381            );
1382
1383            let failed_response = serde_json::to_value(CompletionResponse {
1384                id: "resp_failed".to_string(),
1385                status: ResponseStatus::Failed,
1386                ..sample_response(ResponseStatus::Completed)
1387            })
1388            .expect("failed response should serialize");
1389
1390            socket
1391                .send(Message::text(
1392                    json!({
1393                        "type": "response.failed",
1394                        "sequence_number": 1,
1395                        "response": failed_response,
1396                    })
1397                    .to_string(),
1398                ))
1399                .await
1400                .expect("failed event should send");
1401            socket
1402                .send(Message::text(
1403                    json!({
1404                        "type": "response.done",
1405                        "response": {
1406                            "id": "resp_failed",
1407                            "status": "failed",
1408                        },
1409                    })
1410                    .to_string(),
1411                ))
1412                .await
1413                .expect("done event should send");
1414
1415            let second_request = socket
1416                .next()
1417                .await
1418                .expect("request should exist")
1419                .expect("request should be valid");
1420            let payload = second_request
1421                .into_text()
1422                .expect("second request should be text");
1423            assert!(
1424                payload.contains("\"type\":\"response.create\""),
1425                "expected response.create payload, got {payload}"
1426            );
1427
1428            let response = sample_response(ResponseStatus::Completed);
1429            let response = serde_json::to_value(CompletionResponse {
1430                id: "resp_2".to_string(),
1431                ..response
1432            })
1433            .expect("response should serialize");
1434
1435            socket
1436                .send(Message::text(
1437                    json!({
1438                        "type": "response.completed",
1439                        "sequence_number": 2,
1440                        "response": response,
1441                    })
1442                    .to_string(),
1443                ))
1444                .await
1445                .expect("completed event should send");
1446            socket
1447                .send(Message::text(
1448                    json!({
1449                        "type": "response.done",
1450                        "response": {
1451                            "id": "resp_2",
1452                            "status": "completed",
1453                        },
1454                    })
1455                    .to_string(),
1456                ))
1457                .await
1458                .expect("done event should send");
1459        });
1460
1461        let base_url = format!("http://{address}/v1");
1462        let client = crate::providers::openai::Client::builder()
1463            .api_key("test-key")
1464            .base_url(&base_url)
1465            .build()
1466            .expect("client should build");
1467        let model = client.completion_model("gpt-4o");
1468        let mut session = client
1469            .responses_websocket("gpt-4o")
1470            .await
1471            .expect("session should connect");
1472
1473        session
1474            .send(model.completion_request("first").build())
1475            .await
1476            .expect("first request should send");
1477        let error = session
1478            .wait_for_completed_response()
1479            .await
1480            .expect_err("failed response should error");
1481        assert!(error.to_string().contains("failed response"));
1482        assert_eq!(session.previous_response_id(), None);
1483
1484        session
1485            .send(model.completion_request("second").build())
1486            .await
1487            .expect("second request should send");
1488        let second = session
1489            .wait_for_completed_response()
1490            .await
1491            .expect("second response should complete");
1492        assert_eq!(second.id, "resp_2");
1493
1494        server.await.expect("server task should finish");
1495    }
1496
1497    #[tokio::test]
1498    async fn done_first_completed_turn_updates_previous_response_id() {
1499        let listener = TcpListener::bind("127.0.0.1:0")
1500            .await
1501            .expect("listener should bind");
1502        let address = listener.local_addr().expect("listener should have address");
1503
1504        let server = tokio::spawn(async move {
1505            let (stream, _) = listener.accept().await.expect("server should accept");
1506            let mut socket = accept_async(stream)
1507                .await
1508                .expect("server should upgrade websocket");
1509
1510            for response_id in ["resp_1", "resp_2"] {
1511                let request = socket
1512                    .next()
1513                    .await
1514                    .expect("request should exist")
1515                    .expect("request should be valid");
1516                let payload = request.into_text().expect("request should be text");
1517                assert!(
1518                    payload.contains("\"type\":\"response.create\""),
1519                    "expected response.create payload, got {payload}"
1520                );
1521
1522                if response_id == "resp_2" {
1523                    assert!(
1524                        payload.contains("\"previous_response_id\":\"resp_1\""),
1525                        "expected chained previous_response_id in payload, got {payload}"
1526                    );
1527                }
1528
1529                let response = serde_json::to_value(CompletionResponse {
1530                    id: response_id.to_string(),
1531                    ..sample_response(ResponseStatus::Completed)
1532                })
1533                .expect("response should serialize");
1534
1535                socket
1536                    .send(Message::text(
1537                        json!({
1538                            "type": "response.done",
1539                            "response": response,
1540                        })
1541                        .to_string(),
1542                    ))
1543                    .await
1544                    .expect("done event should send");
1545            }
1546        });
1547
1548        let base_url = format!("http://{address}/v1");
1549        let client = crate::providers::openai::Client::builder()
1550            .api_key("test-key")
1551            .base_url(&base_url)
1552            .build()
1553            .expect("client should build");
1554        let model = client.completion_model("gpt-4o");
1555        let mut session = client
1556            .responses_websocket("gpt-4o")
1557            .await
1558            .expect("session should connect");
1559
1560        session
1561            .send(model.completion_request("first").build())
1562            .await
1563            .expect("first request should send");
1564        let first = session
1565            .wait_for_completed_response()
1566            .await
1567            .expect("first response should complete");
1568        assert_eq!(first.id, "resp_1");
1569        assert_eq!(session.previous_response_id(), Some("resp_1"));
1570
1571        session
1572            .send(model.completion_request("second").build())
1573            .await
1574            .expect("second request should send");
1575        let second = session
1576            .wait_for_completed_response()
1577            .await
1578            .expect("second response should complete");
1579        assert_eq!(second.id, "resp_2");
1580        assert_eq!(session.previous_response_id(), Some("resp_2"));
1581
1582        server.await.expect("server task should finish");
1583    }
1584
1585    #[tokio::test]
1586    async fn done_first_failed_turn_does_not_chain_next_request() {
1587        let listener = TcpListener::bind("127.0.0.1:0")
1588            .await
1589            .expect("listener should bind");
1590        let address = listener.local_addr().expect("listener should have address");
1591
1592        let server = tokio::spawn(async move {
1593            let (stream, _) = listener.accept().await.expect("server should accept");
1594            let mut socket = accept_async(stream)
1595                .await
1596                .expect("server should upgrade websocket");
1597
1598            let first_request = socket
1599                .next()
1600                .await
1601                .expect("request should exist")
1602                .expect("request should be valid");
1603            let payload = first_request
1604                .into_text()
1605                .expect("first request should be text");
1606            assert!(
1607                payload.contains("\"type\":\"response.create\""),
1608                "expected response.create payload, got {payload}"
1609            );
1610            assert!(
1611                !payload.contains("\"previous_response_id\""),
1612                "did not expect previous_response_id in first payload, got {payload}"
1613            );
1614
1615            let failed_response = serde_json::to_value(CompletionResponse {
1616                id: "resp_failed".to_string(),
1617                status: ResponseStatus::Failed,
1618                ..sample_response(ResponseStatus::Completed)
1619            })
1620            .expect("failed response should serialize");
1621
1622            socket
1623                .send(Message::text(
1624                    json!({
1625                        "type": "response.done",
1626                        "response": failed_response,
1627                    })
1628                    .to_string(),
1629                ))
1630                .await
1631                .expect("done event should send");
1632
1633            let second_request = socket
1634                .next()
1635                .await
1636                .expect("request should exist")
1637                .expect("request should be valid");
1638            let payload = second_request
1639                .into_text()
1640                .expect("second request should be text");
1641            assert!(
1642                payload.contains("\"type\":\"response.create\""),
1643                "expected response.create payload, got {payload}"
1644            );
1645            assert!(
1646                !payload.contains("\"previous_response_id\""),
1647                "did not expect chained previous_response_id in payload, got {payload}"
1648            );
1649
1650            let response = serde_json::to_value(CompletionResponse {
1651                id: "resp_2".to_string(),
1652                ..sample_response(ResponseStatus::Completed)
1653            })
1654            .expect("response should serialize");
1655
1656            socket
1657                .send(Message::text(
1658                    json!({
1659                        "type": "response.done",
1660                        "response": response,
1661                    })
1662                    .to_string(),
1663                ))
1664                .await
1665                .expect("done event should send");
1666        });
1667
1668        let base_url = format!("http://{address}/v1");
1669        let client = crate::providers::openai::Client::builder()
1670            .api_key("test-key")
1671            .base_url(&base_url)
1672            .build()
1673            .expect("client should build");
1674        let model = client.completion_model("gpt-4o");
1675        let mut session = client
1676            .responses_websocket("gpt-4o")
1677            .await
1678            .expect("session should connect");
1679
1680        session
1681            .send(model.completion_request("first").build())
1682            .await
1683            .expect("first request should send");
1684        let error = session
1685            .wait_for_completed_response()
1686            .await
1687            .expect_err("failed response should error");
1688        assert!(error.to_string().contains("failed response"));
1689        assert_eq!(session.previous_response_id(), None);
1690
1691        session
1692            .send(model.completion_request("second").build())
1693            .await
1694            .expect("second request should send");
1695        let second = session
1696            .wait_for_completed_response()
1697            .await
1698            .expect("second response should complete");
1699        assert_eq!(second.id, "resp_2");
1700        assert_eq!(session.previous_response_id(), Some("resp_2"));
1701
1702        server.await.expect("server task should finish");
1703    }
1704
1705    #[test]
1706    fn websocket_url_converts_http_to_ws() {
1707        let url = websocket_url("http://localhost:8080/v1").expect("url should convert");
1708        assert_eq!(url, "ws://localhost:8080/v1/responses");
1709    }
1710
1711    #[test]
1712    fn websocket_url_rejects_unsupported_scheme() {
1713        let result = websocket_url("ftp://example.com/v1");
1714        assert!(result.is_err());
1715    }
1716
1717    #[test]
1718    fn websocket_url_trims_trailing_slash() {
1719        let url = websocket_url("https://api.openai.com/v1/").expect("url should convert");
1720        assert_eq!(url, "wss://api.openai.com/v1/responses");
1721    }
1722
1723    #[test]
1724    fn unknown_event_type_is_skipped() {
1725        let payload = json!({
1726            "type": "response.some_future_event",
1727            "data": "hello"
1728        });
1729
1730        let result =
1731            parse_server_event(&payload.to_string()).expect("unknown event should not error");
1732        assert!(result.is_none(), "unknown event should be skipped");
1733    }
1734
1735    #[test]
1736    fn malformed_known_event_returns_error() {
1737        let payload = json!({
1738            "type": "response.completed"
1739        });
1740
1741        let error = parse_server_event(&payload.to_string())
1742            .expect_err("malformed known event should error");
1743        assert!(
1744            error.to_string().contains("StreamingCompletionChunk"),
1745            "expected strict decode failure, got {error}"
1746        );
1747    }
1748
1749    #[tokio::test]
1750    async fn close_is_idempotent() {
1751        let listener = TcpListener::bind("127.0.0.1:0")
1752            .await
1753            .expect("listener should bind");
1754        let address = listener.local_addr().expect("listener should have address");
1755
1756        let server = tokio::spawn(async move {
1757            let (stream, _) = listener.accept().await.expect("server should accept");
1758            let mut socket = accept_async(stream)
1759                .await
1760                .expect("server should upgrade websocket");
1761
1762            let message = socket
1763                .next()
1764                .await
1765                .expect("close frame should arrive")
1766                .expect("close frame should be valid");
1767            assert!(
1768                matches!(message, Message::Close(_)),
1769                "expected close frame, got {message:?}"
1770            );
1771        });
1772
1773        let base_url = format!("http://{address}/v1");
1774        let client = crate::providers::openai::Client::builder()
1775            .api_key("test-key")
1776            .base_url(&base_url)
1777            .build()
1778            .expect("client should build");
1779        let mut session = client
1780            .responses_websocket("gpt-4o")
1781            .await
1782            .expect("session should connect");
1783
1784        session.close().await.expect("first close should succeed");
1785        session.close().await.expect("second close should succeed");
1786
1787        server.await.expect("server task should finish");
1788    }
1789
1790    #[tokio::test]
1791    async fn send_while_in_flight_returns_error() {
1792        let listener = TcpListener::bind("127.0.0.1:0")
1793            .await
1794            .expect("listener should bind");
1795        let address = listener.local_addr().expect("listener should have address");
1796
1797        let server = tokio::spawn(async move {
1798            let (stream, _) = listener.accept().await.expect("server should accept");
1799            let mut socket = accept_async(stream)
1800                .await
1801                .expect("server should upgrade websocket");
1802
1803            // Read the first request but don't respond — keep it in-flight
1804            let _request = socket
1805                .next()
1806                .await
1807                .expect("request should exist")
1808                .expect("request should be valid");
1809
1810            // Wait for client to finish its test
1811            sleep(Duration::from_millis(100)).await;
1812            let _ = socket.close(None).await;
1813        });
1814
1815        let base_url = format!("http://{address}/v1");
1816        let client = crate::providers::openai::Client::builder()
1817            .api_key("test-key")
1818            .base_url(&base_url)
1819            .build()
1820            .expect("client should build");
1821        let model = client.completion_model("gpt-4o");
1822        let mut session = client
1823            .responses_websocket("gpt-4o")
1824            .await
1825            .expect("session should connect");
1826
1827        session
1828            .send(model.completion_request("first").build())
1829            .await
1830            .expect("first request should send");
1831
1832        let error = session
1833            .send(model.completion_request("second").build())
1834            .await
1835            .expect_err("second send while in-flight should error");
1836        assert!(
1837            error.to_string().contains("already in flight"),
1838            "expected in-flight error, got {error}"
1839        );
1840
1841        server.await.expect("server task should finish");
1842    }
1843
1844    #[tokio::test]
1845    async fn send_after_close_returns_error() {
1846        let listener = TcpListener::bind("127.0.0.1:0")
1847            .await
1848            .expect("listener should bind");
1849        let address = listener.local_addr().expect("listener should have address");
1850
1851        let server = tokio::spawn(async move {
1852            let (stream, _) = listener.accept().await.expect("server should accept");
1853            let _socket = accept_async(stream)
1854                .await
1855                .expect("server should upgrade websocket");
1856            sleep(Duration::from_millis(100)).await;
1857        });
1858
1859        let base_url = format!("http://{address}/v1");
1860        let client = crate::providers::openai::Client::builder()
1861            .api_key("test-key")
1862            .base_url(&base_url)
1863            .build()
1864            .expect("client should build");
1865        let model = client.completion_model("gpt-4o");
1866        let mut session = client
1867            .responses_websocket("gpt-4o")
1868            .await
1869            .expect("session should connect");
1870
1871        session.close().await.expect("close should succeed");
1872
1873        let error = session
1874            .send(model.completion_request("after close").build())
1875            .await
1876            .expect_err("send after close should error");
1877        assert!(
1878            error.to_string().contains("session is closed"),
1879            "expected closed-session error, got {error}"
1880        );
1881
1882        server.await.expect("server task should finish");
1883    }
1884
1885    #[tokio::test]
1886    async fn next_event_without_send_returns_error() {
1887        let listener = TcpListener::bind("127.0.0.1:0")
1888            .await
1889            .expect("listener should bind");
1890        let address = listener.local_addr().expect("listener should have address");
1891
1892        let server = tokio::spawn(async move {
1893            let (stream, _) = listener.accept().await.expect("server should accept");
1894            let _socket = accept_async(stream)
1895                .await
1896                .expect("server should upgrade websocket");
1897            sleep(Duration::from_millis(100)).await;
1898        });
1899
1900        let base_url = format!("http://{address}/v1");
1901        let client = crate::providers::openai::Client::builder()
1902            .api_key("test-key")
1903            .base_url(&base_url)
1904            .build()
1905            .expect("client should build");
1906        let mut session = client
1907            .responses_websocket("gpt-4o")
1908            .await
1909            .expect("session should connect");
1910
1911        let error = session
1912            .next_event()
1913            .await
1914            .expect_err("next_event without send should error");
1915        assert!(
1916            error
1917                .to_string()
1918                .contains("No OpenAI websocket response is currently in flight"),
1919            "expected not-in-flight error, got {error}"
1920        );
1921
1922        server.await.expect("server task should finish");
1923    }
1924
1925    #[tokio::test]
1926    async fn unknown_event_is_skipped_during_session() {
1927        let listener = TcpListener::bind("127.0.0.1:0")
1928            .await
1929            .expect("listener should bind");
1930        let address = listener.local_addr().expect("listener should have address");
1931
1932        let server = tokio::spawn(async move {
1933            let (stream, _) = listener.accept().await.expect("server should accept");
1934            let mut socket = accept_async(stream)
1935                .await
1936                .expect("server should upgrade websocket");
1937
1938            let _request = socket
1939                .next()
1940                .await
1941                .expect("request should exist")
1942                .expect("request should be valid");
1943
1944            // Send an unknown event type first
1945            socket
1946                .send(Message::text(
1947                    json!({
1948                        "type": "response.some_future_event",
1949                        "data": "should be skipped"
1950                    })
1951                    .to_string(),
1952                ))
1953                .await
1954                .expect("unknown event should send");
1955
1956            // Then send the real completed response
1957            let response = serde_json::to_value(CompletionResponse {
1958                id: "resp_after_unknown".to_string(),
1959                ..sample_response(ResponseStatus::Completed)
1960            })
1961            .expect("response should serialize");
1962
1963            socket
1964                .send(Message::text(
1965                    json!({
1966                        "type": "response.completed",
1967                        "sequence_number": 1,
1968                        "response": response,
1969                    })
1970                    .to_string(),
1971                ))
1972                .await
1973                .expect("completed event should send");
1974        });
1975
1976        let base_url = format!("http://{address}/v1");
1977        let client = crate::providers::openai::Client::builder()
1978            .api_key("test-key")
1979            .base_url(&base_url)
1980            .build()
1981            .expect("client should build");
1982        let model = client.completion_model("gpt-4o");
1983        let mut session = client
1984            .responses_websocket("gpt-4o")
1985            .await
1986            .expect("session should connect");
1987
1988        session
1989            .send(model.completion_request("hello").build())
1990            .await
1991            .expect("send should succeed");
1992        let response = session
1993            .wait_for_completed_response()
1994            .await
1995            .expect("response should complete despite unknown event");
1996        assert_eq!(response.id, "resp_after_unknown");
1997
1998        server.await.expect("server task should finish");
1999    }
2000}