1use crate::completion::{self, CompletionError};
8use crate::http_client::HttpClientExt;
9use crate::providers::openai::responses_api::streaming::{
10 ItemChunk, ResponseChunk, ResponseChunkKind, StreamingCompletionChunk,
11};
12use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
13use futures::{SinkExt, StreamExt};
14use serde::{Deserialize, Serialize};
15use serde_json::{Map, Value};
16use std::time::Duration;
17use tokio::net::TcpStream;
18use tokio_tungstenite::{
19 MaybeTlsStream, WebSocketStream, connect_async,
20 tungstenite::{self, Message, client::IntoClientRequest},
21};
22use tracing::Level;
23use url::Url;
24
25use super::{CompletionResponse, ResponseError, ResponseStatus, ResponsesCompletionModel};
26
27type OpenAIWebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
28const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
32pub struct ResponsesWebSocketCreateOptions {
33 #[serde(skip_serializing_if = "Option::is_none")]
37 pub generate: Option<bool>,
38}
39
40impl ResponsesWebSocketCreateOptions {
41 #[must_use]
43 pub fn warmup() -> Self {
44 Self {
45 generate: Some(false),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize)]
51struct ResponsesWebSocketClientEvent {
52 #[serde(rename = "type")]
53 kind: ResponsesWebSocketClientEventKind,
54 #[serde(flatten)]
55 request: super::CompletionRequest,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 generate: Option<bool>,
58}
59
60#[derive(Debug, Clone, Serialize)]
61enum ResponsesWebSocketClientEventKind {
62 #[serde(rename = "response.create")]
63 ResponseCreate,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ResponsesWebSocketErrorEvent {
69 #[serde(rename = "type")]
71 pub kind: ResponsesWebSocketErrorEventKind,
72 pub error: ResponsesWebSocketErrorPayload,
74}
75
76impl std::fmt::Display for ResponsesWebSocketErrorEvent {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 self.error.fmt(f)
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum ResponsesWebSocketErrorEventKind {
85 #[serde(rename = "error")]
86 Error,
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct ResponsesWebSocketErrorPayload {
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub code: Option<String>,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub message: Option<String>,
98 #[serde(flatten, default)]
100 pub extra: Map<String, Value>,
101}
102
103impl std::fmt::Display for ResponsesWebSocketErrorPayload {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match (&self.code, &self.message) {
106 (Some(code), Some(message)) => write!(f, "{code}: {message}"),
107 (None, Some(message)) => f.write_str(message),
108 (Some(code), None) => f.write_str(code),
109 (None, None) => f.write_str("OpenAI websocket error"),
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ResponsesWebSocketDoneEvent {
117 #[serde(rename = "type")]
119 pub kind: ResponsesWebSocketDoneEventKind,
120 pub response: Value,
122}
123
124impl ResponsesWebSocketDoneEvent {
125 #[must_use]
127 pub fn response_id(&self) -> Option<&str> {
128 self.response.get("id").and_then(Value::as_str)
129 }
130
131 fn status(&self) -> Option<ResponseStatus> {
132 self.response
133 .get("status")
134 .cloned()
135 .and_then(|status| serde_json::from_value(status).ok())
136 }
137
138 fn as_completion_response(&self) -> Option<CompletionResponse> {
139 serde_json::from_value(self.response.clone()).ok()
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum ResponsesWebSocketDoneEventKind {
146 #[serde(rename = "response.done")]
147 ResponseDone,
148}
149
150#[derive(Debug, Clone)]
152pub enum ResponsesWebSocketEvent {
153 Response(Box<ResponseChunk>),
155 Item(ItemChunk),
157 Error(ResponsesWebSocketErrorEvent),
159 Done(ResponsesWebSocketDoneEvent),
161}
162
163impl ResponsesWebSocketEvent {
164 #[must_use]
166 pub fn response_id(&self) -> Option<&str> {
167 match self {
168 Self::Response(chunk) => Some(&chunk.response.id),
169 Self::Done(done) => done.response_id(),
170 Self::Item(_) | Self::Error(_) => None,
171 }
172 }
173
174 #[must_use]
176 pub fn is_terminal(&self) -> bool {
177 match self {
178 Self::Response(chunk) => matches!(
179 chunk.kind,
180 ResponseChunkKind::ResponseCompleted
181 | ResponseChunkKind::ResponseFailed
182 | ResponseChunkKind::ResponseIncomplete
183 ),
184 Self::Error(_) | Self::Done(_) => true,
185 Self::Item(_) => false,
186 }
187 }
188}
189
190pub struct ResponsesWebSocketSessionBuilder<H = reqwest::Client> {
195 model: ResponsesCompletionModel<H>,
196 connect_timeout: Option<Duration>,
197 event_timeout: Option<Duration>,
198}
199
200impl<H> ResponsesWebSocketSessionBuilder<H> {
201 pub(crate) fn new(model: ResponsesCompletionModel<H>) -> Self {
202 Self {
203 model,
204 connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
205 event_timeout: None,
206 }
207 }
208
209 #[must_use]
211 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
212 self.connect_timeout = Some(timeout);
213 self
214 }
215
216 #[must_use]
218 pub fn without_connect_timeout(mut self) -> Self {
219 self.connect_timeout = None;
220 self
221 }
222
223 #[must_use]
225 pub fn event_timeout(mut self, timeout: Duration) -> Self {
226 self.event_timeout = Some(timeout);
227 self
228 }
229
230 #[must_use]
232 pub fn without_event_timeout(mut self) -> Self {
233 self.event_timeout = None;
234 self
235 }
236}
237
238impl<H> ResponsesWebSocketSessionBuilder<H>
239where
240 H: HttpClientExt
241 + Clone
242 + std::fmt::Debug
243 + Default
244 + WasmCompatSend
245 + WasmCompatSync
246 + 'static,
247{
248 pub async fn connect(self) -> Result<ResponsesWebSocketSession<H>, CompletionError> {
250 ResponsesWebSocketSession::connect_with_timeouts(
251 self.model,
252 self.connect_timeout,
253 self.event_timeout,
254 )
255 .await
256 }
257}
258
259pub struct ResponsesWebSocketSession<H = reqwest::Client> {
268 model: ResponsesCompletionModel<H>,
269 previous_response_id: Option<String>,
270 pending_done_response_id: Option<String>,
271 socket: OpenAIWebSocket,
272 in_flight: bool,
273 event_timeout: Option<Duration>,
274 closed: bool,
275 failed: bool,
276}
277
278impl<H> ResponsesWebSocketSession<H>
279where
280 H: HttpClientExt
281 + Clone
282 + std::fmt::Debug
283 + Default
284 + WasmCompatSend
285 + WasmCompatSync
286 + 'static,
287{
288 async fn connect_with_timeouts(
289 model: ResponsesCompletionModel<H>,
290 connect_timeout: Option<Duration>,
291 event_timeout: Option<Duration>,
292 ) -> Result<Self, CompletionError> {
293 let url = websocket_url(model.client.base_url())?;
294 let request = websocket_request(&url, model.client.headers())?;
295 let socket = connect_websocket(request, connect_timeout).await?;
296
297 Ok(Self {
298 model,
299 previous_response_id: None,
300 pending_done_response_id: None,
301 socket,
302 in_flight: false,
303 event_timeout,
304 closed: false,
305 failed: false,
306 })
307 }
308
309 #[must_use]
311 pub fn previous_response_id(&self) -> Option<&str> {
312 self.previous_response_id.as_deref()
313 }
314
315 pub fn clear_previous_response_id(&mut self) {
317 self.previous_response_id = None;
318 }
319
320 pub async fn send(
322 &mut self,
323 completion_request: crate::completion::CompletionRequest,
324 ) -> Result<(), CompletionError> {
325 self.send_with_options(
326 completion_request,
327 ResponsesWebSocketCreateOptions::default(),
328 )
329 .await
330 }
331
332 pub async fn send_with_options(
334 &mut self,
335 completion_request: crate::completion::CompletionRequest,
336 options: ResponsesWebSocketCreateOptions,
337 ) -> Result<(), CompletionError> {
338 self.ensure_open()?;
339
340 if self.in_flight {
341 return Err(CompletionError::ProviderError(
342 "An OpenAI websocket response is already in flight on this session".to_string(),
343 ));
344 }
345
346 let payload = ResponsesWebSocketClientEvent {
347 kind: ResponsesWebSocketClientEventKind::ResponseCreate,
348 request: self.prepare_request(completion_request)?,
349 generate: options.generate,
350 };
351
352 if tracing::enabled!(Level::TRACE) {
353 tracing::trace!(
354 target: "rig::completions",
355 "OpenAI websocket request: {}",
356 serde_json::to_string_pretty(&payload)?
357 );
358 }
359
360 let payload = serde_json::to_string(&payload)?;
361
362 if let Err(error) = self.socket.send(Message::text(payload)).await {
363 return Err(self.fail_session(websocket_provider_error(error)));
364 }
365 self.in_flight = true;
366
367 Ok(())
368 }
369
370 pub async fn next_event(&mut self) -> Result<ResponsesWebSocketEvent, CompletionError> {
372 self.ensure_open()?;
373
374 if !self.in_flight {
375 return Err(CompletionError::ProviderError(
376 "No OpenAI websocket response is currently in flight on this session".to_string(),
377 ));
378 }
379
380 loop {
381 let message = match self.read_next_message().await {
382 Ok(message) => message,
383 Err(error) => return Err(error),
384 };
385
386 let Some(message) = message else {
387 self.mark_closed();
388 return Err(CompletionError::ProviderError(
389 "The OpenAI websocket connection closed before the turn finished".to_string(),
390 ));
391 };
392
393 let message = match message {
394 Ok(message) => message,
395 Err(error) => return Err(self.fail_session(websocket_provider_error(error))),
396 };
397 let payload = match websocket_message_to_text(message) {
398 Ok(Some(payload)) => payload,
399 Ok(None) => continue,
400 Err(error) => return Err(self.fail_session(error)),
401 };
402 let event = match parse_server_event(&payload) {
403 Ok(Some(event)) => event,
404 Ok(None) => continue,
405 Err(error) => return Err(self.fail_session(error)),
406 };
407 if let ResponsesWebSocketEvent::Done(done) = &event {
408 if self.pending_done_response_id.as_deref() == done.response_id() {
411 self.pending_done_response_id = None;
412 continue;
413 }
414 }
415 self.update_state_for_event(&event);
416 return Ok(event);
417 }
418 }
419
420 pub async fn warmup(
422 &mut self,
423 completion_request: crate::completion::CompletionRequest,
424 ) -> Result<String, CompletionError> {
425 self.send_with_options(
426 completion_request,
427 ResponsesWebSocketCreateOptions::warmup(),
428 )
429 .await?;
430 let response = self.wait_for_completed_response().await?;
431 Ok(response.id)
432 }
433
434 pub async fn completion(
436 &mut self,
437 completion_request: crate::completion::CompletionRequest,
438 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
439 self.send(completion_request).await?;
440 let response = self.wait_for_completed_response().await?;
441 response.try_into()
442 }
443
444 pub async fn close(&mut self) -> Result<(), CompletionError> {
449 if self.closed {
450 return Ok(());
451 }
452
453 let result = self
454 .socket
455 .close(None)
456 .await
457 .map_err(websocket_provider_error);
458 self.mark_closed();
459 result
460 }
461
462 fn prepare_request(
463 &self,
464 completion_request: crate::completion::CompletionRequest,
465 ) -> Result<super::CompletionRequest, CompletionError> {
466 let mut request = self.model.create_completion_request(completion_request)?;
467
468 request.stream = None;
471 request.additional_parameters.background = None;
472
473 if request.additional_parameters.previous_response_id.is_none() {
474 request.additional_parameters.previous_response_id = self.previous_response_id.clone();
475 }
476
477 Ok(request)
478 }
479
480 async fn wait_for_completed_response(&mut self) -> Result<CompletionResponse, CompletionError> {
481 loop {
482 match self.next_event().await? {
483 ResponsesWebSocketEvent::Response(chunk) => {
484 if matches!(
485 chunk.kind,
486 ResponseChunkKind::ResponseCompleted
487 | ResponseChunkKind::ResponseFailed
488 | ResponseChunkKind::ResponseIncomplete
489 ) {
490 return terminal_response_result(chunk.response);
491 }
492 }
493 ResponsesWebSocketEvent::Done(done) => {
494 if let Some(response) = done.as_completion_response() {
495 return terminal_response_result(response);
496 }
497
498 let message = if let Some(response_id) = done.response_id() {
499 format!(
500 "OpenAI websocket turn ended with response.done before a terminal response body was available (response_id={response_id})"
501 )
502 } else {
503 "OpenAI websocket turn ended with response.done before a terminal response body was available"
504 .to_string()
505 };
506
507 return Err(CompletionError::ProviderError(message));
508 }
509 ResponsesWebSocketEvent::Error(error) => {
510 return Err(CompletionError::ProviderError(error.to_string()));
511 }
512 ResponsesWebSocketEvent::Item(_) => {}
513 }
514 }
515 }
516
517 fn update_state_for_event(&mut self, event: &ResponsesWebSocketEvent) {
518 match event {
519 ResponsesWebSocketEvent::Response(chunk) => match chunk.kind {
520 ResponseChunkKind::ResponseCompleted => {
521 let response_id = chunk.response.id.clone();
522 self.previous_response_id = Some(response_id.clone());
523 self.pending_done_response_id = Some(response_id);
524 self.in_flight = false;
525 }
526 ResponseChunkKind::ResponseFailed | ResponseChunkKind::ResponseIncomplete => {
527 self.pending_done_response_id = Some(chunk.response.id.clone());
528 self.previous_response_id = None;
529 self.in_flight = false;
530 }
531 ResponseChunkKind::ResponseCreated | ResponseChunkKind::ResponseInProgress => {}
532 },
533 ResponsesWebSocketEvent::Done(done) => {
534 match done.status() {
535 Some(ResponseStatus::Completed) => {
536 if let Some(response_id) = done.response_id() {
537 self.previous_response_id = Some(response_id.to_string());
538 }
539 }
540 Some(ResponseStatus::Failed)
541 | Some(ResponseStatus::Incomplete)
542 | Some(ResponseStatus::Cancelled) => {
543 self.previous_response_id = None;
544 }
545 Some(ResponseStatus::InProgress | ResponseStatus::Queued) | None => {}
546 }
547 self.pending_done_response_id = None;
548 self.in_flight = false;
549 }
550 ResponsesWebSocketEvent::Error(_) => {
551 self.previous_response_id = None;
552 self.pending_done_response_id = None;
553 self.in_flight = false;
554 }
555 ResponsesWebSocketEvent::Item(_) => {}
556 }
557 }
558
559 fn abort_turn(&mut self) {
560 self.previous_response_id = None;
561 self.pending_done_response_id = None;
562 self.in_flight = false;
563 }
564
565 fn mark_closed(&mut self) {
566 self.abort_turn();
567 self.closed = true;
568 self.failed = false;
569 }
570
571 fn mark_failed(&mut self) {
572 self.abort_turn();
573 self.failed = true;
574 }
575
576 fn ensure_open(&self) -> Result<(), CompletionError> {
577 if self.closed || self.failed {
578 return Err(CompletionError::ProviderError(
579 "The OpenAI websocket session is closed".to_string(),
580 ));
581 }
582
583 Ok(())
584 }
585
586 fn fail_session(&mut self, error: CompletionError) -> CompletionError {
587 self.mark_failed();
588 error
589 }
590
591 async fn read_next_message(
592 &mut self,
593 ) -> Result<Option<Result<Message, tungstenite::Error>>, CompletionError> {
594 if let Some(timeout_duration) = self.event_timeout {
595 match tokio::time::timeout(timeout_duration, self.socket.next()).await {
596 Ok(message) => Ok(message),
597 Err(_) => Err(self.fail_session(event_timeout_error(timeout_duration))),
598 }
599 } else {
600 Ok(self.socket.next().await)
601 }
602 }
603}
604
605impl<H> Drop for ResponsesWebSocketSession<H> {
606 fn drop(&mut self) {
607 if !self.closed {
608 tracing::warn!(
609 target: "rig::completions",
610 in_flight = self.in_flight,
611 "Dropping an OpenAI websocket session without calling close(); the connection will end without a close handshake"
612 );
613 }
614 }
615}
616
617fn terminal_response_result(
618 response: CompletionResponse,
619) -> Result<CompletionResponse, CompletionError> {
620 match response.status {
621 ResponseStatus::Completed => Ok(response),
622 ResponseStatus::Failed => Err(CompletionError::ProviderError(response_error_message(
623 response.error.as_ref(),
624 "failed response",
625 ))),
626 ResponseStatus::Incomplete => {
627 let reason = response
628 .incomplete_details
629 .as_ref()
630 .map(|details| details.reason.as_str())
631 .unwrap_or("unknown reason");
632 Err(CompletionError::ProviderError(format!(
633 "OpenAI websocket response was incomplete: {reason}"
634 )))
635 }
636 status => Err(CompletionError::ProviderError(format!(
637 "OpenAI websocket response ended with status {:?}",
638 status
639 ))),
640 }
641}
642
643fn response_error_message(error: Option<&ResponseError>, fallback: &str) -> String {
644 if let Some(error) = error {
645 if error.code.is_empty() {
646 error.message.clone()
647 } else {
648 format!("{}: {}", error.code, error.message)
649 }
650 } else {
651 format!("OpenAI websocket returned a {fallback}")
652 }
653}
654
655fn is_known_streaming_event(kind: &str) -> bool {
656 matches!(
657 kind,
658 "response.created"
659 | "response.in_progress"
660 | "response.completed"
661 | "response.failed"
662 | "response.incomplete"
663 | "response.output_item.added"
664 | "response.output_item.done"
665 | "response.content_part.added"
666 | "response.content_part.done"
667 | "response.output_text.delta"
668 | "response.output_text.done"
669 | "response.refusal.delta"
670 | "response.refusal.done"
671 | "response.function_call_arguments.delta"
672 | "response.function_call_arguments.done"
673 | "response.reasoning_summary_part.added"
674 | "response.reasoning_summary_part.done"
675 | "response.reasoning_summary_text.delta"
676 | "response.reasoning_summary_text.done"
677 )
678}
679
680fn parse_server_event(payload: &str) -> Result<Option<ResponsesWebSocketEvent>, CompletionError> {
681 #[derive(Deserialize)]
682 struct EventType {
683 #[serde(rename = "type")]
684 kind: String,
685 }
686
687 let event_type = serde_json::from_str::<EventType>(payload)?;
688 match event_type.kind.as_str() {
689 "error" => serde_json::from_str(payload)
690 .map(|e| Some(ResponsesWebSocketEvent::Error(e)))
691 .map_err(CompletionError::from),
692 "response.done" => serde_json::from_str(payload)
693 .map(|d| Some(ResponsesWebSocketEvent::Done(d)))
694 .map_err(CompletionError::from),
695 kind if is_known_streaming_event(kind) => match serde_json::from_str(payload)? {
696 StreamingCompletionChunk::Response(response) => {
697 Ok(Some(ResponsesWebSocketEvent::Response(response)))
698 }
699 StreamingCompletionChunk::Delta(item) => Ok(Some(ResponsesWebSocketEvent::Item(item))),
700 },
701 _ => {
702 tracing::debug!(
703 target: "rig::completions",
704 event_type = event_type.kind.as_str(),
705 "Skipping unrecognised OpenAI websocket event"
706 );
707 Ok(None)
708 }
709 }
710}
711
712fn websocket_message_to_text(message: Message) -> Result<Option<String>, CompletionError> {
713 match message {
714 Message::Text(text) => Ok(Some(text.to_string())),
715 Message::Binary(bytes) => String::from_utf8(bytes.to_vec())
716 .map(Some)
717 .map_err(|error| CompletionError::ResponseError(error.to_string())),
718 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(None),
719 Message::Close(frame) => {
720 let reason = frame
721 .map(|frame| frame.reason.to_string())
722 .filter(|reason| !reason.is_empty())
723 .unwrap_or_else(|| "without a close reason".to_string());
724 Err(CompletionError::ProviderError(format!(
725 "The OpenAI websocket connection closed {reason}"
726 )))
727 }
728 }
729}
730
731fn websocket_url(base_url: &str) -> Result<String, CompletionError> {
732 let mut url = Url::parse(base_url)?;
733 match url.scheme() {
734 "https" => {
735 url.set_scheme("wss").map_err(|_| {
736 CompletionError::ProviderError("Failed to convert https URL to wss".to_string())
737 })?;
738 }
739 "http" => {
740 url.set_scheme("ws").map_err(|_| {
741 CompletionError::ProviderError("Failed to convert http URL to ws".to_string())
742 })?;
743 }
744 scheme => {
745 return Err(CompletionError::ProviderError(format!(
746 "Unsupported base URL scheme for OpenAI websocket mode: {scheme}"
747 )));
748 }
749 }
750
751 let path = format!("{}/responses", url.path().trim_end_matches('/'));
752 url.set_path(&path);
753 Ok(url.to_string())
754}
755
756fn websocket_request(
757 url: &str,
758 headers: &http::HeaderMap,
759) -> Result<http::Request<()>, CompletionError> {
760 let mut request = url.into_client_request().map_err(|error| {
761 CompletionError::ProviderError(format!("Failed to build OpenAI websocket request: {error}"))
762 })?;
763
764 for (name, value) in headers {
765 request.headers_mut().insert(name, value.clone());
766 }
767
768 Ok(request)
769}
770
771async fn connect_websocket(
772 request: http::Request<()>,
773 connect_timeout: Option<Duration>,
774) -> Result<OpenAIWebSocket, CompletionError> {
775 if let Some(timeout_duration) = connect_timeout {
776 match tokio::time::timeout(timeout_duration, connect_async(request)).await {
777 Ok(result) => result
778 .map(|(socket, _)| socket)
779 .map_err(websocket_provider_error),
780 Err(_) => Err(connect_timeout_error(timeout_duration)),
781 }
782 } else {
783 connect_async(request)
784 .await
785 .map(|(socket, _)| socket)
786 .map_err(websocket_provider_error)
787 }
788}
789
790fn connect_timeout_error(timeout: Duration) -> CompletionError {
791 CompletionError::ProviderError(format!(
792 "Timed out connecting to the OpenAI websocket after {timeout:?}"
793 ))
794}
795
796fn event_timeout_error(timeout: Duration) -> CompletionError {
797 CompletionError::ProviderError(format!(
798 "Timed out waiting for the next OpenAI websocket event after {timeout:?}"
799 ))
800}
801
802fn websocket_provider_error(error: tungstenite::Error) -> CompletionError {
803 CompletionError::ProviderError(error.to_string())
804}
805
806#[cfg(test)]
807mod tests {
808 use super::{
809 ResponsesWebSocketCreateOptions, ResponsesWebSocketDoneEvent, ResponsesWebSocketEvent,
810 parse_server_event, terminal_response_result, websocket_url,
811 };
812 use crate::client::CompletionClient;
813 use crate::completion::CompletionModel;
814 use crate::providers::openai::responses_api::{
815 CompletionResponse, ResponseObject, ResponseStatus, ResponsesUsage,
816 };
817 use futures::{SinkExt, StreamExt};
818 use serde_json::json;
819 use std::time::Duration;
820 use tokio::net::TcpListener;
821 use tokio::time::sleep;
822 use tokio_tungstenite::{accept_async, tungstenite::Message};
823
824 fn sample_response(status: ResponseStatus) -> CompletionResponse {
825 CompletionResponse {
826 id: "resp_123".to_string(),
827 object: ResponseObject::Response,
828 created_at: 0,
829 status,
830 error: None,
831 incomplete_details: None,
832 instructions: None,
833 max_output_tokens: None,
834 model: "gpt-5.4".to_string(),
835 usage: Some(ResponsesUsage {
836 input_tokens: 1,
837 input_tokens_details: None,
838 output_tokens: 2,
839 output_tokens_details: Some(
840 crate::providers::openai::responses_api::OutputTokensDetails {
841 reasoning_tokens: 0,
842 },
843 ),
844 total_tokens: 3,
845 }),
846 output: Vec::new(),
847 tools: Vec::new(),
848 additional_parameters: Default::default(),
849 }
850 }
851
852 #[test]
853 fn warmup_options_serialize_generate_false() {
854 let options = ResponsesWebSocketCreateOptions::warmup();
855 let json = serde_json::to_value(options).expect("options should serialize");
856
857 assert_eq!(json, json!({ "generate": false }));
858 }
859
860 #[test]
861 fn websocket_url_converts_https_to_wss() {
862 let url = websocket_url("https://api.openai.com/v1").expect("url should convert");
863 assert_eq!(url, "wss://api.openai.com/v1/responses");
864 }
865
866 #[test]
867 fn parse_done_event_exposes_response_id() {
868 let payload = json!({
869 "type": "response.done",
870 "response": {
871 "id": "resp_done_1",
872 "status": "completed"
873 }
874 });
875
876 let event = parse_server_event(&payload.to_string())
877 .expect("done event should deserialize")
878 .expect("done event should not be skipped");
879
880 assert!(matches!(
881 event,
882 ResponsesWebSocketEvent::Done(ResponsesWebSocketDoneEvent { .. })
883 ));
884 assert_eq!(event.response_id(), Some("resp_done_1"));
885 assert!(event.is_terminal());
886 }
887
888 #[test]
889 fn parse_response_completed_event_is_terminal() {
890 let payload = json!({
891 "type": "response.completed",
892 "sequence_number": 12,
893 "response": {
894 "id": "resp_completed_1",
895 "object": "response",
896 "created_at": 0,
897 "status": "completed",
898 "error": null,
899 "incomplete_details": null,
900 "instructions": null,
901 "max_output_tokens": null,
902 "model": "gpt-5.4",
903 "usage": null,
904 "output": [],
905 "tools": []
906 }
907 });
908
909 let event = parse_server_event(&payload.to_string())
910 .expect("response event should deserialize")
911 .expect("response event should not be skipped");
912
913 assert!(matches!(event, ResponsesWebSocketEvent::Response(_)));
914 assert!(event.is_terminal());
915 assert_eq!(event.response_id(), Some("resp_completed_1"));
916 }
917
918 #[test]
919 fn parse_live_output_item_added_event() {
920 let payload = json!({
921 "type": "response.output_item.added",
922 "item": {
923 "id": "msg_036471c3a72c147b0069ae7848d68881959773fd2d99e3d98a",
924 "type": "message",
925 "status": "in_progress",
926 "content": [],
927 "role": "assistant"
928 },
929 "output_index": 0,
930 "sequence_number": 2
931 });
932
933 let event = parse_server_event(&payload.to_string())
934 .expect("output item event should parse")
935 .expect("output item event should not be skipped");
936
937 assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
938 }
939
940 #[test]
941 fn parse_live_content_part_added_event() {
942 let payload = json!({
943 "type": "response.content_part.added",
944 "content_index": 0,
945 "item_id": "msg_036471c3a72c147b0069ae7848d68881959773fd2d99e3d98a",
946 "output_index": 0,
947 "part": {
948 "type": "output_text",
949 "annotations": [],
950 "logprobs": [],
951 "text": ""
952 },
953 "sequence_number": 3
954 });
955
956 let event = parse_server_event(&payload.to_string())
957 .expect("content part event should parse")
958 .expect("content part event should not be skipped");
959
960 assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
961 }
962
963 #[test]
964 fn parse_live_output_text_delta_event() {
965 let payload = json!({
966 "type": "response.output_text.delta",
967 "content_index": 0,
968 "delta": "Web",
969 "item_id": "msg_023af0f0a91bc2a90069ae788612e881958345bb156915ba29",
970 "logprobs": [],
971 "obfuscation": "2YYErYq7jkqqM",
972 "output_index": 0,
973 "sequence_number": 4
974 });
975
976 let event = parse_server_event(&payload.to_string())
977 .expect("output text delta event should parse")
978 .expect("output text delta event should not be skipped");
979
980 assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
981 }
982
983 #[test]
984 fn terminal_response_requires_completed_status() {
985 let completed = terminal_response_result(sample_response(ResponseStatus::Completed))
986 .expect("completed response should succeed");
987 assert_eq!(completed.id, "resp_123");
988
989 let failed = terminal_response_result(sample_response(ResponseStatus::Failed))
990 .expect_err("failed response should error");
991 assert!(failed.to_string().contains("failed response"));
992 }
993
994 #[tokio::test]
995 async fn malformed_known_event_rejects_reuse_and_allows_close() {
996 let listener = TcpListener::bind("127.0.0.1:0")
997 .await
998 .expect("listener should bind");
999 let address = listener.local_addr().expect("listener should have address");
1000
1001 let server = tokio::spawn(async move {
1002 let (stream, _) = listener.accept().await.expect("server should accept");
1003 let mut socket = accept_async(stream)
1004 .await
1005 .expect("server should upgrade websocket");
1006
1007 let request = socket
1008 .next()
1009 .await
1010 .expect("request should exist")
1011 .expect("request should be valid");
1012 let payload = request.into_text().expect("request should be text");
1013 assert!(
1014 payload.contains("\"type\":\"response.create\""),
1015 "expected response.create payload, got {payload}"
1016 );
1017
1018 socket
1019 .send(Message::text(
1020 json!({
1021 "type": "response.completed"
1022 })
1023 .to_string(),
1024 ))
1025 .await
1026 .expect("malformed known event should send");
1027
1028 let message = socket
1029 .next()
1030 .await
1031 .expect("close frame should arrive")
1032 .expect("close frame should be valid");
1033 assert!(
1034 matches!(message, Message::Close(_)),
1035 "expected close frame, got {message:?}"
1036 );
1037 });
1038
1039 let base_url = format!("http://{address}/v1");
1040 let client = crate::providers::openai::Client::builder()
1041 .api_key("test-key")
1042 .base_url(&base_url)
1043 .build()
1044 .expect("client should build");
1045 let model = client.completion_model("gpt-4o");
1046 let mut session = client
1047 .responses_websocket("gpt-4o")
1048 .await
1049 .expect("session should connect");
1050
1051 session
1052 .send(model.completion_request("hello").build())
1053 .await
1054 .expect("request should send");
1055
1056 let error = session
1057 .next_event()
1058 .await
1059 .expect_err("malformed known event should fail");
1060 assert!(
1061 error.to_string().contains("StreamingCompletionChunk"),
1062 "expected strict decode failure, got {error}"
1063 );
1064
1065 let closed = session
1066 .send(model.completion_request("retry").build())
1067 .await
1068 .expect_err("session should close after fatal parse error");
1069 assert!(
1070 closed.to_string().contains("session is closed"),
1071 "expected closed-session error, got {closed}"
1072 );
1073
1074 session
1075 .close()
1076 .await
1077 .expect("explicit close after fatal parse error should succeed");
1078
1079 server.await.expect("server task should finish");
1080 }
1081
1082 #[tokio::test]
1083 async fn event_timeout_rejects_reuse_and_allows_close() {
1084 let listener = TcpListener::bind("127.0.0.1:0")
1085 .await
1086 .expect("listener should bind");
1087 let address = listener.local_addr().expect("listener should have address");
1088
1089 let server = tokio::spawn(async move {
1090 let (stream, _) = listener.accept().await.expect("server should accept");
1091 let mut socket = accept_async(stream)
1092 .await
1093 .expect("server should upgrade websocket");
1094
1095 let request = socket
1096 .next()
1097 .await
1098 .expect("request should exist")
1099 .expect("request should be valid");
1100 let payload = request.into_text().expect("request should be text");
1101 assert!(
1102 payload.contains("\"type\":\"response.create\""),
1103 "expected response.create payload, got {payload}"
1104 );
1105
1106 sleep(Duration::from_millis(60)).await;
1107 let message = socket
1108 .next()
1109 .await
1110 .expect("close frame should arrive")
1111 .expect("close frame should be valid");
1112 assert!(
1113 matches!(message, Message::Close(_)),
1114 "expected close frame, got {message:?}"
1115 );
1116 });
1117
1118 let base_url = format!("http://{address}/v1");
1119 let client = crate::providers::openai::Client::builder()
1120 .api_key("test-key")
1121 .base_url(&base_url)
1122 .build()
1123 .expect("client should build");
1124 let model = client.completion_model("gpt-4o");
1125 let mut session = client
1126 .responses_websocket_builder("gpt-4o")
1127 .event_timeout(Duration::from_millis(20))
1128 .connect()
1129 .await
1130 .expect("session should connect");
1131
1132 session
1133 .send(model.completion_request("hello").build())
1134 .await
1135 .expect("request should send");
1136
1137 let error = session
1138 .next_event()
1139 .await
1140 .expect_err("next_event should time out");
1141 assert!(
1142 error
1143 .to_string()
1144 .contains("Timed out waiting for the next OpenAI websocket event"),
1145 "expected timeout error, got {error}"
1146 );
1147
1148 let closed = session
1149 .send(model.completion_request("retry").build())
1150 .await
1151 .expect_err("timed-out session should close");
1152 assert!(
1153 closed.to_string().contains("session is closed"),
1154 "expected closed-session error, got {closed}"
1155 );
1156
1157 session
1158 .close()
1159 .await
1160 .expect("explicit close after timeout should succeed");
1161
1162 server.await.expect("server task should finish");
1163 }
1164
1165 #[tokio::test]
1166 async fn late_response_done_is_ignored_on_next_turn() {
1167 let listener = TcpListener::bind("127.0.0.1:0")
1168 .await
1169 .expect("listener should bind");
1170 let address = listener.local_addr().expect("listener should have address");
1171
1172 let server = tokio::spawn(async move {
1173 let (stream, _) = listener.accept().await.expect("server should accept");
1174 let mut socket = accept_async(stream)
1175 .await
1176 .expect("server should upgrade websocket");
1177
1178 for (index, response_id) in ["resp_1", "resp_2"].iter().enumerate() {
1179 let request = socket
1180 .next()
1181 .await
1182 .expect("request should exist")
1183 .expect("request should be valid");
1184 let payload = request.into_text().expect("request should be text");
1185 assert!(
1186 payload.contains("\"type\":\"response.create\""),
1187 "expected response.create payload, got {payload}"
1188 );
1189
1190 let response = sample_response(ResponseStatus::Completed);
1191 let response = serde_json::to_value(CompletionResponse {
1192 id: (*response_id).to_string(),
1193 ..response
1194 })
1195 .expect("response should serialize");
1196
1197 socket
1198 .send(Message::text(
1199 json!({
1200 "type": "response.completed",
1201 "sequence_number": (index * 2) + 1,
1202 "response": response,
1203 })
1204 .to_string(),
1205 ))
1206 .await
1207 .expect("completed event should send");
1208 socket
1209 .send(Message::text(
1210 json!({
1211 "type": "response.done",
1212 "response": {
1213 "id": response_id,
1214 "status": "completed",
1215 },
1216 })
1217 .to_string(),
1218 ))
1219 .await
1220 .expect("done event should send");
1221 }
1222 });
1223
1224 let base_url = format!("http://{address}/v1");
1225 let client = crate::providers::openai::Client::builder()
1226 .api_key("test-key")
1227 .base_url(&base_url)
1228 .build()
1229 .expect("client should build");
1230 let model = client.completion_model("gpt-4o");
1231 let mut session = client
1232 .responses_websocket("gpt-4o")
1233 .await
1234 .expect("session should connect");
1235
1236 session
1237 .send(model.completion_request("first").build())
1238 .await
1239 .expect("first request should send");
1240 let first = session
1241 .wait_for_completed_response()
1242 .await
1243 .expect("first response should complete");
1244 assert_eq!(first.id, "resp_1");
1245 assert_eq!(session.previous_response_id(), Some("resp_1"));
1246
1247 session
1248 .send(model.completion_request("second").build())
1249 .await
1250 .expect("second request should send");
1251 let second = session
1252 .wait_for_completed_response()
1253 .await
1254 .expect("second response should complete");
1255 assert_eq!(second.id, "resp_2");
1256 assert_eq!(session.previous_response_id(), Some("resp_2"));
1257
1258 server.await.expect("server task should finish");
1259 }
1260
1261 #[tokio::test]
1262 async fn clearing_previous_response_id_does_not_disable_late_done_filter() {
1263 let listener = TcpListener::bind("127.0.0.1:0")
1264 .await
1265 .expect("listener should bind");
1266 let address = listener.local_addr().expect("listener should have address");
1267
1268 let server = tokio::spawn(async move {
1269 let (stream, _) = listener.accept().await.expect("server should accept");
1270 let mut socket = accept_async(stream)
1271 .await
1272 .expect("server should upgrade websocket");
1273
1274 for response_id in ["resp_1", "resp_2"] {
1275 let request = socket
1276 .next()
1277 .await
1278 .expect("request should exist")
1279 .expect("request should be valid");
1280 let payload = request.into_text().expect("request should be text");
1281 assert!(
1282 payload.contains("\"type\":\"response.create\""),
1283 "expected response.create payload, got {payload}"
1284 );
1285
1286 let response = sample_response(ResponseStatus::Completed);
1287 let response = serde_json::to_value(CompletionResponse {
1288 id: response_id.to_string(),
1289 ..response
1290 })
1291 .expect("response should serialize");
1292
1293 socket
1294 .send(Message::text(
1295 json!({
1296 "type": "response.completed",
1297 "sequence_number": 1,
1298 "response": response,
1299 })
1300 .to_string(),
1301 ))
1302 .await
1303 .expect("completed event should send");
1304 socket
1305 .send(Message::text(
1306 json!({
1307 "type": "response.done",
1308 "response": {
1309 "id": response_id,
1310 "status": "completed",
1311 },
1312 })
1313 .to_string(),
1314 ))
1315 .await
1316 .expect("done event should send");
1317 }
1318 });
1319
1320 let base_url = format!("http://{address}/v1");
1321 let client = crate::providers::openai::Client::builder()
1322 .api_key("test-key")
1323 .base_url(&base_url)
1324 .build()
1325 .expect("client should build");
1326 let model = client.completion_model("gpt-4o");
1327 let mut session = client
1328 .responses_websocket("gpt-4o")
1329 .await
1330 .expect("session should connect");
1331
1332 session
1333 .send(model.completion_request("first").build())
1334 .await
1335 .expect("first request should send");
1336 let first = session
1337 .wait_for_completed_response()
1338 .await
1339 .expect("first response should complete");
1340 assert_eq!(first.id, "resp_1");
1341
1342 session.clear_previous_response_id();
1343 assert_eq!(session.previous_response_id(), None);
1344
1345 session
1346 .send(model.completion_request("second").build())
1347 .await
1348 .expect("second request should send");
1349 let second = session
1350 .wait_for_completed_response()
1351 .await
1352 .expect("second response should complete");
1353 assert_eq!(second.id, "resp_2");
1354
1355 server.await.expect("server task should finish");
1356 }
1357
1358 #[tokio::test]
1359 async fn failed_turn_keeps_late_done_out_of_next_request() {
1360 let listener = TcpListener::bind("127.0.0.1:0")
1361 .await
1362 .expect("listener should bind");
1363 let address = listener.local_addr().expect("listener should have address");
1364
1365 let server = tokio::spawn(async move {
1366 let (stream, _) = listener.accept().await.expect("server should accept");
1367 let mut socket = accept_async(stream)
1368 .await
1369 .expect("server should upgrade websocket");
1370
1371 let first_request = socket
1372 .next()
1373 .await
1374 .expect("request should exist")
1375 .expect("request should be valid");
1376 let payload = first_request
1377 .into_text()
1378 .expect("failed request should be text");
1379 assert!(
1380 payload.contains("\"type\":\"response.create\""),
1381 "expected response.create payload, got {payload}"
1382 );
1383
1384 let failed_response = serde_json::to_value(CompletionResponse {
1385 id: "resp_failed".to_string(),
1386 status: ResponseStatus::Failed,
1387 ..sample_response(ResponseStatus::Completed)
1388 })
1389 .expect("failed response should serialize");
1390
1391 socket
1392 .send(Message::text(
1393 json!({
1394 "type": "response.failed",
1395 "sequence_number": 1,
1396 "response": failed_response,
1397 })
1398 .to_string(),
1399 ))
1400 .await
1401 .expect("failed event should send");
1402 socket
1403 .send(Message::text(
1404 json!({
1405 "type": "response.done",
1406 "response": {
1407 "id": "resp_failed",
1408 "status": "failed",
1409 },
1410 })
1411 .to_string(),
1412 ))
1413 .await
1414 .expect("done event should send");
1415
1416 let second_request = socket
1417 .next()
1418 .await
1419 .expect("request should exist")
1420 .expect("request should be valid");
1421 let payload = second_request
1422 .into_text()
1423 .expect("second request should be text");
1424 assert!(
1425 payload.contains("\"type\":\"response.create\""),
1426 "expected response.create payload, got {payload}"
1427 );
1428
1429 let response = sample_response(ResponseStatus::Completed);
1430 let response = serde_json::to_value(CompletionResponse {
1431 id: "resp_2".to_string(),
1432 ..response
1433 })
1434 .expect("response should serialize");
1435
1436 socket
1437 .send(Message::text(
1438 json!({
1439 "type": "response.completed",
1440 "sequence_number": 2,
1441 "response": response,
1442 })
1443 .to_string(),
1444 ))
1445 .await
1446 .expect("completed event should send");
1447 socket
1448 .send(Message::text(
1449 json!({
1450 "type": "response.done",
1451 "response": {
1452 "id": "resp_2",
1453 "status": "completed",
1454 },
1455 })
1456 .to_string(),
1457 ))
1458 .await
1459 .expect("done event should send");
1460 });
1461
1462 let base_url = format!("http://{address}/v1");
1463 let client = crate::providers::openai::Client::builder()
1464 .api_key("test-key")
1465 .base_url(&base_url)
1466 .build()
1467 .expect("client should build");
1468 let model = client.completion_model("gpt-4o");
1469 let mut session = client
1470 .responses_websocket("gpt-4o")
1471 .await
1472 .expect("session should connect");
1473
1474 session
1475 .send(model.completion_request("first").build())
1476 .await
1477 .expect("first request should send");
1478 let error = session
1479 .wait_for_completed_response()
1480 .await
1481 .expect_err("failed response should error");
1482 assert!(error.to_string().contains("failed response"));
1483 assert_eq!(session.previous_response_id(), None);
1484
1485 session
1486 .send(model.completion_request("second").build())
1487 .await
1488 .expect("second request should send");
1489 let second = session
1490 .wait_for_completed_response()
1491 .await
1492 .expect("second response should complete");
1493 assert_eq!(second.id, "resp_2");
1494
1495 server.await.expect("server task should finish");
1496 }
1497
1498 #[tokio::test]
1499 async fn done_first_completed_turn_updates_previous_response_id() {
1500 let listener = TcpListener::bind("127.0.0.1:0")
1501 .await
1502 .expect("listener should bind");
1503 let address = listener.local_addr().expect("listener should have address");
1504
1505 let server = tokio::spawn(async move {
1506 let (stream, _) = listener.accept().await.expect("server should accept");
1507 let mut socket = accept_async(stream)
1508 .await
1509 .expect("server should upgrade websocket");
1510
1511 for response_id in ["resp_1", "resp_2"] {
1512 let request = socket
1513 .next()
1514 .await
1515 .expect("request should exist")
1516 .expect("request should be valid");
1517 let payload = request.into_text().expect("request should be text");
1518 assert!(
1519 payload.contains("\"type\":\"response.create\""),
1520 "expected response.create payload, got {payload}"
1521 );
1522
1523 if response_id == "resp_2" {
1524 assert!(
1525 payload.contains("\"previous_response_id\":\"resp_1\""),
1526 "expected chained previous_response_id in payload, got {payload}"
1527 );
1528 }
1529
1530 let response = serde_json::to_value(CompletionResponse {
1531 id: response_id.to_string(),
1532 ..sample_response(ResponseStatus::Completed)
1533 })
1534 .expect("response should serialize");
1535
1536 socket
1537 .send(Message::text(
1538 json!({
1539 "type": "response.done",
1540 "response": response,
1541 })
1542 .to_string(),
1543 ))
1544 .await
1545 .expect("done event should send");
1546 }
1547 });
1548
1549 let base_url = format!("http://{address}/v1");
1550 let client = crate::providers::openai::Client::builder()
1551 .api_key("test-key")
1552 .base_url(&base_url)
1553 .build()
1554 .expect("client should build");
1555 let model = client.completion_model("gpt-4o");
1556 let mut session = client
1557 .responses_websocket("gpt-4o")
1558 .await
1559 .expect("session should connect");
1560
1561 session
1562 .send(model.completion_request("first").build())
1563 .await
1564 .expect("first request should send");
1565 let first = session
1566 .wait_for_completed_response()
1567 .await
1568 .expect("first response should complete");
1569 assert_eq!(first.id, "resp_1");
1570 assert_eq!(session.previous_response_id(), Some("resp_1"));
1571
1572 session
1573 .send(model.completion_request("second").build())
1574 .await
1575 .expect("second request should send");
1576 let second = session
1577 .wait_for_completed_response()
1578 .await
1579 .expect("second response should complete");
1580 assert_eq!(second.id, "resp_2");
1581 assert_eq!(session.previous_response_id(), Some("resp_2"));
1582
1583 server.await.expect("server task should finish");
1584 }
1585
1586 #[tokio::test]
1587 async fn done_first_failed_turn_does_not_chain_next_request() {
1588 let listener = TcpListener::bind("127.0.0.1:0")
1589 .await
1590 .expect("listener should bind");
1591 let address = listener.local_addr().expect("listener should have address");
1592
1593 let server = tokio::spawn(async move {
1594 let (stream, _) = listener.accept().await.expect("server should accept");
1595 let mut socket = accept_async(stream)
1596 .await
1597 .expect("server should upgrade websocket");
1598
1599 let first_request = socket
1600 .next()
1601 .await
1602 .expect("request should exist")
1603 .expect("request should be valid");
1604 let payload = first_request
1605 .into_text()
1606 .expect("first request should be text");
1607 assert!(
1608 payload.contains("\"type\":\"response.create\""),
1609 "expected response.create payload, got {payload}"
1610 );
1611 assert!(
1612 !payload.contains("\"previous_response_id\""),
1613 "did not expect previous_response_id in first payload, got {payload}"
1614 );
1615
1616 let failed_response = serde_json::to_value(CompletionResponse {
1617 id: "resp_failed".to_string(),
1618 status: ResponseStatus::Failed,
1619 ..sample_response(ResponseStatus::Completed)
1620 })
1621 .expect("failed response should serialize");
1622
1623 socket
1624 .send(Message::text(
1625 json!({
1626 "type": "response.done",
1627 "response": failed_response,
1628 })
1629 .to_string(),
1630 ))
1631 .await
1632 .expect("done event should send");
1633
1634 let second_request = socket
1635 .next()
1636 .await
1637 .expect("request should exist")
1638 .expect("request should be valid");
1639 let payload = second_request
1640 .into_text()
1641 .expect("second request should be text");
1642 assert!(
1643 payload.contains("\"type\":\"response.create\""),
1644 "expected response.create payload, got {payload}"
1645 );
1646 assert!(
1647 !payload.contains("\"previous_response_id\""),
1648 "did not expect chained previous_response_id in payload, got {payload}"
1649 );
1650
1651 let response = serde_json::to_value(CompletionResponse {
1652 id: "resp_2".to_string(),
1653 ..sample_response(ResponseStatus::Completed)
1654 })
1655 .expect("response should serialize");
1656
1657 socket
1658 .send(Message::text(
1659 json!({
1660 "type": "response.done",
1661 "response": response,
1662 })
1663 .to_string(),
1664 ))
1665 .await
1666 .expect("done event should send");
1667 });
1668
1669 let base_url = format!("http://{address}/v1");
1670 let client = crate::providers::openai::Client::builder()
1671 .api_key("test-key")
1672 .base_url(&base_url)
1673 .build()
1674 .expect("client should build");
1675 let model = client.completion_model("gpt-4o");
1676 let mut session = client
1677 .responses_websocket("gpt-4o")
1678 .await
1679 .expect("session should connect");
1680
1681 session
1682 .send(model.completion_request("first").build())
1683 .await
1684 .expect("first request should send");
1685 let error = session
1686 .wait_for_completed_response()
1687 .await
1688 .expect_err("failed response should error");
1689 assert!(error.to_string().contains("failed response"));
1690 assert_eq!(session.previous_response_id(), None);
1691
1692 session
1693 .send(model.completion_request("second").build())
1694 .await
1695 .expect("second request should send");
1696 let second = session
1697 .wait_for_completed_response()
1698 .await
1699 .expect("second response should complete");
1700 assert_eq!(second.id, "resp_2");
1701 assert_eq!(session.previous_response_id(), Some("resp_2"));
1702
1703 server.await.expect("server task should finish");
1704 }
1705
1706 #[test]
1707 fn websocket_url_converts_http_to_ws() {
1708 let url = websocket_url("http://localhost:8080/v1").expect("url should convert");
1709 assert_eq!(url, "ws://localhost:8080/v1/responses");
1710 }
1711
1712 #[test]
1713 fn websocket_url_rejects_unsupported_scheme() {
1714 let result = websocket_url("ftp://example.com/v1");
1715 assert!(result.is_err());
1716 }
1717
1718 #[test]
1719 fn websocket_url_trims_trailing_slash() {
1720 let url = websocket_url("https://api.openai.com/v1/").expect("url should convert");
1721 assert_eq!(url, "wss://api.openai.com/v1/responses");
1722 }
1723
1724 #[test]
1725 fn unknown_event_type_is_skipped() {
1726 let payload = json!({
1727 "type": "response.some_future_event",
1728 "data": "hello"
1729 });
1730
1731 let result =
1732 parse_server_event(&payload.to_string()).expect("unknown event should not error");
1733 assert!(result.is_none(), "unknown event should be skipped");
1734 }
1735
1736 #[test]
1737 fn malformed_known_event_returns_error() {
1738 let payload = json!({
1739 "type": "response.completed"
1740 });
1741
1742 let error = parse_server_event(&payload.to_string())
1743 .expect_err("malformed known event should error");
1744 assert!(
1745 error.to_string().contains("StreamingCompletionChunk"),
1746 "expected strict decode failure, got {error}"
1747 );
1748 }
1749
1750 #[tokio::test]
1751 async fn close_is_idempotent() {
1752 let listener = TcpListener::bind("127.0.0.1:0")
1753 .await
1754 .expect("listener should bind");
1755 let address = listener.local_addr().expect("listener should have address");
1756
1757 let server = tokio::spawn(async move {
1758 let (stream, _) = listener.accept().await.expect("server should accept");
1759 let mut socket = accept_async(stream)
1760 .await
1761 .expect("server should upgrade websocket");
1762
1763 let message = socket
1764 .next()
1765 .await
1766 .expect("close frame should arrive")
1767 .expect("close frame should be valid");
1768 assert!(
1769 matches!(message, Message::Close(_)),
1770 "expected close frame, got {message:?}"
1771 );
1772 });
1773
1774 let base_url = format!("http://{address}/v1");
1775 let client = crate::providers::openai::Client::builder()
1776 .api_key("test-key")
1777 .base_url(&base_url)
1778 .build()
1779 .expect("client should build");
1780 let mut session = client
1781 .responses_websocket("gpt-4o")
1782 .await
1783 .expect("session should connect");
1784
1785 session.close().await.expect("first close should succeed");
1786 session.close().await.expect("second close should succeed");
1787
1788 server.await.expect("server task should finish");
1789 }
1790
1791 #[tokio::test]
1792 async fn send_while_in_flight_returns_error() {
1793 let listener = TcpListener::bind("127.0.0.1:0")
1794 .await
1795 .expect("listener should bind");
1796 let address = listener.local_addr().expect("listener should have address");
1797
1798 let server = tokio::spawn(async move {
1799 let (stream, _) = listener.accept().await.expect("server should accept");
1800 let mut socket = accept_async(stream)
1801 .await
1802 .expect("server should upgrade websocket");
1803
1804 let _request = socket
1806 .next()
1807 .await
1808 .expect("request should exist")
1809 .expect("request should be valid");
1810
1811 sleep(Duration::from_millis(100)).await;
1813 let _ = socket.close(None).await;
1814 });
1815
1816 let base_url = format!("http://{address}/v1");
1817 let client = crate::providers::openai::Client::builder()
1818 .api_key("test-key")
1819 .base_url(&base_url)
1820 .build()
1821 .expect("client should build");
1822 let model = client.completion_model("gpt-4o");
1823 let mut session = client
1824 .responses_websocket("gpt-4o")
1825 .await
1826 .expect("session should connect");
1827
1828 session
1829 .send(model.completion_request("first").build())
1830 .await
1831 .expect("first request should send");
1832
1833 let error = session
1834 .send(model.completion_request("second").build())
1835 .await
1836 .expect_err("second send while in-flight should error");
1837 assert!(
1838 error.to_string().contains("already in flight"),
1839 "expected in-flight error, got {error}"
1840 );
1841
1842 server.await.expect("server task should finish");
1843 }
1844
1845 #[tokio::test]
1846 async fn send_after_close_returns_error() {
1847 let listener = TcpListener::bind("127.0.0.1:0")
1848 .await
1849 .expect("listener should bind");
1850 let address = listener.local_addr().expect("listener should have address");
1851
1852 let server = tokio::spawn(async move {
1853 let (stream, _) = listener.accept().await.expect("server should accept");
1854 let _socket = accept_async(stream)
1855 .await
1856 .expect("server should upgrade websocket");
1857 sleep(Duration::from_millis(100)).await;
1858 });
1859
1860 let base_url = format!("http://{address}/v1");
1861 let client = crate::providers::openai::Client::builder()
1862 .api_key("test-key")
1863 .base_url(&base_url)
1864 .build()
1865 .expect("client should build");
1866 let model = client.completion_model("gpt-4o");
1867 let mut session = client
1868 .responses_websocket("gpt-4o")
1869 .await
1870 .expect("session should connect");
1871
1872 session.close().await.expect("close should succeed");
1873
1874 let error = session
1875 .send(model.completion_request("after close").build())
1876 .await
1877 .expect_err("send after close should error");
1878 assert!(
1879 error.to_string().contains("session is closed"),
1880 "expected closed-session error, got {error}"
1881 );
1882
1883 server.await.expect("server task should finish");
1884 }
1885
1886 #[tokio::test]
1887 async fn next_event_without_send_returns_error() {
1888 let listener = TcpListener::bind("127.0.0.1:0")
1889 .await
1890 .expect("listener should bind");
1891 let address = listener.local_addr().expect("listener should have address");
1892
1893 let server = tokio::spawn(async move {
1894 let (stream, _) = listener.accept().await.expect("server should accept");
1895 let _socket = accept_async(stream)
1896 .await
1897 .expect("server should upgrade websocket");
1898 sleep(Duration::from_millis(100)).await;
1899 });
1900
1901 let base_url = format!("http://{address}/v1");
1902 let client = crate::providers::openai::Client::builder()
1903 .api_key("test-key")
1904 .base_url(&base_url)
1905 .build()
1906 .expect("client should build");
1907 let mut session = client
1908 .responses_websocket("gpt-4o")
1909 .await
1910 .expect("session should connect");
1911
1912 let error = session
1913 .next_event()
1914 .await
1915 .expect_err("next_event without send should error");
1916 assert!(
1917 error
1918 .to_string()
1919 .contains("No OpenAI websocket response is currently in flight"),
1920 "expected not-in-flight error, got {error}"
1921 );
1922
1923 server.await.expect("server task should finish");
1924 }
1925
1926 #[tokio::test]
1927 async fn unknown_event_is_skipped_during_session() {
1928 let listener = TcpListener::bind("127.0.0.1:0")
1929 .await
1930 .expect("listener should bind");
1931 let address = listener.local_addr().expect("listener should have address");
1932
1933 let server = tokio::spawn(async move {
1934 let (stream, _) = listener.accept().await.expect("server should accept");
1935 let mut socket = accept_async(stream)
1936 .await
1937 .expect("server should upgrade websocket");
1938
1939 let _request = socket
1940 .next()
1941 .await
1942 .expect("request should exist")
1943 .expect("request should be valid");
1944
1945 socket
1947 .send(Message::text(
1948 json!({
1949 "type": "response.some_future_event",
1950 "data": "should be skipped"
1951 })
1952 .to_string(),
1953 ))
1954 .await
1955 .expect("unknown event should send");
1956
1957 let response = serde_json::to_value(CompletionResponse {
1959 id: "resp_after_unknown".to_string(),
1960 ..sample_response(ResponseStatus::Completed)
1961 })
1962 .expect("response should serialize");
1963
1964 socket
1965 .send(Message::text(
1966 json!({
1967 "type": "response.completed",
1968 "sequence_number": 1,
1969 "response": response,
1970 })
1971 .to_string(),
1972 ))
1973 .await
1974 .expect("completed event should send");
1975 });
1976
1977 let base_url = format!("http://{address}/v1");
1978 let client = crate::providers::openai::Client::builder()
1979 .api_key("test-key")
1980 .base_url(&base_url)
1981 .build()
1982 .expect("client should build");
1983 let model = client.completion_model("gpt-4o");
1984 let mut session = client
1985 .responses_websocket("gpt-4o")
1986 .await
1987 .expect("session should connect");
1988
1989 session
1990 .send(model.completion_request("hello").build())
1991 .await
1992 .expect("send should succeed");
1993 let response = session
1994 .wait_for_completed_response()
1995 .await
1996 .expect("response should complete despite unknown event");
1997 assert_eq!(response.id, "resp_after_unknown");
1998
1999 server.await.expect("server task should finish");
2000 }
2001}