use std::sync::Arc;
use bamboo_domain::reasoning::ReasoningEffort;
use bamboo_infrastructure::Config;
use bamboo_infrastructure::{LLMError, ProviderModelRouter, ProviderRegistry, ResolvedModel};
pub fn get_default_model_for_provider(
config: &Config,
provider_name: &str,
) -> Result<String, LLMError> {
match provider_name.trim() {
"copilot" => {
let provider_model = config
.providers
.copilot
.as_ref()
.and_then(|c| c.model.clone());
Ok(provider_model.unwrap_or_else(|| "gpt-4o".to_string()))
}
"openai" => {
let openai_config = config
.providers
.openai
.as_ref()
.ok_or_else(|| LLMError::Auth("OpenAI configuration required".to_string()))?;
openai_config.model.clone().ok_or_else(|| {
LLMError::Auth("OpenAI model must be specified in config".to_string())
})
}
"anthropic" => {
let anthropic_config =
config.providers.anthropic.as_ref().ok_or_else(|| {
LLMError::Auth("Anthropic configuration required".to_string())
})?;
anthropic_config.model.clone().ok_or_else(|| {
LLMError::Auth("Anthropic model must be specified in config".to_string())
})
}
"gemini" => {
let gemini_config = config
.providers
.gemini
.as_ref()
.ok_or_else(|| LLMError::Auth("Gemini configuration required".to_string()))?;
gemini_config.model.clone().ok_or_else(|| {
LLMError::Auth("Gemini model must be specified in config".to_string())
})
}
other => Err(LLMError::Auth(format!("Unknown provider: {}", other))),
}
}
pub fn get_default_model_from_config(config: &Config) -> Result<String, LLMError> {
get_default_model_for_provider(config, config.provider.as_str())
}
pub fn get_fast_model_for_provider(config: &Config, provider_name: &str) -> Option<String> {
let fast = match provider_name.trim() {
"openai" => config
.providers
.openai
.as_ref()
.and_then(|c| c.fast_model.clone()),
"anthropic" => config
.providers
.anthropic
.as_ref()
.and_then(|c| c.fast_model.clone()),
"gemini" => config
.providers
.gemini
.as_ref()
.and_then(|c| c.fast_model.clone()),
"copilot" => config
.providers
.copilot
.as_ref()
.and_then(|c| c.fast_model.clone()),
_ => None,
};
fast.or_else(|| get_default_model_for_provider(config, provider_name).ok())
}
pub fn get_memory_background_model_for_provider(
config: &Config,
provider_name: &str,
) -> Option<String> {
let configured = config
.memory
.as_ref()
.and_then(|memory| memory.background_model.as_ref())
.map(|value| value.trim())
.filter(|value| !value.is_empty())
.map(ToString::to_string);
configured.or_else(|| get_fast_model_for_provider(config, provider_name))
}
pub fn get_reasoning_effort_for_provider(
config: &Config,
provider_name: &str,
) -> Option<ReasoningEffort> {
match provider_name.trim() {
"openai" => config
.providers
.openai
.as_ref()
.and_then(|c| c.reasoning_effort),
"anthropic" => config
.providers
.anthropic
.as_ref()
.and_then(|c| c.reasoning_effort),
"gemini" => config
.providers
.gemini
.as_ref()
.and_then(|c| c.reasoning_effort),
"copilot" => config
.providers
.copilot
.as_ref()
.and_then(|c| c.reasoning_effort),
_ => None,
}
}
pub fn get_memory_background_model_from_config(config: &Config) -> Result<String, LLMError> {
config.get_memory_background_model().ok_or_else(|| {
LLMError::Auth(format!(
"No background memory model configured for provider '{}'",
config.provider
))
})
}
pub fn get_vision_model_from_config(config: &Config) -> Result<String, LLMError> {
config.get_vision_model().ok_or_else(|| {
LLMError::Auth(format!(
"No model configured for provider '{}'",
config.provider
))
})
}
pub fn resolve_background_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.memory_background.as_ref())
.or_else(|| config.defaults.as_ref().and_then(|d| d.fast.as_ref()))
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_memory_background_model_for_provider(config, provider_name)?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_fast_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.fast.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_fast_model_for_provider(config, provider_name)?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_vision_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.vision.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = config.get_vision_model()?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_planning_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.planning.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_default_chat_model(config, provider_name, provider_registry)
}
pub fn resolve_search_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.search.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_fast_model(config, provider_name, provider_registry)
.or_else(|| resolve_default_chat_model(config, provider_name, provider_registry))
}
pub fn resolve_code_review_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.code_review.as_ref())
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_default_chat_model(config, provider_name, provider_registry)
}
pub fn resolve_subagent_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
subagent_type: &str,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.subagent_models.get(subagent_type))
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_default_chat_model(config, provider_name, provider_registry)
}
fn resolve_default_chat_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().map(|d| &d.chat) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_default_model_for_provider(config, provider_name).ok()?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
#[cfg(test)]
mod tests {
use super::*;
use bamboo_infrastructure::{CopilotConfig, OpenAIConfig, ProviderConfigs};
#[test]
fn test_get_model_from_openai_config() {
let config = Config {
provider: "openai".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o");
}
#[test]
fn test_error_when_model_not_configured() {
let config = Config {
provider: "openai".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: None, fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("model must be specified"));
}
#[test]
fn test_get_model_from_copilot_provider_config() {
let config = Config {
provider: "copilot".to_string(),
providers: ProviderConfigs {
copilot: Some(CopilotConfig {
enabled: true,
headless_auth: false,
model: Some("gpt-4o-mini".to_string()),
fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o-mini");
}
#[test]
fn test_get_model_from_copilot_default_fallback() {
let config = Config {
provider: "copilot".to_string(),
providers: ProviderConfigs::default(),
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o");
}
#[test]
fn test_get_default_model_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_for_provider(&config, "openai").expect("openai config");
assert_eq!(result, "gpt-4o");
}
#[test]
fn test_get_fast_model_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
assert_eq!(
get_fast_model_for_provider(&config, "openai").as_deref(),
Some("gpt-4o-mini")
);
}
#[test]
fn test_get_reasoning_effort_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
assert_eq!(
get_reasoning_effort_for_provider(&config, "openai"),
Some(ReasoningEffort::Medium)
);
}
}