Skip to main content

secrets_rs/
source.rs

1use std::collections::HashMap;
2
3use crate::{
4    EnvSource,
5    error::{SourceError, SourceRegisterError},
6    urn::is_valid_source_id,
7};
8
9/// A source from which secret values can be retrieved.
10///
11/// Implementations are responsible for looking up a secret by `name` and
12/// returning its raw bytes. Case sensitivity of `name` is source-specific.
13pub trait Source: Send + Sync {
14    /// Retrieve the raw bytes for the secret identified by `name`.
15    fn get(&self, name: &str) -> Result<Vec<u8>, SourceError>;
16}
17
18/// A registry that maps source IDs to their [`Source`] implementations.
19///
20/// [`EnvSource`] is registered under `"env"` by default. Additional sources
21/// can be added with [`register`](SourceRegistry::register); registering an
22/// already-used ID replaces the previous source.
23///
24/// Pass a registry to [`Secret::bind`](crate::Secret::bind) or
25/// [`bind_all`](crate::bind_all) to resolve secrets.
26pub struct SourceRegistry {
27    sources: HashMap<String, Box<dyn Source>>,
28}
29
30impl SourceRegistry {
31    /// Creates a registry with [`EnvSource`] pre-registered under `"env"`.
32    pub fn new() -> Self {
33        let mut registry = Self {
34            sources: HashMap::new(),
35        };
36        registry
37            .register("env", EnvSource)
38            .expect("\"env\" is a valid source_id");
39        registry
40    }
41
42    /// Registers a source under the given `id`.
43    ///
44    /// The `id` must match the `source_id` component of the secret URN and
45    /// must consist only of characters valid in a URN NSS: ASCII letters,
46    /// digits, and `-._~!$&'()*+,;=@/`. Returns
47    /// [`SourceRegisterError::InvalidSourceId`] if the id is empty or contains
48    /// invalid characters.
49    ///
50    /// If `id` is already registered, the previous source is replaced.
51    pub fn register(
52        &mut self,
53        id: impl Into<String>,
54        source: impl Source + 'static,
55    ) -> Result<(), SourceRegisterError> {
56        let id = id.into();
57        if !is_valid_source_id(&id) {
58            return Err(SourceRegisterError::InvalidSourceId(id));
59        }
60        self.sources.insert(id, Box::new(source));
61        Ok(())
62    }
63
64    /// Returns the source registered under `source_id`, if any.
65    pub fn get(&self, source_id: &str) -> Option<&dyn Source> {
66        self.sources.get(source_id).map(|s| s.as_ref())
67    }
68}
69
70impl Default for SourceRegistry {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn new_pre_registers_env_source() {
82        let registry = SourceRegistry::new();
83        assert!(
84            registry.get("env").is_some(),
85            "expected \"env\" to be registered by default"
86        );
87    }
88
89    #[test]
90    fn new_does_not_register_file_by_default() {
91        let registry = SourceRegistry::new();
92        assert!(registry.get("file").is_none());
93    }
94
95    #[test]
96    fn env_source_resolves_without_explicit_registration() {
97        unsafe { std::env::set_var("REGISTRY_DEFAULT_ENV_TEST", "works") };
98        let result = SourceRegistry::new()
99            .get("env")
100            .unwrap()
101            .get("REGISTRY_DEFAULT_ENV_TEST");
102        unsafe { std::env::remove_var("REGISTRY_DEFAULT_ENV_TEST") };
103        assert_eq!(result.unwrap(), b"works");
104    }
105
106    #[test]
107    fn register_replaces_existing_id() {
108        struct ConstSource(&'static [u8]);
109        impl Source for ConstSource {
110            fn get(&self, _name: &str) -> Result<Vec<u8>, SourceError> {
111                Ok(self.0.to_vec())
112            }
113        }
114
115        let mut registry = SourceRegistry::new();
116        registry.register("env", ConstSource(b"replaced")).unwrap();
117        let result = registry.get("env").unwrap().get("anything").unwrap();
118        assert_eq!(result, b"replaced");
119    }
120
121    #[test]
122    fn register_accepts_valid_ids() {
123        let mut r = SourceRegistry::new();
124        // Letters, digits, and the full set of allowed punctuation.
125        for id in &[
126            "file",
127            "vault-sm",
128            "aws.secrets",
129            "gcp_sm",
130            "my~source",
131            "ns/sub",
132        ] {
133            assert!(
134                r.register(*id, crate::sources::env::EnvSource).is_ok(),
135                "rejected valid id: {id}"
136            );
137        }
138    }
139
140    #[test]
141    fn register_rejects_empty_id() {
142        let mut r = SourceRegistry::new();
143        assert_eq!(
144            r.register("", crate::sources::env::EnvSource).unwrap_err(),
145            SourceRegisterError::InvalidSourceId(String::new()),
146        );
147    }
148
149    #[test]
150    fn register_rejects_id_with_colon() {
151        let mut r = SourceRegistry::new();
152        let err = r
153            .register("bad:id", crate::sources::env::EnvSource)
154            .unwrap_err();
155        assert_eq!(
156            err,
157            SourceRegisterError::InvalidSourceId("bad:id".to_owned())
158        );
159    }
160
161    #[test]
162    fn register_rejects_id_with_space() {
163        let mut r = SourceRegistry::new();
164        let err = r
165            .register("bad id", crate::sources::env::EnvSource)
166            .unwrap_err();
167        assert_eq!(
168            err,
169            SourceRegisterError::InvalidSourceId("bad id".to_owned())
170        );
171    }
172
173    #[test]
174    fn register_rejects_id_with_non_ascii() {
175        let mut r = SourceRegistry::new();
176        let err = r
177            .register("café", crate::sources::env::EnvSource)
178            .unwrap_err();
179        assert_eq!(err, SourceRegisterError::InvalidSourceId("café".to_owned()));
180    }
181}