use crate::test_client::TestClient;
use crate::types::LlmClient;
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
pub enum FactoryError {
#[error("Unsupported provider: {0}")]
UnsupportedProvider(String),
#[error("Missing API key for provider: {0}")]
MissingApiKey(String),
#[error("Failed to create client: {0}")]
ClientCreationFailed(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LlmProvider {
Anthropic,
OpenAi,
Gemini,
}
impl LlmProvider {
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"anthropic" | "claude" => Some(Self::Anthropic),
"openai" | "gpt" | "chatgpt" => Some(Self::OpenAi),
"gemini" | "google" => Some(Self::Gemini),
_ => None,
}
}
pub fn env_var(&self) -> &'static str {
match self {
Self::Anthropic => "ANTHROPIC_API_KEY",
Self::OpenAi => "OPENAI_API_KEY",
Self::Gemini => "GEMINI_API_KEY",
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Anthropic => "anthropic",
Self::OpenAi => "openai",
Self::Gemini => "gemini",
}
}
}
impl std::fmt::Display for LlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
pub trait LlmClientFactory: Send + Sync {
fn create_client(
&self,
provider: LlmProvider,
api_key: Option<String>,
) -> Result<Arc<dyn LlmClient>, FactoryError>;
fn supported_providers(&self) -> Vec<LlmProvider>;
fn is_supported(&self, provider: LlmProvider) -> bool {
self.supported_providers().contains(&provider)
}
}
#[derive(Debug, Clone, Default)]
pub struct DefaultFactoryConfig {
pub anthropic_api_key: Option<String>,
pub openai_api_key: Option<String>,
pub gemini_api_key: Option<String>,
pub anthropic_base_url: Option<String>,
pub openai_base_url: Option<String>,
pub gemini_base_url: Option<String>,
}
impl DefaultFactoryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_anthropic_key(mut self, key: impl Into<String>) -> Self {
self.anthropic_api_key = Some(key.into());
self
}
pub fn with_openai_key(mut self, key: impl Into<String>) -> Self {
self.openai_api_key = Some(key.into());
self
}
pub fn with_gemini_key(mut self, key: impl Into<String>) -> Self {
self.gemini_api_key = Some(key.into());
self
}
pub fn with_anthropic_base_url(mut self, url: impl Into<String>) -> Self {
self.anthropic_base_url = Some(url.into());
self
}
pub fn with_openai_base_url(mut self, url: impl Into<String>) -> Self {
self.openai_base_url = Some(url.into());
self
}
pub fn with_gemini_base_url(mut self, url: impl Into<String>) -> Self {
self.gemini_base_url = Some(url.into());
self
}
fn get_api_key(&self, provider: LlmProvider) -> Option<String> {
match provider {
LlmProvider::Anthropic => self
.anthropic_api_key
.clone()
.or_else(|| env_preferring_rkat("RKAT_ANTHROPIC_API_KEY", "ANTHROPIC_API_KEY")),
LlmProvider::OpenAi => self
.openai_api_key
.clone()
.or_else(|| env_preferring_rkat("RKAT_OPENAI_API_KEY", "OPENAI_API_KEY")),
LlmProvider::Gemini => self
.gemini_api_key
.clone()
.or_else(|| env_preferring_rkat("RKAT_GEMINI_API_KEY", "GEMINI_API_KEY"))
.or_else(|| std::env::var("GOOGLE_API_KEY").ok()),
}
}
fn get_base_url(&self, provider: LlmProvider) -> Option<&str> {
match provider {
LlmProvider::Anthropic => self.anthropic_base_url.as_deref(),
LlmProvider::OpenAi => self.openai_base_url.as_deref(),
LlmProvider::Gemini => self.gemini_base_url.as_deref(),
}
}
}
pub struct DefaultClientFactory {
config: DefaultFactoryConfig,
}
impl DefaultClientFactory {
pub fn new() -> Self {
Self {
config: DefaultFactoryConfig::default(),
}
}
pub fn with_config(config: DefaultFactoryConfig) -> Self {
Self { config }
}
}
impl Default for DefaultClientFactory {
fn default() -> Self {
Self::new()
}
}
impl LlmClientFactory for DefaultClientFactory {
fn create_client(
&self,
provider: LlmProvider,
api_key: Option<String>,
) -> Result<Arc<dyn LlmClient>, FactoryError> {
if std::env::var("RKAT_TEST_CLIENT").ok().as_deref() == Some("1") {
return Ok(Arc::new(TestClient::default()));
}
let key = api_key.or_else(|| self.config.get_api_key(provider));
match provider {
#[cfg(feature = "anthropic")]
LlmProvider::Anthropic => {
let key = key.ok_or_else(|| FactoryError::MissingApiKey("anthropic".into()))?;
let mut client = crate::AnthropicClient::new(key)
.map_err(|e| FactoryError::UnsupportedProvider(e.to_string()))?;
if let Some(base_url) = self.config.get_base_url(provider) {
client = client.with_base_url(base_url.to_string());
}
Ok(Arc::new(client))
}
#[cfg(not(feature = "anthropic"))]
LlmProvider::Anthropic => Err(FactoryError::UnsupportedProvider(
"anthropic (feature not enabled)".into(),
)),
#[cfg(feature = "openai")]
LlmProvider::OpenAi => {
let key = key.ok_or_else(|| FactoryError::MissingApiKey("openai".into()))?;
let mut client = crate::OpenAiClient::new(key);
if let Some(base_url) = self.config.get_base_url(provider) {
client = client.with_base_url(base_url.to_string());
}
Ok(Arc::new(client))
}
#[cfg(not(feature = "openai"))]
LlmProvider::OpenAi => Err(FactoryError::UnsupportedProvider(
"openai (feature not enabled)".into(),
)),
#[cfg(feature = "gemini")]
LlmProvider::Gemini => {
let key = key.ok_or_else(|| FactoryError::MissingApiKey("gemini".into()))?;
let mut client = crate::GeminiClient::new(key);
if let Some(base_url) = self.config.get_base_url(provider) {
client = client.with_base_url(base_url.to_string());
}
Ok(Arc::new(client))
}
#[cfg(not(feature = "gemini"))]
LlmProvider::Gemini => Err(FactoryError::UnsupportedProvider(
"gemini (feature not enabled)".into(),
)),
}
}
#[allow(clippy::vec_init_then_push)]
fn supported_providers(&self) -> Vec<LlmProvider> {
#[allow(unused_mut)]
let mut providers = Vec::new();
#[cfg(feature = "anthropic")]
providers.push(LlmProvider::Anthropic);
#[cfg(feature = "openai")]
providers.push(LlmProvider::OpenAi);
#[cfg(feature = "gemini")]
providers.push(LlmProvider::Gemini);
providers
}
}
fn env_preferring_rkat(rkat_key: &str, provider_key: &str) -> Option<String> {
std::env::var(rkat_key)
.ok()
.or_else(|| std::env::var(provider_key).ok())
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_provider_parse() {
assert_eq!(
LlmProvider::parse("anthropic"),
Some(LlmProvider::Anthropic)
);
assert_eq!(LlmProvider::parse("claude"), Some(LlmProvider::Anthropic));
assert_eq!(
LlmProvider::parse("ANTHROPIC"),
Some(LlmProvider::Anthropic)
);
assert_eq!(LlmProvider::parse("openai"), Some(LlmProvider::OpenAi));
assert_eq!(LlmProvider::parse("gpt"), Some(LlmProvider::OpenAi));
assert_eq!(LlmProvider::parse("chatgpt"), Some(LlmProvider::OpenAi));
assert_eq!(LlmProvider::parse("gemini"), Some(LlmProvider::Gemini));
assert_eq!(LlmProvider::parse("google"), Some(LlmProvider::Gemini));
assert_eq!(LlmProvider::parse("unknown"), None);
}
#[test]
fn test_provider_env_var() {
assert_eq!(LlmProvider::Anthropic.env_var(), "ANTHROPIC_API_KEY");
assert_eq!(LlmProvider::OpenAi.env_var(), "OPENAI_API_KEY");
assert_eq!(LlmProvider::Gemini.env_var(), "GEMINI_API_KEY");
}
#[test]
fn test_provider_as_str() {
assert_eq!(LlmProvider::Anthropic.as_str(), "anthropic");
assert_eq!(LlmProvider::OpenAi.as_str(), "openai");
assert_eq!(LlmProvider::Gemini.as_str(), "gemini");
}
#[test]
fn test_provider_display() {
assert_eq!(format!("{}", LlmProvider::Anthropic), "anthropic");
}
#[test]
fn test_factory_config_builder() {
let config = DefaultFactoryConfig::new()
.with_anthropic_key("test-key")
.with_openai_base_url("https://custom.api.com");
assert_eq!(config.anthropic_api_key, Some("test-key".to_string()));
assert_eq!(
config.openai_base_url,
Some("https://custom.api.com".to_string())
);
}
#[test]
fn test_default_factory_supported_providers() {
let factory = DefaultClientFactory::new();
let providers = factory.supported_providers();
#[cfg(feature = "anthropic")]
assert!(providers.contains(&LlmProvider::Anthropic));
#[cfg(feature = "openai")]
assert!(providers.contains(&LlmProvider::OpenAi));
#[cfg(feature = "gemini")]
assert!(providers.contains(&LlmProvider::Gemini));
}
#[test]
fn test_factory_is_supported() {
let factory = DefaultClientFactory::new();
#[cfg(feature = "anthropic")]
assert!(factory.is_supported(LlmProvider::Anthropic));
#[cfg(not(feature = "anthropic"))]
assert!(!factory.is_supported(LlmProvider::Anthropic));
}
#[test]
fn test_factory_missing_api_key() {
let factory = DefaultClientFactory::with_config(DefaultFactoryConfig::new());
let result = factory.create_client(LlmProvider::Anthropic, None);
if std::env::var("ANTHROPIC_API_KEY").is_err() {
#[cfg(feature = "anthropic")]
assert!(matches!(result, Err(FactoryError::MissingApiKey(_))));
}
}
#[test]
fn test_factory_with_explicit_key() {
let factory = DefaultClientFactory::new();
#[cfg(feature = "anthropic")]
{
let result =
factory.create_client(LlmProvider::Anthropic, Some("test-api-key".to_string()));
assert!(result.is_ok());
}
}
#[test]
fn test_factory_error_display() {
let err = FactoryError::UnsupportedProvider("test".into());
assert_eq!(err.to_string(), "Unsupported provider: test");
let err = FactoryError::MissingApiKey("anthropic".into());
assert_eq!(err.to_string(), "Missing API key for provider: anthropic");
let err = FactoryError::ClientCreationFailed("timeout".into());
assert_eq!(err.to_string(), "Failed to create client: timeout");
}
}