Skip to main content

adk_server/
config.rs

1use std::{sync::Arc, time::Duration};
2
3use crate::auth_bridge::RequestContextExtractor;
4use adk_core::{CacheCapable, ContextCacheConfig, EventsCompactionConfig};
5
6/// Security configuration for the ADK server.
7#[derive(Clone, Debug)]
8pub struct SecurityConfig {
9    /// Allowed origins for CORS (empty = allow all, which is NOT recommended for production)
10    pub allowed_origins: Vec<String>,
11    /// Maximum request body size in bytes (default: 10MB)
12    pub max_body_size: usize,
13    /// Request timeout duration (default: 30 seconds)
14    pub request_timeout: Duration,
15    /// Whether to include detailed error messages in responses (default: false for production)
16    pub expose_error_details: bool,
17    /// Whether to expose admin-only debug endpoints when auth is configured.
18    pub expose_admin_debug: bool,
19}
20
21impl Default for SecurityConfig {
22    fn default() -> Self {
23        Self {
24            allowed_origins: Vec::new(), // Empty = permissive (for dev), should be configured for prod
25            max_body_size: 10 * 1024 * 1024, // 10MB
26            request_timeout: Duration::from_secs(30),
27            expose_error_details: false,
28            expose_admin_debug: false,
29        }
30    }
31}
32
33impl SecurityConfig {
34    /// Create a development configuration (permissive CORS, detailed errors)
35    pub fn development() -> Self {
36        Self {
37            allowed_origins: Vec::new(),
38            max_body_size: 10 * 1024 * 1024,
39            request_timeout: Duration::from_secs(60),
40            expose_error_details: true,
41            expose_admin_debug: true,
42        }
43    }
44
45    /// Create a production configuration with specific allowed origins
46    pub fn production(allowed_origins: Vec<String>) -> Self {
47        Self {
48            allowed_origins,
49            max_body_size: 10 * 1024 * 1024,
50            request_timeout: Duration::from_secs(30),
51            expose_error_details: false,
52            expose_admin_debug: false,
53        }
54    }
55}
56
57/// Configuration for the ADK server.
58#[derive(Clone)]
59pub struct ServerConfig {
60    pub agent_loader: Arc<dyn adk_core::AgentLoader>,
61    pub session_service: Arc<dyn adk_session::SessionService>,
62    pub artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
63    pub memory_service: Option<Arc<dyn adk_core::Memory>>,
64    pub compaction_config: Option<EventsCompactionConfig>,
65    pub context_cache_config: Option<ContextCacheConfig>,
66    pub cache_capable: Option<Arc<dyn CacheCapable>>,
67    pub span_exporter: Option<Arc<adk_telemetry::AdkSpanExporter>>,
68    pub backend_url: Option<String>,
69    pub security: SecurityConfig,
70    pub request_context_extractor: Option<Arc<dyn RequestContextExtractor>>,
71}
72
73impl ServerConfig {
74    pub fn new(
75        agent_loader: Arc<dyn adk_core::AgentLoader>,
76        session_service: Arc<dyn adk_session::SessionService>,
77    ) -> Self {
78        Self {
79            agent_loader,
80            session_service,
81            artifact_service: None,
82            memory_service: None,
83            compaction_config: None,
84            context_cache_config: None,
85            cache_capable: None,
86            span_exporter: None,
87            backend_url: None,
88            security: SecurityConfig::default(),
89            request_context_extractor: None,
90        }
91    }
92
93    pub fn with_artifact_service(
94        mut self,
95        artifact_service: Arc<dyn adk_artifact::ArtifactService>,
96    ) -> Self {
97        self.artifact_service = Some(artifact_service);
98        self
99    }
100
101    pub fn with_artifact_service_opt(
102        mut self,
103        artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
104    ) -> Self {
105        self.artifact_service = artifact_service;
106        self
107    }
108
109    /// Configure a memory service for semantic search across sessions.
110    ///
111    /// When set, the runner injects memory into the invocation context,
112    /// allowing agents to search previous conversation content via
113    /// `ToolContext::search_memory()`.
114    pub fn with_memory_service(mut self, memory_service: Arc<dyn adk_core::Memory>) -> Self {
115        self.memory_service = Some(memory_service);
116        self
117    }
118
119    /// Configure automatic context compaction for long-running sessions.
120    pub fn with_compaction(mut self, compaction_config: EventsCompactionConfig) -> Self {
121        self.compaction_config = Some(compaction_config);
122        self
123    }
124
125    /// Configure automatic prompt-cache lifecycle management for cache-capable models.
126    pub fn with_context_cache(
127        mut self,
128        context_cache_config: ContextCacheConfig,
129        cache_capable: Arc<dyn CacheCapable>,
130    ) -> Self {
131        self.context_cache_config = Some(context_cache_config);
132        self.cache_capable = Some(cache_capable);
133        self
134    }
135
136    pub fn with_backend_url(mut self, backend_url: impl Into<String>) -> Self {
137        self.backend_url = Some(backend_url.into());
138        self
139    }
140
141    pub fn with_security(mut self, security: SecurityConfig) -> Self {
142        self.security = security;
143        self
144    }
145
146    pub fn with_span_exporter(
147        mut self,
148        span_exporter: Arc<adk_telemetry::AdkSpanExporter>,
149    ) -> Self {
150        self.span_exporter = Some(span_exporter);
151        self
152    }
153
154    /// Configure allowed CORS origins
155    pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
156        self.security.allowed_origins = origins;
157        self
158    }
159
160    /// Configure maximum request body size
161    pub fn with_max_body_size(mut self, size: usize) -> Self {
162        self.security.max_body_size = size;
163        self
164    }
165
166    /// Configure request timeout
167    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
168        self.security.request_timeout = timeout;
169        self
170    }
171
172    /// Enable detailed error messages (for development only)
173    pub fn with_error_details(mut self, expose: bool) -> Self {
174        self.security.expose_error_details = expose;
175        self
176    }
177
178    /// Configure a request context extractor for auth middleware bridging.
179    ///
180    /// When set, the server invokes the extractor on each incoming request
181    /// to extract authenticated identity (user_id, scopes, metadata) from
182    /// HTTP headers. The extracted context flows into `InvocationContext`,
183    /// making scopes available via `ToolContext::user_scopes()`.
184    pub fn with_request_context(mut self, extractor: Arc<dyn RequestContextExtractor>) -> Self {
185        self.request_context_extractor = Some(extractor);
186        self
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use adk_core::{
194        Agent, BaseEventsSummarizer, Event, EventStream, InvocationContext, Result as AdkResult,
195        SingleAgentLoader,
196    };
197    use adk_session::InMemorySessionService;
198    use async_trait::async_trait;
199    use futures::stream;
200
201    struct TestAgent;
202
203    #[async_trait]
204    impl Agent for TestAgent {
205        fn name(&self) -> &str {
206            "server_config_test_agent"
207        }
208
209        fn description(&self) -> &str {
210            "server config test agent"
211        }
212
213        fn sub_agents(&self) -> &[Arc<dyn Agent>] {
214            &[]
215        }
216
217        async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> AdkResult<EventStream> {
218            Ok(Box::pin(stream::empty()))
219        }
220    }
221
222    struct TestCache;
223
224    struct TestSummarizer;
225
226    #[async_trait]
227    impl CacheCapable for TestCache {
228        async fn create_cache(
229            &self,
230            _system_instruction: &str,
231            _tools: &std::collections::HashMap<String, serde_json::Value>,
232            _ttl_seconds: u32,
233        ) -> adk_core::Result<String> {
234            Ok("cache".to_string())
235        }
236
237        async fn delete_cache(&self, _name: &str) -> adk_core::Result<()> {
238            Ok(())
239        }
240    }
241
242    #[async_trait]
243    impl BaseEventsSummarizer for TestSummarizer {
244        async fn summarize_events(&self, _events: &[Event]) -> adk_core::Result<Option<Event>> {
245            Ok(Some(Event::new("summary")))
246        }
247    }
248
249    fn test_config() -> ServerConfig {
250        let agent_loader = Arc::new(SingleAgentLoader::new(Arc::new(TestAgent)));
251        let session_service = Arc::new(InMemorySessionService::new());
252        ServerConfig::new(agent_loader, session_service)
253    }
254
255    #[test]
256    fn with_compaction_sets_optional_config() {
257        let compaction_config = EventsCompactionConfig {
258            compaction_interval: 10,
259            overlap_size: 2,
260            summarizer: Arc::new(TestSummarizer),
261        };
262
263        let config = test_config().with_compaction(compaction_config.clone());
264
265        assert!(config.compaction_config.is_some());
266        assert_eq!(config.compaction_config.as_ref().unwrap().compaction_interval, 10);
267        assert_eq!(config.compaction_config.as_ref().unwrap().overlap_size, 2);
268    }
269
270    #[test]
271    fn with_context_cache_sets_cache_fields() {
272        let context_cache_config =
273            ContextCacheConfig { min_tokens: 512, ttl_seconds: 300, cache_intervals: 2 };
274        let cache_capable = Arc::new(TestCache);
275
276        let config =
277            test_config().with_context_cache(context_cache_config.clone(), cache_capable.clone());
278
279        assert_eq!(config.context_cache_config.as_ref().unwrap().min_tokens, 512);
280        assert_eq!(config.context_cache_config.as_ref().unwrap().ttl_seconds, 300);
281        assert_eq!(config.context_cache_config.as_ref().unwrap().cache_intervals, 2);
282        assert!(config.cache_capable.is_some());
283        let configured = config.cache_capable.as_ref().unwrap();
284        let expected: Arc<dyn CacheCapable> = cache_capable;
285        assert!(Arc::ptr_eq(configured, &expected));
286    }
287}