1use super::model::ModelConfig;
30use super::traits::*;
31use crate::types::*;
32use async_trait::async_trait;
33use futures::StreamExt;
34use reqwest_eventsource::EventSource;
35use serde::Deserialize;
36use tokio::sync::mpsc;
37use tracing::{debug, warn};
38
39pub struct OpenAiResponsesProvider;
41
42#[async_trait]
43impl StreamProvider for OpenAiResponsesProvider {
44 fn provider_id(&self) -> &str {
45 "openai-responses"
46 }
47
48 async fn stream(
49 &self,
50 config: StreamConfig, tx: mpsc::UnboundedSender<StreamEvent>, cancel: tokio_util::sync::CancellationToken, ) -> Result<Message, ProviderError> {
54 let model_config = &config.model_config;
55 let api_key = model_config.resolve_api_key().await?;
57
58 let url = format!("{}/responses", model_config.base_url);
59 let body = build_request_body(&config, model_config);
60 debug!(
61 "OpenAI Responses request: model={} url={}",
62 config.model_config.id, url
63 );
64
65 let client = reqwest::Client::new();
66 let mut request = client
67 .post(&url)
68 .header("content-type", "application/json")
69 .header("authorization", format!("Bearer {}", api_key));
70
71 for (k, v) in &model_config.headers {
72 request = request.header(k, v);
73 }
74
75 let request = request.json(&body);
76 let mut es =
77 EventSource::new(request).map_err(|e| ProviderError::Network(e.to_string()))?;
78
79 let mut content: Vec<Content> = Vec::new();
80 let mut usage = Usage::default();
81 let mut stop_reason = StopReason::Stop;
82 let mut tool_call_buffers: std::collections::HashMap<usize, ToolCallBuffer> =
96 std::collections::HashMap::new();
97
98 let _ = tx.send(StreamEvent::Start);
99
100 loop {
101 tokio::select! {
102 _ = cancel.cancelled() => {
103 es.close();
104 return Err(ProviderError::Cancelled);
105 }
106 event = es.next() => {
107 match event {
108 None => break,
109 Some(Ok(reqwest_eventsource::Event::Open)) => {}
110 Some(Ok(reqwest_eventsource::Event::Message(msg))) => {
111 match msg.event.as_str() {
112 "response.output_text.delta" => {
113 if let Ok(data) = serde_json::from_str::<TextDeltaEvent>(&msg.data) {
114 let text_idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
115 let idx = match text_idx {
116 Some(i) => i,
117 None => {
118 content.push(Content::Text { text: String::new() });
119 content.len() - 1
120 }
121 };
122 if let Some(Content::Text { text }) = content.get_mut(idx) {
123 text.push_str(&data.delta);
124 }
125 let _ = tx.send(StreamEvent::TextDelta {
126 content_index: idx,
127 delta: data.delta,
128 });
129 }
130 }
131 "response.reasoning.delta" => {
132 if let Ok(data) = serde_json::from_str::<TextDeltaEvent>(&msg.data) {
133 let idx = content.iter().position(|c| matches!(c, Content::Thinking { .. }));
134 let idx = match idx {
135 Some(i) => i,
136 None => {
137 content.push(Content::Thinking { thinking: String::new(), signature: None });
138 content.len() - 1
139 }
140 };
141 if let Some(Content::Thinking { thinking, .. }) = content.get_mut(idx) {
142 thinking.push_str(&data.delta);
143 }
144 let _ = tx.send(StreamEvent::ThinkingDelta {
145 content_index: idx,
146 delta: data.delta,
147 });
148 }
149 }
150 "response.function_call_arguments.start" => {
151 if let Ok(data) = serde_json::from_str::<FunctionCallStartEvent>(&msg.data) {
152 let idx = content.len() + tool_call_buffers.len();
153 tool_call_buffers.insert(idx, ToolCallBuffer {
154 id: data.call_id.unwrap_or_default(),
155 name: data.name.unwrap_or_default(),
156 arguments: String::new(),
157 });
158 let buf = &tool_call_buffers[&idx];
159 let _ = tx.send(StreamEvent::ToolCallStart {
160 content_index: idx,
161 id: buf.id.clone(),
162 name: buf.name.clone(),
163 });
164 }
165 }
166 "response.function_call_arguments.delta" => {
167 if let Ok(data) = serde_json::from_str::<TextDeltaEvent>(&msg.data) {
168 if let Some((&idx, buf)) = tool_call_buffers.iter_mut().last() {
170 buf.arguments.push_str(&data.delta);
171 let _ = tx.send(StreamEvent::ToolCallDelta {
172 content_index: idx,
173 delta: data.delta,
174 });
175 }
176 }
177 }
178 "response.function_call_arguments.done" => {
179 }
181 "response.completed" => {
182 if let Ok(data) = serde_json::from_str::<ResponseCompletedEvent>(&msg.data) {
183 if let Some(resp) = data.response {
184 if let Some(u) = resp.usage {
185 usage.input = u.input_tokens;
186 usage.output = u.output_tokens;
187 usage.total_tokens = u.total_tokens;
188 if let Some(details) = u.output_token_details {
189 usage.reasoning = details.reasoning_tokens;
190 }
191 }
192 if resp.status == Some("incomplete".to_string()) {
193 stop_reason = StopReason::Length;
194 }
195 }
196 }
197 break;
198 }
199 "error" => {
200 warn!("OpenAI Responses error: {}", msg.data);
201 let err_msg = Message::Assistant {
202 content: vec![Content::Text { text: String::new() }],
203 stop_reason: StopReason::Error,
204 model: config.model_config.id.clone(),
205 provider: model_config.provider.clone(),
206 usage: usage.clone(),
207 timestamp: now_ms(),
208 error_message: Some(msg.data),
209 };
210 let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
211 return Ok(err_msg);
212 }
213 _ => {
214 debug!("Unknown Responses event: {}", msg.event);
215 }
216 }
217 }
218 Some(Err(e)) => {
219 let err_str = e.to_string();
220 warn!("OpenAI Responses SSE error: {}", err_str);
221 let err_msg = Message::Assistant {
222 content: vec![Content::Text { text: String::new() }],
223 stop_reason: StopReason::Error,
224 model: config.model_config.id.clone(),
225 provider: model_config.provider.clone(),
226 usage: usage.clone(),
227 timestamp: now_ms(),
228 error_message: Some(err_str),
229 };
230 let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
231 return Ok(err_msg);
232 }
233 }
234 }
235 }
236 }
237
238 for (_, buf) in tool_call_buffers {
240 let args = serde_json::from_str(&buf.arguments)
241 .unwrap_or(serde_json::Value::Object(Default::default()));
242 content.push(Content::ToolCall {
243 id: buf.id,
244 name: buf.name,
245 arguments: args,
246 });
247 }
248
249 if content
250 .iter()
251 .any(|c| matches!(c, Content::ToolCall { .. }))
252 {
253 stop_reason = StopReason::ToolUse;
254 }
255
256 let message = Message::Assistant {
257 content,
258 stop_reason,
259 model: config.model_config.id.clone(),
260 provider: model_config.provider.clone(),
261 usage,
262 timestamp: now_ms(),
263 error_message: None,
264 };
265
266 let _ = tx.send(StreamEvent::Done {
267 message: message.clone(),
268 });
269 Ok(message)
270 }
271}
272
273struct ToolCallBuffer {
274 id: String,
275 name: String,
276 arguments: String,
277}
278
279fn build_request_body(
280 config: &StreamConfig, _model_config: &ModelConfig, ) -> serde_json::Value {
283 let mut input: Vec<serde_json::Value> = Vec::new();
284
285 for msg in &config.messages {
286 match msg {
287 Message::User { content, .. } => {
288 let user_content: Vec<serde_json::Value> = content
290 .iter()
291 .filter_map(|c| match c {
292 Content::Text { text } => Some(serde_json::json!({
293 "type": "input_text",
294 "text": text,
295 })),
296 Content::Image { data, mime_type } => Some(serde_json::json!({
297 "type": "input_image",
298 "image_url": format!("data:{};base64,{}", mime_type, data),
299 })),
300 _ => None,
301 })
302 .collect();
303
304 if user_content.len() == 1 && user_content[0]["type"] == "input_text" {
305 input.push(serde_json::json!({
307 "role": "user",
308 "content": user_content[0]["text"].as_str().unwrap_or(""),
309 }));
310 } else {
311 input.push(serde_json::json!({
313 "role": "user",
314 "content": user_content,
315 }));
316 }
317 }
318 Message::Assistant { content, .. } => {
319 for c in content {
320 match c {
321 Content::Text { text } => {
322 input.push(serde_json::json!({
323 "type": "message",
324 "role": "assistant",
325 "content": [{"type": "output_text", "text": text}],
326 }));
327 }
328 Content::ToolCall {
329 id,
330 name,
331 arguments,
332 } => {
333 input.push(serde_json::json!({
334 "type": "function_call",
335 "call_id": id,
336 "name": name,
337 "arguments": arguments.to_string(),
338 }));
339 }
340 _ => {}
341 }
342 }
343 }
344 Message::ToolResult {
345 tool_call_id,
346 content,
347 ..
348 } => {
349 let output_val = if content.iter().any(|c| matches!(c, Content::Image { .. })) {
350 let parts: Vec<serde_json::Value> = content
352 .iter()
353 .filter_map(|c| match c {
354 Content::Text { text } => Some(serde_json::json!({
355 "type": "input_text",
356 "text": text,
357 })),
358 Content::Image { data, mime_type } => Some(serde_json::json!({
359 "type": "input_image",
360 "image_url": format!("data:{};base64,{}", mime_type, data),
361 })),
362 _ => None,
363 })
364 .collect();
365 serde_json::json!(parts)
366 } else {
367 let text = content
368 .iter()
369 .find_map(|c| match c {
370 Content::Text { text } => Some(text.clone()),
371 _ => None,
372 })
373 .unwrap_or_default();
374 serde_json::json!(text)
375 };
376 input.push(serde_json::json!({
377 "type": "function_call_output",
378 "call_id": tool_call_id,
379 "output": output_val,
380 }));
381 }
382 }
383 }
384
385 let mut body = serde_json::json!({
386 "model": config.model_config.id,
387 "stream": true,
388 "input": input,
389 });
390
391 if !config.system_prompt.is_empty() {
392 body["instructions"] = serde_json::json!(config.system_prompt);
393 }
394
395 if let Some(max) = config.max_tokens {
396 body["max_output_tokens"] = serde_json::json!(max);
397 }
398
399 if !config.tools.is_empty() {
400 let tools: Vec<serde_json::Value> = config
401 .tools
402 .iter()
403 .map(|t| {
404 serde_json::json!({
405 "type": "function",
406 "name": t.name,
407 "description": t.description,
408 "parameters": t.parameters,
409 })
410 })
411 .collect();
412 body["tools"] = serde_json::json!(tools);
413 }
414
415 if config.thinking_level != ThinkingLevel::Off {
416 let effort = match config.thinking_level {
417 ThinkingLevel::Minimal | ThinkingLevel::Low => "low",
418 ThinkingLevel::Medium => "medium",
419 ThinkingLevel::High => "high",
420 ThinkingLevel::Off => unreachable!(),
421 };
422 body["reasoning"] = serde_json::json!({"effort": effort});
423 }
424
425 if let Some(temp) = config.temperature {
426 body["temperature"] = serde_json::json!(temp);
427 }
428
429 match &config.response_format {
433 ResponseFormat::Text => {} ResponseFormat::JsonObject => {
435 body["text"] = serde_json::json!({"format": {"type": "json_object"}});
436 }
437 ResponseFormat::JsonSchema {
438 schema,
439 name,
440 strict,
441 } => {
442 body["text"] = serde_json::json!({
443 "format": {
444 "type": "json_schema",
445 "name": name,
446 "schema": schema,
447 "strict": *strict,
448 },
449 });
450 }
451 }
452
453 body
454}
455
456#[derive(Deserialize)]
458struct TextDeltaEvent {
459 delta: String,
460}
461
462#[derive(Deserialize)]
463struct FunctionCallStartEvent {
464 #[serde(default)]
465 call_id: Option<String>,
466 #[serde(default)]
467 name: Option<String>,
468}
469
470#[derive(Deserialize)]
471struct ResponseCompletedEvent {
472 #[serde(default)]
473 response: Option<ResponseData>,
474}
475
476#[derive(Deserialize)]
477struct ResponseData {
478 #[serde(default)]
479 status: Option<String>,
480 #[serde(default)]
481 usage: Option<ResponseUsage>,
482}
483
484#[derive(Deserialize)]
485struct ResponseUsage {
486 #[serde(default)]
487 input_tokens: u64,
488 #[serde(default)]
489 output_tokens: u64,
490 #[serde(default)]
491 total_tokens: u64,
492 #[serde(default)]
493 output_token_details: Option<ResponseOutputTokenDetails>,
494}
495
496#[derive(Deserialize)]
497struct ResponseOutputTokenDetails {
498 #[serde(default)]
499 reasoning_tokens: u64,
500}