Skip to main content

bamboo_server/app_state/
provider_api.rs

1use super::*;
2
3use crate::tools::ToolSurface;
4
5impl AppState {
6    /// Get a clone of the current provider
7    ///
8    /// Returns a thread-safe reference to the current LLM provider.
9    /// This is the preferred way to access the provider for making requests.
10    ///
11    /// # Returns
12    ///
13    /// An Arc reference to the current provider implementation.
14    ///
15    /// # Example
16    ///
17    /// ```rust,no_run
18    /// use bamboo_server::app_state::AppState;
19    /// use std::path::PathBuf;
20    ///
21    /// #[tokio::main]
22    /// async fn main() {
23    ///     let state = AppState::new(PathBuf::from("/path/to/.bamboo"))
24    ///         .await
25    ///         .expect("failed to initialize app state");
26    ///     let provider = state.get_provider().await;
27    ///
28    ///     // Use provider to make LLM requests...
29    /// }
30    /// ```
31    pub async fn get_provider(&self) -> Arc<dyn LLMProvider> {
32        // Important: return the reloadable handle, not a snapshot clone of the current provider.
33        // This ensures config/provider switches take effect without restarting the server.
34        self.provider_handle.clone()
35    }
36
37    /// Get a provider for a specific [`ProviderModelRef`].
38    ///
39    /// Used when `features.provider_model_ref` is enabled to route requests
40    /// to the correct provider based on the model reference.
41    pub fn get_provider_for_model_ref(
42        &self,
43        target: &bamboo_domain::ProviderModelRef,
44    ) -> Result<Arc<dyn LLMProvider>, AppError> {
45        self.provider_router
46            .route(target)
47            .map_err(|e| AppError::BadRequest(e.to_string()))
48    }
49
50    /// Get the appropriate provider for a named provider endpoint (e.g., "openai", "anthropic").
51    ///
52    /// Uses the registry when the `provider_model_ref` feature flag is enabled,
53    /// otherwise falls back to the default provider.
54    pub async fn get_provider_for_endpoint(
55        &self,
56        provider_name: &str,
57    ) -> Result<Arc<dyn LLMProvider>, AppError> {
58        let use_registry = {
59            let config = self.config.read().await;
60            config.features.provider_model_ref
61        };
62
63        if use_registry {
64            self.provider_registry.get(provider_name).ok_or_else(|| {
65                AppError::InternalError(anyhow::anyhow!(
66                    "Provider '{}' not found in registry",
67                    provider_name
68                ))
69            })
70        } else {
71            Ok(self.get_provider().await)
72        }
73    }
74
75    /// Shutdown all MCP servers gracefully
76    ///
77    /// Sends shutdown signals to all running MCP server processes
78    /// and waits for them to terminate cleanly.
79    ///
80    /// This should be called during application shutdown to ensure
81    /// MCP servers are not left running as orphaned processes.
82    #[allow(dead_code)]
83    pub async fn shutdown(&self) {
84        tracing::info!("Shutting down MCP servers...");
85        self.mcp_manager.shutdown_all().await;
86        tracing::info!("MCP servers shut down complete");
87    }
88
89    /// Get the tool executor for a specific surface variant.
90    ///
91    /// Use [`ToolSurface::Root`] for primary sessions,
92    /// [`ToolSurface::Child`] for child sessions, etc.
93    pub fn tools_for(
94        &self,
95        surface: ToolSurface,
96    ) -> Arc<dyn bamboo_agent_core::tools::ToolExecutor> {
97        self.tool_factory.get(surface)
98    }
99
100    /// Get all tool schemas from the composite tool executor
101    ///
102    /// Returns schemas for both built-in tools and MCP-provided tools.
103    /// These schemas are used to inform the LLM about available tools.
104    ///
105    /// # Returns
106    ///
107    /// Vector of tool schemas in Anthropic's tool definition format.
108    pub fn get_all_tool_schemas(&self) -> Vec<bamboo_agent_core::tools::ToolSchema> {
109        self.tool_factory.get(ToolSurface::Root).list_tools()
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    /// Helper: create AppState from a temp dir.
118    async fn make_state() -> (tempfile::TempDir, AppState) {
119        let temp_dir = tempfile::tempdir().expect("temp dir");
120        let state = AppState::new(temp_dir.path().to_path_buf())
121            .await
122            .expect("app state");
123        (temp_dir, state)
124    }
125
126    // ---- get_provider ----
127
128    #[tokio::test]
129    async fn get_provider_returns_a_provider() {
130        let (_temp, state) = make_state().await;
131        let provider = state.get_provider().await;
132        let models = provider.list_models().await;
133        // Default config has no API keys, so the unconfigured provider
134        // returns an auth error — but it should still return a provider handle.
135        assert!(
136            models.is_err(),
137            "Default provider should be UnconfiguredProvider"
138        );
139    }
140
141    // ---- get_provider_for_endpoint ----
142
143    #[tokio::test]
144    async fn endpoint_flag_off_returns_default_provider() {
145        let (_temp, state) = make_state().await;
146        // Flag is OFF by default
147        {
148            let config = state.config.read().await;
149            assert!(!config.features.provider_model_ref);
150        }
151
152        let result = state.get_provider_for_endpoint("openai").await;
153        assert!(
154            result.is_ok(),
155            "Flag OFF should always return default provider"
156        );
157    }
158
159    #[tokio::test]
160    async fn endpoint_flag_on_unknown_provider_returns_error() {
161        let (_temp, state) = make_state().await;
162        {
163            let mut config = state.config.write().await;
164            config.features.provider_model_ref = true;
165        }
166
167        let result = state.get_provider_for_endpoint("nonexistent").await;
168        assert!(result.is_err(), "Flag ON with unknown provider should fail");
169    }
170
171    #[tokio::test]
172    async fn endpoint_flag_on_copilot_returns_provider() {
173        let (_temp, state) = make_state().await;
174        {
175            let mut config = state.config.write().await;
176            config.features.provider_model_ref = true;
177        }
178
179        // Copilot is always available (no API key required)
180        let result = state.get_provider_for_endpoint("copilot").await;
181        assert!(result.is_ok(), "Flag ON with copilot should succeed");
182    }
183
184    // ---- get_provider_for_model_ref ----
185
186    #[tokio::test]
187    async fn model_ref_unknown_provider_returns_error() {
188        let (_temp, state) = make_state().await;
189        let target = bamboo_domain::ProviderModelRef::new("nonexistent", "some-model");
190        let result = state.get_provider_for_model_ref(&target);
191        assert!(result.is_err());
192    }
193
194    #[tokio::test]
195    async fn model_ref_copilot_returns_provider() {
196        let (_temp, state) = make_state().await;
197        let target = bamboo_domain::ProviderModelRef::new("copilot", "gpt-4o");
198        let result = state.get_provider_for_model_ref(&target);
199        assert!(result.is_ok(), "copilot provider should be routable");
200    }
201
202    // ---- tools_for ----
203
204    #[tokio::test]
205    async fn tools_for_root_returns_tool_executor() {
206        let (_temp, state) = make_state().await;
207        let executor = state.tools_for(ToolSurface::Root);
208        let schemas = executor.list_tools();
209        assert!(!schemas.is_empty(), "Root tools should not be empty");
210    }
211
212    // ---- get_all_tool_schemas ----
213
214    #[tokio::test]
215    async fn get_all_tool_schemas_includes_core_tools() {
216        let (_temp, state) = make_state().await;
217        let schemas = state.get_all_tool_schemas();
218        let names: std::collections::HashSet<&str> =
219            schemas.iter().map(|s| s.function.name.as_str()).collect();
220        assert!(names.contains("Task"));
221        assert!(names.contains("SubSession"));
222    }
223}