architect_api/utils/
secrets.rs

1//! Types for working with the secret store
2
3#[cfg(feature = "netidx")]
4use bytes::{Buf, BufMut};
5#[cfg(feature = "netidx")]
6use netidx::pack::{Pack, PackError};
7use schemars::{
8    gen::SchemaGenerator,
9    schema::{InstanceType, Schema, SchemaObject},
10    JsonSchema,
11};
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use std::{fmt::Display, str::FromStr};
14use zeroize::{Zeroize, Zeroizing};
15
16/// A type that is either a reference to a secret, serialized as
17/// a URI string like secrets://<key>, or a plain literal.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum MaybeSecret<T: Zeroize> {
20    Secret(String),
21    Plain(Zeroizing<T>),
22}
23
24impl<T: Zeroize> MaybeSecret<T> {
25    pub fn key(&self) -> Option<String> {
26        match self {
27            MaybeSecret::Secret(s) => Some(s.clone()),
28            MaybeSecret::Plain(_) => None,
29        }
30    }
31
32    pub fn secret<S: AsRef<str>>(key: S) -> Self {
33        MaybeSecret::Secret(key.as_ref().to_string())
34    }
35
36    pub fn plain(t: T) -> Self {
37        MaybeSecret::Plain(Zeroizing::new(t))
38    }
39}
40
41impl<T: Clone + Zeroize> MaybeSecret<T> {
42    pub fn to_plain(&self) -> Option<Zeroizing<T>> {
43        match self {
44            MaybeSecret::Secret(_) => None,
45            MaybeSecret::Plain(t) => Some(t.clone()),
46        }
47    }
48}
49
50// Most useful implementations of T for MaybeSecret will require
51// a FromStr implementation.  If you don't have one handy, use
52// this macro to get a reasonable-ish one using serde_json.
53#[macro_export]
54macro_rules! from_str_json {
55    ($t:ty) => {
56        impl std::str::FromStr for $t {
57            type Err = serde_json::Error;
58
59            fn from_str(s: &str) -> Result<Self, Self::Err> {
60                serde_json::from_str(s)
61            }
62        }
63    };
64}
65
66impl<T: Serialize + Zeroize + JsonSchema> JsonSchema for MaybeSecret<T> {
67    fn schema_name() -> String {
68        // Exclude the module path to make the name in generated schemas clearer.
69        "MaybeSecret".to_owned()
70    }
71
72    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
73        SchemaObject {
74            instance_type: Some(InstanceType::String.into()),
75            ..Default::default()
76        }
77        .into()
78    }
79
80    fn is_referenceable() -> bool {
81        true
82    }
83}
84
85impl<T: Display + Serialize + Zeroize> Display for MaybeSecret<T> {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match &*self {
88            MaybeSecret::Secret(s) => write!(f, "secrets://{}", s),
89            MaybeSecret::Plain(s) => {
90                write!(f, "{}", serde_json::to_string(&**s).map_err(|_| std::fmt::Error)?)
91            }
92        }
93    }
94}
95
96impl<T: FromStr + Zeroize> FromStr for MaybeSecret<T> {
97    type Err = <T as FromStr>::Err;
98
99    fn from_str(s: &str) -> Result<Self, Self::Err> {
100        if s.starts_with("secrets://") {
101            Ok(MaybeSecret::Secret(s[10..].to_string()))
102        } else {
103            Ok(MaybeSecret::Plain(Zeroizing::new(s.parse()?)))
104        }
105    }
106}
107
108impl<T: Serialize + Zeroize> Serialize for MaybeSecret<T> {
109    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
110        match self {
111            MaybeSecret::Secret(s) => ser.serialize_str(&format!("secrets://{}", s)),
112            MaybeSecret::Plain(t) => (&*t).serialize(ser),
113        }
114    }
115}
116
117impl<'de, T: DeserializeOwned + FromStr + Zeroize> Deserialize<'de> for MaybeSecret<T> {
118    fn deserialize<D>(de: D) -> Result<Self, D::Error>
119    where
120        D: serde::Deserializer<'de>,
121    {
122        #[derive(Serialize, Deserialize)]
123        #[serde(untagged)]
124        enum Format<T> {
125            SecretOrString(String),
126            Plain(T),
127        }
128        match Format::<T>::deserialize(de)? {
129            Format::SecretOrString(s) => {
130                if s.starts_with("secrets://") {
131                    Ok(MaybeSecret::Secret(s[10..].to_string()))
132                } else {
133                    // using FromStr here is hacky but it works for the
134                    // important cases of T = String, &str, etc... at
135                    // the cost of requiring FromStr from structs
136                    //
137                    // if you're looking for some dumb FromStr to use
138                    // try the FromStrJson macro in derive
139                    //
140                    // maybe there's some trick leveraging auto(de)ref
141                    // specialization [https://lukaskalbertodt.github.io/2019/12/05/generalized-autoref-based-specialization.html]
142                    // that could help here?
143                    Ok(MaybeSecret::Plain(Zeroizing::new(
144                        T::from_str(&s)
145                            .map_err(|_| serde::de::Error::custom("could not FromStr"))?,
146                    )))
147                }
148            }
149            Format::Plain(t) => Ok(MaybeSecret::Plain(Zeroizing::new(t))),
150        }
151    }
152}
153
154#[cfg(feature = "netidx")]
155impl<T: Zeroize + Pack> Pack for MaybeSecret<T> {
156    fn encoded_len(&self) -> usize {
157        const TAG_LEN: usize = 1;
158        let clen = match self {
159            MaybeSecret::Secret(s) => s.encoded_len(),
160            MaybeSecret::Plain(t) => t.encoded_len(),
161        };
162        TAG_LEN + clen
163    }
164
165    fn encode(&self, buf: &mut impl BufMut) -> Result<(), PackError> {
166        match self {
167            MaybeSecret::Secret(s) => {
168                buf.put_u8(0);
169                s.encode(buf)?;
170            }
171            MaybeSecret::Plain(t) => {
172                buf.put_u8(1);
173                t.encode(buf)?;
174            }
175        }
176        Ok(())
177    }
178
179    fn decode(buf: &mut impl Buf) -> Result<Self, PackError>
180    where
181        Self: Sized,
182    {
183        let tag = buf.get_u8();
184        match tag {
185            0 => Ok(MaybeSecret::Secret(String::decode(buf)?)),
186            1 => Ok(MaybeSecret::Plain(Zeroizing::new(T::decode(buf)?))),
187            _ => Err(PackError::UnknownTag),
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use zeroize::ZeroizeOnDrop;
196
197    #[test]
198    fn test_from_str() {
199        let x: MaybeSecret<u64> = "secrets://foo".parse().unwrap();
200        assert_eq!(x, MaybeSecret::secret("foo"));
201        let y: MaybeSecret<u64> = "42".parse().unwrap();
202        assert_eq!(y, MaybeSecret::plain(42u64));
203        let z: MaybeSecret<String> = "asdf".parse().unwrap();
204        assert_eq!(z, MaybeSecret::plain("asdf".to_string()));
205    }
206
207    #[test]
208    fn test_serde() {
209        let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
210        let y = serde_json::to_string(&x).unwrap();
211        let z = serde_json::from_str(&y).unwrap();
212        assert_eq!(x, z);
213        let x: MaybeSecret<u64> = MaybeSecret::plain(42);
214        let y = serde_json::to_string(&x).unwrap();
215        let z = serde_json::from_str(&y).unwrap();
216        assert_eq!(x, z);
217        let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
218        let y = serde_json::to_string(&x).unwrap();
219        let z = serde_json::from_str(&y).unwrap();
220        assert_eq!(x, z);
221    }
222
223    #[test]
224    fn test_serde_yaml() {
225        let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
226        let y = serde_yaml::to_string(&x).unwrap();
227        let z = serde_yaml::from_str(&y).unwrap();
228        assert_eq!(x, z);
229        let x: MaybeSecret<u64> = MaybeSecret::plain(42);
230        let y = serde_yaml::to_string(&x).unwrap();
231        let z = serde_yaml::from_str(&y).unwrap();
232        assert_eq!(x, z);
233        let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
234        let y = serde_yaml::to_string(&x).unwrap();
235        let z = serde_yaml::from_str(&y).unwrap();
236        assert_eq!(x, z);
237    }
238
239    #[test]
240    fn test_serde_complex() {
241        #[derive(
242            Debug, PartialEq, Eq, Serialize, Deserialize, Zeroize, ZeroizeOnDrop,
243        )]
244        struct Foo {
245            bar: u64,
246            baz: String,
247        }
248        from_str_json!(Foo);
249        let x: MaybeSecret<Foo> =
250            MaybeSecret::plain(Foo { bar: 42, baz: "asdf".to_string() });
251        let y = serde_json::to_string(&x).unwrap();
252        let z = serde_json::from_str(&y).unwrap();
253        assert_eq!(x, z);
254        let yy = serde_yaml::to_string(&x).unwrap();
255        let zz = serde_yaml::from_str(&yy).unwrap();
256        assert_eq!(x, zz);
257        let x: MaybeSecret<Foo> = MaybeSecret::secret("my_secret_key");
258        let y = serde_json::to_string(&x).unwrap();
259        let z = serde_json::from_str(&y).unwrap();
260        assert_eq!(x, z);
261        let yy = serde_yaml::to_string(&x).unwrap();
262        let zz = serde_yaml::from_str(&yy).unwrap();
263        assert_eq!(x, zz);
264    }
265}