1use crate::completion::{self, CompletionError};
8use crate::http_client::HttpClientExt;
9use crate::providers::openai::responses_api::streaming::{
10 ItemChunk, ResponseChunk, ResponseChunkKind, StreamingCompletionChunk,
11};
12use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
13use futures::{SinkExt, StreamExt};
14use serde::{Deserialize, Serialize};
15use serde_json::{Map, Value};
16use std::time::Duration;
17use tokio::net::TcpStream;
18use tokio_tungstenite::{
19 MaybeTlsStream, WebSocketStream, connect_async,
20 tungstenite::{self, Message, client::IntoClientRequest},
21};
22use tracing::Level;
23use url::Url;
24
25use super::{CompletionResponse, ResponseError, ResponseStatus, ResponsesCompletionModel};
26
27type OpenAIWebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
28const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
32pub struct ResponsesWebSocketCreateOptions {
33 #[serde(skip_serializing_if = "Option::is_none")]
37 pub generate: Option<bool>,
38}
39
40impl ResponsesWebSocketCreateOptions {
41 #[must_use]
43 pub fn warmup() -> Self {
44 Self {
45 generate: Some(false),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize)]
51struct ResponsesWebSocketClientEvent {
52 #[serde(rename = "type")]
53 kind: ResponsesWebSocketClientEventKind,
54 #[serde(flatten)]
55 request: super::CompletionRequest,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 generate: Option<bool>,
58}
59
60#[derive(Debug, Clone, Serialize)]
61enum ResponsesWebSocketClientEventKind {
62 #[serde(rename = "response.create")]
63 ResponseCreate,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ResponsesWebSocketErrorEvent {
69 #[serde(rename = "type")]
71 pub kind: ResponsesWebSocketErrorEventKind,
72 pub error: ResponsesWebSocketErrorPayload,
74}
75
76impl std::fmt::Display for ResponsesWebSocketErrorEvent {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 self.error.fmt(f)
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum ResponsesWebSocketErrorEventKind {
85 #[serde(rename = "error")]
86 Error,
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct ResponsesWebSocketErrorPayload {
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub code: Option<String>,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub message: Option<String>,
98 #[serde(flatten, default)]
100 pub extra: Map<String, Value>,
101}
102
103impl std::fmt::Display for ResponsesWebSocketErrorPayload {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match (&self.code, &self.message) {
106 (Some(code), Some(message)) => write!(f, "{code}: {message}"),
107 (None, Some(message)) => f.write_str(message),
108 (Some(code), None) => f.write_str(code),
109 (None, None) => f.write_str("OpenAI websocket error"),
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ResponsesWebSocketDoneEvent {
117 #[serde(rename = "type")]
119 pub kind: ResponsesWebSocketDoneEventKind,
120 pub response: Value,
122}
123
124impl ResponsesWebSocketDoneEvent {
125 #[must_use]
127 pub fn response_id(&self) -> Option<&str> {
128 self.response.get("id").and_then(Value::as_str)
129 }
130
131 fn status(&self) -> Option<ResponseStatus> {
132 self.response
133 .get("status")
134 .cloned()
135 .and_then(|status| serde_json::from_value(status).ok())
136 }
137
138 fn as_completion_response(&self) -> Option<CompletionResponse> {
139 serde_json::from_value(self.response.clone()).ok()
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum ResponsesWebSocketDoneEventKind {
146 #[serde(rename = "response.done")]
147 ResponseDone,
148}
149
150#[derive(Debug, Clone)]
152pub enum ResponsesWebSocketEvent {
153 Response(Box<ResponseChunk>),
155 Item(ItemChunk),
157 Error(ResponsesWebSocketErrorEvent),
159 Done(ResponsesWebSocketDoneEvent),
161}
162
163impl ResponsesWebSocketEvent {
164 #[must_use]
166 pub fn response_id(&self) -> Option<&str> {
167 match self {
168 Self::Response(chunk) => Some(&chunk.response.id),
169 Self::Done(done) => done.response_id(),
170 Self::Item(_) | Self::Error(_) => None,
171 }
172 }
173
174 #[must_use]
176 pub fn is_terminal(&self) -> bool {
177 match self {
178 Self::Response(chunk) => matches!(
179 chunk.kind,
180 ResponseChunkKind::ResponseCompleted
181 | ResponseChunkKind::ResponseFailed
182 | ResponseChunkKind::ResponseIncomplete
183 ),
184 Self::Error(_) | Self::Done(_) => true,
185 Self::Item(_) => false,
186 }
187 }
188}
189
190pub struct ResponsesWebSocketSessionBuilder<T = reqwest::Client> {
195 model: ResponsesCompletionModel<T>,
196 connect_timeout: Option<Duration>,
197 event_timeout: Option<Duration>,
198}
199
200impl<T> ResponsesWebSocketSessionBuilder<T> {
201 pub(crate) fn new(model: ResponsesCompletionModel<T>) -> Self {
202 Self {
203 model,
204 connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
205 event_timeout: None,
206 }
207 }
208
209 #[must_use]
211 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
212 self.connect_timeout = Some(timeout);
213 self
214 }
215
216 #[must_use]
218 pub fn without_connect_timeout(mut self) -> Self {
219 self.connect_timeout = None;
220 self
221 }
222
223 #[must_use]
225 pub fn event_timeout(mut self, timeout: Duration) -> Self {
226 self.event_timeout = Some(timeout);
227 self
228 }
229
230 #[must_use]
232 pub fn without_event_timeout(mut self) -> Self {
233 self.event_timeout = None;
234 self
235 }
236}
237
238impl<T> ResponsesWebSocketSessionBuilder<T>
239where
240 T: HttpClientExt
241 + Clone
242 + std::fmt::Debug
243 + Default
244 + WasmCompatSend
245 + WasmCompatSync
246 + 'static,
247{
248 pub async fn connect(self) -> Result<ResponsesWebSocketSession<T>, CompletionError> {
250 ResponsesWebSocketSession::connect_with_timeouts(
251 self.model,
252 self.connect_timeout,
253 self.event_timeout,
254 )
255 .await
256 }
257}
258
259pub struct ResponsesWebSocketSession<T = reqwest::Client> {
268 model: ResponsesCompletionModel<T>,
269 previous_response_id: Option<String>,
270 pending_done_response_id: Option<String>,
271 socket: OpenAIWebSocket,
272 in_flight: bool,
273 event_timeout: Option<Duration>,
274 closed: bool,
275 failed: bool,
276}
277
278impl<T> ResponsesWebSocketSession<T>
279where
280 T: HttpClientExt
281 + Clone
282 + std::fmt::Debug
283 + Default
284 + WasmCompatSend
285 + WasmCompatSync
286 + 'static,
287{
288 async fn connect_with_timeouts(
289 model: ResponsesCompletionModel<T>,
290 connect_timeout: Option<Duration>,
291 event_timeout: Option<Duration>,
292 ) -> Result<Self, CompletionError> {
293 let url = websocket_url(model.client.base_url())?;
294 let request = websocket_request(&url, model.client.headers())?;
295 let socket = connect_websocket(request, connect_timeout).await?;
296
297 Ok(Self {
298 model,
299 previous_response_id: None,
300 pending_done_response_id: None,
301 socket,
302 in_flight: false,
303 event_timeout,
304 closed: false,
305 failed: false,
306 })
307 }
308
309 #[must_use]
311 pub fn previous_response_id(&self) -> Option<&str> {
312 self.previous_response_id.as_deref()
313 }
314
315 pub fn clear_previous_response_id(&mut self) {
317 self.previous_response_id = None;
318 }
319
320 pub async fn send(
322 &mut self,
323 completion_request: crate::completion::CompletionRequest,
324 ) -> Result<(), CompletionError> {
325 self.send_with_options(
326 completion_request,
327 ResponsesWebSocketCreateOptions::default(),
328 )
329 .await
330 }
331
332 pub async fn send_with_options(
334 &mut self,
335 completion_request: crate::completion::CompletionRequest,
336 options: ResponsesWebSocketCreateOptions,
337 ) -> Result<(), CompletionError> {
338 self.ensure_open()?;
339
340 if self.in_flight {
341 return Err(CompletionError::ProviderError(
342 "An OpenAI websocket response is already in flight on this session".to_string(),
343 ));
344 }
345
346 let payload = ResponsesWebSocketClientEvent {
347 kind: ResponsesWebSocketClientEventKind::ResponseCreate,
348 request: self.prepare_request(completion_request)?,
349 generate: options.generate,
350 };
351
352 if tracing::enabled!(Level::TRACE) {
353 tracing::trace!(
354 target: "rig::completions",
355 "OpenAI websocket request: {}",
356 serde_json::to_string_pretty(&payload)?
357 );
358 }
359
360 let payload = serde_json::to_string(&payload)?;
361
362 if let Err(error) = self.socket.send(Message::text(payload)).await {
363 return Err(self.fail_session(websocket_provider_error(error)));
364 }
365 self.in_flight = true;
366
367 Ok(())
368 }
369
370 pub async fn next_event(&mut self) -> Result<ResponsesWebSocketEvent, CompletionError> {
372 self.ensure_open()?;
373
374 if !self.in_flight {
375 return Err(CompletionError::ProviderError(
376 "No OpenAI websocket response is currently in flight on this session".to_string(),
377 ));
378 }
379
380 loop {
381 let message = match self.read_next_message().await {
382 Ok(message) => message,
383 Err(error) => return Err(error),
384 };
385
386 let Some(message) = message else {
387 self.mark_closed();
388 return Err(CompletionError::ProviderError(
389 "The OpenAI websocket connection closed before the turn finished".to_string(),
390 ));
391 };
392
393 let message = match message {
394 Ok(message) => message,
395 Err(error) => return Err(self.fail_session(websocket_provider_error(error))),
396 };
397 let payload = match websocket_message_to_text(message) {
398 Ok(Some(payload)) => payload,
399 Ok(None) => continue,
400 Err(error) => return Err(self.fail_session(error)),
401 };
402 let event = match parse_server_event(&payload) {
403 Ok(Some(event)) => event,
404 Ok(None) => continue,
405 Err(error) => return Err(self.fail_session(error)),
406 };
407 if let ResponsesWebSocketEvent::Done(done) = &event {
408 if self.pending_done_response_id.as_deref() == done.response_id() {
411 self.pending_done_response_id = None;
412 continue;
413 }
414 }
415 self.update_state_for_event(&event);
416 return Ok(event);
417 }
418 }
419
420 pub async fn warmup(
422 &mut self,
423 completion_request: crate::completion::CompletionRequest,
424 ) -> Result<String, CompletionError> {
425 self.send_with_options(
426 completion_request,
427 ResponsesWebSocketCreateOptions::warmup(),
428 )
429 .await?;
430 let response = self.wait_for_completed_response().await?;
431 Ok(response.id)
432 }
433
434 pub async fn completion(
436 &mut self,
437 completion_request: crate::completion::CompletionRequest,
438 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
439 self.send(completion_request).await?;
440 let response = self.wait_for_completed_response().await?;
441 response.try_into()
442 }
443
444 pub async fn close(&mut self) -> Result<(), CompletionError> {
449 if self.closed {
450 return Ok(());
451 }
452
453 let result = self
454 .socket
455 .close(None)
456 .await
457 .map_err(websocket_provider_error);
458 self.mark_closed();
459 result
460 }
461
462 fn prepare_request(
463 &self,
464 completion_request: crate::completion::CompletionRequest,
465 ) -> Result<super::CompletionRequest, CompletionError> {
466 let mut request = self.model.create_completion_request(completion_request)?;
467
468 request.stream = None;
471 request.additional_parameters.background = None;
472
473 if request.additional_parameters.previous_response_id.is_none() {
474 request.additional_parameters.previous_response_id = self.previous_response_id.clone();
475 }
476
477 Ok(request)
478 }
479
480 async fn wait_for_completed_response(&mut self) -> Result<CompletionResponse, CompletionError> {
481 loop {
482 match self.next_event().await? {
483 ResponsesWebSocketEvent::Response(chunk) => {
484 if matches!(
485 chunk.kind,
486 ResponseChunkKind::ResponseCompleted
487 | ResponseChunkKind::ResponseFailed
488 | ResponseChunkKind::ResponseIncomplete
489 ) {
490 return terminal_response_result(chunk.response);
491 }
492 }
493 ResponsesWebSocketEvent::Done(done) => {
494 if let Some(response) = done.as_completion_response() {
495 return terminal_response_result(response);
496 }
497
498 let message = if let Some(response_id) = done.response_id() {
499 format!(
500 "OpenAI websocket turn ended with response.done before a terminal response body was available (response_id={response_id})"
501 )
502 } else {
503 "OpenAI websocket turn ended with response.done before a terminal response body was available"
504 .to_string()
505 };
506
507 return Err(CompletionError::ProviderError(message));
508 }
509 ResponsesWebSocketEvent::Error(error) => {
510 return Err(CompletionError::ProviderError(error.to_string()));
511 }
512 ResponsesWebSocketEvent::Item(_) => {}
513 }
514 }
515 }
516
517 fn update_state_for_event(&mut self, event: &ResponsesWebSocketEvent) {
518 match event {
519 ResponsesWebSocketEvent::Response(chunk) => match chunk.kind {
520 ResponseChunkKind::ResponseCompleted => {
521 let response_id = chunk.response.id.clone();
522 self.previous_response_id = Some(response_id.clone());
523 self.pending_done_response_id = Some(response_id);
524 self.in_flight = false;
525 }
526 ResponseChunkKind::ResponseFailed | ResponseChunkKind::ResponseIncomplete => {
527 self.pending_done_response_id = Some(chunk.response.id.clone());
528 self.previous_response_id = None;
529 self.in_flight = false;
530 }
531 ResponseChunkKind::ResponseCreated | ResponseChunkKind::ResponseInProgress => {}
532 },
533 ResponsesWebSocketEvent::Done(done) => {
534 match done.status() {
535 Some(ResponseStatus::Completed) => {
536 if let Some(response_id) = done.response_id() {
537 self.previous_response_id = Some(response_id.to_string());
538 }
539 }
540 Some(ResponseStatus::Failed)
541 | Some(ResponseStatus::Incomplete)
542 | Some(ResponseStatus::Cancelled) => {
543 self.previous_response_id = None;
544 }
545 Some(ResponseStatus::InProgress | ResponseStatus::Queued) | None => {}
546 }
547 self.pending_done_response_id = None;
548 self.in_flight = false;
549 }
550 ResponsesWebSocketEvent::Error(_) => {
551 self.previous_response_id = None;
552 self.pending_done_response_id = None;
553 self.in_flight = false;
554 }
555 ResponsesWebSocketEvent::Item(_) => {}
556 }
557 }
558
559 fn abort_turn(&mut self) {
560 self.previous_response_id = None;
561 self.pending_done_response_id = None;
562 self.in_flight = false;
563 }
564
565 fn mark_closed(&mut self) {
566 self.abort_turn();
567 self.closed = true;
568 self.failed = false;
569 }
570
571 fn mark_failed(&mut self) {
572 self.abort_turn();
573 self.failed = true;
574 }
575
576 fn ensure_open(&self) -> Result<(), CompletionError> {
577 if self.closed || self.failed {
578 return Err(CompletionError::ProviderError(
579 "The OpenAI websocket session is closed".to_string(),
580 ));
581 }
582
583 Ok(())
584 }
585
586 fn fail_session(&mut self, error: CompletionError) -> CompletionError {
587 self.mark_failed();
588 error
589 }
590
591 async fn read_next_message(
592 &mut self,
593 ) -> Result<Option<Result<Message, tungstenite::Error>>, CompletionError> {
594 if let Some(timeout_duration) = self.event_timeout {
595 match tokio::time::timeout(timeout_duration, self.socket.next()).await {
596 Ok(message) => Ok(message),
597 Err(_) => Err(self.fail_session(event_timeout_error(timeout_duration))),
598 }
599 } else {
600 Ok(self.socket.next().await)
601 }
602 }
603}
604
605impl<T> Drop for ResponsesWebSocketSession<T> {
606 fn drop(&mut self) {
607 if !self.closed {
608 tracing::warn!(
609 target: "rig::completions",
610 in_flight = self.in_flight,
611 "Dropping an OpenAI websocket session without calling close(); the connection will end without a close handshake"
612 );
613 }
614 }
615}
616
617fn terminal_response_result(
618 response: CompletionResponse,
619) -> Result<CompletionResponse, CompletionError> {
620 match response.status {
621 ResponseStatus::Completed => Ok(response),
622 ResponseStatus::Failed => Err(CompletionError::ProviderError(response_error_message(
623 response.error.as_ref(),
624 "failed response",
625 ))),
626 ResponseStatus::Incomplete => {
627 let reason = response
628 .incomplete_details
629 .as_ref()
630 .map(|details| details.reason.as_str())
631 .unwrap_or("unknown reason");
632 Err(CompletionError::ProviderError(format!(
633 "OpenAI websocket response was incomplete: {reason}"
634 )))
635 }
636 status => Err(CompletionError::ProviderError(format!(
637 "OpenAI websocket response ended with status {:?}",
638 status
639 ))),
640 }
641}
642
643fn response_error_message(error: Option<&ResponseError>, fallback: &str) -> String {
644 if let Some(error) = error {
645 if error.code.is_empty() {
646 error.message.clone()
647 } else {
648 format!("{}: {}", error.code, error.message)
649 }
650 } else {
651 format!("OpenAI websocket returned a {fallback}")
652 }
653}
654
655fn is_known_streaming_event(kind: &str) -> bool {
656 matches!(
657 kind,
658 "response.created"
659 | "response.in_progress"
660 | "response.completed"
661 | "response.failed"
662 | "response.incomplete"
663 | "response.output_item.added"
664 | "response.output_item.done"
665 | "response.content_part.added"
666 | "response.content_part.done"
667 | "response.output_text.delta"
668 | "response.output_text.done"
669 | "response.refusal.delta"
670 | "response.refusal.done"
671 | "response.function_call_arguments.delta"
672 | "response.function_call_arguments.done"
673 | "response.reasoning_summary_part.added"
674 | "response.reasoning_summary_part.done"
675 | "response.reasoning_summary_text.delta"
676 | "response.reasoning_summary_text.done"
677 )
678}
679
680fn parse_server_event(payload: &str) -> Result<Option<ResponsesWebSocketEvent>, CompletionError> {
681 #[derive(Deserialize)]
682 struct EventType {
683 #[serde(rename = "type")]
684 kind: String,
685 }
686
687 let event_type = serde_json::from_str::<EventType>(payload)?;
688 match event_type.kind.as_str() {
689 "error" => serde_json::from_str(payload)
690 .map(|e| Some(ResponsesWebSocketEvent::Error(e)))
691 .map_err(CompletionError::from),
692 "response.done" => serde_json::from_str(payload)
693 .map(|d| Some(ResponsesWebSocketEvent::Done(d)))
694 .map_err(CompletionError::from),
695 kind if is_known_streaming_event(kind) => match serde_json::from_str(payload)? {
696 StreamingCompletionChunk::Response(response) => {
697 Ok(Some(ResponsesWebSocketEvent::Response(response)))
698 }
699 StreamingCompletionChunk::Delta(item) => Ok(Some(ResponsesWebSocketEvent::Item(item))),
700 },
701 _ => {
702 tracing::debug!(
703 target: "rig::completions",
704 event_type = event_type.kind.as_str(),
705 "Skipping unrecognised OpenAI websocket event"
706 );
707 Ok(None)
708 }
709 }
710}
711
712fn websocket_message_to_text(message: Message) -> Result<Option<String>, CompletionError> {
713 match message {
714 Message::Text(text) => Ok(Some(text.to_string())),
715 Message::Binary(bytes) => String::from_utf8(bytes.to_vec())
716 .map(Some)
717 .map_err(|error| CompletionError::ResponseError(error.to_string())),
718 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(None),
719 Message::Close(frame) => {
720 let reason = frame
721 .map(|frame| frame.reason.to_string())
722 .filter(|reason| !reason.is_empty())
723 .unwrap_or_else(|| "without a close reason".to_string());
724 Err(CompletionError::ProviderError(format!(
725 "The OpenAI websocket connection closed {reason}"
726 )))
727 }
728 }
729}
730
731fn websocket_url(base_url: &str) -> Result<String, CompletionError> {
732 let mut url = Url::parse(base_url)?;
733 match url.scheme() {
734 "https" => {
735 url.set_scheme("wss").map_err(|_| {
736 CompletionError::ProviderError("Failed to convert https URL to wss".to_string())
737 })?;
738 }
739 "http" => {
740 url.set_scheme("ws").map_err(|_| {
741 CompletionError::ProviderError("Failed to convert http URL to ws".to_string())
742 })?;
743 }
744 scheme => {
745 return Err(CompletionError::ProviderError(format!(
746 "Unsupported base URL scheme for OpenAI websocket mode: {scheme}"
747 )));
748 }
749 }
750
751 let path = format!("{}/responses", url.path().trim_end_matches('/'));
752 url.set_path(&path);
753 Ok(url.to_string())
754}
755
756fn websocket_request(
757 url: &str,
758 headers: &http::HeaderMap,
759) -> Result<http::Request<()>, CompletionError> {
760 let mut request = url.into_client_request().map_err(|error| {
761 CompletionError::ProviderError(format!("Failed to build OpenAI websocket request: {error}"))
762 })?;
763
764 for (name, value) in headers {
765 request.headers_mut().insert(name, value.clone());
766 }
767
768 Ok(request)
769}
770
771async fn connect_websocket(
772 request: http::Request<()>,
773 connect_timeout: Option<Duration>,
774) -> Result<OpenAIWebSocket, CompletionError> {
775 if let Some(timeout_duration) = connect_timeout {
776 match tokio::time::timeout(timeout_duration, connect_async(request)).await {
777 Ok(result) => result
778 .map(|(socket, _)| socket)
779 .map_err(websocket_provider_error),
780 Err(_) => Err(connect_timeout_error(timeout_duration)),
781 }
782 } else {
783 connect_async(request)
784 .await
785 .map(|(socket, _)| socket)
786 .map_err(websocket_provider_error)
787 }
788}
789
790fn connect_timeout_error(timeout: Duration) -> CompletionError {
791 CompletionError::ProviderError(format!(
792 "Timed out connecting to the OpenAI websocket after {timeout:?}"
793 ))
794}
795
796fn event_timeout_error(timeout: Duration) -> CompletionError {
797 CompletionError::ProviderError(format!(
798 "Timed out waiting for the next OpenAI websocket event after {timeout:?}"
799 ))
800}
801
802fn websocket_provider_error(error: tungstenite::Error) -> CompletionError {
803 CompletionError::ProviderError(error.to_string())
804}
805
806#[cfg(test)]
807mod tests {
808 use super::{
809 ResponsesWebSocketCreateOptions, ResponsesWebSocketDoneEvent, ResponsesWebSocketEvent,
810 parse_server_event, terminal_response_result, websocket_url,
811 };
812 use crate::client::CompletionClient;
813 use crate::completion::CompletionModel;
814 use crate::providers::openai::responses_api::{
815 CompletionResponse, ResponseObject, ResponseStatus, ResponsesUsage,
816 };
817 use futures::{SinkExt, StreamExt};
818 use serde_json::json;
819 use std::time::Duration;
820 use tokio::net::TcpListener;
821 use tokio::time::sleep;
822 use tokio_tungstenite::{accept_async, tungstenite::Message};
823
824 fn sample_response(status: ResponseStatus) -> CompletionResponse {
825 CompletionResponse {
826 id: "resp_123".to_string(),
827 object: ResponseObject::Response,
828 created_at: 0,
829 status,
830 error: None,
831 incomplete_details: None,
832 instructions: None,
833 max_output_tokens: None,
834 model: "gpt-5.4".to_string(),
835 usage: Some(ResponsesUsage {
836 input_tokens: 1,
837 input_tokens_details: None,
838 output_tokens: 2,
839 output_tokens_details:
840 crate::providers::openai::responses_api::OutputTokensDetails {
841 reasoning_tokens: 0,
842 },
843 total_tokens: 3,
844 }),
845 output: Vec::new(),
846 tools: Vec::new(),
847 additional_parameters: Default::default(),
848 }
849 }
850
851 #[test]
852 fn warmup_options_serialize_generate_false() {
853 let options = ResponsesWebSocketCreateOptions::warmup();
854 let json = serde_json::to_value(options).expect("options should serialize");
855
856 assert_eq!(json, json!({ "generate": false }));
857 }
858
859 #[test]
860 fn websocket_url_converts_https_to_wss() {
861 let url = websocket_url("https://api.openai.com/v1").expect("url should convert");
862 assert_eq!(url, "wss://api.openai.com/v1/responses");
863 }
864
865 #[test]
866 fn parse_done_event_exposes_response_id() {
867 let payload = json!({
868 "type": "response.done",
869 "response": {
870 "id": "resp_done_1",
871 "status": "completed"
872 }
873 });
874
875 let event = parse_server_event(&payload.to_string())
876 .expect("done event should deserialize")
877 .expect("done event should not be skipped");
878
879 assert!(matches!(
880 event,
881 ResponsesWebSocketEvent::Done(ResponsesWebSocketDoneEvent { .. })
882 ));
883 assert_eq!(event.response_id(), Some("resp_done_1"));
884 assert!(event.is_terminal());
885 }
886
887 #[test]
888 fn parse_response_completed_event_is_terminal() {
889 let payload = json!({
890 "type": "response.completed",
891 "sequence_number": 12,
892 "response": {
893 "id": "resp_completed_1",
894 "object": "response",
895 "created_at": 0,
896 "status": "completed",
897 "error": null,
898 "incomplete_details": null,
899 "instructions": null,
900 "max_output_tokens": null,
901 "model": "gpt-5.4",
902 "usage": null,
903 "output": [],
904 "tools": []
905 }
906 });
907
908 let event = parse_server_event(&payload.to_string())
909 .expect("response event should deserialize")
910 .expect("response event should not be skipped");
911
912 assert!(matches!(event, ResponsesWebSocketEvent::Response(_)));
913 assert!(event.is_terminal());
914 assert_eq!(event.response_id(), Some("resp_completed_1"));
915 }
916
917 #[test]
918 fn parse_live_output_item_added_event() {
919 let payload = json!({
920 "type": "response.output_item.added",
921 "item": {
922 "id": "msg_036471c3a72c147b0069ae7848d68881959773fd2d99e3d98a",
923 "type": "message",
924 "status": "in_progress",
925 "content": [],
926 "role": "assistant"
927 },
928 "output_index": 0,
929 "sequence_number": 2
930 });
931
932 let event = parse_server_event(&payload.to_string())
933 .expect("output item event should parse")
934 .expect("output item event should not be skipped");
935
936 assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
937 }
938
939 #[test]
940 fn parse_live_content_part_added_event() {
941 let payload = json!({
942 "type": "response.content_part.added",
943 "content_index": 0,
944 "item_id": "msg_036471c3a72c147b0069ae7848d68881959773fd2d99e3d98a",
945 "output_index": 0,
946 "part": {
947 "type": "output_text",
948 "annotations": [],
949 "logprobs": [],
950 "text": ""
951 },
952 "sequence_number": 3
953 });
954
955 let event = parse_server_event(&payload.to_string())
956 .expect("content part event should parse")
957 .expect("content part event should not be skipped");
958
959 assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
960 }
961
962 #[test]
963 fn parse_live_output_text_delta_event() {
964 let payload = json!({
965 "type": "response.output_text.delta",
966 "content_index": 0,
967 "delta": "Web",
968 "item_id": "msg_023af0f0a91bc2a90069ae788612e881958345bb156915ba29",
969 "logprobs": [],
970 "obfuscation": "2YYErYq7jkqqM",
971 "output_index": 0,
972 "sequence_number": 4
973 });
974
975 let event = parse_server_event(&payload.to_string())
976 .expect("output text delta event should parse")
977 .expect("output text delta event should not be skipped");
978
979 assert!(matches!(event, ResponsesWebSocketEvent::Item(_)));
980 }
981
982 #[test]
983 fn terminal_response_requires_completed_status() {
984 let completed = terminal_response_result(sample_response(ResponseStatus::Completed))
985 .expect("completed response should succeed");
986 assert_eq!(completed.id, "resp_123");
987
988 let failed = terminal_response_result(sample_response(ResponseStatus::Failed))
989 .expect_err("failed response should error");
990 assert!(failed.to_string().contains("failed response"));
991 }
992
993 #[tokio::test]
994 async fn malformed_known_event_rejects_reuse_and_allows_close() {
995 let listener = TcpListener::bind("127.0.0.1:0")
996 .await
997 .expect("listener should bind");
998 let address = listener.local_addr().expect("listener should have address");
999
1000 let server = tokio::spawn(async move {
1001 let (stream, _) = listener.accept().await.expect("server should accept");
1002 let mut socket = accept_async(stream)
1003 .await
1004 .expect("server should upgrade websocket");
1005
1006 let request = socket
1007 .next()
1008 .await
1009 .expect("request should exist")
1010 .expect("request should be valid");
1011 let payload = request.into_text().expect("request should be text");
1012 assert!(
1013 payload.contains("\"type\":\"response.create\""),
1014 "expected response.create payload, got {payload}"
1015 );
1016
1017 socket
1018 .send(Message::text(
1019 json!({
1020 "type": "response.completed"
1021 })
1022 .to_string(),
1023 ))
1024 .await
1025 .expect("malformed known event should send");
1026
1027 let message = socket
1028 .next()
1029 .await
1030 .expect("close frame should arrive")
1031 .expect("close frame should be valid");
1032 assert!(
1033 matches!(message, Message::Close(_)),
1034 "expected close frame, got {message:?}"
1035 );
1036 });
1037
1038 let base_url = format!("http://{address}/v1");
1039 let client = crate::providers::openai::Client::builder()
1040 .api_key("test-key")
1041 .base_url(&base_url)
1042 .build()
1043 .expect("client should build");
1044 let model = client.completion_model("gpt-4o");
1045 let mut session = client
1046 .responses_websocket("gpt-4o")
1047 .await
1048 .expect("session should connect");
1049
1050 session
1051 .send(model.completion_request("hello").build())
1052 .await
1053 .expect("request should send");
1054
1055 let error = session
1056 .next_event()
1057 .await
1058 .expect_err("malformed known event should fail");
1059 assert!(
1060 error.to_string().contains("StreamingCompletionChunk"),
1061 "expected strict decode failure, got {error}"
1062 );
1063
1064 let closed = session
1065 .send(model.completion_request("retry").build())
1066 .await
1067 .expect_err("session should close after fatal parse error");
1068 assert!(
1069 closed.to_string().contains("session is closed"),
1070 "expected closed-session error, got {closed}"
1071 );
1072
1073 session
1074 .close()
1075 .await
1076 .expect("explicit close after fatal parse error should succeed");
1077
1078 server.await.expect("server task should finish");
1079 }
1080
1081 #[tokio::test]
1082 async fn event_timeout_rejects_reuse_and_allows_close() {
1083 let listener = TcpListener::bind("127.0.0.1:0")
1084 .await
1085 .expect("listener should bind");
1086 let address = listener.local_addr().expect("listener should have address");
1087
1088 let server = tokio::spawn(async move {
1089 let (stream, _) = listener.accept().await.expect("server should accept");
1090 let mut socket = accept_async(stream)
1091 .await
1092 .expect("server should upgrade websocket");
1093
1094 let request = socket
1095 .next()
1096 .await
1097 .expect("request should exist")
1098 .expect("request should be valid");
1099 let payload = request.into_text().expect("request should be text");
1100 assert!(
1101 payload.contains("\"type\":\"response.create\""),
1102 "expected response.create payload, got {payload}"
1103 );
1104
1105 sleep(Duration::from_millis(60)).await;
1106 let message = socket
1107 .next()
1108 .await
1109 .expect("close frame should arrive")
1110 .expect("close frame should be valid");
1111 assert!(
1112 matches!(message, Message::Close(_)),
1113 "expected close frame, got {message:?}"
1114 );
1115 });
1116
1117 let base_url = format!("http://{address}/v1");
1118 let client = crate::providers::openai::Client::builder()
1119 .api_key("test-key")
1120 .base_url(&base_url)
1121 .build()
1122 .expect("client should build");
1123 let model = client.completion_model("gpt-4o");
1124 let mut session = client
1125 .responses_websocket_builder("gpt-4o")
1126 .event_timeout(Duration::from_millis(20))
1127 .connect()
1128 .await
1129 .expect("session should connect");
1130
1131 session
1132 .send(model.completion_request("hello").build())
1133 .await
1134 .expect("request should send");
1135
1136 let error = session
1137 .next_event()
1138 .await
1139 .expect_err("next_event should time out");
1140 assert!(
1141 error
1142 .to_string()
1143 .contains("Timed out waiting for the next OpenAI websocket event"),
1144 "expected timeout error, got {error}"
1145 );
1146
1147 let closed = session
1148 .send(model.completion_request("retry").build())
1149 .await
1150 .expect_err("timed-out session should close");
1151 assert!(
1152 closed.to_string().contains("session is closed"),
1153 "expected closed-session error, got {closed}"
1154 );
1155
1156 session
1157 .close()
1158 .await
1159 .expect("explicit close after timeout should succeed");
1160
1161 server.await.expect("server task should finish");
1162 }
1163
1164 #[tokio::test]
1165 async fn late_response_done_is_ignored_on_next_turn() {
1166 let listener = TcpListener::bind("127.0.0.1:0")
1167 .await
1168 .expect("listener should bind");
1169 let address = listener.local_addr().expect("listener should have address");
1170
1171 let server = tokio::spawn(async move {
1172 let (stream, _) = listener.accept().await.expect("server should accept");
1173 let mut socket = accept_async(stream)
1174 .await
1175 .expect("server should upgrade websocket");
1176
1177 for (index, response_id) in ["resp_1", "resp_2"].iter().enumerate() {
1178 let request = socket
1179 .next()
1180 .await
1181 .expect("request should exist")
1182 .expect("request should be valid");
1183 let payload = request.into_text().expect("request should be text");
1184 assert!(
1185 payload.contains("\"type\":\"response.create\""),
1186 "expected response.create payload, got {payload}"
1187 );
1188
1189 let response = sample_response(ResponseStatus::Completed);
1190 let response = serde_json::to_value(CompletionResponse {
1191 id: (*response_id).to_string(),
1192 ..response
1193 })
1194 .expect("response should serialize");
1195
1196 socket
1197 .send(Message::text(
1198 json!({
1199 "type": "response.completed",
1200 "sequence_number": (index * 2) + 1,
1201 "response": response,
1202 })
1203 .to_string(),
1204 ))
1205 .await
1206 .expect("completed event should send");
1207 socket
1208 .send(Message::text(
1209 json!({
1210 "type": "response.done",
1211 "response": {
1212 "id": response_id,
1213 "status": "completed",
1214 },
1215 })
1216 .to_string(),
1217 ))
1218 .await
1219 .expect("done event should send");
1220 }
1221 });
1222
1223 let base_url = format!("http://{address}/v1");
1224 let client = crate::providers::openai::Client::builder()
1225 .api_key("test-key")
1226 .base_url(&base_url)
1227 .build()
1228 .expect("client should build");
1229 let model = client.completion_model("gpt-4o");
1230 let mut session = client
1231 .responses_websocket("gpt-4o")
1232 .await
1233 .expect("session should connect");
1234
1235 session
1236 .send(model.completion_request("first").build())
1237 .await
1238 .expect("first request should send");
1239 let first = session
1240 .wait_for_completed_response()
1241 .await
1242 .expect("first response should complete");
1243 assert_eq!(first.id, "resp_1");
1244 assert_eq!(session.previous_response_id(), Some("resp_1"));
1245
1246 session
1247 .send(model.completion_request("second").build())
1248 .await
1249 .expect("second request should send");
1250 let second = session
1251 .wait_for_completed_response()
1252 .await
1253 .expect("second response should complete");
1254 assert_eq!(second.id, "resp_2");
1255 assert_eq!(session.previous_response_id(), Some("resp_2"));
1256
1257 server.await.expect("server task should finish");
1258 }
1259
1260 #[tokio::test]
1261 async fn clearing_previous_response_id_does_not_disable_late_done_filter() {
1262 let listener = TcpListener::bind("127.0.0.1:0")
1263 .await
1264 .expect("listener should bind");
1265 let address = listener.local_addr().expect("listener should have address");
1266
1267 let server = tokio::spawn(async move {
1268 let (stream, _) = listener.accept().await.expect("server should accept");
1269 let mut socket = accept_async(stream)
1270 .await
1271 .expect("server should upgrade websocket");
1272
1273 for response_id in ["resp_1", "resp_2"] {
1274 let request = socket
1275 .next()
1276 .await
1277 .expect("request should exist")
1278 .expect("request should be valid");
1279 let payload = request.into_text().expect("request should be text");
1280 assert!(
1281 payload.contains("\"type\":\"response.create\""),
1282 "expected response.create payload, got {payload}"
1283 );
1284
1285 let response = sample_response(ResponseStatus::Completed);
1286 let response = serde_json::to_value(CompletionResponse {
1287 id: response_id.to_string(),
1288 ..response
1289 })
1290 .expect("response should serialize");
1291
1292 socket
1293 .send(Message::text(
1294 json!({
1295 "type": "response.completed",
1296 "sequence_number": 1,
1297 "response": response,
1298 })
1299 .to_string(),
1300 ))
1301 .await
1302 .expect("completed event should send");
1303 socket
1304 .send(Message::text(
1305 json!({
1306 "type": "response.done",
1307 "response": {
1308 "id": response_id,
1309 "status": "completed",
1310 },
1311 })
1312 .to_string(),
1313 ))
1314 .await
1315 .expect("done event should send");
1316 }
1317 });
1318
1319 let base_url = format!("http://{address}/v1");
1320 let client = crate::providers::openai::Client::builder()
1321 .api_key("test-key")
1322 .base_url(&base_url)
1323 .build()
1324 .expect("client should build");
1325 let model = client.completion_model("gpt-4o");
1326 let mut session = client
1327 .responses_websocket("gpt-4o")
1328 .await
1329 .expect("session should connect");
1330
1331 session
1332 .send(model.completion_request("first").build())
1333 .await
1334 .expect("first request should send");
1335 let first = session
1336 .wait_for_completed_response()
1337 .await
1338 .expect("first response should complete");
1339 assert_eq!(first.id, "resp_1");
1340
1341 session.clear_previous_response_id();
1342 assert_eq!(session.previous_response_id(), None);
1343
1344 session
1345 .send(model.completion_request("second").build())
1346 .await
1347 .expect("second request should send");
1348 let second = session
1349 .wait_for_completed_response()
1350 .await
1351 .expect("second response should complete");
1352 assert_eq!(second.id, "resp_2");
1353
1354 server.await.expect("server task should finish");
1355 }
1356
1357 #[tokio::test]
1358 async fn failed_turn_keeps_late_done_out_of_next_request() {
1359 let listener = TcpListener::bind("127.0.0.1:0")
1360 .await
1361 .expect("listener should bind");
1362 let address = listener.local_addr().expect("listener should have address");
1363
1364 let server = tokio::spawn(async move {
1365 let (stream, _) = listener.accept().await.expect("server should accept");
1366 let mut socket = accept_async(stream)
1367 .await
1368 .expect("server should upgrade websocket");
1369
1370 let first_request = socket
1371 .next()
1372 .await
1373 .expect("request should exist")
1374 .expect("request should be valid");
1375 let payload = first_request
1376 .into_text()
1377 .expect("failed request should be text");
1378 assert!(
1379 payload.contains("\"type\":\"response.create\""),
1380 "expected response.create payload, got {payload}"
1381 );
1382
1383 let failed_response = serde_json::to_value(CompletionResponse {
1384 id: "resp_failed".to_string(),
1385 status: ResponseStatus::Failed,
1386 ..sample_response(ResponseStatus::Completed)
1387 })
1388 .expect("failed response should serialize");
1389
1390 socket
1391 .send(Message::text(
1392 json!({
1393 "type": "response.failed",
1394 "sequence_number": 1,
1395 "response": failed_response,
1396 })
1397 .to_string(),
1398 ))
1399 .await
1400 .expect("failed event should send");
1401 socket
1402 .send(Message::text(
1403 json!({
1404 "type": "response.done",
1405 "response": {
1406 "id": "resp_failed",
1407 "status": "failed",
1408 },
1409 })
1410 .to_string(),
1411 ))
1412 .await
1413 .expect("done event should send");
1414
1415 let second_request = socket
1416 .next()
1417 .await
1418 .expect("request should exist")
1419 .expect("request should be valid");
1420 let payload = second_request
1421 .into_text()
1422 .expect("second request should be text");
1423 assert!(
1424 payload.contains("\"type\":\"response.create\""),
1425 "expected response.create payload, got {payload}"
1426 );
1427
1428 let response = sample_response(ResponseStatus::Completed);
1429 let response = serde_json::to_value(CompletionResponse {
1430 id: "resp_2".to_string(),
1431 ..response
1432 })
1433 .expect("response should serialize");
1434
1435 socket
1436 .send(Message::text(
1437 json!({
1438 "type": "response.completed",
1439 "sequence_number": 2,
1440 "response": response,
1441 })
1442 .to_string(),
1443 ))
1444 .await
1445 .expect("completed event should send");
1446 socket
1447 .send(Message::text(
1448 json!({
1449 "type": "response.done",
1450 "response": {
1451 "id": "resp_2",
1452 "status": "completed",
1453 },
1454 })
1455 .to_string(),
1456 ))
1457 .await
1458 .expect("done event should send");
1459 });
1460
1461 let base_url = format!("http://{address}/v1");
1462 let client = crate::providers::openai::Client::builder()
1463 .api_key("test-key")
1464 .base_url(&base_url)
1465 .build()
1466 .expect("client should build");
1467 let model = client.completion_model("gpt-4o");
1468 let mut session = client
1469 .responses_websocket("gpt-4o")
1470 .await
1471 .expect("session should connect");
1472
1473 session
1474 .send(model.completion_request("first").build())
1475 .await
1476 .expect("first request should send");
1477 let error = session
1478 .wait_for_completed_response()
1479 .await
1480 .expect_err("failed response should error");
1481 assert!(error.to_string().contains("failed response"));
1482 assert_eq!(session.previous_response_id(), None);
1483
1484 session
1485 .send(model.completion_request("second").build())
1486 .await
1487 .expect("second request should send");
1488 let second = session
1489 .wait_for_completed_response()
1490 .await
1491 .expect("second response should complete");
1492 assert_eq!(second.id, "resp_2");
1493
1494 server.await.expect("server task should finish");
1495 }
1496
1497 #[tokio::test]
1498 async fn done_first_completed_turn_updates_previous_response_id() {
1499 let listener = TcpListener::bind("127.0.0.1:0")
1500 .await
1501 .expect("listener should bind");
1502 let address = listener.local_addr().expect("listener should have address");
1503
1504 let server = tokio::spawn(async move {
1505 let (stream, _) = listener.accept().await.expect("server should accept");
1506 let mut socket = accept_async(stream)
1507 .await
1508 .expect("server should upgrade websocket");
1509
1510 for response_id in ["resp_1", "resp_2"] {
1511 let request = socket
1512 .next()
1513 .await
1514 .expect("request should exist")
1515 .expect("request should be valid");
1516 let payload = request.into_text().expect("request should be text");
1517 assert!(
1518 payload.contains("\"type\":\"response.create\""),
1519 "expected response.create payload, got {payload}"
1520 );
1521
1522 if response_id == "resp_2" {
1523 assert!(
1524 payload.contains("\"previous_response_id\":\"resp_1\""),
1525 "expected chained previous_response_id in payload, got {payload}"
1526 );
1527 }
1528
1529 let response = serde_json::to_value(CompletionResponse {
1530 id: response_id.to_string(),
1531 ..sample_response(ResponseStatus::Completed)
1532 })
1533 .expect("response should serialize");
1534
1535 socket
1536 .send(Message::text(
1537 json!({
1538 "type": "response.done",
1539 "response": response,
1540 })
1541 .to_string(),
1542 ))
1543 .await
1544 .expect("done event should send");
1545 }
1546 });
1547
1548 let base_url = format!("http://{address}/v1");
1549 let client = crate::providers::openai::Client::builder()
1550 .api_key("test-key")
1551 .base_url(&base_url)
1552 .build()
1553 .expect("client should build");
1554 let model = client.completion_model("gpt-4o");
1555 let mut session = client
1556 .responses_websocket("gpt-4o")
1557 .await
1558 .expect("session should connect");
1559
1560 session
1561 .send(model.completion_request("first").build())
1562 .await
1563 .expect("first request should send");
1564 let first = session
1565 .wait_for_completed_response()
1566 .await
1567 .expect("first response should complete");
1568 assert_eq!(first.id, "resp_1");
1569 assert_eq!(session.previous_response_id(), Some("resp_1"));
1570
1571 session
1572 .send(model.completion_request("second").build())
1573 .await
1574 .expect("second request should send");
1575 let second = session
1576 .wait_for_completed_response()
1577 .await
1578 .expect("second response should complete");
1579 assert_eq!(second.id, "resp_2");
1580 assert_eq!(session.previous_response_id(), Some("resp_2"));
1581
1582 server.await.expect("server task should finish");
1583 }
1584
1585 #[tokio::test]
1586 async fn done_first_failed_turn_does_not_chain_next_request() {
1587 let listener = TcpListener::bind("127.0.0.1:0")
1588 .await
1589 .expect("listener should bind");
1590 let address = listener.local_addr().expect("listener should have address");
1591
1592 let server = tokio::spawn(async move {
1593 let (stream, _) = listener.accept().await.expect("server should accept");
1594 let mut socket = accept_async(stream)
1595 .await
1596 .expect("server should upgrade websocket");
1597
1598 let first_request = socket
1599 .next()
1600 .await
1601 .expect("request should exist")
1602 .expect("request should be valid");
1603 let payload = first_request
1604 .into_text()
1605 .expect("first request should be text");
1606 assert!(
1607 payload.contains("\"type\":\"response.create\""),
1608 "expected response.create payload, got {payload}"
1609 );
1610 assert!(
1611 !payload.contains("\"previous_response_id\""),
1612 "did not expect previous_response_id in first payload, got {payload}"
1613 );
1614
1615 let failed_response = serde_json::to_value(CompletionResponse {
1616 id: "resp_failed".to_string(),
1617 status: ResponseStatus::Failed,
1618 ..sample_response(ResponseStatus::Completed)
1619 })
1620 .expect("failed response should serialize");
1621
1622 socket
1623 .send(Message::text(
1624 json!({
1625 "type": "response.done",
1626 "response": failed_response,
1627 })
1628 .to_string(),
1629 ))
1630 .await
1631 .expect("done event should send");
1632
1633 let second_request = socket
1634 .next()
1635 .await
1636 .expect("request should exist")
1637 .expect("request should be valid");
1638 let payload = second_request
1639 .into_text()
1640 .expect("second request should be text");
1641 assert!(
1642 payload.contains("\"type\":\"response.create\""),
1643 "expected response.create payload, got {payload}"
1644 );
1645 assert!(
1646 !payload.contains("\"previous_response_id\""),
1647 "did not expect chained previous_response_id in payload, got {payload}"
1648 );
1649
1650 let response = serde_json::to_value(CompletionResponse {
1651 id: "resp_2".to_string(),
1652 ..sample_response(ResponseStatus::Completed)
1653 })
1654 .expect("response should serialize");
1655
1656 socket
1657 .send(Message::text(
1658 json!({
1659 "type": "response.done",
1660 "response": response,
1661 })
1662 .to_string(),
1663 ))
1664 .await
1665 .expect("done event should send");
1666 });
1667
1668 let base_url = format!("http://{address}/v1");
1669 let client = crate::providers::openai::Client::builder()
1670 .api_key("test-key")
1671 .base_url(&base_url)
1672 .build()
1673 .expect("client should build");
1674 let model = client.completion_model("gpt-4o");
1675 let mut session = client
1676 .responses_websocket("gpt-4o")
1677 .await
1678 .expect("session should connect");
1679
1680 session
1681 .send(model.completion_request("first").build())
1682 .await
1683 .expect("first request should send");
1684 let error = session
1685 .wait_for_completed_response()
1686 .await
1687 .expect_err("failed response should error");
1688 assert!(error.to_string().contains("failed response"));
1689 assert_eq!(session.previous_response_id(), None);
1690
1691 session
1692 .send(model.completion_request("second").build())
1693 .await
1694 .expect("second request should send");
1695 let second = session
1696 .wait_for_completed_response()
1697 .await
1698 .expect("second response should complete");
1699 assert_eq!(second.id, "resp_2");
1700 assert_eq!(session.previous_response_id(), Some("resp_2"));
1701
1702 server.await.expect("server task should finish");
1703 }
1704
1705 #[test]
1706 fn websocket_url_converts_http_to_ws() {
1707 let url = websocket_url("http://localhost:8080/v1").expect("url should convert");
1708 assert_eq!(url, "ws://localhost:8080/v1/responses");
1709 }
1710
1711 #[test]
1712 fn websocket_url_rejects_unsupported_scheme() {
1713 let result = websocket_url("ftp://example.com/v1");
1714 assert!(result.is_err());
1715 }
1716
1717 #[test]
1718 fn websocket_url_trims_trailing_slash() {
1719 let url = websocket_url("https://api.openai.com/v1/").expect("url should convert");
1720 assert_eq!(url, "wss://api.openai.com/v1/responses");
1721 }
1722
1723 #[test]
1724 fn unknown_event_type_is_skipped() {
1725 let payload = json!({
1726 "type": "response.some_future_event",
1727 "data": "hello"
1728 });
1729
1730 let result =
1731 parse_server_event(&payload.to_string()).expect("unknown event should not error");
1732 assert!(result.is_none(), "unknown event should be skipped");
1733 }
1734
1735 #[test]
1736 fn malformed_known_event_returns_error() {
1737 let payload = json!({
1738 "type": "response.completed"
1739 });
1740
1741 let error = parse_server_event(&payload.to_string())
1742 .expect_err("malformed known event should error");
1743 assert!(
1744 error.to_string().contains("StreamingCompletionChunk"),
1745 "expected strict decode failure, got {error}"
1746 );
1747 }
1748
1749 #[tokio::test]
1750 async fn close_is_idempotent() {
1751 let listener = TcpListener::bind("127.0.0.1:0")
1752 .await
1753 .expect("listener should bind");
1754 let address = listener.local_addr().expect("listener should have address");
1755
1756 let server = tokio::spawn(async move {
1757 let (stream, _) = listener.accept().await.expect("server should accept");
1758 let mut socket = accept_async(stream)
1759 .await
1760 .expect("server should upgrade websocket");
1761
1762 let message = socket
1763 .next()
1764 .await
1765 .expect("close frame should arrive")
1766 .expect("close frame should be valid");
1767 assert!(
1768 matches!(message, Message::Close(_)),
1769 "expected close frame, got {message:?}"
1770 );
1771 });
1772
1773 let base_url = format!("http://{address}/v1");
1774 let client = crate::providers::openai::Client::builder()
1775 .api_key("test-key")
1776 .base_url(&base_url)
1777 .build()
1778 .expect("client should build");
1779 let mut session = client
1780 .responses_websocket("gpt-4o")
1781 .await
1782 .expect("session should connect");
1783
1784 session.close().await.expect("first close should succeed");
1785 session.close().await.expect("second close should succeed");
1786
1787 server.await.expect("server task should finish");
1788 }
1789
1790 #[tokio::test]
1791 async fn send_while_in_flight_returns_error() {
1792 let listener = TcpListener::bind("127.0.0.1:0")
1793 .await
1794 .expect("listener should bind");
1795 let address = listener.local_addr().expect("listener should have address");
1796
1797 let server = tokio::spawn(async move {
1798 let (stream, _) = listener.accept().await.expect("server should accept");
1799 let mut socket = accept_async(stream)
1800 .await
1801 .expect("server should upgrade websocket");
1802
1803 let _request = socket
1805 .next()
1806 .await
1807 .expect("request should exist")
1808 .expect("request should be valid");
1809
1810 sleep(Duration::from_millis(100)).await;
1812 let _ = socket.close(None).await;
1813 });
1814
1815 let base_url = format!("http://{address}/v1");
1816 let client = crate::providers::openai::Client::builder()
1817 .api_key("test-key")
1818 .base_url(&base_url)
1819 .build()
1820 .expect("client should build");
1821 let model = client.completion_model("gpt-4o");
1822 let mut session = client
1823 .responses_websocket("gpt-4o")
1824 .await
1825 .expect("session should connect");
1826
1827 session
1828 .send(model.completion_request("first").build())
1829 .await
1830 .expect("first request should send");
1831
1832 let error = session
1833 .send(model.completion_request("second").build())
1834 .await
1835 .expect_err("second send while in-flight should error");
1836 assert!(
1837 error.to_string().contains("already in flight"),
1838 "expected in-flight error, got {error}"
1839 );
1840
1841 server.await.expect("server task should finish");
1842 }
1843
1844 #[tokio::test]
1845 async fn send_after_close_returns_error() {
1846 let listener = TcpListener::bind("127.0.0.1:0")
1847 .await
1848 .expect("listener should bind");
1849 let address = listener.local_addr().expect("listener should have address");
1850
1851 let server = tokio::spawn(async move {
1852 let (stream, _) = listener.accept().await.expect("server should accept");
1853 let _socket = accept_async(stream)
1854 .await
1855 .expect("server should upgrade websocket");
1856 sleep(Duration::from_millis(100)).await;
1857 });
1858
1859 let base_url = format!("http://{address}/v1");
1860 let client = crate::providers::openai::Client::builder()
1861 .api_key("test-key")
1862 .base_url(&base_url)
1863 .build()
1864 .expect("client should build");
1865 let model = client.completion_model("gpt-4o");
1866 let mut session = client
1867 .responses_websocket("gpt-4o")
1868 .await
1869 .expect("session should connect");
1870
1871 session.close().await.expect("close should succeed");
1872
1873 let error = session
1874 .send(model.completion_request("after close").build())
1875 .await
1876 .expect_err("send after close should error");
1877 assert!(
1878 error.to_string().contains("session is closed"),
1879 "expected closed-session error, got {error}"
1880 );
1881
1882 server.await.expect("server task should finish");
1883 }
1884
1885 #[tokio::test]
1886 async fn next_event_without_send_returns_error() {
1887 let listener = TcpListener::bind("127.0.0.1:0")
1888 .await
1889 .expect("listener should bind");
1890 let address = listener.local_addr().expect("listener should have address");
1891
1892 let server = tokio::spawn(async move {
1893 let (stream, _) = listener.accept().await.expect("server should accept");
1894 let _socket = accept_async(stream)
1895 .await
1896 .expect("server should upgrade websocket");
1897 sleep(Duration::from_millis(100)).await;
1898 });
1899
1900 let base_url = format!("http://{address}/v1");
1901 let client = crate::providers::openai::Client::builder()
1902 .api_key("test-key")
1903 .base_url(&base_url)
1904 .build()
1905 .expect("client should build");
1906 let mut session = client
1907 .responses_websocket("gpt-4o")
1908 .await
1909 .expect("session should connect");
1910
1911 let error = session
1912 .next_event()
1913 .await
1914 .expect_err("next_event without send should error");
1915 assert!(
1916 error
1917 .to_string()
1918 .contains("No OpenAI websocket response is currently in flight"),
1919 "expected not-in-flight error, got {error}"
1920 );
1921
1922 server.await.expect("server task should finish");
1923 }
1924
1925 #[tokio::test]
1926 async fn unknown_event_is_skipped_during_session() {
1927 let listener = TcpListener::bind("127.0.0.1:0")
1928 .await
1929 .expect("listener should bind");
1930 let address = listener.local_addr().expect("listener should have address");
1931
1932 let server = tokio::spawn(async move {
1933 let (stream, _) = listener.accept().await.expect("server should accept");
1934 let mut socket = accept_async(stream)
1935 .await
1936 .expect("server should upgrade websocket");
1937
1938 let _request = socket
1939 .next()
1940 .await
1941 .expect("request should exist")
1942 .expect("request should be valid");
1943
1944 socket
1946 .send(Message::text(
1947 json!({
1948 "type": "response.some_future_event",
1949 "data": "should be skipped"
1950 })
1951 .to_string(),
1952 ))
1953 .await
1954 .expect("unknown event should send");
1955
1956 let response = serde_json::to_value(CompletionResponse {
1958 id: "resp_after_unknown".to_string(),
1959 ..sample_response(ResponseStatus::Completed)
1960 })
1961 .expect("response should serialize");
1962
1963 socket
1964 .send(Message::text(
1965 json!({
1966 "type": "response.completed",
1967 "sequence_number": 1,
1968 "response": response,
1969 })
1970 .to_string(),
1971 ))
1972 .await
1973 .expect("completed event should send");
1974 });
1975
1976 let base_url = format!("http://{address}/v1");
1977 let client = crate::providers::openai::Client::builder()
1978 .api_key("test-key")
1979 .base_url(&base_url)
1980 .build()
1981 .expect("client should build");
1982 let model = client.completion_model("gpt-4o");
1983 let mut session = client
1984 .responses_websocket("gpt-4o")
1985 .await
1986 .expect("session should connect");
1987
1988 session
1989 .send(model.completion_request("hello").build())
1990 .await
1991 .expect("send should succeed");
1992 let response = session
1993 .wait_for_completed_response()
1994 .await
1995 .expect("response should complete despite unknown event");
1996 assert_eq!(response.id, "resp_after_unknown");
1997
1998 server.await.expect("server task should finish");
1999 }
2000}