use std::sync::Arc;
use bamboo_domain::reasoning::ReasoningEffort;
use bamboo_infrastructure::Config;
use bamboo_infrastructure::{LLMError, ProviderModelRouter, ProviderRegistry, ResolvedModel};
pub fn get_default_model_for_provider(
config: &Config,
provider_name: &str,
) -> Result<String, LLMError> {
match provider_name.trim() {
"copilot" => {
let provider_model = config
.providers
.copilot
.as_ref()
.and_then(|c| c.model.clone());
Ok(provider_model.unwrap_or_else(|| "gpt-4o".to_string()))
}
"openai" => {
let openai_config = config
.providers
.openai
.as_ref()
.ok_or_else(|| LLMError::Auth("OpenAI configuration required".to_string()))?;
openai_config.model.clone().ok_or_else(|| {
LLMError::Auth("OpenAI model must be specified in config".to_string())
})
}
"anthropic" => {
let anthropic_config =
config.providers.anthropic.as_ref().ok_or_else(|| {
LLMError::Auth("Anthropic configuration required".to_string())
})?;
anthropic_config.model.clone().ok_or_else(|| {
LLMError::Auth("Anthropic model must be specified in config".to_string())
})
}
"gemini" => {
let gemini_config = config
.providers
.gemini
.as_ref()
.ok_or_else(|| LLMError::Auth("Gemini configuration required".to_string()))?;
gemini_config.model.clone().ok_or_else(|| {
LLMError::Auth("Gemini model must be specified in config".to_string())
})
}
other => Err(LLMError::Auth(format!("Unknown provider: {}", other))),
}
}
pub fn get_default_model_from_config(config: &Config) -> Result<String, LLMError> {
get_default_model_for_provider(config, config.provider.as_str())
}
pub fn get_fast_model_for_provider(config: &Config, provider_name: &str) -> Option<String> {
let fast = match provider_name.trim() {
"openai" => config
.providers
.openai
.as_ref()
.and_then(|c| c.fast_model.clone()),
"anthropic" => config
.providers
.anthropic
.as_ref()
.and_then(|c| c.fast_model.clone()),
"gemini" => config
.providers
.gemini
.as_ref()
.and_then(|c| c.fast_model.clone()),
"copilot" => config
.providers
.copilot
.as_ref()
.and_then(|c| c.fast_model.clone()),
_ => None,
};
fast.or_else(|| get_default_model_for_provider(config, provider_name).ok())
}
pub fn get_memory_background_model_for_provider(
config: &Config,
provider_name: &str,
) -> Option<String> {
let configured = config
.memory
.as_ref()
.and_then(|memory| memory.background_model.as_ref())
.map(|value| value.trim())
.filter(|value| !value.is_empty())
.map(ToString::to_string);
configured.or_else(|| get_fast_model_for_provider(config, provider_name))
}
pub fn get_reasoning_effort_for_provider(
config: &Config,
provider_name: &str,
) -> Option<ReasoningEffort> {
match provider_name.trim() {
"openai" => config
.providers
.openai
.as_ref()
.and_then(|c| c.reasoning_effort),
"anthropic" => config
.providers
.anthropic
.as_ref()
.and_then(|c| c.reasoning_effort),
"gemini" => config
.providers
.gemini
.as_ref()
.and_then(|c| c.reasoning_effort),
"copilot" => config
.providers
.copilot
.as_ref()
.and_then(|c| c.reasoning_effort),
_ => None,
}
}
pub fn get_memory_background_model_from_config(config: &Config) -> Result<String, LLMError> {
config.get_memory_background_model().ok_or_else(|| {
LLMError::Auth(format!(
"No background memory model configured for provider '{}'",
config.provider
))
})
}
pub fn get_vision_model_from_config(config: &Config) -> Result<String, LLMError> {
config.get_vision_model().ok_or_else(|| {
LLMError::Auth(format!(
"No model configured for provider '{}'",
config.provider
))
})
}
pub fn resolve_background_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.memory_background.as_ref())
.or_else(|| config.defaults.as_ref().and_then(|d| d.fast.as_ref()))
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_memory_background_model_for_provider(config, provider_name)?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_fast_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.fast.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_fast_model_for_provider(config, provider_name)?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_vision_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.vision.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = config.get_vision_model()?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_planning_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.planning.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_default_chat_model(config, provider_name, provider_registry)
}
pub fn resolve_search_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.search.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_fast_model(config, provider_name, provider_registry)
.or_else(|| resolve_default_chat_model(config, provider_name, provider_registry))
}
pub fn resolve_code_review_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.code_review.as_ref())
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_default_chat_model(config, provider_name, provider_registry)
}
pub fn resolve_subagent_model_ref(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
subagent_type: &str,
) -> Option<bamboo_domain::ProviderModelRef> {
if config.features.provider_model_ref {
let router = ProviderModelRouter::new(provider_registry.clone());
if let Some(defaults) = config.defaults.as_ref() {
let candidate_refs = [
defaults.subagent_models.get(subagent_type),
defaults.sub_session.as_ref(),
defaults.fast.as_ref(),
];
for model_ref in candidate_refs.into_iter().flatten() {
if router.route(model_ref).is_ok() {
return Some(model_ref.clone());
}
}
}
}
resolve_fast_model(config, provider_name, provider_registry)
.or_else(|| resolve_default_chat_model(config, provider_name, provider_registry))
.map(|resolved| bamboo_domain::ProviderModelRef::new(provider_name, resolved.model_name))
}
pub fn resolve_subagent_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
subagent_type: &str,
) -> Option<ResolvedModel> {
let model_ref =
resolve_subagent_model_ref(config, provider_name, provider_registry, subagent_type)?;
let provider = ProviderModelRouter::new(provider_registry.clone())
.route(&model_ref)
.or_else(|_| {
provider_registry.get(&model_ref.provider).ok_or_else(|| {
LLMError::Auth(format!("Provider '{}' not available", model_ref.provider))
})
})
.ok()?;
Some(ResolvedModel {
provider,
model_name: model_ref.model,
})
}
fn resolve_default_chat_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().map(|d| &d.chat) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_default_model_for_provider(config, provider_name).ok()?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
#[cfg(test)]
mod tests {
use super::*;
use bamboo_agent_core::tools::ToolSchema;
use bamboo_agent_core::Message;
use bamboo_domain::ProviderModelRef;
use bamboo_infrastructure::{
CopilotConfig, DefaultsConfig, LLMProvider, LLMStream, OpenAIConfig, ProviderConfigs,
};
use std::collections::HashMap;
struct NoopProvider;
#[async_trait::async_trait]
impl LLMProvider for NoopProvider {
async fn chat_stream(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
_max_output_tokens: Option<u32>,
_model: &str,
) -> Result<LLMStream, LLMError> {
Err(LLMError::Api("noop".to_string()))
}
}
fn test_registry() -> Arc<ProviderRegistry> {
let mut providers: HashMap<String, Arc<dyn LLMProvider>> = HashMap::new();
providers.insert("openai".to_string(), Arc::new(NoopProvider));
Arc::new(ProviderRegistry::new(providers, "openai".to_string()))
}
#[test]
fn test_get_model_from_openai_config() {
let config = Config {
provider: "openai".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o");
}
#[test]
fn test_error_when_model_not_configured() {
let config = Config {
provider: "openai".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: None, fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("model must be specified"));
}
#[test]
fn test_get_model_from_copilot_provider_config() {
let config = Config {
provider: "copilot".to_string(),
providers: ProviderConfigs {
copilot: Some(CopilotConfig {
enabled: true,
headless_auth: false,
model: Some("gpt-4o-mini".to_string()),
fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o-mini");
}
#[test]
fn test_get_model_from_copilot_default_fallback() {
let config = Config {
provider: "copilot".to_string(),
providers: ProviderConfigs::default(),
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o");
}
#[test]
fn test_get_default_model_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_for_provider(&config, "openai").expect("openai config");
assert_eq!(result, "gpt-4o");
}
#[test]
fn test_get_fast_model_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
assert_eq!(
get_fast_model_for_provider(&config, "openai").as_deref(),
Some("gpt-4o-mini")
);
}
#[test]
fn test_get_reasoning_effort_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
assert_eq!(
get_reasoning_effort_for_provider(&config, "openai"),
Some(ReasoningEffort::Medium)
);
}
#[test]
fn resolve_subagent_model_ref_prefers_sub_session_over_fast() {
let mut config = Config::default();
config.provider = "openai".to_string();
config.features.provider_model_ref = true;
config.defaults = Some(DefaultsConfig {
chat: ProviderModelRef::new("openai", "gpt-chat"),
fast: Some(ProviderModelRef::new("openai", "gpt-fast")),
vision: None,
memory_background: None,
planning: None,
search: None,
code_review: None,
sub_session: Some(ProviderModelRef::new("openai", "gpt-sub-session")),
subagent_models: HashMap::new(),
});
let resolved = resolve_subagent_model_ref(&config, "openai", &test_registry(), "coder")
.expect("sub-session model should resolve");
assert_eq!(resolved, ProviderModelRef::new("openai", "gpt-sub-session"));
}
#[test]
fn resolve_subagent_model_ref_falls_back_to_fast_when_sub_session_unset() {
let mut config = Config::default();
config.provider = "openai".to_string();
config.features.provider_model_ref = true;
config.defaults = Some(DefaultsConfig {
chat: ProviderModelRef::new("openai", "gpt-chat"),
fast: Some(ProviderModelRef::new("openai", "gpt-fast")),
vision: None,
memory_background: None,
planning: None,
search: None,
code_review: None,
sub_session: None,
subagent_models: HashMap::new(),
});
let resolved = resolve_subagent_model_ref(&config, "openai", &test_registry(), "coder")
.expect("fast model should resolve");
assert_eq!(resolved, ProviderModelRef::new("openai", "gpt-fast"));
}
}