use std::sync::Arc;
use bamboo_domain::reasoning::ReasoningEffort;
use crate::config::GoldConfig;
use bamboo_infrastructure::Config;
use bamboo_infrastructure::{LLMError, ProviderModelRouter, ProviderRegistry, ResolvedModel};
pub const GOLD_CONFIG_METADATA_KEY: &str = "gold_config";
pub fn resolve_provider_type(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<String> {
let trimmed = provider_name.trim();
if trimmed.is_empty() {
return None;
}
config
.provider_instances
.get(trimmed)
.map(|instance| instance.provider_type.clone())
.or_else(|| {
provider_registry
.get_metadata(trimmed)
.map(|meta| meta.provider_type)
})
.or_else(|| Some(trimmed.to_string()))
}
pub fn parse_session_gold_config(session_gold_config_json: Option<&str>) -> Option<GoldConfig> {
let raw = session_gold_config_json?.trim();
if raw.is_empty() {
return None;
}
serde_json::from_str::<GoldConfig>(raw).ok()
}
pub fn normalize_gold_config_json(value: &serde_json::Value) -> Result<String, serde_json::Error> {
let parsed = serde_json::from_value::<GoldConfig>(value.clone())?;
serde_json::to_string(&parsed)
}
pub fn resolve_global_gold_config(config: &Config) -> Option<GoldConfig> {
config
.extra
.get("gold")
.cloned()
.and_then(|value| serde_json::from_value::<GoldConfig>(value).ok())
}
pub fn resolve_gold_config(
config: &Config,
session_gold_config_json: Option<&str>,
) -> Option<GoldConfig> {
if session_gold_config_json.is_some() {
return parse_session_gold_config(session_gold_config_json);
}
resolve_global_gold_config(config)
}
pub fn get_default_model_for_provider(
config: &Config,
provider_name: &str,
) -> Result<String, LLMError> {
match provider_name.trim() {
"copilot" => {
let provider_model = config
.providers
.copilot
.as_ref()
.and_then(|c| c.model.clone());
Ok(provider_model.unwrap_or_else(|| "gpt-4o".to_string()))
}
"openai" => {
let openai_config = config
.providers
.openai
.as_ref()
.ok_or_else(|| LLMError::Auth("OpenAI configuration required".to_string()))?;
openai_config.model.clone().ok_or_else(|| {
LLMError::Auth("OpenAI model must be specified in config".to_string())
})
}
"anthropic" => {
let anthropic_config =
config.providers.anthropic.as_ref().ok_or_else(|| {
LLMError::Auth("Anthropic configuration required".to_string())
})?;
anthropic_config.model.clone().ok_or_else(|| {
LLMError::Auth("Anthropic model must be specified in config".to_string())
})
}
"gemini" => {
let gemini_config = config
.providers
.gemini
.as_ref()
.ok_or_else(|| LLMError::Auth("Gemini configuration required".to_string()))?;
gemini_config.model.clone().ok_or_else(|| {
LLMError::Auth("Gemini model must be specified in config".to_string())
})
}
other => Err(LLMError::Auth(format!("Unknown provider: {}", other))),
}
}
pub fn get_default_model_from_config(config: &Config) -> Result<String, LLMError> {
get_default_model_for_provider(config, config.provider.as_str())
}
pub fn get_schedule_model_from_config(config: &Config) -> Result<String, LLMError> {
config
.get_fast_model()
.map(|model| model.trim().to_string())
.filter(|model| !model.is_empty())
.ok_or_else(|| {
LLMError::Auth(format!(
"No fast/default model configured for provider '{}'",
config.provider
))
})
}
pub fn get_fast_model_for_provider(config: &Config, provider_name: &str) -> Option<String> {
let fast = match provider_name.trim() {
"openai" => config
.providers
.openai
.as_ref()
.and_then(|c| c.fast_model.clone()),
"anthropic" => config
.providers
.anthropic
.as_ref()
.and_then(|c| c.fast_model.clone()),
"gemini" => config
.providers
.gemini
.as_ref()
.and_then(|c| c.fast_model.clone()),
"copilot" => config
.providers
.copilot
.as_ref()
.and_then(|c| c.fast_model.clone()),
_ => None,
};
fast.or_else(|| get_default_model_for_provider(config, provider_name).ok())
}
pub fn get_memory_background_model_for_provider(
config: &Config,
provider_name: &str,
) -> Option<String> {
let configured = config
.memory
.as_ref()
.and_then(|memory| memory.background_model.as_ref())
.map(|value| value.trim())
.filter(|value| !value.is_empty())
.map(ToString::to_string);
configured.or_else(|| get_fast_model_for_provider(config, provider_name))
}
pub fn get_reasoning_effort_for_provider(
config: &Config,
provider_name: &str,
) -> Option<ReasoningEffort> {
config.reasoning_effort_for_key(provider_name)
}
pub fn get_task_summary_model_from_config(config: &Config) -> Result<String, LLMError> {
config.get_task_summary_model().ok_or_else(|| {
LLMError::Auth(format!(
"No task summary model configured for provider '{}'",
config.provider
))
})
}
pub fn get_memory_background_model_from_config(config: &Config) -> Result<String, LLMError> {
config.get_memory_background_model().ok_or_else(|| {
LLMError::Auth(format!(
"No background memory model configured for provider '{}'",
config.provider
))
})
}
pub fn get_vision_model_from_config(config: &Config) -> Result<String, LLMError> {
config.get_vision_model().ok_or_else(|| {
LLMError::Auth(format!(
"No model configured for provider '{}'",
config.provider
))
})
}
pub fn resolve_task_summary_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.task_summary.as_ref())
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_background_model(config, provider_name, provider_registry)
.or_else(|| resolve_default_chat_model(config, provider_name, provider_registry))
}
pub fn resolve_background_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.memory_background.as_ref())
.or_else(|| config.defaults.as_ref().and_then(|d| d.fast.as_ref()))
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_memory_background_model_for_provider(config, provider_name)?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_fast_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.fast.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_fast_model_for_provider(config, provider_name)?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_vision_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.vision.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = config.get_vision_model()?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_planning_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.planning.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_default_chat_model(config, provider_name, provider_registry)
}
pub fn resolve_search_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().and_then(|d| d.search.as_ref()) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_fast_model(config, provider_name, provider_registry)
.or_else(|| resolve_default_chat_model(config, provider_name, provider_registry))
}
pub fn resolve_code_review_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config
.defaults
.as_ref()
.and_then(|d| d.code_review.as_ref())
{
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
resolve_default_chat_model(config, provider_name, provider_registry)
}
pub fn resolve_subagent_model_ref(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
subagent_type: &str,
) -> Option<bamboo_domain::ProviderModelRef> {
if config.features.provider_model_ref {
let router = ProviderModelRouter::new(provider_registry.clone());
if let Some(defaults) = config.defaults.as_ref() {
let candidate_refs = [
defaults.subagent_models.get(subagent_type),
defaults.sub_agent.as_ref(),
defaults.fast.as_ref(),
];
for model_ref in candidate_refs.into_iter().flatten() {
if router.route(model_ref).is_ok() {
return Some(model_ref.clone());
}
}
}
}
resolve_fast_model(config, provider_name, provider_registry)
.or_else(|| resolve_default_chat_model(config, provider_name, provider_registry))
.map(|resolved| bamboo_domain::ProviderModelRef::new(provider_name, resolved.model_name))
}
pub fn resolve_subagent_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
subagent_type: &str,
) -> Option<ResolvedModel> {
let model_ref =
resolve_subagent_model_ref(config, provider_name, provider_registry, subagent_type)?;
let provider = ProviderModelRouter::new(provider_registry.clone())
.route(&model_ref)
.or_else(|_| {
provider_registry.get(&model_ref.provider).ok_or_else(|| {
LLMError::Auth(format!("Provider '{}' not available", model_ref.provider))
})
})
.ok()?;
Some(ResolvedModel {
provider,
model_name: model_ref.model,
})
}
fn resolve_default_chat_model(
config: &Config,
provider_name: &str,
provider_registry: &Arc<ProviderRegistry>,
) -> Option<ResolvedModel> {
if config.features.provider_model_ref {
if let Some(model_ref) = config.defaults.as_ref().map(|d| &d.chat) {
if let Ok(provider) =
ProviderModelRouter::new(provider_registry.clone()).route(model_ref)
{
return Some(ResolvedModel {
provider,
model_name: model_ref.model.clone(),
});
}
}
}
let model_name = get_default_model_for_provider(config, provider_name).ok()?;
let provider = provider_registry.get(provider_name)?;
Some(ResolvedModel {
provider,
model_name,
})
}
pub fn resolve_image_fallback(
config_snapshot: &Config,
) -> Result<Option<crate::ImageFallbackConfig>, String> {
use crate::ImageFallbackMode;
if !config_snapshot.hooks.image_fallback.enabled {
return Ok(None);
}
let mode_str = config_snapshot
.hooks
.image_fallback
.mode
.trim()
.to_ascii_lowercase();
let mode = match mode_str.as_str() {
"placeholder" => ImageFallbackMode::Placeholder,
"error" => ImageFallbackMode::Error,
"ocr" => ImageFallbackMode::Ocr,
"vision" => ImageFallbackMode::Vision,
other => {
return Err(format!(
"Invalid config: hooks.image_fallback.mode must be 'placeholder', 'error', 'ocr', or 'vision' (got '{other}')"
));
}
};
let vision_model = if mode == ImageFallbackMode::Vision {
config_snapshot.get_vision_model()
} else {
None
};
Ok(Some(crate::ImageFallbackConfig { mode, vision_model }))
}
#[cfg(test)]
mod tests {
use super::*;
use bamboo_agent_core::tools::ToolSchema;
use bamboo_agent_core::Message;
use bamboo_domain::ProviderModelRef;
use bamboo_infrastructure::{
CopilotConfig, DefaultsConfig, LLMProvider, LLMStream, OpenAIConfig, ProviderConfigs,
};
use std::collections::HashMap;
struct NoopProvider;
#[async_trait::async_trait]
impl LLMProvider for NoopProvider {
async fn chat_stream(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
_max_output_tokens: Option<u32>,
_model: &str,
) -> Result<LLMStream, LLMError> {
Err(LLMError::Api("noop".to_string()))
}
}
fn test_registry() -> Arc<ProviderRegistry> {
let mut providers: HashMap<String, Arc<dyn LLMProvider>> = HashMap::new();
providers.insert("openai".to_string(), Arc::new(NoopProvider));
Arc::new(ProviderRegistry::new(providers, "openai".to_string()))
}
#[test]
fn test_get_model_from_openai_config() {
let config = Config {
provider: "openai".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o");
}
#[test]
fn test_error_when_model_not_configured() {
let config = Config {
provider: "openai".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: None, fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("model must be specified"));
}
#[test]
fn test_get_model_from_copilot_provider_config() {
let config = Config {
provider: "copilot".to_string(),
providers: ProviderConfigs {
copilot: Some(CopilotConfig {
enabled: true,
headless_auth: false,
model: Some("gpt-4o-mini".to_string()),
fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o-mini");
}
#[test]
fn test_get_model_from_copilot_default_fallback() {
let config = Config {
provider: "copilot".to_string(),
providers: ProviderConfigs::default(),
..Config::default()
};
let result = get_default_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o");
}
#[test]
fn test_get_default_model_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_default_model_for_provider(&config, "openai").expect("openai config");
assert_eq!(result, "gpt-4o");
}
#[test]
fn test_get_fast_model_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
assert_eq!(
get_fast_model_for_provider(&config, "openai").as_deref(),
Some("gpt-4o-mini")
);
}
#[test]
fn test_get_schedule_model_from_config_prefers_fast_model() {
let config = Config {
provider: "openai".to_string(),
defaults: None,
features: bamboo_infrastructure::FeatureFlags {
provider_model_ref: false,
..Default::default()
},
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_schedule_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o-mini");
}
#[test]
fn test_get_schedule_model_from_config_falls_back_to_default_model() {
let config = Config {
provider: "openai".to_string(),
defaults: None,
features: bamboo_infrastructure::FeatureFlags {
provider_model_ref: false,
..Default::default()
},
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: None,
vision_model: None,
reasoning_effort: None,
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
let result = get_schedule_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-4o");
}
#[test]
fn test_get_schedule_model_from_config_prefers_defaults_fast_over_chat() {
let config = Config {
provider: "openai".to_string(),
features: bamboo_infrastructure::FeatureFlags {
provider_model_ref: true,
..Default::default()
},
defaults: Some(DefaultsConfig {
chat: ProviderModelRef::new("openai", "gpt-chat"),
fast: Some(ProviderModelRef::new("openai", "gpt-fast")),
task_summary: None,
vision: None,
memory_background: None,
planning: None,
search: None,
code_review: None,
sub_agent: None,
subagent_models: HashMap::new(),
}),
..Default::default()
};
let result = get_schedule_model_from_config(&config);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "gpt-fast");
}
#[test]
fn test_get_reasoning_effort_for_specific_provider() {
let config = Config {
provider: "anthropic".to_string(),
providers: ProviderConfigs {
openai: Some(OpenAIConfig {
api_key: "test".to_string(),
api_key_encrypted: None,
base_url: None,
model: Some("gpt-4o".to_string()),
fast_model: Some("gpt-4o-mini".to_string()),
vision_model: None,
reasoning_effort: Some(ReasoningEffort::Medium),
responses_only_models: vec![],
request_overrides: None,
extra: Default::default(),
}),
..ProviderConfigs::default()
},
..Config::default()
};
assert_eq!(
get_reasoning_effort_for_provider(&config, "openai"),
Some(ReasoningEffort::Medium)
);
}
#[test]
fn resolve_subagent_model_ref_prefers_sub_agent_over_fast() {
let config = Config {
provider: "openai".to_string(),
features: bamboo_infrastructure::FeatureFlags {
provider_model_ref: true,
..Default::default()
},
defaults: Some(DefaultsConfig {
chat: ProviderModelRef::new("openai", "gpt-chat"),
fast: Some(ProviderModelRef::new("openai", "gpt-fast")),
task_summary: None,
vision: None,
memory_background: None,
planning: None,
search: None,
code_review: None,
sub_agent: Some(ProviderModelRef::new("openai", "gpt-sub-agent")),
subagent_models: HashMap::new(),
}),
..Default::default()
};
let resolved = resolve_subagent_model_ref(&config, "openai", &test_registry(), "coder")
.expect("sub-agent model should resolve");
assert_eq!(resolved, ProviderModelRef::new("openai", "gpt-sub-agent"));
}
#[test]
fn resolve_subagent_model_ref_falls_back_to_fast_when_sub_agent_unset() {
let config = Config {
provider: "openai".to_string(),
features: bamboo_infrastructure::FeatureFlags {
provider_model_ref: true,
..Default::default()
},
defaults: Some(DefaultsConfig {
chat: ProviderModelRef::new("openai", "gpt-chat"),
fast: Some(ProviderModelRef::new("openai", "gpt-fast")),
task_summary: None,
vision: None,
memory_background: None,
planning: None,
search: None,
code_review: None,
sub_agent: None,
subagent_models: HashMap::new(),
}),
..Default::default()
};
let resolved = resolve_subagent_model_ref(&config, "openai", &test_registry(), "coder")
.expect("fast model should resolve");
assert_eq!(resolved, ProviderModelRef::new("openai", "gpt-fast"));
}
}