1use crate::llm::{
11 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, StopReason,
12 StreamBox, StreamDelta, Usage,
13};
14use anyhow::Result;
15use async_trait::async_trait;
16use futures::StreamExt;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19
20use super::openai_responses::OpenAIResponsesProvider;
21
22const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
23
24fn requires_responses_api(model: &str) -> bool {
26 model.contains("codex")
27}
28
29pub const MODEL_GPT52_INSTANT: &str = "gpt-5.2-instant";
31pub const MODEL_GPT52_THINKING: &str = "gpt-5.2-thinking";
32pub const MODEL_GPT52_PRO: &str = "gpt-5.2-pro";
33pub const MODEL_GPT52_CODEX: &str = "gpt-5.2-codex";
34
35pub const MODEL_GPT5: &str = "gpt-5";
37pub const MODEL_GPT5_MINI: &str = "gpt-5-mini";
38pub const MODEL_GPT5_NANO: &str = "gpt-5-nano";
39
40pub const MODEL_O3: &str = "o3";
42pub const MODEL_O3_MINI: &str = "o3-mini";
43pub const MODEL_O4_MINI: &str = "o4-mini";
44pub const MODEL_O1: &str = "o1";
45pub const MODEL_O1_MINI: &str = "o1-mini";
46
47pub const MODEL_GPT41: &str = "gpt-4.1";
49pub const MODEL_GPT41_MINI: &str = "gpt-4.1-mini";
50pub const MODEL_GPT41_NANO: &str = "gpt-4.1-nano";
51
52pub const MODEL_GPT4O: &str = "gpt-4o";
54pub const MODEL_GPT4O_MINI: &str = "gpt-4o-mini";
55
56#[derive(Clone)]
61pub struct OpenAIProvider {
62 client: reqwest::Client,
63 api_key: String,
64 model: String,
65 base_url: String,
66}
67
68impl OpenAIProvider {
69 #[must_use]
71 pub fn new(api_key: String, model: String) -> Self {
72 Self {
73 client: reqwest::Client::new(),
74 api_key,
75 model,
76 base_url: DEFAULT_BASE_URL.to_owned(),
77 }
78 }
79
80 #[must_use]
82 pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
83 Self {
84 client: reqwest::Client::new(),
85 api_key,
86 model,
87 base_url,
88 }
89 }
90
91 #[must_use]
93 pub fn gpt52_instant(api_key: String) -> Self {
94 Self::new(api_key, MODEL_GPT52_INSTANT.to_owned())
95 }
96
97 #[must_use]
99 pub fn gpt52_thinking(api_key: String) -> Self {
100 Self::new(api_key, MODEL_GPT52_THINKING.to_owned())
101 }
102
103 #[must_use]
105 pub fn gpt52_pro(api_key: String) -> Self {
106 Self::new(api_key, MODEL_GPT52_PRO.to_owned())
107 }
108
109 #[must_use]
113 pub fn codex(api_key: String) -> Self {
114 Self::new(api_key, MODEL_GPT52_CODEX.to_owned())
115 }
116
117 #[must_use]
119 pub fn gpt5(api_key: String) -> Self {
120 Self::new(api_key, MODEL_GPT5.to_owned())
121 }
122
123 #[must_use]
125 pub fn gpt5_mini(api_key: String) -> Self {
126 Self::new(api_key, MODEL_GPT5_MINI.to_owned())
127 }
128
129 #[must_use]
131 pub fn gpt5_nano(api_key: String) -> Self {
132 Self::new(api_key, MODEL_GPT5_NANO.to_owned())
133 }
134
135 #[must_use]
137 pub fn o3(api_key: String) -> Self {
138 Self::new(api_key, MODEL_O3.to_owned())
139 }
140
141 #[must_use]
143 pub fn o3_mini(api_key: String) -> Self {
144 Self::new(api_key, MODEL_O3_MINI.to_owned())
145 }
146
147 #[must_use]
149 pub fn o4_mini(api_key: String) -> Self {
150 Self::new(api_key, MODEL_O4_MINI.to_owned())
151 }
152
153 #[must_use]
155 pub fn o1(api_key: String) -> Self {
156 Self::new(api_key, MODEL_O1.to_owned())
157 }
158
159 #[must_use]
161 pub fn o1_mini(api_key: String) -> Self {
162 Self::new(api_key, MODEL_O1_MINI.to_owned())
163 }
164
165 #[must_use]
167 pub fn gpt41(api_key: String) -> Self {
168 Self::new(api_key, MODEL_GPT41.to_owned())
169 }
170
171 #[must_use]
173 pub fn gpt41_mini(api_key: String) -> Self {
174 Self::new(api_key, MODEL_GPT41_MINI.to_owned())
175 }
176
177 #[must_use]
179 pub fn gpt4o(api_key: String) -> Self {
180 Self::new(api_key, MODEL_GPT4O.to_owned())
181 }
182
183 #[must_use]
185 pub fn gpt4o_mini(api_key: String) -> Self {
186 Self::new(api_key, MODEL_GPT4O_MINI.to_owned())
187 }
188}
189
190#[async_trait]
191impl LlmProvider for OpenAIProvider {
192 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
193 if requires_responses_api(&self.model) {
195 let responses_provider =
196 OpenAIResponsesProvider::new(self.api_key.clone(), self.model.clone());
197 return responses_provider.chat(request).await;
198 }
199
200 let messages = build_api_messages(&request);
201 let tools: Option<Vec<ApiTool>> = request
202 .tools
203 .map(|ts| ts.into_iter().map(convert_tool).collect());
204
205 let api_request = ApiChatRequest {
206 model: &self.model,
207 messages: &messages,
208 max_completion_tokens: Some(request.max_tokens),
209 tools: tools.as_deref(),
210 };
211
212 tracing::debug!(
213 model = %self.model,
214 max_tokens = request.max_tokens,
215 "OpenAI LLM request"
216 );
217
218 let response = self
219 .client
220 .post(format!("{}/chat/completions", self.base_url))
221 .header("Content-Type", "application/json")
222 .header("Authorization", format!("Bearer {}", self.api_key))
223 .json(&api_request)
224 .send()
225 .await
226 .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
227
228 let status = response.status();
229 let bytes = response
230 .bytes()
231 .await
232 .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
233
234 tracing::debug!(
235 status = %status,
236 body_len = bytes.len(),
237 "OpenAI LLM response"
238 );
239
240 if status == StatusCode::TOO_MANY_REQUESTS {
241 return Ok(ChatOutcome::RateLimited);
242 }
243
244 if status.is_server_error() {
245 let body = String::from_utf8_lossy(&bytes);
246 tracing::error!(status = %status, body = %body, "OpenAI server error");
247 return Ok(ChatOutcome::ServerError(body.into_owned()));
248 }
249
250 if status.is_client_error() {
251 let body = String::from_utf8_lossy(&bytes);
252 tracing::warn!(status = %status, body = %body, "OpenAI client error");
253 return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
254 }
255
256 let api_response: ApiChatResponse = serde_json::from_slice(&bytes)
257 .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
258
259 let choice = api_response
260 .choices
261 .into_iter()
262 .next()
263 .ok_or_else(|| anyhow::anyhow!("no choices in response"))?;
264
265 let content = build_content_blocks(&choice.message);
266
267 let stop_reason = choice.finish_reason.map(|r| match r {
268 ApiFinishReason::Stop => StopReason::EndTurn,
269 ApiFinishReason::ToolCalls => StopReason::ToolUse,
270 ApiFinishReason::Length => StopReason::MaxTokens,
271 ApiFinishReason::ContentFilter => StopReason::StopSequence,
272 });
273
274 Ok(ChatOutcome::Success(ChatResponse {
275 id: api_response.id,
276 content,
277 model: api_response.model,
278 stop_reason,
279 usage: Usage {
280 input_tokens: api_response.usage.prompt_tokens,
281 output_tokens: api_response.usage.completion_tokens,
282 },
283 }))
284 }
285
286 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
287 if requires_responses_api(&self.model) {
289 let api_key = self.api_key.clone();
290 let model = self.model.clone();
291 return Box::pin(async_stream::stream! {
292 let responses_provider = OpenAIResponsesProvider::new(api_key, model);
293 let mut stream = std::pin::pin!(responses_provider.chat_stream(request));
294 while let Some(item) = futures::StreamExt::next(&mut stream).await {
295 yield item;
296 }
297 });
298 }
299
300 Box::pin(async_stream::stream! {
301 let messages = build_api_messages(&request);
302 let tools: Option<Vec<ApiTool>> = request
303 .tools
304 .map(|ts| ts.into_iter().map(convert_tool).collect());
305
306 let api_request = ApiChatRequestStreaming { model: &self.model, messages: &messages, max_completion_tokens: Some(request.max_tokens), tools: tools.as_deref(), stream: true };
307
308 tracing::debug!(model = %self.model, max_tokens = request.max_tokens, "OpenAI streaming LLM request");
309
310 let Ok(response) = self.client
311 .post(format!("{}/chat/completions", self.base_url))
312 .header("Content-Type", "application/json")
313 .header("Authorization", format!("Bearer {}", self.api_key))
314 .json(&api_request)
315 .send()
316 .await
317 else {
318 yield Err(anyhow::anyhow!("request failed"));
319 return;
320 };
321
322 let status = response.status();
323
324 if !status.is_success() {
325 let body = response.text().await.unwrap_or_default();
326 let (recoverable, level) = if status == StatusCode::TOO_MANY_REQUESTS {
327 (true, "rate_limit")
328 } else if status.is_server_error() {
329 (true, "server_error")
330 } else {
331 (false, "client_error")
332 };
333 tracing::warn!(status = %status, body = %body, kind = level, "OpenAI error");
334 yield Ok(StreamDelta::Error { message: body, recoverable });
335 return;
336 }
337
338 let mut tool_calls: std::collections::HashMap<usize, ToolCallAccumulator> =
340 std::collections::HashMap::new();
341 let mut usage: Option<Usage> = None;
342 let mut buffer = String::new();
343 let mut stream = response.bytes_stream();
344
345 while let Some(chunk_result) = stream.next().await {
346 let Ok(chunk) = chunk_result else {
347 yield Err(anyhow::anyhow!("stream error: {}", chunk_result.unwrap_err()));
348 return;
349 };
350 buffer.push_str(&String::from_utf8_lossy(&chunk));
351
352 while let Some(pos) = buffer.find('\n') {
353 let line = buffer[..pos].trim().to_string();
354 buffer = buffer[pos + 1..].to_string();
355 if line.is_empty() { continue; }
356 let Some(data) = line.strip_prefix("data: ") else { continue; };
357
358 for result in process_sse_data(data) {
359 match result {
360 SseProcessResult::TextDelta(c) => yield Ok(StreamDelta::TextDelta { delta: c, block_index: 0 }),
361 SseProcessResult::ToolCallUpdate { index, id, name, arguments } => apply_tool_call_update(&mut tool_calls, index, id, name, arguments),
362 SseProcessResult::Usage(u) => usage = Some(u),
363 SseProcessResult::Done(sr) => {
364 for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
365 return;
366 }
367 SseProcessResult::Sentinel => {
368 let sr = if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
369 for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
370 return;
371 }
372 }
373 }
374 }
375 }
376
377 for delta in build_stream_end_deltas(&tool_calls, usage, StopReason::EndTurn) {
379 yield Ok(delta);
380 }
381 })
382 }
383
384 fn model(&self) -> &str {
385 &self.model
386 }
387
388 fn provider(&self) -> &'static str {
389 "openai"
390 }
391}
392
393fn apply_tool_call_update(
395 tool_calls: &mut std::collections::HashMap<usize, ToolCallAccumulator>,
396 index: usize,
397 id: Option<String>,
398 name: Option<String>,
399 arguments: Option<String>,
400) {
401 let entry = tool_calls
402 .entry(index)
403 .or_insert_with(|| ToolCallAccumulator {
404 id: String::new(),
405 name: String::new(),
406 arguments: String::new(),
407 });
408 if let Some(id) = id {
409 entry.id = id;
410 }
411 if let Some(name) = name {
412 entry.name = name;
413 }
414 if let Some(args) = arguments {
415 entry.arguments.push_str(&args);
416 }
417}
418
419fn build_stream_end_deltas(
421 tool_calls: &std::collections::HashMap<usize, ToolCallAccumulator>,
422 usage: Option<Usage>,
423 stop_reason: StopReason,
424) -> Vec<StreamDelta> {
425 let mut deltas = Vec::new();
426
427 for (idx, tool) in tool_calls {
429 deltas.push(StreamDelta::ToolUseStart {
430 id: tool.id.clone(),
431 name: tool.name.clone(),
432 block_index: *idx + 1,
433 });
434 deltas.push(StreamDelta::ToolInputDelta {
435 id: tool.id.clone(),
436 delta: tool.arguments.clone(),
437 block_index: *idx + 1,
438 });
439 }
440
441 if let Some(u) = usage {
443 deltas.push(StreamDelta::Usage(u));
444 }
445
446 deltas.push(StreamDelta::Done {
448 stop_reason: Some(stop_reason),
449 });
450
451 deltas
452}
453
454enum SseProcessResult {
456 TextDelta(String),
458 ToolCallUpdate {
460 index: usize,
461 id: Option<String>,
462 name: Option<String>,
463 arguments: Option<String>,
464 },
465 Usage(Usage),
467 Done(StopReason),
469 Sentinel,
471}
472
473fn process_sse_data(data: &str) -> Vec<SseProcessResult> {
475 if data == "[DONE]" {
476 return vec![SseProcessResult::Sentinel];
477 }
478
479 let Ok(chunk) = serde_json::from_str::<SseChunk>(data) else {
480 return vec![];
481 };
482
483 let mut results = Vec::new();
484
485 if let Some(u) = chunk.usage {
487 results.push(SseProcessResult::Usage(Usage {
488 input_tokens: u.prompt_tokens,
489 output_tokens: u.completion_tokens,
490 }));
491 }
492
493 if let Some(choice) = chunk.choices.into_iter().next() {
495 if let Some(content) = choice.delta.content
497 && !content.is_empty()
498 {
499 results.push(SseProcessResult::TextDelta(content));
500 }
501
502 if let Some(tc_deltas) = choice.delta.tool_calls {
504 for tc in tc_deltas {
505 results.push(SseProcessResult::ToolCallUpdate {
506 index: tc.index,
507 id: tc.id,
508 name: tc.function.as_ref().and_then(|f| f.name.clone()),
509 arguments: tc.function.as_ref().and_then(|f| f.arguments.clone()),
510 });
511 }
512 }
513
514 if let Some(finish_reason) = choice.finish_reason {
516 let stop_reason = match finish_reason {
517 SseFinishReason::Stop => StopReason::EndTurn,
518 SseFinishReason::ToolCalls => StopReason::ToolUse,
519 SseFinishReason::Length => StopReason::MaxTokens,
520 SseFinishReason::ContentFilter => StopReason::StopSequence,
521 };
522 results.push(SseProcessResult::Done(stop_reason));
523 }
524 }
525
526 results
527}
528
529fn build_api_messages(request: &ChatRequest) -> Vec<ApiMessage> {
530 let mut messages = Vec::new();
531
532 if !request.system.is_empty() {
534 messages.push(ApiMessage {
535 role: ApiRole::System,
536 content: Some(request.system.clone()),
537 tool_calls: None,
538 tool_call_id: None,
539 });
540 }
541
542 for msg in &request.messages {
544 match &msg.content {
545 Content::Text(text) => {
546 messages.push(ApiMessage {
547 role: match msg.role {
548 crate::llm::Role::User => ApiRole::User,
549 crate::llm::Role::Assistant => ApiRole::Assistant,
550 },
551 content: Some(text.clone()),
552 tool_calls: None,
553 tool_call_id: None,
554 });
555 }
556 Content::Blocks(blocks) => {
557 let mut text_parts = Vec::new();
559 let mut tool_calls = Vec::new();
560
561 for block in blocks {
562 match block {
563 ContentBlock::Text { text } => {
564 text_parts.push(text.clone());
565 }
566 ContentBlock::ToolUse {
567 id, name, input, ..
568 } => {
569 tool_calls.push(ApiToolCall {
570 id: id.clone(),
571 r#type: "function".to_owned(),
572 function: ApiFunctionCall {
573 name: name.clone(),
574 arguments: serde_json::to_string(input)
575 .unwrap_or_else(|_| "{}".to_owned()),
576 },
577 });
578 }
579 ContentBlock::ToolResult {
580 tool_use_id,
581 content,
582 ..
583 } => {
584 messages.push(ApiMessage {
586 role: ApiRole::Tool,
587 content: Some(content.clone()),
588 tool_calls: None,
589 tool_call_id: Some(tool_use_id.clone()),
590 });
591 }
592 }
593 }
594
595 if !text_parts.is_empty() || !tool_calls.is_empty() {
597 let role = match msg.role {
598 crate::llm::Role::User => ApiRole::User,
599 crate::llm::Role::Assistant => ApiRole::Assistant,
600 };
601
602 if role == ApiRole::Assistant || !text_parts.is_empty() {
604 messages.push(ApiMessage {
605 role,
606 content: if text_parts.is_empty() {
607 None
608 } else {
609 Some(text_parts.join("\n"))
610 },
611 tool_calls: if tool_calls.is_empty() {
612 None
613 } else {
614 Some(tool_calls)
615 },
616 tool_call_id: None,
617 });
618 }
619 }
620 }
621 }
622 }
623
624 messages
625}
626
627fn convert_tool(t: crate::llm::Tool) -> ApiTool {
628 ApiTool {
629 r#type: "function".to_owned(),
630 function: ApiFunction {
631 name: t.name,
632 description: t.description,
633 parameters: t.input_schema,
634 },
635 }
636}
637
638fn build_content_blocks(message: &ApiResponseMessage) -> Vec<ContentBlock> {
639 let mut blocks = Vec::new();
640
641 if let Some(content) = &message.content
643 && !content.is_empty()
644 {
645 blocks.push(ContentBlock::Text {
646 text: content.clone(),
647 });
648 }
649
650 if let Some(tool_calls) = &message.tool_calls {
652 for tc in tool_calls {
653 let input: serde_json::Value =
654 serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null);
655 blocks.push(ContentBlock::ToolUse {
656 id: tc.id.clone(),
657 name: tc.function.name.clone(),
658 input,
659 thought_signature: None,
660 });
661 }
662 }
663
664 blocks
665}
666
667#[derive(Serialize)]
672struct ApiChatRequest<'a> {
673 model: &'a str,
674 messages: &'a [ApiMessage],
675 #[serde(skip_serializing_if = "Option::is_none")]
676 max_completion_tokens: Option<u32>,
677 #[serde(skip_serializing_if = "Option::is_none")]
678 tools: Option<&'a [ApiTool]>,
679}
680
681#[derive(Serialize)]
682struct ApiChatRequestStreaming<'a> {
683 model: &'a str,
684 messages: &'a [ApiMessage],
685 #[serde(skip_serializing_if = "Option::is_none")]
686 max_completion_tokens: Option<u32>,
687 #[serde(skip_serializing_if = "Option::is_none")]
688 tools: Option<&'a [ApiTool]>,
689 stream: bool,
690}
691
692#[derive(Serialize)]
693struct ApiMessage {
694 role: ApiRole,
695 #[serde(skip_serializing_if = "Option::is_none")]
696 content: Option<String>,
697 #[serde(skip_serializing_if = "Option::is_none")]
698 tool_calls: Option<Vec<ApiToolCall>>,
699 #[serde(skip_serializing_if = "Option::is_none")]
700 tool_call_id: Option<String>,
701}
702
703#[derive(Debug, Serialize, PartialEq, Eq)]
704#[serde(rename_all = "lowercase")]
705enum ApiRole {
706 System,
707 User,
708 Assistant,
709 Tool,
710}
711
712#[derive(Serialize)]
713struct ApiToolCall {
714 id: String,
715 r#type: String,
716 function: ApiFunctionCall,
717}
718
719#[derive(Serialize)]
720struct ApiFunctionCall {
721 name: String,
722 arguments: String,
723}
724
725#[derive(Serialize)]
726struct ApiTool {
727 r#type: String,
728 function: ApiFunction,
729}
730
731#[derive(Serialize)]
732struct ApiFunction {
733 name: String,
734 description: String,
735 parameters: serde_json::Value,
736}
737
738#[derive(Deserialize)]
743struct ApiChatResponse {
744 id: String,
745 choices: Vec<ApiChoice>,
746 model: String,
747 usage: ApiUsage,
748}
749
750#[derive(Deserialize)]
751struct ApiChoice {
752 message: ApiResponseMessage,
753 finish_reason: Option<ApiFinishReason>,
754}
755
756#[derive(Deserialize)]
757struct ApiResponseMessage {
758 content: Option<String>,
759 tool_calls: Option<Vec<ApiResponseToolCall>>,
760}
761
762#[derive(Deserialize)]
763struct ApiResponseToolCall {
764 id: String,
765 function: ApiResponseFunctionCall,
766}
767
768#[derive(Deserialize)]
769struct ApiResponseFunctionCall {
770 name: String,
771 arguments: String,
772}
773
774#[derive(Deserialize)]
775#[serde(rename_all = "snake_case")]
776enum ApiFinishReason {
777 Stop,
778 ToolCalls,
779 Length,
780 ContentFilter,
781}
782
783#[derive(Deserialize)]
784struct ApiUsage {
785 prompt_tokens: u32,
786 completion_tokens: u32,
787}
788
789struct ToolCallAccumulator {
795 id: String,
796 name: String,
797 arguments: String,
798}
799
800#[derive(Deserialize)]
802struct SseChunk {
803 choices: Vec<SseChoice>,
804 #[serde(default)]
805 usage: Option<SseUsage>,
806}
807
808#[derive(Deserialize)]
809struct SseChoice {
810 delta: SseDelta,
811 finish_reason: Option<SseFinishReason>,
812}
813
814#[derive(Deserialize)]
815struct SseDelta {
816 content: Option<String>,
817 tool_calls: Option<Vec<SseToolCallDelta>>,
818}
819
820#[derive(Deserialize)]
821struct SseToolCallDelta {
822 index: usize,
823 id: Option<String>,
824 function: Option<SseFunctionDelta>,
825}
826
827#[derive(Deserialize)]
828struct SseFunctionDelta {
829 name: Option<String>,
830 arguments: Option<String>,
831}
832
833#[derive(Deserialize)]
834#[serde(rename_all = "snake_case")]
835enum SseFinishReason {
836 Stop,
837 ToolCalls,
838 Length,
839 ContentFilter,
840}
841
842#[derive(Deserialize)]
843struct SseUsage {
844 prompt_tokens: u32,
845 completion_tokens: u32,
846}
847
848#[cfg(test)]
849mod tests {
850 use super::*;
851
852 #[test]
857 fn test_new_creates_provider_with_custom_model() {
858 let provider = OpenAIProvider::new("test-api-key".to_string(), "custom-model".to_string());
859
860 assert_eq!(provider.model(), "custom-model");
861 assert_eq!(provider.provider(), "openai");
862 assert_eq!(provider.base_url, DEFAULT_BASE_URL);
863 }
864
865 #[test]
866 fn test_with_base_url_creates_provider_with_custom_url() {
867 let provider = OpenAIProvider::with_base_url(
868 "test-api-key".to_string(),
869 "llama3".to_string(),
870 "http://localhost:11434/v1".to_string(),
871 );
872
873 assert_eq!(provider.model(), "llama3");
874 assert_eq!(provider.base_url, "http://localhost:11434/v1");
875 }
876
877 #[test]
878 fn test_gpt4o_factory_creates_gpt4o_provider() {
879 let provider = OpenAIProvider::gpt4o("test-api-key".to_string());
880
881 assert_eq!(provider.model(), MODEL_GPT4O);
882 assert_eq!(provider.provider(), "openai");
883 }
884
885 #[test]
886 fn test_gpt4o_mini_factory_creates_gpt4o_mini_provider() {
887 let provider = OpenAIProvider::gpt4o_mini("test-api-key".to_string());
888
889 assert_eq!(provider.model(), MODEL_GPT4O_MINI);
890 assert_eq!(provider.provider(), "openai");
891 }
892
893 #[test]
894 fn test_gpt52_thinking_factory_creates_provider() {
895 let provider = OpenAIProvider::gpt52_thinking("test-api-key".to_string());
896
897 assert_eq!(provider.model(), MODEL_GPT52_THINKING);
898 assert_eq!(provider.provider(), "openai");
899 }
900
901 #[test]
902 fn test_gpt5_factory_creates_gpt5_provider() {
903 let provider = OpenAIProvider::gpt5("test-api-key".to_string());
904
905 assert_eq!(provider.model(), MODEL_GPT5);
906 assert_eq!(provider.provider(), "openai");
907 }
908
909 #[test]
910 fn test_gpt5_mini_factory_creates_provider() {
911 let provider = OpenAIProvider::gpt5_mini("test-api-key".to_string());
912
913 assert_eq!(provider.model(), MODEL_GPT5_MINI);
914 assert_eq!(provider.provider(), "openai");
915 }
916
917 #[test]
918 fn test_o3_factory_creates_o3_provider() {
919 let provider = OpenAIProvider::o3("test-api-key".to_string());
920
921 assert_eq!(provider.model(), MODEL_O3);
922 assert_eq!(provider.provider(), "openai");
923 }
924
925 #[test]
926 fn test_o4_mini_factory_creates_o4_mini_provider() {
927 let provider = OpenAIProvider::o4_mini("test-api-key".to_string());
928
929 assert_eq!(provider.model(), MODEL_O4_MINI);
930 assert_eq!(provider.provider(), "openai");
931 }
932
933 #[test]
934 fn test_o1_factory_creates_o1_provider() {
935 let provider = OpenAIProvider::o1("test-api-key".to_string());
936
937 assert_eq!(provider.model(), MODEL_O1);
938 assert_eq!(provider.provider(), "openai");
939 }
940
941 #[test]
942 fn test_gpt41_factory_creates_gpt41_provider() {
943 let provider = OpenAIProvider::gpt41("test-api-key".to_string());
944
945 assert_eq!(provider.model(), MODEL_GPT41);
946 assert_eq!(provider.provider(), "openai");
947 }
948
949 #[test]
954 fn test_model_constants_have_expected_values() {
955 assert_eq!(MODEL_GPT52_INSTANT, "gpt-5.2-instant");
957 assert_eq!(MODEL_GPT52_THINKING, "gpt-5.2-thinking");
958 assert_eq!(MODEL_GPT52_PRO, "gpt-5.2-pro");
959 assert_eq!(MODEL_GPT5, "gpt-5");
961 assert_eq!(MODEL_GPT5_MINI, "gpt-5-mini");
962 assert_eq!(MODEL_GPT5_NANO, "gpt-5-nano");
963 assert_eq!(MODEL_O3, "o3");
965 assert_eq!(MODEL_O3_MINI, "o3-mini");
966 assert_eq!(MODEL_O4_MINI, "o4-mini");
967 assert_eq!(MODEL_O1, "o1");
968 assert_eq!(MODEL_O1_MINI, "o1-mini");
969 assert_eq!(MODEL_GPT41, "gpt-4.1");
971 assert_eq!(MODEL_GPT41_MINI, "gpt-4.1-mini");
972 assert_eq!(MODEL_GPT41_NANO, "gpt-4.1-nano");
973 assert_eq!(MODEL_GPT4O, "gpt-4o");
975 assert_eq!(MODEL_GPT4O_MINI, "gpt-4o-mini");
976 }
977
978 #[test]
983 fn test_provider_is_cloneable() {
984 let provider = OpenAIProvider::new("test-api-key".to_string(), "test-model".to_string());
985 let cloned = provider.clone();
986
987 assert_eq!(provider.model(), cloned.model());
988 assert_eq!(provider.provider(), cloned.provider());
989 assert_eq!(provider.base_url, cloned.base_url);
990 }
991
992 #[test]
997 fn test_api_role_serialization() {
998 let system_role = ApiRole::System;
999 let user_role = ApiRole::User;
1000 let assistant_role = ApiRole::Assistant;
1001 let tool_role = ApiRole::Tool;
1002
1003 assert_eq!(serde_json::to_string(&system_role).unwrap(), "\"system\"");
1004 assert_eq!(serde_json::to_string(&user_role).unwrap(), "\"user\"");
1005 assert_eq!(
1006 serde_json::to_string(&assistant_role).unwrap(),
1007 "\"assistant\""
1008 );
1009 assert_eq!(serde_json::to_string(&tool_role).unwrap(), "\"tool\"");
1010 }
1011
1012 #[test]
1013 fn test_api_message_serialization_simple() {
1014 let message = ApiMessage {
1015 role: ApiRole::User,
1016 content: Some("Hello, world!".to_string()),
1017 tool_calls: None,
1018 tool_call_id: None,
1019 };
1020
1021 let json = serde_json::to_string(&message).unwrap();
1022 assert!(json.contains("\"role\":\"user\""));
1023 assert!(json.contains("\"content\":\"Hello, world!\""));
1024 assert!(!json.contains("tool_calls"));
1026 assert!(!json.contains("tool_call_id"));
1027 }
1028
1029 #[test]
1030 fn test_api_message_serialization_with_tool_calls() {
1031 let message = ApiMessage {
1032 role: ApiRole::Assistant,
1033 content: Some("Let me help.".to_string()),
1034 tool_calls: Some(vec![ApiToolCall {
1035 id: "call_123".to_string(),
1036 r#type: "function".to_string(),
1037 function: ApiFunctionCall {
1038 name: "read_file".to_string(),
1039 arguments: "{\"path\": \"/test.txt\"}".to_string(),
1040 },
1041 }]),
1042 tool_call_id: None,
1043 };
1044
1045 let json = serde_json::to_string(&message).unwrap();
1046 assert!(json.contains("\"role\":\"assistant\""));
1047 assert!(json.contains("\"tool_calls\""));
1048 assert!(json.contains("\"id\":\"call_123\""));
1049 assert!(json.contains("\"type\":\"function\""));
1050 assert!(json.contains("\"name\":\"read_file\""));
1051 }
1052
1053 #[test]
1054 fn test_api_tool_message_serialization() {
1055 let message = ApiMessage {
1056 role: ApiRole::Tool,
1057 content: Some("File contents here".to_string()),
1058 tool_calls: None,
1059 tool_call_id: Some("call_123".to_string()),
1060 };
1061
1062 let json = serde_json::to_string(&message).unwrap();
1063 assert!(json.contains("\"role\":\"tool\""));
1064 assert!(json.contains("\"tool_call_id\":\"call_123\""));
1065 assert!(json.contains("\"content\":\"File contents here\""));
1066 }
1067
1068 #[test]
1069 fn test_api_tool_serialization() {
1070 let tool = ApiTool {
1071 r#type: "function".to_string(),
1072 function: ApiFunction {
1073 name: "test_tool".to_string(),
1074 description: "A test tool".to_string(),
1075 parameters: serde_json::json!({
1076 "type": "object",
1077 "properties": {
1078 "arg": {"type": "string"}
1079 }
1080 }),
1081 },
1082 };
1083
1084 let json = serde_json::to_string(&tool).unwrap();
1085 assert!(json.contains("\"type\":\"function\""));
1086 assert!(json.contains("\"name\":\"test_tool\""));
1087 assert!(json.contains("\"description\":\"A test tool\""));
1088 assert!(json.contains("\"parameters\""));
1089 }
1090
1091 #[test]
1096 fn test_api_response_deserialization() {
1097 let json = r#"{
1098 "id": "chatcmpl-123",
1099 "choices": [
1100 {
1101 "message": {
1102 "content": "Hello!"
1103 },
1104 "finish_reason": "stop"
1105 }
1106 ],
1107 "model": "gpt-4o",
1108 "usage": {
1109 "prompt_tokens": 100,
1110 "completion_tokens": 50
1111 }
1112 }"#;
1113
1114 let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1115 assert_eq!(response.id, "chatcmpl-123");
1116 assert_eq!(response.model, "gpt-4o");
1117 assert_eq!(response.usage.prompt_tokens, 100);
1118 assert_eq!(response.usage.completion_tokens, 50);
1119 assert_eq!(response.choices.len(), 1);
1120 assert_eq!(
1121 response.choices[0].message.content,
1122 Some("Hello!".to_string())
1123 );
1124 }
1125
1126 #[test]
1127 fn test_api_response_with_tool_calls_deserialization() {
1128 let json = r#"{
1129 "id": "chatcmpl-456",
1130 "choices": [
1131 {
1132 "message": {
1133 "content": null,
1134 "tool_calls": [
1135 {
1136 "id": "call_abc",
1137 "type": "function",
1138 "function": {
1139 "name": "read_file",
1140 "arguments": "{\"path\": \"test.txt\"}"
1141 }
1142 }
1143 ]
1144 },
1145 "finish_reason": "tool_calls"
1146 }
1147 ],
1148 "model": "gpt-4o",
1149 "usage": {
1150 "prompt_tokens": 150,
1151 "completion_tokens": 30
1152 }
1153 }"#;
1154
1155 let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1156 let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1157 assert_eq!(tool_calls.len(), 1);
1158 assert_eq!(tool_calls[0].id, "call_abc");
1159 assert_eq!(tool_calls[0].function.name, "read_file");
1160 }
1161
1162 #[test]
1163 fn test_api_finish_reason_deserialization() {
1164 let stop: ApiFinishReason = serde_json::from_str("\"stop\"").unwrap();
1165 let tool_calls: ApiFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1166 let length: ApiFinishReason = serde_json::from_str("\"length\"").unwrap();
1167 let content_filter: ApiFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1168
1169 assert!(matches!(stop, ApiFinishReason::Stop));
1170 assert!(matches!(tool_calls, ApiFinishReason::ToolCalls));
1171 assert!(matches!(length, ApiFinishReason::Length));
1172 assert!(matches!(content_filter, ApiFinishReason::ContentFilter));
1173 }
1174
1175 #[test]
1180 fn test_build_api_messages_with_system() {
1181 let request = ChatRequest {
1182 system: "You are helpful.".to_string(),
1183 messages: vec![crate::llm::Message::user("Hello")],
1184 tools: None,
1185 max_tokens: 1024,
1186 };
1187
1188 let api_messages = build_api_messages(&request);
1189 assert_eq!(api_messages.len(), 2);
1190 assert_eq!(api_messages[0].role, ApiRole::System);
1191 assert_eq!(
1192 api_messages[0].content,
1193 Some("You are helpful.".to_string())
1194 );
1195 assert_eq!(api_messages[1].role, ApiRole::User);
1196 assert_eq!(api_messages[1].content, Some("Hello".to_string()));
1197 }
1198
1199 #[test]
1200 fn test_build_api_messages_empty_system() {
1201 let request = ChatRequest {
1202 system: String::new(),
1203 messages: vec![crate::llm::Message::user("Hello")],
1204 tools: None,
1205 max_tokens: 1024,
1206 };
1207
1208 let api_messages = build_api_messages(&request);
1209 assert_eq!(api_messages.len(), 1);
1210 assert_eq!(api_messages[0].role, ApiRole::User);
1211 }
1212
1213 #[test]
1214 fn test_convert_tool() {
1215 let tool = crate::llm::Tool {
1216 name: "test_tool".to_string(),
1217 description: "A test tool".to_string(),
1218 input_schema: serde_json::json!({"type": "object"}),
1219 };
1220
1221 let api_tool = convert_tool(tool);
1222 assert_eq!(api_tool.r#type, "function");
1223 assert_eq!(api_tool.function.name, "test_tool");
1224 assert_eq!(api_tool.function.description, "A test tool");
1225 }
1226
1227 #[test]
1228 fn test_build_content_blocks_text_only() {
1229 let message = ApiResponseMessage {
1230 content: Some("Hello!".to_string()),
1231 tool_calls: None,
1232 };
1233
1234 let blocks = build_content_blocks(&message);
1235 assert_eq!(blocks.len(), 1);
1236 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
1237 }
1238
1239 #[test]
1240 fn test_build_content_blocks_with_tool_calls() {
1241 let message = ApiResponseMessage {
1242 content: Some("Let me help.".to_string()),
1243 tool_calls: Some(vec![ApiResponseToolCall {
1244 id: "call_123".to_string(),
1245 function: ApiResponseFunctionCall {
1246 name: "read_file".to_string(),
1247 arguments: "{\"path\": \"test.txt\"}".to_string(),
1248 },
1249 }]),
1250 };
1251
1252 let blocks = build_content_blocks(&message);
1253 assert_eq!(blocks.len(), 2);
1254 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me help."));
1255 assert!(
1256 matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "call_123" && name == "read_file")
1257 );
1258 }
1259
1260 #[test]
1265 fn test_sse_chunk_text_delta_deserialization() {
1266 let json = r#"{
1267 "choices": [{
1268 "delta": {
1269 "content": "Hello"
1270 },
1271 "finish_reason": null
1272 }]
1273 }"#;
1274
1275 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1276 assert_eq!(chunk.choices.len(), 1);
1277 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
1278 assert!(chunk.choices[0].finish_reason.is_none());
1279 }
1280
1281 #[test]
1282 fn test_sse_chunk_tool_call_delta_deserialization() {
1283 let json = r#"{
1284 "choices": [{
1285 "delta": {
1286 "tool_calls": [{
1287 "index": 0,
1288 "id": "call_abc",
1289 "function": {
1290 "name": "read_file",
1291 "arguments": ""
1292 }
1293 }]
1294 },
1295 "finish_reason": null
1296 }]
1297 }"#;
1298
1299 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1300 let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1301 assert_eq!(tool_calls.len(), 1);
1302 assert_eq!(tool_calls[0].index, 0);
1303 assert_eq!(tool_calls[0].id, Some("call_abc".to_string()));
1304 assert_eq!(
1305 tool_calls[0].function.as_ref().unwrap().name,
1306 Some("read_file".to_string())
1307 );
1308 }
1309
1310 #[test]
1311 fn test_sse_chunk_tool_call_arguments_delta_deserialization() {
1312 let json = r#"{
1313 "choices": [{
1314 "delta": {
1315 "tool_calls": [{
1316 "index": 0,
1317 "function": {
1318 "arguments": "{\"path\":"
1319 }
1320 }]
1321 },
1322 "finish_reason": null
1323 }]
1324 }"#;
1325
1326 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1327 let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1328 assert_eq!(tool_calls[0].id, None);
1329 assert_eq!(
1330 tool_calls[0].function.as_ref().unwrap().arguments,
1331 Some("{\"path\":".to_string())
1332 );
1333 }
1334
1335 #[test]
1336 fn test_sse_chunk_with_finish_reason_deserialization() {
1337 let json = r#"{
1338 "choices": [{
1339 "delta": {},
1340 "finish_reason": "stop"
1341 }]
1342 }"#;
1343
1344 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1345 assert!(matches!(
1346 chunk.choices[0].finish_reason,
1347 Some(SseFinishReason::Stop)
1348 ));
1349 }
1350
1351 #[test]
1352 fn test_sse_chunk_with_usage_deserialization() {
1353 let json = r#"{
1354 "choices": [{
1355 "delta": {},
1356 "finish_reason": "stop"
1357 }],
1358 "usage": {
1359 "prompt_tokens": 100,
1360 "completion_tokens": 50
1361 }
1362 }"#;
1363
1364 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1365 let usage = chunk.usage.unwrap();
1366 assert_eq!(usage.prompt_tokens, 100);
1367 assert_eq!(usage.completion_tokens, 50);
1368 }
1369
1370 #[test]
1371 fn test_sse_finish_reason_deserialization() {
1372 let stop: SseFinishReason = serde_json::from_str("\"stop\"").unwrap();
1373 let tool_calls: SseFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1374 let length: SseFinishReason = serde_json::from_str("\"length\"").unwrap();
1375 let content_filter: SseFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1376
1377 assert!(matches!(stop, SseFinishReason::Stop));
1378 assert!(matches!(tool_calls, SseFinishReason::ToolCalls));
1379 assert!(matches!(length, SseFinishReason::Length));
1380 assert!(matches!(content_filter, SseFinishReason::ContentFilter));
1381 }
1382
1383 #[test]
1384 fn test_streaming_request_serialization() {
1385 let messages = vec![ApiMessage {
1386 role: ApiRole::User,
1387 content: Some("Hello".to_string()),
1388 tool_calls: None,
1389 tool_call_id: None,
1390 }];
1391
1392 let request = ApiChatRequestStreaming {
1393 model: "gpt-4o",
1394 messages: &messages,
1395 max_completion_tokens: Some(1024),
1396 tools: None,
1397 stream: true,
1398 };
1399
1400 let json = serde_json::to_string(&request).unwrap();
1401 assert!(json.contains("\"stream\":true"));
1402 assert!(json.contains("\"model\":\"gpt-4o\""));
1403 }
1404}