1use crate::llm::{
8 ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, StopReason, Usage,
9};
10use anyhow::Result;
11use async_trait::async_trait;
12use reqwest::StatusCode;
13use serde::{Deserialize, Serialize};
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
16
17pub const MODEL_GPT52_INSTANT: &str = "gpt-5.2-instant";
19pub const MODEL_GPT52_THINKING: &str = "gpt-5.2-thinking";
20pub const MODEL_GPT52_PRO: &str = "gpt-5.2-pro";
21
22pub const MODEL_GPT5: &str = "gpt-5";
24pub const MODEL_GPT5_MINI: &str = "gpt-5-mini";
25pub const MODEL_GPT5_NANO: &str = "gpt-5-nano";
26
27pub const MODEL_O3: &str = "o3";
29pub const MODEL_O3_MINI: &str = "o3-mini";
30pub const MODEL_O4_MINI: &str = "o4-mini";
31pub const MODEL_O1: &str = "o1";
32pub const MODEL_O1_MINI: &str = "o1-mini";
33
34pub const MODEL_GPT41: &str = "gpt-4.1";
36pub const MODEL_GPT41_MINI: &str = "gpt-4.1-mini";
37pub const MODEL_GPT41_NANO: &str = "gpt-4.1-nano";
38
39pub const MODEL_GPT4O: &str = "gpt-4o";
41pub const MODEL_GPT4O_MINI: &str = "gpt-4o-mini";
42
43#[derive(Clone)]
48pub struct OpenAIProvider {
49 client: reqwest::Client,
50 api_key: String,
51 model: String,
52 base_url: String,
53}
54
55impl OpenAIProvider {
56 #[must_use]
58 pub fn new(api_key: String, model: String) -> Self {
59 Self {
60 client: reqwest::Client::new(),
61 api_key,
62 model,
63 base_url: DEFAULT_BASE_URL.to_owned(),
64 }
65 }
66
67 #[must_use]
69 pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
70 Self {
71 client: reqwest::Client::new(),
72 api_key,
73 model,
74 base_url,
75 }
76 }
77
78 #[must_use]
80 pub fn gpt52_instant(api_key: String) -> Self {
81 Self::new(api_key, MODEL_GPT52_INSTANT.to_owned())
82 }
83
84 #[must_use]
86 pub fn gpt52_thinking(api_key: String) -> Self {
87 Self::new(api_key, MODEL_GPT52_THINKING.to_owned())
88 }
89
90 #[must_use]
92 pub fn gpt52_pro(api_key: String) -> Self {
93 Self::new(api_key, MODEL_GPT52_PRO.to_owned())
94 }
95
96 #[must_use]
98 pub fn gpt5(api_key: String) -> Self {
99 Self::new(api_key, MODEL_GPT5.to_owned())
100 }
101
102 #[must_use]
104 pub fn gpt5_mini(api_key: String) -> Self {
105 Self::new(api_key, MODEL_GPT5_MINI.to_owned())
106 }
107
108 #[must_use]
110 pub fn gpt5_nano(api_key: String) -> Self {
111 Self::new(api_key, MODEL_GPT5_NANO.to_owned())
112 }
113
114 #[must_use]
116 pub fn o3(api_key: String) -> Self {
117 Self::new(api_key, MODEL_O3.to_owned())
118 }
119
120 #[must_use]
122 pub fn o3_mini(api_key: String) -> Self {
123 Self::new(api_key, MODEL_O3_MINI.to_owned())
124 }
125
126 #[must_use]
128 pub fn o4_mini(api_key: String) -> Self {
129 Self::new(api_key, MODEL_O4_MINI.to_owned())
130 }
131
132 #[must_use]
134 pub fn o1(api_key: String) -> Self {
135 Self::new(api_key, MODEL_O1.to_owned())
136 }
137
138 #[must_use]
140 pub fn o1_mini(api_key: String) -> Self {
141 Self::new(api_key, MODEL_O1_MINI.to_owned())
142 }
143
144 #[must_use]
146 pub fn gpt41(api_key: String) -> Self {
147 Self::new(api_key, MODEL_GPT41.to_owned())
148 }
149
150 #[must_use]
152 pub fn gpt41_mini(api_key: String) -> Self {
153 Self::new(api_key, MODEL_GPT41_MINI.to_owned())
154 }
155
156 #[must_use]
158 pub fn gpt4o(api_key: String) -> Self {
159 Self::new(api_key, MODEL_GPT4O.to_owned())
160 }
161
162 #[must_use]
164 pub fn gpt4o_mini(api_key: String) -> Self {
165 Self::new(api_key, MODEL_GPT4O_MINI.to_owned())
166 }
167}
168
169#[async_trait]
170impl LlmProvider for OpenAIProvider {
171 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
172 let messages = build_api_messages(&request);
173 let tools: Option<Vec<ApiTool>> = request
174 .tools
175 .map(|ts| ts.into_iter().map(convert_tool).collect());
176
177 let api_request = ApiChatRequest {
178 model: &self.model,
179 messages: &messages,
180 max_completion_tokens: Some(request.max_tokens),
181 tools: tools.as_deref(),
182 };
183
184 tracing::debug!(
185 model = %self.model,
186 max_tokens = request.max_tokens,
187 "OpenAI LLM request"
188 );
189
190 let response = self
191 .client
192 .post(format!("{}/chat/completions", self.base_url))
193 .header("Content-Type", "application/json")
194 .header("Authorization", format!("Bearer {}", self.api_key))
195 .json(&api_request)
196 .send()
197 .await
198 .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
199
200 let status = response.status();
201 let bytes = response
202 .bytes()
203 .await
204 .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
205
206 tracing::debug!(
207 status = %status,
208 body_len = bytes.len(),
209 "OpenAI LLM response"
210 );
211
212 if status == StatusCode::TOO_MANY_REQUESTS {
213 return Ok(ChatOutcome::RateLimited);
214 }
215
216 if status.is_server_error() {
217 let body = String::from_utf8_lossy(&bytes);
218 tracing::error!(status = %status, body = %body, "OpenAI server error");
219 return Ok(ChatOutcome::ServerError(body.into_owned()));
220 }
221
222 if status.is_client_error() {
223 let body = String::from_utf8_lossy(&bytes);
224 tracing::warn!(status = %status, body = %body, "OpenAI client error");
225 return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
226 }
227
228 let api_response: ApiChatResponse = serde_json::from_slice(&bytes)
229 .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
230
231 let choice = api_response
232 .choices
233 .into_iter()
234 .next()
235 .ok_or_else(|| anyhow::anyhow!("no choices in response"))?;
236
237 let content = build_content_blocks(&choice.message);
238
239 let stop_reason = choice.finish_reason.map(|r| match r {
240 ApiFinishReason::Stop => StopReason::EndTurn,
241 ApiFinishReason::ToolCalls => StopReason::ToolUse,
242 ApiFinishReason::Length => StopReason::MaxTokens,
243 ApiFinishReason::ContentFilter => StopReason::StopSequence,
244 });
245
246 Ok(ChatOutcome::Success(ChatResponse {
247 id: api_response.id,
248 content,
249 model: api_response.model,
250 stop_reason,
251 usage: Usage {
252 input_tokens: api_response.usage.prompt_tokens,
253 output_tokens: api_response.usage.completion_tokens,
254 },
255 }))
256 }
257
258 fn model(&self) -> &str {
259 &self.model
260 }
261
262 fn provider(&self) -> &'static str {
263 "openai"
264 }
265}
266
267fn build_api_messages(request: &ChatRequest) -> Vec<ApiMessage> {
268 let mut messages = Vec::new();
269
270 if !request.system.is_empty() {
272 messages.push(ApiMessage {
273 role: ApiRole::System,
274 content: Some(request.system.clone()),
275 tool_calls: None,
276 tool_call_id: None,
277 });
278 }
279
280 for msg in &request.messages {
282 match &msg.content {
283 Content::Text(text) => {
284 messages.push(ApiMessage {
285 role: match msg.role {
286 crate::llm::Role::User => ApiRole::User,
287 crate::llm::Role::Assistant => ApiRole::Assistant,
288 },
289 content: Some(text.clone()),
290 tool_calls: None,
291 tool_call_id: None,
292 });
293 }
294 Content::Blocks(blocks) => {
295 let mut text_parts = Vec::new();
297 let mut tool_calls = Vec::new();
298
299 for block in blocks {
300 match block {
301 ContentBlock::Text { text } => {
302 text_parts.push(text.clone());
303 }
304 ContentBlock::ToolUse { id, name, input } => {
305 tool_calls.push(ApiToolCall {
306 id: id.clone(),
307 r#type: "function".to_owned(),
308 function: ApiFunctionCall {
309 name: name.clone(),
310 arguments: serde_json::to_string(input)
311 .unwrap_or_else(|_| "{}".to_owned()),
312 },
313 });
314 }
315 ContentBlock::ToolResult {
316 tool_use_id,
317 content,
318 ..
319 } => {
320 messages.push(ApiMessage {
322 role: ApiRole::Tool,
323 content: Some(content.clone()),
324 tool_calls: None,
325 tool_call_id: Some(tool_use_id.clone()),
326 });
327 }
328 }
329 }
330
331 if !text_parts.is_empty() || !tool_calls.is_empty() {
333 let role = match msg.role {
334 crate::llm::Role::User => ApiRole::User,
335 crate::llm::Role::Assistant => ApiRole::Assistant,
336 };
337
338 if role == ApiRole::Assistant || !text_parts.is_empty() {
340 messages.push(ApiMessage {
341 role,
342 content: if text_parts.is_empty() {
343 None
344 } else {
345 Some(text_parts.join("\n"))
346 },
347 tool_calls: if tool_calls.is_empty() {
348 None
349 } else {
350 Some(tool_calls)
351 },
352 tool_call_id: None,
353 });
354 }
355 }
356 }
357 }
358 }
359
360 messages
361}
362
363fn convert_tool(t: crate::llm::Tool) -> ApiTool {
364 ApiTool {
365 r#type: "function".to_owned(),
366 function: ApiFunction {
367 name: t.name,
368 description: t.description,
369 parameters: t.input_schema,
370 },
371 }
372}
373
374fn build_content_blocks(message: &ApiResponseMessage) -> Vec<ContentBlock> {
375 let mut blocks = Vec::new();
376
377 if let Some(content) = &message.content
379 && !content.is_empty()
380 {
381 blocks.push(ContentBlock::Text {
382 text: content.clone(),
383 });
384 }
385
386 if let Some(tool_calls) = &message.tool_calls {
388 for tc in tool_calls {
389 let input: serde_json::Value =
390 serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null);
391 blocks.push(ContentBlock::ToolUse {
392 id: tc.id.clone(),
393 name: tc.function.name.clone(),
394 input,
395 });
396 }
397 }
398
399 blocks
400}
401
402#[derive(Serialize)]
407struct ApiChatRequest<'a> {
408 model: &'a str,
409 messages: &'a [ApiMessage],
410 #[serde(skip_serializing_if = "Option::is_none")]
411 max_completion_tokens: Option<u32>,
412 #[serde(skip_serializing_if = "Option::is_none")]
413 tools: Option<&'a [ApiTool]>,
414}
415
416#[derive(Serialize)]
417struct ApiMessage {
418 role: ApiRole,
419 #[serde(skip_serializing_if = "Option::is_none")]
420 content: Option<String>,
421 #[serde(skip_serializing_if = "Option::is_none")]
422 tool_calls: Option<Vec<ApiToolCall>>,
423 #[serde(skip_serializing_if = "Option::is_none")]
424 tool_call_id: Option<String>,
425}
426
427#[derive(Debug, Serialize, PartialEq, Eq)]
428#[serde(rename_all = "lowercase")]
429enum ApiRole {
430 System,
431 User,
432 Assistant,
433 Tool,
434}
435
436#[derive(Serialize)]
437struct ApiToolCall {
438 id: String,
439 r#type: String,
440 function: ApiFunctionCall,
441}
442
443#[derive(Serialize)]
444struct ApiFunctionCall {
445 name: String,
446 arguments: String,
447}
448
449#[derive(Serialize)]
450struct ApiTool {
451 r#type: String,
452 function: ApiFunction,
453}
454
455#[derive(Serialize)]
456struct ApiFunction {
457 name: String,
458 description: String,
459 parameters: serde_json::Value,
460}
461
462#[derive(Deserialize)]
467struct ApiChatResponse {
468 id: String,
469 choices: Vec<ApiChoice>,
470 model: String,
471 usage: ApiUsage,
472}
473
474#[derive(Deserialize)]
475struct ApiChoice {
476 message: ApiResponseMessage,
477 finish_reason: Option<ApiFinishReason>,
478}
479
480#[derive(Deserialize)]
481struct ApiResponseMessage {
482 content: Option<String>,
483 tool_calls: Option<Vec<ApiResponseToolCall>>,
484}
485
486#[derive(Deserialize)]
487struct ApiResponseToolCall {
488 id: String,
489 function: ApiResponseFunctionCall,
490}
491
492#[derive(Deserialize)]
493struct ApiResponseFunctionCall {
494 name: String,
495 arguments: String,
496}
497
498#[derive(Deserialize)]
499#[serde(rename_all = "snake_case")]
500enum ApiFinishReason {
501 Stop,
502 ToolCalls,
503 Length,
504 ContentFilter,
505}
506
507#[derive(Deserialize)]
508struct ApiUsage {
509 prompt_tokens: u32,
510 completion_tokens: u32,
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
522 fn test_new_creates_provider_with_custom_model() {
523 let provider = OpenAIProvider::new("test-api-key".to_string(), "custom-model".to_string());
524
525 assert_eq!(provider.model(), "custom-model");
526 assert_eq!(provider.provider(), "openai");
527 assert_eq!(provider.base_url, DEFAULT_BASE_URL);
528 }
529
530 #[test]
531 fn test_with_base_url_creates_provider_with_custom_url() {
532 let provider = OpenAIProvider::with_base_url(
533 "test-api-key".to_string(),
534 "llama3".to_string(),
535 "http://localhost:11434/v1".to_string(),
536 );
537
538 assert_eq!(provider.model(), "llama3");
539 assert_eq!(provider.base_url, "http://localhost:11434/v1");
540 }
541
542 #[test]
543 fn test_gpt4o_factory_creates_gpt4o_provider() {
544 let provider = OpenAIProvider::gpt4o("test-api-key".to_string());
545
546 assert_eq!(provider.model(), MODEL_GPT4O);
547 assert_eq!(provider.provider(), "openai");
548 }
549
550 #[test]
551 fn test_gpt4o_mini_factory_creates_gpt4o_mini_provider() {
552 let provider = OpenAIProvider::gpt4o_mini("test-api-key".to_string());
553
554 assert_eq!(provider.model(), MODEL_GPT4O_MINI);
555 assert_eq!(provider.provider(), "openai");
556 }
557
558 #[test]
559 fn test_gpt52_thinking_factory_creates_provider() {
560 let provider = OpenAIProvider::gpt52_thinking("test-api-key".to_string());
561
562 assert_eq!(provider.model(), MODEL_GPT52_THINKING);
563 assert_eq!(provider.provider(), "openai");
564 }
565
566 #[test]
567 fn test_gpt5_factory_creates_gpt5_provider() {
568 let provider = OpenAIProvider::gpt5("test-api-key".to_string());
569
570 assert_eq!(provider.model(), MODEL_GPT5);
571 assert_eq!(provider.provider(), "openai");
572 }
573
574 #[test]
575 fn test_gpt5_mini_factory_creates_provider() {
576 let provider = OpenAIProvider::gpt5_mini("test-api-key".to_string());
577
578 assert_eq!(provider.model(), MODEL_GPT5_MINI);
579 assert_eq!(provider.provider(), "openai");
580 }
581
582 #[test]
583 fn test_o3_factory_creates_o3_provider() {
584 let provider = OpenAIProvider::o3("test-api-key".to_string());
585
586 assert_eq!(provider.model(), MODEL_O3);
587 assert_eq!(provider.provider(), "openai");
588 }
589
590 #[test]
591 fn test_o4_mini_factory_creates_o4_mini_provider() {
592 let provider = OpenAIProvider::o4_mini("test-api-key".to_string());
593
594 assert_eq!(provider.model(), MODEL_O4_MINI);
595 assert_eq!(provider.provider(), "openai");
596 }
597
598 #[test]
599 fn test_o1_factory_creates_o1_provider() {
600 let provider = OpenAIProvider::o1("test-api-key".to_string());
601
602 assert_eq!(provider.model(), MODEL_O1);
603 assert_eq!(provider.provider(), "openai");
604 }
605
606 #[test]
607 fn test_gpt41_factory_creates_gpt41_provider() {
608 let provider = OpenAIProvider::gpt41("test-api-key".to_string());
609
610 assert_eq!(provider.model(), MODEL_GPT41);
611 assert_eq!(provider.provider(), "openai");
612 }
613
614 #[test]
619 fn test_model_constants_have_expected_values() {
620 assert_eq!(MODEL_GPT52_INSTANT, "gpt-5.2-instant");
622 assert_eq!(MODEL_GPT52_THINKING, "gpt-5.2-thinking");
623 assert_eq!(MODEL_GPT52_PRO, "gpt-5.2-pro");
624 assert_eq!(MODEL_GPT5, "gpt-5");
626 assert_eq!(MODEL_GPT5_MINI, "gpt-5-mini");
627 assert_eq!(MODEL_GPT5_NANO, "gpt-5-nano");
628 assert_eq!(MODEL_O3, "o3");
630 assert_eq!(MODEL_O3_MINI, "o3-mini");
631 assert_eq!(MODEL_O4_MINI, "o4-mini");
632 assert_eq!(MODEL_O1, "o1");
633 assert_eq!(MODEL_O1_MINI, "o1-mini");
634 assert_eq!(MODEL_GPT41, "gpt-4.1");
636 assert_eq!(MODEL_GPT41_MINI, "gpt-4.1-mini");
637 assert_eq!(MODEL_GPT41_NANO, "gpt-4.1-nano");
638 assert_eq!(MODEL_GPT4O, "gpt-4o");
640 assert_eq!(MODEL_GPT4O_MINI, "gpt-4o-mini");
641 }
642
643 #[test]
648 fn test_provider_is_cloneable() {
649 let provider = OpenAIProvider::new("test-api-key".to_string(), "test-model".to_string());
650 let cloned = provider.clone();
651
652 assert_eq!(provider.model(), cloned.model());
653 assert_eq!(provider.provider(), cloned.provider());
654 assert_eq!(provider.base_url, cloned.base_url);
655 }
656
657 #[test]
662 fn test_api_role_serialization() {
663 let system_role = ApiRole::System;
664 let user_role = ApiRole::User;
665 let assistant_role = ApiRole::Assistant;
666 let tool_role = ApiRole::Tool;
667
668 assert_eq!(serde_json::to_string(&system_role).unwrap(), "\"system\"");
669 assert_eq!(serde_json::to_string(&user_role).unwrap(), "\"user\"");
670 assert_eq!(
671 serde_json::to_string(&assistant_role).unwrap(),
672 "\"assistant\""
673 );
674 assert_eq!(serde_json::to_string(&tool_role).unwrap(), "\"tool\"");
675 }
676
677 #[test]
678 fn test_api_message_serialization_simple() {
679 let message = ApiMessage {
680 role: ApiRole::User,
681 content: Some("Hello, world!".to_string()),
682 tool_calls: None,
683 tool_call_id: None,
684 };
685
686 let json = serde_json::to_string(&message).unwrap();
687 assert!(json.contains("\"role\":\"user\""));
688 assert!(json.contains("\"content\":\"Hello, world!\""));
689 assert!(!json.contains("tool_calls"));
691 assert!(!json.contains("tool_call_id"));
692 }
693
694 #[test]
695 fn test_api_message_serialization_with_tool_calls() {
696 let message = ApiMessage {
697 role: ApiRole::Assistant,
698 content: Some("Let me help.".to_string()),
699 tool_calls: Some(vec![ApiToolCall {
700 id: "call_123".to_string(),
701 r#type: "function".to_string(),
702 function: ApiFunctionCall {
703 name: "read_file".to_string(),
704 arguments: "{\"path\": \"/test.txt\"}".to_string(),
705 },
706 }]),
707 tool_call_id: None,
708 };
709
710 let json = serde_json::to_string(&message).unwrap();
711 assert!(json.contains("\"role\":\"assistant\""));
712 assert!(json.contains("\"tool_calls\""));
713 assert!(json.contains("\"id\":\"call_123\""));
714 assert!(json.contains("\"type\":\"function\""));
715 assert!(json.contains("\"name\":\"read_file\""));
716 }
717
718 #[test]
719 fn test_api_tool_message_serialization() {
720 let message = ApiMessage {
721 role: ApiRole::Tool,
722 content: Some("File contents here".to_string()),
723 tool_calls: None,
724 tool_call_id: Some("call_123".to_string()),
725 };
726
727 let json = serde_json::to_string(&message).unwrap();
728 assert!(json.contains("\"role\":\"tool\""));
729 assert!(json.contains("\"tool_call_id\":\"call_123\""));
730 assert!(json.contains("\"content\":\"File contents here\""));
731 }
732
733 #[test]
734 fn test_api_tool_serialization() {
735 let tool = ApiTool {
736 r#type: "function".to_string(),
737 function: ApiFunction {
738 name: "test_tool".to_string(),
739 description: "A test tool".to_string(),
740 parameters: serde_json::json!({
741 "type": "object",
742 "properties": {
743 "arg": {"type": "string"}
744 }
745 }),
746 },
747 };
748
749 let json = serde_json::to_string(&tool).unwrap();
750 assert!(json.contains("\"type\":\"function\""));
751 assert!(json.contains("\"name\":\"test_tool\""));
752 assert!(json.contains("\"description\":\"A test tool\""));
753 assert!(json.contains("\"parameters\""));
754 }
755
756 #[test]
761 fn test_api_response_deserialization() {
762 let json = r#"{
763 "id": "chatcmpl-123",
764 "choices": [
765 {
766 "message": {
767 "content": "Hello!"
768 },
769 "finish_reason": "stop"
770 }
771 ],
772 "model": "gpt-4o",
773 "usage": {
774 "prompt_tokens": 100,
775 "completion_tokens": 50
776 }
777 }"#;
778
779 let response: ApiChatResponse = serde_json::from_str(json).unwrap();
780 assert_eq!(response.id, "chatcmpl-123");
781 assert_eq!(response.model, "gpt-4o");
782 assert_eq!(response.usage.prompt_tokens, 100);
783 assert_eq!(response.usage.completion_tokens, 50);
784 assert_eq!(response.choices.len(), 1);
785 assert_eq!(
786 response.choices[0].message.content,
787 Some("Hello!".to_string())
788 );
789 }
790
791 #[test]
792 fn test_api_response_with_tool_calls_deserialization() {
793 let json = r#"{
794 "id": "chatcmpl-456",
795 "choices": [
796 {
797 "message": {
798 "content": null,
799 "tool_calls": [
800 {
801 "id": "call_abc",
802 "type": "function",
803 "function": {
804 "name": "read_file",
805 "arguments": "{\"path\": \"test.txt\"}"
806 }
807 }
808 ]
809 },
810 "finish_reason": "tool_calls"
811 }
812 ],
813 "model": "gpt-4o",
814 "usage": {
815 "prompt_tokens": 150,
816 "completion_tokens": 30
817 }
818 }"#;
819
820 let response: ApiChatResponse = serde_json::from_str(json).unwrap();
821 let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
822 assert_eq!(tool_calls.len(), 1);
823 assert_eq!(tool_calls[0].id, "call_abc");
824 assert_eq!(tool_calls[0].function.name, "read_file");
825 }
826
827 #[test]
828 fn test_api_finish_reason_deserialization() {
829 let stop: ApiFinishReason = serde_json::from_str("\"stop\"").unwrap();
830 let tool_calls: ApiFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
831 let length: ApiFinishReason = serde_json::from_str("\"length\"").unwrap();
832 let content_filter: ApiFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
833
834 assert!(matches!(stop, ApiFinishReason::Stop));
835 assert!(matches!(tool_calls, ApiFinishReason::ToolCalls));
836 assert!(matches!(length, ApiFinishReason::Length));
837 assert!(matches!(content_filter, ApiFinishReason::ContentFilter));
838 }
839
840 #[test]
845 fn test_build_api_messages_with_system() {
846 let request = ChatRequest {
847 system: "You are helpful.".to_string(),
848 messages: vec![crate::llm::Message::user("Hello")],
849 tools: None,
850 max_tokens: 1024,
851 };
852
853 let api_messages = build_api_messages(&request);
854 assert_eq!(api_messages.len(), 2);
855 assert_eq!(api_messages[0].role, ApiRole::System);
856 assert_eq!(
857 api_messages[0].content,
858 Some("You are helpful.".to_string())
859 );
860 assert_eq!(api_messages[1].role, ApiRole::User);
861 assert_eq!(api_messages[1].content, Some("Hello".to_string()));
862 }
863
864 #[test]
865 fn test_build_api_messages_empty_system() {
866 let request = ChatRequest {
867 system: String::new(),
868 messages: vec![crate::llm::Message::user("Hello")],
869 tools: None,
870 max_tokens: 1024,
871 };
872
873 let api_messages = build_api_messages(&request);
874 assert_eq!(api_messages.len(), 1);
875 assert_eq!(api_messages[0].role, ApiRole::User);
876 }
877
878 #[test]
879 fn test_convert_tool() {
880 let tool = crate::llm::Tool {
881 name: "test_tool".to_string(),
882 description: "A test tool".to_string(),
883 input_schema: serde_json::json!({"type": "object"}),
884 };
885
886 let api_tool = convert_tool(tool);
887 assert_eq!(api_tool.r#type, "function");
888 assert_eq!(api_tool.function.name, "test_tool");
889 assert_eq!(api_tool.function.description, "A test tool");
890 }
891
892 #[test]
893 fn test_build_content_blocks_text_only() {
894 let message = ApiResponseMessage {
895 content: Some("Hello!".to_string()),
896 tool_calls: None,
897 };
898
899 let blocks = build_content_blocks(&message);
900 assert_eq!(blocks.len(), 1);
901 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
902 }
903
904 #[test]
905 fn test_build_content_blocks_with_tool_calls() {
906 let message = ApiResponseMessage {
907 content: Some("Let me help.".to_string()),
908 tool_calls: Some(vec![ApiResponseToolCall {
909 id: "call_123".to_string(),
910 function: ApiResponseFunctionCall {
911 name: "read_file".to_string(),
912 arguments: "{\"path\": \"test.txt\"}".to_string(),
913 },
914 }]),
915 };
916
917 let blocks = build_content_blocks(&message);
918 assert_eq!(blocks.len(), 2);
919 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me help."));
920 assert!(
921 matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "call_123" && name == "read_file")
922 );
923 }
924}