use crate::config::ProfileConfig;
use crate::config::ProviderType;
use stakpak_shared::models::integrations::anthropic::AnthropicModel;
use stakpak_shared::models::integrations::gemini::GeminiModel;
use stakpak_shared::models::integrations::openai::OpenAIModel;
use stakpak_shared::models::llm::ProviderConfig;
pub fn generate_openai_profile() -> ProfileConfig {
let mut profile = ProfileConfig {
provider: Some(ProviderType::Local),
smart_model: Some(OpenAIModel::default_smart_model()),
eco_model: Some(OpenAIModel::default_eco_model()),
..ProfileConfig::default()
};
profile.providers.insert(
"openai".to_string(),
ProviderConfig::OpenAI {
api_key: None,
api_endpoint: None,
},
);
profile
}
pub fn generate_gemini_profile() -> ProfileConfig {
let mut profile = ProfileConfig {
provider: Some(ProviderType::Local),
smart_model: Some(GeminiModel::default_smart_model()),
eco_model: Some(GeminiModel::default_eco_model()),
..ProfileConfig::default()
};
profile.providers.insert(
"gemini".to_string(),
ProviderConfig::Gemini {
api_key: None,
api_endpoint: None,
},
);
profile
}
pub fn generate_anthropic_profile() -> ProfileConfig {
let mut profile = ProfileConfig {
provider: Some(ProviderType::Local),
smart_model: Some(AnthropicModel::default_smart_model()),
eco_model: Some(AnthropicModel::default_eco_model()),
..ProfileConfig::default()
};
profile.providers.insert(
"anthropic".to_string(),
ProviderConfig::Anthropic {
api_key: None,
api_endpoint: None,
access_token: None,
},
);
profile
}
pub fn generate_custom_provider_profile(
provider_name: String,
api_endpoint: String,
api_key: Option<String>,
smart_model: String,
eco_model: String,
) -> ProfileConfig {
let mut profile = ProfileConfig {
provider: Some(ProviderType::Local),
smart_model: Some(format!("{}/{}", provider_name, smart_model)),
eco_model: Some(format!("{}/{}", provider_name, eco_model)),
..ProfileConfig::default()
};
profile.providers.insert(
provider_name,
ProviderConfig::Custom {
api_key,
api_endpoint,
},
);
profile
}
#[derive(Debug, Clone)]
pub struct HybridModelConfig {
pub provider: HybridProvider,
pub model: String,
pub api_key: String,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum HybridProvider {
OpenAI,
Gemini,
Anthropic,
}
impl HybridProvider {
pub fn as_str(&self) -> &'static str {
match self {
HybridProvider::OpenAI => "OpenAI",
HybridProvider::Gemini => "Gemini",
HybridProvider::Anthropic => "Anthropic",
}
}
}
pub fn generate_hybrid_config(smart: HybridModelConfig, eco: HybridModelConfig) -> ProfileConfig {
let mut profile = ProfileConfig {
provider: Some(ProviderType::Local),
smart_model: Some(smart.model.clone()),
eco_model: Some(eco.model.clone()),
..ProfileConfig::default()
};
if smart.provider == HybridProvider::OpenAI || eco.provider == HybridProvider::OpenAI {
profile.providers.insert(
"openai".to_string(),
ProviderConfig::OpenAI {
api_key: if smart.provider == HybridProvider::OpenAI {
Some(smart.api_key.clone())
} else {
Some(eco.api_key.clone())
},
api_endpoint: None,
},
);
}
if smart.provider == HybridProvider::Gemini || eco.provider == HybridProvider::Gemini {
profile.providers.insert(
"gemini".to_string(),
ProviderConfig::Gemini {
api_key: if smart.provider == HybridProvider::Gemini {
Some(smart.api_key.clone())
} else {
Some(eco.api_key.clone())
},
api_endpoint: None,
},
);
}
if smart.provider == HybridProvider::Anthropic || eco.provider == HybridProvider::Anthropic {
profile.providers.insert(
"anthropic".to_string(),
ProviderConfig::Anthropic {
api_key: if smart.provider == HybridProvider::Anthropic {
Some(smart.api_key.clone())
} else {
Some(eco.api_key.clone())
},
api_endpoint: None,
access_token: None,
},
);
}
profile
}
pub fn config_to_toml_preview(profile: &ProfileConfig) -> String {
let mut toml = String::from("[profiles.default]\n");
if let Some(provider) = &profile.provider {
toml.push_str(&format!(
"provider = \"{}\"\n",
match provider {
ProviderType::Remote => "remote",
ProviderType::Local => "local",
}
));
}
if let Some(ref smart_model) = profile.smart_model {
toml.push_str(&format!("smart_model = \"{}\"\n", smart_model));
}
if let Some(ref eco_model) = profile.eco_model {
toml.push_str(&format!("eco_model = \"{}\"\n", eco_model));
}
for (name, config) in &profile.providers {
toml.push_str(&format!("\n[profiles.default.providers.{}]\n", name));
match config {
ProviderConfig::OpenAI {
api_key,
api_endpoint,
} => {
toml.push_str("type = \"openai\"\n");
if let Some(endpoint) = api_endpoint {
toml.push_str(&format!("api_endpoint = \"{}\"\n", endpoint));
}
if let Some(key) = api_key {
toml.push_str(&format!(
"api_key = \"{}\"\n",
if key.is_empty() { "" } else { "***" }
));
}
}
ProviderConfig::Anthropic {
api_key,
api_endpoint,
access_token,
} => {
toml.push_str("type = \"anthropic\"\n");
if let Some(endpoint) = api_endpoint {
toml.push_str(&format!("api_endpoint = \"{}\"\n", endpoint));
}
if let Some(key) = api_key {
toml.push_str(&format!(
"api_key = \"{}\"\n",
if key.is_empty() { "" } else { "***" }
));
}
if let Some(token) = access_token {
toml.push_str(&format!(
"access_token = \"{}\"\n",
if token.is_empty() { "" } else { "***" }
));
}
}
ProviderConfig::Gemini {
api_key,
api_endpoint,
} => {
toml.push_str("type = \"gemini\"\n");
if let Some(endpoint) = api_endpoint {
toml.push_str(&format!("api_endpoint = \"{}\"\n", endpoint));
}
if let Some(key) = api_key {
toml.push_str(&format!(
"api_key = \"{}\"\n",
if key.is_empty() { "" } else { "***" }
));
}
}
ProviderConfig::Custom {
api_key,
api_endpoint,
} => {
toml.push_str("type = \"custom\"\n");
toml.push_str(&format!("api_endpoint = \"{}\"\n", api_endpoint));
if let Some(key) = api_key {
toml.push_str(&format!(
"api_key = \"{}\"\n",
if key.is_empty() { "" } else { "***" }
));
}
}
ProviderConfig::Stakpak {
api_key,
api_endpoint,
} => {
toml.push_str("type = \"stakpak\"\n");
toml.push_str(&format!(
"api_key = \"{}\"\n",
if api_key.is_empty() { "" } else { "***" }
));
if let Some(endpoint) = api_endpoint {
toml.push_str(&format!("api_endpoint = \"{}\"\n", endpoint));
}
}
ProviderConfig::Bedrock {
region,
profile_name,
} => {
toml.push_str("type = \"amazon-bedrock\"\n");
toml.push_str(&format!("region = \"{}\"\n", region));
if let Some(profile) = profile_name {
toml.push_str(&format!("profile_name = \"{}\"\n", profile));
}
}
}
}
toml
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_custom_provider_profile() {
let profile = generate_custom_provider_profile(
"litellm".to_string(),
"http://localhost:4000".to_string(),
Some("sk-1234".to_string()),
"claude-opus".to_string(),
"claude-haiku".to_string(),
);
assert!(matches!(profile.provider, Some(ProviderType::Local)));
assert_eq!(profile.smart_model, Some("litellm/claude-opus".to_string()));
assert_eq!(profile.eco_model, Some("litellm/claude-haiku".to_string()));
let provider = profile
.providers
.get("litellm")
.expect("litellm provider should exist");
match provider {
ProviderConfig::Custom {
api_key,
api_endpoint,
} => {
assert_eq!(api_endpoint, "http://localhost:4000");
assert_eq!(api_key, &Some("sk-1234".to_string()));
}
_ => panic!("Expected Custom provider"),
}
}
#[test]
fn test_generate_custom_provider_profile_without_api_key() {
let profile = generate_custom_provider_profile(
"ollama".to_string(),
"http://localhost:11434/v1".to_string(),
None,
"llama3".to_string(),
"llama3".to_string(),
);
assert!(matches!(profile.provider, Some(ProviderType::Local)));
assert_eq!(profile.smart_model, Some("ollama/llama3".to_string()));
assert_eq!(profile.eco_model, Some("ollama/llama3".to_string()));
let provider = profile
.providers
.get("ollama")
.expect("ollama provider should exist");
match provider {
ProviderConfig::Custom {
api_key,
api_endpoint,
} => {
assert_eq!(api_endpoint, "http://localhost:11434/v1");
assert!(api_key.is_none());
}
_ => panic!("Expected Custom provider"),
}
}
#[test]
fn test_config_to_toml_preview_with_custom_provider() {
let mut profile = ProfileConfig {
provider: Some(ProviderType::Local),
smart_model: Some("litellm/claude-opus".to_string()),
eco_model: Some("litellm/claude-haiku".to_string()),
..ProfileConfig::default()
};
profile.providers.insert(
"litellm".to_string(),
ProviderConfig::Custom {
api_endpoint: "http://localhost:4000".to_string(),
api_key: Some("sk-1234".to_string()),
},
);
let toml = config_to_toml_preview(&profile);
assert!(toml.contains("provider = \"local\""));
assert!(toml.contains("smart_model = \"litellm/claude-opus\""));
assert!(toml.contains("eco_model = \"litellm/claude-haiku\""));
assert!(toml.contains("[profiles.default.providers.litellm]"));
assert!(toml.contains("type = \"custom\""));
assert!(toml.contains("api_endpoint = \"http://localhost:4000\""));
assert!(toml.contains("api_key = \"***\"")); }
#[test]
fn test_config_to_toml_preview_custom_provider_no_api_key() {
let mut profile = ProfileConfig {
provider: Some(ProviderType::Local),
smart_model: Some("ollama/llama3".to_string()),
eco_model: Some("ollama/llama3".to_string()),
..ProfileConfig::default()
};
profile.providers.insert(
"ollama".to_string(),
ProviderConfig::Custom {
api_endpoint: "http://localhost:11434/v1".to_string(),
api_key: None,
},
);
let toml = config_to_toml_preview(&profile);
assert!(toml.contains("[profiles.default.providers.ollama]"));
assert!(toml.contains("type = \"custom\""));
assert!(toml.contains("api_endpoint = \"http://localhost:11434/v1\""));
let lines: Vec<&str> = toml.lines().collect();
let has_api_key_in_ollama_section = lines
.iter()
.skip_while(|l| !l.contains("providers.ollama"))
.take_while(|l| !l.starts_with('[') || l.contains("providers.ollama"))
.any(|l| l.contains("api_key"));
assert!(!has_api_key_in_ollama_section);
}
}