1use async_trait::async_trait;
18use eventsource_stream::Eventsource;
19use futures::StreamExt;
20use reqwest::{Client, RequestBuilder, Url};
21use serde::{Deserialize, Serialize};
22use serde_json::{Value, json};
23use std::sync::{Arc, Mutex};
24
25use crate::error::{AgentLoopError, Result};
26use crate::llm_driver_registry::{
27 LlmCallConfig, LlmCompletionMetadata, LlmContentPart, LlmDriver, LlmMessage, LlmMessageContent,
28 LlmMessageRole, LlmResponseStream, LlmStreamEvent,
29};
30use crate::llm_retry::{
31 LlmRetryConfig, RateLimitInfo, RetryMetadata, is_rate_limit_status, is_transient_error,
32};
33use crate::tool_types::{ToolCall, ToolDefinition};
34
35const DEFAULT_API_URL: &str = "https://api.openai.com/v1/chat/completions";
36
37pub(crate) fn apply_openai_api_auth(
38 request: RequestBuilder,
39 api_url: &str,
40 api_key: &str,
41) -> RequestBuilder {
42 if is_azure_openai_api_url(api_url) {
43 request.header("api-key", api_key)
44 } else {
45 request.header("Authorization", format!("Bearer {}", api_key))
46 }
47}
48
49pub fn is_azure_openai_api_url(api_url: &str) -> bool {
50 Url::parse(api_url)
51 .ok()
52 .and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase()))
53 .is_some_and(|host| {
54 host.ends_with(".openai.azure.com") || host.ends_with(".services.ai.azure.com")
55 })
56}
57
58pub fn is_openai_api_url(api_url: &str) -> bool {
62 Url::parse(api_url)
63 .ok()
64 .and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase()))
65 .is_some_and(|host| host == "api.openai.com")
66}
67
68#[derive(Clone)]
94pub struct OpenAIProtocolLlmDriver {
95 client: Client,
96 api_key: String,
97 api_url: String,
98 retry_config: LlmRetryConfig,
100}
101
102impl OpenAIProtocolLlmDriver {
103 pub fn new(api_key: impl Into<String>) -> Self {
105 Self {
106 client: Client::new(),
107 api_key: api_key.into(),
108 api_url: DEFAULT_API_URL.to_string(),
109 retry_config: LlmRetryConfig::default(),
110 }
111 }
112
113 pub fn from_env() -> Result<Self> {
115 let api_key = std::env::var("OPENAI_API_KEY")
116 .map_err(|_| AgentLoopError::llm("OPENAI_API_KEY environment variable not set"))?;
117 Ok(Self::new(api_key))
118 }
119
120 pub fn with_base_url(api_key: impl Into<String>, api_url: impl Into<String>) -> Self {
122 Self {
123 client: Client::new(),
124 api_key: api_key.into(),
125 api_url: api_url.into(),
126 retry_config: LlmRetryConfig::default(),
127 }
128 }
129
130 pub fn with_retry_config(mut self, config: LlmRetryConfig) -> Self {
132 self.retry_config = config;
133 self
134 }
135
136 pub fn api_url(&self) -> &str {
138 &self.api_url
139 }
140
141 pub fn api_key(&self) -> &str {
143 &self.api_key
144 }
145
146 pub fn client(&self) -> &Client {
148 &self.client
149 }
150
151 fn convert_role(role: &LlmMessageRole) -> &'static str {
152 match role {
153 LlmMessageRole::System => "system",
154 LlmMessageRole::User => "user",
155 LlmMessageRole::Assistant => "assistant",
156 LlmMessageRole::Tool => "tool",
157 }
158 }
159
160 fn convert_message(msg: &LlmMessage) -> OpenAiMessage {
161 let content = match &msg.content {
162 LlmMessageContent::Text(text) => OpenAiContent::Text(text.clone()),
163 LlmMessageContent::Parts(parts) => {
164 let openai_parts: Vec<OpenAiContentPart> = parts
165 .iter()
166 .map(|part| match part {
167 LlmContentPart::Text { text } => OpenAiContentPart::Text {
168 r#type: "text".to_string(),
169 text: text.clone(),
170 },
171 LlmContentPart::Image { url } => OpenAiContentPart::ImageUrl {
172 r#type: "image_url".to_string(),
173 image_url: OpenAiImageUrl { url: url.clone() },
174 },
175 LlmContentPart::Audio { url } => OpenAiContentPart::InputAudio {
176 r#type: "input_audio".to_string(),
177 input_audio: OpenAiInputAudio {
178 data: url.clone(),
179 format: "wav".to_string(),
180 },
181 },
182 })
183 .collect();
184 OpenAiContent::Parts(openai_parts)
185 }
186 };
187
188 let tool_calls = if msg.role == LlmMessageRole::Assistant {
190 msg.tool_calls.as_ref().map(|calls| {
191 calls
192 .iter()
193 .map(|tc| OpenAiToolCall {
194 id: tc.id.clone(),
195 r#type: "function".to_string(),
196 function: OpenAiFunctionCall {
197 name: tc.name.clone(),
198 arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
199 },
200 })
201 .collect()
202 })
203 } else {
204 None
205 };
206
207 OpenAiMessage {
208 role: Self::convert_role(&msg.role).to_string(),
209 content: Some(content),
210 tool_calls,
211 tool_call_id: msg.tool_call_id.clone(),
212 }
213 }
214
215 fn convert_tools(tools: &[ToolDefinition]) -> Vec<OpenAiTool> {
216 tools
217 .iter()
218 .map(|tool| OpenAiTool {
219 r#type: "function".to_string(),
220 function: OpenAiFunction {
221 name: tool.name().to_string(),
222 description: tool.description().to_string(),
223 parameters: tool.parameters().clone(),
224 },
225 })
226 .collect()
227 }
228}
229
230fn drop_orphaned_tool_messages(messages: &[LlmMessage]) -> Vec<LlmMessage> {
234 use std::collections::HashSet;
235
236 let visible_call_ids: HashSet<&str> = messages
237 .iter()
238 .filter(|m| m.role == LlmMessageRole::Assistant)
239 .flat_map(|m| m.tool_calls.iter().flatten())
240 .map(|tc| tc.id.as_str())
241 .collect();
242
243 if visible_call_ids.is_empty() {
244 return messages
245 .iter()
246 .filter(|m| m.role != LlmMessageRole::Tool)
247 .cloned()
248 .collect();
249 }
250
251 messages
252 .iter()
253 .filter(|m| {
254 if m.role == LlmMessageRole::Tool {
255 return m
256 .tool_call_id
257 .as_deref()
258 .is_none_or(|id| visible_call_ids.contains(id));
259 }
260 true
261 })
262 .cloned()
263 .collect()
264}
265
266#[async_trait]
267impl LlmDriver for OpenAIProtocolLlmDriver {
268 async fn chat_completion_stream(
269 &self,
270 messages: Vec<LlmMessage>,
271 config: &LlmCallConfig,
272 ) -> Result<LlmResponseStream> {
273 let messages = drop_orphaned_tool_messages(&messages);
277 let openai_messages: Vec<OpenAiMessage> =
278 messages.iter().map(Self::convert_message).collect();
279
280 let tools = if config.tools.is_empty() {
281 None
282 } else {
283 Some(Self::convert_tools(&config.tools))
284 };
285
286 let metadata = if config.metadata.is_empty() {
288 None
289 } else {
290 Some(config.metadata.clone())
291 };
292
293 let request = OpenAiRequest {
294 model: config.model.clone(),
295 messages: openai_messages,
296 temperature: config.temperature,
297 max_tokens: config.max_tokens,
298 stream: true,
299 stream_options: Some(OpenAiStreamOptions {
300 include_usage: true,
301 }),
302 tools,
303 reasoning_effort: config
305 .reasoning_effort
306 .as_ref()
307 .filter(|e| !e.eq_ignore_ascii_case("none"))
308 .cloned(),
309 metadata,
310 };
311
312 let mut retry_metadata = RetryMetadata::default();
314 let mut last_error: Option<String> = None;
315
316 let response = loop {
317 let response = apply_openai_api_auth(
318 self.client.post(&self.api_url),
319 &self.api_url,
320 &self.api_key,
321 )
322 .header("Content-Type", "application/json")
323 .json(&request)
324 .send()
325 .await
326 .map_err(|e| AgentLoopError::llm(format!("Failed to send request: {}", e)))?;
327
328 let status = response.status();
329
330 if status.is_success() {
331 break response;
333 }
334
335 if is_transient_error(status) && retry_metadata.attempts < self.retry_config.max_retries
337 {
338 let rate_limit_info = if is_rate_limit_status(status) {
340 Some(RateLimitInfo::from_openai_headers(response.headers()))
341 } else {
342 None
343 };
344
345 let error_text = response.text().await.unwrap_or_default();
346
347 if is_openai_request_too_large(status, &error_text) {
349 return Err(AgentLoopError::request_too_large(format!(
350 "OpenAI API error ({}): {}",
351 status, error_text
352 )));
353 }
354
355 let wait_duration = rate_limit_info
357 .as_ref()
358 .map(|info| info.recommended_wait(&self.retry_config, retry_metadata.attempts))
359 .unwrap_or_else(|| {
360 self.retry_config.calculate_backoff(retry_metadata.attempts)
361 });
362
363 tracing::warn!(
364 status = %status,
365 attempt = retry_metadata.attempts + 1,
366 max_retries = self.retry_config.max_retries,
367 wait_secs = wait_duration.as_secs_f64(),
368 retry_after = ?rate_limit_info.as_ref().and_then(|i| i.retry_after_secs),
369 "OpenAIProtocolDriver: rate limit or transient error, retrying"
370 );
371
372 retry_metadata.record_retry(wait_duration, rate_limit_info);
374 last_error = Some(error_text);
375
376 tokio::time::sleep(wait_duration).await;
378 continue;
379 }
380
381 let error_text = response.text().await.unwrap_or_default();
383 let error_msg = format!("OpenAI API error ({}): {}", status, error_text);
384
385 if is_openai_model_not_found(status, &error_text) {
387 return Err(AgentLoopError::model_not_available(config.model.clone()));
388 }
389
390 if is_openai_request_too_large(status, &error_text) {
392 return Err(AgentLoopError::request_too_large(error_msg));
393 }
394
395 if retry_metadata.attempts > 0 {
397 return Err(AgentLoopError::llm(format!(
398 "{} (after {} retries, last error: {})",
399 error_msg,
400 retry_metadata.attempts,
401 last_error.unwrap_or_default()
402 )));
403 }
404
405 return Err(AgentLoopError::llm(error_msg));
406 };
407
408 if retry_metadata.had_retries() {
410 tracing::info!(
411 attempts = retry_metadata.attempts,
412 total_wait_secs = retry_metadata.total_retry_wait.as_secs_f64(),
413 "OpenAIProtocolDriver: request succeeded after retries"
414 );
415 }
416
417 let byte_stream = response.bytes_stream();
418 let event_stream = byte_stream.eventsource();
419
420 let model = config.model.clone();
421 let total_tokens = Arc::new(Mutex::new(0u32));
422 let prompt_tokens = Arc::new(Mutex::new(0u32));
423 let cache_read_tokens = Arc::new(Mutex::new(Option::<u32>::None));
424 let provider_cost_usd = Arc::new(Mutex::new(Option::<f64>::None));
427 let accumulated_tool_calls = Arc::new(Mutex::new(Vec::<ToolCall>::new()));
428 let finish_reason = Arc::new(Mutex::new(Option::<String>::None));
429 let shared_retry_metadata = if retry_metadata.had_retries() {
431 Some(Arc::new(retry_metadata))
432 } else {
433 None
434 };
435
436 let converted_stream: LlmResponseStream = Box::pin(
440 event_stream
441 .then(move |result| {
442 let model = model.clone();
443 let total_tokens = Arc::clone(&total_tokens);
444 let prompt_tokens = Arc::clone(&prompt_tokens);
445 let cache_read_tokens = Arc::clone(&cache_read_tokens);
446 let provider_cost_usd = Arc::clone(&provider_cost_usd);
447 let accumulated_tool_calls = Arc::clone(&accumulated_tool_calls);
448 let finish_reason = Arc::clone(&finish_reason);
449 let retry_metadata_for_done = shared_retry_metadata.clone();
450
451 async move {
452 let event = match result {
453 Ok(event) => event,
454 Err(e) => {
455 return vec![Ok(LlmStreamEvent::Error(format!(
456 "Stream error: {}",
457 e
458 )))];
459 }
460 };
461
462 if event.data == "[DONE]" {
463 let output_tokens = *total_tokens.lock().unwrap();
464 let input_tokens = *prompt_tokens.lock().unwrap();
465 let cached = *cache_read_tokens.lock().unwrap();
466 let cost = *provider_cost_usd.lock().unwrap();
467 let mut reason = finish_reason.lock().unwrap().clone();
468
469 let mut events = Vec::new();
470
471 {
480 let mut acc = accumulated_tool_calls.lock().unwrap();
481 if let Some(event) = take_pending_tool_calls(&mut acc) {
482 events.push(Ok(event));
483 reason.get_or_insert_with(|| "tool_calls".to_string());
484 }
485 }
486
487 events.push(Ok(LlmStreamEvent::Done(Box::new(
488 LlmCompletionMetadata {
489 total_tokens: Some(input_tokens + output_tokens),
490 prompt_tokens: Some(input_tokens),
491 completion_tokens: Some(output_tokens),
492 cache_read_tokens: cached,
493 cache_creation_tokens: None,
494 provider_cost_usd: cost,
495 model: Some(model),
496 finish_reason: reason.or_else(|| Some("stop".to_string())),
497 retry_metadata: retry_metadata_for_done
498 .map(|arc| (*arc).clone()),
499 response_id: None,
500 phase: None,
501 },
502 ))));
503
504 return events;
505 }
506
507 match serde_json::from_str::<OpenAiStreamChunk>(&event.data) {
508 Ok(chunk) => {
509 if let Some(usage) = &chunk.usage {
511 if let Some(pt) = usage.prompt_tokens {
512 *prompt_tokens.lock().unwrap() = pt;
513 }
514 if let Some(ct) = usage.completion_tokens {
515 *total_tokens.lock().unwrap() = ct;
516 }
517 if let Some(details) = &usage.prompt_tokens_details
519 && details.cached_tokens.is_some()
520 {
521 *cache_read_tokens.lock().unwrap() = details.cached_tokens;
522 }
523 if usage.cost.is_some() {
526 *provider_cost_usd.lock().unwrap() = usage.cost;
527 }
528 }
529
530 if let Some(choice) = chunk.choices.first() {
531 let mut tt = total_tokens.lock().unwrap();
532 let mut acc = accumulated_tool_calls.lock().unwrap();
533 let mut fr = finish_reason.lock().unwrap();
534 let stream_event =
535 process_stream_choice(choice, &mut tt, &mut acc, &mut fr);
536 return vec![Ok(stream_event)];
537 }
538 vec![Ok(LlmStreamEvent::TextDelta(String::new()))]
539 }
540 Err(e) => vec![Ok(LlmStreamEvent::Error(format!(
541 "Failed to parse chunk: {}",
542 e
543 )))],
544 }
545 }
546 })
547 .flat_map(futures::stream::iter),
548 );
549
550 Ok(converted_stream)
551 }
552}
553
554impl std::fmt::Debug for OpenAIProtocolLlmDriver {
555 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
556 f.debug_struct("OpenAIProtocolLlmDriver")
557 .field("api_url", &self.api_url)
558 .field("api_key", &"[REDACTED]")
559 .finish()
560 }
561}
562
563pub fn is_openai_model_not_found(status: reqwest::StatusCode, error_text: &str) -> bool {
574 let error_lower = error_text.to_lowercase();
575
576 if status == reqwest::StatusCode::NOT_FOUND
578 || status == reqwest::StatusCode::BAD_REQUEST
579 || status == reqwest::StatusCode::FORBIDDEN
580 {
581 if error_lower.contains("model_not_found") {
583 return true;
584 }
585 }
586
587 if status == reqwest::StatusCode::NOT_FOUND {
589 if error_lower.contains("does not exist") {
590 return true;
591 }
592 if error_lower.contains("model") && error_lower.contains("not found") {
593 return true;
594 }
595 }
596
597 false
598}
599
600pub fn is_openai_request_too_large(status: reqwest::StatusCode, error_text: &str) -> bool {
607 let error_lower = error_text.to_lowercase();
608
609 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
611 if error_lower.contains("request too large") {
613 return true;
614 }
615 if error_lower.contains("tokens") && error_lower.contains("limit") {
617 return true;
618 }
619 }
620
621 if status == reqwest::StatusCode::BAD_REQUEST {
623 if error_lower.contains("context_length_exceeded") {
625 return true;
626 }
627 if error_lower.contains("maximum context length") {
629 return true;
630 }
631 }
632
633 if error_lower.contains("tokens must be reduced")
635 || error_lower.contains("reduce the length")
636 || error_lower.contains("input is too long")
637 {
638 return true;
639 }
640
641 false
642}
643
644#[derive(Debug, Serialize)]
649struct OpenAiRequest {
650 model: String,
651 messages: Vec<OpenAiMessage>,
652 #[serde(skip_serializing_if = "Option::is_none")]
653 temperature: Option<f32>,
654 #[serde(skip_serializing_if = "Option::is_none")]
655 max_tokens: Option<u32>,
656 stream: bool,
657 #[serde(skip_serializing_if = "Option::is_none")]
659 stream_options: Option<OpenAiStreamOptions>,
660 #[serde(skip_serializing_if = "Option::is_none")]
661 tools: Option<Vec<OpenAiTool>>,
662 #[serde(skip_serializing_if = "Option::is_none")]
663 reasoning_effort: Option<String>,
664 #[serde(skip_serializing_if = "Option::is_none")]
667 metadata: Option<std::collections::HashMap<String, String>>,
668}
669
670#[derive(Debug, Serialize)]
671struct OpenAiStreamOptions {
672 include_usage: bool,
673}
674
675#[derive(Debug, Serialize, Deserialize)]
676#[serde(untagged)]
677enum OpenAiContent {
678 Text(String),
679 Parts(Vec<OpenAiContentPart>),
680}
681
682#[derive(Debug, Serialize, Deserialize)]
683#[serde(untagged)]
684enum OpenAiContentPart {
685 Text {
686 r#type: String,
687 text: String,
688 },
689 ImageUrl {
690 r#type: String,
691 image_url: OpenAiImageUrl,
692 },
693 InputAudio {
694 r#type: String,
695 input_audio: OpenAiInputAudio,
696 },
697}
698
699#[derive(Debug, Serialize, Deserialize)]
700struct OpenAiImageUrl {
701 url: String,
702}
703
704#[derive(Debug, Serialize, Deserialize)]
705struct OpenAiInputAudio {
706 data: String,
707 format: String,
708}
709
710#[derive(Debug, Serialize, Deserialize)]
711struct OpenAiMessage {
712 role: String,
713 #[serde(skip_serializing_if = "Option::is_none")]
714 content: Option<OpenAiContent>,
715 #[serde(skip_serializing_if = "Option::is_none")]
716 tool_calls: Option<Vec<OpenAiToolCall>>,
717 #[serde(skip_serializing_if = "Option::is_none")]
718 tool_call_id: Option<String>,
719}
720
721#[derive(Debug, Serialize, Deserialize)]
722struct OpenAiTool {
723 r#type: String,
724 function: OpenAiFunction,
725}
726
727#[derive(Debug, Serialize, Deserialize)]
728struct OpenAiFunction {
729 name: String,
730 description: String,
731 parameters: Value,
732}
733
734#[derive(Debug, Serialize, Deserialize)]
735struct OpenAiToolCall {
736 id: String,
737 r#type: String,
738 function: OpenAiFunctionCall,
739}
740
741#[derive(Debug, Serialize, Deserialize)]
742struct OpenAiFunctionCall {
743 name: String,
744 arguments: String,
745}
746
747#[derive(Debug, Deserialize)]
748#[allow(dead_code)] struct OpenAiStreamChunk {
750 #[serde(default)]
752 id: Option<String>,
753 #[serde(default)]
755 model: Option<String>,
756 choices: Vec<OpenAiStreamChoice>,
757 #[serde(default)]
758 usage: Option<OpenAiUsage>,
759}
760
761#[derive(Debug, Deserialize)]
762struct OpenAiUsage {
763 prompt_tokens: Option<u32>,
764 completion_tokens: Option<u32>,
765 #[serde(default)]
767 prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
768 #[serde(default)]
771 cost: Option<f64>,
772}
773
774#[derive(Debug, Deserialize, Default)]
775struct OpenAiPromptTokensDetails {
776 #[serde(default)]
778 cached_tokens: Option<u32>,
779}
780
781#[derive(Debug, Deserialize)]
782struct OpenAiStreamChoice {
783 delta: OpenAiDelta,
784 #[serde(default)]
785 finish_reason: Option<String>,
786}
787
788#[derive(Debug, Deserialize)]
789struct OpenAiDelta {
790 #[serde(default)]
791 content: Option<String>,
792 #[serde(default)]
793 tool_calls: Option<Vec<OpenAiStreamToolCall>>,
794}
795
796#[derive(Debug, Deserialize)]
797struct OpenAiStreamToolCall {
798 index: u32,
799 id: Option<String>,
800 function: Option<OpenAiStreamFunction>,
801}
802
803#[derive(Debug, Deserialize)]
804struct OpenAiStreamFunction {
805 name: Option<String>,
806 arguments: Option<String>,
807}
808
809fn finalize_tool_calls(tool_calls: Vec<ToolCall>) -> Vec<ToolCall> {
812 tool_calls
813 .into_iter()
814 .map(|mut tc| {
815 if let Some(args_str) = tc.arguments.as_str() {
816 tc.arguments = serde_json::from_str(args_str).unwrap_or(json!({}));
817 }
818 tc
819 })
820 .collect()
821}
822
823fn take_pending_tool_calls(accumulated_tool_calls: &mut Vec<ToolCall>) -> Option<LlmStreamEvent> {
827 if accumulated_tool_calls.is_empty() {
828 return None;
829 }
830 let calls = std::mem::take(accumulated_tool_calls);
831 Some(LlmStreamEvent::ToolCalls(finalize_tool_calls(calls)))
832}
833
834fn process_stream_choice(
844 choice: &OpenAiStreamChoice,
845 total_tokens: &mut u32,
846 accumulated_tool_calls: &mut Vec<ToolCall>,
847 finish_reason: &mut Option<String>,
848) -> LlmStreamEvent {
849 if let Some(tool_calls) = &choice.delta.tool_calls {
851 for tc in tool_calls {
852 let idx = tc.index as usize;
853 while accumulated_tool_calls.len() <= idx {
854 accumulated_tool_calls.push(ToolCall {
855 id: String::new(),
856 name: String::new(),
857 arguments: json!(""),
858 });
859 }
860
861 if let Some(id) = &tc.id {
862 accumulated_tool_calls[idx].id = id.clone();
863 }
864 if let Some(function) = &tc.function {
865 if let Some(name) = &function.name {
866 accumulated_tool_calls[idx].name = name.clone();
867 }
868 if let Some(args) = &function.arguments {
869 let current = accumulated_tool_calls[idx].arguments.as_str().unwrap_or("");
870 let combined = format!("{}{}", current, args);
871 accumulated_tool_calls[idx].arguments = json!(combined);
872 }
873 }
874 }
875 return LlmStreamEvent::TextDelta(String::new());
876 }
877
878 if let Some(content) = &choice.delta.content
881 && !content.is_empty()
882 {
883 *total_tokens += 1;
884 return LlmStreamEvent::TextDelta(content.clone());
885 }
886
887 if let Some(fr) = &choice.finish_reason {
891 *finish_reason = Some(fr.clone());
892
893 if fr == "tool_calls" && !accumulated_tool_calls.is_empty() {
894 let calls = std::mem::take(accumulated_tool_calls);
895 return LlmStreamEvent::ToolCalls(finalize_tool_calls(calls));
896 }
897 }
898
899 LlmStreamEvent::TextDelta(String::new())
900}
901
902#[cfg(test)]
907mod tests {
908 use super::*;
909
910 #[test]
911 fn test_driver_with_api_key() {
912 let driver = OpenAIProtocolLlmDriver::new("test-key");
913 assert!(format!("{:?}", driver).contains("OpenAIProtocolLlmDriver"));
914 }
915
916 #[test]
917 fn test_driver_with_base_url() {
918 let driver = OpenAIProtocolLlmDriver::with_base_url(
919 "test-key",
920 "https://custom.api.com/v1/completions",
921 );
922 assert!(format!("{:?}", driver).contains("OpenAIProtocolLlmDriver"));
923 assert_eq!(driver.api_url(), "https://custom.api.com/v1/completions");
924 }
925
926 #[test]
927 fn test_is_azure_openai_api_url() {
928 assert!(is_azure_openai_api_url(
929 "https://example.openai.azure.com/openai/v1/chat/completions"
930 ));
931 assert!(is_azure_openai_api_url(
932 "https://example.services.ai.azure.com/openai/v1/responses"
933 ));
934 assert!(!is_azure_openai_api_url(
935 "https://api.openai.com/v1/chat/completions"
936 ));
937 }
938
939 #[test]
940 fn test_request_includes_stream_options_for_usage() {
941 let request = OpenAiRequest {
944 model: "gpt-4o".to_string(),
945 messages: vec![OpenAiMessage {
946 role: "user".to_string(),
947 content: Some(OpenAiContent::Text("Hello".to_string())),
948 tool_calls: None,
949 tool_call_id: None,
950 }],
951 temperature: None,
952 max_tokens: None,
953 stream: true,
954 stream_options: Some(OpenAiStreamOptions {
955 include_usage: true,
956 }),
957 tools: None,
958 reasoning_effort: None,
959 metadata: None,
960 };
961
962 let json = serde_json::to_value(&request).unwrap();
963 assert_eq!(json["stream"], true);
964 assert_eq!(json["stream_options"]["include_usage"], true);
965 }
966
967 #[test]
968 fn test_request_includes_metadata() {
969 let mut metadata = std::collections::HashMap::new();
971 metadata.insert("session_id".to_string(), "session_abc123".to_string());
972 metadata.insert("agent_id".to_string(), "agent_xyz789".to_string());
973
974 let request = OpenAiRequest {
975 model: "gpt-4o".to_string(),
976 messages: vec![OpenAiMessage {
977 role: "user".to_string(),
978 content: Some(OpenAiContent::Text("Hello".to_string())),
979 tool_calls: None,
980 tool_call_id: None,
981 }],
982 temperature: None,
983 max_tokens: None,
984 stream: true,
985 stream_options: None,
986 tools: None,
987 reasoning_effort: None,
988 metadata: Some(metadata),
989 };
990
991 let json = serde_json::to_value(&request).unwrap();
992 assert_eq!(json["metadata"]["session_id"], "session_abc123");
993 assert_eq!(json["metadata"]["agent_id"], "agent_xyz789");
994 }
995
996 #[test]
997 fn test_usage_chunk_parsing() {
998 let usage_chunk = r#"{
1001 "id": "chatcmpl-123",
1002 "object": "chat.completion.chunk",
1003 "created": 1234567890,
1004 "model": "gpt-4o",
1005 "choices": [],
1006 "usage": {
1007 "prompt_tokens": 150,
1008 "completion_tokens": 42,
1009 "total_tokens": 192
1010 }
1011 }"#;
1012
1013 let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1014 assert!(chunk.usage.is_some());
1015 let usage = chunk.usage.unwrap();
1016 assert_eq!(usage.prompt_tokens, Some(150));
1017 assert_eq!(usage.completion_tokens, Some(42));
1018 }
1019
1020 #[test]
1021 fn test_usage_chunk_with_cached_tokens() {
1022 let usage_chunk = r#"{
1024 "id": "chatcmpl-123",
1025 "choices": [],
1026 "usage": {
1027 "prompt_tokens": 150,
1028 "completion_tokens": 42,
1029 "prompt_tokens_details": {
1030 "cached_tokens": 100
1031 }
1032 }
1033 }"#;
1034
1035 let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1036 let usage = chunk.usage.unwrap();
1037 assert_eq!(usage.prompt_tokens, Some(150));
1038 assert_eq!(usage.completion_tokens, Some(42));
1039 assert!(usage.prompt_tokens_details.is_some());
1040 assert_eq!(
1041 usage.prompt_tokens_details.unwrap().cached_tokens,
1042 Some(100)
1043 );
1044 }
1045
1046 #[test]
1047 fn test_usage_chunk_with_openrouter_cost() {
1048 let usage_chunk = r#"{
1050 "id": "gen-123",
1051 "choices": [],
1052 "usage": {
1053 "prompt_tokens": 194,
1054 "completion_tokens": 2,
1055 "total_tokens": 196,
1056 "cost": 0.00095
1057 }
1058 }"#;
1059
1060 let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1061 let usage = chunk.usage.unwrap();
1062 assert_eq!(usage.cost, Some(0.00095));
1063 }
1064
1065 #[test]
1066 fn test_usage_chunk_without_cost_defaults_none() {
1067 let usage_chunk = r#"{
1069 "id": "chatcmpl-123",
1070 "choices": [],
1071 "usage": { "prompt_tokens": 10, "completion_tokens": 5 }
1072 }"#;
1073
1074 let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1075 assert_eq!(chunk.usage.unwrap().cost, None);
1076 }
1077
1078 #[test]
1079 fn test_finish_reason_chunk_parsing() {
1080 let finish_chunk = r#"{
1082 "id": "chatcmpl-123",
1083 "choices": [{
1084 "index": 0,
1085 "delta": {},
1086 "finish_reason": "stop"
1087 }]
1088 }"#;
1089
1090 let chunk: OpenAiStreamChunk = serde_json::from_str(finish_chunk).unwrap();
1091 assert!(chunk.usage.is_none()); assert_eq!(chunk.choices.len(), 1);
1093 assert_eq!(chunk.choices[0].finish_reason, Some("stop".to_string()));
1094 }
1095
1096 #[test]
1101 fn test_is_openai_request_too_large_429_request_too_large() {
1102 let error = r#"{"error":{"message":"Request too large for gpt-4o in organization org-xxx on tokens per min (TPM): Limit 500000, Requested 538772."}}"#;
1103 assert!(is_openai_request_too_large(
1104 reqwest::StatusCode::TOO_MANY_REQUESTS,
1105 error
1106 ));
1107 }
1108
1109 #[test]
1110 fn test_is_openai_request_too_large_429_token_limit() {
1111 let error =
1112 r#"{"error":{"message":"tokens per min (TPM): Limit 500000, Requested 600000"}}"#;
1113 assert!(is_openai_request_too_large(
1114 reqwest::StatusCode::TOO_MANY_REQUESTS,
1115 error
1116 ));
1117 }
1118
1119 #[test]
1120 fn test_is_openai_request_too_large_400_context_length() {
1121 let error = r#"{"error":{"code":"context_length_exceeded","message":"This model's maximum context length is 128000 tokens."}}"#;
1122 assert!(is_openai_request_too_large(
1123 reqwest::StatusCode::BAD_REQUEST,
1124 error
1125 ));
1126 }
1127
1128 #[test]
1129 fn test_is_openai_request_too_large_400_max_context() {
1130 let error =
1131 r#"{"error":{"message":"This model's maximum context length is 128000 tokens"}}"#;
1132 assert!(is_openai_request_too_large(
1133 reqwest::StatusCode::BAD_REQUEST,
1134 error
1135 ));
1136 }
1137
1138 #[test]
1139 fn test_is_openai_request_too_large_tokens_must_be_reduced() {
1140 let error = r#"{"error":{"message":"The input or output tokens must be reduced"}}"#;
1141 assert!(is_openai_request_too_large(
1142 reqwest::StatusCode::BAD_REQUEST,
1143 error
1144 ));
1145 }
1146
1147 #[test]
1148 fn test_is_openai_request_too_large_false_for_other_errors() {
1149 let error = r#"{"error":{"message":"Rate limit exceeded: too many requests per minute"}}"#;
1151 assert!(!is_openai_request_too_large(
1152 reqwest::StatusCode::TOO_MANY_REQUESTS,
1153 error
1154 ));
1155
1156 let error = r#"{"error":{"message":"Internal server error"}}"#;
1158 assert!(!is_openai_request_too_large(
1159 reqwest::StatusCode::INTERNAL_SERVER_ERROR,
1160 error
1161 ));
1162
1163 let error = r#"{"error":{"message":"Invalid request"}}"#;
1165 assert!(!is_openai_request_too_large(
1166 reqwest::StatusCode::BAD_REQUEST,
1167 error
1168 ));
1169 }
1170
1171 #[test]
1176 fn test_is_openai_model_not_found_real_error() {
1177 let error = r#"{"error":{"code":"model_not_found","message":"The model 'gpt-99' does not exist or you do not have access to it.","type":"invalid_request_error","param":null}}"#;
1179 assert!(is_openai_model_not_found(
1180 reqwest::StatusCode::NOT_FOUND,
1181 error
1182 ));
1183 }
1184
1185 #[test]
1186 fn test_is_openai_model_not_found_does_not_exist() {
1187 let error = r#"{"error":{"message":"The model 'fake-model' does not exist"}}"#;
1188 assert!(is_openai_model_not_found(
1189 reqwest::StatusCode::NOT_FOUND,
1190 error
1191 ));
1192 }
1193
1194 #[test]
1195 fn test_is_openai_model_not_found_generic_not_found() {
1196 let error = r#"{"error":{"message":"Model not found"}}"#;
1197 assert!(is_openai_model_not_found(
1198 reqwest::StatusCode::NOT_FOUND,
1199 error
1200 ));
1201 }
1202
1203 #[test]
1204 fn test_is_openai_model_not_found_400_with_model_not_found_code() {
1205 let error = r#"{"error":{"code":"model_not_found","message":"The requested model 'gpt-99' does not exist.","type":"invalid_request_error","param":"model"}}"#;
1207 assert!(is_openai_model_not_found(
1208 reqwest::StatusCode::BAD_REQUEST,
1209 error
1210 ));
1211 }
1212
1213 #[test]
1214 fn test_is_openai_model_not_found_false_for_non_model_error() {
1215 let error = r#"{"error":{"code":"invalid_request","message":"Some other error"}}"#;
1217 assert!(!is_openai_model_not_found(
1218 reqwest::StatusCode::BAD_REQUEST,
1219 error
1220 ));
1221 }
1222
1223 #[test]
1224 fn test_is_openai_model_not_found_false_for_other_404() {
1225 let error = r#"{"error":{"message":"Endpoint not found"}}"#;
1227 assert!(!is_openai_model_not_found(
1228 reqwest::StatusCode::NOT_FOUND,
1229 error
1230 ));
1231 }
1232
1233 #[test]
1234 fn test_is_openai_model_not_found_403_tier_gated_model() {
1235 let error = r#"{"error":{"code":"model_not_found","message":"The model 'gpt-5.4-mini' does not exist or you do not have access to it.","type":"invalid_request_error","param":null}}"#;
1238 assert!(is_openai_model_not_found(
1239 reqwest::StatusCode::FORBIDDEN,
1240 error
1241 ));
1242 }
1243
1244 #[test]
1245 fn test_is_openai_model_not_found_403_plain_auth_error_is_not_model_not_found() {
1246 let error = r#"{"error":{"message":"Invalid authentication credentials","type":"authentication_error"}}"#;
1249 assert!(!is_openai_model_not_found(
1250 reqwest::StatusCode::FORBIDDEN,
1251 error
1252 ));
1253 }
1254
1255 #[test]
1260 fn test_reasoning_effort_none_is_omitted() {
1261 let request = OpenAiRequest {
1264 model: "gpt-4o-mini".to_string(),
1265 messages: vec![OpenAiMessage {
1266 role: "user".to_string(),
1267 content: Some(OpenAiContent::Text("Hello".to_string())),
1268 tool_calls: None,
1269 tool_call_id: None,
1270 }],
1271 temperature: None,
1272 max_tokens: None,
1273 stream: true,
1274 stream_options: None,
1275 tools: None,
1276 reasoning_effort: Some("none".to_string())
1277 .as_ref()
1278 .filter(|e| !e.eq_ignore_ascii_case("none"))
1279 .cloned(),
1280 metadata: None,
1281 };
1282
1283 let json = serde_json::to_value(&request).unwrap();
1284 assert!(
1285 json.get("reasoning_effort").is_none(),
1286 "reasoning_effort should be omitted when effort is 'none'"
1287 );
1288 }
1289
1290 #[test]
1291 fn test_reasoning_effort_high_is_included() {
1292 let request = OpenAiRequest {
1293 model: "o3-mini".to_string(),
1294 messages: vec![OpenAiMessage {
1295 role: "user".to_string(),
1296 content: Some(OpenAiContent::Text("Hello".to_string())),
1297 tool_calls: None,
1298 tool_call_id: None,
1299 }],
1300 temperature: None,
1301 max_tokens: None,
1302 stream: true,
1303 stream_options: None,
1304 tools: None,
1305 reasoning_effort: Some("high".to_string())
1306 .as_ref()
1307 .filter(|e| !e.eq_ignore_ascii_case("none"))
1308 .cloned(),
1309 metadata: None,
1310 };
1311
1312 let json = serde_json::to_value(&request).unwrap();
1313 assert_eq!(json["reasoning_effort"], "high");
1314 }
1315
1316 fn choice(json_str: &str) -> OpenAiStreamChoice {
1321 serde_json::from_str(json_str).unwrap()
1322 }
1323
1324 #[test]
1328 fn test_empty_content_finish_chunk_still_emits_tool_calls() {
1329 let mut total_tokens = 0u32;
1330 let mut acc: Vec<ToolCall> = Vec::new();
1331 let mut finish_reason: Option<String> = None;
1332
1333 let e = process_stream_choice(
1335 &choice(
1336 r#"{"delta":{"content":null,"tool_calls":[{"index":0,"id":"call_1","function":{"name":"read_file","arguments":""}}]},"finish_reason":null}"#,
1337 ),
1338 &mut total_tokens,
1339 &mut acc,
1340 &mut finish_reason,
1341 );
1342 assert!(matches!(e, LlmStreamEvent::TextDelta(s) if s.is_empty()));
1343
1344 let e = process_stream_choice(
1346 &choice(
1347 r#"{"delta":{"content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"Cargo.toml\"}"}}]},"finish_reason":null}"#,
1348 ),
1349 &mut total_tokens,
1350 &mut acc,
1351 &mut finish_reason,
1352 );
1353 assert!(matches!(e, LlmStreamEvent::TextDelta(s) if s.is_empty()));
1354
1355 let e = process_stream_choice(
1358 &choice(r#"{"delta":{"content":""},"finish_reason":"tool_calls"}"#),
1359 &mut total_tokens,
1360 &mut acc,
1361 &mut finish_reason,
1362 );
1363 match e {
1364 LlmStreamEvent::ToolCalls(calls) => {
1365 assert_eq!(calls.len(), 1);
1366 assert_eq!(calls[0].id, "call_1");
1367 assert_eq!(calls[0].name, "read_file");
1368 assert_eq!(calls[0].arguments, json!({"path": "Cargo.toml"}));
1369 }
1370 other => panic!("expected ToolCalls, got {:?}", other),
1371 }
1372 assert_eq!(finish_reason.as_deref(), Some("tool_calls"));
1373
1374 let e = process_stream_choice(
1377 &choice(r#"{"delta":{"content":""},"finish_reason":"tool_calls"}"#),
1378 &mut total_tokens,
1379 &mut acc,
1380 &mut finish_reason,
1381 );
1382 assert!(
1383 matches!(e, LlmStreamEvent::TextDelta(s) if s.is_empty()),
1384 "tool calls must only be emitted once"
1385 );
1386 }
1387
1388 #[test]
1390 fn test_non_empty_content_is_emitted() {
1391 let mut total_tokens = 0u32;
1392 let mut acc: Vec<ToolCall> = Vec::new();
1393 let mut finish_reason: Option<String> = None;
1394
1395 let e = process_stream_choice(
1396 &choice(r#"{"delta":{"content":"hello"},"finish_reason":null}"#),
1397 &mut total_tokens,
1398 &mut acc,
1399 &mut finish_reason,
1400 );
1401 assert!(matches!(e, LlmStreamEvent::TextDelta(s) if s == "hello"));
1402 assert_eq!(total_tokens, 1);
1403 }
1404
1405 #[test]
1408 fn test_finish_chunk_without_content_emits_tool_calls() {
1409 let mut total_tokens = 0u32;
1410 let mut acc: Vec<ToolCall> = Vec::new();
1411 let mut finish_reason: Option<String> = None;
1412
1413 process_stream_choice(
1414 &choice(
1415 r#"{"delta":{"tool_calls":[{"index":0,"id":"call_9","function":{"name":"list_dir","arguments":"{}"}}]},"finish_reason":null}"#,
1416 ),
1417 &mut total_tokens,
1418 &mut acc,
1419 &mut finish_reason,
1420 );
1421
1422 let e = process_stream_choice(
1423 &choice(r#"{"delta":{},"finish_reason":"tool_calls"}"#),
1424 &mut total_tokens,
1425 &mut acc,
1426 &mut finish_reason,
1427 );
1428 match e {
1429 LlmStreamEvent::ToolCalls(calls) => {
1430 assert_eq!(calls.len(), 1);
1431 assert_eq!(calls[0].name, "list_dir");
1432 }
1433 other => panic!("expected ToolCalls, got {:?}", other),
1434 }
1435 }
1436
1437 #[test]
1440 fn test_take_pending_tool_calls_flushes_then_drains() {
1441 let mut acc = vec![ToolCall {
1442 id: "call_1".to_string(),
1443 name: "read_file".to_string(),
1444 arguments: json!(r#"{"path":"Cargo.toml"}"#),
1445 }];
1446
1447 match take_pending_tool_calls(&mut acc) {
1448 Some(LlmStreamEvent::ToolCalls(calls)) => {
1449 assert_eq!(calls.len(), 1);
1450 assert_eq!(calls[0].name, "read_file");
1451 assert_eq!(calls[0].arguments, json!({"path": "Cargo.toml"}));
1452 }
1453 other => panic!("expected ToolCalls, got {:?}", other),
1454 }
1455 assert!(acc.is_empty(), "accumulator must be drained after flush");
1456 assert!(take_pending_tool_calls(&mut acc).is_none());
1457 }
1458
1459 #[test]
1460 fn test_finalize_tool_calls_parses_arguments() {
1461 let calls = vec![ToolCall {
1462 id: "call_1".to_string(),
1463 name: "read_file".to_string(),
1464 arguments: json!(r#"{"path":"src/main.rs"}"#),
1465 }];
1466 let finalized = finalize_tool_calls(calls);
1467 assert_eq!(finalized[0].arguments, json!({"path": "src/main.rs"}));
1468 }
1469
1470 #[test]
1471 fn drop_orphaned_tool_messages_removes_unmatched_tool_results() {
1472 use crate::llm_driver_registry::LlmMessageContent;
1473
1474 let messages = vec![
1475 LlmMessage::text(LlmMessageRole::User, "hello"),
1476 LlmMessage {
1477 role: LlmMessageRole::Tool,
1478 content: LlmMessageContent::Text("result".to_string()),
1479 tool_calls: None,
1480 tool_call_id: Some("call_trimmed".to_string()),
1481 phase: None,
1482 thinking: None,
1483 thinking_signature: None,
1484 },
1485 ];
1486 let filtered = drop_orphaned_tool_messages(&messages);
1487 assert_eq!(filtered.len(), 1);
1488 assert_eq!(filtered[0].role, LlmMessageRole::User);
1489 }
1490
1491 #[test]
1492 fn drop_orphaned_tool_messages_keeps_matched_tool_results() {
1493 use crate::llm_driver_registry::LlmMessageContent;
1494 use crate::tool_types::ToolCall;
1495
1496 let messages = vec![
1497 LlmMessage {
1498 role: LlmMessageRole::Assistant,
1499 content: LlmMessageContent::Text(String::new()),
1500 tool_calls: Some(vec![ToolCall {
1501 id: "call_1".to_string(),
1502 name: "read_file".to_string(),
1503 arguments: json!({}),
1504 }]),
1505 tool_call_id: None,
1506 phase: None,
1507 thinking: None,
1508 thinking_signature: None,
1509 },
1510 LlmMessage {
1511 role: LlmMessageRole::Tool,
1512 content: LlmMessageContent::Text("file content".to_string()),
1513 tool_calls: None,
1514 tool_call_id: Some("call_1".to_string()),
1515 phase: None,
1516 thinking: None,
1517 thinking_signature: None,
1518 },
1519 ];
1520 let filtered = drop_orphaned_tool_messages(&messages);
1521 assert_eq!(filtered.len(), 2);
1522 }
1523}