1use std::collections::BTreeMap;
5use std::sync::Arc;
6
7use crate::factory::{
8 AuditSinkFactory, CredentialStoreFactory, LifecycleStoreFactory, PolicyStoreFactory, RateLimitCounterFactory,
9 SessionStoreFactory,
10};
11use crate::{
12 AuditSink, ConfigError, CredentialStore, DriverName, LifecycleStore, PolicyStore, RateLimitCounter, SessionStore,
13 StorageConfig,
14};
15
16#[derive(Default)]
23pub struct Registry {
24 policy_stores: BTreeMap<DriverName, Box<dyn PolicyStoreFactory>>,
25 audit_sinks: BTreeMap<DriverName, Box<dyn AuditSinkFactory>>,
26 session_stores: BTreeMap<DriverName, Box<dyn SessionStoreFactory>>,
27 credential_stores: BTreeMap<DriverName, Box<dyn CredentialStoreFactory>>,
28 rate_limit_counters: BTreeMap<DriverName, Box<dyn RateLimitCounterFactory>>,
29 lifecycle_stores: BTreeMap<DriverName, Box<dyn LifecycleStoreFactory>>,
30}
31
32impl Registry {
33 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn register_policy_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn PolicyStoreFactory>) {
40 self.policy_stores.insert(name.into(), factory);
41 }
42
43 pub fn register_audit_sink(&mut self, name: impl Into<DriverName>, factory: Box<dyn AuditSinkFactory>) {
45 self.audit_sinks.insert(name.into(), factory);
46 }
47
48 pub fn register_session_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn SessionStoreFactory>) {
50 self.session_stores.insert(name.into(), factory);
51 }
52
53 pub fn register_credential_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn CredentialStoreFactory>) {
55 self.credential_stores.insert(name.into(), factory);
56 }
57
58 pub fn register_rate_limit_counter(
60 &mut self,
61 name: impl Into<DriverName>,
62 factory: Box<dyn RateLimitCounterFactory>,
63 ) {
64 self.rate_limit_counters.insert(name.into(), factory);
65 }
66
67 pub fn register_lifecycle_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn LifecycleStoreFactory>) {
69 self.lifecycle_stores.insert(name.into(), factory);
70 }
71
72 pub fn policy_store_names(&self) -> Vec<&str> {
74 self.policy_stores.keys().map(DriverName::as_str).collect()
75 }
76
77 pub fn audit_sink_names(&self) -> Vec<&str> {
79 self.audit_sinks.keys().map(DriverName::as_str).collect()
80 }
81
82 pub fn session_store_names(&self) -> Vec<&str> {
84 self.session_stores.keys().map(DriverName::as_str).collect()
85 }
86
87 pub fn credential_store_names(&self) -> Vec<&str> {
89 self.credential_stores.keys().map(DriverName::as_str).collect()
90 }
91
92 pub fn rate_limit_counter_names(&self) -> Vec<&str> {
94 self.rate_limit_counters.keys().map(DriverName::as_str).collect()
95 }
96
97 pub fn lifecycle_store_names(&self) -> Vec<&str> {
99 self.lifecycle_stores.keys().map(DriverName::as_str).collect()
100 }
101
102 fn check<F: ?Sized>(
105 kind: &'static str,
106 name: &DriverName,
107 factories: &BTreeMap<DriverName, Box<F>>,
108 config: &StorageConfig,
109 ) -> Result<(), ConfigError> {
110 if !factories.contains_key(name) {
111 return Err(ConfigError::UnknownDriver {
112 kind,
113 name: name.to_string(),
114 available: factories.keys().map(DriverName::to_string).collect(),
115 });
116 }
117 if config.driver_section(name).is_none() {
118 return Err(ConfigError::MissingDriverSection {
119 kind,
120 name: name.to_string(),
121 });
122 }
123 Ok(())
124 }
125
126 pub fn validate(&self, config: &StorageConfig) -> Result<(), ConfigError> {
131 Self::check("policy_store", &config.policy_store, &self.policy_stores, config)?;
132 Self::check("audit_sink", &config.audit_sink, &self.audit_sinks, config)?;
133 Self::check("session_store", &config.session_store, &self.session_stores, config)?;
134 Self::check(
135 "credential_store",
136 &config.credential_store,
137 &self.credential_stores,
138 config,
139 )?;
140 Self::check(
141 "rate_limit_counter",
142 &config.rate_limit_counter,
143 &self.rate_limit_counters,
144 config,
145 )?;
146 Self::check(
147 "lifecycle_store",
148 &config.lifecycle_store,
149 &self.lifecycle_stores,
150 config,
151 )?;
152 Ok(())
153 }
154
155 pub fn build_policy_store(&self, config: &StorageConfig) -> Result<Arc<dyn PolicyStore>, ConfigError> {
157 let name = &config.policy_store;
158 Self::check("policy_store", name, &self.policy_stores, config)?;
159 let section = config.driver_section(name).expect("subsection checked by `check`");
160 self.policy_stores[name]
161 .build(section)
162 .map_err(|source| ConfigError::Build {
163 kind: "policy_store",
164 name: name.to_string(),
165 source,
166 })
167 }
168
169 pub fn build_audit_sink(&self, config: &StorageConfig) -> Result<Arc<dyn AuditSink>, ConfigError> {
171 let name = &config.audit_sink;
172 Self::check("audit_sink", name, &self.audit_sinks, config)?;
173 let section = config.driver_section(name).expect("subsection checked by `check`");
174 self.audit_sinks[name]
175 .build(section)
176 .map_err(|source| ConfigError::Build {
177 kind: "audit_sink",
178 name: name.to_string(),
179 source,
180 })
181 }
182
183 pub fn build_session_store(&self, config: &StorageConfig) -> Result<Arc<dyn SessionStore>, ConfigError> {
185 let name = &config.session_store;
186 Self::check("session_store", name, &self.session_stores, config)?;
187 let section = config.driver_section(name).expect("subsection checked by `check`");
188 self.session_stores[name]
189 .build(section)
190 .map_err(|source| ConfigError::Build {
191 kind: "session_store",
192 name: name.to_string(),
193 source,
194 })
195 }
196
197 pub fn build_credential_store(&self, config: &StorageConfig) -> Result<Arc<dyn CredentialStore>, ConfigError> {
199 let name = &config.credential_store;
200 Self::check("credential_store", name, &self.credential_stores, config)?;
201 let section = config.driver_section(name).expect("subsection checked by `check`");
202 self.credential_stores[name]
203 .build(section)
204 .map_err(|source| ConfigError::Build {
205 kind: "credential_store",
206 name: name.to_string(),
207 source,
208 })
209 }
210
211 pub fn build_rate_limit_counter(&self, config: &StorageConfig) -> Result<Arc<dyn RateLimitCounter>, ConfigError> {
213 let name = &config.rate_limit_counter;
214 Self::check("rate_limit_counter", name, &self.rate_limit_counters, config)?;
215 let section = config.driver_section(name).expect("subsection checked by `check`");
216 self.rate_limit_counters[name]
217 .build(section)
218 .map_err(|source| ConfigError::Build {
219 kind: "rate_limit_counter",
220 name: name.to_string(),
221 source,
222 })
223 }
224
225 pub fn build_lifecycle_store(&self, config: &StorageConfig) -> Result<Arc<dyn LifecycleStore>, ConfigError> {
227 let name = &config.lifecycle_store;
228 Self::check("lifecycle_store", name, &self.lifecycle_stores, config)?;
229 let section = config.driver_section(name).expect("subsection checked by `check`");
230 self.lifecycle_stores[name]
231 .build(section)
232 .map_err(|source| ConfigError::Build {
233 kind: "lifecycle_store",
234 name: name.to_string(),
235 source,
236 })
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::factory::{
244 AuditSinkFactory, CredentialStoreFactory, LifecycleStoreFactory, RateLimitCounterFactory, SessionStoreFactory,
245 };
246 use crate::{AgentId, PolicyDocument, StorageError};
247
248 struct FakePolicyStore;
250
251 #[async_trait::async_trait]
252 impl PolicyStore for FakePolicyStore {
253 async fn get_policy(&self, _agent_id: &AgentId) -> crate::Result<PolicyDocument> {
254 Err(StorageError::NotFound("fake".into()))
255 }
256 async fn invalidate(&self, _agent_id: &AgentId) -> crate::Result<()> {
257 Ok(())
258 }
259 }
260
261 struct FakeFactory;
264
265 impl PolicyStoreFactory for FakeFactory {
266 fn build(&self, _config: &toml::Value) -> crate::Result<Arc<dyn PolicyStore>> {
267 Ok(Arc::new(FakePolicyStore))
268 }
269 }
270 macro_rules! unused_factory {
271 ($trait:ident, $store:ident) => {
272 impl $trait for FakeFactory {
273 fn build(&self, _config: &toml::Value) -> crate::Result<Arc<dyn crate::$store>> {
274 Err(StorageError::Backend("unused in tests".into()))
275 }
276 }
277 };
278 }
279 unused_factory!(AuditSinkFactory, AuditSink);
280 unused_factory!(SessionStoreFactory, SessionStore);
281 unused_factory!(CredentialStoreFactory, CredentialStore);
282 unused_factory!(RateLimitCounterFactory, RateLimitCounter);
283 unused_factory!(LifecycleStoreFactory, LifecycleStore);
284
285 fn registry_with_memory() -> Registry {
287 let mut r = Registry::new();
288 r.register_policy_store("memory", Box::new(FakeFactory));
289 r.register_audit_sink("memory", Box::new(FakeFactory));
290 r.register_session_store("memory", Box::new(FakeFactory));
291 r.register_credential_store("memory", Box::new(FakeFactory));
292 r.register_rate_limit_counter("memory", Box::new(FakeFactory));
293 r.register_lifecycle_store("memory", Box::new(FakeFactory));
294 r
295 }
296
297 fn parse(toml_str: &str) -> StorageConfig {
298 toml::from_str(toml_str).expect("fixture parses")
299 }
300
301 const VALID: &str = r#"
302policy_store = "memory"
303audit_sink = "memory"
304session_store = "memory"
305credential_store = "memory"
306rate_limit_counter = "memory"
307lifecycle_store = "memory"
308
309[memory]
310flush_every = 100
311"#;
312
313 const UNKNOWN_DRIVER: &str = r#"
314policy_store = "mongodb"
315audit_sink = "memory"
316session_store = "memory"
317credential_store = "memory"
318rate_limit_counter = "memory"
319lifecycle_store = "memory"
320
321[memory]
322flush_every = 100
323
324[mongodb]
325url = "mongodb://localhost"
326"#;
327
328 const MISSING_SUBSECTION: &str = r#"
329policy_store = "memory"
330audit_sink = "memory"
331session_store = "memory"
332credential_store = "memory"
333rate_limit_counter = "memory"
334lifecycle_store = "memory"
335"#;
336
337 #[test]
338 fn storage_section_flattens_known_keys_and_subsections() {
339 let config = parse(VALID);
340 assert_eq!(config.policy_store.as_str(), "memory");
341 assert_eq!(config.lifecycle_store.as_str(), "memory");
342 assert!(config.driver_section(&DriverName::new("memory")).is_some());
344 }
345
346 #[test]
347 fn valid_combination_passes_validate_and_builds() {
348 let registry = registry_with_memory();
349 let config = parse(VALID);
350 assert!(registry.validate(&config).is_ok());
351 assert!(registry.build_policy_store(&config).is_ok());
352 }
353
354 #[test]
355 fn unknown_driver_reports_kind_name_and_available() {
356 let registry = registry_with_memory();
357 let config = parse(UNKNOWN_DRIVER);
358 let err = registry.validate(&config).unwrap_err();
359 match err {
360 ConfigError::UnknownDriver { kind, name, available } => {
361 assert_eq!(kind, "policy_store");
362 assert_eq!(name, "mongodb");
363 assert_eq!(available, vec!["memory".to_string()]);
364 }
365 other => panic!("expected UnknownDriver, got {other:?}"),
366 }
367 let rendered = registry.validate(&config).unwrap_err().to_string();
369 assert!(rendered.contains("memory"), "error lists valid names: {rendered}");
370 }
371
372 #[test]
373 fn missing_per_driver_subsection_is_rejected() {
374 let registry = registry_with_memory();
375 let config = parse(MISSING_SUBSECTION);
376 let err = registry.validate(&config).unwrap_err();
377 assert!(
378 matches!(err, ConfigError::MissingDriverSection { ref name, .. } if name == "memory"),
379 "expected MissingDriverSection, got {err:?}"
380 );
381 }
382
383 #[test]
384 fn builtin_registry_accepts_known_oss_driver_names() {
385 let mut registry = Registry::new();
386 crate::builtin::register_builtin_drivers(&mut registry);
387 let config = parse(
388 r#"
389policy_store = "redis"
390audit_sink = "postgres"
391session_store = "redis"
392credential_store = "postgres"
393rate_limit_counter = "redis"
394lifecycle_store = "postgres"
395
396[redis]
397url = "redis://localhost:6379"
398
399[postgres]
400url = "postgresql://localhost:5432/assembly"
401"#,
402 );
403 assert!(registry.validate(&config).is_ok());
404 assert!(registry.build_policy_store(&config).is_err());
406 }
407}