use kernex_core::error::KernexError;
use kernex_core::run::ModelTier;
use kernex_core::traits::Provider;
use std::path::PathBuf;
#[derive(Default, Clone, Debug)]
pub struct ProviderConfig {
pub base_url: Option<String>,
pub api_key: Option<String>,
pub model: Option<String>,
pub max_tokens: Option<u32>,
pub workspace_path: Option<PathBuf>,
pub sandbox_profile: Option<kernex_sandbox::SandboxProfile>,
pub tier: Option<ModelTier>,
}
fn resolve_model(provider: &str, model: Option<String>, tier: Option<ModelTier>) -> Option<String> {
model.or_else(|| {
tier.map(|t| model_from_tier(provider, t).to_string())
.filter(|s| !s.is_empty())
})
}
fn model_from_tier(provider: &str, tier: ModelTier) -> &'static str {
match (provider, tier) {
("openai", ModelTier::Standard) => "gpt-4o-mini",
("openai", ModelTier::Flagship) => "gpt-4o",
("anthropic", ModelTier::Standard) => "claude-sonnet-4-6",
("anthropic", ModelTier::Flagship) => "claude-opus-4-6",
("gemini", ModelTier::Standard) => "gemini-2.0-flash",
("gemini", ModelTier::Flagship) => "gemini-2.5-pro",
("ollama", ModelTier::Standard) => "llama3.2",
("ollama", ModelTier::Flagship) => "llama3.1:70b",
("openrouter", ModelTier::Standard) => "anthropic/claude-sonnet-4-6",
("openrouter", ModelTier::Flagship) => "anthropic/claude-opus-4-6",
_ => "",
}
}
pub struct ProviderFactory;
impl ProviderFactory {
pub fn create(
provider: &str,
config: ProviderConfig,
) -> Result<Box<dyn Provider>, KernexError> {
match provider.to_lowercase().as_str() {
"openai" => {
let model = resolve_model("openai", config.model, config.tier)
.unwrap_or_else(|| "gpt-4o".to_string());
let p = crate::openai::OpenAiProvider::from_config(
config
.base_url
.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
config.api_key.unwrap_or_default(),
model,
config.workspace_path,
)?
.with_sandbox_profile(config.sandbox_profile.unwrap_or_default());
Ok(Box::new(p))
}
"anthropic" => {
let model = resolve_model("anthropic", config.model, config.tier)
.unwrap_or_else(|| "claude-3-7-sonnet-20250219".to_string());
let p = crate::anthropic::AnthropicProvider::from_config(
config.api_key.unwrap_or_default(),
model,
config.max_tokens.unwrap_or(8192),
config.workspace_path,
)?
.with_sandbox_profile(config.sandbox_profile.unwrap_or_default());
Ok(Box::new(p))
}
"gemini" => {
let model = resolve_model("gemini", config.model, config.tier)
.unwrap_or_else(|| "gemini-2.5-flash".to_string());
let p = crate::gemini::GeminiProvider::from_config(
config.api_key.unwrap_or_default(),
model,
config.workspace_path,
)?
.with_sandbox_profile(config.sandbox_profile.unwrap_or_default());
Ok(Box::new(p))
}
"ollama" => {
let model = resolve_model("ollama", config.model, config.tier)
.unwrap_or_else(|| "llama3.2".to_string());
let p = crate::ollama::OllamaProvider::from_config(
config
.base_url
.unwrap_or_else(|| "http://localhost:11434".to_string()),
model,
config.workspace_path,
)?
.with_sandbox_profile(config.sandbox_profile.unwrap_or_default());
Ok(Box::new(p))
}
"openrouter" => {
let model = resolve_model("openrouter", config.model, config.tier)
.unwrap_or_else(|| "anthropic/claude-3.5-sonnet".to_string());
let p = crate::openrouter::OpenRouterProvider::from_config(
config.api_key.unwrap_or_default(),
model,
config.workspace_path,
)?
.with_sandbox_profile(config.sandbox_profile.unwrap_or_default());
Ok(Box::new(p))
}
"claude-code" => {
let p = crate::claude_code::ClaudeCodeProvider::from_config(
25, vec![], 3600, config.workspace_path,
5, config.model.unwrap_or_default(),
None, )
.with_sandbox_profile(config.sandbox_profile.unwrap_or_default());
Ok(Box::new(p))
}
_ => Err(KernexError::Provider(format!(
"Unknown provider type: {}",
provider
))),
}
}
pub fn supported_providers() -> &'static [&'static str] {
&[
"openai",
"anthropic",
"gemini",
"ollama",
"openrouter",
"claude-code",
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn provider_config_default() {
let config = ProviderConfig::default();
assert!(config.base_url.is_none());
assert!(config.api_key.is_none());
assert!(config.model.is_none());
assert!(config.max_tokens.is_none());
assert!(config.workspace_path.is_none());
}
#[test]
fn provider_config_with_values() {
let config = ProviderConfig {
base_url: Some("https://api.example.com".to_string()),
api_key: Some("sk-test".to_string()),
model: Some("gpt-4".to_string()),
max_tokens: Some(4096),
workspace_path: Some(PathBuf::from("/tmp")),
sandbox_profile: None,
tier: None,
};
assert_eq!(config.base_url, Some("https://api.example.com".to_string()));
assert_eq!(config.model, Some("gpt-4".to_string()));
}
#[test]
fn factory_unknown_provider_error() {
let result = ProviderFactory::create("unknown-provider", ProviderConfig::default());
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Unknown provider type"));
}
}
#[test]
fn factory_supported_providers_list() {
let providers = ProviderFactory::supported_providers();
assert!(providers.contains(&"openai"));
assert!(providers.contains(&"anthropic"));
assert!(providers.contains(&"gemini"));
assert!(providers.contains(&"ollama"));
assert!(providers.contains(&"openrouter"));
assert!(providers.contains(&"claude-code"));
assert_eq!(providers.len(), 6);
}
#[test]
fn factory_case_insensitive() {
let result = ProviderFactory::create("UNKNOWN", ProviderConfig::default());
assert!(result.is_err());
}
#[test]
fn factory_creates_openai() {
let config = ProviderConfig {
api_key: Some("test-key".to_string()),
workspace_path: Some(PathBuf::from("/tmp")),
..Default::default()
};
let result = ProviderFactory::create("openai", config);
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(provider.name(), "openai");
}
#[test]
fn factory_creates_anthropic() {
let config = ProviderConfig {
api_key: Some("test-key".to_string()),
workspace_path: Some(PathBuf::from("/tmp")),
..Default::default()
};
let result = ProviderFactory::create("anthropic", config);
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(provider.name(), "anthropic");
}
#[test]
fn factory_creates_gemini() {
let config = ProviderConfig {
api_key: Some("test-key".to_string()),
workspace_path: Some(PathBuf::from("/tmp")),
..Default::default()
};
let result = ProviderFactory::create("gemini", config);
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(provider.name(), "gemini");
}
#[test]
fn factory_creates_ollama() {
let config = ProviderConfig {
workspace_path: Some(PathBuf::from("/tmp")),
..Default::default()
};
let result = ProviderFactory::create("ollama", config);
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(provider.name(), "ollama");
}
#[test]
fn factory_creates_openrouter() {
let config = ProviderConfig {
api_key: Some("test-key".to_string()),
workspace_path: Some(PathBuf::from("/tmp")),
..Default::default()
};
let result = ProviderFactory::create("openrouter", config);
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(provider.name(), "openrouter");
}
#[test]
fn factory_creates_claude_code() {
let config = ProviderConfig {
workspace_path: Some(PathBuf::from("/tmp")),
..Default::default()
};
let result = ProviderFactory::create("claude-code", config);
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(provider.name(), "claude-code");
}
#[test]
fn model_from_tier_standard() {
assert_eq!(
model_from_tier("anthropic", ModelTier::Standard),
"claude-sonnet-4-6"
);
assert_eq!(
model_from_tier("openai", ModelTier::Standard),
"gpt-4o-mini"
);
assert_eq!(
model_from_tier("gemini", ModelTier::Standard),
"gemini-2.0-flash"
);
assert_eq!(model_from_tier("ollama", ModelTier::Standard), "llama3.2");
assert_eq!(
model_from_tier("openrouter", ModelTier::Standard),
"anthropic/claude-sonnet-4-6"
);
}
#[test]
fn model_from_tier_flagship() {
assert_eq!(
model_from_tier("anthropic", ModelTier::Flagship),
"claude-opus-4-6"
);
assert_eq!(model_from_tier("openai", ModelTier::Flagship), "gpt-4o");
assert_eq!(
model_from_tier("gemini", ModelTier::Flagship),
"gemini-2.5-pro"
);
assert_eq!(
model_from_tier("ollama", ModelTier::Flagship),
"llama3.1:70b"
);
assert_eq!(
model_from_tier("openrouter", ModelTier::Flagship),
"anthropic/claude-opus-4-6"
);
}
#[test]
fn resolve_model_explicit_wins_over_tier() {
let result = resolve_model(
"anthropic",
Some("my-custom-model".to_string()),
Some(ModelTier::Flagship),
);
assert_eq!(result, Some("my-custom-model".to_string()));
}
#[test]
fn resolve_model_tier_used_when_no_explicit_model() {
let result = resolve_model("anthropic", None, Some(ModelTier::Standard));
assert_eq!(result, Some("claude-sonnet-4-6".to_string()));
}
#[test]
fn resolve_model_returns_none_when_both_absent() {
let result = resolve_model("anthropic", None, None);
assert!(result.is_none());
}
#[test]
fn factory_creates_anthropic_with_tier() {
let config = ProviderConfig {
api_key: Some("sk-test".to_string()),
workspace_path: Some(PathBuf::from("/tmp")),
tier: Some(ModelTier::Standard),
..Default::default()
};
let result = ProviderFactory::create("anthropic", config);
assert!(result.is_ok());
assert_eq!(result.unwrap().name(), "anthropic");
}
}