Skip to main content

adk_server/
config.rs

1use std::{sync::Arc, time::Duration};
2
3use crate::auth_bridge::RequestContextExtractor;
4use adk_core::{CacheCapable, ContextCacheConfig, EventsCompactionConfig};
5
6#[cfg(feature = "yaml-agent")]
7use std::path::PathBuf;
8
9/// Security configuration for the ADK server.
10#[derive(Clone, Debug)]
11pub struct SecurityConfig {
12    /// Allowed origins for CORS (empty = allow all, which is NOT recommended for production)
13    pub allowed_origins: Vec<String>,
14    /// Maximum request body size in bytes (default: 10MB)
15    pub max_body_size: usize,
16    /// Request timeout duration (default: 30 seconds)
17    pub request_timeout: Duration,
18    /// Whether to include detailed error messages in responses (default: false for production)
19    pub expose_error_details: bool,
20    /// Whether to expose admin-only debug endpoints when auth is configured.
21    pub expose_admin_debug: bool,
22}
23
24impl Default for SecurityConfig {
25    fn default() -> Self {
26        Self {
27            allowed_origins: Vec::new(), // Empty = permissive (for dev), should be configured for prod
28            max_body_size: 10 * 1024 * 1024, // 10MB
29            request_timeout: Duration::from_secs(30),
30            expose_error_details: false,
31            expose_admin_debug: false,
32        }
33    }
34}
35
36impl SecurityConfig {
37    /// Create a development configuration (permissive CORS, detailed errors)
38    pub fn development() -> Self {
39        Self {
40            allowed_origins: Vec::new(),
41            max_body_size: 10 * 1024 * 1024,
42            request_timeout: Duration::from_secs(60),
43            expose_error_details: true,
44            expose_admin_debug: true,
45        }
46    }
47
48    /// Create a production configuration with specific allowed origins
49    pub fn production(allowed_origins: Vec<String>) -> Self {
50        Self {
51            allowed_origins,
52            max_body_size: 10 * 1024 * 1024,
53            request_timeout: Duration::from_secs(30),
54            expose_error_details: false,
55            expose_admin_debug: false,
56        }
57    }
58}
59
60/// Configuration for the ADK server.
61#[derive(Clone)]
62pub struct ServerConfig {
63    pub agent_loader: Arc<dyn adk_core::AgentLoader>,
64    pub session_service: Arc<dyn adk_session::SessionService>,
65    pub artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
66    pub memory_service: Option<Arc<dyn adk_core::Memory>>,
67    pub compaction_config: Option<EventsCompactionConfig>,
68    pub context_cache_config: Option<ContextCacheConfig>,
69    pub cache_capable: Option<Arc<dyn CacheCapable>>,
70    pub span_exporter: Option<Arc<adk_telemetry::AdkSpanExporter>>,
71    pub backend_url: Option<String>,
72    pub security: SecurityConfig,
73    pub request_context_extractor: Option<Arc<dyn RequestContextExtractor>>,
74    /// Directories containing YAML agent definitions to watch for hot reload.
75    /// Only used when the `yaml-agent` feature is enabled.
76    #[cfg(feature = "yaml-agent")]
77    pub yaml_agent_dirs: Vec<PathBuf>,
78}
79
80impl ServerConfig {
81    pub fn new(
82        agent_loader: Arc<dyn adk_core::AgentLoader>,
83        session_service: Arc<dyn adk_session::SessionService>,
84    ) -> Self {
85        Self {
86            agent_loader,
87            session_service,
88            artifact_service: None,
89            memory_service: None,
90            compaction_config: None,
91            context_cache_config: None,
92            cache_capable: None,
93            span_exporter: None,
94            backend_url: None,
95            security: SecurityConfig::default(),
96            request_context_extractor: None,
97            #[cfg(feature = "yaml-agent")]
98            yaml_agent_dirs: Vec::new(),
99        }
100    }
101
102    pub fn with_artifact_service(
103        mut self,
104        artifact_service: Arc<dyn adk_artifact::ArtifactService>,
105    ) -> Self {
106        self.artifact_service = Some(artifact_service);
107        self
108    }
109
110    pub fn with_artifact_service_opt(
111        mut self,
112        artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
113    ) -> Self {
114        self.artifact_service = artifact_service;
115        self
116    }
117
118    /// Configure a memory service for semantic search across sessions.
119    ///
120    /// When set, the runner injects memory into the invocation context,
121    /// allowing agents to search previous conversation content via
122    /// `ToolContext::search_memory()`.
123    pub fn with_memory_service(mut self, memory_service: Arc<dyn adk_core::Memory>) -> Self {
124        self.memory_service = Some(memory_service);
125        self
126    }
127
128    /// Configure automatic context compaction for long-running sessions.
129    pub fn with_compaction(mut self, compaction_config: EventsCompactionConfig) -> Self {
130        self.compaction_config = Some(compaction_config);
131        self
132    }
133
134    /// Configure automatic prompt-cache lifecycle management for cache-capable models.
135    pub fn with_context_cache(
136        mut self,
137        context_cache_config: ContextCacheConfig,
138        cache_capable: Arc<dyn CacheCapable>,
139    ) -> Self {
140        self.context_cache_config = Some(context_cache_config);
141        self.cache_capable = Some(cache_capable);
142        self
143    }
144
145    pub fn with_backend_url(mut self, backend_url: impl Into<String>) -> Self {
146        self.backend_url = Some(backend_url.into());
147        self
148    }
149
150    pub fn with_security(mut self, security: SecurityConfig) -> Self {
151        self.security = security;
152        self
153    }
154
155    pub fn with_span_exporter(
156        mut self,
157        span_exporter: Arc<adk_telemetry::AdkSpanExporter>,
158    ) -> Self {
159        self.span_exporter = Some(span_exporter);
160        self
161    }
162
163    /// Configure allowed CORS origins
164    pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
165        self.security.allowed_origins = origins;
166        self
167    }
168
169    /// Configure maximum request body size
170    pub fn with_max_body_size(mut self, size: usize) -> Self {
171        self.security.max_body_size = size;
172        self
173    }
174
175    /// Configure request timeout
176    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
177        self.security.request_timeout = timeout;
178        self
179    }
180
181    /// Enable detailed error messages (for development only)
182    pub fn with_error_details(mut self, expose: bool) -> Self {
183        self.security.expose_error_details = expose;
184        self
185    }
186
187    /// Configure a request context extractor for auth middleware bridging.
188    ///
189    /// When set, the server invokes the extractor on each incoming request
190    /// to extract authenticated identity (user_id, scopes, metadata) from
191    /// HTTP headers. The extracted context flows into `InvocationContext`,
192    /// making scopes available via `ToolContext::user_scopes()`.
193    pub fn with_request_context(mut self, extractor: Arc<dyn RequestContextExtractor>) -> Self {
194        self.request_context_extractor = Some(extractor);
195        self
196    }
197
198    /// Configure directories containing YAML agent definitions to watch.
199    ///
200    /// When the `yaml-agent` feature is enabled and directories are configured,
201    /// the server starts a [`HotReloadWatcher`](crate::yaml_agent::HotReloadWatcher)
202    /// for each directory at startup, automatically loading and hot-reloading
203    /// YAML-defined agents.
204    #[cfg(feature = "yaml-agent")]
205    pub fn with_yaml_agent_dirs(mut self, dirs: Vec<PathBuf>) -> Self {
206        self.yaml_agent_dirs = dirs;
207        self
208    }
209
210    /// Add a single YAML agent directory to watch.
211    ///
212    /// Convenience method that appends one directory to the list.
213    #[cfg(feature = "yaml-agent")]
214    pub fn with_yaml_agent_dir(mut self, dir: impl Into<PathBuf>) -> Self {
215        self.yaml_agent_dirs.push(dir.into());
216        self
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use adk_core::{
224        Agent, BaseEventsSummarizer, Event, EventStream, InvocationContext, Result as AdkResult,
225        SingleAgentLoader,
226    };
227    use adk_session::InMemorySessionService;
228    use async_trait::async_trait;
229    use futures::stream;
230
231    struct TestAgent;
232
233    #[async_trait]
234    impl Agent for TestAgent {
235        fn name(&self) -> &str {
236            "server_config_test_agent"
237        }
238
239        fn description(&self) -> &str {
240            "server config test agent"
241        }
242
243        fn sub_agents(&self) -> &[Arc<dyn Agent>] {
244            &[]
245        }
246
247        async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> AdkResult<EventStream> {
248            Ok(Box::pin(stream::empty()))
249        }
250    }
251
252    struct TestCache;
253
254    struct TestSummarizer;
255
256    #[async_trait]
257    impl CacheCapable for TestCache {
258        async fn create_cache(
259            &self,
260            _system_instruction: &str,
261            _tools: &std::collections::HashMap<String, serde_json::Value>,
262            _ttl_seconds: u32,
263        ) -> adk_core::Result<String> {
264            Ok("cache".to_string())
265        }
266
267        async fn delete_cache(&self, _name: &str) -> adk_core::Result<()> {
268            Ok(())
269        }
270    }
271
272    #[async_trait]
273    impl BaseEventsSummarizer for TestSummarizer {
274        async fn summarize_events(&self, _events: &[Event]) -> adk_core::Result<Option<Event>> {
275            Ok(Some(Event::new("summary")))
276        }
277    }
278
279    fn test_config() -> ServerConfig {
280        let agent_loader = Arc::new(SingleAgentLoader::new(Arc::new(TestAgent)));
281        let session_service = Arc::new(InMemorySessionService::new());
282        ServerConfig::new(agent_loader, session_service)
283    }
284
285    #[test]
286    fn with_compaction_sets_optional_config() {
287        let compaction_config = EventsCompactionConfig {
288            compaction_interval: 10,
289            overlap_size: 2,
290            summarizer: Arc::new(TestSummarizer),
291        };
292
293        let config = test_config().with_compaction(compaction_config.clone());
294
295        assert!(config.compaction_config.is_some());
296        assert_eq!(config.compaction_config.as_ref().unwrap().compaction_interval, 10);
297        assert_eq!(config.compaction_config.as_ref().unwrap().overlap_size, 2);
298    }
299
300    #[test]
301    fn with_context_cache_sets_cache_fields() {
302        let context_cache_config =
303            ContextCacheConfig { min_tokens: 512, ttl_seconds: 300, cache_intervals: 2 };
304        let cache_capable = Arc::new(TestCache);
305
306        let config =
307            test_config().with_context_cache(context_cache_config.clone(), cache_capable.clone());
308
309        assert_eq!(config.context_cache_config.as_ref().unwrap().min_tokens, 512);
310        assert_eq!(config.context_cache_config.as_ref().unwrap().ttl_seconds, 300);
311        assert_eq!(config.context_cache_config.as_ref().unwrap().cache_intervals, 2);
312        assert!(config.cache_capable.is_some());
313        let configured = config.cache_capable.as_ref().unwrap();
314        let expected: Arc<dyn CacheCapable> = cache_capable;
315        assert!(Arc::ptr_eq(configured, &expected));
316    }
317}