Skip to main content

phi_core/provider/
openai_responses.rs

1//! OpenAI Responses API provider.
2//!
3//! This is the newer OpenAI API that uses a different event format
4//! from Chat Completions. It has first-class support for reasoning items.
5/*
6ARCHITECTURE: OpenAiResponsesProvider — the "next generation" OpenAI API
7
8OpenAI launched the Responses API as a replacement for Chat Completions.
9Key differences vs Chat Completions:
10  - Input uses `input` field (not `messages`) with role/content objects
11  - Streaming events use `response.output_item.*` event names (not `choices[N].delta`)
12  - Reasoning blocks are first-class items in the output array (not `delta.reasoning_content`)
13  - Response has `output` array instead of `choices` array
14  - No `[DONE]` sentinel; stream ends when `response.completed` arrives
15
16This provider handles the Responses API event format. The Azure provider re-uses
17the same request body format (build_request_body / build_azure_request_body are similar)
18since Azure's OpenAI Responses API mirrors the OpenAI one.
19
20ARCHITECTURE: Tool call ID correlation
21
22The Responses API streams tool calls via `response.output_item.added` (with the item's
23index) then `response.function_call_arguments.delta` events. We use a
24`HashMap<usize, ToolCallBuffer>` (keyed by output item index) to buffer partial arguments.
25HashMap is used instead of Vec because indices are sparse — they represent positions in
26the full output array, which may include reasoning blocks between tool calls.
27*/
28
29use super::model::ModelConfig;
30use super::traits::*;
31use crate::types::*;
32use async_trait::async_trait;
33use futures::StreamExt;
34use reqwest_eventsource::EventSource;
35use serde::Deserialize;
36use tokio::sync::mpsc;
37use tracing::{debug, warn};
38
39/// Unit struct — no state. All logic in the `StreamProvider` impl.
40pub struct OpenAiResponsesProvider;
41
42#[async_trait]
43impl StreamProvider for OpenAiResponsesProvider {
44    fn provider_id(&self) -> &str {
45        "openai-responses"
46    }
47
48    async fn stream(
49        &self,
50        config: StreamConfig, // REQUEST — uses Responses API shape (input[] not messages[])
51        tx: mpsc::UnboundedSender<StreamEvent>, // OBSERVER — receives events; no [DONE] sentinel (stream just closes)
52        cancel: tokio_util::sync::CancellationToken, // ABORT — races against SSE stream
53    ) -> Result<Message, ProviderError> {
54        let model_config = &config.model_config;
55        // Resolve via CredentialProvider when set, else use the static `api_key`.
56        let api_key = model_config.resolve_api_key().await?;
57
58        let url = format!("{}/responses", model_config.base_url);
59        let body = build_request_body(&config, model_config);
60        debug!(
61            "OpenAI Responses request: model={} url={}",
62            config.model_config.id, url
63        );
64
65        let client = reqwest::Client::new();
66        let mut request = client
67            .post(&url)
68            .header("content-type", "application/json")
69            .header("authorization", format!("Bearer {}", api_key));
70
71        for (k, v) in &model_config.headers {
72            request = request.header(k, v);
73        }
74
75        let request = request.json(&body);
76        let mut es =
77            EventSource::new(request).map_err(|e| ProviderError::Network(e.to_string()))?;
78
79        let mut content: Vec<Content> = Vec::new();
80        let mut usage = Usage::default();
81        let mut stop_reason = StopReason::Stop;
82        /*
83        RUST QUIRK: `std::collections::HashMap<usize, ToolCallBuffer>` — sparse index map
84
85        Why HashMap here but Vec in the OpenAI compat provider?
86        The Responses API output array may have non-tool-call items (text, reasoning) at
87        arbitrary indices. A Vec indexed by output position would require filling gaps with
88        placeholder values. HashMap handles sparse indices naturally: only insert when we
89        see a tool call item, look up by exact index.
90
91        `std::collections::HashMap` (from the standard library) uses SipHash by default
92        — secure against hash-flooding DoS attacks, slightly slower than FxHash.
93        For small maps (< 10 entries), this is negligible.
94        */
95        let mut tool_call_buffers: std::collections::HashMap<usize, ToolCallBuffer> =
96            std::collections::HashMap::new();
97
98        let _ = tx.send(StreamEvent::Start);
99
100        loop {
101            tokio::select! {
102                _ = cancel.cancelled() => {
103                    es.close();
104                    return Err(ProviderError::Cancelled);
105                }
106                event = es.next() => {
107                    match event {
108                        None => break,
109                        Some(Ok(reqwest_eventsource::Event::Open)) => {}
110                        Some(Ok(reqwest_eventsource::Event::Message(msg))) => {
111                            match msg.event.as_str() {
112                                "response.output_text.delta" => {
113                                    if let Ok(data) = serde_json::from_str::<TextDeltaEvent>(&msg.data) {
114                                        let text_idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
115                                        let idx = match text_idx {
116                                            Some(i) => i,
117                                            None => {
118                                                content.push(Content::Text { text: String::new() });
119                                                content.len() - 1
120                                            }
121                                        };
122                                        if let Some(Content::Text { text }) = content.get_mut(idx) {
123                                            text.push_str(&data.delta);
124                                        }
125                                        let _ = tx.send(StreamEvent::TextDelta {
126                                            content_index: idx,
127                                            delta: data.delta,
128                                        });
129                                    }
130                                }
131                                "response.reasoning.delta" => {
132                                    if let Ok(data) = serde_json::from_str::<TextDeltaEvent>(&msg.data) {
133                                        let idx = content.iter().position(|c| matches!(c, Content::Thinking { .. }));
134                                        let idx = match idx {
135                                            Some(i) => i,
136                                            None => {
137                                                content.push(Content::Thinking { thinking: String::new(), signature: None });
138                                                content.len() - 1
139                                            }
140                                        };
141                                        if let Some(Content::Thinking { thinking, .. }) = content.get_mut(idx) {
142                                            thinking.push_str(&data.delta);
143                                        }
144                                        let _ = tx.send(StreamEvent::ThinkingDelta {
145                                            content_index: idx,
146                                            delta: data.delta,
147                                        });
148                                    }
149                                }
150                                "response.function_call_arguments.start" => {
151                                    if let Ok(data) = serde_json::from_str::<FunctionCallStartEvent>(&msg.data) {
152                                        let idx = content.len() + tool_call_buffers.len();
153                                        tool_call_buffers.insert(idx, ToolCallBuffer {
154                                            id: data.call_id.unwrap_or_default(),
155                                            name: data.name.unwrap_or_default(),
156                                            arguments: String::new(),
157                                        });
158                                        let buf = &tool_call_buffers[&idx];
159                                        let _ = tx.send(StreamEvent::ToolCallStart {
160                                            content_index: idx,
161                                            id: buf.id.clone(),
162                                            name: buf.name.clone(),
163                                        });
164                                    }
165                                }
166                                "response.function_call_arguments.delta" => {
167                                    if let Ok(data) = serde_json::from_str::<TextDeltaEvent>(&msg.data) {
168                                        // Find last buffer
169                                        if let Some((&idx, buf)) = tool_call_buffers.iter_mut().last() {
170                                            buf.arguments.push_str(&data.delta);
171                                            let _ = tx.send(StreamEvent::ToolCallDelta {
172                                                content_index: idx,
173                                                delta: data.delta,
174                                            });
175                                        }
176                                    }
177                                }
178                                "response.function_call_arguments.done" => {
179                                    // Tool call complete
180                                }
181                                "response.completed" => {
182                                    if let Ok(data) = serde_json::from_str::<ResponseCompletedEvent>(&msg.data) {
183                                        if let Some(resp) = data.response {
184                                            if let Some(u) = resp.usage {
185                                                usage.input = u.input_tokens;
186                                                usage.output = u.output_tokens;
187                                                usage.total_tokens = u.total_tokens;
188                                                if let Some(details) = u.output_token_details {
189                                                    usage.reasoning = details.reasoning_tokens;
190                                                }
191                                            }
192                                            if resp.status == Some("incomplete".to_string()) {
193                                                stop_reason = StopReason::Length;
194                                            }
195                                        }
196                                    }
197                                    break;
198                                }
199                                "error" => {
200                                    warn!("OpenAI Responses error: {}", msg.data);
201                                    let err_msg = Message::Assistant {
202                                        content: vec![Content::Text { text: String::new() }],
203                                        stop_reason: StopReason::Error,
204                                        model: config.model_config.id.clone(),
205                                        provider: model_config.provider.clone(),
206                                        usage: usage.clone(),
207                                        timestamp: now_ms(),
208                                        error_message: Some(msg.data),
209                                    };
210                                    let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
211                                    return Ok(err_msg);
212                                }
213                                _ => {
214                                    debug!("Unknown Responses event: {}", msg.event);
215                                }
216                            }
217                        }
218                        Some(Err(e)) => {
219                            let err_str = e.to_string();
220                            warn!("OpenAI Responses SSE error: {}", err_str);
221                            let err_msg = Message::Assistant {
222                                content: vec![Content::Text { text: String::new() }],
223                                stop_reason: StopReason::Error,
224                                model: config.model_config.id.clone(),
225                                provider: model_config.provider.clone(),
226                                usage: usage.clone(),
227                                timestamp: now_ms(),
228                                error_message: Some(err_str),
229                            };
230                            let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
231                            return Ok(err_msg);
232                        }
233                    }
234                }
235            }
236        }
237
238        // Finalize tool calls
239        for (_, buf) in tool_call_buffers {
240            let args = serde_json::from_str(&buf.arguments)
241                .unwrap_or(serde_json::Value::Object(Default::default()));
242            content.push(Content::ToolCall {
243                id: buf.id,
244                name: buf.name,
245                arguments: args,
246            });
247        }
248
249        if content
250            .iter()
251            .any(|c| matches!(c, Content::ToolCall { .. }))
252        {
253            stop_reason = StopReason::ToolUse;
254        }
255
256        let message = Message::Assistant {
257            content,
258            stop_reason,
259            model: config.model_config.id.clone(),
260            provider: model_config.provider.clone(),
261            usage,
262            timestamp: now_ms(),
263            error_message: None,
264        };
265
266        let _ = tx.send(StreamEvent::Done {
267            message: message.clone(),
268        });
269        Ok(message)
270    }
271}
272
273struct ToolCallBuffer {
274    id: String,
275    name: String,
276    arguments: String,
277}
278
279fn build_request_body(
280    config: &StreamConfig, // REQUEST — messages, tools, model, system prompt, cache config
281    _model_config: &ModelConfig, // UNUSED — reserved for future per-provider quirks (prefixed _ to suppress warning)
282) -> serde_json::Value {
283    let mut input: Vec<serde_json::Value> = Vec::new();
284
285    for msg in &config.messages {
286        match msg {
287            Message::User { content, .. } => {
288                // Build content array for user message (supports text + images)
289                let user_content: Vec<serde_json::Value> = content
290                    .iter()
291                    .filter_map(|c| match c {
292                        Content::Text { text } => Some(serde_json::json!({
293                            "type": "input_text",
294                            "text": text,
295                        })),
296                        Content::Image { data, mime_type } => Some(serde_json::json!({
297                            "type": "input_image",
298                            "image_url": format!("data:{};base64,{}", mime_type, data),
299                        })),
300                        _ => None,
301                    })
302                    .collect();
303
304                if user_content.len() == 1 && user_content[0]["type"] == "input_text" {
305                    // Simple text-only message can use shorthand format
306                    input.push(serde_json::json!({
307                        "role": "user",
308                        "content": user_content[0]["text"].as_str().unwrap_or(""),
309                    }));
310                } else {
311                    // Multi-modal content uses array format
312                    input.push(serde_json::json!({
313                        "role": "user",
314                        "content": user_content,
315                    }));
316                }
317            }
318            Message::Assistant { content, .. } => {
319                for c in content {
320                    match c {
321                        Content::Text { text } => {
322                            input.push(serde_json::json!({
323                                "type": "message",
324                                "role": "assistant",
325                                "content": [{"type": "output_text", "text": text}],
326                            }));
327                        }
328                        Content::ToolCall {
329                            id,
330                            name,
331                            arguments,
332                        } => {
333                            input.push(serde_json::json!({
334                                "type": "function_call",
335                                "call_id": id,
336                                "name": name,
337                                "arguments": arguments.to_string(),
338                            }));
339                        }
340                        _ => {}
341                    }
342                }
343            }
344            Message::ToolResult {
345                tool_call_id,
346                content,
347                ..
348            } => {
349                let output_val = if content.iter().any(|c| matches!(c, Content::Image { .. })) {
350                    // Images present: build content array
351                    let parts: Vec<serde_json::Value> = content
352                        .iter()
353                        .filter_map(|c| match c {
354                            Content::Text { text } => Some(serde_json::json!({
355                                "type": "input_text",
356                                "text": text,
357                            })),
358                            Content::Image { data, mime_type } => Some(serde_json::json!({
359                                "type": "input_image",
360                                "image_url": format!("data:{};base64,{}", mime_type, data),
361                            })),
362                            _ => None,
363                        })
364                        .collect();
365                    serde_json::json!(parts)
366                } else {
367                    let text = content
368                        .iter()
369                        .find_map(|c| match c {
370                            Content::Text { text } => Some(text.clone()),
371                            _ => None,
372                        })
373                        .unwrap_or_default();
374                    serde_json::json!(text)
375                };
376                input.push(serde_json::json!({
377                    "type": "function_call_output",
378                    "call_id": tool_call_id,
379                    "output": output_val,
380                }));
381            }
382        }
383    }
384
385    let mut body = serde_json::json!({
386        "model": config.model_config.id,
387        "stream": true,
388        "input": input,
389    });
390
391    if !config.system_prompt.is_empty() {
392        body["instructions"] = serde_json::json!(config.system_prompt);
393    }
394
395    if let Some(max) = config.max_tokens {
396        body["max_output_tokens"] = serde_json::json!(max);
397    }
398
399    if !config.tools.is_empty() {
400        let tools: Vec<serde_json::Value> = config
401            .tools
402            .iter()
403            .map(|t| {
404                serde_json::json!({
405                    "type": "function",
406                    "name": t.name,
407                    "description": t.description,
408                    "parameters": t.parameters,
409                })
410            })
411            .collect();
412        body["tools"] = serde_json::json!(tools);
413    }
414
415    if config.thinking_level != ThinkingLevel::Off {
416        let effort = match config.thinking_level {
417            ThinkingLevel::Minimal | ThinkingLevel::Low => "low",
418            ThinkingLevel::Medium => "medium",
419            ThinkingLevel::High => "high",
420            ThinkingLevel::Off => unreachable!(),
421        };
422        body["reasoning"] = serde_json::json!({"effort": effort});
423    }
424
425    if let Some(temp) = config.temperature {
426        body["temperature"] = serde_json::json!(temp);
427    }
428
429    // Structured-output wiring (Responses API shape). The Responses API uses
430    // `text.format` rather than the top-level `response_format` field that Chat
431    // Completions uses.
432    match &config.response_format {
433        ResponseFormat::Text => {} // default; omit the field
434        ResponseFormat::JsonObject => {
435            body["text"] = serde_json::json!({"format": {"type": "json_object"}});
436        }
437        ResponseFormat::JsonSchema {
438            schema,
439            name,
440            strict,
441        } => {
442            body["text"] = serde_json::json!({
443                "format": {
444                    "type": "json_schema",
445                    "name": name,
446                    "schema": schema,
447                    "strict": *strict,
448                },
449            });
450        }
451    }
452
453    body
454}
455
456// Event types
457#[derive(Deserialize)]
458struct TextDeltaEvent {
459    delta: String,
460}
461
462#[derive(Deserialize)]
463struct FunctionCallStartEvent {
464    #[serde(default)]
465    call_id: Option<String>,
466    #[serde(default)]
467    name: Option<String>,
468}
469
470#[derive(Deserialize)]
471struct ResponseCompletedEvent {
472    #[serde(default)]
473    response: Option<ResponseData>,
474}
475
476#[derive(Deserialize)]
477struct ResponseData {
478    #[serde(default)]
479    status: Option<String>,
480    #[serde(default)]
481    usage: Option<ResponseUsage>,
482}
483
484#[derive(Deserialize)]
485struct ResponseUsage {
486    #[serde(default)]
487    input_tokens: u64,
488    #[serde(default)]
489    output_tokens: u64,
490    #[serde(default)]
491    total_tokens: u64,
492    #[serde(default)]
493    output_token_details: Option<ResponseOutputTokenDetails>,
494}
495
496#[derive(Deserialize)]
497struct ResponseOutputTokenDetails {
498    #[serde(default)]
499    reasoning_tokens: u64,
500}