Skip to main content

adk_server/
config.rs

1use std::{sync::Arc, time::Duration};
2
3use crate::auth_bridge::RequestContextExtractor;
4use adk_core::{CacheCapable, ContextCacheConfig, EventsCompactionConfig};
5
6#[cfg(feature = "yaml-agent")]
7use std::path::PathBuf;
8
9/// Security configuration for the ADK server.
10#[derive(Clone, Debug)]
11pub struct SecurityConfig {
12    /// Allowed origins for CORS (empty = allow all, which is NOT recommended for production)
13    pub allowed_origins: Vec<String>,
14    /// Maximum request body size in bytes (default: 10MB)
15    pub max_body_size: usize,
16    /// Request timeout duration (default: 30 seconds)
17    pub request_timeout: Duration,
18    /// Whether to include detailed error messages in responses (default: false for production)
19    pub expose_error_details: bool,
20    /// Whether to expose admin-only debug endpoints when auth is configured.
21    pub expose_admin_debug: bool,
22}
23
24impl Default for SecurityConfig {
25    fn default() -> Self {
26        Self {
27            allowed_origins: Vec::new(), // Empty = permissive (for dev), should be configured for prod
28            max_body_size: 10 * 1024 * 1024, // 10MB
29            request_timeout: Duration::from_secs(30),
30            expose_error_details: false,
31            expose_admin_debug: false,
32        }
33    }
34}
35
36impl SecurityConfig {
37    /// Create a development configuration (permissive CORS, detailed errors)
38    pub fn development() -> Self {
39        Self {
40            allowed_origins: Vec::new(),
41            max_body_size: 10 * 1024 * 1024,
42            request_timeout: Duration::from_secs(60),
43            expose_error_details: true,
44            expose_admin_debug: true,
45        }
46    }
47
48    /// Create a production configuration with specific allowed origins
49    pub fn production(allowed_origins: Vec<String>) -> Self {
50        Self {
51            allowed_origins,
52            max_body_size: 10 * 1024 * 1024,
53            request_timeout: Duration::from_secs(30),
54            expose_error_details: false,
55            expose_admin_debug: false,
56        }
57    }
58}
59
60/// Configuration for the ADK server.
61#[derive(Clone)]
62pub struct ServerConfig {
63    pub agent_loader: Arc<dyn adk_core::AgentLoader>,
64    pub session_service: Arc<dyn adk_session::SessionService>,
65    pub artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
66    pub memory_service: Option<Arc<dyn adk_core::Memory>>,
67    pub compaction_config: Option<EventsCompactionConfig>,
68    pub context_cache_config: Option<ContextCacheConfig>,
69    pub cache_capable: Option<Arc<dyn CacheCapable>>,
70    pub span_exporter: Option<Arc<adk_telemetry::AdkSpanExporter>>,
71    pub backend_url: Option<String>,
72    pub security: SecurityConfig,
73    pub request_context_extractor: Option<Arc<dyn RequestContextExtractor>>,
74    /// Optional interceptor chain for A2A request/response middleware.
75    ///
76    /// When set, the chain is invoked before and after A2A executor processing.
77    #[cfg(feature = "a2a-interceptors")]
78    pub interceptor_chain: Option<Arc<crate::a2a::interceptor::InterceptorChain>>,
79    /// Directories containing YAML agent definitions to watch for hot reload.
80    /// Only used when the `yaml-agent` feature is enabled.
81    #[cfg(feature = "yaml-agent")]
82    pub yaml_agent_dirs: Vec<PathBuf>,
83}
84
85impl ServerConfig {
86    pub fn new(
87        agent_loader: Arc<dyn adk_core::AgentLoader>,
88        session_service: Arc<dyn adk_session::SessionService>,
89    ) -> Self {
90        Self {
91            agent_loader,
92            session_service,
93            artifact_service: None,
94            memory_service: None,
95            compaction_config: None,
96            context_cache_config: None,
97            cache_capable: None,
98            span_exporter: None,
99            backend_url: None,
100            security: SecurityConfig::default(),
101            request_context_extractor: None,
102            #[cfg(feature = "a2a-interceptors")]
103            interceptor_chain: None,
104            #[cfg(feature = "yaml-agent")]
105            yaml_agent_dirs: Vec::new(),
106        }
107    }
108
109    pub fn with_artifact_service(
110        mut self,
111        artifact_service: Arc<dyn adk_artifact::ArtifactService>,
112    ) -> Self {
113        self.artifact_service = Some(artifact_service);
114        self
115    }
116
117    pub fn with_artifact_service_opt(
118        mut self,
119        artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
120    ) -> Self {
121        self.artifact_service = artifact_service;
122        self
123    }
124
125    /// Configure a memory service for semantic search across sessions.
126    ///
127    /// When set, the runner injects memory into the invocation context,
128    /// allowing agents to search previous conversation content via
129    /// `ToolContext::search_memory()`.
130    pub fn with_memory_service(mut self, memory_service: Arc<dyn adk_core::Memory>) -> Self {
131        self.memory_service = Some(memory_service);
132        self
133    }
134
135    /// Configure automatic context compaction for long-running sessions.
136    pub fn with_compaction(mut self, compaction_config: EventsCompactionConfig) -> Self {
137        self.compaction_config = Some(compaction_config);
138        self
139    }
140
141    /// Configure automatic prompt-cache lifecycle management for cache-capable models.
142    pub fn with_context_cache(
143        mut self,
144        context_cache_config: ContextCacheConfig,
145        cache_capable: Arc<dyn CacheCapable>,
146    ) -> Self {
147        self.context_cache_config = Some(context_cache_config);
148        self.cache_capable = Some(cache_capable);
149        self
150    }
151
152    pub fn with_backend_url(mut self, backend_url: impl Into<String>) -> Self {
153        self.backend_url = Some(backend_url.into());
154        self
155    }
156
157    pub fn with_security(mut self, security: SecurityConfig) -> Self {
158        self.security = security;
159        self
160    }
161
162    pub fn with_span_exporter(
163        mut self,
164        span_exporter: Arc<adk_telemetry::AdkSpanExporter>,
165    ) -> Self {
166        self.span_exporter = Some(span_exporter);
167        self
168    }
169
170    /// Configure allowed CORS origins
171    pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
172        self.security.allowed_origins = origins;
173        self
174    }
175
176    /// Configure maximum request body size
177    pub fn with_max_body_size(mut self, size: usize) -> Self {
178        self.security.max_body_size = size;
179        self
180    }
181
182    /// Configure request timeout
183    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
184        self.security.request_timeout = timeout;
185        self
186    }
187
188    /// Enable detailed error messages (for development only)
189    pub fn with_error_details(mut self, expose: bool) -> Self {
190        self.security.expose_error_details = expose;
191        self
192    }
193
194    /// Configure a request context extractor for auth middleware bridging.
195    ///
196    /// When set, the server invokes the extractor on each incoming request
197    /// to extract authenticated identity (user_id, scopes, metadata) from
198    /// HTTP headers. The extracted context flows into `InvocationContext`,
199    /// making scopes available via `ToolContext::user_scopes()`.
200    pub fn with_request_context(mut self, extractor: Arc<dyn RequestContextExtractor>) -> Self {
201        self.request_context_extractor = Some(extractor);
202        self
203    }
204
205    /// Configure an interceptor chain for A2A request/response middleware.
206    ///
207    /// When set, the chain's `run_before` is called before the A2A executor
208    /// processes a request, and `run_after` is called after it completes.
209    /// Interceptors can reject, short-circuit, or modify requests and responses.
210    ///
211    /// # Example
212    ///
213    /// ```rust,ignore
214    /// use adk_server::a2a::interceptor::InterceptorChain;
215    /// use std::sync::Arc;
216    ///
217    /// let chain = InterceptorChain::new()
218    ///     .add(my_auth_interceptor)
219    ///     .add(my_rate_limiter);
220    ///
221    /// let config = ServerConfig::new(loader, session)
222    ///     .with_interceptor_chain(Arc::new(chain));
223    /// ```
224    #[cfg(feature = "a2a-interceptors")]
225    pub fn with_interceptor_chain(
226        mut self,
227        chain: Arc<crate::a2a::interceptor::InterceptorChain>,
228    ) -> Self {
229        self.interceptor_chain = Some(chain);
230        self
231    }
232
233    /// Configure directories containing YAML agent definitions to watch.
234    ///
235    /// When the `yaml-agent` feature is enabled and directories are configured,
236    /// the server starts a [`HotReloadWatcher`](crate::yaml_agent::HotReloadWatcher)
237    /// for each directory at startup, automatically loading and hot-reloading
238    /// YAML-defined agents.
239    #[cfg(feature = "yaml-agent")]
240    pub fn with_yaml_agent_dirs(mut self, dirs: Vec<PathBuf>) -> Self {
241        self.yaml_agent_dirs = dirs;
242        self
243    }
244
245    /// Add a single YAML agent directory to watch.
246    ///
247    /// Convenience method that appends one directory to the list.
248    #[cfg(feature = "yaml-agent")]
249    pub fn with_yaml_agent_dir(mut self, dir: impl Into<PathBuf>) -> Self {
250        self.yaml_agent_dirs.push(dir.into());
251        self
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use adk_core::{
259        Agent, BaseEventsSummarizer, Event, EventStream, InvocationContext, Result as AdkResult,
260        SingleAgentLoader,
261    };
262    use adk_session::InMemorySessionService;
263    use async_trait::async_trait;
264    use futures::stream;
265
266    struct TestAgent;
267
268    #[async_trait]
269    impl Agent for TestAgent {
270        fn name(&self) -> &str {
271            "server_config_test_agent"
272        }
273
274        fn description(&self) -> &str {
275            "server config test agent"
276        }
277
278        fn sub_agents(&self) -> &[Arc<dyn Agent>] {
279            &[]
280        }
281
282        async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> AdkResult<EventStream> {
283            Ok(Box::pin(stream::empty()))
284        }
285    }
286
287    struct TestCache;
288
289    struct TestSummarizer;
290
291    #[async_trait]
292    impl CacheCapable for TestCache {
293        async fn create_cache(
294            &self,
295            _system_instruction: &str,
296            _tools: &std::collections::HashMap<String, serde_json::Value>,
297            _ttl_seconds: u32,
298        ) -> adk_core::Result<String> {
299            Ok("cache".to_string())
300        }
301
302        async fn delete_cache(&self, _name: &str) -> adk_core::Result<()> {
303            Ok(())
304        }
305    }
306
307    #[async_trait]
308    impl BaseEventsSummarizer for TestSummarizer {
309        async fn summarize_events(&self, _events: &[Event]) -> adk_core::Result<Option<Event>> {
310            Ok(Some(Event::new("summary")))
311        }
312    }
313
314    fn test_config() -> ServerConfig {
315        let agent_loader = Arc::new(SingleAgentLoader::new(Arc::new(TestAgent)));
316        let session_service = Arc::new(InMemorySessionService::new());
317        ServerConfig::new(agent_loader, session_service)
318    }
319
320    #[test]
321    fn with_compaction_sets_optional_config() {
322        let compaction_config = EventsCompactionConfig {
323            compaction_interval: 10,
324            overlap_size: 2,
325            summarizer: Arc::new(TestSummarizer),
326        };
327
328        let config = test_config().with_compaction(compaction_config.clone());
329
330        assert!(config.compaction_config.is_some());
331        assert_eq!(config.compaction_config.as_ref().unwrap().compaction_interval, 10);
332        assert_eq!(config.compaction_config.as_ref().unwrap().overlap_size, 2);
333    }
334
335    #[test]
336    fn with_context_cache_sets_cache_fields() {
337        let context_cache_config =
338            ContextCacheConfig { min_tokens: 512, ttl_seconds: 300, cache_intervals: 2 };
339        let cache_capable = Arc::new(TestCache);
340
341        let config =
342            test_config().with_context_cache(context_cache_config.clone(), cache_capable.clone());
343
344        assert_eq!(config.context_cache_config.as_ref().unwrap().min_tokens, 512);
345        assert_eq!(config.context_cache_config.as_ref().unwrap().ttl_seconds, 300);
346        assert_eq!(config.context_cache_config.as_ref().unwrap().cache_intervals, 2);
347        assert!(config.cache_capable.is_some());
348        let configured = config.cache_capable.as_ref().unwrap();
349        let expected: Arc<dyn CacheCapable> = cache_capable;
350        assert!(Arc::ptr_eq(configured, &expected));
351    }
352}