architect_api/utils/
secrets.rs1use schemars::{
4 r#gen::SchemaGenerator,
5 schema::{InstanceType, Schema, SchemaObject},
6 JsonSchema,
7};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use std::{fmt::Display, str::FromStr};
10use zeroize::{Zeroize, Zeroizing};
11
12#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
15pub enum MaybeSecret<T: Zeroize> {
16 Secret(String),
17 Plain(Zeroizing<T>),
18}
19
20impl<T: Zeroize> MaybeSecret<T> {
21 pub fn key(&self) -> Option<String> {
22 match self {
23 MaybeSecret::Secret(s) => Some(s.clone()),
24 MaybeSecret::Plain(_) => None,
25 }
26 }
27
28 pub fn secret<S: AsRef<str>>(key: S) -> Self {
29 MaybeSecret::Secret(key.as_ref().to_string())
30 }
31
32 pub fn plain(t: T) -> Self {
33 MaybeSecret::Plain(Zeroizing::new(t))
34 }
35}
36
37impl<T: Clone + Zeroize> MaybeSecret<T> {
38 pub fn to_plain(&self) -> Option<Zeroizing<T>> {
39 match self {
40 MaybeSecret::Secret(_) => None,
41 MaybeSecret::Plain(t) => Some(t.clone()),
42 }
43 }
44}
45
46#[macro_export]
50macro_rules! from_str_json {
51 ($t:ty) => {
52 impl std::str::FromStr for $t {
53 type Err = serde_json::Error;
54
55 fn from_str(s: &str) -> Result<Self, Self::Err> {
56 serde_json::from_str(s)
57 }
58 }
59 };
60}
61
62impl<T: Serialize + Zeroize + JsonSchema> JsonSchema for MaybeSecret<T> {
63 fn schema_name() -> String {
64 "MaybeSecret".to_owned()
66 }
67
68 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
69 SchemaObject {
70 instance_type: Some(InstanceType::String.into()),
71 ..Default::default()
72 }
73 .into()
74 }
75
76 fn is_referenceable() -> bool {
77 true
78 }
79}
80
81impl<T: Display + Serialize + Zeroize> Display for MaybeSecret<T> {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 MaybeSecret::Secret(s) => write!(f, "secrets://{}", s),
85 MaybeSecret::Plain(s) => {
86 write!(f, "{}", serde_json::to_string(&**s).map_err(|_| std::fmt::Error)?)
87 }
88 }
89 }
90}
91
92impl<T: FromStr + Zeroize> FromStr for MaybeSecret<T> {
93 type Err = <T as FromStr>::Err;
94
95 fn from_str(s: &str) -> Result<Self, Self::Err> {
96 if let Some(key) = s.strip_prefix("secrets://") {
97 Ok(MaybeSecret::Secret(key.to_string()))
98 } else {
99 Ok(MaybeSecret::Plain(Zeroizing::new(s.parse()?)))
100 }
101 }
102}
103
104impl<T: Serialize + Zeroize> Serialize for MaybeSecret<T> {
105 fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
106 match self {
107 MaybeSecret::Secret(s) => ser.serialize_str(&format!("secrets://{}", s)),
108 MaybeSecret::Plain(t) => (*t).serialize(ser),
109 }
110 }
111}
112
113impl<'de, T: DeserializeOwned + FromStr + Zeroize> Deserialize<'de> for MaybeSecret<T> {
114 fn deserialize<D>(de: D) -> Result<Self, D::Error>
115 where
116 D: serde::Deserializer<'de>,
117 {
118 #[derive(Serialize, Deserialize)]
119 #[serde(untagged)]
120 enum Format<T> {
121 SecretOrString(String),
122 Plain(T),
123 }
124 match Format::<T>::deserialize(de)? {
125 Format::SecretOrString(s) => {
126 if let Some(key) = s.strip_prefix("secrets://") {
127 Ok(MaybeSecret::Secret(key.to_string()))
128 } else {
129 Ok(MaybeSecret::Plain(Zeroizing::new(
140 T::from_str(&s)
141 .map_err(|_| serde::de::Error::custom("could not FromStr"))?,
142 )))
143 }
144 }
145 Format::Plain(t) => Ok(MaybeSecret::Plain(Zeroizing::new(t))),
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use zeroize::ZeroizeOnDrop;
154
155 #[test]
156 fn test_from_str() {
157 let x: MaybeSecret<u64> = "secrets://foo".parse().unwrap();
158 assert_eq!(x, MaybeSecret::secret("foo"));
159 let y: MaybeSecret<u64> = "42".parse().unwrap();
160 assert_eq!(y, MaybeSecret::plain(42u64));
161 let z: MaybeSecret<String> = "asdf".parse().unwrap();
162 assert_eq!(z, MaybeSecret::plain("asdf".to_string()));
163 }
164
165 #[test]
166 fn test_serde() {
167 let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
168 let y = serde_json::to_string(&x).unwrap();
169 let z = serde_json::from_str(&y).unwrap();
170 assert_eq!(x, z);
171 let x: MaybeSecret<u64> = MaybeSecret::plain(42);
172 let y = serde_json::to_string(&x).unwrap();
173 let z = serde_json::from_str(&y).unwrap();
174 assert_eq!(x, z);
175 let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
176 let y = serde_json::to_string(&x).unwrap();
177 let z = serde_json::from_str(&y).unwrap();
178 assert_eq!(x, z);
179 }
180
181 #[test]
182 fn test_serde_yaml() {
183 let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
184 let y = serde_yaml::to_string(&x).unwrap();
185 let z = serde_yaml::from_str(&y).unwrap();
186 assert_eq!(x, z);
187 let x: MaybeSecret<u64> = MaybeSecret::plain(42);
188 let y = serde_yaml::to_string(&x).unwrap();
189 let z = serde_yaml::from_str(&y).unwrap();
190 assert_eq!(x, z);
191 let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
192 let y = serde_yaml::to_string(&x).unwrap();
193 let z = serde_yaml::from_str(&y).unwrap();
194 assert_eq!(x, z);
195 }
196
197 #[test]
198 fn test_serde_complex() {
199 #[derive(
200 Debug, PartialEq, Eq, Serialize, Deserialize, Zeroize, ZeroizeOnDrop,
201 )]
202 struct Foo {
203 bar: u64,
204 baz: String,
205 }
206 from_str_json!(Foo);
207 let x: MaybeSecret<Foo> =
208 MaybeSecret::plain(Foo { bar: 42, baz: "asdf".to_string() });
209 let y = serde_json::to_string(&x).unwrap();
210 let z = serde_json::from_str(&y).unwrap();
211 assert_eq!(x, z);
212 let yy = serde_yaml::to_string(&x).unwrap();
213 let zz = serde_yaml::from_str(&yy).unwrap();
214 assert_eq!(x, zz);
215 let x: MaybeSecret<Foo> = MaybeSecret::secret("my_secret_key");
216 let y = serde_json::to_string(&x).unwrap();
217 let z = serde_json::from_str(&y).unwrap();
218 assert_eq!(x, z);
219 let yy = serde_yaml::to_string(&x).unwrap();
220 let zz = serde_yaml::from_str(&yy).unwrap();
221 assert_eq!(x, zz);
222 }
223}