1use std::collections::HashMap;
2
3use async_openai::Client;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::responses::{
6 CreateResponse, EasyInputContent, EasyInputMessage, FunctionCallOutput, FunctionCallOutputItemParam, FunctionTool,
7 FunctionToolCall, ImageDetail, IncludeEnum, InputContent, InputImageContent, InputItem, InputParam,
8 InputTextContent, Item, MessageType, OutputItem, Reasoning, ReasoningEffort as OaiReasoningEffort,
9 ReasoningSummary, ResponseStreamEvent, ResponseUsage, Role, Tool,
10};
11use tokio_stream::StreamExt;
12use tracing::{debug, error};
13
14use crate::provider::get_context_window;
15use crate::providers::openai_compatible::AetherOpenAiConfig;
16use crate::{
17 ChatMessage, ContentBlock, Context, LlmError, LlmModel, LlmResponse, LlmResponseStream, ProviderAuthMode,
18 ProviderConnectionConfig, ProviderFactory, ReasoningEffort, Result, StopReason, StreamingModelProvider, TokenUsage,
19 ToolDefinition,
20};
21
22impl From<ResponseUsage> for TokenUsage {
23 fn from(usage: ResponseUsage) -> Self {
24 TokenUsage {
25 input_tokens: usage.input_tokens,
26 output_tokens: usage.output_tokens,
27 cache_read_tokens: Some(usage.input_tokens_details.cached_tokens),
28 reasoning_tokens: Some(usage.output_tokens_details.reasoning_tokens),
29 ..TokenUsage::default()
30 }
31 }
32}
33
34pub(crate) fn map_user_content_for_responses(parts: &[ContentBlock]) -> Result<EasyInputContent> {
35 let mut items = Vec::with_capacity(parts.len());
36 for p in parts {
37 match p {
38 ContentBlock::Text { text } => {
39 items.push(InputContent::InputText(InputTextContent { text: text.clone() }));
40 }
41 ContentBlock::Image { .. } => {
42 items.push(InputContent::InputImage(InputImageContent {
43 detail: ImageDetail::Auto,
44 file_id: None,
45 image_url: Some(p.as_data_uri().unwrap()),
46 }));
47 }
48 ContentBlock::Audio { .. } => {
49 return Err(LlmError::UnsupportedContent("OpenAI Responses does not support audio input".into()));
50 }
51 }
52 }
53 Ok(EasyInputContent::ContentList(items))
54}
55
56pub struct OpenAiProvider {
57 client: Client<AetherOpenAiConfig>,
58 model: String,
59}
60
61impl ProviderFactory for OpenAiProvider {
62 async fn from_env() -> Result<Self> {
63 Self::from_env_with_connection(ProviderConnectionConfig::default()).await
64 }
65
66 async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
67 let api_key = match connection.auth_mode {
68 ProviderAuthMode::Default => {
69 std::env::var("OPENAI_API_KEY").map_err(|_| LlmError::MissingApiKey("OPENAI_API_KEY".to_string()))?
70 }
71 ProviderAuthMode::None => String::new(),
72 };
73
74 let mut config = OpenAIConfig::new().with_api_key(api_key);
75 if let Some(base_url) = connection.base_url {
76 config = config.with_api_base(base_url);
77 }
78 let config = AetherOpenAiConfig::new(config, connection.auth_mode);
79
80 Ok(Self { client: Client::with_config(config), model: "gpt-4.1".to_string() })
81 }
82
83 fn with_model(mut self, model: &str) -> Self {
84 if !model.is_empty() {
85 self.model = model.to_string();
86 }
87 self
88 }
89}
90
91impl StreamingModelProvider for OpenAiProvider {
92 fn stream_response(&self, context: &Context) -> LlmResponseStream {
93 let client = self.client.clone();
94 let model = self.model.clone();
95 let request = match build_response_request(&model, context) {
96 Ok(req) => req,
97 Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
98 };
99
100 Box::pin(async_stream::stream! {
101 debug!("Starting OpenAI Responses API stream for model: {model}");
102
103 let stream = match client.responses().create_stream(request).await {
104 Ok(s) => s,
105 Err(e) => {
106 error!("Failed to create OpenAI Responses stream: {e:?}");
107 yield Err(LlmError::ApiRequest(e.to_string()));
108 return;
109 }
110 };
111
112 let mut stream = Box::pin(stream);
113 let mut fn_calls: HashMap<String, (String, String)> = HashMap::new();
114 let mut started = false;
115
116 while let Some(result) = stream.next().await {
117 match result {
118 Ok(event) => {
119 for response in process_event(event, &mut fn_calls, &mut started) {
120 yield response;
121 }
122 }
123 Err(e) => {
124 yield Err(LlmError::ApiError(e.to_string()));
125 break;
126 }
127 }
128 }
129
130 if !started {
131 yield Ok(LlmResponse::done());
132 }
133 })
134 }
135
136 fn display_name(&self) -> String {
137 format!("OpenAI ({})", self.model)
138 }
139
140 fn context_window(&self) -> Option<u32> {
141 get_context_window("openai", &self.model)
142 }
143
144 fn model(&self) -> Option<LlmModel> {
145 format!("openai:{}", self.model).parse().ok()
146 }
147}
148
149fn process_event(
150 event: ResponseStreamEvent,
151 fn_calls: &mut HashMap<String, (String, String)>,
152 started: &mut bool,
153) -> Vec<Result<LlmResponse>> {
154 match event {
155 ResponseStreamEvent::ResponseCreated(e) => {
156 *started = true;
157 vec![Ok(LlmResponse::start(&e.response.id))]
158 }
159 ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
160 vec![Ok(LlmResponse::text(&e.delta))]
161 }
162 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
163 vec![Ok(LlmResponse::reasoning(&e.delta))]
164 }
165 ResponseStreamEvent::ResponseOutputItemAdded(e) => {
166 if let OutputItem::FunctionCall(fc) = e.item {
167 let item_id = fc.id.clone().unwrap_or_default();
168 fn_calls.insert(item_id, (fc.call_id.clone(), fc.name.clone()));
169 vec![Ok(LlmResponse::tool_request_start(&fc.call_id, &fc.name))]
170 } else {
171 vec![]
172 }
173 }
174 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
175 if let Some((call_id, _)) = fn_calls.get(&e.item_id) {
176 vec![Ok(LlmResponse::tool_request_arg(call_id, &e.delta))]
177 } else {
178 vec![]
179 }
180 }
181 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
182 if let Some((call_id, name)) = fn_calls.remove(&e.item_id) {
183 let name = e.name.unwrap_or(name);
184 vec![Ok(LlmResponse::tool_request_complete(&call_id, &name, &e.arguments))]
185 } else {
186 vec![]
187 }
188 }
189 ResponseStreamEvent::ResponseCompleted(e) => {
190 let mut results = Vec::new();
191 if let Some(usage) = e.response.usage {
192 results.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
193 }
194 results.push(Ok(LlmResponse::done_with_stop_reason(StopReason::EndTurn)));
195 results
196 }
197 ResponseStreamEvent::ResponseFailed(e) => {
198 let msg = e.response.error.map_or_else(|| "Unknown error".to_string(), |err| err.message);
199 vec![Err(LlmError::ApiError(msg))]
200 }
201 ResponseStreamEvent::ResponseIncomplete(_) => {
202 vec![Ok(LlmResponse::done_with_stop_reason(StopReason::Length))]
203 }
204 ResponseStreamEvent::ResponseError(e) => {
205 vec![Err(LlmError::ApiError(e.message))]
206 }
207 _ => vec![],
208 }
209}
210
211fn build_response_request(model: &str, context: &Context) -> Result<CreateResponse> {
212 let mut instructions: Option<String> = None;
213 let mut items: Vec<InputItem> = Vec::new();
214
215 for msg in context.messages() {
216 match msg {
217 ChatMessage::System { content, .. } => {
218 instructions = Some(content.clone());
219 }
220 ChatMessage::User { content, .. } => {
221 items.push(InputItem::EasyMessage(EasyInputMessage {
222 r#type: MessageType::Message,
223 role: Role::User,
224 content: map_user_content_for_responses(content)?,
225 phase: None,
226 }));
227 }
228 ChatMessage::Assistant { content, tool_calls, .. } => {
229 if !content.is_empty() {
230 items.push(InputItem::EasyMessage(EasyInputMessage {
231 r#type: MessageType::Message,
232 role: Role::Assistant,
233 content: EasyInputContent::Text(content.clone()),
234 phase: None,
235 }));
236 }
237 for tc in tool_calls {
238 items.push(InputItem::Item(Item::FunctionCall(FunctionToolCall {
239 call_id: tc.id.clone(),
240 name: tc.name.clone(),
241 arguments: tc.arguments.clone(),
242 namespace: None,
243 id: None,
244 status: None,
245 })));
246 }
247 }
248 ChatMessage::ToolCallResult(result) => {
249 let (call_id, output) = match result {
250 Ok(r) => (r.id.clone(), r.result.clone()),
251 Err(e) => (e.id.clone(), e.error.clone()),
252 };
253 items.push(InputItem::Item(Item::FunctionCallOutput(FunctionCallOutputItemParam {
254 call_id,
255 output: FunctionCallOutput::Text(output),
256 id: None,
257 status: None,
258 })));
259 }
260 ChatMessage::Summary { content, .. } => {
261 items.push(InputItem::EasyMessage(EasyInputMessage {
262 r#type: MessageType::Message,
263 role: Role::User,
264 content: EasyInputContent::Text(format!("[Previous conversation handoff]\n\n{content}")),
265 phase: None,
266 }));
267 }
268 ChatMessage::Error { .. } => {}
269 }
270 }
271
272 let tools = map_tools(context.tools())?;
273
274 let reasoning = context
275 .reasoning_effort()
276 .map(|effort| Reasoning { effort: Some(map_reasoning_effort(effort)), summary: Some(ReasoningSummary::Auto) });
277
278 Ok(CreateResponse {
279 model: Some(model.to_string()),
280 input: InputParam::Items(items),
281 instructions,
282 tools: if tools.is_empty() { None } else { Some(tools) },
283 reasoning,
284 stream: Some(true),
285 include: Some(vec![IncludeEnum::ReasoningEncryptedContent]),
286 store: Some(false),
287 background: None,
288 conversation: None,
289 max_output_tokens: None,
290 metadata: None,
291 parallel_tool_calls: None,
292 previous_response_id: None,
293 prompt: None,
294 service_tier: None,
295 stream_options: None,
296 temperature: None,
297 text: None,
298 tool_choice: None,
299 top_p: None,
300 truncation: None,
301 prompt_cache_key: None,
302 safety_identifier: None,
303 max_tool_calls: None,
304 prompt_cache_retention: None,
305 top_logprobs: None,
306 })
307}
308
309fn map_tools(tools: &[ToolDefinition]) -> Result<Vec<Tool>> {
310 tools
311 .iter()
312 .map(|t| {
313 let parameters: serde_json::Value = serde_json::from_str(&t.parameters)
314 .map_err(|e| LlmError::ToolParameterParsing { tool_name: t.name.clone(), error: e.to_string() })?;
315
316 Ok(Tool::Function(FunctionTool {
317 name: t.name.clone(),
318 description: Some(t.description.clone()),
319 parameters: Some(parameters),
320 strict: Some(false),
321 defer_loading: None,
322 }))
323 })
324 .collect()
325}
326
327fn map_reasoning_effort(effort: ReasoningEffort) -> OaiReasoningEffort {
328 match effort {
329 ReasoningEffort::Low => OaiReasoningEffort::Low,
330 ReasoningEffort::Medium => OaiReasoningEffort::Medium,
331 ReasoningEffort::High => OaiReasoningEffort::High,
332 ReasoningEffort::Xhigh => OaiReasoningEffort::Xhigh,
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::AssistantReasoning;
340 use crate::ToolCallRequest;
341 use crate::types::IsoString;
342
343 #[test]
344 fn test_build_request_simple_user_message() {
345 let context = Context::new(
346 vec![ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() }],
347 vec![],
348 );
349
350 let req = build_response_request("gpt-4.1", &context).unwrap();
351 assert_eq!(req.model, Some("gpt-4.1".to_string()));
352 assert!(req.instructions.is_none());
353 assert!(req.tools.is_none());
354 assert!(req.reasoning.is_none());
355
356 let json = serde_json::to_value(&req).unwrap();
357 assert_eq!(json["input"][0]["role"], "user");
358 assert_eq!(json["input"][0]["content"][0]["text"], "Hello");
359 }
360
361 #[test]
362 fn test_build_request_with_system_message() {
363 let context = Context::new(
364 vec![
365 ChatMessage::System { content: "You are helpful.".to_string(), timestamp: IsoString::now() },
366 ChatMessage::User { content: vec![ContentBlock::text("Hi")], timestamp: IsoString::now() },
367 ],
368 vec![],
369 );
370
371 let req = build_response_request("gpt-4.1", &context).unwrap();
372 assert_eq!(req.instructions, Some("You are helpful.".to_string()));
373
374 let json = serde_json::to_value(&req).unwrap();
375 let items = json["input"].as_array().unwrap();
376 assert_eq!(items.len(), 1);
377 assert_eq!(items[0]["role"], "user");
378 }
379
380 #[test]
381 fn test_build_request_with_tool_calls() {
382 let context = Context::new(
383 vec![
384 ChatMessage::User { content: vec![ContentBlock::text("Search for rust")], timestamp: IsoString::now() },
385 ChatMessage::Assistant {
386 content: String::new(),
387 reasoning: AssistantReasoning::default(),
388 timestamp: IsoString::now(),
389 tool_calls: vec![ToolCallRequest {
390 id: "call_1".to_string(),
391 name: "search".to_string(),
392 arguments: r#"{"q":"rust"}"#.to_string(),
393 }],
394 },
395 ChatMessage::ToolCallResult(Ok(crate::ToolCallResult {
396 id: "call_1".to_string(),
397 name: "search".to_string(),
398 arguments: r#"{"q":"rust"}"#.to_string(),
399 result: "Found results".to_string(),
400 })),
401 ],
402 vec![ToolDefinition {
403 name: "search".to_string(),
404 description: "Search".to_string(),
405 parameters: r#"{"type":"object"}"#.to_string(),
406 server: None,
407 }],
408 );
409
410 let req = build_response_request("gpt-4.1", &context).unwrap();
411 let json = serde_json::to_value(&req).unwrap();
412
413 let items = json["input"].as_array().unwrap();
414 assert_eq!(items[0]["role"], "user");
415 assert_eq!(items[1]["type"], "function_call");
416 assert_eq!(items[1]["call_id"], "call_1");
417 assert_eq!(items[2]["type"], "function_call_output");
418 assert_eq!(items[2]["call_id"], "call_1");
419 assert_eq!(items[2]["output"], "Found results");
420
421 assert!(req.tools.is_some());
422 let tools_json = serde_json::to_value(&req.tools).unwrap();
423 assert_eq!(tools_json[0]["type"], "function");
424 assert_eq!(tools_json[0]["name"], "search");
425 }
426
427 #[test]
428 fn test_build_request_with_reasoning_effort() {
429 let mut context = Context::new(
430 vec![ChatMessage::User { content: vec![ContentBlock::text("Think")], timestamp: IsoString::now() }],
431 vec![],
432 );
433 context.set_reasoning_effort(Some(ReasoningEffort::High));
434
435 let req = build_response_request("o3", &context).unwrap();
436 let reasoning = req.reasoning.unwrap();
437 assert_eq!(reasoning.effort, Some(OaiReasoningEffort::High));
438 assert_eq!(reasoning.summary, Some(ReasoningSummary::Auto));
439 }
440
441 #[test]
442 fn test_build_request_with_audio_returns_unsupported_content() {
443 let context = Context::new(
444 vec![ChatMessage::User {
445 content: vec![ContentBlock::Audio { data: "YXVkaW8=".to_string(), mime_type: "audio/wav".to_string() }],
446 timestamp: IsoString::now(),
447 }],
448 vec![],
449 );
450
451 assert!(matches!(build_response_request("gpt-4.1", &context), Err(LlmError::UnsupportedContent(_))));
452 }
453
454 #[test]
455 fn test_map_tools_valid() {
456 let tools = vec![ToolDefinition {
457 name: "read_file".to_string(),
458 description: "Read a file".to_string(),
459 parameters: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#.to_string(),
460 server: None,
461 }];
462
463 let result = map_tools(&tools).unwrap();
464 assert_eq!(result.len(), 1);
465
466 let json = serde_json::to_value(&result[0]).unwrap();
467 assert_eq!(json["type"], "function");
468 assert_eq!(json["name"], "read_file");
469 }
470
471 #[test]
472 fn test_map_tools_invalid_json() {
473 let tools = vec![ToolDefinition {
474 name: "broken".to_string(),
475 description: "Broken".to_string(),
476 parameters: "not json{".to_string(),
477 server: None,
478 }];
479
480 let result = map_tools(&tools);
481 assert!(result.is_err());
482 match result.unwrap_err() {
483 LlmError::ToolParameterParsing { tool_name, .. } => {
484 assert_eq!(tool_name, "broken");
485 }
486 other => panic!("Expected ToolParameterParsing, got: {other}"),
487 }
488 }
489
490 #[test]
491 fn test_provider_display_name() {
492 let config = AetherOpenAiConfig::new(OpenAIConfig::new().with_api_key("test"), ProviderAuthMode::Default);
493 let provider = OpenAiProvider { client: Client::with_config(config), model: "gpt-4.1".to_string() };
494 assert_eq!(provider.display_name(), "OpenAI (gpt-4.1)");
495 }
496}