1use async_trait::async_trait;
25use once_cell::sync::Lazy;
26use serde::{Deserialize, Serialize};
27use std::collections::{BTreeMap, HashMap};
28use std::sync::RwLock;
29use std::time::Duration;
30
31use crate::error::{Error, Result};
32use crate::llm::ollama::types::{
33 ChatMessage as OllamaChatMessage, ChatRequest as OllamaChatRequest,
34};
35use crate::llm::ollama::OllamaClient;
36
37static HTTP_CLIENT_POOL: Lazy<RwLock<HashMap<u64, reqwest::Client>>> =
44 Lazy::new(|| RwLock::new(HashMap::new()));
45
46static DEFAULT_HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
49 reqwest::Client::builder()
50 .timeout(Duration::from_secs(120))
51 .pool_max_idle_per_host(10)
52 .pool_idle_timeout(Duration::from_secs(90))
53 .tcp_keepalive(Duration::from_secs(60))
54 .build()
55 .expect("Failed to create default HTTP client")
56});
57
58fn env_first(keys: &[&str]) -> Option<String> {
59 for k in keys {
60 if let Ok(v) = std::env::var(k) {
61 let v = v.trim().to_string();
62 if !v.is_empty() {
63 return Some(v);
64 }
65 }
66 }
67 None
68}
69
70fn get_pooled_client(timeout_secs: u64) -> reqwest::Client {
73 if timeout_secs == 120 {
75 return DEFAULT_HTTP_CLIENT.clone();
76 }
77
78 if let Ok(pool) = HTTP_CLIENT_POOL.read() {
81 if let Some(client) = pool.get(&timeout_secs) {
82 return client.clone();
83 }
84 }
85
86 let client = reqwest::Client::builder()
88 .timeout(Duration::from_secs(timeout_secs))
89 .pool_max_idle_per_host(10)
90 .pool_idle_timeout(Duration::from_secs(90))
91 .tcp_keepalive(Duration::from_secs(60))
92 .build()
93 .unwrap_or_else(|_| DEFAULT_HTTP_CLIENT.clone());
94
95 if let Ok(mut pool) = HTTP_CLIENT_POOL.write() {
97 pool.insert(timeout_secs, client.clone());
98 }
99
100 client
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct LlmConfig {
110 pub provider: LlmProvider,
112
113 pub model: String,
115
116 pub api_key: Option<String>,
118
119 pub base_url: Option<String>,
121
122 #[serde(default = "default_temperature")]
124 pub temperature: f64,
125
126 #[serde(default = "default_max_tokens")]
128 pub max_tokens: u32,
129
130 #[serde(default = "default_timeout")]
132 pub timeout_secs: u64,
133
134 #[serde(default)]
136 pub extra: ProviderExtra,
137}
138
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct ProviderExtra {
142 pub azure_resource: Option<String>,
144
145 pub azure_deployment: Option<String>,
147
148 pub aws_region: Option<String>,
150
151 pub gcp_project: Option<String>,
153
154 pub gcp_location: Option<String>,
156
157 pub cf_account_id: Option<String>,
159
160 pub cf_gateway_id: Option<String>,
162
163 pub gateway_provider: Option<String>,
165}
166
167fn default_temperature() -> f64 {
168 0.7
169}
170
171fn default_max_tokens() -> u32 {
172 2000
173}
174
175fn default_timeout() -> u64 {
176 60
177}
178
179impl Default for LlmConfig {
180 fn default() -> Self {
181 Self {
182 provider: LlmProvider::Anthropic,
183 model: "claude-opus-4-5".to_string(), api_key: None,
185 base_url: None,
186 temperature: default_temperature(),
187 max_tokens: default_max_tokens(),
188 timeout_secs: default_timeout(),
189 extra: ProviderExtra::default(),
190 }
191 }
192}
193
194impl LlmConfig {
195 pub fn for_provider(provider: LlmProvider, model: impl Into<String>) -> Self {
197 Self {
198 provider,
199 model: model.into(),
200 ..Default::default()
201 }
202 }
203
204 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
206 self.api_key = Some(key.into());
207 self
208 }
209
210 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
212 self.base_url = Some(url.into());
213 self
214 }
215
216 pub fn with_temperature(mut self, temp: f64) -> Self {
218 self.temperature = temp;
219 self
220 }
221
222 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
224 self.max_tokens = tokens;
225 self
226 }
227
228 pub fn with_azure(
230 mut self,
231 resource: impl Into<String>,
232 deployment: impl Into<String>,
233 ) -> Self {
234 self.extra.azure_resource = Some(resource.into());
235 self.extra.azure_deployment = Some(deployment.into());
236 self
237 }
238
239 pub fn with_aws_region(mut self, region: impl Into<String>) -> Self {
241 self.extra.aws_region = Some(region.into());
242 self
243 }
244
245 pub fn with_gcp(mut self, project: impl Into<String>, location: impl Into<String>) -> Self {
247 self.extra.gcp_project = Some(project.into());
248 self.extra.gcp_location = Some(location.into());
249 self
250 }
251
252 pub fn with_cloudflare_gateway(
254 mut self,
255 account_id: impl Into<String>,
256 gateway_id: impl Into<String>,
257 ) -> Self {
258 self.extra.cf_account_id = Some(account_id.into());
259 self.extra.cf_gateway_id = Some(gateway_id.into());
260 self
261 }
262}
263
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
270#[serde(rename_all = "snake_case")]
271#[derive(Default)]
272pub enum LlmProvider {
273 #[default]
278 Anthropic,
279
280 #[serde(rename = "open_ai")]
282 OpenAI,
283
284 GoogleGemini,
287
288 GoogleVertex,
291
292 #[serde(rename = "azure_open_ai")]
295 AzureOpenAI,
296
297 #[serde(rename = "aws_bedrock")]
300 AWSBedrock,
301
302 Ollama,
308
309 #[serde(rename = "xai")]
315 XAI,
316
317 Groq,
320
321 Mistral,
324
325 DeepSeek,
328
329 Cohere,
332
333 Perplexity,
336
337 Cerebras,
340
341 #[serde(rename = "together_ai")]
347 TogetherAI,
348
349 #[serde(rename = "fireworks_ai")]
352 FireworksAI,
353
354 AlibabaQwen,
357
358 OpenRouter,
364
365 #[serde(rename = "cloudflare_ai")]
368 CloudflareAI,
369
370 Opencode,
373}
374
375impl LlmProvider {
376 pub fn all() -> &'static [LlmProvider] {
378 &[
379 LlmProvider::Anthropic,
380 LlmProvider::OpenAI,
381 LlmProvider::GoogleGemini,
382 LlmProvider::GoogleVertex,
383 LlmProvider::AzureOpenAI,
384 LlmProvider::AWSBedrock,
385 LlmProvider::Ollama,
386 LlmProvider::XAI,
387 LlmProvider::Groq,
388 LlmProvider::Mistral,
389 LlmProvider::DeepSeek,
390 LlmProvider::Cohere,
391 LlmProvider::Perplexity,
392 LlmProvider::Cerebras,
393 LlmProvider::TogetherAI,
394 LlmProvider::FireworksAI,
395 LlmProvider::AlibabaQwen,
396 LlmProvider::OpenRouter,
397 LlmProvider::CloudflareAI,
398 LlmProvider::Opencode,
399 ]
400 }
401
402 pub fn env_var(&self) -> &'static str {
404 match self {
405 LlmProvider::Anthropic => "ANTHROPIC_API_KEY",
406 LlmProvider::OpenAI => "OPENAI_API_KEY",
407 LlmProvider::GoogleGemini => "GEMINI_API_KEY",
408 LlmProvider::GoogleVertex => "GOOGLE_APPLICATION_CREDENTIALS",
409 LlmProvider::AzureOpenAI => "AZURE_OPENAI_API_KEY",
410 LlmProvider::AWSBedrock => "AWS_ACCESS_KEY_ID", LlmProvider::Ollama => "RK_OLLAMA_MODEL",
414 LlmProvider::XAI => "XAI_API_KEY",
415 LlmProvider::Groq => "GROQ_API_KEY",
416 LlmProvider::Mistral => "MISTRAL_API_KEY",
417 LlmProvider::DeepSeek => "DEEPSEEK_API_KEY",
418 LlmProvider::Cohere => "COHERE_API_KEY",
419 LlmProvider::Perplexity => "PERPLEXITY_API_KEY",
420 LlmProvider::Cerebras => "CEREBRAS_API_KEY",
421 LlmProvider::TogetherAI => "TOGETHER_API_KEY",
422 LlmProvider::FireworksAI => "FIREWORKS_API_KEY",
423 LlmProvider::AlibabaQwen => "DASHSCOPE_API_KEY",
424 LlmProvider::OpenRouter => "OPENROUTER_API_KEY",
425 LlmProvider::CloudflareAI => "CLOUDFLARE_API_KEY",
426 LlmProvider::Opencode => "OPENCODE_API_KEY",
427 }
428 }
429
430 pub fn default_base_url(&self) -> &'static str {
432 match self {
433 LlmProvider::Anthropic => "https://api.anthropic.com/v1",
434 LlmProvider::OpenAI => "https://api.openai.com/v1",
435 LlmProvider::GoogleGemini => "https://generativelanguage.googleapis.com/v1beta/openai",
436 LlmProvider::GoogleVertex => "https://aiplatform.googleapis.com/v1", LlmProvider::AzureOpenAI => "https://RESOURCE.openai.azure.com/openai", LlmProvider::AWSBedrock => "https://bedrock-runtime.us-east-1.amazonaws.com", LlmProvider::Ollama => "http://localhost:11434",
440 LlmProvider::XAI => "https://api.x.ai/v1",
441 LlmProvider::Groq => "https://api.groq.com/openai/v1",
442 LlmProvider::Mistral => "https://api.mistral.ai/v1",
443 LlmProvider::DeepSeek => "https://api.deepseek.com/v1",
444 LlmProvider::Cohere => "https://api.cohere.ai/v1",
445 LlmProvider::Perplexity => "https://api.perplexity.ai",
446 LlmProvider::Cerebras => "https://api.cerebras.ai/v1",
447 LlmProvider::TogetherAI => "https://api.together.xyz/v1",
448 LlmProvider::FireworksAI => "https://api.fireworks.ai/inference/v1",
449 LlmProvider::AlibabaQwen => "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
450 LlmProvider::OpenRouter => "https://openrouter.ai/api/v1",
451 LlmProvider::CloudflareAI => "https://gateway.ai.cloudflare.com/v1", LlmProvider::Opencode => "https://api.opencode.ai/v1",
453 }
454 }
455
456 pub fn default_model(&self) -> &'static str {
458 match self {
459 LlmProvider::Anthropic => "claude-opus-4-5", LlmProvider::OpenAI => "gpt-5.1", LlmProvider::GoogleGemini => "gemini-3.0-pro", LlmProvider::GoogleVertex => "gemini-3.0-pro", LlmProvider::AzureOpenAI => "gpt-5.1", LlmProvider::AWSBedrock => "anthropic.claude-opus-4-5-v1:0", LlmProvider::Ollama => "deepseek-v3.2:cloud", LlmProvider::XAI => "grok-4.1", LlmProvider::Groq => "llama-3.3-70b-versatile", LlmProvider::Mistral => "mistral-large-3", LlmProvider::DeepSeek => "deepseek-v3.2", LlmProvider::Cohere => "command-a", LlmProvider::Perplexity => "sonar-pro", LlmProvider::Cerebras => "llama-4-scout", LlmProvider::TogetherAI => "meta-llama/Llama-4-Scout-17B-16E-Instruct",
481 LlmProvider::FireworksAI => "accounts/fireworks/models/llama-v4-scout-instruct",
482 LlmProvider::AlibabaQwen => "qwen3-max", LlmProvider::OpenRouter => "anthropic/claude-opus-4-5", LlmProvider::CloudflareAI => "@cf/meta/llama-4-scout-instruct-fp8-fast",
487 LlmProvider::Opencode => "default",
488 }
489 }
490
491 pub fn is_anthropic_format(&self) -> bool {
493 matches!(self, LlmProvider::Anthropic)
494 }
495
496 pub fn is_openai_compatible(&self) -> bool {
498 !self.is_anthropic_format()
499 }
500
501 pub fn requires_special_auth(&self) -> bool {
503 matches!(
504 self,
505 LlmProvider::AzureOpenAI | LlmProvider::AWSBedrock | LlmProvider::GoogleVertex
506 )
507 }
508
509 pub fn display_name(&self) -> &'static str {
511 match self {
512 LlmProvider::Anthropic => "Anthropic",
513 LlmProvider::OpenAI => "OpenAI",
514 LlmProvider::GoogleGemini => "Google Gemini (AI Studio)",
515 LlmProvider::GoogleVertex => "Google Vertex AI",
516 LlmProvider::AzureOpenAI => "Azure OpenAI",
517 LlmProvider::AWSBedrock => "AWS Bedrock",
518 LlmProvider::Ollama => "Ollama",
519 LlmProvider::XAI => "xAI (Grok)",
520 LlmProvider::Groq => "Groq",
521 LlmProvider::Mistral => "Mistral AI",
522 LlmProvider::DeepSeek => "DeepSeek",
523 LlmProvider::Cohere => "Cohere",
524 LlmProvider::Perplexity => "Perplexity",
525 LlmProvider::Cerebras => "Cerebras",
526 LlmProvider::TogetherAI => "Together AI",
527 LlmProvider::FireworksAI => "Fireworks AI",
528 LlmProvider::AlibabaQwen => "Alibaba Qwen",
529 LlmProvider::OpenRouter => "OpenRouter",
530 LlmProvider::CloudflareAI => "Cloudflare AI",
531 LlmProvider::Opencode => "Opencode AI",
532 }
533 }
534}
535
536impl std::fmt::Display for LlmProvider {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 write!(f, "{}", self.display_name())
539 }
540}
541
542#[derive(Debug, Clone)]
548pub struct LlmRequest {
549 pub system: Option<String>,
551
552 pub prompt: String,
554
555 pub temperature: Option<f64>,
557
558 pub max_tokens: Option<u32>,
560}
561
562impl LlmRequest {
563 pub fn new(prompt: impl Into<String>) -> Self {
565 Self {
566 system: None,
567 prompt: prompt.into(),
568 temperature: None,
569 max_tokens: None,
570 }
571 }
572
573 pub fn with_system(mut self, system: impl Into<String>) -> Self {
575 self.system = Some(system.into());
576 self
577 }
578
579 pub fn with_temperature(mut self, temp: f64) -> Self {
581 self.temperature = Some(temp);
582 self
583 }
584
585 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
587 self.max_tokens = Some(tokens);
588 self
589 }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct LlmResponse {
595 pub content: String,
597
598 pub model: String,
600
601 pub finish_reason: FinishReason,
603
604 pub usage: LlmUsage,
606
607 #[serde(default)]
609 pub provider: Option<LlmProvider>,
610}
611
612#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
614#[serde(rename_all = "snake_case")]
615#[derive(Default)]
616pub enum FinishReason {
617 #[default]
619 Stop,
620 MaxTokens,
622 ContentFilter,
624 Error,
626}
627
628#[derive(Debug, Clone, Default, Serialize, Deserialize)]
630pub struct LlmUsage {
631 pub input_tokens: u32,
633
634 pub output_tokens: u32,
636
637 pub total_tokens: u32,
639}
640
641impl LlmUsage {
642 pub fn cost_usd(&self, model: &str) -> f64 {
644 let (input_price, output_price) = match model {
646 m if m.contains("claude-opus-4-5") || m.contains("claude-opus-4.5") => (15.0, 75.0),
648 m if m.contains("claude-opus-4") => (15.0, 75.0),
649 m if m.contains("claude-sonnet-4") => (3.0, 15.0),
650 m if m.contains("claude-3-5-sonnet") => (3.0, 15.0),
651 m if m.contains("claude-sonnet-4-5") => (3.0, 15.0),
652 m if m.contains("claude-3-haiku") => (0.25, 1.25),
653
654 m if m.contains("gpt-5.1") || m.contains("gpt-5") => (5.0, 15.0),
656 m if m.contains("gpt-4o") => (2.5, 10.0),
657 m if m.contains("gpt-4-turbo") => (10.0, 30.0),
658 m if m.contains("gpt-3.5") => (0.5, 1.5),
659 m if m.contains("o1") || m.contains("o3") || m.contains("o4") => (15.0, 60.0),
660
661 m if m.contains("gemini-3.0-pro") => (1.5, 6.0),
663 m if m.contains("gemini-2.5-pro") => (1.25, 5.0),
664 m if m.contains("gemini-2.0-flash") || m.contains("gemini-2.5-flash") => (0.1, 0.4),
665 m if m.contains("gemini-1.5-pro") => (1.25, 5.0),
666 m if m.contains("gemini-1.5-flash") => (0.075, 0.3),
667
668 m if m.contains("grok-4.1") || m.contains("grok-4") => (3.0, 12.0),
670 m if m.contains("grok-2") || m.contains("grok-3") => (2.0, 10.0),
671
672 m if m.contains("llama") && m.contains("groq") => (0.05, 0.08),
674 m if m.contains("mixtral") && m.contains("groq") => (0.24, 0.24),
675 m if m.contains("llama-3.3-70b-versatile") => (0.59, 0.79),
676
677 m if m.contains("mistral-large-3") => (2.5, 7.5),
679 m if m.contains("mistral-large") => (2.0, 6.0),
680 m if m.contains("ministral") => (0.1, 0.3),
681 m if m.contains("mistral-small") => (0.2, 0.6),
682 m if m.contains("codestral") => (0.2, 0.6),
683
684 m if m.contains("deepseek-v3.2") || m.contains("deepseek-v3") => (0.27, 1.10),
686 m if m.contains("deepseek") => (0.14, 0.28),
687
688 m if m.contains("llama-4") || m.contains("Llama-4") => (0.18, 0.59),
690 m if m.contains("llama-3.3-70b") => (0.88, 0.88),
691 m if m.contains("llama-3.1-405b") => (3.5, 3.5),
692
693 m if m.contains("qwen3") || m.contains("qwen-max") => (0.4, 1.2),
695 m if m.contains("qwen") => (0.3, 0.3),
696
697 m if m.contains("command-a") => (2.5, 10.0),
699 m if m.contains("command-r-plus") => (2.5, 10.0),
700 m if m.contains("command-r") => (0.15, 0.6),
701
702 m if m.contains("sonar-pro") => (3.0, 15.0),
704 m if m.contains("sonar") => (1.0, 1.0),
705
706 m if m.contains("cerebras") => (0.6, 0.6),
708
709 _ => (1.0, 3.0),
711 };
712
713 let input_cost = (self.input_tokens as f64 / 1_000_000.0) * input_price;
714 let output_cost = (self.output_tokens as f64 / 1_000_000.0) * output_price;
715
716 input_cost + output_cost
717 }
718}
719
720#[async_trait]
726pub trait LlmClient: Send + Sync {
727 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse>;
729
730 fn provider(&self) -> LlmProvider;
732
733 fn model(&self) -> &str;
735}
736
737pub struct UnifiedLlmClient {
743 config: LlmConfig,
744 http_client: reqwest::Client,
745}
746
747impl UnifiedLlmClient {
748 pub fn new(config: LlmConfig) -> Result<Self> {
753 let http_client = get_pooled_client(config.timeout_secs);
756
757 Ok(Self {
758 config,
759 http_client,
760 })
761 }
762
763 pub fn default_anthropic() -> Result<Self> {
769 Self::new(LlmConfig::default())
770 }
771
772 pub fn openai(model: impl Into<String>) -> Result<Self> {
774 Self::new(LlmConfig::for_provider(LlmProvider::OpenAI, model))
775 }
776
777 pub fn openrouter(model: impl Into<String>) -> Result<Self> {
779 Self::new(LlmConfig::for_provider(LlmProvider::OpenRouter, model))
780 }
781
782 pub fn gemini(model: impl Into<String>) -> Result<Self> {
784 Self::new(LlmConfig::for_provider(LlmProvider::GoogleGemini, model))
785 }
786
787 pub fn grok(model: impl Into<String>) -> Result<Self> {
789 Self::new(LlmConfig::for_provider(LlmProvider::XAI, model))
790 }
791
792 pub fn groq(model: impl Into<String>) -> Result<Self> {
794 Self::new(LlmConfig::for_provider(LlmProvider::Groq, model))
795 }
796
797 pub fn mistral(model: impl Into<String>) -> Result<Self> {
799 Self::new(LlmConfig::for_provider(LlmProvider::Mistral, model))
800 }
801
802 pub fn deepseek(model: impl Into<String>) -> Result<Self> {
804 Self::new(LlmConfig::for_provider(LlmProvider::DeepSeek, model))
805 }
806
807 pub fn together(model: impl Into<String>) -> Result<Self> {
809 Self::new(LlmConfig::for_provider(LlmProvider::TogetherAI, model))
810 }
811
812 pub fn fireworks(model: impl Into<String>) -> Result<Self> {
814 Self::new(LlmConfig::for_provider(LlmProvider::FireworksAI, model))
815 }
816
817 pub fn qwen(model: impl Into<String>) -> Result<Self> {
819 Self::new(LlmConfig::for_provider(LlmProvider::AlibabaQwen, model))
820 }
821
822 pub fn cohere(model: impl Into<String>) -> Result<Self> {
824 Self::new(LlmConfig::for_provider(LlmProvider::Cohere, model))
825 }
826
827 pub fn perplexity(model: impl Into<String>) -> Result<Self> {
829 Self::new(LlmConfig::for_provider(LlmProvider::Perplexity, model))
830 }
831
832 pub fn cerebras(model: impl Into<String>) -> Result<Self> {
834 Self::new(LlmConfig::for_provider(LlmProvider::Cerebras, model))
835 }
836
837 pub fn azure(
839 resource: impl Into<String>,
840 deployment: impl Into<String>,
841 model: impl Into<String>,
842 ) -> Result<Self> {
843 Self::new(
844 LlmConfig::for_provider(LlmProvider::AzureOpenAI, model)
845 .with_azure(resource, deployment),
846 )
847 }
848
849 pub fn cloudflare_gateway(
851 account_id: impl Into<String>,
852 gateway_id: impl Into<String>,
853 model: impl Into<String>,
854 ) -> Result<Self> {
855 Self::new(
856 LlmConfig::for_provider(LlmProvider::CloudflareAI, model)
857 .with_cloudflare_gateway(account_id, gateway_id),
858 )
859 }
860
861 pub fn ollama() -> Result<Self> {
867 let model = env_first(&["RK_OLLAMA_MODEL", "OLLAMA_MODEL"]).ok_or_else(|| {
868 Error::Config("Ollama model not set. Set RK_OLLAMA_MODEL (or OLLAMA_MODEL)".to_string())
869 })?;
870
871 let base_url = env_first(&["RK_OLLAMA_URL", "OLLAMA_URL"]);
872
873 let mut cfg = LlmConfig::for_provider(LlmProvider::Ollama, model);
874 cfg.base_url = base_url;
875
876 Self::new(cfg)
877 }
878
879 fn get_api_key(&self) -> Result<String> {
885 if let Some(key) = &self.config.api_key {
886 return Ok(key.clone());
887 }
888
889 let env_var = self.config.provider.env_var();
890 std::env::var(env_var).map_err(|_| {
891 Error::Config(format!(
892 "API key not found. Set {} or provide in config",
893 env_var
894 ))
895 })
896 }
897
898 fn get_base_url(&self) -> Result<String> {
900 if let Some(url) = &self.config.base_url {
901 return Ok(url.clone());
902 }
903
904 match self.config.provider {
905 LlmProvider::Ollama => {
906 Ok(env_first(&["RK_OLLAMA_URL", "OLLAMA_URL"]) .unwrap_or_else(|| self.config.provider.default_base_url().to_string()))
908 }
909 LlmProvider::AzureOpenAI => {
910 let resource = self
911 .config
912 .extra
913 .azure_resource
914 .as_ref()
915 .ok_or_else(|| Error::Config("Azure resource name required".to_string()))?;
916 let deployment =
917 self.config.extra.azure_deployment.as_ref().ok_or_else(|| {
918 Error::Config("Azure deployment name required".to_string())
919 })?;
920 Ok(format!(
921 "https://{}.openai.azure.com/openai/deployments/{}",
922 resource, deployment
923 ))
924 }
925 LlmProvider::GoogleVertex => {
926 let project = self
927 .config
928 .extra
929 .gcp_project
930 .as_ref()
931 .ok_or_else(|| Error::Config("GCP project ID required".to_string()))?;
932 let default_location = "us-central1".to_string();
933 let location = self
934 .config
935 .extra
936 .gcp_location
937 .as_ref()
938 .unwrap_or(&default_location);
939 Ok(format!(
940 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models",
941 location, project, location
942 ))
943 }
944 LlmProvider::AWSBedrock => {
945 let default_region = "us-east-1".to_string();
946 let region = self
947 .config
948 .extra
949 .aws_region
950 .as_ref()
951 .unwrap_or(&default_region);
952 Ok(format!("https://bedrock-runtime.{}.amazonaws.com", region))
953 }
954 LlmProvider::CloudflareAI => {
955 let account_id =
956 self.config.extra.cf_account_id.as_ref().ok_or_else(|| {
957 Error::Config("Cloudflare account ID required".to_string())
958 })?;
959 let gateway_id =
960 self.config.extra.cf_gateway_id.as_ref().ok_or_else(|| {
961 Error::Config("Cloudflare gateway ID required".to_string())
962 })?;
963 Ok(format!(
964 "https://gateway.ai.cloudflare.com/v1/{}/{}/openai",
965 account_id, gateway_id
966 ))
967 }
968 _ => Ok(self.config.provider.default_base_url().to_string()),
969 }
970 }
971
972 async fn call_anthropic(&self, request: LlmRequest) -> Result<LlmResponse> {
978 let api_key = self.get_api_key()?;
979 let base_url = self.get_base_url()?;
980 let url = format!("{}/messages", base_url);
981
982 let messages = vec![serde_json::json!({
983 "role": "user",
984 "content": request.prompt
985 })];
986
987 let body = serde_json::json!({
988 "model": self.config.model,
989 "max_tokens": request.max_tokens.unwrap_or(self.config.max_tokens),
990 "temperature": request.temperature.unwrap_or(self.config.temperature),
991 "system": request.system.unwrap_or_else(|| "You are ReasonKit, a structured reasoning engine. You answer precisely and accurately.".to_string()),
992 "messages": messages
993 });
994
995 let response = self
996 .http_client
997 .post(&url)
998 .header("x-api-key", &api_key)
999 .header("anthropic-version", "2023-06-01")
1000 .header("content-type", "application/json")
1001 .json(&body)
1002 .send()
1003 .await
1004 .map_err(|e| Error::Network(format!("Anthropic API request failed: {}", e)))?;
1005
1006 if !response.status().is_success() {
1007 let status = response.status();
1008 let text = response.text().await.unwrap_or_default();
1009 return Err(Error::Network(format!(
1010 "Anthropic API error {}: {}",
1011 status, text
1012 )));
1013 }
1014
1015 let json: AnthropicResponse = response.json().await.map_err(|e| Error::Parse {
1016 message: format!("Failed to parse Anthropic response: {}", e),
1017 })?;
1018
1019 Ok(LlmResponse {
1020 content: json
1021 .content
1022 .first()
1023 .map(|c| c.text.clone())
1024 .unwrap_or_default(),
1025 model: json.model,
1026 finish_reason: match json.stop_reason.as_deref() {
1027 Some("end_turn") => FinishReason::Stop,
1028 Some("max_tokens") => FinishReason::MaxTokens,
1029 _ => FinishReason::Stop,
1030 },
1031 usage: LlmUsage {
1032 input_tokens: json.usage.input_tokens,
1033 output_tokens: json.usage.output_tokens,
1034 total_tokens: json.usage.input_tokens + json.usage.output_tokens,
1035 },
1036 provider: Some(LlmProvider::Anthropic),
1037 })
1038 }
1039
1040 async fn call_ollama_chat(&self, request: LlmRequest) -> Result<LlmResponse> {
1042 let base_url = self.get_base_url()?;
1043
1044 let mut messages: Vec<OllamaChatMessage> = Vec::new();
1045 if let Some(system) = &request.system {
1046 messages.push(OllamaChatMessage {
1047 role: "system".to_string(),
1048 content: system.clone(),
1049 });
1050 }
1051 messages.push(OllamaChatMessage {
1052 role: "user".to_string(),
1053 content: request.prompt,
1054 });
1055
1056 let mut options: BTreeMap<String, serde_json::Value> = BTreeMap::new();
1059 options.insert(
1060 "temperature".to_string(),
1061 serde_json::Value::from(request.temperature.unwrap_or(self.config.temperature)),
1062 );
1063 options.insert(
1064 "num_predict".to_string(),
1065 serde_json::Value::from(
1066 request
1067 .max_tokens
1068 .unwrap_or(self.config.max_tokens)
1069 .min(i32::MAX as u32) as i64,
1070 ),
1071 );
1072
1073 let req = OllamaChatRequest {
1074 model: self.config.model.clone(),
1075 messages,
1076 stream: Some(false),
1077 options: Some(options),
1078 };
1079
1080 let client = OllamaClient::new(base_url)?;
1081 let resp = client
1082 .chat(req)
1083 .await
1084 .map_err(|e| Error::Network(format!("Ollama API request failed: {}", e)))?;
1085
1086 Ok(LlmResponse {
1087 content: resp.message.content,
1088 model: resp.model,
1089 finish_reason: FinishReason::Stop,
1091 usage: LlmUsage::default(),
1093 provider: Some(LlmProvider::Ollama),
1094 })
1095 }
1096
1097 async fn call_openai_compatible(&self, request: LlmRequest) -> Result<LlmResponse> {
1099 let api_key = self.get_api_key()?;
1100 let base_url = self.get_base_url()?;
1101 let url = format!("{}/chat/completions", base_url);
1102
1103 let mut messages = Vec::new();
1104
1105 if let Some(system) = &request.system {
1106 messages.push(serde_json::json!({
1107 "role": "system",
1108 "content": system
1109 }));
1110 }
1111
1112 messages.push(serde_json::json!({
1113 "role": "user",
1114 "content": request.prompt
1115 }));
1116
1117 let body = serde_json::json!({
1118 "model": self.config.model,
1119 "max_tokens": request.max_tokens.unwrap_or(self.config.max_tokens),
1120 "temperature": request.temperature.unwrap_or(self.config.temperature),
1121 "messages": messages
1122 });
1123
1124 let mut req = self
1125 .http_client
1126 .post(&url)
1127 .header("Authorization", format!("Bearer {}", api_key))
1128 .header("content-type", "application/json");
1129
1130 match self.config.provider {
1132 LlmProvider::OpenRouter => {
1133 req = req
1134 .header("HTTP-Referer", "https://reasonkit.sh")
1135 .header("X-Title", "ReasonKit ThinkTool");
1136 }
1137 LlmProvider::AzureOpenAI => {
1138 req = req
1139 .header("api-key", &api_key)
1140 .header("api-version", "2024-02-15-preview");
1141 }
1142 LlmProvider::GoogleGemini => {
1143 req = req.header("x-goog-api-key", &api_key);
1145 }
1146 _ => {}
1147 }
1148
1149 let response = req.json(&body).send().await.map_err(|e| {
1150 Error::Network(format!(
1151 "{} API request failed: {}",
1152 self.config.provider, e
1153 ))
1154 })?;
1155
1156 if !response.status().is_success() {
1157 let status = response.status();
1158 let text = response.text().await.unwrap_or_default();
1159 return Err(Error::Network(format!(
1160 "{} API error {}: {}",
1161 self.config.provider, status, text
1162 )));
1163 }
1164
1165 let json: OpenAIResponse = response.json().await.map_err(|e| Error::Parse {
1166 message: format!("Failed to parse {} response: {}", self.config.provider, e),
1167 })?;
1168
1169 let choice = json.choices.first().ok_or_else(|| Error::Parse {
1170 message: "No choices in response".to_string(),
1171 })?;
1172
1173 Ok(LlmResponse {
1174 content: choice.message.content.clone().unwrap_or_default(),
1175 model: json.model,
1176 finish_reason: match choice.finish_reason.as_deref() {
1177 Some("stop") => FinishReason::Stop,
1178 Some("length") => FinishReason::MaxTokens,
1179 Some("content_filter") => FinishReason::ContentFilter,
1180 _ => FinishReason::Stop,
1181 },
1182 usage: LlmUsage {
1183 input_tokens: json.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
1184 output_tokens: json
1185 .usage
1186 .as_ref()
1187 .map(|u| u.completion_tokens)
1188 .unwrap_or(0),
1189 total_tokens: json.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
1190 },
1191 provider: Some(self.config.provider),
1192 })
1193 }
1194}
1195
1196#[async_trait]
1197impl LlmClient for UnifiedLlmClient {
1198 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
1199 match self.config.provider {
1200 LlmProvider::Anthropic => self.call_anthropic(request).await,
1201 LlmProvider::Ollama => self.call_ollama_chat(request).await,
1202 _ => self.call_openai_compatible(request).await,
1203 }
1204 }
1205
1206 fn provider(&self) -> LlmProvider {
1207 self.config.provider
1208 }
1209
1210 fn model(&self) -> &str {
1211 &self.config.model
1212 }
1213}
1214
1215pub fn discover_available_providers() -> Vec<LlmProvider> {
1221 LlmProvider::all()
1222 .iter()
1223 .filter(|p| {
1224 if p.requires_special_auth() {
1226 return false;
1227 }
1228
1229 match p {
1230 LlmProvider::Ollama => env_first(&["RK_OLLAMA_MODEL", "OLLAMA_MODEL"]).is_some(),
1232 _ => std::env::var(p.env_var()).is_ok(),
1233 }
1234 })
1235 .copied()
1236 .collect()
1237}
1238
1239pub fn create_available_client() -> Result<UnifiedLlmClient> {
1241 let available = discover_available_providers();
1242
1243 if available.is_empty() {
1244 return Err(Error::Config(
1245 "No LLM providers found. Set an API key (e.g. ANTHROPIC_API_KEY, OPENAI_API_KEY, ...) or set RK_OLLAMA_MODEL for Ollama.".to_string(),
1246 ));
1247 }
1248
1249 let provider = available[0];
1250
1251 if provider == LlmProvider::Ollama {
1252 return UnifiedLlmClient::ollama();
1253 }
1254
1255 UnifiedLlmClient::new(LlmConfig {
1256 provider,
1257 model: provider.default_model().to_string(),
1258 ..Default::default()
1259 })
1260}
1261
1262#[derive(Debug, Clone, Serialize)]
1264pub struct ProviderInfo {
1265 pub id: LlmProvider,
1267 pub name: &'static str,
1269 pub env_var: &'static str,
1271 pub default_model: &'static str,
1273 pub base_url: &'static str,
1275 pub is_available: bool,
1277}
1278
1279pub fn get_provider_info() -> Vec<ProviderInfo> {
1281 LlmProvider::all()
1282 .iter()
1283 .map(|p| {
1284 let is_available = match p {
1285 LlmProvider::Ollama => env_first(&["RK_OLLAMA_MODEL", "OLLAMA_MODEL"]).is_some(),
1286 _ => std::env::var(p.env_var()).is_ok(),
1287 };
1288
1289 ProviderInfo {
1290 id: *p,
1291 name: p.display_name(),
1292 env_var: p.env_var(),
1293 default_model: p.default_model(),
1294 base_url: p.default_base_url(),
1295 is_available,
1296 }
1297 })
1298 .collect()
1299}
1300
1301#[derive(Debug, Deserialize)]
1306struct AnthropicResponse {
1307 model: String,
1308 content: Vec<AnthropicContent>,
1309 stop_reason: Option<String>,
1310 usage: AnthropicUsage,
1311}
1312
1313#[derive(Debug, Deserialize)]
1314struct AnthropicContent {
1315 #[serde(rename = "type")]
1316 #[allow(dead_code)]
1317 content_type: String,
1318 text: String,
1319}
1320
1321#[derive(Debug, Deserialize)]
1322struct AnthropicUsage {
1323 input_tokens: u32,
1324 output_tokens: u32,
1325}
1326
1327#[derive(Debug, Deserialize)]
1328struct OpenAIResponse {
1329 model: String,
1330 choices: Vec<OpenAIChoice>,
1331 usage: Option<OpenAIUsage>,
1332}
1333
1334#[derive(Debug, Deserialize)]
1335struct OpenAIChoice {
1336 message: OpenAIMessage,
1337 finish_reason: Option<String>,
1338}
1339
1340#[derive(Debug, Deserialize)]
1341struct OpenAIMessage {
1342 content: Option<String>,
1343}
1344
1345#[derive(Debug, Deserialize)]
1346struct OpenAIUsage {
1347 prompt_tokens: u32,
1348 completion_tokens: u32,
1349 total_tokens: u32,
1350}
1351
1352#[cfg(test)]
1357mod tests {
1358 use super::*;
1359
1360 #[test]
1365 fn test_provider_count() {
1366 assert_eq!(LlmProvider::all().len(), 20);
1367 }
1368
1369 #[test]
1370 fn test_all_providers_unique() {
1371 let providers = LlmProvider::all();
1372 let mut seen = std::collections::HashSet::new();
1373 for p in providers {
1374 assert!(seen.insert(p), "Duplicate provider: {:?}", p);
1375 }
1376 }
1377
1378 #[test]
1379 fn test_provider_default_is_anthropic() {
1380 let default = LlmProvider::default();
1381 assert_eq!(default, LlmProvider::Anthropic);
1382 }
1383
1384 #[test]
1385 fn test_provider_env_vars() {
1386 assert_eq!(LlmProvider::Anthropic.env_var(), "ANTHROPIC_API_KEY");
1387 assert_eq!(LlmProvider::OpenAI.env_var(), "OPENAI_API_KEY");
1388 assert_eq!(LlmProvider::GoogleGemini.env_var(), "GEMINI_API_KEY");
1389 assert_eq!(LlmProvider::Ollama.env_var(), "RK_OLLAMA_MODEL");
1390 assert_eq!(LlmProvider::XAI.env_var(), "XAI_API_KEY");
1391 assert_eq!(LlmProvider::Groq.env_var(), "GROQ_API_KEY");
1392 assert_eq!(LlmProvider::Mistral.env_var(), "MISTRAL_API_KEY");
1393 assert_eq!(LlmProvider::DeepSeek.env_var(), "DEEPSEEK_API_KEY");
1394 assert_eq!(LlmProvider::TogetherAI.env_var(), "TOGETHER_API_KEY");
1395 assert_eq!(LlmProvider::FireworksAI.env_var(), "FIREWORKS_API_KEY");
1396 assert_eq!(LlmProvider::AlibabaQwen.env_var(), "DASHSCOPE_API_KEY");
1397 assert_eq!(LlmProvider::OpenRouter.env_var(), "OPENROUTER_API_KEY");
1398 assert_eq!(LlmProvider::CloudflareAI.env_var(), "CLOUDFLARE_API_KEY");
1399 assert_eq!(LlmProvider::Cohere.env_var(), "COHERE_API_KEY");
1400 assert_eq!(LlmProvider::Perplexity.env_var(), "PERPLEXITY_API_KEY");
1401 assert_eq!(LlmProvider::Cerebras.env_var(), "CEREBRAS_API_KEY");
1402 assert_eq!(LlmProvider::Opencode.env_var(), "OPENCODE_API_KEY");
1403 }
1404
1405 #[test]
1406 fn test_provider_base_urls() {
1407 assert_eq!(
1408 LlmProvider::Anthropic.default_base_url(),
1409 "https://api.anthropic.com/v1"
1410 );
1411 assert_eq!(
1412 LlmProvider::OpenAI.default_base_url(),
1413 "https://api.openai.com/v1"
1414 );
1415 assert_eq!(
1416 LlmProvider::Ollama.default_base_url(),
1417 "http://localhost:11434"
1418 );
1419 assert_eq!(
1420 LlmProvider::Groq.default_base_url(),
1421 "https://api.groq.com/openai/v1"
1422 );
1423 assert_eq!(LlmProvider::XAI.default_base_url(), "https://api.x.ai/v1");
1424 assert_eq!(
1425 LlmProvider::Mistral.default_base_url(),
1426 "https://api.mistral.ai/v1"
1427 );
1428 assert_eq!(
1429 LlmProvider::DeepSeek.default_base_url(),
1430 "https://api.deepseek.com/v1"
1431 );
1432 assert_eq!(
1433 LlmProvider::TogetherAI.default_base_url(),
1434 "https://api.together.xyz/v1"
1435 );
1436 assert_eq!(
1437 LlmProvider::FireworksAI.default_base_url(),
1438 "https://api.fireworks.ai/inference/v1"
1439 );
1440 assert_eq!(
1441 LlmProvider::Cohere.default_base_url(),
1442 "https://api.cohere.ai/v1"
1443 );
1444 assert_eq!(
1445 LlmProvider::Perplexity.default_base_url(),
1446 "https://api.perplexity.ai"
1447 );
1448 assert_eq!(
1449 LlmProvider::Cerebras.default_base_url(),
1450 "https://api.cerebras.ai/v1"
1451 );
1452 assert_eq!(
1453 LlmProvider::Opencode.default_base_url(),
1454 "https://api.opencode.ai/v1"
1455 );
1456 }
1457
1458 #[test]
1459 fn test_provider_base_urls_contain_https() {
1460 for provider in LlmProvider::all() {
1461 let url = provider.default_base_url();
1462
1463 if *provider == LlmProvider::Ollama {
1464 assert!(
1465 url.starts_with("http://localhost"),
1466 "Ollama default URL should be localhost http://: {}",
1467 url
1468 );
1469 continue;
1470 }
1471
1472 assert!(
1473 url.starts_with("https://"),
1474 "Provider {:?} URL does not start with https://: {}",
1475 provider,
1476 url
1477 );
1478 }
1479 }
1480
1481 #[test]
1482 fn test_provider_compatibility() {
1483 assert!(LlmProvider::Anthropic.is_anthropic_format());
1484 assert!(!LlmProvider::OpenAI.is_anthropic_format());
1485
1486 assert!(!LlmProvider::Anthropic.is_openai_compatible());
1487 assert!(LlmProvider::OpenAI.is_openai_compatible());
1488 assert!(LlmProvider::Groq.is_openai_compatible());
1489 assert!(LlmProvider::XAI.is_openai_compatible());
1490 assert!(LlmProvider::GoogleGemini.is_openai_compatible());
1491 assert!(LlmProvider::Mistral.is_openai_compatible());
1492 assert!(LlmProvider::DeepSeek.is_openai_compatible());
1493 assert!(LlmProvider::TogetherAI.is_openai_compatible());
1494 assert!(LlmProvider::FireworksAI.is_openai_compatible());
1495 assert!(LlmProvider::OpenRouter.is_openai_compatible());
1496 }
1497
1498 #[test]
1499 fn test_special_auth_providers() {
1500 assert!(LlmProvider::AzureOpenAI.requires_special_auth());
1501 assert!(LlmProvider::AWSBedrock.requires_special_auth());
1502 assert!(LlmProvider::GoogleVertex.requires_special_auth());
1503 assert!(!LlmProvider::OpenAI.requires_special_auth());
1504 assert!(!LlmProvider::Groq.requires_special_auth());
1505 assert!(!LlmProvider::Anthropic.requires_special_auth());
1506 assert!(!LlmProvider::DeepSeek.requires_special_auth());
1507 }
1508
1509 #[test]
1510 fn test_provider_display() {
1511 assert_eq!(LlmProvider::Anthropic.to_string(), "Anthropic");
1512 assert_eq!(LlmProvider::XAI.to_string(), "xAI (Grok)");
1513 assert_eq!(
1514 LlmProvider::GoogleGemini.to_string(),
1515 "Google Gemini (AI Studio)"
1516 );
1517 assert_eq!(LlmProvider::Groq.to_string(), "Groq");
1518 assert_eq!(LlmProvider::OpenRouter.to_string(), "OpenRouter");
1519 assert_eq!(LlmProvider::DeepSeek.to_string(), "DeepSeek");
1520 }
1521
1522 #[test]
1523 fn test_provider_display_names_non_empty() {
1524 for provider in LlmProvider::all() {
1525 let name = provider.display_name();
1526 assert!(
1527 !name.is_empty(),
1528 "Provider {:?} has empty display name",
1529 provider
1530 );
1531 }
1532 }
1533
1534 #[test]
1535 fn test_provider_default_models_non_empty() {
1536 for provider in LlmProvider::all() {
1537 let model = provider.default_model();
1538 assert!(
1539 !model.is_empty(),
1540 "Provider {:?} has empty default model",
1541 provider
1542 );
1543 }
1544 }
1545
1546 #[test]
1551 fn test_provider_serialization_roundtrip() {
1552 for provider in LlmProvider::all() {
1553 let json = serde_json::to_string(provider).expect("Serialization failed");
1554 let parsed: LlmProvider = serde_json::from_str(&json).expect("Deserialization failed");
1555 assert_eq!(*provider, parsed);
1556 }
1557 }
1558
1559 #[test]
1560 fn test_provider_serialization_snake_case() {
1561 let json = serde_json::to_string(&LlmProvider::OpenAI).unwrap();
1562 assert_eq!(json, "\"open_ai\"");
1563
1564 let json = serde_json::to_string(&LlmProvider::GoogleGemini).unwrap();
1565 assert_eq!(json, "\"google_gemini\"");
1566
1567 let json = serde_json::to_string(&LlmProvider::TogetherAI).unwrap();
1568 assert_eq!(json, "\"together_ai\"");
1569 }
1570
1571 #[test]
1572 fn test_provider_deserialization_from_snake_case() {
1573 let provider: LlmProvider = serde_json::from_str("\"open_ai\"").unwrap();
1574 assert_eq!(provider, LlmProvider::OpenAI);
1575
1576 let provider: LlmProvider = serde_json::from_str("\"deep_seek\"").unwrap();
1577 assert_eq!(provider, LlmProvider::DeepSeek);
1578 }
1579
1580 #[test]
1585 fn test_llm_config_default() {
1586 let config = LlmConfig::default();
1587 assert_eq!(config.provider, LlmProvider::Anthropic);
1588 assert_eq!(config.model, "claude-opus-4-5");
1589 assert_eq!(config.temperature, 0.7);
1590 assert_eq!(config.max_tokens, 2000);
1591 assert_eq!(config.timeout_secs, 60);
1592 assert!(config.api_key.is_none());
1593 assert!(config.base_url.is_none());
1594 }
1595
1596 #[test]
1597 fn test_llm_config_for_provider() {
1598 let config = LlmConfig::for_provider(LlmProvider::Groq, "llama-3.3-70b-versatile");
1599 assert_eq!(config.provider, LlmProvider::Groq);
1600 assert_eq!(config.model, "llama-3.3-70b-versatile");
1601 assert_eq!(config.temperature, 0.7);
1603 assert_eq!(config.max_tokens, 2000);
1604 }
1605
1606 #[test]
1607 fn test_llm_config_builder() {
1608 let config = LlmConfig::for_provider(LlmProvider::Groq, "llama-3.3-70b-versatile")
1609 .with_temperature(0.5)
1610 .with_max_tokens(4000);
1611
1612 assert_eq!(config.provider, LlmProvider::Groq);
1613 assert_eq!(config.model, "llama-3.3-70b-versatile");
1614 assert_eq!(config.temperature, 0.5);
1615 assert_eq!(config.max_tokens, 4000);
1616 }
1617
1618 #[test]
1619 fn test_llm_config_with_api_key() {
1620 let config = LlmConfig::default().with_api_key("test-key-12345");
1621 assert_eq!(config.api_key, Some("test-key-12345".to_string()));
1622 }
1623
1624 #[test]
1625 fn test_llm_config_with_base_url() {
1626 let config = LlmConfig::default().with_base_url("https://custom.api.com/v1");
1627 assert_eq!(
1628 config.base_url,
1629 Some("https://custom.api.com/v1".to_string())
1630 );
1631 }
1632
1633 #[test]
1634 fn test_llm_config_chained_builders() {
1635 let config = LlmConfig::for_provider(LlmProvider::OpenAI, "gpt-4o")
1636 .with_api_key("sk-test")
1637 .with_base_url("https://proxy.example.com/v1")
1638 .with_temperature(0.3)
1639 .with_max_tokens(8000);
1640
1641 assert_eq!(config.provider, LlmProvider::OpenAI);
1642 assert_eq!(config.model, "gpt-4o");
1643 assert_eq!(config.api_key, Some("sk-test".to_string()));
1644 assert_eq!(
1645 config.base_url,
1646 Some("https://proxy.example.com/v1".to_string())
1647 );
1648 assert_eq!(config.temperature, 0.3);
1649 assert_eq!(config.max_tokens, 8000);
1650 }
1651
1652 #[test]
1653 fn test_azure_config() {
1654 let config = LlmConfig::for_provider(LlmProvider::AzureOpenAI, "gpt-4o")
1655 .with_azure("my-resource", "my-deployment");
1656
1657 assert_eq!(config.extra.azure_resource, Some("my-resource".to_string()));
1658 assert_eq!(
1659 config.extra.azure_deployment,
1660 Some("my-deployment".to_string())
1661 );
1662 }
1663
1664 #[test]
1665 fn test_gcp_config() {
1666 let config = LlmConfig::for_provider(LlmProvider::GoogleVertex, "gemini-3.0-pro")
1667 .with_gcp("my-project-123", "us-west1");
1668
1669 assert_eq!(config.extra.gcp_project, Some("my-project-123".to_string()));
1670 assert_eq!(config.extra.gcp_location, Some("us-west1".to_string()));
1671 }
1672
1673 #[test]
1674 fn test_aws_region_config() {
1675 let config = LlmConfig::for_provider(LlmProvider::AWSBedrock, "anthropic.claude-v2")
1676 .with_aws_region("eu-west-1");
1677
1678 assert_eq!(config.extra.aws_region, Some("eu-west-1".to_string()));
1679 }
1680
1681 #[test]
1682 fn test_cloudflare_config() {
1683 let config = LlmConfig::for_provider(LlmProvider::CloudflareAI, "@cf/meta/llama-3.3-70b")
1684 .with_cloudflare_gateway("account123", "gateway456");
1685
1686 assert_eq!(config.extra.cf_account_id, Some("account123".to_string()));
1687 assert_eq!(config.extra.cf_gateway_id, Some("gateway456".to_string()));
1688 }
1689
1690 #[test]
1691 fn test_llm_config_serialization() {
1692 let config = LlmConfig::for_provider(LlmProvider::OpenAI, "gpt-4o").with_temperature(0.5);
1693
1694 let json = serde_json::to_string(&config).expect("Serialization failed");
1695 let parsed: LlmConfig = serde_json::from_str(&json).expect("Deserialization failed");
1696
1697 assert_eq!(parsed.provider, LlmProvider::OpenAI);
1698 assert_eq!(parsed.model, "gpt-4o");
1699 assert_eq!(parsed.temperature, 0.5);
1700 }
1701
1702 #[test]
1707 fn test_llm_request_new() {
1708 let request = LlmRequest::new("Hello, world!");
1709 assert_eq!(request.prompt, "Hello, world!");
1710 assert!(request.system.is_none());
1711 assert!(request.temperature.is_none());
1712 assert!(request.max_tokens.is_none());
1713 }
1714
1715 #[test]
1716 fn test_llm_request_builder() {
1717 let request = LlmRequest::new("Hello")
1718 .with_system("You are helpful")
1719 .with_temperature(0.5)
1720 .with_max_tokens(100);
1721
1722 assert_eq!(request.prompt, "Hello");
1723 assert_eq!(request.system, Some("You are helpful".to_string()));
1724 assert_eq!(request.temperature, Some(0.5));
1725 assert_eq!(request.max_tokens, Some(100));
1726 }
1727
1728 #[test]
1729 fn test_llm_request_with_system_only() {
1730 let request = LlmRequest::new("Test prompt").with_system("System prompt here");
1731 assert_eq!(request.system, Some("System prompt here".to_string()));
1732 assert!(request.temperature.is_none());
1733 assert!(request.max_tokens.is_none());
1734 }
1735
1736 #[test]
1737 fn test_llm_request_with_long_prompt() {
1738 let long_prompt = "a".repeat(100_000);
1739 let request = LlmRequest::new(&long_prompt);
1740 assert_eq!(request.prompt.len(), 100_000);
1741 }
1742
1743 #[test]
1744 fn test_llm_request_with_unicode() {
1745 let request =
1746 LlmRequest::new("Hello world in Japanese: Konnichiwa! Chinese: Ni hao! Emoji: Test");
1747 assert!(request.prompt.contains("Konnichiwa"));
1748 assert!(request.prompt.contains("Ni hao"));
1749 }
1750
1751 #[test]
1752 fn test_llm_request_temperature_boundaries() {
1753 let request = LlmRequest::new("Test").with_temperature(0.0);
1755 assert_eq!(request.temperature, Some(0.0));
1756
1757 let request = LlmRequest::new("Test").with_temperature(2.0);
1758 assert_eq!(request.temperature, Some(2.0));
1759
1760 let request = LlmRequest::new("Test").with_temperature(1.0);
1761 assert_eq!(request.temperature, Some(1.0));
1762 }
1763
1764 #[test]
1769 fn test_finish_reason_default() {
1770 let reason = FinishReason::default();
1771 assert_eq!(reason, FinishReason::Stop);
1772 }
1773
1774 #[test]
1775 fn test_finish_reason_serialization() {
1776 let reasons = vec![
1777 FinishReason::Stop,
1778 FinishReason::MaxTokens,
1779 FinishReason::ContentFilter,
1780 FinishReason::Error,
1781 ];
1782
1783 for reason in reasons {
1784 let json = serde_json::to_string(&reason).expect("Serialization failed");
1785 let parsed: FinishReason = serde_json::from_str(&json).expect("Deserialization failed");
1786 assert_eq!(reason, parsed);
1787 }
1788 }
1789
1790 #[test]
1791 fn test_llm_response_serialization() {
1792 let response = LlmResponse {
1793 content: "Test response".to_string(),
1794 model: "gpt-4o".to_string(),
1795 finish_reason: FinishReason::Stop,
1796 usage: LlmUsage {
1797 input_tokens: 100,
1798 output_tokens: 50,
1799 total_tokens: 150,
1800 },
1801 provider: Some(LlmProvider::OpenAI),
1802 };
1803
1804 let json = serde_json::to_string(&response).expect("Serialization failed");
1805 let parsed: LlmResponse = serde_json::from_str(&json).expect("Deserialization failed");
1806
1807 assert_eq!(parsed.content, "Test response");
1808 assert_eq!(parsed.model, "gpt-4o");
1809 assert_eq!(parsed.finish_reason, FinishReason::Stop);
1810 assert_eq!(parsed.usage.input_tokens, 100);
1811 assert_eq!(parsed.provider, Some(LlmProvider::OpenAI));
1812 }
1813
1814 #[test]
1819 fn test_llm_usage_default() {
1820 let usage = LlmUsage::default();
1821 assert_eq!(usage.input_tokens, 0);
1822 assert_eq!(usage.output_tokens, 0);
1823 assert_eq!(usage.total_tokens, 0);
1824 }
1825
1826 #[test]
1827 fn test_cost_calculation_claude() {
1828 let usage = LlmUsage {
1829 input_tokens: 1000,
1830 output_tokens: 500,
1831 total_tokens: 1500,
1832 };
1833
1834 let cost = usage.cost_usd("claude-opus-4-5");
1836 let expected = (1000.0 / 1_000_000.0) * 15.0 + (500.0 / 1_000_000.0) * 75.0;
1837 assert!((cost - expected).abs() < 0.0001);
1838 }
1839
1840 #[test]
1841 fn test_cost_calculation_claude_sonnet() {
1842 let usage = LlmUsage {
1843 input_tokens: 1000,
1844 output_tokens: 500,
1845 total_tokens: 1500,
1846 };
1847
1848 let cost = usage.cost_usd("claude-sonnet-4");
1850 assert!(cost > 0.0);
1851 assert!(cost < 0.02);
1852 }
1853
1854 #[test]
1855 fn test_cost_calculation_gpt35_cheaper_than_sonnet() {
1856 let usage = LlmUsage {
1857 input_tokens: 1000,
1858 output_tokens: 500,
1859 total_tokens: 1500,
1860 };
1861
1862 let cost_sonnet = usage.cost_usd("claude-sonnet-4");
1863 let cost_gpt35 = usage.cost_usd("gpt-3.5-turbo");
1864 assert!(cost_gpt35 < cost_sonnet);
1865 }
1866
1867 #[test]
1868 fn test_cost_calculation_groq_very_cheap() {
1869 let usage = LlmUsage {
1870 input_tokens: 1000,
1871 output_tokens: 500,
1872 total_tokens: 1500,
1873 };
1874
1875 let cost_gpt35 = usage.cost_usd("gpt-3.5-turbo");
1876 let cost_groq = usage.cost_usd("llama-groq");
1877 assert!(cost_groq < cost_gpt35);
1878 }
1879
1880 #[test]
1881 fn test_cost_calculation_deepseek_cheap() {
1882 let usage = LlmUsage {
1883 input_tokens: 1000,
1884 output_tokens: 500,
1885 total_tokens: 1500,
1886 };
1887
1888 let cost_sonnet = usage.cost_usd("claude-sonnet-4");
1889 let cost_deepseek = usage.cost_usd("deepseek-chat");
1890 assert!(cost_deepseek < cost_sonnet);
1891 }
1892
1893 #[test]
1894 fn test_cost_calculation_zero_tokens() {
1895 let usage = LlmUsage {
1896 input_tokens: 0,
1897 output_tokens: 0,
1898 total_tokens: 0,
1899 };
1900
1901 let cost = usage.cost_usd("gpt-4o");
1902 assert_eq!(cost, 0.0);
1903 }
1904
1905 #[test]
1906 fn test_cost_calculation_large_token_count() {
1907 let usage = LlmUsage {
1908 input_tokens: 1_000_000,
1909 output_tokens: 500_000,
1910 total_tokens: 1_500_000,
1911 };
1912
1913 let cost = usage.cost_usd("gpt-4o");
1915 let expected = 2.5 + 5.0; assert!((cost - expected).abs() < 0.01);
1917 }
1918
1919 #[test]
1920 fn test_cost_calculation_various_models() {
1921 let usage = LlmUsage {
1922 input_tokens: 10000,
1923 output_tokens: 5000,
1924 total_tokens: 15000,
1925 };
1926
1927 let models = vec![
1929 "gpt-5.1",
1930 "gpt-4o",
1931 "gemini-3.0-pro",
1932 "grok-4.1",
1933 "mistral-large-3",
1934 "deepseek-v3.2",
1935 "llama-4-scout",
1936 "qwen3-max",
1937 "command-a",
1938 "sonar-pro",
1939 ];
1940
1941 for model in models {
1942 let cost = usage.cost_usd(model);
1943 assert!(cost > 0.0, "Cost for {} should be positive", model);
1944 }
1945 }
1946
1947 #[test]
1948 fn test_cost_calculation_unknown_model_uses_default() {
1949 let usage = LlmUsage {
1950 input_tokens: 1000,
1951 output_tokens: 500,
1952 total_tokens: 1500,
1953 };
1954
1955 let cost = usage.cost_usd("some-unknown-model-xyz");
1957 let expected = (1000.0 / 1_000_000.0) * 1.0 + (500.0 / 1_000_000.0) * 3.0;
1958 assert!((cost - expected).abs() < 0.0001);
1959 }
1960
1961 #[test]
1966 fn test_client_creation_default() {
1967 let config = LlmConfig::default();
1968 let client = UnifiedLlmClient::new(config);
1969 assert!(client.is_ok());
1970 }
1971
1972 #[test]
1973 fn test_client_creation_for_each_provider() {
1974 for provider in LlmProvider::all() {
1975 let config = LlmConfig::for_provider(*provider, provider.default_model());
1976 let client = UnifiedLlmClient::new(config);
1977 assert!(client.is_ok(), "Client creation failed for {:?}", provider);
1978 }
1979 }
1980
1981 #[test]
1982 fn test_client_provider_method() {
1983 let config = LlmConfig::for_provider(LlmProvider::Groq, "llama-3.3-70b-versatile");
1984 let client = UnifiedLlmClient::new(config).unwrap();
1985 assert_eq!(client.provider(), LlmProvider::Groq);
1986 }
1987
1988 #[test]
1989 fn test_client_model_method() {
1990 let config = LlmConfig::for_provider(LlmProvider::OpenAI, "gpt-4o-mini");
1991 let client = UnifiedLlmClient::new(config).unwrap();
1992 assert_eq!(client.model(), "gpt-4o-mini");
1993 }
1994
1995 #[test]
1996 fn test_convenience_constructor_openai() {
1997 let client = UnifiedLlmClient::openai("gpt-4o");
1998 assert!(client.is_ok());
1999 assert_eq!(client.unwrap().provider(), LlmProvider::OpenAI);
2000 }
2001
2002 #[test]
2003 fn test_convenience_constructor_groq() {
2004 let client = UnifiedLlmClient::groq("llama-3.3-70b-versatile");
2005 assert!(client.is_ok());
2006 assert_eq!(client.unwrap().provider(), LlmProvider::Groq);
2007 }
2008
2009 #[test]
2010 fn test_convenience_constructor_deepseek() {
2011 let client = UnifiedLlmClient::deepseek("deepseek-v3");
2012 assert!(client.is_ok());
2013 assert_eq!(client.unwrap().provider(), LlmProvider::DeepSeek);
2014 }
2015
2016 #[test]
2017 fn test_convenience_constructor_mistral() {
2018 let client = UnifiedLlmClient::mistral("mistral-large");
2019 assert!(client.is_ok());
2020 assert_eq!(client.unwrap().provider(), LlmProvider::Mistral);
2021 }
2022
2023 #[test]
2024 fn test_convenience_constructor_grok() {
2025 let client = UnifiedLlmClient::grok("grok-2");
2026 assert!(client.is_ok());
2027 assert_eq!(client.unwrap().provider(), LlmProvider::XAI);
2028 }
2029
2030 #[test]
2031 fn test_convenience_constructor_openrouter() {
2032 let client = UnifiedLlmClient::openrouter("anthropic/claude-3.5-sonnet");
2033 assert!(client.is_ok());
2034 assert_eq!(client.unwrap().provider(), LlmProvider::OpenRouter);
2035 }
2036
2037 #[test]
2042 fn test_azure_url_construction() {
2043 let config = LlmConfig::for_provider(LlmProvider::AzureOpenAI, "gpt-4o")
2044 .with_azure("my-resource", "my-deployment")
2045 .with_api_key("test-key");
2046 let client = UnifiedLlmClient::new(config).unwrap();
2047
2048 let url = client.get_base_url().unwrap();
2049 assert_eq!(
2050 url,
2051 "https://my-resource.openai.azure.com/openai/deployments/my-deployment"
2052 );
2053 }
2054
2055 #[test]
2056 fn test_azure_url_missing_resource_error() {
2057 let config = LlmConfig::for_provider(LlmProvider::AzureOpenAI, "gpt-4o");
2058 let client = UnifiedLlmClient::new(config).unwrap();
2059
2060 let result = client.get_base_url();
2061 assert!(result.is_err());
2062 let err = result.unwrap_err().to_string();
2063 assert!(err.contains("Azure resource name required"));
2064 }
2065
2066 #[test]
2067 fn test_azure_url_missing_deployment_error() {
2068 let mut config = LlmConfig::for_provider(LlmProvider::AzureOpenAI, "gpt-4o");
2069 config.extra.azure_resource = Some("my-resource".to_string());
2070 let client = UnifiedLlmClient::new(config).unwrap();
2071
2072 let result = client.get_base_url();
2073 assert!(result.is_err());
2074 let err = result.unwrap_err().to_string();
2075 assert!(err.contains("Azure deployment name required"));
2076 }
2077
2078 #[test]
2079 fn test_vertex_url_construction() {
2080 let config = LlmConfig::for_provider(LlmProvider::GoogleVertex, "gemini-3.0-pro")
2081 .with_gcp("my-project", "us-west1");
2082 let client = UnifiedLlmClient::new(config).unwrap();
2083
2084 let url = client.get_base_url().unwrap();
2085 assert!(url.contains("us-west1"));
2086 assert!(url.contains("my-project"));
2087 assert!(url.contains("aiplatform.googleapis.com"));
2088 }
2089
2090 #[test]
2091 fn test_vertex_url_default_location() {
2092 let mut config = LlmConfig::for_provider(LlmProvider::GoogleVertex, "gemini-3.0-pro");
2093 config.extra.gcp_project = Some("my-project".to_string());
2094 let client = UnifiedLlmClient::new(config).unwrap();
2095
2096 let url = client.get_base_url().unwrap();
2097 assert!(url.contains("us-central1")); }
2099
2100 #[test]
2101 fn test_vertex_url_missing_project_error() {
2102 let config = LlmConfig::for_provider(LlmProvider::GoogleVertex, "gemini-3.0-pro");
2103 let client = UnifiedLlmClient::new(config).unwrap();
2104
2105 let result = client.get_base_url();
2106 assert!(result.is_err());
2107 let err = result.unwrap_err().to_string();
2108 assert!(err.contains("GCP project ID required"));
2109 }
2110
2111 #[test]
2112 fn test_bedrock_url_construction() {
2113 let config = LlmConfig::for_provider(LlmProvider::AWSBedrock, "anthropic.claude-v2")
2114 .with_aws_region("eu-west-1");
2115 let client = UnifiedLlmClient::new(config).unwrap();
2116
2117 let url = client.get_base_url().unwrap();
2118 assert_eq!(url, "https://bedrock-runtime.eu-west-1.amazonaws.com");
2119 }
2120
2121 #[test]
2122 fn test_bedrock_url_default_region() {
2123 let config = LlmConfig::for_provider(LlmProvider::AWSBedrock, "anthropic.claude-v2");
2124 let client = UnifiedLlmClient::new(config).unwrap();
2125
2126 let url = client.get_base_url().unwrap();
2127 assert_eq!(url, "https://bedrock-runtime.us-east-1.amazonaws.com");
2128 }
2129
2130 #[test]
2131 fn test_cloudflare_url_construction() {
2132 let config = LlmConfig::for_provider(LlmProvider::CloudflareAI, "@cf/meta/llama-3")
2133 .with_cloudflare_gateway("acc123", "gw456");
2134 let client = UnifiedLlmClient::new(config).unwrap();
2135
2136 let url = client.get_base_url().unwrap();
2137 assert_eq!(
2138 url,
2139 "https://gateway.ai.cloudflare.com/v1/acc123/gw456/openai"
2140 );
2141 }
2142
2143 #[test]
2144 fn test_cloudflare_url_missing_account_error() {
2145 let config = LlmConfig::for_provider(LlmProvider::CloudflareAI, "@cf/meta/llama-3");
2146 let client = UnifiedLlmClient::new(config).unwrap();
2147
2148 let result = client.get_base_url();
2149 assert!(result.is_err());
2150 let err = result.unwrap_err().to_string();
2151 assert!(err.contains("Cloudflare account ID required"));
2152 }
2153
2154 #[test]
2155 fn test_base_url_override() {
2156 let config = LlmConfig::for_provider(LlmProvider::OpenAI, "gpt-4o")
2157 .with_base_url("https://custom-proxy.example.com/v1");
2158 let client = UnifiedLlmClient::new(config).unwrap();
2159
2160 let url = client.get_base_url().unwrap();
2161 assert_eq!(url, "https://custom-proxy.example.com/v1");
2162 }
2163
2164 #[test]
2169 fn test_api_key_from_config() {
2170 let config = LlmConfig::default().with_api_key("config-key-123");
2171 let client = UnifiedLlmClient::new(config).unwrap();
2172
2173 let key = client.get_api_key().unwrap();
2174 assert_eq!(key, "config-key-123");
2175 }
2176
2177 #[test]
2178 fn test_api_key_missing_error() {
2179 std::env::remove_var("ANTHROPIC_API_KEY");
2181
2182 let config = LlmConfig::default();
2183 let client = UnifiedLlmClient::new(config).unwrap();
2184
2185 let result = client.get_api_key();
2186 assert!(result.is_err());
2187 let err = result.unwrap_err().to_string();
2188 assert!(err.contains("ANTHROPIC_API_KEY"));
2189 }
2190
2191 #[test]
2196 fn test_provider_info() {
2197 let info = get_provider_info();
2198 assert_eq!(info.len(), 20);
2199
2200 let anthropic = info
2201 .iter()
2202 .find(|i| i.id == LlmProvider::Anthropic)
2203 .unwrap();
2204 assert_eq!(anthropic.name, "Anthropic");
2205 assert_eq!(anthropic.env_var, "ANTHROPIC_API_KEY");
2206 assert_eq!(anthropic.default_model, "claude-opus-4-5");
2207 }
2208
2209 #[test]
2210 fn test_provider_info_all_fields_populated() {
2211 let info = get_provider_info();
2212
2213 for provider_info in info {
2214 assert!(!provider_info.name.is_empty());
2215 assert!(!provider_info.env_var.is_empty());
2216 assert!(!provider_info.default_model.is_empty());
2217 assert!(!provider_info.base_url.is_empty());
2218 }
2219 }
2220
2221 #[test]
2226 fn test_http_client_pooling_default_timeout() {
2227 let client1 = get_pooled_client(120);
2229 let client2 = get_pooled_client(120);
2230
2231 assert!(client1.get("https://example.com").build().is_ok());
2235 assert!(client2.get("https://example.com").build().is_ok());
2236 }
2237
2238 #[test]
2239 fn test_http_client_pooling_custom_timeout() {
2240 let client1 = get_pooled_client(30);
2241 let client2 = get_pooled_client(30);
2242
2243 assert!(client1.get("https://example.com").build().is_ok());
2245 assert!(client2.get("https://example.com").build().is_ok());
2246 }
2247
2248 #[test]
2249 fn test_http_client_pooling_different_timeouts() {
2250 let client_30 = get_pooled_client(30);
2251 let client_60 = get_pooled_client(60);
2252 let client_90 = get_pooled_client(90);
2253
2254 assert!(client_30.get("https://example.com").build().is_ok());
2256 assert!(client_60.get("https://example.com").build().is_ok());
2257 assert!(client_90.get("https://example.com").build().is_ok());
2258 }
2259
2260 #[test]
2265 fn test_discover_providers_filters_special_auth() {
2266 let providers = discover_available_providers();
2267
2268 assert!(!providers.contains(&LlmProvider::AzureOpenAI));
2270 assert!(!providers.contains(&LlmProvider::AWSBedrock));
2271 assert!(!providers.contains(&LlmProvider::GoogleVertex));
2272 }
2273
2274 #[test]
2279 fn test_parse_anthropic_response() {
2280 let json = r#"{
2281 "model": "claude-opus-4-5-20251101",
2282 "content": [{"type": "text", "text": "Hello, world!"}],
2283 "stop_reason": "end_turn",
2284 "usage": {"input_tokens": 10, "output_tokens": 5}
2285 }"#;
2286
2287 let response: AnthropicResponse = serde_json::from_str(json).unwrap();
2288
2289 assert_eq!(response.model, "claude-opus-4-5-20251101");
2290 assert_eq!(response.content.len(), 1);
2291 assert_eq!(response.content[0].text, "Hello, world!");
2292 assert_eq!(response.stop_reason, Some("end_turn".to_string()));
2293 assert_eq!(response.usage.input_tokens, 10);
2294 assert_eq!(response.usage.output_tokens, 5);
2295 }
2296
2297 #[test]
2298 fn test_parse_openai_response() {
2299 let json = r#"{
2300 "model": "gpt-4o",
2301 "choices": [
2302 {
2303 "message": {"content": "Test response"},
2304 "finish_reason": "stop"
2305 }
2306 ],
2307 "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30}
2308 }"#;
2309
2310 let response: OpenAIResponse = serde_json::from_str(json).unwrap();
2311
2312 assert_eq!(response.model, "gpt-4o");
2313 assert_eq!(response.choices.len(), 1);
2314 assert_eq!(
2315 response.choices[0].message.content,
2316 Some("Test response".to_string())
2317 );
2318 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
2319 assert!(response.usage.is_some());
2320 let usage = response.usage.unwrap();
2321 assert_eq!(usage.prompt_tokens, 20);
2322 assert_eq!(usage.completion_tokens, 10);
2323 assert_eq!(usage.total_tokens, 30);
2324 }
2325
2326 #[test]
2327 fn test_parse_openai_response_no_usage() {
2328 let json = r#"{
2329 "model": "gpt-4o",
2330 "choices": [
2331 {
2332 "message": {"content": "No usage info"},
2333 "finish_reason": "stop"
2334 }
2335 ]
2336 }"#;
2337
2338 let response: OpenAIResponse = serde_json::from_str(json).unwrap();
2339 assert!(response.usage.is_none());
2340 }
2341
2342 #[test]
2343 fn test_parse_openai_response_null_content() {
2344 let json = r#"{
2345 "model": "gpt-4o",
2346 "choices": [
2347 {
2348 "message": {"content": null},
2349 "finish_reason": "stop"
2350 }
2351 ]
2352 }"#;
2353
2354 let response: OpenAIResponse = serde_json::from_str(json).unwrap();
2355 assert!(response.choices[0].message.content.is_none());
2356 }
2357
2358 #[test]
2359 fn test_parse_anthropic_max_tokens_finish() {
2360 let json = r#"{
2361 "model": "claude-sonnet-4-5",
2362 "content": [{"type": "text", "text": "Truncated..."}],
2363 "stop_reason": "max_tokens",
2364 "usage": {"input_tokens": 100, "output_tokens": 4000}
2365 }"#;
2366
2367 let response: AnthropicResponse = serde_json::from_str(json).unwrap();
2368 assert_eq!(response.stop_reason, Some("max_tokens".to_string()));
2369 }
2370
2371 #[test]
2372 fn test_parse_openai_content_filter_finish() {
2373 let json = r#"{
2374 "model": "gpt-4o",
2375 "choices": [
2376 {
2377 "message": {"content": ""},
2378 "finish_reason": "content_filter"
2379 }
2380 ]
2381 }"#;
2382
2383 let response: OpenAIResponse = serde_json::from_str(json).unwrap();
2384 assert_eq!(
2385 response.choices[0].finish_reason,
2386 Some("content_filter".to_string())
2387 );
2388 }
2389
2390 #[test]
2395 fn test_empty_prompt_allowed() {
2396 let request = LlmRequest::new("");
2397 assert_eq!(request.prompt, "");
2398 }
2399
2400 #[test]
2401 fn test_config_temperature_extreme_values() {
2402 let config = LlmConfig::default().with_temperature(0.0);
2404 assert_eq!(config.temperature, 0.0);
2405
2406 let config = LlmConfig::default().with_temperature(2.0);
2408 assert_eq!(config.temperature, 2.0);
2409 }
2410
2411 #[test]
2412 fn test_config_max_tokens_extreme_values() {
2413 let config = LlmConfig::default().with_max_tokens(1);
2414 assert_eq!(config.max_tokens, 1);
2415
2416 let config = LlmConfig::default().with_max_tokens(1_000_000);
2417 assert_eq!(config.max_tokens, 1_000_000);
2418 }
2419
2420 #[test]
2421 fn test_provider_extra_defaults() {
2422 let extra = ProviderExtra::default();
2423 assert!(extra.azure_resource.is_none());
2424 assert!(extra.azure_deployment.is_none());
2425 assert!(extra.aws_region.is_none());
2426 assert!(extra.gcp_project.is_none());
2427 assert!(extra.gcp_location.is_none());
2428 assert!(extra.cf_account_id.is_none());
2429 assert!(extra.cf_gateway_id.is_none());
2430 assert!(extra.gateway_provider.is_none());
2431 }
2432
2433 struct MockRateLimiter {
2439 requests_per_second: u32,
2440 current_count: std::sync::atomic::AtomicU32,
2441 }
2442
2443 impl MockRateLimiter {
2444 fn new(rps: u32) -> Self {
2445 Self {
2446 requests_per_second: rps,
2447 current_count: std::sync::atomic::AtomicU32::new(0),
2448 }
2449 }
2450
2451 fn try_acquire(&self) -> bool {
2452 let count = self
2453 .current_count
2454 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2455 count < self.requests_per_second
2456 }
2457
2458 fn reset(&self) {
2459 self.current_count
2460 .store(0, std::sync::atomic::Ordering::SeqCst);
2461 }
2462 }
2463
2464 #[test]
2465 fn test_rate_limiter_allows_within_limit() {
2466 let limiter = MockRateLimiter::new(10);
2467
2468 for _ in 0..10 {
2469 assert!(limiter.try_acquire());
2470 }
2471 }
2472
2473 #[test]
2474 fn test_rate_limiter_blocks_over_limit() {
2475 let limiter = MockRateLimiter::new(5);
2476
2477 for _ in 0..5 {
2478 assert!(limiter.try_acquire());
2479 }
2480
2481 assert!(!limiter.try_acquire());
2483 }
2484
2485 #[test]
2486 fn test_rate_limiter_reset() {
2487 let limiter = MockRateLimiter::new(3);
2488
2489 for _ in 0..3 {
2490 limiter.try_acquire();
2491 }
2492 assert!(!limiter.try_acquire());
2493
2494 limiter.reset();
2495 assert!(limiter.try_acquire());
2496 }
2497
2498 struct MockLlmClient {
2504 provider: LlmProvider,
2505 model: String,
2506 response: LlmResponse,
2507 }
2508
2509 impl MockLlmClient {
2510 fn new(provider: LlmProvider, model: impl Into<String>) -> Self {
2511 Self {
2512 provider,
2513 model: model.into(),
2514 response: LlmResponse {
2515 content: "Mock response".to_string(),
2516 model: "mock-model".to_string(),
2517 finish_reason: FinishReason::Stop,
2518 usage: LlmUsage {
2519 input_tokens: 10,
2520 output_tokens: 5,
2521 total_tokens: 15,
2522 },
2523 provider: Some(provider),
2524 },
2525 }
2526 }
2527
2528 fn with_response(mut self, content: impl Into<String>) -> Self {
2529 self.response.content = content.into();
2530 self
2531 }
2532
2533 fn with_finish_reason(mut self, reason: FinishReason) -> Self {
2534 self.response.finish_reason = reason;
2535 self
2536 }
2537
2538 fn with_usage(mut self, input: u32, output: u32) -> Self {
2539 self.response.usage = LlmUsage {
2540 input_tokens: input,
2541 output_tokens: output,
2542 total_tokens: input + output,
2543 };
2544 self
2545 }
2546 }
2547
2548 #[async_trait]
2549 impl LlmClient for MockLlmClient {
2550 async fn complete(&self, _request: LlmRequest) -> Result<LlmResponse> {
2551 Ok(self.response.clone())
2552 }
2553
2554 fn provider(&self) -> LlmProvider {
2555 self.provider
2556 }
2557
2558 fn model(&self) -> &str {
2559 &self.model
2560 }
2561 }
2562
2563 #[tokio::test]
2564 async fn test_mock_client_returns_configured_response() {
2565 let client = MockLlmClient::new(LlmProvider::OpenAI, "gpt-4o")
2566 .with_response("Custom test response")
2567 .with_usage(100, 50);
2568
2569 let request = LlmRequest::new("Test prompt");
2570 let response = client.complete(request).await.unwrap();
2571
2572 assert_eq!(response.content, "Custom test response");
2573 assert_eq!(response.usage.input_tokens, 100);
2574 assert_eq!(response.usage.output_tokens, 50);
2575 assert_eq!(response.provider, Some(LlmProvider::OpenAI));
2576 }
2577
2578 #[tokio::test]
2579 async fn test_mock_client_finish_reason() {
2580 let client = MockLlmClient::new(LlmProvider::Anthropic, "claude-3")
2581 .with_finish_reason(FinishReason::MaxTokens);
2582
2583 let response = client.complete(LlmRequest::new("Test")).await.unwrap();
2584 assert_eq!(response.finish_reason, FinishReason::MaxTokens);
2585 }
2586
2587 #[test]
2588 fn test_mock_client_provider_and_model() {
2589 let client = MockLlmClient::new(LlmProvider::Groq, "llama-3.3-70b");
2590
2591 assert_eq!(client.provider(), LlmProvider::Groq);
2592 assert_eq!(client.model(), "llama-3.3-70b");
2593 }
2594
2595 #[derive(Debug, Clone)]
2601 struct StreamChunk {
2602 content: String,
2603 is_final: bool,
2604 }
2605
2606 struct MockStreamingResponse {
2608 chunks: Vec<StreamChunk>,
2609 }
2610
2611 impl MockStreamingResponse {
2612 fn new(chunks: Vec<&str>) -> Self {
2613 let mut stream_chunks: Vec<StreamChunk> = chunks
2614 .into_iter()
2615 .map(|c| StreamChunk {
2616 content: c.to_string(),
2617 is_final: false,
2618 })
2619 .collect();
2620
2621 if let Some(last) = stream_chunks.last_mut() {
2622 last.is_final = true;
2623 }
2624
2625 Self {
2626 chunks: stream_chunks,
2627 }
2628 }
2629
2630 fn collect_content(&self) -> String {
2631 self.chunks.iter().map(|c| c.content.as_str()).collect()
2632 }
2633 }
2634
2635 #[test]
2636 fn test_streaming_chunks_collection() {
2637 let stream = MockStreamingResponse::new(vec!["Hello", " ", "world", "!"]);
2638
2639 assert_eq!(stream.chunks.len(), 4);
2640 assert_eq!(stream.collect_content(), "Hello world!");
2641 }
2642
2643 #[test]
2644 fn test_streaming_final_flag() {
2645 let stream = MockStreamingResponse::new(vec!["Part 1", "Part 2", "Part 3"]);
2646
2647 assert!(!stream.chunks[0].is_final);
2648 assert!(!stream.chunks[1].is_final);
2649 assert!(stream.chunks[2].is_final);
2650 }
2651
2652 #[test]
2653 fn test_streaming_empty() {
2654 let stream = MockStreamingResponse::new(vec![]);
2655
2656 assert!(stream.chunks.is_empty());
2657 assert_eq!(stream.collect_content(), "");
2658 }
2659
2660 #[test]
2661 fn test_streaming_single_chunk() {
2662 let stream = MockStreamingResponse::new(vec!["Complete response in one chunk"]);
2663
2664 assert_eq!(stream.chunks.len(), 1);
2665 assert!(stream.chunks[0].is_final);
2666 }
2667
2668 #[tokio::test]
2673 async fn test_full_request_response_cycle_mock() {
2674 let client = MockLlmClient::new(LlmProvider::OpenAI, "gpt-4o")
2675 .with_response("This is a comprehensive analysis of your question.")
2676 .with_usage(150, 75);
2677
2678 let request = LlmRequest::new("Analyze the impact of AI on healthcare")
2679 .with_system("You are a medical AI expert")
2680 .with_temperature(0.3)
2681 .with_max_tokens(500);
2682
2683 let response = client.complete(request).await.unwrap();
2684
2685 assert_eq!(
2686 response.content,
2687 "This is a comprehensive analysis of your question."
2688 );
2689 assert_eq!(response.finish_reason, FinishReason::Stop);
2690 assert_eq!(response.usage.total_tokens, 225);
2691 }
2692
2693 #[tokio::test]
2694 async fn test_multiple_providers_mock() {
2695 let providers_and_models = vec![
2696 (LlmProvider::OpenAI, "gpt-4o"),
2697 (LlmProvider::Anthropic, "claude-sonnet-4-5"),
2698 (LlmProvider::Groq, "llama-3.3-70b"),
2699 (LlmProvider::DeepSeek, "deepseek-v3"),
2700 (LlmProvider::Mistral, "mistral-large"),
2701 ];
2702
2703 for (provider, model) in providers_and_models {
2704 let client = MockLlmClient::new(provider, model)
2705 .with_response(format!("Response from {}", provider.display_name()));
2706
2707 let response = client.complete(LlmRequest::new("Test")).await.unwrap();
2708
2709 assert!(response.content.contains(provider.display_name()));
2710 assert_eq!(response.provider, Some(provider));
2711 }
2712 }
2713
2714 #[tokio::test]
2719 async fn test_concurrent_client_creation() {
2720 use tokio::task::JoinSet;
2721
2722 let mut tasks = JoinSet::new();
2723
2724 for i in 0..10 {
2725 tasks.spawn(async move {
2726 let config = LlmConfig::for_provider(LlmProvider::OpenAI, format!("gpt-4o-{}", i));
2727 UnifiedLlmClient::new(config)
2728 });
2729 }
2730
2731 let mut success_count = 0;
2732 while let Some(result) = tasks.join_next().await {
2733 if result.unwrap().is_ok() {
2734 success_count += 1;
2735 }
2736 }
2737
2738 assert_eq!(success_count, 10);
2739 }
2740
2741 #[tokio::test]
2742 async fn test_concurrent_mock_requests() {
2743 use std::sync::Arc;
2744 use tokio::task::JoinSet;
2745
2746 let client = Arc::new(
2747 MockLlmClient::new(LlmProvider::OpenAI, "gpt-4o").with_response("Concurrent response"),
2748 );
2749
2750 let mut tasks = JoinSet::new();
2751
2752 for i in 0..20 {
2753 let client = Arc::clone(&client);
2754 tasks.spawn(async move {
2755 let request = LlmRequest::new(format!("Request {}", i));
2756 client.complete(request).await
2757 });
2758 }
2759
2760 let mut success_count = 0;
2761 while let Some(result) = tasks.join_next().await {
2762 if result.unwrap().is_ok() {
2763 success_count += 1;
2764 }
2765 }
2766
2767 assert_eq!(success_count, 20);
2768 }
2769}