Skip to main content

aa_storage/
registry.rs

1//! [`Registry`] — maps driver names to backend factories and resolves a
2//! [`StorageConfig`] into concrete stores.
3
4use std::collections::BTreeMap;
5use std::sync::Arc;
6
7use crate::factory::{
8    AuditSinkFactory, CredentialStoreFactory, LifecycleStoreFactory, PolicyStoreFactory, RateLimitCounterFactory,
9    SessionStoreFactory,
10};
11use crate::{
12    AuditSink, ConfigError, CredentialStore, DriverName, LifecycleStore, PolicyStore, RateLimitCounter, SessionStore,
13    StorageConfig,
14};
15
16/// Registry of storage-driver factories, keyed by [`DriverName`] per kind.
17///
18/// Each driver crate calls the `register_*` methods (typically from a single
19/// `register(&mut Registry)` entry point) to make its backends selectable by
20/// name. The loader then uses [`validate`](Registry::validate) to check a
21/// [`StorageConfig`] and the `build_*` methods to instantiate the chosen stores.
22#[derive(Default)]
23pub struct Registry {
24    policy_stores: BTreeMap<DriverName, Box<dyn PolicyStoreFactory>>,
25    audit_sinks: BTreeMap<DriverName, Box<dyn AuditSinkFactory>>,
26    session_stores: BTreeMap<DriverName, Box<dyn SessionStoreFactory>>,
27    credential_stores: BTreeMap<DriverName, Box<dyn CredentialStoreFactory>>,
28    rate_limit_counters: BTreeMap<DriverName, Box<dyn RateLimitCounterFactory>>,
29    lifecycle_stores: BTreeMap<DriverName, Box<dyn LifecycleStoreFactory>>,
30}
31
32impl Registry {
33    /// Create an empty registry with no drivers registered.
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Register a policy-store driver under `name`.
39    pub fn register_policy_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn PolicyStoreFactory>) {
40        self.policy_stores.insert(name.into(), factory);
41    }
42
43    /// Register an audit-sink driver under `name`.
44    pub fn register_audit_sink(&mut self, name: impl Into<DriverName>, factory: Box<dyn AuditSinkFactory>) {
45        self.audit_sinks.insert(name.into(), factory);
46    }
47
48    /// Register a session-store driver under `name`.
49    pub fn register_session_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn SessionStoreFactory>) {
50        self.session_stores.insert(name.into(), factory);
51    }
52
53    /// Register a credential-store driver under `name`.
54    pub fn register_credential_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn CredentialStoreFactory>) {
55        self.credential_stores.insert(name.into(), factory);
56    }
57
58    /// Register a rate-limit-counter driver under `name`.
59    pub fn register_rate_limit_counter(
60        &mut self,
61        name: impl Into<DriverName>,
62        factory: Box<dyn RateLimitCounterFactory>,
63    ) {
64        self.rate_limit_counters.insert(name.into(), factory);
65    }
66
67    /// Register a lifecycle-store driver under `name`.
68    pub fn register_lifecycle_store(&mut self, name: impl Into<DriverName>, factory: Box<dyn LifecycleStoreFactory>) {
69        self.lifecycle_stores.insert(name.into(), factory);
70    }
71
72    /// Names of all registered policy-store drivers, sorted.
73    pub fn policy_store_names(&self) -> Vec<&str> {
74        self.policy_stores.keys().map(DriverName::as_str).collect()
75    }
76
77    /// Names of all registered audit-sink drivers, sorted.
78    pub fn audit_sink_names(&self) -> Vec<&str> {
79        self.audit_sinks.keys().map(DriverName::as_str).collect()
80    }
81
82    /// Names of all registered session-store drivers, sorted.
83    pub fn session_store_names(&self) -> Vec<&str> {
84        self.session_stores.keys().map(DriverName::as_str).collect()
85    }
86
87    /// Names of all registered credential-store drivers, sorted.
88    pub fn credential_store_names(&self) -> Vec<&str> {
89        self.credential_stores.keys().map(DriverName::as_str).collect()
90    }
91
92    /// Names of all registered rate-limit-counter drivers, sorted.
93    pub fn rate_limit_counter_names(&self) -> Vec<&str> {
94        self.rate_limit_counters.keys().map(DriverName::as_str).collect()
95    }
96
97    /// Names of all registered lifecycle-store drivers, sorted.
98    pub fn lifecycle_store_names(&self) -> Vec<&str> {
99        self.lifecycle_stores.keys().map(DriverName::as_str).collect()
100    }
101
102    /// Check that `name` is registered in `factories` and that `config` carries
103    /// its `[storage.<name>]` subsection.
104    fn check<F: ?Sized>(
105        kind: &'static str,
106        name: &DriverName,
107        factories: &BTreeMap<DriverName, Box<F>>,
108        config: &StorageConfig,
109    ) -> Result<(), ConfigError> {
110        if !factories.contains_key(name) {
111            return Err(ConfigError::UnknownDriver {
112                kind,
113                name: name.to_string(),
114                available: factories.keys().map(DriverName::to_string).collect(),
115            });
116        }
117        if config.driver_section(name).is_none() {
118            return Err(ConfigError::MissingDriverSection {
119                kind,
120                name: name.to_string(),
121            });
122        }
123        Ok(())
124    }
125
126    /// Check every driver named in `config` is registered and has a subsection.
127    ///
128    /// Returns the first [`ConfigError`] encountered. This is the entry point
129    /// behind `aasm config validate`.
130    pub fn validate(&self, config: &StorageConfig) -> Result<(), ConfigError> {
131        Self::check("policy_store", &config.policy_store, &self.policy_stores, config)?;
132        Self::check("audit_sink", &config.audit_sink, &self.audit_sinks, config)?;
133        Self::check("session_store", &config.session_store, &self.session_stores, config)?;
134        Self::check(
135            "credential_store",
136            &config.credential_store,
137            &self.credential_stores,
138            config,
139        )?;
140        Self::check(
141            "rate_limit_counter",
142            &config.rate_limit_counter,
143            &self.rate_limit_counters,
144            config,
145        )?;
146        Self::check(
147            "lifecycle_store",
148            &config.lifecycle_store,
149            &self.lifecycle_stores,
150            config,
151        )?;
152        Ok(())
153    }
154
155    /// Build the configured [`PolicyStore`].
156    pub fn build_policy_store(&self, config: &StorageConfig) -> Result<Arc<dyn PolicyStore>, ConfigError> {
157        let name = &config.policy_store;
158        Self::check("policy_store", name, &self.policy_stores, config)?;
159        let section = config.driver_section(name).expect("subsection checked by `check`");
160        self.policy_stores[name]
161            .build(section)
162            .map_err(|source| ConfigError::Build {
163                kind: "policy_store",
164                name: name.to_string(),
165                source,
166            })
167    }
168
169    /// Build the configured [`AuditSink`].
170    pub fn build_audit_sink(&self, config: &StorageConfig) -> Result<Arc<dyn AuditSink>, ConfigError> {
171        let name = &config.audit_sink;
172        Self::check("audit_sink", name, &self.audit_sinks, config)?;
173        let section = config.driver_section(name).expect("subsection checked by `check`");
174        self.audit_sinks[name]
175            .build(section)
176            .map_err(|source| ConfigError::Build {
177                kind: "audit_sink",
178                name: name.to_string(),
179                source,
180            })
181    }
182
183    /// Build the configured [`SessionStore`].
184    pub fn build_session_store(&self, config: &StorageConfig) -> Result<Arc<dyn SessionStore>, ConfigError> {
185        let name = &config.session_store;
186        Self::check("session_store", name, &self.session_stores, config)?;
187        let section = config.driver_section(name).expect("subsection checked by `check`");
188        self.session_stores[name]
189            .build(section)
190            .map_err(|source| ConfigError::Build {
191                kind: "session_store",
192                name: name.to_string(),
193                source,
194            })
195    }
196
197    /// Build the configured [`CredentialStore`].
198    pub fn build_credential_store(&self, config: &StorageConfig) -> Result<Arc<dyn CredentialStore>, ConfigError> {
199        let name = &config.credential_store;
200        Self::check("credential_store", name, &self.credential_stores, config)?;
201        let section = config.driver_section(name).expect("subsection checked by `check`");
202        self.credential_stores[name]
203            .build(section)
204            .map_err(|source| ConfigError::Build {
205                kind: "credential_store",
206                name: name.to_string(),
207                source,
208            })
209    }
210
211    /// Build the configured [`RateLimitCounter`].
212    pub fn build_rate_limit_counter(&self, config: &StorageConfig) -> Result<Arc<dyn RateLimitCounter>, ConfigError> {
213        let name = &config.rate_limit_counter;
214        Self::check("rate_limit_counter", name, &self.rate_limit_counters, config)?;
215        let section = config.driver_section(name).expect("subsection checked by `check`");
216        self.rate_limit_counters[name]
217            .build(section)
218            .map_err(|source| ConfigError::Build {
219                kind: "rate_limit_counter",
220                name: name.to_string(),
221                source,
222            })
223    }
224
225    /// Build the configured [`LifecycleStore`].
226    pub fn build_lifecycle_store(&self, config: &StorageConfig) -> Result<Arc<dyn LifecycleStore>, ConfigError> {
227        let name = &config.lifecycle_store;
228        Self::check("lifecycle_store", name, &self.lifecycle_stores, config)?;
229        let section = config.driver_section(name).expect("subsection checked by `check`");
230        self.lifecycle_stores[name]
231            .build(section)
232            .map_err(|source| ConfigError::Build {
233                kind: "lifecycle_store",
234                name: name.to_string(),
235                source,
236            })
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::factory::{
244        AuditSinkFactory, CredentialStoreFactory, LifecycleStoreFactory, RateLimitCounterFactory, SessionStoreFactory,
245    };
246    use crate::{AgentId, PolicyDocument, StorageError};
247
248    /// A policy store that exists only so a factory has something to return.
249    struct FakePolicyStore;
250
251    #[async_trait::async_trait]
252    impl PolicyStore for FakePolicyStore {
253        async fn get_policy(&self, _agent_id: &AgentId) -> crate::Result<PolicyDocument> {
254            Err(StorageError::NotFound("fake".into()))
255        }
256        async fn invalidate(&self, _agent_id: &AgentId) -> crate::Result<()> {
257            Ok(())
258        }
259    }
260
261    /// One factory registerable for every kind. Only the policy-store build is
262    /// exercised; the other kinds only need to be *present* for `validate`.
263    struct FakeFactory;
264
265    impl PolicyStoreFactory for FakeFactory {
266        fn build(&self, _config: &toml::Value) -> crate::Result<Arc<dyn PolicyStore>> {
267            Ok(Arc::new(FakePolicyStore))
268        }
269    }
270    macro_rules! unused_factory {
271        ($trait:ident, $store:ident) => {
272            impl $trait for FakeFactory {
273                fn build(&self, _config: &toml::Value) -> crate::Result<Arc<dyn crate::$store>> {
274                    Err(StorageError::Backend("unused in tests".into()))
275                }
276            }
277        };
278    }
279    unused_factory!(AuditSinkFactory, AuditSink);
280    unused_factory!(SessionStoreFactory, SessionStore);
281    unused_factory!(CredentialStoreFactory, CredentialStore);
282    unused_factory!(RateLimitCounterFactory, RateLimitCounter);
283    unused_factory!(LifecycleStoreFactory, LifecycleStore);
284
285    /// Registry with the `"memory"` driver registered for every kind.
286    fn registry_with_memory() -> Registry {
287        let mut r = Registry::new();
288        r.register_policy_store("memory", Box::new(FakeFactory));
289        r.register_audit_sink("memory", Box::new(FakeFactory));
290        r.register_session_store("memory", Box::new(FakeFactory));
291        r.register_credential_store("memory", Box::new(FakeFactory));
292        r.register_rate_limit_counter("memory", Box::new(FakeFactory));
293        r.register_lifecycle_store("memory", Box::new(FakeFactory));
294        r
295    }
296
297    fn parse(toml_str: &str) -> StorageConfig {
298        toml::from_str(toml_str).expect("fixture parses")
299    }
300
301    const VALID: &str = r#"
302policy_store       = "memory"
303audit_sink         = "memory"
304session_store      = "memory"
305credential_store   = "memory"
306rate_limit_counter = "memory"
307lifecycle_store    = "memory"
308
309[memory]
310flush_every = 100
311"#;
312
313    const UNKNOWN_DRIVER: &str = r#"
314policy_store       = "mongodb"
315audit_sink         = "memory"
316session_store      = "memory"
317credential_store   = "memory"
318rate_limit_counter = "memory"
319lifecycle_store    = "memory"
320
321[memory]
322flush_every = 100
323
324[mongodb]
325url = "mongodb://localhost"
326"#;
327
328    const MISSING_SUBSECTION: &str = r#"
329policy_store       = "memory"
330audit_sink         = "memory"
331session_store      = "memory"
332credential_store   = "memory"
333rate_limit_counter = "memory"
334lifecycle_store    = "memory"
335"#;
336
337    #[test]
338    fn storage_section_flattens_known_keys_and_subsections() {
339        let config = parse(VALID);
340        assert_eq!(config.policy_store.as_str(), "memory");
341        assert_eq!(config.lifecycle_store.as_str(), "memory");
342        // The `[memory]` table is captured as a per-driver subsection.
343        assert!(config.driver_section(&DriverName::new("memory")).is_some());
344    }
345
346    #[test]
347    fn valid_combination_passes_validate_and_builds() {
348        let registry = registry_with_memory();
349        let config = parse(VALID);
350        assert!(registry.validate(&config).is_ok());
351        assert!(registry.build_policy_store(&config).is_ok());
352    }
353
354    #[test]
355    fn unknown_driver_reports_kind_name_and_available() {
356        let registry = registry_with_memory();
357        let config = parse(UNKNOWN_DRIVER);
358        let err = registry.validate(&config).unwrap_err();
359        match err {
360            ConfigError::UnknownDriver { kind, name, available } => {
361                assert_eq!(kind, "policy_store");
362                assert_eq!(name, "mongodb");
363                assert_eq!(available, vec!["memory".to_string()]);
364            }
365            other => panic!("expected UnknownDriver, got {other:?}"),
366        }
367        // The valid names appear in the rendered error message.
368        let rendered = registry.validate(&config).unwrap_err().to_string();
369        assert!(rendered.contains("memory"), "error lists valid names: {rendered}");
370    }
371
372    #[test]
373    fn missing_per_driver_subsection_is_rejected() {
374        let registry = registry_with_memory();
375        let config = parse(MISSING_SUBSECTION);
376        let err = registry.validate(&config).unwrap_err();
377        assert!(
378            matches!(err, ConfigError::MissingDriverSection { ref name, .. } if name == "memory"),
379            "expected MissingDriverSection, got {err:?}"
380        );
381    }
382
383    #[test]
384    fn builtin_registry_accepts_known_oss_driver_names() {
385        let mut registry = Registry::new();
386        crate::builtin::register_builtin_drivers(&mut registry);
387        let config = parse(
388            r#"
389policy_store       = "redis"
390audit_sink         = "postgres"
391session_store      = "redis"
392credential_store   = "postgres"
393rate_limit_counter = "redis"
394lifecycle_store    = "postgres"
395
396[redis]
397url = "redis://localhost:6379"
398
399[postgres]
400url = "postgresql://localhost:5432/assembly"
401"#,
402        );
403        assert!(registry.validate(&config).is_ok());
404        // But building a not-yet-implemented backend surfaces a clear error.
405        assert!(registry.build_policy_store(&config).is_err());
406    }
407}