Skip to main content

llm/providers/openai/
responses_provider.rs

1use std::collections::HashMap;
2
3use async_openai::Client;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::responses::{
6    CreateResponse, EasyInputContent, EasyInputMessage, FunctionCallOutput, FunctionCallOutputItemParam, FunctionTool,
7    FunctionToolCall, ImageDetail, IncludeEnum, InputContent, InputImageContent, InputItem, InputParam,
8    InputTextContent, Item, MessageType, OutputItem, Reasoning, ReasoningEffort as OaiReasoningEffort,
9    ReasoningSummary, ResponseStreamEvent, ResponseUsage, Role, Tool,
10};
11use tokio_stream::StreamExt;
12use tracing::{debug, error};
13
14use crate::provider::get_context_window;
15use crate::providers::openai_compatible::AetherOpenAiConfig;
16use crate::{
17    ChatMessage, ContentBlock, Context, LlmError, LlmModel, LlmResponse, LlmResponseStream, ProviderAuthMode,
18    ProviderConnectionConfig, ProviderFactory, ReasoningEffort, Result, StopReason, StreamingModelProvider, TokenUsage,
19    ToolDefinition,
20};
21
22impl From<ResponseUsage> for TokenUsage {
23    fn from(usage: ResponseUsage) -> Self {
24        TokenUsage {
25            input_tokens: usage.input_tokens,
26            output_tokens: usage.output_tokens,
27            cache_read_tokens: Some(usage.input_tokens_details.cached_tokens),
28            reasoning_tokens: Some(usage.output_tokens_details.reasoning_tokens),
29            ..TokenUsage::default()
30        }
31    }
32}
33
34pub(crate) fn map_user_content_for_responses(parts: &[ContentBlock]) -> Result<EasyInputContent> {
35    let mut items = Vec::with_capacity(parts.len());
36    for p in parts {
37        match p {
38            ContentBlock::Text { text } => {
39                items.push(InputContent::InputText(InputTextContent { text: text.clone() }));
40            }
41            ContentBlock::Image { .. } => {
42                items.push(InputContent::InputImage(InputImageContent {
43                    detail: ImageDetail::Auto,
44                    file_id: None,
45                    image_url: Some(p.as_data_uri().unwrap()),
46                }));
47            }
48            ContentBlock::Audio { .. } => {
49                return Err(LlmError::UnsupportedContent("OpenAI Responses does not support audio input".into()));
50            }
51        }
52    }
53    Ok(EasyInputContent::ContentList(items))
54}
55
56pub struct OpenAiProvider {
57    client: Client<AetherOpenAiConfig>,
58    model: String,
59}
60
61impl ProviderFactory for OpenAiProvider {
62    async fn from_env() -> Result<Self> {
63        Self::from_env_with_connection(ProviderConnectionConfig::default()).await
64    }
65
66    async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
67        let api_key = match connection.auth_mode {
68            ProviderAuthMode::Default => {
69                std::env::var("OPENAI_API_KEY").map_err(|_| LlmError::MissingApiKey("OPENAI_API_KEY".to_string()))?
70            }
71            ProviderAuthMode::None => String::new(),
72        };
73
74        let mut config = OpenAIConfig::new().with_api_key(api_key);
75        if let Some(base_url) = connection.base_url {
76            config = config.with_api_base(base_url);
77        }
78        let config = AetherOpenAiConfig::new(config, connection.auth_mode);
79
80        Ok(Self { client: Client::with_config(config), model: "gpt-4.1".to_string() })
81    }
82
83    fn with_model(mut self, model: &str) -> Self {
84        if !model.is_empty() {
85            self.model = model.to_string();
86        }
87        self
88    }
89}
90
91impl StreamingModelProvider for OpenAiProvider {
92    fn stream_response(&self, context: &Context) -> LlmResponseStream {
93        let client = self.client.clone();
94        let model = self.model.clone();
95        let request = match build_response_request(&model, context) {
96            Ok(req) => req,
97            Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
98        };
99
100        Box::pin(async_stream::stream! {
101            debug!("Starting OpenAI Responses API stream for model: {model}");
102
103            let stream = match client.responses().create_stream(request).await {
104                Ok(s) => s,
105                Err(e) => {
106                    error!("Failed to create OpenAI Responses stream: {e:?}");
107                    yield Err(LlmError::ApiRequest(e.to_string()));
108                    return;
109                }
110            };
111
112            let mut stream = Box::pin(stream);
113            let mut fn_calls: HashMap<String, (String, String)> = HashMap::new();
114            let mut started = false;
115
116            while let Some(result) = stream.next().await {
117                match result {
118                    Ok(event) => {
119                        for response in process_event(event, &mut fn_calls, &mut started) {
120                            yield response;
121                        }
122                    }
123                    Err(e) => {
124                        yield Err(LlmError::ApiError(e.to_string()));
125                        break;
126                    }
127                }
128            }
129
130            if !started {
131                yield Ok(LlmResponse::done());
132            }
133        })
134    }
135
136    fn display_name(&self) -> String {
137        format!("OpenAI ({})", self.model)
138    }
139
140    fn context_window(&self) -> Option<u32> {
141        get_context_window("openai", &self.model)
142    }
143
144    fn model(&self) -> Option<LlmModel> {
145        format!("openai:{}", self.model).parse().ok()
146    }
147}
148
149fn process_event(
150    event: ResponseStreamEvent,
151    fn_calls: &mut HashMap<String, (String, String)>,
152    started: &mut bool,
153) -> Vec<Result<LlmResponse>> {
154    match event {
155        ResponseStreamEvent::ResponseCreated(e) => {
156            *started = true;
157            vec![Ok(LlmResponse::start(&e.response.id))]
158        }
159        ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
160            vec![Ok(LlmResponse::text(&e.delta))]
161        }
162        ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
163            vec![Ok(LlmResponse::reasoning(&e.delta))]
164        }
165        ResponseStreamEvent::ResponseOutputItemAdded(e) => {
166            if let OutputItem::FunctionCall(fc) = e.item {
167                let item_id = fc.id.clone().unwrap_or_default();
168                fn_calls.insert(item_id, (fc.call_id.clone(), fc.name.clone()));
169                vec![Ok(LlmResponse::tool_request_start(&fc.call_id, &fc.name))]
170            } else {
171                vec![]
172            }
173        }
174        ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
175            if let Some((call_id, _)) = fn_calls.get(&e.item_id) {
176                vec![Ok(LlmResponse::tool_request_arg(call_id, &e.delta))]
177            } else {
178                vec![]
179            }
180        }
181        ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
182            if let Some((call_id, name)) = fn_calls.remove(&e.item_id) {
183                let name = e.name.unwrap_or(name);
184                vec![Ok(LlmResponse::tool_request_complete(&call_id, &name, &e.arguments))]
185            } else {
186                vec![]
187            }
188        }
189        ResponseStreamEvent::ResponseCompleted(e) => {
190            let mut results = Vec::new();
191            if let Some(usage) = e.response.usage {
192                results.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
193            }
194            results.push(Ok(LlmResponse::done_with_stop_reason(StopReason::EndTurn)));
195            results
196        }
197        ResponseStreamEvent::ResponseFailed(e) => {
198            let msg = e.response.error.map_or_else(|| "Unknown error".to_string(), |err| err.message);
199            vec![Err(LlmError::ApiError(msg))]
200        }
201        ResponseStreamEvent::ResponseIncomplete(_) => {
202            vec![Ok(LlmResponse::done_with_stop_reason(StopReason::Length))]
203        }
204        ResponseStreamEvent::ResponseError(e) => {
205            vec![Err(LlmError::ApiError(e.message))]
206        }
207        _ => vec![],
208    }
209}
210
211fn build_response_request(model: &str, context: &Context) -> Result<CreateResponse> {
212    let mut instructions: Option<String> = None;
213    let mut items: Vec<InputItem> = Vec::new();
214
215    for msg in context.messages() {
216        match msg {
217            ChatMessage::System { content, .. } => {
218                instructions = Some(content.clone());
219            }
220            ChatMessage::User { content, .. } => {
221                items.push(InputItem::EasyMessage(EasyInputMessage {
222                    r#type: MessageType::Message,
223                    role: Role::User,
224                    content: map_user_content_for_responses(content)?,
225                    phase: None,
226                }));
227            }
228            ChatMessage::Assistant { content, tool_calls, .. } => {
229                if !content.is_empty() {
230                    items.push(InputItem::EasyMessage(EasyInputMessage {
231                        r#type: MessageType::Message,
232                        role: Role::Assistant,
233                        content: EasyInputContent::Text(content.clone()),
234                        phase: None,
235                    }));
236                }
237                for tc in tool_calls {
238                    items.push(InputItem::Item(Item::FunctionCall(FunctionToolCall {
239                        call_id: tc.id.clone(),
240                        name: tc.name.clone(),
241                        arguments: tc.arguments.clone(),
242                        namespace: None,
243                        id: None,
244                        status: None,
245                    })));
246                }
247            }
248            ChatMessage::ToolCallResult(result) => {
249                let (call_id, output) = match result {
250                    Ok(r) => (r.id.clone(), r.result.clone()),
251                    Err(e) => (e.id.clone(), e.error.clone()),
252                };
253                items.push(InputItem::Item(Item::FunctionCallOutput(FunctionCallOutputItemParam {
254                    call_id,
255                    output: FunctionCallOutput::Text(output),
256                    id: None,
257                    status: None,
258                })));
259            }
260            ChatMessage::Summary { content, .. } => {
261                items.push(InputItem::EasyMessage(EasyInputMessage {
262                    r#type: MessageType::Message,
263                    role: Role::User,
264                    content: EasyInputContent::Text(format!("[Previous conversation handoff]\n\n{content}")),
265                    phase: None,
266                }));
267            }
268            ChatMessage::Error { .. } => {}
269        }
270    }
271
272    let tools = map_tools(context.tools())?;
273
274    let reasoning = context
275        .reasoning_effort()
276        .map(|effort| Reasoning { effort: Some(map_reasoning_effort(effort)), summary: Some(ReasoningSummary::Auto) });
277
278    Ok(CreateResponse {
279        model: Some(model.to_string()),
280        input: InputParam::Items(items),
281        instructions,
282        tools: if tools.is_empty() { None } else { Some(tools) },
283        reasoning,
284        stream: Some(true),
285        include: Some(vec![IncludeEnum::ReasoningEncryptedContent]),
286        store: Some(false),
287        background: None,
288        conversation: None,
289        max_output_tokens: None,
290        metadata: None,
291        parallel_tool_calls: None,
292        previous_response_id: None,
293        prompt: None,
294        service_tier: None,
295        stream_options: None,
296        temperature: None,
297        text: None,
298        tool_choice: None,
299        top_p: None,
300        truncation: None,
301        prompt_cache_key: None,
302        safety_identifier: None,
303        max_tool_calls: None,
304        prompt_cache_retention: None,
305        top_logprobs: None,
306    })
307}
308
309fn map_tools(tools: &[ToolDefinition]) -> Result<Vec<Tool>> {
310    tools
311        .iter()
312        .map(|t| {
313            let parameters: serde_json::Value = serde_json::from_str(&t.parameters)
314                .map_err(|e| LlmError::ToolParameterParsing { tool_name: t.name.clone(), error: e.to_string() })?;
315
316            Ok(Tool::Function(FunctionTool {
317                name: t.name.clone(),
318                description: Some(t.description.clone()),
319                parameters: Some(parameters),
320                strict: Some(false),
321                defer_loading: None,
322            }))
323        })
324        .collect()
325}
326
327fn map_reasoning_effort(effort: ReasoningEffort) -> OaiReasoningEffort {
328    match effort {
329        ReasoningEffort::Low => OaiReasoningEffort::Low,
330        ReasoningEffort::Medium => OaiReasoningEffort::Medium,
331        ReasoningEffort::High => OaiReasoningEffort::High,
332        ReasoningEffort::Xhigh => OaiReasoningEffort::Xhigh,
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::AssistantReasoning;
340    use crate::ToolCallRequest;
341    use crate::types::IsoString;
342
343    #[test]
344    fn test_build_request_simple_user_message() {
345        let context = Context::new(
346            vec![ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() }],
347            vec![],
348        );
349
350        let req = build_response_request("gpt-4.1", &context).unwrap();
351        assert_eq!(req.model, Some("gpt-4.1".to_string()));
352        assert!(req.instructions.is_none());
353        assert!(req.tools.is_none());
354        assert!(req.reasoning.is_none());
355
356        let json = serde_json::to_value(&req).unwrap();
357        assert_eq!(json["input"][0]["role"], "user");
358        assert_eq!(json["input"][0]["content"][0]["text"], "Hello");
359    }
360
361    #[test]
362    fn test_build_request_with_system_message() {
363        let context = Context::new(
364            vec![
365                ChatMessage::System { content: "You are helpful.".to_string(), timestamp: IsoString::now() },
366                ChatMessage::User { content: vec![ContentBlock::text("Hi")], timestamp: IsoString::now() },
367            ],
368            vec![],
369        );
370
371        let req = build_response_request("gpt-4.1", &context).unwrap();
372        assert_eq!(req.instructions, Some("You are helpful.".to_string()));
373
374        let json = serde_json::to_value(&req).unwrap();
375        let items = json["input"].as_array().unwrap();
376        assert_eq!(items.len(), 1);
377        assert_eq!(items[0]["role"], "user");
378    }
379
380    #[test]
381    fn test_build_request_with_tool_calls() {
382        let context = Context::new(
383            vec![
384                ChatMessage::User { content: vec![ContentBlock::text("Search for rust")], timestamp: IsoString::now() },
385                ChatMessage::Assistant {
386                    content: String::new(),
387                    reasoning: AssistantReasoning::default(),
388                    timestamp: IsoString::now(),
389                    tool_calls: vec![ToolCallRequest {
390                        id: "call_1".to_string(),
391                        name: "search".to_string(),
392                        arguments: r#"{"q":"rust"}"#.to_string(),
393                    }],
394                },
395                ChatMessage::ToolCallResult(Ok(crate::ToolCallResult {
396                    id: "call_1".to_string(),
397                    name: "search".to_string(),
398                    arguments: r#"{"q":"rust"}"#.to_string(),
399                    result: "Found results".to_string(),
400                })),
401            ],
402            vec![ToolDefinition {
403                name: "search".to_string(),
404                description: "Search".to_string(),
405                parameters: r#"{"type":"object"}"#.to_string(),
406                server: None,
407            }],
408        );
409
410        let req = build_response_request("gpt-4.1", &context).unwrap();
411        let json = serde_json::to_value(&req).unwrap();
412
413        let items = json["input"].as_array().unwrap();
414        assert_eq!(items[0]["role"], "user");
415        assert_eq!(items[1]["type"], "function_call");
416        assert_eq!(items[1]["call_id"], "call_1");
417        assert_eq!(items[2]["type"], "function_call_output");
418        assert_eq!(items[2]["call_id"], "call_1");
419        assert_eq!(items[2]["output"], "Found results");
420
421        assert!(req.tools.is_some());
422        let tools_json = serde_json::to_value(&req.tools).unwrap();
423        assert_eq!(tools_json[0]["type"], "function");
424        assert_eq!(tools_json[0]["name"], "search");
425    }
426
427    #[test]
428    fn test_build_request_with_reasoning_effort() {
429        let mut context = Context::new(
430            vec![ChatMessage::User { content: vec![ContentBlock::text("Think")], timestamp: IsoString::now() }],
431            vec![],
432        );
433        context.set_reasoning_effort(Some(ReasoningEffort::High));
434
435        let req = build_response_request("o3", &context).unwrap();
436        let reasoning = req.reasoning.unwrap();
437        assert_eq!(reasoning.effort, Some(OaiReasoningEffort::High));
438        assert_eq!(reasoning.summary, Some(ReasoningSummary::Auto));
439    }
440
441    #[test]
442    fn test_build_request_with_audio_returns_unsupported_content() {
443        let context = Context::new(
444            vec![ChatMessage::User {
445                content: vec![ContentBlock::Audio { data: "YXVkaW8=".to_string(), mime_type: "audio/wav".to_string() }],
446                timestamp: IsoString::now(),
447            }],
448            vec![],
449        );
450
451        assert!(matches!(build_response_request("gpt-4.1", &context), Err(LlmError::UnsupportedContent(_))));
452    }
453
454    #[test]
455    fn test_map_tools_valid() {
456        let tools = vec![ToolDefinition {
457            name: "read_file".to_string(),
458            description: "Read a file".to_string(),
459            parameters: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#.to_string(),
460            server: None,
461        }];
462
463        let result = map_tools(&tools).unwrap();
464        assert_eq!(result.len(), 1);
465
466        let json = serde_json::to_value(&result[0]).unwrap();
467        assert_eq!(json["type"], "function");
468        assert_eq!(json["name"], "read_file");
469    }
470
471    #[test]
472    fn test_map_tools_invalid_json() {
473        let tools = vec![ToolDefinition {
474            name: "broken".to_string(),
475            description: "Broken".to_string(),
476            parameters: "not json{".to_string(),
477            server: None,
478        }];
479
480        let result = map_tools(&tools);
481        assert!(result.is_err());
482        match result.unwrap_err() {
483            LlmError::ToolParameterParsing { tool_name, .. } => {
484                assert_eq!(tool_name, "broken");
485            }
486            other => panic!("Expected ToolParameterParsing, got: {other}"),
487        }
488    }
489
490    #[test]
491    fn test_provider_display_name() {
492        let config = AetherOpenAiConfig::new(OpenAIConfig::new().with_api_key("test"), ProviderAuthMode::Default);
493        let provider = OpenAiProvider { client: Client::with_config(config), model: "gpt-4.1".to_string() };
494        assert_eq!(provider.display_name(), "OpenAI (gpt-4.1)");
495    }
496}