1use std::sync::Arc;
21
22use bamboo_domain::reasoning::ReasoningEffort;
23use bamboo_domain::ProviderModelRef;
24use bamboo_infrastructure::{Config, ProviderRegistry, ResolvedModel};
25
26use crate::model_config_helper::{
27 resolve_background_model, resolve_fast_model, resolve_subagent_model,
28 resolve_task_summary_model, resolve_vision_model,
29};
30
31pub struct GlobalAreaModels {
38 pub fast: Option<ResolvedModel>,
40 pub fast_ref: Option<ProviderModelRef>,
41 pub background: Option<ResolvedModel>,
43 pub background_ref: Option<ProviderModelRef>,
44 pub summarization: Option<ResolvedModel>,
46 pub summarization_ref: Option<ProviderModelRef>,
47}
48
49pub fn resolve_global_area_models(
58 config: &Config,
59 provider_name: &str,
60 provider_registry: &Arc<ProviderRegistry>,
61) -> GlobalAreaModels {
62 let defaults = config.defaults.as_ref();
63 GlobalAreaModels {
64 fast: resolve_fast_model(config, provider_name, provider_registry),
65 fast_ref: defaults.and_then(|d| d.fast.clone()),
66 background: resolve_background_model(config, provider_name, provider_registry),
67 background_ref: defaults.and_then(|d| d.memory_background.clone()),
68 summarization: resolve_task_summary_model(config, provider_name, provider_registry),
69 summarization_ref: defaults.and_then(|d| d.task_summary.clone()),
70 }
71}
72
73pub fn resolve_global_vision_model(
77 config: &Config,
78 provider_name: &str,
79 provider_registry: &Arc<ProviderRegistry>,
80) -> Option<ResolvedModel> {
81 resolve_vision_model(config, provider_name, provider_registry)
82}
83
84pub fn resolve_global_subagent_model(
89 config: &Config,
90 provider_name: &str,
91 provider_registry: &Arc<ProviderRegistry>,
92 subagent_type: &str,
93) -> Option<ResolvedModel> {
94 resolve_subagent_model(config, provider_name, provider_registry, subagent_type)
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum ReasoningEffortSource {
101 Session,
102 Request,
103 ProviderDefault,
104 None,
105}
106
107impl ReasoningEffortSource {
108 pub fn as_str(self) -> &'static str {
109 match self {
110 Self::Session => "session",
111 Self::Request => "request",
112 Self::ProviderDefault => "provider_default",
113 Self::None => "none",
114 }
115 }
116}
117
118pub fn resolve_effective_reasoning_effort(
126 session_effort: Option<ReasoningEffort>,
127 request_effort: Option<ReasoningEffort>,
128 provider_default: Option<ReasoningEffort>,
129) -> (Option<ReasoningEffort>, ReasoningEffortSource) {
130 if let Some(effort) = session_effort {
131 (Some(effort), ReasoningEffortSource::Session)
132 } else if let Some(effort) = request_effort {
133 (Some(effort), ReasoningEffortSource::Request)
134 } else if let Some(effort) = provider_default {
135 (Some(effort), ReasoningEffortSource::ProviderDefault)
136 } else {
137 (None, ReasoningEffortSource::None)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use bamboo_agent_core::tools::ToolSchema;
145 use bamboo_agent_core::Message;
146 use bamboo_domain::{Session, DEFAULT_REASONING_EFFORT};
147 use bamboo_infrastructure::{
148 DefaultsConfig, FeatureFlags, LLMError, LLMProvider, LLMStream, OpenAIConfig,
149 ProviderConfigs,
150 };
151 use std::collections::HashMap;
152
153 struct NoopProvider;
154
155 #[async_trait::async_trait]
156 impl LLMProvider for NoopProvider {
157 async fn chat_stream(
158 &self,
159 _messages: &[Message],
160 _tools: &[ToolSchema],
161 _max_output_tokens: Option<u32>,
162 _model: &str,
163 ) -> Result<LLMStream, LLMError> {
164 Err(LLMError::Api("noop".to_string()))
165 }
166 }
167
168 fn test_registry() -> Arc<ProviderRegistry> {
169 let mut providers: HashMap<String, Arc<dyn LLMProvider>> = HashMap::new();
170 providers.insert("openai".to_string(), Arc::new(NoopProvider));
171 Arc::new(ProviderRegistry::new(providers, "openai".to_string()))
172 }
173
174 fn defaults_with_all_areas() -> DefaultsConfig {
175 DefaultsConfig {
176 chat: ProviderModelRef::new("openai", "gpt-chat"),
177 fast: Some(ProviderModelRef::new("openai", "gpt-fast")),
178 task_summary: Some(ProviderModelRef::new("openai", "gpt-summary")),
179 vision: Some(ProviderModelRef::new("openai", "gpt-vision")),
180 memory_background: Some(ProviderModelRef::new("openai", "gpt-memory")),
181 planning: None,
182 search: None,
183 code_review: None,
184 sub_agent: Some(ProviderModelRef::new("openai", "gpt-sub")),
185 subagent_models: HashMap::new(),
186 }
187 }
188
189 fn config_with_defaults(defaults: DefaultsConfig) -> Config {
190 Config {
191 provider: "openai".to_string(),
192 features: FeatureFlags {
193 provider_model_ref: true,
194 ..Default::default()
195 },
196 defaults: Some(defaults),
197 ..Config::default()
198 }
199 }
200
201 #[test]
204 fn global_area_models_read_each_area_from_its_own_default() {
205 let config = config_with_defaults(defaults_with_all_areas());
206 let areas = resolve_global_area_models(&config, "openai", &test_registry());
207
208 assert_eq!(
209 areas.fast.as_ref().map(|m| m.model_name.as_str()),
210 Some("gpt-fast")
211 );
212 assert_eq!(
213 areas.summarization.as_ref().map(|m| m.model_name.as_str()),
214 Some("gpt-summary")
215 );
216 assert_eq!(
217 areas.background.as_ref().map(|m| m.model_name.as_str()),
218 Some("gpt-memory")
219 );
220 assert_eq!(
222 areas.fast_ref,
223 Some(ProviderModelRef::new("openai", "gpt-fast"))
224 );
225 assert_eq!(
226 areas.summarization_ref,
227 Some(ProviderModelRef::new("openai", "gpt-summary"))
228 );
229 assert_eq!(
230 areas.background_ref,
231 Some(ProviderModelRef::new("openai", "gpt-memory"))
232 );
233 }
234
235 #[test]
242 fn global_area_models_are_independent_of_any_session() {
243 let config = config_with_defaults(defaults_with_all_areas());
244 let registry = test_registry();
245
246 let before = resolve_global_area_models(&config, "openai", ®istry);
247
248 let mut session = Session::new("s1", "some-exotic-session-model");
250 session.model_ref = Some(ProviderModelRef::new("openai", "some-exotic-session-model"));
251 session.reasoning_effort = Some(ReasoningEffort::Max);
252 let _ = &session; let after = resolve_global_area_models(&config, "openai", ®istry);
255
256 assert_eq!(
257 before.fast.as_ref().map(|m| m.model_name.clone()),
258 after.fast.as_ref().map(|m| m.model_name.clone())
259 );
260 assert_eq!(
261 before.background.as_ref().map(|m| m.model_name.clone()),
262 after.background.as_ref().map(|m| m.model_name.clone())
263 );
264 assert_eq!(
265 before.summarization.as_ref().map(|m| m.model_name.clone()),
266 after.summarization.as_ref().map(|m| m.model_name.clone())
267 );
268 assert_ne!(
270 after.fast.as_ref().map(|m| m.model_name.as_str()),
271 Some("some-exotic-session-model")
272 );
273 }
274
275 #[test]
276 fn vision_model_is_global_from_defaults() {
277 let config = config_with_defaults(defaults_with_all_areas());
278 let vision = resolve_global_vision_model(&config, "openai", &test_registry());
279 assert_eq!(
280 vision.as_ref().map(|m| m.model_name.as_str()),
281 Some("gpt-vision")
282 );
283 }
284
285 #[test]
286 fn subagent_model_is_global_from_defaults() {
287 let config = config_with_defaults(defaults_with_all_areas());
288 let sub = resolve_global_subagent_model(&config, "openai", &test_registry(), "coder");
290 assert_eq!(sub.as_ref().map(|m| m.model_name.as_str()), Some("gpt-sub"));
291 }
292
293 #[test]
294 fn background_falls_back_to_fast_when_memory_background_unset() {
295 let mut defaults = defaults_with_all_areas();
296 defaults.memory_background = None;
297 let config = config_with_defaults(defaults);
298
299 let areas = resolve_global_area_models(&config, "openai", &test_registry());
300 assert_eq!(
302 areas.background.as_ref().map(|m| m.model_name.as_str()),
303 Some("gpt-fast")
304 );
305 }
306
307 #[test]
308 fn legacy_mode_resolves_fast_from_provider_config() {
309 let config = Config {
311 provider: "openai".to_string(),
312 features: FeatureFlags {
313 provider_model_ref: false,
314 ..Default::default()
315 },
316 defaults: None,
317 providers: ProviderConfigs {
318 openai: Some(OpenAIConfig {
319 api_key: "test".to_string(),
320 api_key_encrypted: None,
321 base_url: None,
322 model: Some("gpt-4o".to_string()),
323 fast_model: Some("gpt-4o-mini".to_string()),
324 vision_model: None,
325 reasoning_effort: None,
326 responses_only_models: vec![],
327 request_overrides: None,
328 extra: Default::default(),
329 }),
330 ..ProviderConfigs::default()
331 },
332 ..Config::default()
333 };
334
335 let areas = resolve_global_area_models(&config, "openai", &test_registry());
336 assert_eq!(
337 areas.fast.as_ref().map(|m| m.model_name.as_str()),
338 Some("gpt-4o-mini")
339 );
340 }
341
342 #[test]
345 fn reasoning_prefers_session_then_request_then_provider() {
346 assert_eq!(
347 resolve_effective_reasoning_effort(
348 Some(ReasoningEffort::Max),
349 Some(ReasoningEffort::High),
350 Some(ReasoningEffort::Low),
351 ),
352 (Some(ReasoningEffort::Max), ReasoningEffortSource::Session)
353 );
354 assert_eq!(
355 resolve_effective_reasoning_effort(
356 None,
357 Some(ReasoningEffort::High),
358 Some(ReasoningEffort::Low),
359 ),
360 (Some(ReasoningEffort::High), ReasoningEffortSource::Request)
361 );
362 assert_eq!(
363 resolve_effective_reasoning_effort(None, None, Some(ReasoningEffort::Low)),
364 (
365 Some(ReasoningEffort::Low),
366 ReasoningEffortSource::ProviderDefault
367 )
368 );
369 }
370
371 #[test]
372 fn reasoning_none_when_nothing_configured() {
373 let (effort, source) = resolve_effective_reasoning_effort(None, None, None);
374 assert_eq!(effort, None);
375 assert_eq!(source, ReasoningEffortSource::None);
376 }
377
378 #[test]
379 fn canonical_default_is_medium_and_used_as_terminal() {
380 assert_eq!(DEFAULT_REASONING_EFFORT, ReasoningEffort::Medium);
383 let (effort, _) = resolve_effective_reasoning_effort(None, None, None);
384 assert_eq!(
385 effort.unwrap_or(DEFAULT_REASONING_EFFORT),
386 ReasoningEffort::Medium
387 );
388 }
389}