1use async_trait::async_trait;
11use bytes::Bytes;
12use futures::{Stream, StreamExt};
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::Value as JsonValue;
16use std::pin::Pin;
17
18use crate::{
19 error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
20 ProviderEvent, StopReason, StreamOptions, Usage,
21};
22
23use super::shared_client;
24
25#[derive(Clone)]
27pub struct OpenAiResponsesProvider {
28 client: &'static Client,
29 api_key: Option<String>,
30 base_url: Option<String>,
31}
32
33impl OpenAiResponsesProvider {
34 pub fn new() -> Self {
38 Self {
39 client: shared_client(),
40 api_key: None,
41 base_url: None,
42 }
43 }
44
45 pub fn with_api_key(api_key: impl Into<String>) -> Self {
47 Self {
48 client: shared_client(),
49 api_key: Some(api_key.into()),
50 base_url: None,
51 }
52 }
53
54 pub fn with_base_url_and_key(base_url: &str, api_key: Option<String>) -> Self {
58 Self {
59 client: shared_client(),
60 api_key,
61 base_url: Some(base_url.to_string()),
62 }
63 }
64}
65
66impl Default for OpenAiResponsesProvider {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72#[async_trait]
73impl Provider for OpenAiResponsesProvider {
74 async fn stream(
75 &self,
76 model: &Model,
77 context: &Context,
78 options: Option<StreamOptions>,
79 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
80 let options = options.unwrap_or_default();
81
82 let effective_base_url = self.base_url.as_deref().unwrap_or(&model.base_url);
84 let url = format!("{}/responses", effective_base_url);
85
86 let api_key = options
88 .api_key
89 .as_ref()
90 .or(self.api_key.as_ref())
91 .ok_or_else(|| ProviderError::MissingApiKey)?;
92
93 let input = build_input(context)?;
95
96 let mut body = serde_json::json!({
98 "model": model.id,
99 "input": input,
100 "stream": true,
101 });
102
103 if let Some(temp) = options.temperature {
105 body["temperature"] = serde_json::json!(temp);
106 }
107
108 if let Some(max) = options.max_tokens {
109 body["max_tokens"] = serde_json::json!(max);
110 }
111
112 if !context.tools.is_empty() {
114 body["tools"] = build_tools(&context.tools);
115 }
116
117 if let Some(ref thinking_level) = options.thinking_level {
119 if thinking_level != &crate::ThinkingLevel::Off {
120 if let Some(effort) = thinking_level.as_str() {
121 body["reasoning"] = serde_json::json!({
122 "effort": effort,
123 });
124 }
125 }
126 }
127
128 let mut headers = reqwest::header::HeaderMap::new();
130 headers.insert(
131 reqwest::header::AUTHORIZATION,
132 format!("Bearer {}", api_key)
133 .parse()
134 .expect("valid bearer header"),
135 );
136 headers.insert(
137 reqwest::header::CONTENT_TYPE,
138 "application/json".parse().expect("valid header value"),
139 );
140
141 for (k, v) in &options.headers {
143 if let (Ok(name), Ok(value)) = (
144 k.parse::<reqwest::header::HeaderName>(),
145 v.parse::<reqwest::header::HeaderValue>(),
146 ) {
147 headers.insert(name, value);
148 }
149 }
150
151 let response = self
153 .client
154 .post(&url)
155 .headers(headers)
156 .json(&body)
157 .send()
158 .await
159 .map_err(ProviderError::RequestFailed)?;
160
161 if !response.status().is_success() {
162 let status = response.status();
163 let body: String = response.text().await.unwrap_or_default();
164 return Err(ProviderError::HttpError(status.as_u16(), body));
165 }
166
167 let provider_name = model.provider.clone();
169 let model_id = model.id.clone();
170
171 let stream = response.bytes_stream().flat_map(
172 move |chunk: Result<Bytes, reqwest::Error>| match chunk {
173 Ok(bytes) => {
174 let text = String::from_utf8_lossy(&bytes).to_string();
175 futures::stream::iter(parse_sse_events(&text, &provider_name, &model_id))
176 }
177 Err(e) => futures::stream::iter(vec![ProviderEvent::Error {
178 reason: StopReason::Error,
179 error: create_error_message(&e.to_string(), &provider_name, &model_id),
180 }]),
181 },
182 );
183
184 Ok(Box::pin(stream))
185 }
186
187 fn name(&self) -> &str {
188 "openai-responses"
189 }
190}
191
192fn build_input(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
197 let mut input = Vec::new();
198
199 if let Some(ref prompt) = context.system_prompt {
201 input.push(serde_json::json!({
202 "role": "developer",
203 "content": prompt,
204 }));
205 }
206
207 for msg in &context.messages {
209 match msg {
210 crate::Message::User(u) => {
211 let content = match &u.content {
212 crate::MessageContent::Text(s) => serde_json::json!(s.clone()),
213 crate::MessageContent::Blocks(blocks) => blocks_to_json(blocks)?,
214 };
215 input.push(serde_json::json!({
216 "role": "user",
217 "content": content,
218 }));
219 }
220 crate::Message::Assistant(a) => {
221 let content = blocks_to_json(&a.content)?;
222 input.push(serde_json::json!({
223 "role": "assistant",
224 "content": content,
225 }));
226 }
227 crate::Message::ToolResult(t) => {
228 let content = blocks_to_json(&t.content)?;
229 input.push(serde_json::json!({
230 "role": "user",
231 "content": content,
232 }));
233 }
234 }
235 }
236
237 Ok(input)
238}
239
240fn blocks_to_json(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
242 if blocks.len() == 1 {
243 if let Some(text) = blocks[0].as_text() {
244 return Ok(JsonValue::String(text.to_string()));
245 }
246 }
247
248 let items: Result<Vec<_>, _> = blocks
249 .iter()
250 .map(|block| match block {
251 ContentBlock::Text(t) => Ok(serde_json::json!({
252 "type": "output_text",
253 "text": t.text,
254 })),
255 ContentBlock::ToolCall(tc) => Ok(serde_json::json!({
256 "type": "function_call",
257 "id": tc.id,
258 "name": tc.name,
259 "arguments": tc.arguments.to_string(),
260 })),
261 ContentBlock::Thinking(th) => Ok(serde_json::json!({
262 "type": "reasoning",
263 "summary": [
264 {
265 "type": "summary_text",
266 "text": th.thinking,
267 }
268 ]
269 })),
270 ContentBlock::Image(img) => Ok(serde_json::json!({
271 "type": "input_image",
272 "data": format!("data:{};base64,{}", img.mime_type, img.data),
273 "mime_type": img.mime_type,
274 })),
275 ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
276 "Unknown content block type".into(),
277 )),
278 })
279 .collect();
280
281 Ok(serde_json::json!(items?))
282}
283
284fn build_tools(tools: &[crate::Tool]) -> JsonValue {
286 let items: Vec<_> = tools
287 .iter()
288 .map(|tool| {
289 serde_json::json!({
290 "type": "function",
291 "name": tool.name,
292 "description": tool.description,
293 "parameters": tool.parameters,
294 })
295 })
296 .collect();
297
298 serde_json::json!(items)
299}
300
301fn parse_sse_events(text: &str, provider: &str, model_id: &str) -> Vec<ProviderEvent> {
311 let mut events = Vec::new();
312 let mut partial_message = AssistantMessage::new(Api::OpenAiResponses, provider, model_id);
313 let mut current_text_index: Option<usize> = None;
314 let mut current_tool_call_index: Option<usize> = None;
315 let mut accumulated_usage = Usage::default();
316
317 let estimated_events = text
319 .split('\n')
320 .filter(|l| l.starts_with("event: ") || l.starts_with("data: "))
321 .count();
322 events.reserve(estimated_events);
323
324 for line in text.split('\n') {
325 let line = line.trim_end_matches('\r');
326 if line.is_empty() {
327 continue;
328 }
329
330 if line.starts_with("event: ") {
332 let event_name = line.strip_prefix("event: ").unwrap_or(line).trim();
333 match event_name {
335 "response.created"
336 | "response.output_item.added"
337 | "response.content_part.added"
338 | "response.output_text.delta"
339 | "response.function_call_arguments.delta"
340 | "response.completed"
341 | "response.output_text.done"
342 | "response.reasoning.done" => {
343 }
345 _ => {}
346 }
347 continue;
348 }
349
350 if !line.starts_with("data: ") {
352 continue;
353 }
354
355 let data = line[6..].trim();
356 if data.is_empty() || data == "[DONE]" {
357 continue;
358 }
359
360 if let Ok(event) = serde_json::from_str::<ResponsesEvent>(data) {
362 match event {
363 ResponsesEvent::ResponseCreatedData { response } => {
364 if let Some(id) = response.id {
365 partial_message.response_id = Some(id);
366 }
367 events.push(ProviderEvent::Start {
368 partial: partial_message.clone(),
369 });
370 }
371 ResponsesEvent::OutputItemAdded { output_item } => {
372 match output_item.r#type.as_str() {
373 "message" => {
374 events.push(ProviderEvent::ToolCallStart {
375 content_index: output_item.index,
376 tool_call_id: output_item.id.clone(),
377 tool_name: None,
378 partial: partial_message.clone(),
379 });
380 current_tool_call_index = Some(output_item.index);
381 }
382 "function_call" => {
383 events.push(ProviderEvent::ToolCallStart {
384 content_index: output_item.index,
385 tool_call_id: output_item.id.clone(),
386 tool_name: None,
387 partial: partial_message.clone(),
388 });
389 current_tool_call_index = Some(output_item.index);
390 }
391 "reasoning" => {
392 events.push(ProviderEvent::ThinkingStart {
393 content_index: output_item.index,
394 partial: partial_message.clone(),
395 });
396 }
397 _ => {}
398 }
399 }
400 ResponsesEvent::ContentPartAdded { content_part } => {
401 match content_part.r#type.as_str() {
402 "output_text" => {
403 events.push(ProviderEvent::TextStart {
404 content_index: content_part.index,
405 partial: partial_message.clone(),
406 });
407 current_text_index = Some(content_part.index);
408 }
409 "function_call" => {
410 events.push(ProviderEvent::ToolCallStart {
411 content_index: content_part.index,
412 tool_call_id: None,
413 tool_name: None,
414 partial: partial_message.clone(),
415 });
416 current_tool_call_index = Some(content_part.index);
417 }
418 _ => {}
419 }
420 }
421 ResponsesEvent::OutputTextDelta { output_text: delta } => {
422 let content_idx = delta.content_index.or(current_text_index).unwrap_or(0);
424 events.push(ProviderEvent::TextDelta {
425 content_index: content_idx,
426 delta: delta.slice.unwrap_or_default(),
427 partial: partial_message.clone(),
428 });
429 if current_text_index.is_none() {
431 current_text_index = Some(content_idx);
432 }
433 }
434 ResponsesEvent::FunctionCallArgumentsDelta {
435 function_call: delta,
436 } => {
437 let content_idx = delta.content_index.or(current_tool_call_index).unwrap_or(0);
439 events.push(ProviderEvent::ToolCallDelta {
440 content_index: content_idx,
441 delta: delta.arguments.unwrap_or_default(),
442 partial: partial_message.clone(),
443 });
444 if current_tool_call_index.is_none() {
446 current_tool_call_index = Some(content_idx);
447 }
448 }
449 ResponsesEvent::OutputTextDone { output_text } => {
450 if let Some(idx) = current_text_index {
451 let text_content = output_text
452 .content
453 .map(|c| c.text.unwrap_or_default())
454 .unwrap_or_default();
455 events.push(ProviderEvent::TextEnd {
456 content_index: idx,
457 content: text_content,
458 partial: partial_message.clone(),
459 });
460 current_text_index = None;
461 }
462 }
463 ResponsesEvent::ReasoningDone { reasoning } => {
464 if let Some(summary) = reasoning.summary {
465 for item in summary {
466 if item.r#type == "summary_text" {
467 events.push(ProviderEvent::ThinkingEnd {
468 content_index: 0,
469 content: item.text.unwrap_or_default(),
470 partial: partial_message.clone(),
471 });
472 }
473 }
474 }
475 }
476 ResponsesEvent::ResponseWithUsage { response } => {
477 let is_incomplete = response.incomplete_details.is_some();
479
480 if let Some(usage) = response.usage {
482 accumulated_usage.input = usage.input_tokens;
483 accumulated_usage.output = usage.output_tokens;
484 accumulated_usage.total_tokens = usage.total_tokens;
485 if let Some(cached) = usage.input_tokens_details {
486 accumulated_usage.cache_read = cached.cached_tokens;
487 }
488 }
489
490 let stop_reason = if is_incomplete {
492 if let Some(incomplete) = response.incomplete_details {
493 match incomplete.reason.as_str() {
494 "max_output_tokens" => StopReason::Length,
495 "content_filter" => StopReason::Error,
496 _ => StopReason::Stop,
497 }
498 } else {
499 StopReason::Stop
500 }
501 } else {
502 StopReason::Stop
503 };
504
505 let mut done_msg = partial_message.clone();
506 done_msg.usage = accumulated_usage.clone();
507 events.push(ProviderEvent::Done {
508 reason: stop_reason,
509 message: done_msg,
510 });
511 }
512 _ => {}
513 }
514 }
515 }
516
517 events
518}
519
520fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
522 let mut message = AssistantMessage::new(Api::OpenAiResponses, provider, model_id);
523 message.stop_reason = StopReason::Error;
524 message.error_message = Some(msg.to_string());
525 message
526}
527
528#[derive(Debug, Deserialize)]
534#[serde(untagged)]
535enum ResponsesEvent {
536 ResponseWithUsage {
538 response: ResponseWithUsageData,
539 },
540 OutputItemAdded {
542 output_item: OutputItem,
543 },
544 ContentPartAdded {
546 content_part: ContentPart,
547 },
548 OutputTextDelta {
550 output_text: TextDelta,
551 },
552 FunctionCallArgumentsDelta {
554 function_call: FunctionCallDelta,
555 },
556 OutputTextDone {
558 output_text: OutputTextDone,
559 },
560 ReasoningDone {
562 reasoning: ReasoningDone,
563 },
564 ResponseCreatedData {
566 response: ResponseCreatedData,
567 },
568 #[allow(dead_code)]
570 Unknown(JsonValue),
571}
572
573#[derive(Debug, Deserialize)]
574struct ResponseCreatedData {
576 id: Option<String>,
577 #[serde(rename = "object")]
578 _object: Option<String>,
579 _status: Option<String>,
580 #[serde(rename = "model")]
581 _model: Option<String>,
582 _created_at: Option<i64>,
583}
584
585#[derive(Debug, Deserialize)]
586struct OutputItem {
588 index: usize,
589 #[serde(rename = "type")]
590 r#type: String,
591 id: Option<String>,
592 _status: Option<String>,
593}
594
595#[derive(Debug, Deserialize)]
596struct ContentPart {
597 index: usize,
598 #[serde(rename = "type")]
599 r#type: String,
600}
601
602#[derive(Debug, Deserialize)]
603struct TextDelta {
605 content_index: Option<usize>,
606 _output_index: Option<usize>,
607 slice: Option<String>,
608}
609
610#[derive(Debug, Deserialize)]
611struct FunctionCallDelta {
613 content_index: Option<usize>,
614 _output_index: Option<usize>,
615 _name: Option<String>,
616 arguments: Option<String>,
617 _call_id: Option<String>,
618}
619
620#[derive(Debug, Deserialize)]
621struct OutputTextDone {
623 _content_index: Option<usize>,
624 _output_index: Option<usize>,
625 content: Option<TextContent>,
626}
627
628#[derive(Debug, Deserialize)]
629struct TextContent {
630 text: Option<String>,
631}
632
633#[derive(Debug, Deserialize)]
634struct ReasoningDone {
636 _content_index: Option<usize>,
637 _output_index: Option<usize>,
638 summary: Option<Vec<SummaryItem>>,
639}
640
641#[derive(Debug, Deserialize)]
642struct SummaryItem {
643 #[serde(rename = "type")]
644 r#type: String,
645 text: Option<String>,
646}
647
648#[derive(Debug, Deserialize)]
650struct ResponseWithUsageData {
652 _id: Option<String>,
653 _status: Option<String>,
654 usage: Option<UsageData>,
655 incomplete_details: Option<IncompleteDetails>,
656}
657
658#[derive(Debug, Deserialize)]
659struct IncompleteDetails {
660 reason: String,
661}
662
663#[derive(Debug, Deserialize)]
664struct UsageData {
666 input_tokens: usize,
667 output_tokens: usize,
668 total_tokens: usize,
669 #[serde(rename = "input_tokens_details")]
670 input_tokens_details: Option<InputTokensDetails>,
671}
672
673#[derive(Debug, Deserialize)]
674struct InputTokensDetails {
675 #[serde(rename = "cached_tokens")]
676 cached_tokens: usize,
677}
678
679#[cfg(test)]
684mod tests {
685 use super::*;
686 use crate::{Context, Message, Model, TextContent};
687 use serde_json::json;
688
689 fn create_test_model() -> Model {
690 Model::new(
691 "gpt-4o",
692 "GPT-4o",
693 Api::OpenAiResponses,
694 "openai-responses",
695 "https://api.openai.com/v1",
696 )
697 }
698
699 fn create_test_context() -> Context {
700 Context::new()
701 }
702
703 #[test]
704 fn test_provider_name() {
705 let provider = OpenAiResponsesProvider::new();
706 assert_eq!(provider.name(), "openai-responses");
707 }
708
709 #[test]
710 fn test_build_input_with_text() {
711 let mut context = create_test_context();
712 context.add_message(Message::user("Hello, world!"));
713
714 let input = build_input(&context).unwrap();
715 assert_eq!(input.len(), 1);
716 assert_eq!(input[0]["role"], "user");
717 assert_eq!(input[0]["content"], "Hello, world!");
718 }
719
720 #[test]
721 fn test_build_input_with_system_prompt() {
722 let mut context = create_test_context();
723 context.set_system_prompt("You are a helpful assistant.");
724 context.add_message(Message::user("Hi!"));
725
726 let input = build_input(&context).unwrap();
727 assert_eq!(input.len(), 2);
728 assert_eq!(input[0]["role"], "developer");
729 assert_eq!(input[0]["content"], "You are a helpful assistant.");
730 }
731
732 #[test]
733 fn test_build_input_with_multiple_messages() {
734 let mut context = create_test_context();
735 context.add_message(Message::user("First message"));
736 context.add_message(Message::user("Second message"));
737
738 let input = build_input(&context).unwrap();
739 assert_eq!(input.len(), 2);
740 }
741
742 #[test]
743 fn test_blocks_to_json_text() {
744 let blocks = vec![ContentBlock::Text(TextContent::new("Hello"))];
745 let result = blocks_to_json(&blocks).unwrap();
746 assert_eq!(result, "Hello");
747 }
748
749 #[test]
750 fn test_blocks_to_json_multiple_blocks() {
751 let blocks = vec![
752 ContentBlock::Text(TextContent::new("Hello")),
753 ContentBlock::Text(TextContent::new("World")),
754 ];
755 let result = blocks_to_json(&blocks).unwrap();
756 assert!(result.is_array());
757 assert_eq!(result.as_array().unwrap().len(), 2);
758 }
759
760 #[test]
761 fn test_build_tools() {
762 let tools = vec![crate::Tool {
763 name: "get_weather".to_string(),
764 description: "Get weather for a location".to_string(),
765 parameters: json!({
766 "type": "object",
767 "properties": {
768 "location": {"type": "string"}
769 }
770 }),
771 }];
772
773 let result = build_tools(&tools);
774 assert!(result.is_array());
775 let tool = &result[0];
776 assert_eq!(tool["type"], "function");
777 assert_eq!(tool["name"], "get_weather");
778 }
779
780 #[test]
781 fn test_parse_response_created_event() {
782 let sse_data =
784 r#"data: {"response":{"id":"resp_123","status":"in_progress","model":"gpt-4o"}}"#;
785
786 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
787 assert!(!events.is_empty());
788 if let ProviderEvent::Start { partial } = &events[0] {
789 assert_eq!(partial.api, Api::OpenAiResponses);
790 }
791 }
792
793 #[test]
794 fn test_parse_output_item_added_event() {
795 let sse_data = r#"data: {"output_item":{"index":0,"id":"msg_123","type":"message","status":"in_progress"}}"#;
797
798 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
799 assert!(events
801 .iter()
802 .any(|e| matches!(e, ProviderEvent::ToolCallStart { .. })));
803 }
804
805 #[test]
806 fn test_parse_text_delta_event() {
807 let sse_data = r#"data: {"output_text":{"content_index":0,"slice":"Hello"}}"#;
809
810 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
811 assert!(events
812 .iter()
813 .any(|e| matches!(e, ProviderEvent::TextDelta { .. })));
814 }
815
816 #[test]
817 fn test_parse_function_call_delta_event() {
818 let sse_data = r#"data: {"function_call":{"content_index":0,"arguments":"{\"location"}}"#;
820
821 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
822 assert!(events
823 .iter()
824 .any(|e| matches!(e, ProviderEvent::ToolCallDelta { .. })));
825 }
826
827 #[test]
828 fn test_parse_completed_event_with_usage() {
829 let sse_data = r#"data: {"response":{"id":"resp_123","status":"completed","usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}"#;
831
832 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
833 assert!(events.iter().any(|e| matches!(
834 e,
835 ProviderEvent::Done {
836 reason: StopReason::Stop,
837 ..
838 }
839 )));
840 }
841
842 #[test]
843 fn test_parse_reasoning_event() {
844 let sse_data = r#"data: {"reasoning":{"content_index":0,"summary":[{"type":"summary_text","text":"Thinking process..."}]}}"#;
846
847 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
848 assert!(events
849 .iter()
850 .any(|e| matches!(e, ProviderEvent::ThinkingEnd { .. })));
851 }
852
853 #[test]
854 fn test_provider_with_api_key() {
855 let provider = OpenAiResponsesProvider::with_api_key("sk-test-key");
856 assert_eq!(provider.name(), "openai-responses");
858 }
859
860 #[test]
861 fn test_multiple_events_in_stream() {
862 let sse_data = r#"data: {"response":{"id":"resp_123"}}
864data: {"output_item":{"index":0,"type":"message"}}
865data: {"output_text":{"slice":"Hello"}}
866data: {"response":{"status":"completed"}}"#;
867
868 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
869 assert!(events.len() >= 4);
870 }
871
872 #[test]
873 fn test_invalid_json_skipped() {
874 let sse_data = r#"event: response.created
875data: {invalid json here}
876event: response.created
877data: {"response":{"id":"resp_123"}}"#;
878
879 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
880 assert!(!events.is_empty());
882 }
883
884 #[test]
885 fn test_done_marker() {
886 let sse_data = r#"event: response.created
887data: {"response":{"id":"resp_123"}}
888data: [DONE]"#;
889
890 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
891 assert!(events.len() <= 2);
893 }
894
895 #[test]
896 fn test_incomplete_response() {
897 let sse_data = r#"data: {"response":{"id":"resp_123","incomplete_details":{"reason":"max_output_tokens"}}}"#;
899
900 let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
901 assert!(events.iter().any(|e| matches!(
902 e,
903 ProviderEvent::Done {
904 reason: StopReason::Length,
905 ..
906 }
907 )));
908 }
909}