1use std::collections::HashMap;
2
3use async_openai::Client;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::responses::{
6 CreateResponse, EasyInputContent, EasyInputMessage, FunctionCallOutput, FunctionCallOutputItemParam, FunctionTool,
7 FunctionToolCall, ImageDetail, IncludeEnum, InputContent, InputImageContent, InputItem, InputParam,
8 InputTextContent, Item, MessageType, OutputItem, Reasoning, ReasoningEffort as OaiReasoningEffort,
9 ReasoningSummary, ResponseStreamEvent, ResponseUsage, Role, Tool,
10};
11use tokio_stream::StreamExt;
12use tracing::{debug, error};
13
14use crate::provider::get_context_window;
15use crate::{
16 ChatMessage, ContentBlock, Context, LlmError, LlmModel, LlmResponse, LlmResponseStream, ProviderFactory,
17 ReasoningEffort, Result, StopReason, StreamingModelProvider, TokenUsage, ToolDefinition,
18};
19
20impl From<ResponseUsage> for TokenUsage {
21 fn from(usage: ResponseUsage) -> Self {
22 TokenUsage {
23 input_tokens: usage.input_tokens,
24 output_tokens: usage.output_tokens,
25 cache_read_tokens: Some(usage.input_tokens_details.cached_tokens),
26 reasoning_tokens: Some(usage.output_tokens_details.reasoning_tokens),
27 ..TokenUsage::default()
28 }
29 }
30}
31
32pub(crate) fn map_user_content_for_responses(parts: &[ContentBlock]) -> Result<EasyInputContent> {
33 let mut items = Vec::with_capacity(parts.len());
34 for p in parts {
35 match p {
36 ContentBlock::Text { text } => {
37 items.push(InputContent::InputText(InputTextContent { text: text.clone() }));
38 }
39 ContentBlock::Image { .. } => {
40 items.push(InputContent::InputImage(InputImageContent {
41 detail: ImageDetail::Auto,
42 file_id: None,
43 image_url: Some(p.as_data_uri().unwrap()),
44 }));
45 }
46 ContentBlock::Audio { .. } => {
47 return Err(LlmError::UnsupportedContent("OpenAI Responses does not support audio input".into()));
48 }
49 }
50 }
51 Ok(EasyInputContent::ContentList(items))
52}
53
54pub struct OpenAiProvider {
55 client: Client<OpenAIConfig>,
56 model: String,
57}
58
59impl ProviderFactory for OpenAiProvider {
60 async fn from_env() -> Result<Self> {
61 let api_key =
62 std::env::var("OPENAI_API_KEY").map_err(|_| LlmError::MissingApiKey("OPENAI_API_KEY".to_string()))?;
63
64 let config = OpenAIConfig::new().with_api_key(api_key);
65
66 Ok(Self { client: Client::with_config(config), model: "gpt-4.1".to_string() })
67 }
68
69 fn with_model(mut self, model: &str) -> Self {
70 if !model.is_empty() {
71 self.model = model.to_string();
72 }
73 self
74 }
75}
76
77impl StreamingModelProvider for OpenAiProvider {
78 fn stream_response(&self, context: &Context) -> LlmResponseStream {
79 let client = self.client.clone();
80 let model = self.model.clone();
81 let request = match build_response_request(&model, context) {
82 Ok(req) => req,
83 Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
84 };
85
86 Box::pin(async_stream::stream! {
87 debug!("Starting OpenAI Responses API stream for model: {model}");
88
89 let stream = match client.responses().create_stream(request).await {
90 Ok(s) => s,
91 Err(e) => {
92 error!("Failed to create OpenAI Responses stream: {e:?}");
93 yield Err(LlmError::ApiRequest(e.to_string()));
94 return;
95 }
96 };
97
98 let mut stream = Box::pin(stream);
99 let mut fn_calls: HashMap<String, (String, String)> = HashMap::new();
100 let mut started = false;
101
102 while let Some(result) = stream.next().await {
103 match result {
104 Ok(event) => {
105 for response in process_event(event, &mut fn_calls, &mut started) {
106 yield response;
107 }
108 }
109 Err(e) => {
110 yield Err(LlmError::ApiError(e.to_string()));
111 break;
112 }
113 }
114 }
115
116 if !started {
117 yield Ok(LlmResponse::done());
118 }
119 })
120 }
121
122 fn display_name(&self) -> String {
123 format!("OpenAI ({})", self.model)
124 }
125
126 fn context_window(&self) -> Option<u32> {
127 get_context_window("openai", &self.model)
128 }
129
130 fn model(&self) -> Option<LlmModel> {
131 format!("openai:{}", self.model).parse().ok()
132 }
133}
134
135fn process_event(
136 event: ResponseStreamEvent,
137 fn_calls: &mut HashMap<String, (String, String)>,
138 started: &mut bool,
139) -> Vec<Result<LlmResponse>> {
140 match event {
141 ResponseStreamEvent::ResponseCreated(e) => {
142 *started = true;
143 vec![Ok(LlmResponse::start(&e.response.id))]
144 }
145 ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
146 vec![Ok(LlmResponse::text(&e.delta))]
147 }
148 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
149 vec![Ok(LlmResponse::reasoning(&e.delta))]
150 }
151 ResponseStreamEvent::ResponseOutputItemAdded(e) => {
152 if let OutputItem::FunctionCall(fc) = e.item {
153 let item_id = fc.id.clone().unwrap_or_default();
154 fn_calls.insert(item_id, (fc.call_id.clone(), fc.name.clone()));
155 vec![Ok(LlmResponse::tool_request_start(&fc.call_id, &fc.name))]
156 } else {
157 vec![]
158 }
159 }
160 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
161 if let Some((call_id, _)) = fn_calls.get(&e.item_id) {
162 vec![Ok(LlmResponse::tool_request_arg(call_id, &e.delta))]
163 } else {
164 vec![]
165 }
166 }
167 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
168 if let Some((call_id, name)) = fn_calls.remove(&e.item_id) {
169 let name = e.name.unwrap_or(name);
170 vec![Ok(LlmResponse::tool_request_complete(&call_id, &name, &e.arguments))]
171 } else {
172 vec![]
173 }
174 }
175 ResponseStreamEvent::ResponseCompleted(e) => {
176 let mut results = Vec::new();
177 if let Some(usage) = e.response.usage {
178 results.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
179 }
180 results.push(Ok(LlmResponse::done_with_stop_reason(StopReason::EndTurn)));
181 results
182 }
183 ResponseStreamEvent::ResponseFailed(e) => {
184 let msg = e.response.error.map_or_else(|| "Unknown error".to_string(), |err| err.message);
185 vec![Err(LlmError::ApiError(msg))]
186 }
187 ResponseStreamEvent::ResponseIncomplete(_) => {
188 vec![Ok(LlmResponse::done_with_stop_reason(StopReason::Length))]
189 }
190 ResponseStreamEvent::ResponseError(e) => {
191 vec![Err(LlmError::ApiError(e.message))]
192 }
193 _ => vec![],
194 }
195}
196
197fn build_response_request(model: &str, context: &Context) -> Result<CreateResponse> {
198 let mut instructions: Option<String> = None;
199 let mut items: Vec<InputItem> = Vec::new();
200
201 for msg in context.messages() {
202 match msg {
203 ChatMessage::System { content, .. } => {
204 instructions = Some(content.clone());
205 }
206 ChatMessage::User { content, .. } => {
207 items.push(InputItem::EasyMessage(EasyInputMessage {
208 r#type: MessageType::Message,
209 role: Role::User,
210 content: map_user_content_for_responses(content)?,
211 phase: None,
212 }));
213 }
214 ChatMessage::Assistant { content, tool_calls, .. } => {
215 if !content.is_empty() {
216 items.push(InputItem::EasyMessage(EasyInputMessage {
217 r#type: MessageType::Message,
218 role: Role::Assistant,
219 content: EasyInputContent::Text(content.clone()),
220 phase: None,
221 }));
222 }
223 for tc in tool_calls {
224 items.push(InputItem::Item(Item::FunctionCall(FunctionToolCall {
225 call_id: tc.id.clone(),
226 name: tc.name.clone(),
227 arguments: tc.arguments.clone(),
228 namespace: None,
229 id: None,
230 status: None,
231 })));
232 }
233 }
234 ChatMessage::ToolCallResult(result) => {
235 let (call_id, output) = match result {
236 Ok(r) => (r.id.clone(), r.result.clone()),
237 Err(e) => (e.id.clone(), e.error.clone()),
238 };
239 items.push(InputItem::Item(Item::FunctionCallOutput(FunctionCallOutputItemParam {
240 call_id,
241 output: FunctionCallOutput::Text(output),
242 id: None,
243 status: None,
244 })));
245 }
246 ChatMessage::Summary { content, .. } => {
247 items.push(InputItem::EasyMessage(EasyInputMessage {
248 r#type: MessageType::Message,
249 role: Role::User,
250 content: EasyInputContent::Text(format!("[Previous conversation handoff]\n\n{content}")),
251 phase: None,
252 }));
253 }
254 ChatMessage::Error { .. } => {}
255 }
256 }
257
258 let tools = map_tools(context.tools())?;
259
260 let reasoning = context
261 .reasoning_effort()
262 .map(|effort| Reasoning { effort: Some(map_reasoning_effort(effort)), summary: Some(ReasoningSummary::Auto) });
263
264 Ok(CreateResponse {
265 model: Some(model.to_string()),
266 input: InputParam::Items(items),
267 instructions,
268 tools: if tools.is_empty() { None } else { Some(tools) },
269 reasoning,
270 stream: Some(true),
271 include: Some(vec![IncludeEnum::ReasoningEncryptedContent]),
272 store: Some(false),
273 background: None,
274 conversation: None,
275 max_output_tokens: None,
276 metadata: None,
277 parallel_tool_calls: None,
278 previous_response_id: None,
279 prompt: None,
280 service_tier: None,
281 stream_options: None,
282 temperature: None,
283 text: None,
284 tool_choice: None,
285 top_p: None,
286 truncation: None,
287 prompt_cache_key: None,
288 safety_identifier: None,
289 max_tool_calls: None,
290 prompt_cache_retention: None,
291 top_logprobs: None,
292 })
293}
294
295fn map_tools(tools: &[ToolDefinition]) -> Result<Vec<Tool>> {
296 tools
297 .iter()
298 .map(|t| {
299 let parameters: serde_json::Value = serde_json::from_str(&t.parameters)
300 .map_err(|e| LlmError::ToolParameterParsing { tool_name: t.name.clone(), error: e.to_string() })?;
301
302 Ok(Tool::Function(FunctionTool {
303 name: t.name.clone(),
304 description: Some(t.description.clone()),
305 parameters: Some(parameters),
306 strict: Some(false),
307 defer_loading: None,
308 }))
309 })
310 .collect()
311}
312
313fn map_reasoning_effort(effort: ReasoningEffort) -> OaiReasoningEffort {
314 match effort {
315 ReasoningEffort::Low => OaiReasoningEffort::Low,
316 ReasoningEffort::Medium => OaiReasoningEffort::Medium,
317 ReasoningEffort::High => OaiReasoningEffort::High,
318 ReasoningEffort::Xhigh => OaiReasoningEffort::Xhigh,
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::AssistantReasoning;
326 use crate::ToolCallRequest;
327 use crate::types::IsoString;
328
329 #[test]
330 fn test_build_request_simple_user_message() {
331 let context = Context::new(
332 vec![ChatMessage::User { content: vec![ContentBlock::text("Hello")], timestamp: IsoString::now() }],
333 vec![],
334 );
335
336 let req = build_response_request("gpt-4.1", &context).unwrap();
337 assert_eq!(req.model, Some("gpt-4.1".to_string()));
338 assert!(req.instructions.is_none());
339 assert!(req.tools.is_none());
340 assert!(req.reasoning.is_none());
341
342 let json = serde_json::to_value(&req).unwrap();
343 assert_eq!(json["input"][0]["role"], "user");
344 assert_eq!(json["input"][0]["content"][0]["text"], "Hello");
345 }
346
347 #[test]
348 fn test_build_request_with_system_message() {
349 let context = Context::new(
350 vec![
351 ChatMessage::System { content: "You are helpful.".to_string(), timestamp: IsoString::now() },
352 ChatMessage::User { content: vec![ContentBlock::text("Hi")], timestamp: IsoString::now() },
353 ],
354 vec![],
355 );
356
357 let req = build_response_request("gpt-4.1", &context).unwrap();
358 assert_eq!(req.instructions, Some("You are helpful.".to_string()));
359
360 let json = serde_json::to_value(&req).unwrap();
361 let items = json["input"].as_array().unwrap();
362 assert_eq!(items.len(), 1);
363 assert_eq!(items[0]["role"], "user");
364 }
365
366 #[test]
367 fn test_build_request_with_tool_calls() {
368 let context = Context::new(
369 vec![
370 ChatMessage::User { content: vec![ContentBlock::text("Search for rust")], timestamp: IsoString::now() },
371 ChatMessage::Assistant {
372 content: String::new(),
373 reasoning: AssistantReasoning::default(),
374 timestamp: IsoString::now(),
375 tool_calls: vec![ToolCallRequest {
376 id: "call_1".to_string(),
377 name: "search".to_string(),
378 arguments: r#"{"q":"rust"}"#.to_string(),
379 }],
380 },
381 ChatMessage::ToolCallResult(Ok(crate::ToolCallResult {
382 id: "call_1".to_string(),
383 name: "search".to_string(),
384 arguments: r#"{"q":"rust"}"#.to_string(),
385 result: "Found results".to_string(),
386 })),
387 ],
388 vec![ToolDefinition {
389 name: "search".to_string(),
390 description: "Search".to_string(),
391 parameters: r#"{"type":"object"}"#.to_string(),
392 server: None,
393 }],
394 );
395
396 let req = build_response_request("gpt-4.1", &context).unwrap();
397 let json = serde_json::to_value(&req).unwrap();
398
399 let items = json["input"].as_array().unwrap();
400 assert_eq!(items[0]["role"], "user");
401 assert_eq!(items[1]["type"], "function_call");
402 assert_eq!(items[1]["call_id"], "call_1");
403 assert_eq!(items[2]["type"], "function_call_output");
404 assert_eq!(items[2]["call_id"], "call_1");
405 assert_eq!(items[2]["output"], "Found results");
406
407 assert!(req.tools.is_some());
408 let tools_json = serde_json::to_value(&req.tools).unwrap();
409 assert_eq!(tools_json[0]["type"], "function");
410 assert_eq!(tools_json[0]["name"], "search");
411 }
412
413 #[test]
414 fn test_build_request_with_reasoning_effort() {
415 let mut context = Context::new(
416 vec![ChatMessage::User { content: vec![ContentBlock::text("Think")], timestamp: IsoString::now() }],
417 vec![],
418 );
419 context.set_reasoning_effort(Some(ReasoningEffort::High));
420
421 let req = build_response_request("o3", &context).unwrap();
422 let reasoning = req.reasoning.unwrap();
423 assert_eq!(reasoning.effort, Some(OaiReasoningEffort::High));
424 assert_eq!(reasoning.summary, Some(ReasoningSummary::Auto));
425 }
426
427 #[test]
428 fn test_build_request_with_audio_returns_unsupported_content() {
429 let context = Context::new(
430 vec![ChatMessage::User {
431 content: vec![ContentBlock::Audio { data: "YXVkaW8=".to_string(), mime_type: "audio/wav".to_string() }],
432 timestamp: IsoString::now(),
433 }],
434 vec![],
435 );
436
437 assert!(matches!(build_response_request("gpt-4.1", &context), Err(LlmError::UnsupportedContent(_))));
438 }
439
440 #[test]
441 fn test_map_tools_valid() {
442 let tools = vec![ToolDefinition {
443 name: "read_file".to_string(),
444 description: "Read a file".to_string(),
445 parameters: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#.to_string(),
446 server: None,
447 }];
448
449 let result = map_tools(&tools).unwrap();
450 assert_eq!(result.len(), 1);
451
452 let json = serde_json::to_value(&result[0]).unwrap();
453 assert_eq!(json["type"], "function");
454 assert_eq!(json["name"], "read_file");
455 }
456
457 #[test]
458 fn test_map_tools_invalid_json() {
459 let tools = vec![ToolDefinition {
460 name: "broken".to_string(),
461 description: "Broken".to_string(),
462 parameters: "not json{".to_string(),
463 server: None,
464 }];
465
466 let result = map_tools(&tools);
467 assert!(result.is_err());
468 match result.unwrap_err() {
469 LlmError::ToolParameterParsing { tool_name, .. } => {
470 assert_eq!(tool_name, "broken");
471 }
472 other => panic!("Expected ToolParameterParsing, got: {other}"),
473 }
474 }
475
476 #[test]
477 fn test_provider_display_name() {
478 let provider = OpenAiProvider { client: Client::new(), model: "gpt-4.1".to_string() };
479 assert_eq!(provider.display_name(), "OpenAI (gpt-4.1)");
480 }
481}