bamboo_server/app_state/
provider_api.rs1use super::*;
2
3use crate::tools::ToolSurface;
4
5impl AppState {
6 pub async fn get_provider(&self) -> Arc<dyn LLMProvider> {
32 self.provider_handle.clone()
35 }
36
37 pub fn get_provider_for_model_ref(
42 &self,
43 target: &bamboo_domain::ProviderModelRef,
44 ) -> Result<Arc<dyn LLMProvider>, AppError> {
45 self.provider_router
46 .route(target)
47 .map_err(|e| AppError::BadRequest(e.to_string()))
48 }
49
50 pub async fn get_provider_for_endpoint(
55 &self,
56 provider_name: &str,
57 ) -> Result<Arc<dyn LLMProvider>, AppError> {
58 let use_registry = {
59 let config = self.config.read().await;
60 config.features.provider_model_ref
61 };
62
63 if use_registry {
64 self.provider_registry.get(provider_name).ok_or_else(|| {
65 AppError::InternalError(anyhow::anyhow!(
66 "Provider '{}' not found in registry",
67 provider_name
68 ))
69 })
70 } else {
71 Ok(self.get_provider().await)
72 }
73 }
74
75 #[allow(dead_code)]
83 pub async fn shutdown(&self) {
84 tracing::info!("Shutting down MCP servers...");
85 self.mcp_manager.shutdown_all().await;
86 tracing::info!("MCP servers shut down complete");
87 }
88
89 pub fn tools_for(
94 &self,
95 surface: ToolSurface,
96 ) -> Arc<dyn bamboo_agent_core::tools::ToolExecutor> {
97 self.tool_factory.get(surface)
98 }
99
100 pub fn get_all_tool_schemas(&self) -> Vec<bamboo_agent_core::tools::ToolSchema> {
109 self.tool_factory.get(ToolSurface::Root).list_tools()
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 async fn make_state() -> (tempfile::TempDir, AppState) {
119 let temp_dir = tempfile::tempdir().expect("temp dir");
120 let state = AppState::new(temp_dir.path().to_path_buf())
121 .await
122 .expect("app state");
123 (temp_dir, state)
124 }
125
126 #[tokio::test]
129 async fn get_provider_returns_a_provider() {
130 let (_temp, state) = make_state().await;
131 let provider = state.get_provider().await;
132 let models = provider.list_models().await;
133 assert!(
136 models.is_err(),
137 "Default provider should be UnconfiguredProvider"
138 );
139 }
140
141 #[tokio::test]
144 async fn endpoint_flag_off_returns_default_provider() {
145 let (_temp, state) = make_state().await;
146 {
148 let config = state.config.read().await;
149 assert!(!config.features.provider_model_ref);
150 }
151
152 let result = state.get_provider_for_endpoint("openai").await;
153 assert!(
154 result.is_ok(),
155 "Flag OFF should always return default provider"
156 );
157 }
158
159 #[tokio::test]
160 async fn endpoint_flag_on_unknown_provider_returns_error() {
161 let (_temp, state) = make_state().await;
162 {
163 let mut config = state.config.write().await;
164 config.features.provider_model_ref = true;
165 }
166
167 let result = state.get_provider_for_endpoint("nonexistent").await;
168 assert!(result.is_err(), "Flag ON with unknown provider should fail");
169 }
170
171 #[tokio::test]
172 async fn endpoint_flag_on_copilot_returns_provider() {
173 let (_temp, state) = make_state().await;
174 {
175 let mut config = state.config.write().await;
176 config.features.provider_model_ref = true;
177 }
178
179 let result = state.get_provider_for_endpoint("copilot").await;
181 assert!(result.is_ok(), "Flag ON with copilot should succeed");
182 }
183
184 #[tokio::test]
187 async fn model_ref_unknown_provider_returns_error() {
188 let (_temp, state) = make_state().await;
189 let target = bamboo_domain::ProviderModelRef::new("nonexistent", "some-model");
190 let result = state.get_provider_for_model_ref(&target);
191 assert!(result.is_err());
192 }
193
194 #[tokio::test]
195 async fn model_ref_copilot_returns_provider() {
196 let (_temp, state) = make_state().await;
197 let target = bamboo_domain::ProviderModelRef::new("copilot", "gpt-4o");
198 let result = state.get_provider_for_model_ref(&target);
199 assert!(result.is_ok(), "copilot provider should be routable");
200 }
201
202 #[tokio::test]
205 async fn tools_for_root_returns_tool_executor() {
206 let (_temp, state) = make_state().await;
207 let executor = state.tools_for(ToolSurface::Root);
208 let schemas = executor.list_tools();
209 assert!(!schemas.is_empty(), "Root tools should not be empty");
210 }
211
212 #[tokio::test]
215 async fn get_all_tool_schemas_includes_core_tools() {
216 let (_temp, state) = make_state().await;
217 let schemas = state.get_all_tool_schemas();
218 let names: std::collections::HashSet<&str> =
219 schemas.iter().map(|s| s.function.name.as_str()).collect();
220 assert!(names.contains("Task"));
221 assert!(names.contains("SubSession"));
222 }
223}