1use crate::llm::{
11 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, StopReason,
12 StreamBox, StreamDelta, Usage,
13};
14use anyhow::Result;
15use async_trait::async_trait;
16use futures::StreamExt;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19
20use super::openai_responses::OpenAIResponsesProvider;
21
22const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
23
24fn requires_responses_api(model: &str) -> bool {
26 model.contains("codex")
27}
28
29pub const MODEL_GPT52_INSTANT: &str = "gpt-5.2-instant";
31pub const MODEL_GPT52_THINKING: &str = "gpt-5.2-thinking";
32pub const MODEL_GPT52_PRO: &str = "gpt-5.2-pro";
33pub const MODEL_GPT52_CODEX: &str = "gpt-5.2-codex";
34
35pub const MODEL_GPT5: &str = "gpt-5";
37pub const MODEL_GPT5_MINI: &str = "gpt-5-mini";
38pub const MODEL_GPT5_NANO: &str = "gpt-5-nano";
39
40pub const MODEL_O3: &str = "o3";
42pub const MODEL_O3_MINI: &str = "o3-mini";
43pub const MODEL_O4_MINI: &str = "o4-mini";
44pub const MODEL_O1: &str = "o1";
45pub const MODEL_O1_MINI: &str = "o1-mini";
46
47pub const MODEL_GPT41: &str = "gpt-4.1";
49pub const MODEL_GPT41_MINI: &str = "gpt-4.1-mini";
50pub const MODEL_GPT41_NANO: &str = "gpt-4.1-nano";
51
52pub const MODEL_GPT4O: &str = "gpt-4o";
54pub const MODEL_GPT4O_MINI: &str = "gpt-4o-mini";
55
56#[derive(Clone)]
61pub struct OpenAIProvider {
62 client: reqwest::Client,
63 api_key: String,
64 model: String,
65 base_url: String,
66}
67
68impl OpenAIProvider {
69 #[must_use]
71 pub fn new(api_key: String, model: String) -> Self {
72 Self {
73 client: reqwest::Client::new(),
74 api_key,
75 model,
76 base_url: DEFAULT_BASE_URL.to_owned(),
77 }
78 }
79
80 #[must_use]
82 pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
83 Self {
84 client: reqwest::Client::new(),
85 api_key,
86 model,
87 base_url,
88 }
89 }
90
91 #[must_use]
93 pub fn gpt52_instant(api_key: String) -> Self {
94 Self::new(api_key, MODEL_GPT52_INSTANT.to_owned())
95 }
96
97 #[must_use]
99 pub fn gpt52_thinking(api_key: String) -> Self {
100 Self::new(api_key, MODEL_GPT52_THINKING.to_owned())
101 }
102
103 #[must_use]
105 pub fn gpt52_pro(api_key: String) -> Self {
106 Self::new(api_key, MODEL_GPT52_PRO.to_owned())
107 }
108
109 #[must_use]
113 pub fn codex(api_key: String) -> Self {
114 Self::new(api_key, MODEL_GPT52_CODEX.to_owned())
115 }
116
117 #[must_use]
119 pub fn gpt5(api_key: String) -> Self {
120 Self::new(api_key, MODEL_GPT5.to_owned())
121 }
122
123 #[must_use]
125 pub fn gpt5_mini(api_key: String) -> Self {
126 Self::new(api_key, MODEL_GPT5_MINI.to_owned())
127 }
128
129 #[must_use]
131 pub fn gpt5_nano(api_key: String) -> Self {
132 Self::new(api_key, MODEL_GPT5_NANO.to_owned())
133 }
134
135 #[must_use]
137 pub fn o3(api_key: String) -> Self {
138 Self::new(api_key, MODEL_O3.to_owned())
139 }
140
141 #[must_use]
143 pub fn o3_mini(api_key: String) -> Self {
144 Self::new(api_key, MODEL_O3_MINI.to_owned())
145 }
146
147 #[must_use]
149 pub fn o4_mini(api_key: String) -> Self {
150 Self::new(api_key, MODEL_O4_MINI.to_owned())
151 }
152
153 #[must_use]
155 pub fn o1(api_key: String) -> Self {
156 Self::new(api_key, MODEL_O1.to_owned())
157 }
158
159 #[must_use]
161 pub fn o1_mini(api_key: String) -> Self {
162 Self::new(api_key, MODEL_O1_MINI.to_owned())
163 }
164
165 #[must_use]
167 pub fn gpt41(api_key: String) -> Self {
168 Self::new(api_key, MODEL_GPT41.to_owned())
169 }
170
171 #[must_use]
173 pub fn gpt41_mini(api_key: String) -> Self {
174 Self::new(api_key, MODEL_GPT41_MINI.to_owned())
175 }
176
177 #[must_use]
179 pub fn gpt4o(api_key: String) -> Self {
180 Self::new(api_key, MODEL_GPT4O.to_owned())
181 }
182
183 #[must_use]
185 pub fn gpt4o_mini(api_key: String) -> Self {
186 Self::new(api_key, MODEL_GPT4O_MINI.to_owned())
187 }
188}
189
190#[async_trait]
191impl LlmProvider for OpenAIProvider {
192 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
193 if requires_responses_api(&self.model) {
195 let responses_provider =
196 OpenAIResponsesProvider::new(self.api_key.clone(), self.model.clone());
197 return responses_provider.chat(request).await;
198 }
199
200 let messages = build_api_messages(&request);
201 let tools: Option<Vec<ApiTool>> = request
202 .tools
203 .map(|ts| ts.into_iter().map(convert_tool).collect());
204
205 let api_request = ApiChatRequest {
206 model: &self.model,
207 messages: &messages,
208 max_completion_tokens: Some(request.max_tokens),
209 tools: tools.as_deref(),
210 };
211
212 tracing::debug!(
213 model = %self.model,
214 max_tokens = request.max_tokens,
215 "OpenAI LLM request"
216 );
217
218 let response = self
219 .client
220 .post(format!("{}/chat/completions", self.base_url))
221 .header("Content-Type", "application/json")
222 .header("Authorization", format!("Bearer {}", self.api_key))
223 .json(&api_request)
224 .send()
225 .await
226 .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
227
228 let status = response.status();
229 let bytes = response
230 .bytes()
231 .await
232 .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
233
234 tracing::debug!(
235 status = %status,
236 body_len = bytes.len(),
237 "OpenAI LLM response"
238 );
239
240 if status == StatusCode::TOO_MANY_REQUESTS {
241 return Ok(ChatOutcome::RateLimited);
242 }
243
244 if status.is_server_error() {
245 let body = String::from_utf8_lossy(&bytes);
246 tracing::error!(status = %status, body = %body, "OpenAI server error");
247 return Ok(ChatOutcome::ServerError(body.into_owned()));
248 }
249
250 if status.is_client_error() {
251 let body = String::from_utf8_lossy(&bytes);
252 tracing::warn!(status = %status, body = %body, "OpenAI client error");
253 return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
254 }
255
256 let api_response: ApiChatResponse = serde_json::from_slice(&bytes)
257 .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
258
259 let choice = api_response
260 .choices
261 .into_iter()
262 .next()
263 .ok_or_else(|| anyhow::anyhow!("no choices in response"))?;
264
265 let content = build_content_blocks(&choice.message);
266
267 let stop_reason = choice.finish_reason.map(|r| match r {
268 ApiFinishReason::Stop => StopReason::EndTurn,
269 ApiFinishReason::ToolCalls => StopReason::ToolUse,
270 ApiFinishReason::Length => StopReason::MaxTokens,
271 ApiFinishReason::ContentFilter => StopReason::StopSequence,
272 });
273
274 Ok(ChatOutcome::Success(ChatResponse {
275 id: api_response.id,
276 content,
277 model: api_response.model,
278 stop_reason,
279 usage: Usage {
280 input_tokens: api_response.usage.prompt_tokens,
281 output_tokens: api_response.usage.completion_tokens,
282 },
283 }))
284 }
285
286 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
287 if requires_responses_api(&self.model) {
289 let api_key = self.api_key.clone();
290 let model = self.model.clone();
291 return Box::pin(async_stream::stream! {
292 let responses_provider = OpenAIResponsesProvider::new(api_key, model);
293 let mut stream = std::pin::pin!(responses_provider.chat_stream(request));
294 while let Some(item) = futures::StreamExt::next(&mut stream).await {
295 yield item;
296 }
297 });
298 }
299
300 Box::pin(async_stream::stream! {
301 let messages = build_api_messages(&request);
302 let tools: Option<Vec<ApiTool>> = request
303 .tools
304 .map(|ts| ts.into_iter().map(convert_tool).collect());
305
306 let api_request = ApiChatRequestStreaming { model: &self.model, messages: &messages, max_completion_tokens: Some(request.max_tokens), tools: tools.as_deref(), stream: true };
307
308 tracing::debug!(model = %self.model, max_tokens = request.max_tokens, "OpenAI streaming LLM request");
309
310 let Ok(response) = self.client
311 .post(format!("{}/chat/completions", self.base_url))
312 .header("Content-Type", "application/json")
313 .header("Authorization", format!("Bearer {}", self.api_key))
314 .json(&api_request)
315 .send()
316 .await
317 else {
318 yield Err(anyhow::anyhow!("request failed"));
319 return;
320 };
321
322 let status = response.status();
323
324 if !status.is_success() {
325 let body = response.text().await.unwrap_or_default();
326 let (recoverable, level) = if status == StatusCode::TOO_MANY_REQUESTS {
327 (true, "rate_limit")
328 } else if status.is_server_error() {
329 (true, "server_error")
330 } else {
331 (false, "client_error")
332 };
333 tracing::warn!(status = %status, body = %body, kind = level, "OpenAI error");
334 yield Ok(StreamDelta::Error { message: body, recoverable });
335 return;
336 }
337
338 let mut tool_calls: std::collections::HashMap<usize, ToolCallAccumulator> =
340 std::collections::HashMap::new();
341 let mut usage: Option<Usage> = None;
342 let mut buffer = String::new();
343 let mut stream = response.bytes_stream();
344
345 while let Some(chunk_result) = stream.next().await {
346 let Ok(chunk) = chunk_result else {
347 yield Err(anyhow::anyhow!("stream error: {}", chunk_result.unwrap_err()));
348 return;
349 };
350 buffer.push_str(&String::from_utf8_lossy(&chunk));
351
352 while let Some(pos) = buffer.find('\n') {
353 let line = buffer[..pos].trim().to_string();
354 buffer = buffer[pos + 1..].to_string();
355 if line.is_empty() { continue; }
356 let Some(data) = line.strip_prefix("data: ") else { continue; };
357
358 for result in process_sse_data(data) {
359 match result {
360 SseProcessResult::TextDelta(c) => yield Ok(StreamDelta::TextDelta { delta: c, block_index: 0 }),
361 SseProcessResult::ToolCallUpdate { index, id, name, arguments } => apply_tool_call_update(&mut tool_calls, index, id, name, arguments),
362 SseProcessResult::Usage(u) => usage = Some(u),
363 SseProcessResult::Done(sr) => {
364 for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
365 return;
366 }
367 SseProcessResult::Sentinel => {
368 let sr = if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
369 for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
370 return;
371 }
372 }
373 }
374 }
375 }
376
377 for delta in build_stream_end_deltas(&tool_calls, usage, StopReason::EndTurn) {
379 yield Ok(delta);
380 }
381 })
382 }
383
384 fn model(&self) -> &str {
385 &self.model
386 }
387
388 fn provider(&self) -> &'static str {
389 "openai"
390 }
391}
392
393fn apply_tool_call_update(
395 tool_calls: &mut std::collections::HashMap<usize, ToolCallAccumulator>,
396 index: usize,
397 id: Option<String>,
398 name: Option<String>,
399 arguments: Option<String>,
400) {
401 let entry = tool_calls
402 .entry(index)
403 .or_insert_with(|| ToolCallAccumulator {
404 id: String::new(),
405 name: String::new(),
406 arguments: String::new(),
407 });
408 if let Some(id) = id {
409 entry.id = id;
410 }
411 if let Some(name) = name {
412 entry.name = name;
413 }
414 if let Some(args) = arguments {
415 entry.arguments.push_str(&args);
416 }
417}
418
419fn build_stream_end_deltas(
421 tool_calls: &std::collections::HashMap<usize, ToolCallAccumulator>,
422 usage: Option<Usage>,
423 stop_reason: StopReason,
424) -> Vec<StreamDelta> {
425 let mut deltas = Vec::new();
426
427 for (idx, tool) in tool_calls {
429 deltas.push(StreamDelta::ToolUseStart {
430 id: tool.id.clone(),
431 name: tool.name.clone(),
432 block_index: *idx + 1,
433 });
434 deltas.push(StreamDelta::ToolInputDelta {
435 id: tool.id.clone(),
436 delta: tool.arguments.clone(),
437 block_index: *idx + 1,
438 });
439 }
440
441 if let Some(u) = usage {
443 deltas.push(StreamDelta::Usage(u));
444 }
445
446 deltas.push(StreamDelta::Done {
448 stop_reason: Some(stop_reason),
449 });
450
451 deltas
452}
453
454enum SseProcessResult {
456 TextDelta(String),
458 ToolCallUpdate {
460 index: usize,
461 id: Option<String>,
462 name: Option<String>,
463 arguments: Option<String>,
464 },
465 Usage(Usage),
467 Done(StopReason),
469 Sentinel,
471}
472
473fn process_sse_data(data: &str) -> Vec<SseProcessResult> {
475 if data == "[DONE]" {
476 return vec![SseProcessResult::Sentinel];
477 }
478
479 let Ok(chunk) = serde_json::from_str::<SseChunk>(data) else {
480 return vec![];
481 };
482
483 let mut results = Vec::new();
484
485 if let Some(u) = chunk.usage {
487 results.push(SseProcessResult::Usage(Usage {
488 input_tokens: u.prompt_tokens,
489 output_tokens: u.completion_tokens,
490 }));
491 }
492
493 if let Some(choice) = chunk.choices.into_iter().next() {
495 if let Some(content) = choice.delta.content
497 && !content.is_empty()
498 {
499 results.push(SseProcessResult::TextDelta(content));
500 }
501
502 if let Some(tc_deltas) = choice.delta.tool_calls {
504 for tc in tc_deltas {
505 results.push(SseProcessResult::ToolCallUpdate {
506 index: tc.index,
507 id: tc.id,
508 name: tc.function.as_ref().and_then(|f| f.name.clone()),
509 arguments: tc.function.as_ref().and_then(|f| f.arguments.clone()),
510 });
511 }
512 }
513
514 if let Some(finish_reason) = choice.finish_reason {
516 let stop_reason = match finish_reason {
517 SseFinishReason::Stop => StopReason::EndTurn,
518 SseFinishReason::ToolCalls => StopReason::ToolUse,
519 SseFinishReason::Length => StopReason::MaxTokens,
520 SseFinishReason::ContentFilter => StopReason::StopSequence,
521 };
522 results.push(SseProcessResult::Done(stop_reason));
523 }
524 }
525
526 results
527}
528
529fn build_api_messages(request: &ChatRequest) -> Vec<ApiMessage> {
530 let mut messages = Vec::new();
531
532 if !request.system.is_empty() {
534 messages.push(ApiMessage {
535 role: ApiRole::System,
536 content: Some(request.system.clone()),
537 tool_calls: None,
538 tool_call_id: None,
539 });
540 }
541
542 for msg in &request.messages {
544 match &msg.content {
545 Content::Text(text) => {
546 messages.push(ApiMessage {
547 role: match msg.role {
548 crate::llm::Role::User => ApiRole::User,
549 crate::llm::Role::Assistant => ApiRole::Assistant,
550 },
551 content: Some(text.clone()),
552 tool_calls: None,
553 tool_call_id: None,
554 });
555 }
556 Content::Blocks(blocks) => {
557 let mut text_parts = Vec::new();
559 let mut tool_calls = Vec::new();
560
561 for block in blocks {
562 match block {
563 ContentBlock::Text { text } => {
564 text_parts.push(text.clone());
565 }
566 ContentBlock::Thinking { .. } => {
567 }
569 ContentBlock::ToolUse {
570 id, name, input, ..
571 } => {
572 tool_calls.push(ApiToolCall {
573 id: id.clone(),
574 r#type: "function".to_owned(),
575 function: ApiFunctionCall {
576 name: name.clone(),
577 arguments: serde_json::to_string(input)
578 .unwrap_or_else(|_| "{}".to_owned()),
579 },
580 });
581 }
582 ContentBlock::ToolResult {
583 tool_use_id,
584 content,
585 ..
586 } => {
587 messages.push(ApiMessage {
589 role: ApiRole::Tool,
590 content: Some(content.clone()),
591 tool_calls: None,
592 tool_call_id: Some(tool_use_id.clone()),
593 });
594 }
595 }
596 }
597
598 if !text_parts.is_empty() || !tool_calls.is_empty() {
600 let role = match msg.role {
601 crate::llm::Role::User => ApiRole::User,
602 crate::llm::Role::Assistant => ApiRole::Assistant,
603 };
604
605 if role == ApiRole::Assistant || !text_parts.is_empty() {
607 messages.push(ApiMessage {
608 role,
609 content: if text_parts.is_empty() {
610 None
611 } else {
612 Some(text_parts.join("\n"))
613 },
614 tool_calls: if tool_calls.is_empty() {
615 None
616 } else {
617 Some(tool_calls)
618 },
619 tool_call_id: None,
620 });
621 }
622 }
623 }
624 }
625 }
626
627 messages
628}
629
630fn convert_tool(t: crate::llm::Tool) -> ApiTool {
631 ApiTool {
632 r#type: "function".to_owned(),
633 function: ApiFunction {
634 name: t.name,
635 description: t.description,
636 parameters: t.input_schema,
637 },
638 }
639}
640
641fn build_content_blocks(message: &ApiResponseMessage) -> Vec<ContentBlock> {
642 let mut blocks = Vec::new();
643
644 if let Some(content) = &message.content
646 && !content.is_empty()
647 {
648 blocks.push(ContentBlock::Text {
649 text: content.clone(),
650 });
651 }
652
653 if let Some(tool_calls) = &message.tool_calls {
655 for tc in tool_calls {
656 let input: serde_json::Value =
657 serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null);
658 blocks.push(ContentBlock::ToolUse {
659 id: tc.id.clone(),
660 name: tc.function.name.clone(),
661 input,
662 thought_signature: None,
663 });
664 }
665 }
666
667 blocks
668}
669
670#[derive(Serialize)]
675struct ApiChatRequest<'a> {
676 model: &'a str,
677 messages: &'a [ApiMessage],
678 #[serde(skip_serializing_if = "Option::is_none")]
679 max_completion_tokens: Option<u32>,
680 #[serde(skip_serializing_if = "Option::is_none")]
681 tools: Option<&'a [ApiTool]>,
682}
683
684#[derive(Serialize)]
685struct ApiChatRequestStreaming<'a> {
686 model: &'a str,
687 messages: &'a [ApiMessage],
688 #[serde(skip_serializing_if = "Option::is_none")]
689 max_completion_tokens: Option<u32>,
690 #[serde(skip_serializing_if = "Option::is_none")]
691 tools: Option<&'a [ApiTool]>,
692 stream: bool,
693}
694
695#[derive(Serialize)]
696struct ApiMessage {
697 role: ApiRole,
698 #[serde(skip_serializing_if = "Option::is_none")]
699 content: Option<String>,
700 #[serde(skip_serializing_if = "Option::is_none")]
701 tool_calls: Option<Vec<ApiToolCall>>,
702 #[serde(skip_serializing_if = "Option::is_none")]
703 tool_call_id: Option<String>,
704}
705
706#[derive(Debug, Serialize, PartialEq, Eq)]
707#[serde(rename_all = "lowercase")]
708enum ApiRole {
709 System,
710 User,
711 Assistant,
712 Tool,
713}
714
715#[derive(Serialize)]
716struct ApiToolCall {
717 id: String,
718 r#type: String,
719 function: ApiFunctionCall,
720}
721
722#[derive(Serialize)]
723struct ApiFunctionCall {
724 name: String,
725 arguments: String,
726}
727
728#[derive(Serialize)]
729struct ApiTool {
730 r#type: String,
731 function: ApiFunction,
732}
733
734#[derive(Serialize)]
735struct ApiFunction {
736 name: String,
737 description: String,
738 parameters: serde_json::Value,
739}
740
741#[derive(Deserialize)]
746struct ApiChatResponse {
747 id: String,
748 choices: Vec<ApiChoice>,
749 model: String,
750 usage: ApiUsage,
751}
752
753#[derive(Deserialize)]
754struct ApiChoice {
755 message: ApiResponseMessage,
756 finish_reason: Option<ApiFinishReason>,
757}
758
759#[derive(Deserialize)]
760struct ApiResponseMessage {
761 content: Option<String>,
762 tool_calls: Option<Vec<ApiResponseToolCall>>,
763}
764
765#[derive(Deserialize)]
766struct ApiResponseToolCall {
767 id: String,
768 function: ApiResponseFunctionCall,
769}
770
771#[derive(Deserialize)]
772struct ApiResponseFunctionCall {
773 name: String,
774 arguments: String,
775}
776
777#[derive(Deserialize)]
778#[serde(rename_all = "snake_case")]
779enum ApiFinishReason {
780 Stop,
781 ToolCalls,
782 Length,
783 ContentFilter,
784}
785
786#[derive(Deserialize)]
787struct ApiUsage {
788 prompt_tokens: u32,
789 completion_tokens: u32,
790}
791
792struct ToolCallAccumulator {
798 id: String,
799 name: String,
800 arguments: String,
801}
802
803#[derive(Deserialize)]
805struct SseChunk {
806 choices: Vec<SseChoice>,
807 #[serde(default)]
808 usage: Option<SseUsage>,
809}
810
811#[derive(Deserialize)]
812struct SseChoice {
813 delta: SseDelta,
814 finish_reason: Option<SseFinishReason>,
815}
816
817#[derive(Deserialize)]
818struct SseDelta {
819 content: Option<String>,
820 tool_calls: Option<Vec<SseToolCallDelta>>,
821}
822
823#[derive(Deserialize)]
824struct SseToolCallDelta {
825 index: usize,
826 id: Option<String>,
827 function: Option<SseFunctionDelta>,
828}
829
830#[derive(Deserialize)]
831struct SseFunctionDelta {
832 name: Option<String>,
833 arguments: Option<String>,
834}
835
836#[derive(Deserialize)]
837#[serde(rename_all = "snake_case")]
838enum SseFinishReason {
839 Stop,
840 ToolCalls,
841 Length,
842 ContentFilter,
843}
844
845#[derive(Deserialize)]
846struct SseUsage {
847 prompt_tokens: u32,
848 completion_tokens: u32,
849}
850
851#[cfg(test)]
852mod tests {
853 use super::*;
854
855 #[test]
860 fn test_new_creates_provider_with_custom_model() {
861 let provider = OpenAIProvider::new("test-api-key".to_string(), "custom-model".to_string());
862
863 assert_eq!(provider.model(), "custom-model");
864 assert_eq!(provider.provider(), "openai");
865 assert_eq!(provider.base_url, DEFAULT_BASE_URL);
866 }
867
868 #[test]
869 fn test_with_base_url_creates_provider_with_custom_url() {
870 let provider = OpenAIProvider::with_base_url(
871 "test-api-key".to_string(),
872 "llama3".to_string(),
873 "http://localhost:11434/v1".to_string(),
874 );
875
876 assert_eq!(provider.model(), "llama3");
877 assert_eq!(provider.base_url, "http://localhost:11434/v1");
878 }
879
880 #[test]
881 fn test_gpt4o_factory_creates_gpt4o_provider() {
882 let provider = OpenAIProvider::gpt4o("test-api-key".to_string());
883
884 assert_eq!(provider.model(), MODEL_GPT4O);
885 assert_eq!(provider.provider(), "openai");
886 }
887
888 #[test]
889 fn test_gpt4o_mini_factory_creates_gpt4o_mini_provider() {
890 let provider = OpenAIProvider::gpt4o_mini("test-api-key".to_string());
891
892 assert_eq!(provider.model(), MODEL_GPT4O_MINI);
893 assert_eq!(provider.provider(), "openai");
894 }
895
896 #[test]
897 fn test_gpt52_thinking_factory_creates_provider() {
898 let provider = OpenAIProvider::gpt52_thinking("test-api-key".to_string());
899
900 assert_eq!(provider.model(), MODEL_GPT52_THINKING);
901 assert_eq!(provider.provider(), "openai");
902 }
903
904 #[test]
905 fn test_gpt5_factory_creates_gpt5_provider() {
906 let provider = OpenAIProvider::gpt5("test-api-key".to_string());
907
908 assert_eq!(provider.model(), MODEL_GPT5);
909 assert_eq!(provider.provider(), "openai");
910 }
911
912 #[test]
913 fn test_gpt5_mini_factory_creates_provider() {
914 let provider = OpenAIProvider::gpt5_mini("test-api-key".to_string());
915
916 assert_eq!(provider.model(), MODEL_GPT5_MINI);
917 assert_eq!(provider.provider(), "openai");
918 }
919
920 #[test]
921 fn test_o3_factory_creates_o3_provider() {
922 let provider = OpenAIProvider::o3("test-api-key".to_string());
923
924 assert_eq!(provider.model(), MODEL_O3);
925 assert_eq!(provider.provider(), "openai");
926 }
927
928 #[test]
929 fn test_o4_mini_factory_creates_o4_mini_provider() {
930 let provider = OpenAIProvider::o4_mini("test-api-key".to_string());
931
932 assert_eq!(provider.model(), MODEL_O4_MINI);
933 assert_eq!(provider.provider(), "openai");
934 }
935
936 #[test]
937 fn test_o1_factory_creates_o1_provider() {
938 let provider = OpenAIProvider::o1("test-api-key".to_string());
939
940 assert_eq!(provider.model(), MODEL_O1);
941 assert_eq!(provider.provider(), "openai");
942 }
943
944 #[test]
945 fn test_gpt41_factory_creates_gpt41_provider() {
946 let provider = OpenAIProvider::gpt41("test-api-key".to_string());
947
948 assert_eq!(provider.model(), MODEL_GPT41);
949 assert_eq!(provider.provider(), "openai");
950 }
951
952 #[test]
957 fn test_model_constants_have_expected_values() {
958 assert_eq!(MODEL_GPT52_INSTANT, "gpt-5.2-instant");
960 assert_eq!(MODEL_GPT52_THINKING, "gpt-5.2-thinking");
961 assert_eq!(MODEL_GPT52_PRO, "gpt-5.2-pro");
962 assert_eq!(MODEL_GPT5, "gpt-5");
964 assert_eq!(MODEL_GPT5_MINI, "gpt-5-mini");
965 assert_eq!(MODEL_GPT5_NANO, "gpt-5-nano");
966 assert_eq!(MODEL_O3, "o3");
968 assert_eq!(MODEL_O3_MINI, "o3-mini");
969 assert_eq!(MODEL_O4_MINI, "o4-mini");
970 assert_eq!(MODEL_O1, "o1");
971 assert_eq!(MODEL_O1_MINI, "o1-mini");
972 assert_eq!(MODEL_GPT41, "gpt-4.1");
974 assert_eq!(MODEL_GPT41_MINI, "gpt-4.1-mini");
975 assert_eq!(MODEL_GPT41_NANO, "gpt-4.1-nano");
976 assert_eq!(MODEL_GPT4O, "gpt-4o");
978 assert_eq!(MODEL_GPT4O_MINI, "gpt-4o-mini");
979 }
980
981 #[test]
986 fn test_provider_is_cloneable() {
987 let provider = OpenAIProvider::new("test-api-key".to_string(), "test-model".to_string());
988 let cloned = provider.clone();
989
990 assert_eq!(provider.model(), cloned.model());
991 assert_eq!(provider.provider(), cloned.provider());
992 assert_eq!(provider.base_url, cloned.base_url);
993 }
994
995 #[test]
1000 fn test_api_role_serialization() {
1001 let system_role = ApiRole::System;
1002 let user_role = ApiRole::User;
1003 let assistant_role = ApiRole::Assistant;
1004 let tool_role = ApiRole::Tool;
1005
1006 assert_eq!(serde_json::to_string(&system_role).unwrap(), "\"system\"");
1007 assert_eq!(serde_json::to_string(&user_role).unwrap(), "\"user\"");
1008 assert_eq!(
1009 serde_json::to_string(&assistant_role).unwrap(),
1010 "\"assistant\""
1011 );
1012 assert_eq!(serde_json::to_string(&tool_role).unwrap(), "\"tool\"");
1013 }
1014
1015 #[test]
1016 fn test_api_message_serialization_simple() {
1017 let message = ApiMessage {
1018 role: ApiRole::User,
1019 content: Some("Hello, world!".to_string()),
1020 tool_calls: None,
1021 tool_call_id: None,
1022 };
1023
1024 let json = serde_json::to_string(&message).unwrap();
1025 assert!(json.contains("\"role\":\"user\""));
1026 assert!(json.contains("\"content\":\"Hello, world!\""));
1027 assert!(!json.contains("tool_calls"));
1029 assert!(!json.contains("tool_call_id"));
1030 }
1031
1032 #[test]
1033 fn test_api_message_serialization_with_tool_calls() {
1034 let message = ApiMessage {
1035 role: ApiRole::Assistant,
1036 content: Some("Let me help.".to_string()),
1037 tool_calls: Some(vec![ApiToolCall {
1038 id: "call_123".to_string(),
1039 r#type: "function".to_string(),
1040 function: ApiFunctionCall {
1041 name: "read_file".to_string(),
1042 arguments: "{\"path\": \"/test.txt\"}".to_string(),
1043 },
1044 }]),
1045 tool_call_id: None,
1046 };
1047
1048 let json = serde_json::to_string(&message).unwrap();
1049 assert!(json.contains("\"role\":\"assistant\""));
1050 assert!(json.contains("\"tool_calls\""));
1051 assert!(json.contains("\"id\":\"call_123\""));
1052 assert!(json.contains("\"type\":\"function\""));
1053 assert!(json.contains("\"name\":\"read_file\""));
1054 }
1055
1056 #[test]
1057 fn test_api_tool_message_serialization() {
1058 let message = ApiMessage {
1059 role: ApiRole::Tool,
1060 content: Some("File contents here".to_string()),
1061 tool_calls: None,
1062 tool_call_id: Some("call_123".to_string()),
1063 };
1064
1065 let json = serde_json::to_string(&message).unwrap();
1066 assert!(json.contains("\"role\":\"tool\""));
1067 assert!(json.contains("\"tool_call_id\":\"call_123\""));
1068 assert!(json.contains("\"content\":\"File contents here\""));
1069 }
1070
1071 #[test]
1072 fn test_api_tool_serialization() {
1073 let tool = ApiTool {
1074 r#type: "function".to_string(),
1075 function: ApiFunction {
1076 name: "test_tool".to_string(),
1077 description: "A test tool".to_string(),
1078 parameters: serde_json::json!({
1079 "type": "object",
1080 "properties": {
1081 "arg": {"type": "string"}
1082 }
1083 }),
1084 },
1085 };
1086
1087 let json = serde_json::to_string(&tool).unwrap();
1088 assert!(json.contains("\"type\":\"function\""));
1089 assert!(json.contains("\"name\":\"test_tool\""));
1090 assert!(json.contains("\"description\":\"A test tool\""));
1091 assert!(json.contains("\"parameters\""));
1092 }
1093
1094 #[test]
1099 fn test_api_response_deserialization() {
1100 let json = r#"{
1101 "id": "chatcmpl-123",
1102 "choices": [
1103 {
1104 "message": {
1105 "content": "Hello!"
1106 },
1107 "finish_reason": "stop"
1108 }
1109 ],
1110 "model": "gpt-4o",
1111 "usage": {
1112 "prompt_tokens": 100,
1113 "completion_tokens": 50
1114 }
1115 }"#;
1116
1117 let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1118 assert_eq!(response.id, "chatcmpl-123");
1119 assert_eq!(response.model, "gpt-4o");
1120 assert_eq!(response.usage.prompt_tokens, 100);
1121 assert_eq!(response.usage.completion_tokens, 50);
1122 assert_eq!(response.choices.len(), 1);
1123 assert_eq!(
1124 response.choices[0].message.content,
1125 Some("Hello!".to_string())
1126 );
1127 }
1128
1129 #[test]
1130 fn test_api_response_with_tool_calls_deserialization() {
1131 let json = r#"{
1132 "id": "chatcmpl-456",
1133 "choices": [
1134 {
1135 "message": {
1136 "content": null,
1137 "tool_calls": [
1138 {
1139 "id": "call_abc",
1140 "type": "function",
1141 "function": {
1142 "name": "read_file",
1143 "arguments": "{\"path\": \"test.txt\"}"
1144 }
1145 }
1146 ]
1147 },
1148 "finish_reason": "tool_calls"
1149 }
1150 ],
1151 "model": "gpt-4o",
1152 "usage": {
1153 "prompt_tokens": 150,
1154 "completion_tokens": 30
1155 }
1156 }"#;
1157
1158 let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1159 let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1160 assert_eq!(tool_calls.len(), 1);
1161 assert_eq!(tool_calls[0].id, "call_abc");
1162 assert_eq!(tool_calls[0].function.name, "read_file");
1163 }
1164
1165 #[test]
1166 fn test_api_finish_reason_deserialization() {
1167 let stop: ApiFinishReason = serde_json::from_str("\"stop\"").unwrap();
1168 let tool_calls: ApiFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1169 let length: ApiFinishReason = serde_json::from_str("\"length\"").unwrap();
1170 let content_filter: ApiFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1171
1172 assert!(matches!(stop, ApiFinishReason::Stop));
1173 assert!(matches!(tool_calls, ApiFinishReason::ToolCalls));
1174 assert!(matches!(length, ApiFinishReason::Length));
1175 assert!(matches!(content_filter, ApiFinishReason::ContentFilter));
1176 }
1177
1178 #[test]
1183 fn test_build_api_messages_with_system() {
1184 let request = ChatRequest {
1185 system: "You are helpful.".to_string(),
1186 messages: vec![crate::llm::Message::user("Hello")],
1187 tools: None,
1188 max_tokens: 1024,
1189 thinking: None,
1190 };
1191
1192 let api_messages = build_api_messages(&request);
1193 assert_eq!(api_messages.len(), 2);
1194 assert_eq!(api_messages[0].role, ApiRole::System);
1195 assert_eq!(
1196 api_messages[0].content,
1197 Some("You are helpful.".to_string())
1198 );
1199 assert_eq!(api_messages[1].role, ApiRole::User);
1200 assert_eq!(api_messages[1].content, Some("Hello".to_string()));
1201 }
1202
1203 #[test]
1204 fn test_build_api_messages_empty_system() {
1205 let request = ChatRequest {
1206 system: String::new(),
1207 messages: vec![crate::llm::Message::user("Hello")],
1208 tools: None,
1209 max_tokens: 1024,
1210 thinking: None,
1211 };
1212
1213 let api_messages = build_api_messages(&request);
1214 assert_eq!(api_messages.len(), 1);
1215 assert_eq!(api_messages[0].role, ApiRole::User);
1216 }
1217
1218 #[test]
1219 fn test_convert_tool() {
1220 let tool = crate::llm::Tool {
1221 name: "test_tool".to_string(),
1222 description: "A test tool".to_string(),
1223 input_schema: serde_json::json!({"type": "object"}),
1224 };
1225
1226 let api_tool = convert_tool(tool);
1227 assert_eq!(api_tool.r#type, "function");
1228 assert_eq!(api_tool.function.name, "test_tool");
1229 assert_eq!(api_tool.function.description, "A test tool");
1230 }
1231
1232 #[test]
1233 fn test_build_content_blocks_text_only() {
1234 let message = ApiResponseMessage {
1235 content: Some("Hello!".to_string()),
1236 tool_calls: None,
1237 };
1238
1239 let blocks = build_content_blocks(&message);
1240 assert_eq!(blocks.len(), 1);
1241 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
1242 }
1243
1244 #[test]
1245 fn test_build_content_blocks_with_tool_calls() {
1246 let message = ApiResponseMessage {
1247 content: Some("Let me help.".to_string()),
1248 tool_calls: Some(vec![ApiResponseToolCall {
1249 id: "call_123".to_string(),
1250 function: ApiResponseFunctionCall {
1251 name: "read_file".to_string(),
1252 arguments: "{\"path\": \"test.txt\"}".to_string(),
1253 },
1254 }]),
1255 };
1256
1257 let blocks = build_content_blocks(&message);
1258 assert_eq!(blocks.len(), 2);
1259 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me help."));
1260 assert!(
1261 matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "call_123" && name == "read_file")
1262 );
1263 }
1264
1265 #[test]
1270 fn test_sse_chunk_text_delta_deserialization() {
1271 let json = r#"{
1272 "choices": [{
1273 "delta": {
1274 "content": "Hello"
1275 },
1276 "finish_reason": null
1277 }]
1278 }"#;
1279
1280 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1281 assert_eq!(chunk.choices.len(), 1);
1282 assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
1283 assert!(chunk.choices[0].finish_reason.is_none());
1284 }
1285
1286 #[test]
1287 fn test_sse_chunk_tool_call_delta_deserialization() {
1288 let json = r#"{
1289 "choices": [{
1290 "delta": {
1291 "tool_calls": [{
1292 "index": 0,
1293 "id": "call_abc",
1294 "function": {
1295 "name": "read_file",
1296 "arguments": ""
1297 }
1298 }]
1299 },
1300 "finish_reason": null
1301 }]
1302 }"#;
1303
1304 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1305 let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1306 assert_eq!(tool_calls.len(), 1);
1307 assert_eq!(tool_calls[0].index, 0);
1308 assert_eq!(tool_calls[0].id, Some("call_abc".to_string()));
1309 assert_eq!(
1310 tool_calls[0].function.as_ref().unwrap().name,
1311 Some("read_file".to_string())
1312 );
1313 }
1314
1315 #[test]
1316 fn test_sse_chunk_tool_call_arguments_delta_deserialization() {
1317 let json = r#"{
1318 "choices": [{
1319 "delta": {
1320 "tool_calls": [{
1321 "index": 0,
1322 "function": {
1323 "arguments": "{\"path\":"
1324 }
1325 }]
1326 },
1327 "finish_reason": null
1328 }]
1329 }"#;
1330
1331 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1332 let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1333 assert_eq!(tool_calls[0].id, None);
1334 assert_eq!(
1335 tool_calls[0].function.as_ref().unwrap().arguments,
1336 Some("{\"path\":".to_string())
1337 );
1338 }
1339
1340 #[test]
1341 fn test_sse_chunk_with_finish_reason_deserialization() {
1342 let json = r#"{
1343 "choices": [{
1344 "delta": {},
1345 "finish_reason": "stop"
1346 }]
1347 }"#;
1348
1349 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1350 assert!(matches!(
1351 chunk.choices[0].finish_reason,
1352 Some(SseFinishReason::Stop)
1353 ));
1354 }
1355
1356 #[test]
1357 fn test_sse_chunk_with_usage_deserialization() {
1358 let json = r#"{
1359 "choices": [{
1360 "delta": {},
1361 "finish_reason": "stop"
1362 }],
1363 "usage": {
1364 "prompt_tokens": 100,
1365 "completion_tokens": 50
1366 }
1367 }"#;
1368
1369 let chunk: SseChunk = serde_json::from_str(json).unwrap();
1370 let usage = chunk.usage.unwrap();
1371 assert_eq!(usage.prompt_tokens, 100);
1372 assert_eq!(usage.completion_tokens, 50);
1373 }
1374
1375 #[test]
1376 fn test_sse_finish_reason_deserialization() {
1377 let stop: SseFinishReason = serde_json::from_str("\"stop\"").unwrap();
1378 let tool_calls: SseFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1379 let length: SseFinishReason = serde_json::from_str("\"length\"").unwrap();
1380 let content_filter: SseFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1381
1382 assert!(matches!(stop, SseFinishReason::Stop));
1383 assert!(matches!(tool_calls, SseFinishReason::ToolCalls));
1384 assert!(matches!(length, SseFinishReason::Length));
1385 assert!(matches!(content_filter, SseFinishReason::ContentFilter));
1386 }
1387
1388 #[test]
1389 fn test_streaming_request_serialization() {
1390 let messages = vec![ApiMessage {
1391 role: ApiRole::User,
1392 content: Some("Hello".to_string()),
1393 tool_calls: None,
1394 tool_call_id: None,
1395 }];
1396
1397 let request = ApiChatRequestStreaming {
1398 model: "gpt-4o",
1399 messages: &messages,
1400 max_completion_tokens: Some(1024),
1401 tools: None,
1402 stream: true,
1403 };
1404
1405 let json = serde_json::to_string(&request).unwrap();
1406 assert!(json.contains("\"stream\":true"));
1407 assert!(json.contains("\"model\":\"gpt-4o\""));
1408 }
1409}