1use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
10 ToolCall, Usage, UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use base64::Engine;
17use futures::StreamExt;
18use futures::stream::{self, Stream};
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, VecDeque};
21use std::pin::Pin;
22
23const OPENAI_RESPONSES_API_URL: &str = "https://api.openai.com/v1/responses";
28pub(crate) const CODEX_RESPONSES_API_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
29const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 4096;
30
31pub struct OpenAIResponsesProvider {
37 client: Client,
38 model: String,
39 base_url: String,
40 provider: String,
41 api: String,
42 codex_mode: bool,
43 compat: Option<CompatConfig>,
44}
45
46impl OpenAIResponsesProvider {
47 pub fn new(model: impl Into<String>) -> Self {
49 Self {
50 client: Client::new(),
51 model: model.into(),
52 base_url: OPENAI_RESPONSES_API_URL.to_string(),
53 provider: "openai".to_string(),
54 api: "openai-responses".to_string(),
55 codex_mode: false,
56 compat: None,
57 }
58 }
59
60 #[must_use]
62 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
63 self.provider = provider.into();
64 self
65 }
66
67 #[must_use]
69 pub fn with_api_name(mut self, api: impl Into<String>) -> Self {
70 self.api = api.into();
71 self
72 }
73
74 #[must_use]
76 pub const fn with_codex_mode(mut self, enabled: bool) -> Self {
77 self.codex_mode = enabled;
78 self
79 }
80
81 #[must_use]
83 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
84 self.base_url = base_url.into();
85 self
86 }
87
88 #[must_use]
90 pub fn with_client(mut self, client: Client) -> Self {
91 self.client = client;
92 self
93 }
94
95 #[must_use]
97 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
98 self.compat = compat;
99 self
100 }
101
102 pub fn build_request(
103 &self,
104 context: &Context<'_>,
105 options: &StreamOptions,
106 ) -> OpenAIResponsesRequest {
107 let input = build_openai_responses_input(context);
108 let tools: Option<Vec<OpenAIResponsesTool>> = if context.tools.is_empty() {
109 None
110 } else {
111 Some(
112 context
113 .tools
114 .iter()
115 .map(convert_tool_to_openai_responses)
116 .collect(),
117 )
118 };
119
120 let instructions = context.system_prompt.as_deref().map(ToString::to_string);
121
122 let (tool_choice, parallel_tool_calls, text, include, reasoning) = if self.codex_mode {
125 let effort = options
126 .thinking_level
127 .as_ref()
128 .map_or_else(|| "high".to_string(), ToString::to_string);
129 (
130 Some("auto"),
131 Some(true),
132 Some(OpenAIResponsesTextConfig {
133 verbosity: "medium",
134 }),
135 Some(vec!["reasoning.encrypted_content"]),
136 Some(OpenAIResponsesReasoning {
137 effort,
138 summary: Some("auto"),
139 }),
140 )
141 } else {
142 (None, None, None, None, None)
143 };
144
145 OpenAIResponsesRequest {
146 model: self.model.clone(),
147 input,
148 instructions,
149 temperature: options.temperature,
150 max_output_tokens: if self.codex_mode {
151 None
152 } else {
153 options.max_tokens.or(Some(DEFAULT_MAX_OUTPUT_TOKENS))
154 },
155 tools,
156 stream: true,
157 store: false,
158 tool_choice,
159 parallel_tool_calls,
160 text,
161 include,
162 reasoning,
163 }
164 }
165}
166
167fn bearer_token_from_authorization_header(value: &str) -> Option<String> {
168 let mut parts = value.split_whitespace();
169 let scheme = parts.next()?;
170 let token = parts.next()?;
171 if parts.next().is_some() {
172 return None;
173 }
174 if scheme.eq_ignore_ascii_case("bearer") && !token.trim().is_empty() {
175 Some(token.trim().to_string())
176 } else {
177 None
178 }
179}
180
181#[async_trait]
182impl Provider for OpenAIResponsesProvider {
183 fn name(&self) -> &str {
184 &self.provider
185 }
186
187 fn api(&self) -> &str {
188 &self.api
189 }
190
191 fn model_id(&self) -> &str {
192 &self.model
193 }
194
195 #[allow(clippy::too_many_lines)]
196 async fn stream(
197 &self,
198 context: &Context<'_>,
199 options: &StreamOptions,
200 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
201 let has_authorization_header = options
202 .headers
203 .keys()
204 .any(|key| key.eq_ignore_ascii_case("authorization"));
205 let authorization_header_value = options.headers.iter().find_map(|(key, value)| {
206 if key.eq_ignore_ascii_case("authorization") {
207 Some(value.trim().to_string())
208 } else {
209 None
210 }
211 });
212
213 let auth_value = if has_authorization_header {
214 None
215 } else {
216 Some(
217 options
218 .api_key
219 .clone()
220 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
221 .ok_or_else(|| {
222 Error::provider(
223 self.name(),
224 "Missing API key for OpenAI/Codex. Set OPENAI_API_KEY or run /login.",
225 )
226 })?,
227 )
228 };
229
230 let request_body = self.build_request(context, options);
231
232 let mut request = self
235 .client
236 .post(&self.base_url)
237 .header("Accept", "text/event-stream");
238
239 if let Some(ref auth_value) = auth_value {
240 request = request.header("Authorization", format!("Bearer {auth_value}"));
241 }
242
243 if self.codex_mode {
244 let codex_token = authorization_header_value
245 .as_deref()
246 .and_then(bearer_token_from_authorization_header)
247 .or_else(|| auth_value.clone())
248 .ok_or_else(|| {
249 Error::provider(
250 self.name(),
251 "OpenAI Codex mode requires a Bearer token. Provide one via /login openai-codex or an Authorization: Bearer <token> header.",
252 )
253 })?;
254 let account_id = extract_chatgpt_account_id(&codex_token).ok_or_else(|| {
255 Error::provider(
256 self.name(),
257 "Invalid OpenAI Codex OAuth token (missing chatgpt_account_id claim). Run /login openai-codex again.",
258 )
259 })?;
260 request = request
261 .header("chatgpt-account-id", account_id)
262 .header("OpenAI-Beta", "responses=experimental")
263 .header("originator", "pi")
264 .header("User-Agent", "pi_agent_rust");
265 if let Some(session_id) = &options.session_id {
266 request = request.header("session_id", session_id);
267 }
268 }
269
270 if let Some(compat) = &self.compat {
272 if let Some(custom_headers) = &compat.custom_headers {
273 for (key, value) in custom_headers {
274 request = request.header(key, value);
275 }
276 }
277 }
278
279 for (key, value) in &options.headers {
281 request = request.header(key, value);
282 }
283
284 let request = request.json(&request_body)?;
285
286 let response = Box::pin(request.send()).await?;
287 let status = response.status();
288 if !(200..300).contains(&status) {
289 let body = response
290 .text()
291 .await
292 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
293 return Err(Error::provider(
294 self.name(),
295 format!("OpenAI API error (HTTP {status}): {body}"),
296 ));
297 }
298
299 let content_type = response
304 .headers()
305 .iter()
306 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
307 .map(|(_, value)| value.to_ascii_lowercase());
308 if let Some(ref ct) = content_type {
309 if !ct.contains("text/event-stream") && !ct.contains("application/x-ndjson") {
310 return Err(Error::api(format!(
311 "OpenAI API protocol error (HTTP {status}): unexpected Content-Type {ct} (expected text/event-stream)"
312 )));
313 }
314 }
315
316 let event_source = SseStream::new(response.bytes_stream());
317
318 let model = self.model.clone();
319 let api = self.api().to_string();
320 let provider = self.name().to_string();
321
322 let stream = stream::unfold(
323 StreamState::new(event_source, model, api, provider),
324 |mut state| async move {
325 loop {
326 if let Some(event) = state.pending_events.pop_front() {
327 return Some((Ok(event), state));
328 }
329
330 if state.finished {
334 return None;
335 }
336
337 match state.event_source.next().await {
338 Some(Ok(msg)) => {
339 if msg.data == "[DONE]" {
340 state.finish(None);
343 continue;
344 }
345
346 if let Err(e) = state.process_event(&msg.data) {
347 return Some((Err(e), state));
348 }
349 }
350 Some(Err(e)) => {
351 let err = Error::api(format!("SSE error: {e}"));
352 return Some((Err(err), state));
353 }
354 None => {
355 return Some((
358 Err(Error::api("Stream ended without Done event")),
359 state,
360 ));
361 }
362 }
363 }
364 },
365 );
366
367 Ok(Box::pin(stream))
368 }
369}
370
371#[derive(Debug, Clone, PartialEq, Eq, Hash)]
376struct TextKey {
377 item_id: String,
378 content_index: u32,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq, Hash)]
382struct ReasoningKey {
383 item_id: String,
384 summary_index: u32,
385}
386
387struct ToolCallState {
388 content_index: usize,
389 call_id: String,
390 name: String,
391 arguments: String,
392}
393
394struct StreamState<S>
395where
396 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
397{
398 event_source: SseStream<S>,
399 partial: AssistantMessage,
400 pending_events: VecDeque<StreamEvent>,
401 started: bool,
402 finished: bool,
403 text_blocks: HashMap<TextKey, usize>,
404 reasoning_blocks: HashMap<ReasoningKey, usize>,
405 tool_calls_by_item_id: HashMap<String, ToolCallState>,
406}
407
408impl<S> StreamState<S>
409where
410 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
411{
412 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
413 Self {
414 event_source,
415 partial: AssistantMessage {
416 content: Vec::new(),
417 api,
418 provider,
419 model,
420 usage: Usage::default(),
421 stop_reason: StopReason::Stop,
422 error_message: None,
423 timestamp: chrono::Utc::now().timestamp_millis(),
424 },
425 pending_events: VecDeque::new(),
426 started: false,
427 finished: false,
428 text_blocks: HashMap::new(),
429 reasoning_blocks: HashMap::new(),
430 tool_calls_by_item_id: HashMap::new(),
431 }
432 }
433
434 fn ensure_started(&mut self) {
435 if !self.started {
436 self.started = true;
437 self.pending_events.push_back(StreamEvent::Start {
438 partial: self.partial.clone(),
439 });
440 }
441 }
442
443 fn text_block_for(&mut self, item_id: String, content_index: u32) -> usize {
444 let key = TextKey {
445 item_id,
446 content_index,
447 };
448 if let Some(idx) = self.text_blocks.get(&key) {
449 return *idx;
450 }
451
452 let idx = self.partial.content.len();
453 self.partial
454 .content
455 .push(ContentBlock::Text(TextContent::new("")));
456 self.text_blocks.insert(key, idx);
457 self.pending_events
458 .push_back(StreamEvent::TextStart { content_index: idx });
459 idx
460 }
461
462 fn reasoning_block_for(&mut self, item_id: String, summary_index: u32) -> usize {
463 let key = ReasoningKey {
464 item_id,
465 summary_index,
466 };
467 if let Some(idx) = self.reasoning_blocks.get(&key) {
468 return *idx;
469 }
470
471 let idx = self.partial.content.len();
472 self.partial
473 .content
474 .push(ContentBlock::Thinking(ThinkingContent {
475 thinking: String::new(),
476 thinking_signature: None,
477 }));
478 self.reasoning_blocks.insert(key, idx);
479 self.pending_events
480 .push_back(StreamEvent::ThinkingStart { content_index: idx });
481 idx
482 }
483
484 #[allow(clippy::too_many_lines)]
485 fn process_event(&mut self, data: &str) -> Result<()> {
486 let chunk: OpenAIResponsesChunk =
487 serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
488
489 match chunk {
490 OpenAIResponsesChunk::OutputTextDelta {
491 item_id,
492 content_index,
493 delta,
494 } => {
495 self.ensure_started();
496 let idx = self.text_block_for(item_id, content_index);
497 if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(idx) {
498 t.text.push_str(&delta);
499 }
500 self.pending_events.push_back(StreamEvent::TextDelta {
501 content_index: idx,
502 delta,
503 });
504 }
505 OpenAIResponsesChunk::ReasoningSummaryTextDelta {
506 item_id,
507 summary_index,
508 delta,
509 } => {
510 self.ensure_started();
511 let idx = self.reasoning_block_for(item_id, summary_index);
512 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(idx) {
513 t.thinking.push_str(&delta);
514 }
515 self.pending_events.push_back(StreamEvent::ThinkingDelta {
516 content_index: idx,
517 delta,
518 });
519 }
520 OpenAIResponsesChunk::OutputItemAdded { item } => {
521 if let OpenAIResponsesOutputItem::FunctionCall {
522 id,
523 call_id,
524 name,
525 arguments,
526 } = item
527 {
528 self.ensure_started();
529
530 let content_index = self.partial.content.len();
531 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
532 id: call_id.clone(),
533 name: name.clone(),
534 arguments: serde_json::Value::Null,
535 thought_signature: None,
536 }));
537
538 self.tool_calls_by_item_id.insert(
539 id,
540 ToolCallState {
541 content_index,
542 call_id,
543 name,
544 arguments: arguments.clone(),
545 },
546 );
547
548 self.pending_events
549 .push_back(StreamEvent::ToolCallStart { content_index });
550
551 if !arguments.is_empty() {
552 self.pending_events.push_back(StreamEvent::ToolCallDelta {
553 content_index,
554 delta: arguments,
555 });
556 }
557 }
558 }
559 OpenAIResponsesChunk::FunctionCallArgumentsDelta { item_id, delta } => {
560 self.ensure_started();
561 if let Some(tc) = self.tool_calls_by_item_id.get_mut(&item_id) {
562 tc.arguments.push_str(&delta);
563 self.pending_events.push_back(StreamEvent::ToolCallDelta {
564 content_index: tc.content_index,
565 delta,
566 });
567 }
568 }
569 OpenAIResponsesChunk::OutputItemDone { item } => {
570 if let OpenAIResponsesOutputItemDone::FunctionCall {
571 id,
572 call_id,
573 name,
574 arguments,
575 } = item
576 {
577 self.ensure_started();
578 self.end_tool_call(&id, &call_id, &name, &arguments);
579 }
580 }
581 OpenAIResponsesChunk::ResponseCompleted { response }
582 | OpenAIResponsesChunk::ResponseDone { response }
583 | OpenAIResponsesChunk::ResponseIncomplete { response } => {
584 self.ensure_started();
585 self.partial.usage.input = response.usage.input_tokens;
586 self.partial.usage.output = response.usage.output_tokens;
587 self.partial.usage.total_tokens = response
588 .usage
589 .total_tokens
590 .unwrap_or(response.usage.input_tokens + response.usage.output_tokens);
591
592 self.finish(response.incomplete_reason());
593 }
594 OpenAIResponsesChunk::ResponseFailed { response } => {
595 self.ensure_started();
596 self.partial.stop_reason = StopReason::Error;
597 self.partial.error_message = Some(
598 response
599 .error
600 .and_then(|error| error.message)
601 .unwrap_or_else(|| "Codex response failed".to_string()),
602 );
603 self.pending_events.push_back(StreamEvent::Error {
604 reason: StopReason::Error,
605 error: std::mem::take(&mut self.partial),
606 });
607 self.finished = true;
608 }
609 OpenAIResponsesChunk::Error { message } => {
610 self.ensure_started();
611 self.partial.stop_reason = StopReason::Error;
612 self.partial.error_message = Some(message);
613 self.pending_events.push_back(StreamEvent::Error {
614 reason: StopReason::Error,
615 error: std::mem::take(&mut self.partial),
616 });
617 self.finished = true;
618 }
619 OpenAIResponsesChunk::Unknown => {}
620 }
621
622 Ok(())
623 }
624
625 fn partial_has_tool_call(&self) -> bool {
626 self.partial
627 .content
628 .iter()
629 .any(|b| matches!(b, ContentBlock::ToolCall(_)))
630 }
631
632 fn end_tool_call(&mut self, item_id: &str, call_id: &str, name: &str, arguments: &str) {
633 let mut tc = self
634 .tool_calls_by_item_id
635 .remove(item_id)
636 .unwrap_or_else(|| {
637 let content_index = self.partial.content.len();
639 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
640 id: call_id.to_string(),
641 name: name.to_string(),
642 arguments: serde_json::Value::Null,
643 thought_signature: None,
644 }));
645 ToolCallState {
646 content_index,
647 call_id: call_id.to_string(),
648 name: name.to_string(),
649 arguments: String::new(),
650 }
651 });
652
653 if !arguments.is_empty() {
655 tc.arguments = arguments.to_string();
656 }
657
658 let parsed_args: serde_json::Value = serde_json::from_str(&tc.arguments).unwrap_or_else(|e| {
659 tracing::warn!(error = %e, raw = %tc.arguments, "Failed to parse tool arguments as JSON");
660 serde_json::Value::Null
661 });
662
663 self.partial.stop_reason = StopReason::ToolUse;
664 self.pending_events.push_back(StreamEvent::ToolCallEnd {
665 content_index: tc.content_index,
666 tool_call: ToolCall {
667 id: tc.call_id.clone(),
668 name: tc.name.clone(),
669 arguments: parsed_args.clone(),
670 thought_signature: None,
671 },
672 });
673
674 if let Some(ContentBlock::ToolCall(block)) = self.partial.content.get_mut(tc.content_index)
675 {
676 block.id = tc.call_id;
677 block.name = tc.name;
678 block.arguments = parsed_args;
679 }
680 }
681
682 fn finish(&mut self, incomplete_reason: Option<String>) {
683 if self.finished {
684 return;
685 }
686
687 for idx in self.text_blocks.values() {
689 if let Some(ContentBlock::Text(t)) = self.partial.content.get(*idx) {
690 self.pending_events.push_back(StreamEvent::TextEnd {
691 content_index: *idx,
692 content: t.text.clone(),
693 });
694 }
695 }
696
697 for idx in self.reasoning_blocks.values() {
699 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get(*idx) {
700 self.pending_events.push_back(StreamEvent::ThinkingEnd {
701 content_index: *idx,
702 content: t.thinking.clone(),
703 });
704 }
705 }
706
707 let ids: Vec<String> = self.tool_calls_by_item_id.keys().cloned().collect();
709 for id in ids {
710 let (call_id, name, arguments) = match self.tool_calls_by_item_id.get(&id) {
712 Some(tc) => (tc.call_id.clone(), tc.name.clone(), tc.arguments.clone()),
713 None => continue,
714 };
715 self.end_tool_call(&id, &call_id, &name, &arguments);
716 }
717
718 if let Some(reason) = incomplete_reason {
720 let reason_lower = reason.to_ascii_lowercase();
721 if reason_lower.contains("max_output") || reason_lower.contains("length") {
722 self.partial.stop_reason = StopReason::Length;
723 } else if reason_lower.contains("tool") {
724 self.partial.stop_reason = StopReason::ToolUse;
725 } else if reason_lower.contains("content_filter") || reason_lower.contains("error") {
726 self.partial.stop_reason = StopReason::Error;
727 }
728 } else if self.partial_has_tool_call() {
729 self.partial.stop_reason = StopReason::ToolUse;
730 }
731
732 let reason = self.partial.stop_reason;
733 self.pending_events.push_back(StreamEvent::Done {
734 reason,
735 message: self.partial.clone(),
736 });
737 self.finished = true;
738 }
739}
740
741fn extract_chatgpt_account_id(token: &str) -> Option<String> {
742 let mut parts = token.split('.');
743 let _header = parts.next()?;
744 let payload = parts.next()?;
745 let _signature = parts.next()?;
746 if parts.next().is_some() {
747 return None;
748 }
749
750 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
751 .decode(payload)
752 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(payload))
753 .ok()?;
754 let payload_json: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
755 payload_json
756 .get("https://api.openai.com/auth")
757 .and_then(|auth| auth.get("chatgpt_account_id"))
758 .and_then(serde_json::Value::as_str)
759 .map(str::trim)
760 .filter(|value| !value.is_empty())
761 .map(ToString::to_string)
762}
763
764#[derive(Debug, Serialize)]
769pub struct OpenAIResponsesRequest {
770 model: String,
771 input: Vec<OpenAIResponsesInputItem>,
772 #[serde(skip_serializing_if = "Option::is_none")]
773 instructions: Option<String>,
774 #[serde(skip_serializing_if = "Option::is_none")]
775 temperature: Option<f32>,
776 #[serde(skip_serializing_if = "Option::is_none")]
777 max_output_tokens: Option<u32>,
778 #[serde(skip_serializing_if = "Option::is_none")]
779 tools: Option<Vec<OpenAIResponsesTool>>,
780 stream: bool,
781 store: bool,
782 #[serde(skip_serializing_if = "Option::is_none")]
783 tool_choice: Option<&'static str>,
784 #[serde(skip_serializing_if = "Option::is_none")]
785 parallel_tool_calls: Option<bool>,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 text: Option<OpenAIResponsesTextConfig>,
788 #[serde(skip_serializing_if = "Option::is_none")]
789 include: Option<Vec<&'static str>>,
790 #[serde(skip_serializing_if = "Option::is_none")]
791 reasoning: Option<OpenAIResponsesReasoning>,
792}
793
794#[derive(Debug, Serialize)]
795struct OpenAIResponsesTextConfig {
796 verbosity: &'static str,
797}
798
799#[derive(Debug, Serialize)]
800struct OpenAIResponsesReasoning {
801 effort: String,
802 #[serde(skip_serializing_if = "Option::is_none")]
803 summary: Option<&'static str>,
804}
805
806#[derive(Debug, Serialize)]
807#[serde(untagged)]
808enum OpenAIResponsesInputItem {
809 System {
810 role: &'static str,
811 content: String,
812 },
813 User {
814 role: &'static str,
815 content: Vec<OpenAIResponsesUserContentPart>,
816 },
817 Assistant {
818 role: &'static str,
819 content: Vec<OpenAIResponsesAssistantContentPart>,
820 },
821 FunctionCall {
822 #[serde(rename = "type")]
823 r#type: &'static str,
824 call_id: String,
825 name: String,
826 arguments: String,
827 },
828 FunctionCallOutput {
829 #[serde(rename = "type")]
830 r#type: &'static str,
831 call_id: String,
832 output: String,
833 },
834}
835
836#[derive(Debug, Serialize)]
837#[serde(tag = "type", rename_all = "snake_case")]
838enum OpenAIResponsesUserContentPart {
839 #[serde(rename = "input_text")]
840 InputText { text: String },
841 #[serde(rename = "input_image")]
842 InputImage { image_url: String },
843}
844
845#[derive(Debug, Serialize)]
846#[serde(tag = "type", rename_all = "snake_case")]
847enum OpenAIResponsesAssistantContentPart {
848 #[serde(rename = "output_text")]
849 OutputText { text: String },
850}
851
852#[derive(Debug, Serialize)]
853struct OpenAIResponsesTool {
854 #[serde(rename = "type")]
855 r#type: &'static str,
856 name: String,
857 #[serde(skip_serializing_if = "Option::is_none")]
858 description: Option<String>,
859 parameters: serde_json::Value,
860}
861
862fn convert_tool_to_openai_responses(tool: &ToolDef) -> OpenAIResponsesTool {
863 OpenAIResponsesTool {
864 r#type: "function",
865 name: tool.name.clone(),
866 description: if tool.description.trim().is_empty() {
867 None
868 } else {
869 Some(tool.description.clone())
870 },
871 parameters: tool.parameters.clone(),
872 }
873}
874
875fn build_openai_responses_input(context: &Context<'_>) -> Vec<OpenAIResponsesInputItem> {
876 let mut input = Vec::with_capacity(context.messages.len());
877
878 for message in context.messages.iter() {
883 match message {
884 Message::User(user) => input.push(convert_user_message_to_responses(&user.content)),
885 Message::Custom(custom) => input.push(OpenAIResponsesInputItem::User {
886 role: "user",
887 content: vec![OpenAIResponsesUserContentPart::InputText {
888 text: custom.content.clone(),
889 }],
890 }),
891 Message::Assistant(assistant) => {
892 let mut pending_text = String::new();
894
895 for block in &assistant.content {
896 match block {
897 ContentBlock::Text(t) => pending_text.push_str(&t.text),
898 ContentBlock::ToolCall(tc) => {
899 if !pending_text.is_empty() {
900 input.push(OpenAIResponsesInputItem::Assistant {
901 role: "assistant",
902 content: vec![
903 OpenAIResponsesAssistantContentPart::OutputText {
904 text: std::mem::take(&mut pending_text),
905 },
906 ],
907 });
908 }
909 input.push(OpenAIResponsesInputItem::FunctionCall {
910 r#type: "function_call",
911 call_id: tc.id.clone(),
912 name: tc.name.clone(),
913 arguments: tc.arguments.to_string(),
914 });
915 }
916 _ => {}
917 }
918 }
919
920 if !pending_text.is_empty() {
921 input.push(OpenAIResponsesInputItem::Assistant {
922 role: "assistant",
923 content: vec![OpenAIResponsesAssistantContentPart::OutputText {
924 text: pending_text,
925 }],
926 });
927 }
928 }
929 Message::ToolResult(result) => {
930 let mut out = String::new();
931 for (i, block) in result.content.iter().enumerate() {
932 if i > 0 {
933 out.push('\n');
934 }
935 if let ContentBlock::Text(t) = block {
936 out.push_str(&t.text);
937 }
938 }
939 input.push(OpenAIResponsesInputItem::FunctionCallOutput {
940 r#type: "function_call_output",
941 call_id: result.tool_call_id.clone(),
942 output: out,
943 });
944 }
945 }
946 }
947
948 input
949}
950
951fn convert_user_message_to_responses(content: &UserContent) -> OpenAIResponsesInputItem {
952 match content {
953 UserContent::Text(text) => OpenAIResponsesInputItem::User {
954 role: "user",
955 content: vec![OpenAIResponsesUserContentPart::InputText { text: text.clone() }],
956 },
957 UserContent::Blocks(blocks) => {
958 let mut parts = Vec::new();
959 for block in blocks {
960 match block {
961 ContentBlock::Text(t) => {
962 parts.push(OpenAIResponsesUserContentPart::InputText {
963 text: t.text.clone(),
964 });
965 }
966 ContentBlock::Image(img) => {
967 let url = format!("data:{};base64,{}", img.mime_type, img.data);
968 parts.push(OpenAIResponsesUserContentPart::InputImage { image_url: url });
969 }
970 _ => {}
971 }
972 }
973 if parts.is_empty() {
974 parts.push(OpenAIResponsesUserContentPart::InputText {
975 text: String::new(),
976 });
977 }
978 OpenAIResponsesInputItem::User {
979 role: "user",
980 content: parts,
981 }
982 }
983 }
984}
985
986#[derive(Debug, Deserialize)]
991#[serde(tag = "type")]
992enum OpenAIResponsesChunk {
993 #[serde(rename = "response.output_text.delta")]
994 OutputTextDelta {
995 item_id: String,
996 content_index: u32,
997 delta: String,
998 },
999 #[serde(rename = "response.output_item.added")]
1000 OutputItemAdded { item: OpenAIResponsesOutputItem },
1001 #[serde(rename = "response.output_item.done")]
1002 OutputItemDone { item: OpenAIResponsesOutputItemDone },
1003 #[serde(rename = "response.function_call_arguments.delta")]
1004 FunctionCallArgumentsDelta { item_id: String, delta: String },
1005 #[serde(rename = "response.reasoning_summary_text.delta")]
1006 ReasoningSummaryTextDelta {
1007 item_id: String,
1008 summary_index: u32,
1009 delta: String,
1010 },
1011 #[serde(rename = "response.completed")]
1012 ResponseCompleted {
1013 response: OpenAIResponsesDonePayload,
1014 },
1015 #[serde(rename = "response.done")]
1016 ResponseDone {
1017 response: OpenAIResponsesDonePayload,
1018 },
1019 #[serde(rename = "response.incomplete")]
1020 ResponseIncomplete {
1021 response: OpenAIResponsesDonePayload,
1022 },
1023 #[serde(rename = "response.failed")]
1024 ResponseFailed {
1025 response: OpenAIResponsesFailedPayload,
1026 },
1027 #[serde(rename = "error")]
1028 Error { message: String },
1029 #[serde(other)]
1030 Unknown,
1031}
1032
1033#[derive(Debug, Deserialize)]
1034#[serde(tag = "type")]
1035enum OpenAIResponsesOutputItem {
1036 #[serde(rename = "function_call")]
1037 FunctionCall {
1038 id: String,
1039 call_id: String,
1040 name: String,
1041 #[serde(default)]
1042 arguments: String,
1043 },
1044 #[serde(other)]
1045 Unknown,
1046}
1047
1048#[derive(Debug, Deserialize)]
1049#[serde(tag = "type")]
1050enum OpenAIResponsesOutputItemDone {
1051 #[serde(rename = "function_call")]
1052 FunctionCall {
1053 id: String,
1054 call_id: String,
1055 name: String,
1056 #[serde(default)]
1057 arguments: String,
1058 },
1059 #[serde(other)]
1060 Unknown,
1061}
1062
1063#[derive(Debug, Deserialize)]
1064struct OpenAIResponsesDonePayload {
1065 #[serde(default)]
1066 incomplete_details: Option<OpenAIResponsesIncompleteDetails>,
1067 usage: OpenAIResponsesUsage,
1068}
1069
1070#[derive(Debug, Deserialize)]
1071struct OpenAIResponsesFailedPayload {
1072 #[serde(default)]
1073 error: Option<OpenAIResponsesFailedError>,
1074}
1075
1076#[derive(Debug, Deserialize)]
1077struct OpenAIResponsesFailedError {
1078 #[serde(default)]
1079 message: Option<String>,
1080}
1081
1082#[derive(Debug, Deserialize)]
1083struct OpenAIResponsesIncompleteDetails {
1084 reason: String,
1085}
1086
1087#[derive(Debug, Deserialize)]
1088#[allow(clippy::struct_field_names)]
1089struct OpenAIResponsesUsage {
1090 input_tokens: u64,
1091 output_tokens: u64,
1092 #[serde(default)]
1093 total_tokens: Option<u64>,
1094}
1095
1096impl OpenAIResponsesDonePayload {
1097 fn incomplete_reason(&self) -> Option<String> {
1098 self.incomplete_details.as_ref().map(|d| d.reason.clone())
1099 }
1100}
1101
1102#[cfg(test)]
1107mod tests {
1108 use super::*;
1109 use asupersync::runtime::RuntimeBuilder;
1110 use futures::stream;
1111 use serde_json::{Value, json};
1112 use std::collections::HashMap;
1113 use std::io::{Read, Write};
1114 use std::net::TcpListener;
1115 use std::sync::mpsc;
1116 use std::time::Duration;
1117
1118 #[test]
1119 fn test_provider_info() {
1120 let provider = OpenAIResponsesProvider::new("gpt-4o");
1121 assert_eq!(provider.name(), "openai");
1122 assert_eq!(provider.api(), "openai-responses");
1123 }
1124
1125 #[test]
1126 fn test_build_request_includes_system_tools_and_defaults() {
1127 let provider = OpenAIResponsesProvider::new("gpt-4o");
1128 let context = Context::owned(
1129 Some("System guidance".to_string()),
1130 vec![Message::User(crate::model::UserMessage {
1131 content: UserContent::Text("Ping".to_string()),
1132 timestamp: 0,
1133 })],
1134 vec![
1135 ToolDef {
1136 name: "search".to_string(),
1137 description: "Search docs".to_string(),
1138 parameters: json!({
1139 "type": "object",
1140 "properties": { "q": { "type": "string" } },
1141 "required": ["q"]
1142 }),
1143 },
1144 ToolDef {
1145 name: "blank_desc".to_string(),
1146 description: " ".to_string(),
1147 parameters: json!({ "type": "object" }),
1148 },
1149 ],
1150 );
1151 let options = StreamOptions {
1152 temperature: Some(0.3),
1153 ..Default::default()
1154 };
1155
1156 let request = provider.build_request(&context, &options);
1157 let value = serde_json::to_value(&request).expect("serialize request");
1158 assert_eq!(value["model"], "gpt-4o");
1159 let temperature = value["temperature"]
1160 .as_f64()
1161 .expect("temperature should serialize as number");
1162 assert!((temperature - 0.3).abs() < 1e-6);
1163 assert_eq!(value["max_output_tokens"], DEFAULT_MAX_OUTPUT_TOKENS);
1164 assert_eq!(value["stream"], true);
1165 assert_eq!(value["instructions"], "System guidance");
1166 assert_eq!(value["input"][0]["role"], "user");
1167 assert_eq!(value["input"][0]["content"][0]["type"], "input_text");
1168 assert_eq!(value["input"][0]["content"][0]["text"], "Ping");
1169 assert_eq!(value["tools"][0]["type"], "function");
1170 assert_eq!(value["tools"][0]["name"], "search");
1171 assert_eq!(value["tools"][0]["description"], "Search docs");
1172 assert_eq!(
1173 value["tools"][0]["parameters"],
1174 json!({
1175 "type": "object",
1176 "properties": { "q": { "type": "string" } },
1177 "required": ["q"]
1178 })
1179 );
1180 assert!(value["tools"][1].get("description").is_none());
1181 }
1182
1183 #[test]
1184 fn test_stream_parses_text_and_tool_call() {
1185 let events = vec![
1186 json!({
1187 "type": "response.output_text.delta",
1188 "item_id": "msg_1",
1189 "content_index": 0,
1190 "delta": "Hello"
1191 }),
1192 json!({
1193 "type": "response.output_item.added",
1194 "output_index": 1,
1195 "item": {
1196 "type": "function_call",
1197 "id": "fc_1",
1198 "call_id": "call_1",
1199 "name": "echo",
1200 "arguments": ""
1201 }
1202 }),
1203 json!({
1204 "type": "response.function_call_arguments.delta",
1205 "item_id": "fc_1",
1206 "output_index": 1,
1207 "delta": "{\"text\":\"hi\"}"
1208 }),
1209 json!({
1210 "type": "response.output_item.done",
1211 "output_index": 1,
1212 "item": {
1213 "type": "function_call",
1214 "id": "fc_1",
1215 "call_id": "call_1",
1216 "name": "echo",
1217 "arguments": "{\"text\":\"hi\"}",
1218 "status": "completed"
1219 }
1220 }),
1221 json!({
1222 "type": "response.completed",
1223 "response": {
1224 "incomplete_details": null,
1225 "usage": {
1226 "input_tokens": 1,
1227 "output_tokens": 2,
1228 "total_tokens": 3
1229 }
1230 }
1231 }),
1232 ];
1233
1234 let out = collect_events(&events);
1235 assert!(matches!(out.first(), Some(StreamEvent::Start { .. })));
1236 assert!(
1237 out.iter()
1238 .any(|e| matches!(e, StreamEvent::TextDelta { delta, .. } if delta == "Hello"))
1239 );
1240 assert!(out.iter().any(
1241 |e| matches!(e, StreamEvent::ToolCallEnd { tool_call, .. } if tool_call.name == "echo")
1242 ));
1243 assert!(out.iter().any(|e| matches!(
1244 e,
1245 StreamEvent::Done {
1246 reason: StopReason::ToolUse,
1247 ..
1248 }
1249 )));
1250 }
1251
1252 #[test]
1253 fn test_stream_accumulates_function_call_arguments_deltas() {
1254 let events = vec![
1255 json!({
1256 "type": "response.output_item.added",
1257 "item": {
1258 "type": "function_call",
1259 "id": "fc_2",
1260 "call_id": "call_2",
1261 "name": "search",
1262 "arguments": "{\"q\":\"ru"
1263 }
1264 }),
1265 json!({
1266 "type": "response.function_call_arguments.delta",
1267 "item_id": "fc_2",
1268 "delta": "st\"}"
1269 }),
1270 json!({
1271 "type": "response.output_item.done",
1272 "item": {
1273 "type": "function_call",
1274 "id": "fc_2",
1275 "call_id": "call_2",
1276 "name": "search",
1277 "arguments": ""
1278 }
1279 }),
1280 json!({
1281 "type": "response.completed",
1282 "response": {
1283 "incomplete_details": null,
1284 "usage": {
1285 "input_tokens": 1,
1286 "output_tokens": 1,
1287 "total_tokens": 2
1288 }
1289 }
1290 }),
1291 ];
1292
1293 let out = collect_events(&events);
1294 let tool_end = out
1295 .iter()
1296 .find_map(|event| match event {
1297 StreamEvent::ToolCallEnd { tool_call, .. } => Some(tool_call),
1298 _ => None,
1299 })
1300 .expect("tool call end");
1301 assert_eq!(tool_end.id, "call_2");
1302 assert_eq!(tool_end.name, "search");
1303 assert_eq!(tool_end.arguments, json!({ "q": "rust" }));
1304 }
1305
1306 #[test]
1307 fn test_stream_sets_bearer_auth_header() {
1308 let captured = run_stream_and_capture_headers().expect("captured request");
1309 assert_eq!(
1310 captured.headers.get("authorization").map(String::as_str),
1311 Some("Bearer test-openai-key")
1312 );
1313 assert_eq!(
1314 captured.headers.get("accept").map(String::as_str),
1315 Some("text/event-stream")
1316 );
1317
1318 let body: Value = serde_json::from_str(&captured.body).expect("request body json");
1319 assert_eq!(body["stream"], true);
1320 assert_eq!(body["input"][0]["role"], "user");
1321 assert_eq!(body["input"][0]["content"][0]["type"], "input_text");
1322 }
1323
1324 fn build_test_jwt(account_id: &str) -> String {
1325 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
1326 .encode(br#"{"alg":"none","typ":"JWT"}"#);
1327 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(
1328 serde_json::to_vec(&json!({
1329 "https://api.openai.com/auth": {
1330 "chatgpt_account_id": account_id
1331 }
1332 }))
1333 .expect("payload json"),
1334 );
1335 format!("{header}.{payload}.sig")
1336 }
1337
1338 #[test]
1339 fn test_bearer_token_parser_accepts_case_insensitive_scheme() {
1340 let token = super::bearer_token_from_authorization_header("bEaReR abc.def.ghi");
1341 assert_eq!(token.as_deref(), Some("abc.def.ghi"));
1342 assert!(super::bearer_token_from_authorization_header("Basic abc").is_none());
1343 assert!(super::bearer_token_from_authorization_header("Bearer").is_none());
1344 }
1345
1346 #[test]
1347 fn test_codex_mode_adds_required_headers_with_authorization_override() {
1348 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1349 let provider = OpenAIResponsesProvider::new("gpt-4o")
1350 .with_provider_name("openai-codex")
1351 .with_api_name("openai-codex-responses")
1352 .with_codex_mode(true)
1353 .with_base_url(base_url);
1354 let context = Context::owned(
1355 None,
1356 vec![Message::User(crate::model::UserMessage {
1357 content: UserContent::Text("ping".to_string()),
1358 timestamp: 0,
1359 })],
1360 Vec::new(),
1361 );
1362 let token = build_test_jwt("acct_test_123");
1363 let mut headers = HashMap::new();
1364 headers.insert("Authorization".to_string(), format!("Bearer {token}"));
1365 let options = StreamOptions {
1366 headers,
1367 session_id: Some("session-abc".to_string()),
1368 ..Default::default()
1369 };
1370
1371 let runtime = RuntimeBuilder::current_thread()
1372 .build()
1373 .expect("runtime build");
1374 runtime.block_on(async {
1375 let mut stream = provider.stream(&context, &options).await.expect("stream");
1376 while let Some(event) = stream.next().await {
1377 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1378 break;
1379 }
1380 }
1381 });
1382
1383 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1384 let expected_auth = format!("Bearer {token}");
1385 assert_eq!(
1386 captured.headers.get("authorization").map(String::as_str),
1387 Some(expected_auth.as_str())
1388 );
1389 assert_eq!(
1390 captured
1391 .headers
1392 .get("chatgpt-account-id")
1393 .map(String::as_str),
1394 Some("acct_test_123")
1395 );
1396 assert_eq!(
1397 captured.headers.get("openai-beta").map(String::as_str),
1398 Some("responses=experimental")
1399 );
1400 assert_eq!(
1401 captured.headers.get("session_id").map(String::as_str),
1402 Some("session-abc")
1403 );
1404 }
1405
1406 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1407 let runtime = RuntimeBuilder::current_thread()
1408 .build()
1409 .expect("runtime build");
1410 runtime.block_on(async move {
1411 let byte_stream = stream::iter(events.iter().map(|event| {
1412 let data = serde_json::to_string(event).expect("serialize event");
1413 Ok(format!("data: {data}\n\n").into_bytes())
1414 }));
1415
1416 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1417 let mut state = StreamState::new(
1418 event_source,
1419 "gpt-test".to_string(),
1420 "openai-responses".to_string(),
1421 "openai".to_string(),
1422 );
1423
1424 let mut out = Vec::new();
1425 while let Some(item) = state.event_source.next().await {
1426 let msg = item.expect("SSE event");
1427 state.process_event(&msg.data).expect("process_event");
1428 out.extend(state.pending_events.drain(..));
1429 if state.finished {
1430 break;
1431 }
1432 }
1433
1434 out
1435 })
1436 }
1437
1438 #[derive(Debug)]
1439 struct CapturedRequest {
1440 headers: HashMap<String, String>,
1441 body: String,
1442 }
1443
1444 fn run_stream_and_capture_headers() -> Option<CapturedRequest> {
1445 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1446 let provider = OpenAIResponsesProvider::new("gpt-4o").with_base_url(base_url);
1447 let context = Context::owned(
1448 None,
1449 vec![Message::User(crate::model::UserMessage {
1450 content: UserContent::Text("ping".to_string()),
1451 timestamp: 0,
1452 })],
1453 Vec::new(),
1454 );
1455 let options = StreamOptions {
1456 api_key: Some("test-openai-key".to_string()),
1457 ..Default::default()
1458 };
1459
1460 let runtime = RuntimeBuilder::current_thread()
1461 .build()
1462 .expect("runtime build");
1463 runtime.block_on(async {
1464 let mut stream = provider.stream(&context, &options).await.expect("stream");
1465 while let Some(event) = stream.next().await {
1466 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1467 break;
1468 }
1469 }
1470 });
1471
1472 rx.recv_timeout(Duration::from_secs(2)).ok()
1473 }
1474
1475 fn success_sse_body() -> String {
1476 [
1477 r#"data: {"type":"response.output_text.delta","item_id":"msg_1","content_index":0,"delta":"ok"}"#,
1478 "",
1479 r#"data: {"type":"response.completed","response":{"incomplete_details":null,"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}"#,
1480 "",
1481 ]
1482 .join("\n")
1483 }
1484
1485 fn spawn_test_server(
1486 status_code: u16,
1487 content_type: &str,
1488 body: &str,
1489 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1490 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1491 let addr = listener.local_addr().expect("local addr");
1492 let (tx, rx) = mpsc::channel();
1493 let body = body.to_string();
1494 let content_type = content_type.to_string();
1495
1496 std::thread::spawn(move || {
1497 let (mut socket, _) = listener.accept().expect("accept");
1498 socket
1499 .set_read_timeout(Some(Duration::from_secs(2)))
1500 .expect("set read timeout");
1501
1502 let mut bytes = Vec::new();
1503 let mut chunk = [0_u8; 4096];
1504 loop {
1505 match socket.read(&mut chunk) {
1506 Ok(0) => break,
1507 Ok(n) => {
1508 bytes.extend_from_slice(&chunk[..n]);
1509 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1510 break;
1511 }
1512 }
1513 Err(err)
1514 if err.kind() == std::io::ErrorKind::WouldBlock
1515 || err.kind() == std::io::ErrorKind::TimedOut =>
1516 {
1517 break;
1518 }
1519 Err(err) => panic!("read request failed: {err}"),
1520 }
1521 }
1522
1523 let header_end = bytes
1524 .windows(4)
1525 .position(|window| window == b"\r\n\r\n")
1526 .expect("request header boundary");
1527 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1528 let headers = parse_headers(&header_text);
1529 let mut request_body = bytes[header_end + 4..].to_vec();
1530
1531 let content_length = headers
1532 .get("content-length")
1533 .and_then(|value| value.parse::<usize>().ok())
1534 .unwrap_or(0);
1535 while request_body.len() < content_length {
1536 match socket.read(&mut chunk) {
1537 Ok(0) => break,
1538 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1539 Err(err)
1540 if err.kind() == std::io::ErrorKind::WouldBlock
1541 || err.kind() == std::io::ErrorKind::TimedOut =>
1542 {
1543 break;
1544 }
1545 Err(err) => panic!("read request body failed: {err}"),
1546 }
1547 }
1548
1549 let captured = CapturedRequest {
1550 headers,
1551 body: String::from_utf8_lossy(&request_body).to_string(),
1552 };
1553 tx.send(captured).expect("send captured request");
1554
1555 let reason = match status_code {
1556 401 => "Unauthorized",
1557 500 => "Internal Server Error",
1558 _ => "OK",
1559 };
1560 let response = format!(
1561 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1562 body.len()
1563 );
1564 socket
1565 .write_all(response.as_bytes())
1566 .expect("write response");
1567 socket.flush().expect("flush response");
1568 });
1569
1570 (format!("http://{addr}/responses"), rx)
1571 }
1572
1573 fn parse_headers(header_text: &str) -> HashMap<String, String> {
1574 let mut headers = HashMap::new();
1575 for line in header_text.lines().skip(1) {
1576 if let Some((name, value)) = line.split_once(':') {
1577 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1578 }
1579 }
1580 headers
1581 }
1582
1583 #[derive(Debug, Deserialize)]
1588 struct ProviderFixture {
1589 cases: Vec<ProviderCase>,
1590 }
1591
1592 #[derive(Debug, Deserialize)]
1593 struct ProviderCase {
1594 name: String,
1595 events: Vec<Value>,
1596 expected: Vec<EventSummary>,
1597 }
1598
1599 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1600 struct EventSummary {
1601 kind: String,
1602 #[serde(default)]
1603 content_index: Option<usize>,
1604 #[serde(default)]
1605 delta: Option<String>,
1606 #[serde(default)]
1607 content: Option<String>,
1608 #[serde(default)]
1609 reason: Option<String>,
1610 }
1611
1612 fn load_fixture(file_name: &str) -> ProviderFixture {
1613 let path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1614 .join("tests/fixtures/provider_responses")
1615 .join(file_name);
1616 let data = std::fs::read_to_string(&path).expect("read fixture file");
1617 serde_json::from_str(&data).expect("parse fixture JSON")
1618 }
1619
1620 fn summarize_event(event: &StreamEvent) -> EventSummary {
1621 match event {
1622 StreamEvent::Start { .. } => EventSummary {
1623 kind: "start".to_string(),
1624 content_index: None,
1625 delta: None,
1626 content: None,
1627 reason: None,
1628 },
1629 StreamEvent::TextStart { content_index, .. } => EventSummary {
1630 kind: "text_start".to_string(),
1631 content_index: Some(*content_index),
1632 delta: None,
1633 content: None,
1634 reason: None,
1635 },
1636 StreamEvent::TextDelta {
1637 content_index,
1638 delta,
1639 ..
1640 } => EventSummary {
1641 kind: "text_delta".to_string(),
1642 content_index: Some(*content_index),
1643 delta: Some(delta.clone()),
1644 content: None,
1645 reason: None,
1646 },
1647 StreamEvent::TextEnd {
1648 content_index,
1649 content,
1650 ..
1651 } => EventSummary {
1652 kind: "text_end".to_string(),
1653 content_index: Some(*content_index),
1654 delta: None,
1655 content: Some(content.clone()),
1656 reason: None,
1657 },
1658 StreamEvent::Done { reason, .. } => EventSummary {
1659 kind: "done".to_string(),
1660 content_index: None,
1661 delta: None,
1662 content: None,
1663 reason: Some(reason_to_string(*reason)),
1664 },
1665 StreamEvent::Error { reason, .. } => EventSummary {
1666 kind: "error".to_string(),
1667 content_index: None,
1668 delta: None,
1669 content: None,
1670 reason: Some(reason_to_string(*reason)),
1671 },
1672 _ => EventSummary {
1673 kind: "other".to_string(),
1674 content_index: None,
1675 delta: None,
1676 content: None,
1677 reason: None,
1678 },
1679 }
1680 }
1681
1682 fn reason_to_string(reason: StopReason) -> String {
1683 match reason {
1684 StopReason::Stop => "stop".to_string(),
1685 StopReason::ToolUse => "tool_use".to_string(),
1686 StopReason::Length => "length".to_string(),
1687 StopReason::Error => "error".to_string(),
1688 StopReason::Aborted => "aborted".to_string(),
1689 }
1690 }
1691
1692 #[test]
1693 fn test_stream_fixtures() {
1694 let fixture = load_fixture("openai_responses_stream.json");
1695 for case in fixture.cases {
1696 let events = collect_events(&case.events);
1697 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1698 assert_eq!(summaries, case.expected, "case: {}", case.name);
1699 }
1700 }
1701}
1702
1703#[cfg(feature = "fuzzing")]
1708pub mod fuzz {
1709 use super::*;
1710 use futures::stream;
1711 use std::pin::Pin;
1712
1713 type FuzzStream =
1714 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1715
1716 pub struct Processor(StreamState<FuzzStream>);
1718
1719 impl Default for Processor {
1720 fn default() -> Self {
1721 Self::new()
1722 }
1723 }
1724
1725 impl Processor {
1726 pub fn new() -> Self {
1728 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1729 Self(StreamState::new(
1730 crate::sse::SseStream::new(Box::pin(empty)),
1731 "gpt-responses-fuzz".into(),
1732 "openai-responses".into(),
1733 "openai".into(),
1734 ))
1735 }
1736
1737 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1739 self.0.process_event(data)?;
1740 Ok(self.0.pending_events.drain(..).collect())
1741 }
1742 }
1743}