architect_api/utils/
secrets.rs

1//! Types for working with the secret store
2
3use schemars::{
4    r#gen::SchemaGenerator,
5    schema::{InstanceType, Schema, SchemaObject},
6    JsonSchema,
7};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use std::{fmt::Display, str::FromStr};
10use zeroize::{Zeroize, Zeroizing};
11
12/// A type that is either a reference to a secret, serialized as
13/// a URI string like secrets://<key>, or a plain literal.
14#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
15pub enum MaybeSecret<T: Zeroize> {
16    Secret(String),
17    Plain(Zeroizing<T>),
18}
19
20impl<T: Zeroize> MaybeSecret<T> {
21    pub fn key(&self) -> Option<String> {
22        match self {
23            MaybeSecret::Secret(s) => Some(s.clone()),
24            MaybeSecret::Plain(_) => None,
25        }
26    }
27
28    pub fn secret<S: AsRef<str>>(key: S) -> Self {
29        MaybeSecret::Secret(key.as_ref().to_string())
30    }
31
32    pub fn plain(t: T) -> Self {
33        MaybeSecret::Plain(Zeroizing::new(t))
34    }
35}
36
37impl<T: Clone + Zeroize> MaybeSecret<T> {
38    pub fn to_plain(&self) -> Option<Zeroizing<T>> {
39        match self {
40            MaybeSecret::Secret(_) => None,
41            MaybeSecret::Plain(t) => Some(t.clone()),
42        }
43    }
44}
45
46// Most useful implementations of T for MaybeSecret will require
47// a FromStr implementation.  If you don't have one handy, use
48// this macro to get a reasonable-ish one using serde_json.
49#[macro_export]
50macro_rules! from_str_json {
51    ($t:ty) => {
52        impl std::str::FromStr for $t {
53            type Err = serde_json::Error;
54
55            fn from_str(s: &str) -> Result<Self, Self::Err> {
56                serde_json::from_str(s)
57            }
58        }
59    };
60}
61
62impl<T: Serialize + Zeroize + JsonSchema> JsonSchema for MaybeSecret<T> {
63    fn schema_name() -> String {
64        // Exclude the module path to make the name in generated schemas clearer.
65        "MaybeSecret".to_owned()
66    }
67
68    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
69        SchemaObject {
70            instance_type: Some(InstanceType::String.into()),
71            ..Default::default()
72        }
73        .into()
74    }
75
76    fn is_referenceable() -> bool {
77        true
78    }
79}
80
81impl<T: Display + Serialize + Zeroize> Display for MaybeSecret<T> {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            MaybeSecret::Secret(s) => write!(f, "secrets://{}", s),
85            MaybeSecret::Plain(s) => {
86                write!(f, "{}", serde_json::to_string(&**s).map_err(|_| std::fmt::Error)?)
87            }
88        }
89    }
90}
91
92impl<T: FromStr + Zeroize> FromStr for MaybeSecret<T> {
93    type Err = <T as FromStr>::Err;
94
95    fn from_str(s: &str) -> Result<Self, Self::Err> {
96        if let Some(key) = s.strip_prefix("secrets://") {
97            Ok(MaybeSecret::Secret(key.to_string()))
98        } else {
99            Ok(MaybeSecret::Plain(Zeroizing::new(s.parse()?)))
100        }
101    }
102}
103
104impl<T: Serialize + Zeroize> Serialize for MaybeSecret<T> {
105    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
106        match self {
107            MaybeSecret::Secret(s) => ser.serialize_str(&format!("secrets://{}", s)),
108            MaybeSecret::Plain(t) => (*t).serialize(ser),
109        }
110    }
111}
112
113impl<'de, T: DeserializeOwned + FromStr + Zeroize> Deserialize<'de> for MaybeSecret<T> {
114    fn deserialize<D>(de: D) -> Result<Self, D::Error>
115    where
116        D: serde::Deserializer<'de>,
117    {
118        #[derive(Serialize, Deserialize)]
119        #[serde(untagged)]
120        enum Format<T> {
121            SecretOrString(String),
122            Plain(T),
123        }
124        match Format::<T>::deserialize(de)? {
125            Format::SecretOrString(s) => {
126                if let Some(key) = s.strip_prefix("secrets://") {
127                    Ok(MaybeSecret::Secret(key.to_string()))
128                } else {
129                    // using FromStr here is hacky but it works for the
130                    // important cases of T = String, &str, etc... at
131                    // the cost of requiring FromStr from structs
132                    //
133                    // if you're looking for some dumb FromStr to use
134                    // try the FromStrJson macro in derive
135                    //
136                    // maybe there's some trick leveraging auto(de)ref
137                    // specialization [https://lukaskalbertodt.github.io/2019/12/05/generalized-autoref-based-specialization.html]
138                    // that could help here?
139                    Ok(MaybeSecret::Plain(Zeroizing::new(
140                        T::from_str(&s)
141                            .map_err(|_| serde::de::Error::custom("could not FromStr"))?,
142                    )))
143                }
144            }
145            Format::Plain(t) => Ok(MaybeSecret::Plain(Zeroizing::new(t))),
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use zeroize::ZeroizeOnDrop;
154
155    #[test]
156    fn test_from_str() {
157        let x: MaybeSecret<u64> = "secrets://foo".parse().unwrap();
158        assert_eq!(x, MaybeSecret::secret("foo"));
159        let y: MaybeSecret<u64> = "42".parse().unwrap();
160        assert_eq!(y, MaybeSecret::plain(42u64));
161        let z: MaybeSecret<String> = "asdf".parse().unwrap();
162        assert_eq!(z, MaybeSecret::plain("asdf".to_string()));
163    }
164
165    #[test]
166    fn test_serde() {
167        let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
168        let y = serde_json::to_string(&x).unwrap();
169        let z = serde_json::from_str(&y).unwrap();
170        assert_eq!(x, z);
171        let x: MaybeSecret<u64> = MaybeSecret::plain(42);
172        let y = serde_json::to_string(&x).unwrap();
173        let z = serde_json::from_str(&y).unwrap();
174        assert_eq!(x, z);
175        let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
176        let y = serde_json::to_string(&x).unwrap();
177        let z = serde_json::from_str(&y).unwrap();
178        assert_eq!(x, z);
179    }
180
181    #[test]
182    fn test_serde_yaml() {
183        let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
184        let y = serde_yaml::to_string(&x).unwrap();
185        let z = serde_yaml::from_str(&y).unwrap();
186        assert_eq!(x, z);
187        let x: MaybeSecret<u64> = MaybeSecret::plain(42);
188        let y = serde_yaml::to_string(&x).unwrap();
189        let z = serde_yaml::from_str(&y).unwrap();
190        assert_eq!(x, z);
191        let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
192        let y = serde_yaml::to_string(&x).unwrap();
193        let z = serde_yaml::from_str(&y).unwrap();
194        assert_eq!(x, z);
195    }
196
197    #[test]
198    fn test_serde_complex() {
199        #[derive(
200            Debug, PartialEq, Eq, Serialize, Deserialize, Zeroize, ZeroizeOnDrop,
201        )]
202        struct Foo {
203            bar: u64,
204            baz: String,
205        }
206        from_str_json!(Foo);
207        let x: MaybeSecret<Foo> =
208            MaybeSecret::plain(Foo { bar: 42, baz: "asdf".to_string() });
209        let y = serde_json::to_string(&x).unwrap();
210        let z = serde_json::from_str(&y).unwrap();
211        assert_eq!(x, z);
212        let yy = serde_yaml::to_string(&x).unwrap();
213        let zz = serde_yaml::from_str(&yy).unwrap();
214        assert_eq!(x, zz);
215        let x: MaybeSecret<Foo> = MaybeSecret::secret("my_secret_key");
216        let y = serde_json::to_string(&x).unwrap();
217        let z = serde_json::from_str(&y).unwrap();
218        assert_eq!(x, z);
219        let yy = serde_yaml::to_string(&x).unwrap();
220        let zz = serde_yaml::from_str(&yy).unwrap();
221        assert_eq!(x, zz);
222    }
223}