Skip to main content

llm/providers/openai/
responses_provider.rs

1use std::collections::HashMap;
2
3use async_openai::Client;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::responses::{
6    CreateResponse, EasyInputContent, EasyInputMessage, FunctionCallOutput, FunctionCallOutputItemParam, FunctionTool,
7    FunctionToolCall, ImageDetail, IncludeEnum, InputContent, InputImageContent, InputItem, InputParam,
8    InputTextContent, Item, MessageType, OutputItem, Reasoning, ReasoningEffort as OaiReasoningEffort,
9    ReasoningSummary, ResponseStreamEvent, ResponseUsage, Role, Tool,
10};
11use tokio_stream::StreamExt;
12use tracing::{debug, error};
13
14use crate::provider::get_context_window;
15use crate::{
16    ChatMessage, ContentBlock, Context, LlmError, LlmModel, LlmResponse, LlmResponseStream, ProviderFactory,
17    ReasoningEffort, Result, StopReason, StreamingModelProvider, TokenUsage, ToolDefinition,
18};
19
20impl From<ResponseUsage> for TokenUsage {
21    fn from(usage: ResponseUsage) -> Self {
22        TokenUsage {
23            input_tokens: usage.input_tokens,
24            output_tokens: usage.output_tokens,
25            cache_read_tokens: Some(usage.input_tokens_details.cached_tokens),
26            reasoning_tokens: Some(usage.output_tokens_details.reasoning_tokens),
27            ..TokenUsage::default()
28        }
29    }
30}
31
32pub(crate) fn map_user_content_for_responses(parts: &[ContentBlock]) -> Result<EasyInputContent> {
33    let mut items = Vec::with_capacity(parts.len());
34    for p in parts {
35        match p {
36            ContentBlock::Text { text } => {
37                items.push(InputContent::InputText(InputTextContent { text: text.clone() }));
38            }
39            ContentBlock::Image { .. } => {
40                items.push(InputContent::InputImage(InputImageContent {
41                    detail: ImageDetail::Auto,
42                    file_id: None,
43                    image_url: Some(p.as_data_uri().unwrap()),
44                }));
45            }
46            ContentBlock::Audio { .. } => {
47                return Err(LlmError::UnsupportedContent("OpenAI Responses does not support audio input".into()));
48            }
49        }
50    }
51    Ok(EasyInputContent::ContentList(items))
52}
53
54pub struct OpenAiProvider {
55    client: Client<OpenAIConfig>,
56    model: String,
57}
58
59impl ProviderFactory for OpenAiProvider {
60    async fn from_env() -> Result<Self> {
61        let api_key =
62            std::env::var("OPENAI_API_KEY").map_err(|_| LlmError::MissingApiKey("OPENAI_API_KEY".to_string()))?;
63
64        let config = OpenAIConfig::new().with_api_key(api_key);
65
66        Ok(Self { client: Client::with_config(config), model: "gpt-4.1".to_string() })
67    }
68
69    fn with_model(mut self, model: &str) -> Self {
70        if !model.is_empty() {
71            self.model = model.to_string();
72        }
73        self
74    }
75}
76
77impl StreamingModelProvider for OpenAiProvider {
78    fn stream_response(&self, context: &Context) -> LlmResponseStream {
79        let client = self.client.clone();
80        let model = self.model.clone();
81        let request = match build_response_request(&model, context) {
82            Ok(req) => req,
83            Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
84        };
85
86        Box::pin(async_stream::stream! {
87            debug!("Starting OpenAI Responses API stream for model: {model}");
88
89            let stream = match client.responses().create_stream(request).await {
90                Ok(s) => s,
91                Err(e) => {
92                    error!("Failed to create OpenAI Responses stream: {e:?}");
93                    yield Err(LlmError::ApiRequest(e.to_string()));
94                    return;
95                }
96            };
97
98            let mut stream = Box::pin(stream);
99            let mut fn_calls: HashMap<String, (String, String)> = HashMap::new();
100            let mut started = false;
101
102            while let Some(result) = stream.next().await {
103                match result {
104                    Ok(event) => {
105                        for response in process_event(event, &mut fn_calls, &mut started) {
106                            yield response;
107                        }
108                    }
109                    Err(e) => {
110                        yield Err(LlmError::ApiError(e.to_string()));
111                        break;
112                    }
113                }
114            }
115
116            if !started {
117                yield Ok(LlmResponse::done());
118            }
119        })
120    }
121
122    fn display_name(&self) -> String {
123        format!("OpenAI ({})", self.model)
124    }
125
126    fn context_window(&self) -> Option<u32> {
127        get_context_window("openai", &self.model)
128    }
129
130    fn model(&self) -> Option<LlmModel> {
131        format!("openai:{}", self.model).parse().ok()
132    }
133}
134
135fn process_event(
136    event: ResponseStreamEvent,
137    fn_calls: &mut HashMap<String, (String, String)>,
138    started: &mut bool,
139) -> Vec<Result<LlmResponse>> {
140    match event {
141        ResponseStreamEvent::ResponseCreated(e) => {
142            *started = true;
143            vec![Ok(LlmResponse::start(&e.response.id))]
144        }
145        ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
146            vec![Ok(LlmResponse::text(&e.delta))]
147        }
148        ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
149            vec![Ok(LlmResponse::reasoning(&e.delta))]
150        }
151        ResponseStreamEvent::ResponseOutputItemAdded(e) => {
152            if let OutputItem::FunctionCall(fc) = e.item {
153                let item_id = fc.id.clone().unwrap_or_default();
154                fn_calls.insert(item_id, (fc.call_id.clone(), fc.name.clone()));
155                vec![Ok(LlmResponse::tool_request_start(&fc.call_id, &fc.name))]
156            } else {
157                vec![]
158            }
159        }
160        ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
161            if let Some((call_id, _)) = fn_calls.get(&e.item_id) {
162                vec![Ok(LlmResponse::tool_request_arg(call_id, &e.delta))]
163            } else {
164                vec![]
165            }
166        }
167        ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
168            if let Some((call_id, name)) = fn_calls.remove(&e.item_id) {
169                let name = e.name.unwrap_or(name);
170                vec![Ok(LlmResponse::tool_request_complete(&call_id, &name, &e.arguments))]
171            } else {
172                vec![]
173            }
174        }
175        ResponseStreamEvent::ResponseCompleted(e) => {
176            let mut results = Vec::new();
177            if let Some(usage) = e.response.usage {
178                results.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
179            }
180            results.push(Ok(LlmResponse::done_with_stop_reason(StopReason::EndTurn)));
181            results
182        }
183        ResponseStreamEvent::ResponseFailed(e) => {
184            let msg = e.response.error.map_or_else(|| "Unknown error".to_string(), |err| err.message);
185            vec![Err(LlmError::ApiError(msg))]
186        }
187        ResponseStreamEvent::ResponseIncomplete(_) => {
188            vec![Ok(LlmResponse::done_with_stop_reason(StopReason::Length))]
189        }
190        ResponseStreamEvent::ResponseError(e) => {
191            vec![Err(LlmError::ApiError(e.message))]
192        }
193        _ => vec![],
194    }
195}
196
197fn build_response_request(model: &str, context: &Context) -> Result<CreateResponse> {
198    let mut instructions: Option<String> = None;
199    let mut items: Vec<InputItem> = Vec::new();
200
201    for msg in context.messages() {
202        match msg {
203            ChatMessage::System { content, .. } => {
204                instructions = Some(content.clone());
205            }
206            ChatMessage::User { content, .. } => {
207                items.push(InputItem::EasyMessage(EasyInputMessage {
208                    r#type: MessageType::Message,
209                    role: Role::User,
210                    content: map_user_content_for_responses(content)?,
211                    phase: None,
212                }));
213            }
214            ChatMessage::Assistant { content, tool_calls, .. } => {
215                if !content.is_empty() {
216                    items.push(InputItem::EasyMessage(EasyInputMessage {
217                        r#type: MessageType::Message,
218                        role: Role::Assistant,
219                        content: EasyInputContent::Text(content.clone()),
220                        phase: None,
221                    }));
222                }
223                for tc in tool_calls {
224                    items.push(InputItem::Item(Item::FunctionCall(FunctionToolCall {
225                        call_id: tc.id.clone(),
226                        name: tc.name.clone(),
227                        arguments: tc.arguments.clone(),
228                        namespace: None,
229                        id: None,
230                        status: None,
231                    })));
232                }
233            }
234            ChatMessage::ToolCallResult(result) => {
235                let (call_id, output) = match result {
236                    Ok(r) => (r.id.clone(), r.result.clone()),
237                    Err(e) => (e.id.clone(), e.error.clone()),
238                };
239                items.push(InputItem::Item(Item::FunctionCallOutput(FunctionCallOutputItemParam {
240                    call_id,
241                    output: FunctionCallOutput::Text(output),
242                    id: None,
243                    status: None,
244                })));
245            }
246            ChatMessage::Summary { content, .. } => {
247                items.push(InputItem::EasyMessage(EasyInputMessage {
248                    r#type: MessageType::Message,
249                    role: Role::User,
250                    content: EasyInputContent::Text(format!("[Previous conversation handoff]\n\n{content}")),
251                    phase: None,
252                }));
253            }
254            ChatMessage::Error { .. } => {}
255        }
256    }
257
258    let tools = map_tools(context.tools())?;
259
260    let reasoning = context
261        .reasoning_effort()
262        .map(|effort| Reasoning { effort: Some(map_reasoning_effort(effort)), summary: Some(ReasoningSummary::Auto) });
263
264    Ok(CreateResponse {
265        model: Some(model.to_string()),
266        input: InputParam::Items(items),
267        instructions,
268        tools: if tools.is_empty() { None } else { Some(tools) },
269        reasoning,
270        stream: Some(true),
271        include: Some(vec![IncludeEnum::ReasoningEncryptedContent]),
272        store: Some(false),
273        background: None,
274        conversation: None,
275        max_output_tokens: None,
276        metadata: None,
277        parallel_tool_calls: None,
278        previous_response_id: None,
279        prompt: None,
280        service_tier: None,
281        stream_options: None,
282        temperature: None,
283        text: None,
284        tool_choice: None,
285        top_p: None,
286        truncation: None,
287        prompt_cache_key: None,
288        safety_identifier: None,
289        max_tool_calls: None,
290        prompt_cache_retention: None,
291        top_logprobs: None,
292    })
293}
294
295fn map_tools(tools: &[ToolDefinition]) -> Result<Vec<Tool>> {
296    tools
297        .iter()
298        .map(|t| {
299            let parameters: serde_json::Value = serde_json::from_str(&t.parameters)
300                .map_err(|e| LlmError::ToolParameterParsing { tool_name: t.name.clone(), error: e.to_string() })?;
301
302            Ok(Tool::Function(FunctionTool {
303                name: t.name.clone(),
304                description: Some(t.description.clone()),
305                parameters: Some(parameters),
306                strict: Some(false),
307                defer_loading: None,
308            }))
309        })
310        .collect()
311}
312
313fn map_reasoning_effort(effort: ReasoningEffort) -> OaiReasoningEffort {
314    match effort {
315        ReasoningEffort::Low => OaiReasoningEffort::Low,
316        ReasoningEffort::Medium => OaiReasoningEffort::Medium,
317        ReasoningEffort::High => OaiReasoningEffort::High,
318        ReasoningEffort::Xhigh => OaiReasoningEffort::Xhigh,
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::AssistantReasoning;
326    use crate::ToolCallRequest;
327    use crate::types::IsoString;
328
329    #[test]
330    fn test_build_request_simple_user_message() {
331        let context = Context::new(
332            vec![ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() }],
333            vec![],
334        );
335
336        let req = build_response_request("gpt-4.1", &context).unwrap();
337        assert_eq!(req.model, Some("gpt-4.1".to_string()));
338        assert!(req.instructions.is_none());
339        assert!(req.tools.is_none());
340        assert!(req.reasoning.is_none());
341
342        let json = serde_json::to_value(&req).unwrap();
343        assert_eq!(json["input"][0]["role"], "user");
344        assert_eq!(json["input"][0]["content"][0]["text"], "Hello");
345    }
346
347    #[test]
348    fn test_build_request_with_system_message() {
349        let context = Context::new(
350            vec![
351                ChatMessage::System { content: "You are helpful.".to_string(), timestamp: IsoString::now() },
352                ChatMessage::User { content: vec![ContentBlock::text("Hi")], timestamp: IsoString::now() },
353            ],
354            vec![],
355        );
356
357        let req = build_response_request("gpt-4.1", &context).unwrap();
358        assert_eq!(req.instructions, Some("You are helpful.".to_string()));
359
360        let json = serde_json::to_value(&req).unwrap();
361        let items = json["input"].as_array().unwrap();
362        assert_eq!(items.len(), 1);
363        assert_eq!(items[0]["role"], "user");
364    }
365
366    #[test]
367    fn test_build_request_with_tool_calls() {
368        let context = Context::new(
369            vec![
370                ChatMessage::User { content: vec![ContentBlock::text("Search for rust")], timestamp: IsoString::now() },
371                ChatMessage::Assistant {
372                    content: String::new(),
373                    reasoning: AssistantReasoning::default(),
374                    timestamp: IsoString::now(),
375                    tool_calls: vec![ToolCallRequest {
376                        id: "call_1".to_string(),
377                        name: "search".to_string(),
378                        arguments: r#"{"q":"rust"}"#.to_string(),
379                    }],
380                },
381                ChatMessage::ToolCallResult(Ok(crate::ToolCallResult {
382                    id: "call_1".to_string(),
383                    name: "search".to_string(),
384                    arguments: r#"{"q":"rust"}"#.to_string(),
385                    result: "Found results".to_string(),
386                })),
387            ],
388            vec![ToolDefinition {
389                name: "search".to_string(),
390                description: "Search".to_string(),
391                parameters: r#"{"type":"object"}"#.to_string(),
392                server: None,
393            }],
394        );
395
396        let req = build_response_request("gpt-4.1", &context).unwrap();
397        let json = serde_json::to_value(&req).unwrap();
398
399        let items = json["input"].as_array().unwrap();
400        assert_eq!(items[0]["role"], "user");
401        assert_eq!(items[1]["type"], "function_call");
402        assert_eq!(items[1]["call_id"], "call_1");
403        assert_eq!(items[2]["type"], "function_call_output");
404        assert_eq!(items[2]["call_id"], "call_1");
405        assert_eq!(items[2]["output"], "Found results");
406
407        assert!(req.tools.is_some());
408        let tools_json = serde_json::to_value(&req.tools).unwrap();
409        assert_eq!(tools_json[0]["type"], "function");
410        assert_eq!(tools_json[0]["name"], "search");
411    }
412
413    #[test]
414    fn test_build_request_with_reasoning_effort() {
415        let mut context = Context::new(
416            vec![ChatMessage::User { content: vec![ContentBlock::text("Think")], timestamp: IsoString::now() }],
417            vec![],
418        );
419        context.set_reasoning_effort(Some(ReasoningEffort::High));
420
421        let req = build_response_request("o3", &context).unwrap();
422        let reasoning = req.reasoning.unwrap();
423        assert_eq!(reasoning.effort, Some(OaiReasoningEffort::High));
424        assert_eq!(reasoning.summary, Some(ReasoningSummary::Auto));
425    }
426
427    #[test]
428    fn test_build_request_with_audio_returns_unsupported_content() {
429        let context = Context::new(
430            vec![ChatMessage::User {
431                content: vec![ContentBlock::Audio { data: "YXVkaW8=".to_string(), mime_type: "audio/wav".to_string() }],
432                timestamp: IsoString::now(),
433            }],
434            vec![],
435        );
436
437        assert!(matches!(build_response_request("gpt-4.1", &context), Err(LlmError::UnsupportedContent(_))));
438    }
439
440    #[test]
441    fn test_map_tools_valid() {
442        let tools = vec![ToolDefinition {
443            name: "read_file".to_string(),
444            description: "Read a file".to_string(),
445            parameters: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#.to_string(),
446            server: None,
447        }];
448
449        let result = map_tools(&tools).unwrap();
450        assert_eq!(result.len(), 1);
451
452        let json = serde_json::to_value(&result[0]).unwrap();
453        assert_eq!(json["type"], "function");
454        assert_eq!(json["name"], "read_file");
455    }
456
457    #[test]
458    fn test_map_tools_invalid_json() {
459        let tools = vec![ToolDefinition {
460            name: "broken".to_string(),
461            description: "Broken".to_string(),
462            parameters: "not json{".to_string(),
463            server: None,
464        }];
465
466        let result = map_tools(&tools);
467        assert!(result.is_err());
468        match result.unwrap_err() {
469            LlmError::ToolParameterParsing { tool_name, .. } => {
470                assert_eq!(tool_name, "broken");
471            }
472            other => panic!("Expected ToolParameterParsing, got: {other}"),
473        }
474    }
475
476    #[test]
477    fn test_provider_display_name() {
478        let provider = OpenAiProvider { client: Client::new(), model: "gpt-4.1".to_string() };
479        assert_eq!(provider.display_name(), "OpenAI (gpt-4.1)");
480    }
481}