1use crate::models::{
54 integrations::{anthropic::AnthropicModel, gemini::GeminiModel, openai::OpenAIModel},
55 model_pricing::{ContextAware, ModelContextInfo},
56};
57use serde::{Deserialize, Serialize};
58use std::collections::HashMap;
59use std::fmt::Display;
60
61#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ProviderConfig {
95 OpenAI {
97 #[serde(skip_serializing_if = "Option::is_none")]
98 api_key: Option<String>,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 api_endpoint: Option<String>,
101 },
102 Anthropic {
104 #[serde(skip_serializing_if = "Option::is_none")]
105 api_key: Option<String>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 api_endpoint: Option<String>,
108 #[serde(skip_serializing_if = "Option::is_none")]
110 access_token: Option<String>,
111 },
112 Gemini {
114 #[serde(skip_serializing_if = "Option::is_none")]
115 api_key: Option<String>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 api_endpoint: Option<String>,
118 },
119 Custom {
137 #[serde(skip_serializing_if = "Option::is_none")]
138 api_key: Option<String>,
139 api_endpoint: String,
142 },
143}
144
145impl ProviderConfig {
146 pub fn provider_type(&self) -> &'static str {
148 match self {
149 ProviderConfig::OpenAI { .. } => "openai",
150 ProviderConfig::Anthropic { .. } => "anthropic",
151 ProviderConfig::Gemini { .. } => "gemini",
152 ProviderConfig::Custom { .. } => "custom",
153 }
154 }
155
156 pub fn api_key(&self) -> Option<&str> {
158 match self {
159 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
160 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
161 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
162 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
163 }
164 }
165
166 pub fn api_endpoint(&self) -> Option<&str> {
168 match self {
169 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
170 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
171 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
172 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
173 }
174 }
175
176 pub fn access_token(&self) -> Option<&str> {
178 match self {
179 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
180 _ => None,
181 }
182 }
183
184 pub fn openai(api_key: Option<String>) -> Self {
186 ProviderConfig::OpenAI {
187 api_key,
188 api_endpoint: None,
189 }
190 }
191
192 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
194 ProviderConfig::Anthropic {
195 api_key,
196 api_endpoint: None,
197 access_token,
198 }
199 }
200
201 pub fn gemini(api_key: Option<String>) -> Self {
203 ProviderConfig::Gemini {
204 api_key,
205 api_endpoint: None,
206 }
207 }
208
209 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
211 ProviderConfig::Custom {
212 api_key,
213 api_endpoint,
214 }
215 }
216}
217
218#[derive(Clone, Debug, PartialEq, Serialize)]
219pub enum LLMModel {
220 Anthropic(AnthropicModel),
221 Gemini(GeminiModel),
222 OpenAI(OpenAIModel),
223 Custom {
233 provider: String,
235 model: String,
237 },
238}
239
240impl ContextAware for LLMModel {
241 fn context_info(&self) -> ModelContextInfo {
242 match self {
243 LLMModel::Anthropic(model) => model.context_info(),
244 LLMModel::Gemini(model) => model.context_info(),
245 LLMModel::OpenAI(model) => model.context_info(),
246 LLMModel::Custom { .. } => ModelContextInfo::default(),
247 }
248 }
249
250 fn model_name(&self) -> String {
251 match self {
252 LLMModel::Anthropic(model) => model.model_name(),
253 LLMModel::Gemini(model) => model.model_name(),
254 LLMModel::OpenAI(model) => model.model_name(),
255 LLMModel::Custom { provider, model } => format!("{}/{}", provider, model),
256 }
257 }
258}
259
260#[derive(Debug, Clone, Default)]
264pub struct LLMProviderConfig {
265 pub providers: HashMap<String, ProviderConfig>,
267}
268
269impl LLMProviderConfig {
270 pub fn new() -> Self {
272 Self {
273 providers: HashMap::new(),
274 }
275 }
276
277 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
279 self.providers.insert(name.into(), config);
280 }
281
282 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
284 self.providers.get(name)
285 }
286
287 pub fn is_empty(&self) -> bool {
289 self.providers.is_empty()
290 }
291}
292
293impl From<String> for LLMModel {
294 fn from(value: String) -> Self {
307 if let Some((provider, model)) = value.split_once('/') {
310 match provider {
312 "anthropic" => return Self::from_model_name(model),
313 "openai" => return Self::from_model_name(model),
314 "google" | "gemini" => return Self::from_model_name(model),
315 _ => {
317 return LLMModel::Custom {
318 provider: provider.to_string(),
319 model: model.to_string(), };
321 }
322 }
323 }
324
325 Self::from_model_name(&value)
327 }
328}
329
330impl LLMModel {
331 fn from_model_name(model: &str) -> Self {
333 if model.starts_with("claude-haiku-4-5") {
334 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
335 } else if model.starts_with("claude-sonnet-4-5") {
336 LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
337 } else if model.starts_with("claude-opus-4-5") {
338 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
339 } else if model == "gemini-2.5-flash-lite" {
340 LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
341 } else if model.starts_with("gemini-2.5-flash") {
342 LLMModel::Gemini(GeminiModel::Gemini25Flash)
343 } else if model.starts_with("gemini-2.5-pro") {
344 LLMModel::Gemini(GeminiModel::Gemini25Pro)
345 } else if model.starts_with("gemini-3-pro-preview") {
346 LLMModel::Gemini(GeminiModel::Gemini3Pro)
347 } else if model.starts_with("gemini-3-flash-preview") {
348 LLMModel::Gemini(GeminiModel::Gemini3Flash)
349 } else if model.starts_with("gpt-5-mini") {
350 LLMModel::OpenAI(OpenAIModel::GPT5Mini)
351 } else if model.starts_with("gpt-5") {
352 LLMModel::OpenAI(OpenAIModel::GPT5)
353 } else {
354 LLMModel::Custom {
356 provider: "custom".to_string(),
357 model: model.to_string(),
358 }
359 }
360 }
361
362 pub fn provider_name(&self) -> &str {
364 match self {
365 LLMModel::Anthropic(_) => "anthropic",
366 LLMModel::Gemini(_) => "google",
367 LLMModel::OpenAI(_) => "openai",
368 LLMModel::Custom { provider, .. } => provider,
369 }
370 }
371
372 pub fn model_id(&self) -> String {
374 match self {
375 LLMModel::Anthropic(m) => m.to_string(),
376 LLMModel::Gemini(m) => m.to_string(),
377 LLMModel::OpenAI(m) => m.to_string(),
378 LLMModel::Custom { model, .. } => model.clone(),
379 }
380 }
381}
382
383impl Display for LLMModel {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 match self {
386 LLMModel::Anthropic(model) => write!(f, "{}", model),
387 LLMModel::Gemini(model) => write!(f, "{}", model),
388 LLMModel::OpenAI(model) => write!(f, "{}", model),
389 LLMModel::Custom { provider, model } => write!(f, "{}/{}", provider, model),
390 }
391 }
392}
393
394#[derive(Clone, Debug, Serialize, Deserialize, Default)]
396pub struct LLMProviderOptions {
397 #[serde(skip_serializing_if = "Option::is_none")]
399 pub anthropic: Option<LLMAnthropicOptions>,
400
401 #[serde(skip_serializing_if = "Option::is_none")]
403 pub openai: Option<LLMOpenAIOptions>,
404
405 #[serde(skip_serializing_if = "Option::is_none")]
407 pub google: Option<LLMGoogleOptions>,
408}
409
410#[derive(Clone, Debug, Serialize, Deserialize, Default)]
412pub struct LLMAnthropicOptions {
413 #[serde(skip_serializing_if = "Option::is_none")]
415 pub thinking: Option<LLMThinkingOptions>,
416}
417
418#[derive(Clone, Debug, Serialize, Deserialize)]
420pub struct LLMThinkingOptions {
421 pub budget_tokens: u32,
423}
424
425impl LLMThinkingOptions {
426 pub fn new(budget_tokens: u32) -> Self {
427 Self {
428 budget_tokens: budget_tokens.max(1024),
429 }
430 }
431}
432
433#[derive(Clone, Debug, Serialize, Deserialize, Default)]
435pub struct LLMOpenAIOptions {
436 #[serde(skip_serializing_if = "Option::is_none")]
438 pub reasoning_effort: Option<String>,
439}
440
441#[derive(Clone, Debug, Serialize, Deserialize, Default)]
443pub struct LLMGoogleOptions {
444 #[serde(skip_serializing_if = "Option::is_none")]
446 pub thinking_budget: Option<u32>,
447}
448
449#[derive(Clone, Debug, Serialize)]
450pub struct LLMInput {
451 pub model: LLMModel,
452 pub messages: Vec<LLMMessage>,
453 pub max_tokens: u32,
454 pub tools: Option<Vec<LLMTool>>,
455 #[serde(skip_serializing_if = "Option::is_none")]
456 pub provider_options: Option<LLMProviderOptions>,
457}
458
459#[derive(Debug)]
460pub struct LLMStreamInput {
461 pub model: LLMModel,
462 pub messages: Vec<LLMMessage>,
463 pub max_tokens: u32,
464 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
465 pub tools: Option<Vec<LLMTool>>,
466 pub provider_options: Option<LLMProviderOptions>,
467}
468
469impl From<&LLMStreamInput> for LLMInput {
470 fn from(value: &LLMStreamInput) -> Self {
471 LLMInput {
472 model: value.model.clone(),
473 messages: value.messages.clone(),
474 max_tokens: value.max_tokens,
475 tools: value.tools.clone(),
476 provider_options: value.provider_options.clone(),
477 }
478 }
479}
480
481#[derive(Serialize, Deserialize, Debug, Clone, Default)]
482pub struct LLMMessage {
483 pub role: String,
484 pub content: LLMMessageContent,
485}
486
487#[derive(Serialize, Deserialize, Debug, Clone)]
488pub struct SimpleLLMMessage {
489 #[serde(rename = "role")]
490 pub role: SimpleLLMRole,
491 pub content: String,
492}
493
494#[derive(Serialize, Deserialize, Debug, Clone)]
495#[serde(rename_all = "lowercase")]
496pub enum SimpleLLMRole {
497 User,
498 Assistant,
499}
500
501impl std::fmt::Display for SimpleLLMRole {
502 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503 match self {
504 SimpleLLMRole::User => write!(f, "user"),
505 SimpleLLMRole::Assistant => write!(f, "assistant"),
506 }
507 }
508}
509
510#[derive(Serialize, Deserialize, Debug, Clone)]
511#[serde(untagged)]
512pub enum LLMMessageContent {
513 String(String),
514 List(Vec<LLMMessageTypedContent>),
515}
516
517#[allow(clippy::to_string_trait_impl)]
518impl ToString for LLMMessageContent {
519 fn to_string(&self) -> String {
520 match self {
521 LLMMessageContent::String(s) => s.clone(),
522 LLMMessageContent::List(l) => l
523 .iter()
524 .map(|c| match c {
525 LLMMessageTypedContent::Text { text } => text.clone(),
526 LLMMessageTypedContent::ToolCall { .. } => String::new(),
527 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
528 LLMMessageTypedContent::Image { .. } => String::new(),
529 })
530 .collect::<Vec<_>>()
531 .join("\n"),
532 }
533 }
534}
535
536impl From<String> for LLMMessageContent {
537 fn from(value: String) -> Self {
538 LLMMessageContent::String(value)
539 }
540}
541
542impl Default for LLMMessageContent {
543 fn default() -> Self {
544 LLMMessageContent::String(String::new())
545 }
546}
547
548#[derive(Serialize, Deserialize, Debug, Clone)]
549#[serde(tag = "type")]
550pub enum LLMMessageTypedContent {
551 #[serde(rename = "text")]
552 Text { text: String },
553 #[serde(rename = "tool_use")]
554 ToolCall {
555 id: String,
556 name: String,
557 #[serde(alias = "input")]
558 args: serde_json::Value,
559 },
560 #[serde(rename = "tool_result")]
561 ToolResult {
562 tool_use_id: String,
563 content: String,
564 },
565 #[serde(rename = "image")]
566 Image { source: LLMMessageImageSource },
567}
568
569#[derive(Serialize, Deserialize, Debug, Clone)]
570pub struct LLMMessageImageSource {
571 #[serde(rename = "type")]
572 pub r#type: String,
573 pub media_type: String,
574 pub data: String,
575}
576
577impl Default for LLMMessageTypedContent {
578 fn default() -> Self {
579 LLMMessageTypedContent::Text {
580 text: String::new(),
581 }
582 }
583}
584
585#[derive(Serialize, Deserialize, Debug, Clone)]
586pub struct LLMChoice {
587 pub finish_reason: Option<String>,
588 pub index: u32,
589 pub message: LLMMessage,
590}
591
592#[derive(Serialize, Deserialize, Debug, Clone)]
593pub struct LLMCompletionResponse {
594 pub model: String,
595 pub object: String,
596 pub choices: Vec<LLMChoice>,
597 pub created: u64,
598 pub usage: Option<LLMTokenUsage>,
599 pub id: String,
600}
601
602#[derive(Serialize, Deserialize, Debug, Clone)]
603pub struct LLMStreamDelta {
604 #[serde(skip_serializing_if = "Option::is_none")]
605 pub content: Option<String>,
606}
607
608#[derive(Serialize, Deserialize, Debug, Clone)]
609pub struct LLMStreamChoice {
610 pub finish_reason: Option<String>,
611 pub index: u32,
612 pub message: Option<LLMMessage>,
613 pub delta: LLMStreamDelta,
614}
615
616#[derive(Serialize, Deserialize, Debug, Clone)]
617pub struct LLMCompletionStreamResponse {
618 pub model: String,
619 pub object: String,
620 pub choices: Vec<LLMStreamChoice>,
621 pub created: u64,
622 #[serde(skip_serializing_if = "Option::is_none")]
623 pub usage: Option<LLMTokenUsage>,
624 pub id: String,
625 pub citations: Option<Vec<String>>,
626}
627
628#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
629pub struct LLMTool {
630 pub name: String,
631 pub description: String,
632 pub input_schema: serde_json::Value,
633}
634
635#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
636pub struct LLMTokenUsage {
637 pub prompt_tokens: u32,
638 pub completion_tokens: u32,
639 pub total_tokens: u32,
640
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub prompt_tokens_details: Option<PromptTokensDetails>,
643}
644
645#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
646#[serde(rename_all = "snake_case")]
647pub enum TokenType {
648 InputTokens,
649 OutputTokens,
650 CacheReadInputTokens,
651 CacheWriteInputTokens,
652}
653
654#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
655pub struct PromptTokensDetails {
656 #[serde(skip_serializing_if = "Option::is_none")]
657 pub input_tokens: Option<u32>,
658 #[serde(skip_serializing_if = "Option::is_none")]
659 pub output_tokens: Option<u32>,
660 #[serde(skip_serializing_if = "Option::is_none")]
661 pub cache_read_input_tokens: Option<u32>,
662 #[serde(skip_serializing_if = "Option::is_none")]
663 pub cache_write_input_tokens: Option<u32>,
664}
665
666impl PromptTokensDetails {
667 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
669 [
670 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
671 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
672 (
673 TokenType::CacheReadInputTokens,
674 self.cache_read_input_tokens.unwrap_or(0),
675 ),
676 (
677 TokenType::CacheWriteInputTokens,
678 self.cache_write_input_tokens.unwrap_or(0),
679 ),
680 ]
681 .into_iter()
682 }
683}
684
685impl std::ops::Add for PromptTokensDetails {
686 type Output = Self;
687
688 fn add(self, rhs: Self) -> Self::Output {
689 Self {
690 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
691 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
692 cache_read_input_tokens: Some(
693 self.cache_read_input_tokens.unwrap_or(0)
694 + rhs.cache_read_input_tokens.unwrap_or(0),
695 ),
696 cache_write_input_tokens: Some(
697 self.cache_write_input_tokens.unwrap_or(0)
698 + rhs.cache_write_input_tokens.unwrap_or(0),
699 ),
700 }
701 }
702}
703
704impl std::ops::AddAssign for PromptTokensDetails {
705 fn add_assign(&mut self, rhs: Self) {
706 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
707 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
708 self.cache_read_input_tokens = Some(
709 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
710 );
711 self.cache_write_input_tokens = Some(
712 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
713 );
714 }
715}
716
717#[derive(Serialize, Deserialize, Debug, Clone)]
718#[serde(tag = "type")]
719pub enum GenerationDelta {
720 Content { content: String },
721 Thinking { thinking: String },
722 ToolUse { tool_use: GenerationDeltaToolUse },
723 Usage { usage: LLMTokenUsage },
724 Metadata { metadata: serde_json::Value },
725}
726
727#[derive(Serialize, Deserialize, Debug, Clone)]
728pub struct GenerationDeltaToolUse {
729 pub id: Option<String>,
730 pub name: Option<String>,
731 pub input: Option<String>,
732 pub index: usize,
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 #[test]
740 fn test_llm_model_from_known_anthropic_model() {
741 let model = LLMModel::from("claude-opus-4-5-20251101".to_string());
742 assert!(matches!(
743 model,
744 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
745 ));
746 }
747
748 #[test]
749 fn test_llm_model_from_known_openai_model() {
750 let model = LLMModel::from("gpt-5".to_string());
751 assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
752 }
753
754 #[test]
755 fn test_llm_model_from_known_gemini_model() {
756 let model = LLMModel::from("gemini-2.5-flash".to_string());
757 assert!(matches!(
758 model,
759 LLMModel::Gemini(GeminiModel::Gemini25Flash)
760 ));
761 }
762
763 #[test]
764 fn test_llm_model_from_custom_provider_with_slash() {
765 let model = LLMModel::from("litellm/claude-opus-4-5".to_string());
766 match model {
767 LLMModel::Custom { provider, model } => {
768 assert_eq!(provider, "litellm");
769 assert_eq!(model, "claude-opus-4-5");
770 }
771 _ => panic!("Expected Custom model"),
772 }
773 }
774
775 #[test]
776 fn test_llm_model_from_ollama_provider() {
777 let model = LLMModel::from("ollama/llama3".to_string());
778 match model {
779 LLMModel::Custom { provider, model } => {
780 assert_eq!(provider, "ollama");
781 assert_eq!(model, "llama3");
782 }
783 _ => panic!("Expected Custom model"),
784 }
785 }
786
787 #[test]
788 fn test_llm_model_explicit_anthropic_prefix() {
789 let model = LLMModel::from("anthropic/claude-opus-4-5".to_string());
791 assert!(matches!(
792 model,
793 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
794 ));
795 }
796
797 #[test]
798 fn test_llm_model_explicit_openai_prefix() {
799 let model = LLMModel::from("openai/gpt-5".to_string());
800 assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
801 }
802
803 #[test]
804 fn test_llm_model_explicit_google_prefix() {
805 let model = LLMModel::from("google/gemini-2.5-flash".to_string());
806 assert!(matches!(
807 model,
808 LLMModel::Gemini(GeminiModel::Gemini25Flash)
809 ));
810 }
811
812 #[test]
813 fn test_llm_model_explicit_gemini_prefix() {
814 let model = LLMModel::from("gemini/gemini-2.5-flash".to_string());
816 assert!(matches!(
817 model,
818 LLMModel::Gemini(GeminiModel::Gemini25Flash)
819 ));
820 }
821
822 #[test]
823 fn test_llm_model_unknown_model_becomes_custom() {
824 let model = LLMModel::from("some-random-model".to_string());
825 match model {
826 LLMModel::Custom { provider, model } => {
827 assert_eq!(provider, "custom");
828 assert_eq!(model, "some-random-model");
829 }
830 _ => panic!("Expected Custom model"),
831 }
832 }
833
834 #[test]
835 fn test_llm_model_display_anthropic() {
836 let model = LLMModel::Anthropic(AnthropicModel::Claude45Sonnet);
837 let s = model.to_string();
838 assert!(s.contains("claude"));
839 }
840
841 #[test]
842 fn test_llm_model_display_custom() {
843 let model = LLMModel::Custom {
844 provider: "litellm".to_string(),
845 model: "claude-opus".to_string(),
846 };
847 assert_eq!(model.to_string(), "litellm/claude-opus");
848 }
849
850 #[test]
851 fn test_llm_model_provider_name() {
852 assert_eq!(
853 LLMModel::Anthropic(AnthropicModel::Claude45Sonnet).provider_name(),
854 "anthropic"
855 );
856 assert_eq!(
857 LLMModel::OpenAI(OpenAIModel::GPT5).provider_name(),
858 "openai"
859 );
860 assert_eq!(
861 LLMModel::Gemini(GeminiModel::Gemini25Flash).provider_name(),
862 "google"
863 );
864 assert_eq!(
865 LLMModel::Custom {
866 provider: "litellm".to_string(),
867 model: "test".to_string()
868 }
869 .provider_name(),
870 "litellm"
871 );
872 }
873
874 #[test]
875 fn test_llm_model_model_id() {
876 let model = LLMModel::Custom {
877 provider: "litellm".to_string(),
878 model: "claude-opus-4-5".to_string(),
879 };
880 assert_eq!(model.model_id(), "claude-opus-4-5");
881 }
882
883 #[test]
888 fn test_provider_config_openai_serialization() {
889 let config = ProviderConfig::OpenAI {
890 api_key: Some("sk-test".to_string()),
891 api_endpoint: None,
892 };
893 let json = serde_json::to_string(&config).unwrap();
894 assert!(json.contains("\"type\":\"openai\""));
895 assert!(json.contains("\"api_key\":\"sk-test\""));
896 assert!(!json.contains("api_endpoint")); }
898
899 #[test]
900 fn test_provider_config_openai_with_endpoint() {
901 let config = ProviderConfig::OpenAI {
902 api_key: Some("sk-test".to_string()),
903 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
904 };
905 let json = serde_json::to_string(&config).unwrap();
906 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
907 }
908
909 #[test]
910 fn test_provider_config_anthropic_serialization() {
911 let config = ProviderConfig::Anthropic {
912 api_key: Some("sk-ant-test".to_string()),
913 api_endpoint: None,
914 access_token: Some("oauth-token".to_string()),
915 };
916 let json = serde_json::to_string(&config).unwrap();
917 assert!(json.contains("\"type\":\"anthropic\""));
918 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
919 assert!(json.contains("\"access_token\":\"oauth-token\""));
920 }
921
922 #[test]
923 fn test_provider_config_gemini_serialization() {
924 let config = ProviderConfig::Gemini {
925 api_key: Some("gemini-key".to_string()),
926 api_endpoint: None,
927 };
928 let json = serde_json::to_string(&config).unwrap();
929 assert!(json.contains("\"type\":\"gemini\""));
930 assert!(json.contains("\"api_key\":\"gemini-key\""));
931 }
932
933 #[test]
934 fn test_provider_config_custom_serialization() {
935 let config = ProviderConfig::Custom {
936 api_key: Some("sk-custom".to_string()),
937 api_endpoint: "http://localhost:4000".to_string(),
938 };
939 let json = serde_json::to_string(&config).unwrap();
940 assert!(json.contains("\"type\":\"custom\""));
941 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
942 assert!(json.contains("\"api_key\":\"sk-custom\""));
943 }
944
945 #[test]
946 fn test_provider_config_custom_without_key() {
947 let config = ProviderConfig::Custom {
948 api_key: None,
949 api_endpoint: "http://localhost:11434/v1".to_string(),
950 };
951 let json = serde_json::to_string(&config).unwrap();
952 assert!(json.contains("\"type\":\"custom\""));
953 assert!(json.contains("\"api_endpoint\""));
954 assert!(!json.contains("api_key")); }
956
957 #[test]
958 fn test_provider_config_deserialization_openai() {
959 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
960 let config: ProviderConfig = serde_json::from_str(json).unwrap();
961 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
962 assert_eq!(config.api_key(), Some("sk-test"));
963 }
964
965 #[test]
966 fn test_provider_config_deserialization_anthropic() {
967 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
968 let config: ProviderConfig = serde_json::from_str(json).unwrap();
969 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
970 assert_eq!(config.api_key(), Some("sk-ant"));
971 assert_eq!(config.access_token(), Some("oauth"));
972 }
973
974 #[test]
975 fn test_provider_config_deserialization_gemini() {
976 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
977 let config: ProviderConfig = serde_json::from_str(json).unwrap();
978 assert!(matches!(config, ProviderConfig::Gemini { .. }));
979 assert_eq!(config.api_key(), Some("gemini-key"));
980 }
981
982 #[test]
983 fn test_provider_config_deserialization_custom() {
984 let json =
985 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
986 let config: ProviderConfig = serde_json::from_str(json).unwrap();
987 assert!(matches!(config, ProviderConfig::Custom { .. }));
988 assert_eq!(config.api_key(), Some("sk-custom"));
989 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
990 }
991
992 #[test]
993 fn test_provider_config_helper_methods() {
994 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
995 assert_eq!(openai.provider_type(), "openai");
996 assert_eq!(openai.api_key(), Some("sk-openai"));
997
998 let anthropic =
999 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1000 assert_eq!(anthropic.provider_type(), "anthropic");
1001 assert_eq!(anthropic.access_token(), Some("oauth"));
1002
1003 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1004 assert_eq!(gemini.provider_type(), "gemini");
1005
1006 let custom = ProviderConfig::custom(
1007 "http://localhost:4000".to_string(),
1008 Some("sk-custom".to_string()),
1009 );
1010 assert_eq!(custom.provider_type(), "custom");
1011 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1012 }
1013
1014 #[test]
1015 fn test_llm_provider_config_new() {
1016 let config = LLMProviderConfig::new();
1017 assert!(config.is_empty());
1018 }
1019
1020 #[test]
1021 fn test_llm_provider_config_add_and_get() {
1022 let mut config = LLMProviderConfig::new();
1023 config.add_provider(
1024 "openai",
1025 ProviderConfig::openai(Some("sk-test".to_string())),
1026 );
1027 config.add_provider(
1028 "anthropic",
1029 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1030 );
1031
1032 assert!(!config.is_empty());
1033 assert!(config.get_provider("openai").is_some());
1034 assert!(config.get_provider("anthropic").is_some());
1035 assert!(config.get_provider("unknown").is_none());
1036 }
1037
1038 #[test]
1039 fn test_provider_config_toml_parsing() {
1040 let json = r#"{
1042 "openai": {"type": "openai", "api_key": "sk-openai"},
1043 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1044 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1045 }"#;
1046
1047 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1048 assert_eq!(providers.len(), 3);
1049
1050 assert!(matches!(
1051 providers.get("openai"),
1052 Some(ProviderConfig::OpenAI { .. })
1053 ));
1054 assert!(matches!(
1055 providers.get("anthropic"),
1056 Some(ProviderConfig::Anthropic { .. })
1057 ));
1058 assert!(matches!(
1059 providers.get("litellm"),
1060 Some(ProviderConfig::Custom { .. })
1061 ));
1062 }
1063}