1use std::collections::HashMap;
2
3use async_openai::Client;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::responses::{
6 CreateResponse, EasyInputContent, EasyInputMessage, FunctionCallOutput,
7 FunctionCallOutputItemParam, FunctionTool, FunctionToolCall, ImageDetail, IncludeEnum,
8 InputContent, InputImageContent, InputItem, InputParam, InputTextContent, Item, MessageType,
9 OutputItem, Reasoning, ReasoningEffort as OaiReasoningEffort, ReasoningSummary,
10 ResponseStreamEvent, Role, Tool,
11};
12use tokio_stream::StreamExt;
13use tracing::{debug, error};
14
15use crate::provider::get_context_window;
16use crate::{
17 ChatMessage, ContentBlock, Context, LlmError, LlmModel, LlmResponse, LlmResponseStream,
18 ProviderFactory, ReasoningEffort, Result, StopReason, StreamingModelProvider, ToolDefinition,
19};
20
21pub(crate) fn map_user_content_for_responses(parts: &[ContentBlock]) -> Result<EasyInputContent> {
22 let mut items = Vec::with_capacity(parts.len());
23 for p in parts {
24 match p {
25 ContentBlock::Text { text } => {
26 items.push(InputContent::InputText(InputTextContent {
27 text: text.clone(),
28 }));
29 }
30 ContentBlock::Image { .. } => {
31 items.push(InputContent::InputImage(InputImageContent {
32 detail: ImageDetail::Auto,
33 file_id: None,
34 image_url: Some(p.as_data_uri().unwrap()),
35 }));
36 }
37 ContentBlock::Audio { .. } => {
38 return Err(LlmError::UnsupportedContent(
39 "OpenAI Responses does not support audio input".into(),
40 ));
41 }
42 }
43 }
44 Ok(EasyInputContent::ContentList(items))
45}
46
47pub struct OpenAiProvider {
48 client: Client<OpenAIConfig>,
49 model: String,
50}
51
52impl ProviderFactory for OpenAiProvider {
53 fn from_env() -> Result<Self> {
54 let api_key = std::env::var("OPENAI_API_KEY")
55 .map_err(|_| LlmError::MissingApiKey("OPENAI_API_KEY".to_string()))?;
56
57 let config = OpenAIConfig::new().with_api_key(api_key);
58
59 Ok(Self {
60 client: Client::with_config(config),
61 model: "gpt-4.1".to_string(),
62 })
63 }
64
65 fn with_model(mut self, model: &str) -> Self {
66 if !model.is_empty() {
67 self.model = model.to_string();
68 }
69 self
70 }
71}
72
73impl StreamingModelProvider for OpenAiProvider {
74 fn stream_response(&self, context: &Context) -> LlmResponseStream {
75 let client = self.client.clone();
76 let model = self.model.clone();
77 let request = match build_response_request(&model, context) {
78 Ok(req) => req,
79 Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
80 };
81
82 Box::pin(async_stream::stream! {
83 debug!("Starting OpenAI Responses API stream for model: {model}");
84
85 let stream = match client.responses().create_stream(request).await {
86 Ok(s) => s,
87 Err(e) => {
88 error!("Failed to create OpenAI Responses stream: {e:?}");
89 yield Err(LlmError::ApiRequest(e.to_string()));
90 return;
91 }
92 };
93
94 let mut stream = Box::pin(stream);
95 let mut fn_calls: HashMap<String, (String, String)> = HashMap::new();
96 let mut started = false;
97
98 while let Some(result) = stream.next().await {
99 match result {
100 Ok(event) => {
101 for response in process_event(event, &mut fn_calls, &mut started) {
102 yield response;
103 }
104 }
105 Err(e) => {
106 yield Err(LlmError::ApiError(e.to_string()));
107 break;
108 }
109 }
110 }
111
112 if !started {
113 yield Ok(LlmResponse::done());
114 }
115 })
116 }
117
118 fn display_name(&self) -> String {
119 format!("OpenAI ({})", self.model)
120 }
121
122 fn context_window(&self) -> Option<u32> {
123 get_context_window("openai", &self.model)
124 }
125
126 fn model(&self) -> Option<LlmModel> {
127 format!("openai:{}", self.model).parse().ok()
128 }
129}
130
131fn process_event(
132 event: ResponseStreamEvent,
133 fn_calls: &mut HashMap<String, (String, String)>,
134 started: &mut bool,
135) -> Vec<Result<LlmResponse>> {
136 match event {
137 ResponseStreamEvent::ResponseCreated(e) => {
138 *started = true;
139 vec![Ok(LlmResponse::start(&e.response.id))]
140 }
141 ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
142 vec![Ok(LlmResponse::text(&e.delta))]
143 }
144 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
145 vec![Ok(LlmResponse::reasoning(&e.delta))]
146 }
147 ResponseStreamEvent::ResponseOutputItemAdded(e) => {
148 if let OutputItem::FunctionCall(fc) = e.item {
149 let item_id = fc.id.clone().unwrap_or_default();
150 fn_calls.insert(item_id, (fc.call_id.clone(), fc.name.clone()));
151 vec![Ok(LlmResponse::tool_request_start(&fc.call_id, &fc.name))]
152 } else {
153 vec![]
154 }
155 }
156 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
157 if let Some((call_id, _)) = fn_calls.get(&e.item_id) {
158 vec![Ok(LlmResponse::tool_request_arg(call_id, &e.delta))]
159 } else {
160 vec![]
161 }
162 }
163 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
164 if let Some((call_id, name)) = fn_calls.remove(&e.item_id) {
165 let name = e.name.unwrap_or(name);
166 vec![Ok(LlmResponse::tool_request_complete(
167 &call_id,
168 &name,
169 &e.arguments,
170 ))]
171 } else {
172 vec![]
173 }
174 }
175 ResponseStreamEvent::ResponseCompleted(e) => {
176 let mut results = Vec::new();
177 if let Some(usage) = e.response.usage {
178 let cached = usage.input_tokens_details.cached_tokens;
179 let cached = if cached > 0 { Some(cached) } else { None };
180 results.push(Ok(LlmResponse::usage_with_cache(
181 usage.input_tokens,
182 usage.output_tokens,
183 cached,
184 )));
185 }
186 results.push(Ok(LlmResponse::done_with_stop_reason(StopReason::EndTurn)));
187 results
188 }
189 ResponseStreamEvent::ResponseFailed(e) => {
190 let msg = e
191 .response
192 .error
193 .map_or_else(|| "Unknown error".to_string(), |err| err.message);
194 vec![Err(LlmError::ApiError(msg))]
195 }
196 ResponseStreamEvent::ResponseIncomplete(_) => {
197 vec![Ok(LlmResponse::done_with_stop_reason(StopReason::Length))]
198 }
199 ResponseStreamEvent::ResponseError(e) => {
200 vec![Err(LlmError::ApiError(e.message))]
201 }
202 _ => vec![],
203 }
204}
205
206fn build_response_request(model: &str, context: &Context) -> Result<CreateResponse> {
207 let mut instructions: Option<String> = None;
208 let mut items: Vec<InputItem> = Vec::new();
209
210 for msg in context.messages() {
211 match msg {
212 ChatMessage::System { content, .. } => {
213 instructions = Some(content.clone());
214 }
215 ChatMessage::User { content, .. } => {
216 items.push(InputItem::EasyMessage(EasyInputMessage {
217 r#type: MessageType::Message,
218 role: Role::User,
219 content: map_user_content_for_responses(content)?,
220 }));
221 }
222 ChatMessage::Assistant {
223 content,
224 tool_calls,
225 ..
226 } => {
227 if !content.is_empty() {
228 items.push(InputItem::EasyMessage(EasyInputMessage {
229 r#type: MessageType::Message,
230 role: Role::Assistant,
231 content: EasyInputContent::Text(content.clone()),
232 }));
233 }
234 for tc in tool_calls {
235 items.push(InputItem::Item(Item::FunctionCall(FunctionToolCall {
236 call_id: tc.id.clone(),
237 name: tc.name.clone(),
238 arguments: tc.arguments.clone(),
239 id: None,
240 status: None,
241 })));
242 }
243 }
244 ChatMessage::ToolCallResult(result) => {
245 let (call_id, output) = match result {
246 Ok(r) => (r.id.clone(), r.result.clone()),
247 Err(e) => (e.id.clone(), e.error.clone()),
248 };
249 items.push(InputItem::Item(Item::FunctionCallOutput(
250 FunctionCallOutputItemParam {
251 call_id,
252 output: FunctionCallOutput::Text(output),
253 id: None,
254 status: None,
255 },
256 )));
257 }
258 ChatMessage::Summary { content, .. } => {
259 items.push(InputItem::EasyMessage(EasyInputMessage {
260 r#type: MessageType::Message,
261 role: Role::User,
262 content: EasyInputContent::Text(format!(
263 "[Previous conversation handoff]\n\n{content}"
264 )),
265 }));
266 }
267 ChatMessage::Error { .. } => {}
268 }
269 }
270
271 let tools = map_tools(context.tools())?;
272
273 let reasoning = context.reasoning_effort().map(|effort| Reasoning {
274 effort: Some(map_reasoning_effort(effort)),
275 summary: Some(ReasoningSummary::Auto),
276 });
277
278 Ok(CreateResponse {
279 model: Some(model.to_string()),
280 input: InputParam::Items(items),
281 instructions,
282 tools: if tools.is_empty() { None } else { Some(tools) },
283 reasoning,
284 stream: Some(true),
285 include: Some(vec![IncludeEnum::ReasoningEncryptedContent]),
286 store: Some(false),
287 background: None,
288 conversation: None,
289 max_output_tokens: None,
290 metadata: None,
291 parallel_tool_calls: None,
292 previous_response_id: None,
293 prompt: None,
294 service_tier: None,
295 stream_options: None,
296 temperature: None,
297 text: None,
298 tool_choice: None,
299 top_p: None,
300 truncation: None,
301 prompt_cache_key: None,
302 safety_identifier: None,
303 max_tool_calls: None,
304 prompt_cache_retention: None,
305 top_logprobs: None,
306 })
307}
308
309fn map_tools(tools: &[ToolDefinition]) -> Result<Vec<Tool>> {
310 tools
311 .iter()
312 .map(|t| {
313 let parameters: serde_json::Value =
314 serde_json::from_str(&t.parameters).map_err(|e| {
315 LlmError::ToolParameterParsing {
316 tool_name: t.name.clone(),
317 error: e.to_string(),
318 }
319 })?;
320
321 Ok(Tool::Function(FunctionTool {
322 name: t.name.clone(),
323 description: Some(t.description.clone()),
324 parameters: Some(parameters),
325 strict: Some(false),
326 }))
327 })
328 .collect()
329}
330
331fn map_reasoning_effort(effort: ReasoningEffort) -> OaiReasoningEffort {
332 match effort {
333 ReasoningEffort::Low => OaiReasoningEffort::Low,
334 ReasoningEffort::Medium => OaiReasoningEffort::Medium,
335 ReasoningEffort::High => OaiReasoningEffort::High,
336 ReasoningEffort::Xhigh => OaiReasoningEffort::Xhigh,
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::ToolCallRequest;
344 use crate::types::IsoString;
345
346 #[test]
347 fn test_build_request_simple_user_message() {
348 let context = Context::new(
349 vec![ChatMessage::User {
350 content: vec![ContentBlock::text("Hello")],
351 timestamp: IsoString::now(),
352 }],
353 vec![],
354 );
355
356 let req = build_response_request("gpt-4.1", &context).unwrap();
357 assert_eq!(req.model, Some("gpt-4.1".to_string()));
358 assert!(req.instructions.is_none());
359 assert!(req.tools.is_none());
360 assert!(req.reasoning.is_none());
361
362 let json = serde_json::to_value(&req).unwrap();
363 assert_eq!(json["input"][0]["role"], "user");
364 assert_eq!(json["input"][0]["content"][0]["text"], "Hello");
365 }
366
367 #[test]
368 fn test_build_request_with_system_message() {
369 let context = Context::new(
370 vec![
371 ChatMessage::System {
372 content: "You are helpful.".to_string(),
373 timestamp: IsoString::now(),
374 },
375 ChatMessage::User {
376 content: vec![ContentBlock::text("Hi")],
377 timestamp: IsoString::now(),
378 },
379 ],
380 vec![],
381 );
382
383 let req = build_response_request("gpt-4.1", &context).unwrap();
384 assert_eq!(req.instructions, Some("You are helpful.".to_string()));
385
386 let json = serde_json::to_value(&req).unwrap();
387 let items = json["input"].as_array().unwrap();
388 assert_eq!(items.len(), 1);
389 assert_eq!(items[0]["role"], "user");
390 }
391
392 #[test]
393 fn test_build_request_with_tool_calls() {
394 let context = Context::new(
395 vec![
396 ChatMessage::User {
397 content: vec![ContentBlock::text("Search for rust")],
398 timestamp: IsoString::now(),
399 },
400 ChatMessage::Assistant {
401 content: String::new(),
402 reasoning: Default::default(),
403 timestamp: IsoString::now(),
404 tool_calls: vec![ToolCallRequest {
405 id: "call_1".to_string(),
406 name: "search".to_string(),
407 arguments: r#"{"q":"rust"}"#.to_string(),
408 }],
409 },
410 ChatMessage::ToolCallResult(Ok(crate::ToolCallResult {
411 id: "call_1".to_string(),
412 name: "search".to_string(),
413 arguments: r#"{"q":"rust"}"#.to_string(),
414 result: "Found results".to_string(),
415 })),
416 ],
417 vec![ToolDefinition {
418 name: "search".to_string(),
419 description: "Search".to_string(),
420 parameters: r#"{"type":"object"}"#.to_string(),
421 server: None,
422 }],
423 );
424
425 let req = build_response_request("gpt-4.1", &context).unwrap();
426 let json = serde_json::to_value(&req).unwrap();
427
428 let items = json["input"].as_array().unwrap();
429 assert_eq!(items[0]["role"], "user");
430 assert_eq!(items[1]["type"], "function_call");
431 assert_eq!(items[1]["call_id"], "call_1");
432 assert_eq!(items[2]["type"], "function_call_output");
433 assert_eq!(items[2]["call_id"], "call_1");
434 assert_eq!(items[2]["output"], "Found results");
435
436 assert!(req.tools.is_some());
437 let tools_json = serde_json::to_value(&req.tools).unwrap();
438 assert_eq!(tools_json[0]["type"], "function");
439 assert_eq!(tools_json[0]["name"], "search");
440 }
441
442 #[test]
443 fn test_build_request_with_reasoning_effort() {
444 let mut context = Context::new(
445 vec![ChatMessage::User {
446 content: vec![ContentBlock::text("Think")],
447 timestamp: IsoString::now(),
448 }],
449 vec![],
450 );
451 context.set_reasoning_effort(Some(ReasoningEffort::High));
452
453 let req = build_response_request("o3", &context).unwrap();
454 let reasoning = req.reasoning.unwrap();
455 assert_eq!(reasoning.effort, Some(OaiReasoningEffort::High));
456 assert_eq!(reasoning.summary, Some(ReasoningSummary::Auto));
457 }
458
459 #[test]
460 fn test_build_request_with_audio_returns_unsupported_content() {
461 let context = Context::new(
462 vec![ChatMessage::User {
463 content: vec![ContentBlock::Audio {
464 data: "YXVkaW8=".to_string(),
465 mime_type: "audio/wav".to_string(),
466 }],
467 timestamp: IsoString::now(),
468 }],
469 vec![],
470 );
471
472 assert!(matches!(
473 build_response_request("gpt-4.1", &context),
474 Err(LlmError::UnsupportedContent(_))
475 ));
476 }
477
478 #[test]
479 fn test_map_tools_valid() {
480 let tools = vec![ToolDefinition {
481 name: "read_file".to_string(),
482 description: "Read a file".to_string(),
483 parameters: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#.to_string(),
484 server: None,
485 }];
486
487 let result = map_tools(&tools).unwrap();
488 assert_eq!(result.len(), 1);
489
490 let json = serde_json::to_value(&result[0]).unwrap();
491 assert_eq!(json["type"], "function");
492 assert_eq!(json["name"], "read_file");
493 }
494
495 #[test]
496 fn test_map_tools_invalid_json() {
497 let tools = vec![ToolDefinition {
498 name: "broken".to_string(),
499 description: "Broken".to_string(),
500 parameters: "not json{".to_string(),
501 server: None,
502 }];
503
504 let result = map_tools(&tools);
505 assert!(result.is_err());
506 match result.unwrap_err() {
507 LlmError::ToolParameterParsing { tool_name, .. } => {
508 assert_eq!(tool_name, "broken");
509 }
510 other => panic!("Expected ToolParameterParsing, got: {other}"),
511 }
512 }
513
514 #[test]
515 fn test_provider_display_name() {
516 let provider = OpenAiProvider {
517 client: Client::new(),
518 model: "gpt-4.1".to_string(),
519 };
520 assert_eq!(provider.display_name(), "OpenAI (gpt-4.1)");
521 }
522}