1use std::sync::Arc;
21
22use bamboo_domain::reasoning::ReasoningEffort;
23use bamboo_domain::ProviderModelRef;
24use bamboo_infrastructure::{Config, ProviderRegistry, ResolvedModel};
25
26use crate::model_config_helper::{
27 resolve_background_model, resolve_fast_model, resolve_subagent_model, resolve_task_summary_model,
28 resolve_vision_model,
29};
30
31pub struct GlobalAreaModels {
38 pub fast: Option<ResolvedModel>,
40 pub fast_ref: Option<ProviderModelRef>,
41 pub background: Option<ResolvedModel>,
43 pub background_ref: Option<ProviderModelRef>,
44 pub summarization: Option<ResolvedModel>,
46 pub summarization_ref: Option<ProviderModelRef>,
47}
48
49pub fn resolve_global_area_models(
58 config: &Config,
59 provider_name: &str,
60 provider_registry: &Arc<ProviderRegistry>,
61) -> GlobalAreaModels {
62 let defaults = config.defaults.as_ref();
63 GlobalAreaModels {
64 fast: resolve_fast_model(config, provider_name, provider_registry),
65 fast_ref: defaults.and_then(|d| d.fast.clone()),
66 background: resolve_background_model(config, provider_name, provider_registry),
67 background_ref: defaults.and_then(|d| d.memory_background.clone()),
68 summarization: resolve_task_summary_model(config, provider_name, provider_registry),
69 summarization_ref: defaults.and_then(|d| d.task_summary.clone()),
70 }
71}
72
73pub fn resolve_global_vision_model(
77 config: &Config,
78 provider_name: &str,
79 provider_registry: &Arc<ProviderRegistry>,
80) -> Option<ResolvedModel> {
81 resolve_vision_model(config, provider_name, provider_registry)
82}
83
84pub fn resolve_global_subagent_model(
89 config: &Config,
90 provider_name: &str,
91 provider_registry: &Arc<ProviderRegistry>,
92 subagent_type: &str,
93) -> Option<ResolvedModel> {
94 resolve_subagent_model(config, provider_name, provider_registry, subagent_type)
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum ReasoningEffortSource {
101 Session,
102 Request,
103 ProviderDefault,
104 None,
105}
106
107impl ReasoningEffortSource {
108 pub fn as_str(self) -> &'static str {
109 match self {
110 Self::Session => "session",
111 Self::Request => "request",
112 Self::ProviderDefault => "provider_default",
113 Self::None => "none",
114 }
115 }
116}
117
118pub fn resolve_effective_reasoning_effort(
126 session_effort: Option<ReasoningEffort>,
127 request_effort: Option<ReasoningEffort>,
128 provider_default: Option<ReasoningEffort>,
129) -> (Option<ReasoningEffort>, ReasoningEffortSource) {
130 if let Some(effort) = session_effort {
131 (Some(effort), ReasoningEffortSource::Session)
132 } else if let Some(effort) = request_effort {
133 (Some(effort), ReasoningEffortSource::Request)
134 } else if let Some(effort) = provider_default {
135 (Some(effort), ReasoningEffortSource::ProviderDefault)
136 } else {
137 (None, ReasoningEffortSource::None)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use bamboo_agent_core::tools::ToolSchema;
145 use bamboo_agent_core::Message;
146 use bamboo_domain::{Session, DEFAULT_REASONING_EFFORT};
147 use bamboo_infrastructure::{
148 DefaultsConfig, FeatureFlags, LLMError, LLMProvider, LLMStream, OpenAIConfig,
149 ProviderConfigs,
150 };
151 use std::collections::HashMap;
152
153 struct NoopProvider;
154
155 #[async_trait::async_trait]
156 impl LLMProvider for NoopProvider {
157 async fn chat_stream(
158 &self,
159 _messages: &[Message],
160 _tools: &[ToolSchema],
161 _max_output_tokens: Option<u32>,
162 _model: &str,
163 ) -> Result<LLMStream, LLMError> {
164 Err(LLMError::Api("noop".to_string()))
165 }
166 }
167
168 fn test_registry() -> Arc<ProviderRegistry> {
169 let mut providers: HashMap<String, Arc<dyn LLMProvider>> = HashMap::new();
170 providers.insert("openai".to_string(), Arc::new(NoopProvider));
171 Arc::new(ProviderRegistry::new(providers, "openai".to_string()))
172 }
173
174 fn defaults_with_all_areas() -> DefaultsConfig {
175 DefaultsConfig {
176 chat: ProviderModelRef::new("openai", "gpt-chat"),
177 fast: Some(ProviderModelRef::new("openai", "gpt-fast")),
178 task_summary: Some(ProviderModelRef::new("openai", "gpt-summary")),
179 vision: Some(ProviderModelRef::new("openai", "gpt-vision")),
180 memory_background: Some(ProviderModelRef::new("openai", "gpt-memory")),
181 planning: None,
182 search: None,
183 code_review: None,
184 sub_agent: Some(ProviderModelRef::new("openai", "gpt-sub")),
185 subagent_models: HashMap::new(),
186 }
187 }
188
189 fn config_with_defaults(defaults: DefaultsConfig) -> Config {
190 Config {
191 provider: "openai".to_string(),
192 features: FeatureFlags {
193 provider_model_ref: true,
194 ..Default::default()
195 },
196 defaults: Some(defaults),
197 ..Config::default()
198 }
199 }
200
201 #[test]
204 fn global_area_models_read_each_area_from_its_own_default() {
205 let config = config_with_defaults(defaults_with_all_areas());
206 let areas = resolve_global_area_models(&config, "openai", &test_registry());
207
208 assert_eq!(areas.fast.as_ref().map(|m| m.model_name.as_str()), Some("gpt-fast"));
209 assert_eq!(
210 areas.summarization.as_ref().map(|m| m.model_name.as_str()),
211 Some("gpt-summary")
212 );
213 assert_eq!(
214 areas.background.as_ref().map(|m| m.model_name.as_str()),
215 Some("gpt-memory")
216 );
217 assert_eq!(areas.fast_ref, Some(ProviderModelRef::new("openai", "gpt-fast")));
219 assert_eq!(
220 areas.summarization_ref,
221 Some(ProviderModelRef::new("openai", "gpt-summary"))
222 );
223 assert_eq!(
224 areas.background_ref,
225 Some(ProviderModelRef::new("openai", "gpt-memory"))
226 );
227 }
228
229 #[test]
236 fn global_area_models_are_independent_of_any_session() {
237 let config = config_with_defaults(defaults_with_all_areas());
238 let registry = test_registry();
239
240 let before = resolve_global_area_models(&config, "openai", ®istry);
241
242 let mut session = Session::new("s1", "some-exotic-session-model");
244 session.model_ref = Some(ProviderModelRef::new("openai", "some-exotic-session-model"));
245 session.reasoning_effort = Some(ReasoningEffort::Max);
246 let _ = &session; let after = resolve_global_area_models(&config, "openai", ®istry);
249
250 assert_eq!(
251 before.fast.as_ref().map(|m| m.model_name.clone()),
252 after.fast.as_ref().map(|m| m.model_name.clone())
253 );
254 assert_eq!(
255 before.background.as_ref().map(|m| m.model_name.clone()),
256 after.background.as_ref().map(|m| m.model_name.clone())
257 );
258 assert_eq!(
259 before.summarization.as_ref().map(|m| m.model_name.clone()),
260 after.summarization.as_ref().map(|m| m.model_name.clone())
261 );
262 assert_ne!(
264 after.fast.as_ref().map(|m| m.model_name.as_str()),
265 Some("some-exotic-session-model")
266 );
267 }
268
269 #[test]
270 fn vision_model_is_global_from_defaults() {
271 let config = config_with_defaults(defaults_with_all_areas());
272 let vision = resolve_global_vision_model(&config, "openai", &test_registry());
273 assert_eq!(vision.as_ref().map(|m| m.model_name.as_str()), Some("gpt-vision"));
274 }
275
276 #[test]
277 fn subagent_model_is_global_from_defaults() {
278 let config = config_with_defaults(defaults_with_all_areas());
279 let sub = resolve_global_subagent_model(&config, "openai", &test_registry(), "coder");
281 assert_eq!(sub.as_ref().map(|m| m.model_name.as_str()), Some("gpt-sub"));
282 }
283
284 #[test]
285 fn background_falls_back_to_fast_when_memory_background_unset() {
286 let mut defaults = defaults_with_all_areas();
287 defaults.memory_background = None;
288 let config = config_with_defaults(defaults);
289
290 let areas = resolve_global_area_models(&config, "openai", &test_registry());
291 assert_eq!(
293 areas.background.as_ref().map(|m| m.model_name.as_str()),
294 Some("gpt-fast")
295 );
296 }
297
298 #[test]
299 fn legacy_mode_resolves_fast_from_provider_config() {
300 let config = Config {
302 provider: "openai".to_string(),
303 features: FeatureFlags {
304 provider_model_ref: false,
305 ..Default::default()
306 },
307 defaults: None,
308 providers: ProviderConfigs {
309 openai: Some(OpenAIConfig {
310 api_key: "test".to_string(),
311 api_key_encrypted: None,
312 base_url: None,
313 model: Some("gpt-4o".to_string()),
314 fast_model: Some("gpt-4o-mini".to_string()),
315 vision_model: None,
316 reasoning_effort: None,
317 responses_only_models: vec![],
318 request_overrides: None,
319 extra: Default::default(),
320 }),
321 ..ProviderConfigs::default()
322 },
323 ..Config::default()
324 };
325
326 let areas = resolve_global_area_models(&config, "openai", &test_registry());
327 assert_eq!(areas.fast.as_ref().map(|m| m.model_name.as_str()), Some("gpt-4o-mini"));
328 }
329
330 #[test]
333 fn reasoning_prefers_session_then_request_then_provider() {
334 assert_eq!(
335 resolve_effective_reasoning_effort(
336 Some(ReasoningEffort::Max),
337 Some(ReasoningEffort::High),
338 Some(ReasoningEffort::Low),
339 ),
340 (Some(ReasoningEffort::Max), ReasoningEffortSource::Session)
341 );
342 assert_eq!(
343 resolve_effective_reasoning_effort(
344 None,
345 Some(ReasoningEffort::High),
346 Some(ReasoningEffort::Low),
347 ),
348 (Some(ReasoningEffort::High), ReasoningEffortSource::Request)
349 );
350 assert_eq!(
351 resolve_effective_reasoning_effort(None, None, Some(ReasoningEffort::Low)),
352 (Some(ReasoningEffort::Low), ReasoningEffortSource::ProviderDefault)
353 );
354 }
355
356 #[test]
357 fn reasoning_none_when_nothing_configured() {
358 let (effort, source) = resolve_effective_reasoning_effort(None, None, None);
359 assert_eq!(effort, None);
360 assert_eq!(source, ReasoningEffortSource::None);
361 }
362
363 #[test]
364 fn canonical_default_is_medium_and_used_as_terminal() {
365 assert_eq!(DEFAULT_REASONING_EFFORT, ReasoningEffort::Medium);
368 let (effort, _) = resolve_effective_reasoning_effort(None, None, None);
369 assert_eq!(effort.unwrap_or(DEFAULT_REASONING_EFFORT), ReasoningEffort::Medium);
370 }
371}