Skip to main content

llm/providers/openai/
responses_provider.rs

1use std::collections::HashMap;
2
3use async_openai::Client;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::responses::{
6    CreateResponse, EasyInputContent, EasyInputMessage, FunctionCallOutput,
7    FunctionCallOutputItemParam, FunctionTool, FunctionToolCall, ImageDetail, IncludeEnum,
8    InputContent, InputImageContent, InputItem, InputParam, InputTextContent, Item, MessageType,
9    OutputItem, Reasoning, ReasoningEffort as OaiReasoningEffort, ReasoningSummary,
10    ResponseStreamEvent, Role, Tool,
11};
12use tokio_stream::StreamExt;
13use tracing::{debug, error};
14
15use crate::provider::get_context_window;
16use crate::{
17    ChatMessage, ContentBlock, Context, LlmError, LlmModel, LlmResponse, LlmResponseStream,
18    ProviderFactory, ReasoningEffort, Result, StopReason, StreamingModelProvider, ToolDefinition,
19};
20
21pub(crate) fn map_user_content_for_responses(parts: &[ContentBlock]) -> Result<EasyInputContent> {
22    let mut items = Vec::with_capacity(parts.len());
23    for p in parts {
24        match p {
25            ContentBlock::Text { text } => {
26                items.push(InputContent::InputText(InputTextContent {
27                    text: text.clone(),
28                }));
29            }
30            ContentBlock::Image { .. } => {
31                items.push(InputContent::InputImage(InputImageContent {
32                    detail: ImageDetail::Auto,
33                    file_id: None,
34                    image_url: Some(p.as_data_uri().unwrap()),
35                }));
36            }
37            ContentBlock::Audio { .. } => {
38                return Err(LlmError::UnsupportedContent(
39                    "OpenAI Responses does not support audio input".into(),
40                ));
41            }
42        }
43    }
44    Ok(EasyInputContent::ContentList(items))
45}
46
47pub struct OpenAiProvider {
48    client: Client<OpenAIConfig>,
49    model: String,
50}
51
52impl ProviderFactory for OpenAiProvider {
53    fn from_env() -> Result<Self> {
54        let api_key = std::env::var("OPENAI_API_KEY")
55            .map_err(|_| LlmError::MissingApiKey("OPENAI_API_KEY".to_string()))?;
56
57        let config = OpenAIConfig::new().with_api_key(api_key);
58
59        Ok(Self {
60            client: Client::with_config(config),
61            model: "gpt-4.1".to_string(),
62        })
63    }
64
65    fn with_model(mut self, model: &str) -> Self {
66        if !model.is_empty() {
67            self.model = model.to_string();
68        }
69        self
70    }
71}
72
73impl StreamingModelProvider for OpenAiProvider {
74    fn stream_response(&self, context: &Context) -> LlmResponseStream {
75        let client = self.client.clone();
76        let model = self.model.clone();
77        let request = match build_response_request(&model, context) {
78            Ok(req) => req,
79            Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
80        };
81
82        Box::pin(async_stream::stream! {
83            debug!("Starting OpenAI Responses API stream for model: {model}");
84
85            let stream = match client.responses().create_stream(request).await {
86                Ok(s) => s,
87                Err(e) => {
88                    error!("Failed to create OpenAI Responses stream: {e:?}");
89                    yield Err(LlmError::ApiRequest(e.to_string()));
90                    return;
91                }
92            };
93
94            let mut stream = Box::pin(stream);
95            let mut fn_calls: HashMap<String, (String, String)> = HashMap::new();
96            let mut started = false;
97
98            while let Some(result) = stream.next().await {
99                match result {
100                    Ok(event) => {
101                        for response in process_event(event, &mut fn_calls, &mut started) {
102                            yield response;
103                        }
104                    }
105                    Err(e) => {
106                        yield Err(LlmError::ApiError(e.to_string()));
107                        break;
108                    }
109                }
110            }
111
112            if !started {
113                yield Ok(LlmResponse::done());
114            }
115        })
116    }
117
118    fn display_name(&self) -> String {
119        format!("OpenAI ({})", self.model)
120    }
121
122    fn context_window(&self) -> Option<u32> {
123        get_context_window("openai", &self.model)
124    }
125
126    fn model(&self) -> Option<LlmModel> {
127        format!("openai:{}", self.model).parse().ok()
128    }
129}
130
131fn process_event(
132    event: ResponseStreamEvent,
133    fn_calls: &mut HashMap<String, (String, String)>,
134    started: &mut bool,
135) -> Vec<Result<LlmResponse>> {
136    match event {
137        ResponseStreamEvent::ResponseCreated(e) => {
138            *started = true;
139            vec![Ok(LlmResponse::start(&e.response.id))]
140        }
141        ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
142            vec![Ok(LlmResponse::text(&e.delta))]
143        }
144        ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
145            vec![Ok(LlmResponse::reasoning(&e.delta))]
146        }
147        ResponseStreamEvent::ResponseOutputItemAdded(e) => {
148            if let OutputItem::FunctionCall(fc) = e.item {
149                let item_id = fc.id.clone().unwrap_or_default();
150                fn_calls.insert(item_id, (fc.call_id.clone(), fc.name.clone()));
151                vec![Ok(LlmResponse::tool_request_start(&fc.call_id, &fc.name))]
152            } else {
153                vec![]
154            }
155        }
156        ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
157            if let Some((call_id, _)) = fn_calls.get(&e.item_id) {
158                vec![Ok(LlmResponse::tool_request_arg(call_id, &e.delta))]
159            } else {
160                vec![]
161            }
162        }
163        ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
164            if let Some((call_id, name)) = fn_calls.remove(&e.item_id) {
165                let name = e.name.unwrap_or(name);
166                vec![Ok(LlmResponse::tool_request_complete(
167                    &call_id,
168                    &name,
169                    &e.arguments,
170                ))]
171            } else {
172                vec![]
173            }
174        }
175        ResponseStreamEvent::ResponseCompleted(e) => {
176            let mut results = Vec::new();
177            if let Some(usage) = e.response.usage {
178                let cached = usage.input_tokens_details.cached_tokens;
179                let cached = if cached > 0 { Some(cached) } else { None };
180                results.push(Ok(LlmResponse::usage_with_cache(
181                    usage.input_tokens,
182                    usage.output_tokens,
183                    cached,
184                )));
185            }
186            results.push(Ok(LlmResponse::done_with_stop_reason(StopReason::EndTurn)));
187            results
188        }
189        ResponseStreamEvent::ResponseFailed(e) => {
190            let msg = e
191                .response
192                .error
193                .map_or_else(|| "Unknown error".to_string(), |err| err.message);
194            vec![Err(LlmError::ApiError(msg))]
195        }
196        ResponseStreamEvent::ResponseIncomplete(_) => {
197            vec![Ok(LlmResponse::done_with_stop_reason(StopReason::Length))]
198        }
199        ResponseStreamEvent::ResponseError(e) => {
200            vec![Err(LlmError::ApiError(e.message))]
201        }
202        _ => vec![],
203    }
204}
205
206fn build_response_request(model: &str, context: &Context) -> Result<CreateResponse> {
207    let mut instructions: Option<String> = None;
208    let mut items: Vec<InputItem> = Vec::new();
209
210    for msg in context.messages() {
211        match msg {
212            ChatMessage::System { content, .. } => {
213                instructions = Some(content.clone());
214            }
215            ChatMessage::User { content, .. } => {
216                items.push(InputItem::EasyMessage(EasyInputMessage {
217                    r#type: MessageType::Message,
218                    role: Role::User,
219                    content: map_user_content_for_responses(content)?,
220                }));
221            }
222            ChatMessage::Assistant {
223                content,
224                tool_calls,
225                ..
226            } => {
227                if !content.is_empty() {
228                    items.push(InputItem::EasyMessage(EasyInputMessage {
229                        r#type: MessageType::Message,
230                        role: Role::Assistant,
231                        content: EasyInputContent::Text(content.clone()),
232                    }));
233                }
234                for tc in tool_calls {
235                    items.push(InputItem::Item(Item::FunctionCall(FunctionToolCall {
236                        call_id: tc.id.clone(),
237                        name: tc.name.clone(),
238                        arguments: tc.arguments.clone(),
239                        id: None,
240                        status: None,
241                    })));
242                }
243            }
244            ChatMessage::ToolCallResult(result) => {
245                let (call_id, output) = match result {
246                    Ok(r) => (r.id.clone(), r.result.clone()),
247                    Err(e) => (e.id.clone(), e.error.clone()),
248                };
249                items.push(InputItem::Item(Item::FunctionCallOutput(
250                    FunctionCallOutputItemParam {
251                        call_id,
252                        output: FunctionCallOutput::Text(output),
253                        id: None,
254                        status: None,
255                    },
256                )));
257            }
258            ChatMessage::Summary { content, .. } => {
259                items.push(InputItem::EasyMessage(EasyInputMessage {
260                    r#type: MessageType::Message,
261                    role: Role::User,
262                    content: EasyInputContent::Text(format!(
263                        "[Previous conversation handoff]\n\n{content}"
264                    )),
265                }));
266            }
267            ChatMessage::Error { .. } => {}
268        }
269    }
270
271    let tools = map_tools(context.tools())?;
272
273    let reasoning = context.reasoning_effort().map(|effort| Reasoning {
274        effort: Some(map_reasoning_effort(effort)),
275        summary: Some(ReasoningSummary::Auto),
276    });
277
278    Ok(CreateResponse {
279        model: Some(model.to_string()),
280        input: InputParam::Items(items),
281        instructions,
282        tools: if tools.is_empty() { None } else { Some(tools) },
283        reasoning,
284        stream: Some(true),
285        include: Some(vec![IncludeEnum::ReasoningEncryptedContent]),
286        store: Some(false),
287        background: None,
288        conversation: None,
289        max_output_tokens: None,
290        metadata: None,
291        parallel_tool_calls: None,
292        previous_response_id: None,
293        prompt: None,
294        service_tier: None,
295        stream_options: None,
296        temperature: None,
297        text: None,
298        tool_choice: None,
299        top_p: None,
300        truncation: None,
301        prompt_cache_key: None,
302        safety_identifier: None,
303        max_tool_calls: None,
304        prompt_cache_retention: None,
305        top_logprobs: None,
306    })
307}
308
309fn map_tools(tools: &[ToolDefinition]) -> Result<Vec<Tool>> {
310    tools
311        .iter()
312        .map(|t| {
313            let parameters: serde_json::Value =
314                serde_json::from_str(&t.parameters).map_err(|e| {
315                    LlmError::ToolParameterParsing {
316                        tool_name: t.name.clone(),
317                        error: e.to_string(),
318                    }
319                })?;
320
321            Ok(Tool::Function(FunctionTool {
322                name: t.name.clone(),
323                description: Some(t.description.clone()),
324                parameters: Some(parameters),
325                strict: Some(false),
326            }))
327        })
328        .collect()
329}
330
331fn map_reasoning_effort(effort: ReasoningEffort) -> OaiReasoningEffort {
332    match effort {
333        ReasoningEffort::Low => OaiReasoningEffort::Low,
334        ReasoningEffort::Medium => OaiReasoningEffort::Medium,
335        ReasoningEffort::High => OaiReasoningEffort::High,
336        ReasoningEffort::Xhigh => OaiReasoningEffort::Xhigh,
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::ToolCallRequest;
344    use crate::types::IsoString;
345
346    #[test]
347    fn test_build_request_simple_user_message() {
348        let context = Context::new(
349            vec![ChatMessage::User {
350                content: vec![ContentBlock::text("Hello")],
351                timestamp: IsoString::now(),
352            }],
353            vec![],
354        );
355
356        let req = build_response_request("gpt-4.1", &context).unwrap();
357        assert_eq!(req.model, Some("gpt-4.1".to_string()));
358        assert!(req.instructions.is_none());
359        assert!(req.tools.is_none());
360        assert!(req.reasoning.is_none());
361
362        let json = serde_json::to_value(&req).unwrap();
363        assert_eq!(json["input"][0]["role"], "user");
364        assert_eq!(json["input"][0]["content"][0]["text"], "Hello");
365    }
366
367    #[test]
368    fn test_build_request_with_system_message() {
369        let context = Context::new(
370            vec![
371                ChatMessage::System {
372                    content: "You are helpful.".to_string(),
373                    timestamp: IsoString::now(),
374                },
375                ChatMessage::User {
376                    content: vec![ContentBlock::text("Hi")],
377                    timestamp: IsoString::now(),
378                },
379            ],
380            vec![],
381        );
382
383        let req = build_response_request("gpt-4.1", &context).unwrap();
384        assert_eq!(req.instructions, Some("You are helpful.".to_string()));
385
386        let json = serde_json::to_value(&req).unwrap();
387        let items = json["input"].as_array().unwrap();
388        assert_eq!(items.len(), 1);
389        assert_eq!(items[0]["role"], "user");
390    }
391
392    #[test]
393    fn test_build_request_with_tool_calls() {
394        let context = Context::new(
395            vec![
396                ChatMessage::User {
397                    content: vec![ContentBlock::text("Search for rust")],
398                    timestamp: IsoString::now(),
399                },
400                ChatMessage::Assistant {
401                    content: String::new(),
402                    reasoning: Default::default(),
403                    timestamp: IsoString::now(),
404                    tool_calls: vec![ToolCallRequest {
405                        id: "call_1".to_string(),
406                        name: "search".to_string(),
407                        arguments: r#"{"q":"rust"}"#.to_string(),
408                    }],
409                },
410                ChatMessage::ToolCallResult(Ok(crate::ToolCallResult {
411                    id: "call_1".to_string(),
412                    name: "search".to_string(),
413                    arguments: r#"{"q":"rust"}"#.to_string(),
414                    result: "Found results".to_string(),
415                })),
416            ],
417            vec![ToolDefinition {
418                name: "search".to_string(),
419                description: "Search".to_string(),
420                parameters: r#"{"type":"object"}"#.to_string(),
421                server: None,
422            }],
423        );
424
425        let req = build_response_request("gpt-4.1", &context).unwrap();
426        let json = serde_json::to_value(&req).unwrap();
427
428        let items = json["input"].as_array().unwrap();
429        assert_eq!(items[0]["role"], "user");
430        assert_eq!(items[1]["type"], "function_call");
431        assert_eq!(items[1]["call_id"], "call_1");
432        assert_eq!(items[2]["type"], "function_call_output");
433        assert_eq!(items[2]["call_id"], "call_1");
434        assert_eq!(items[2]["output"], "Found results");
435
436        assert!(req.tools.is_some());
437        let tools_json = serde_json::to_value(&req.tools).unwrap();
438        assert_eq!(tools_json[0]["type"], "function");
439        assert_eq!(tools_json[0]["name"], "search");
440    }
441
442    #[test]
443    fn test_build_request_with_reasoning_effort() {
444        let mut context = Context::new(
445            vec![ChatMessage::User {
446                content: vec![ContentBlock::text("Think")],
447                timestamp: IsoString::now(),
448            }],
449            vec![],
450        );
451        context.set_reasoning_effort(Some(ReasoningEffort::High));
452
453        let req = build_response_request("o3", &context).unwrap();
454        let reasoning = req.reasoning.unwrap();
455        assert_eq!(reasoning.effort, Some(OaiReasoningEffort::High));
456        assert_eq!(reasoning.summary, Some(ReasoningSummary::Auto));
457    }
458
459    #[test]
460    fn test_build_request_with_audio_returns_unsupported_content() {
461        let context = Context::new(
462            vec![ChatMessage::User {
463                content: vec![ContentBlock::Audio {
464                    data: "YXVkaW8=".to_string(),
465                    mime_type: "audio/wav".to_string(),
466                }],
467                timestamp: IsoString::now(),
468            }],
469            vec![],
470        );
471
472        assert!(matches!(
473            build_response_request("gpt-4.1", &context),
474            Err(LlmError::UnsupportedContent(_))
475        ));
476    }
477
478    #[test]
479    fn test_map_tools_valid() {
480        let tools = vec![ToolDefinition {
481            name: "read_file".to_string(),
482            description: "Read a file".to_string(),
483            parameters: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#.to_string(),
484            server: None,
485        }];
486
487        let result = map_tools(&tools).unwrap();
488        assert_eq!(result.len(), 1);
489
490        let json = serde_json::to_value(&result[0]).unwrap();
491        assert_eq!(json["type"], "function");
492        assert_eq!(json["name"], "read_file");
493    }
494
495    #[test]
496    fn test_map_tools_invalid_json() {
497        let tools = vec![ToolDefinition {
498            name: "broken".to_string(),
499            description: "Broken".to_string(),
500            parameters: "not json{".to_string(),
501            server: None,
502        }];
503
504        let result = map_tools(&tools);
505        assert!(result.is_err());
506        match result.unwrap_err() {
507            LlmError::ToolParameterParsing { tool_name, .. } => {
508                assert_eq!(tool_name, "broken");
509            }
510            other => panic!("Expected ToolParameterParsing, got: {other}"),
511        }
512    }
513
514    #[test]
515    fn test_provider_display_name() {
516        let provider = OpenAiProvider {
517            client: Client::new(),
518            model: "gpt-4.1".to_string(),
519        };
520        assert_eq!(provider.display_name(), "OpenAI (gpt-4.1)");
521    }
522}