architect_api/utils/
secrets.rs1#[cfg(feature = "netidx")]
4use bytes::{Buf, BufMut};
5#[cfg(feature = "netidx")]
6use netidx::pack::{Pack, PackError};
7use schemars::{
8 gen::SchemaGenerator,
9 schema::{InstanceType, Schema, SchemaObject},
10 JsonSchema,
11};
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use std::{fmt::Display, str::FromStr};
14use zeroize::{Zeroize, Zeroizing};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum MaybeSecret<T: Zeroize> {
20 Secret(String),
21 Plain(Zeroizing<T>),
22}
23
24impl<T: Zeroize> MaybeSecret<T> {
25 pub fn key(&self) -> Option<String> {
26 match self {
27 MaybeSecret::Secret(s) => Some(s.clone()),
28 MaybeSecret::Plain(_) => None,
29 }
30 }
31
32 pub fn secret<S: AsRef<str>>(key: S) -> Self {
33 MaybeSecret::Secret(key.as_ref().to_string())
34 }
35
36 pub fn plain(t: T) -> Self {
37 MaybeSecret::Plain(Zeroizing::new(t))
38 }
39}
40
41impl<T: Clone + Zeroize> MaybeSecret<T> {
42 pub fn to_plain(&self) -> Option<Zeroizing<T>> {
43 match self {
44 MaybeSecret::Secret(_) => None,
45 MaybeSecret::Plain(t) => Some(t.clone()),
46 }
47 }
48}
49
50#[macro_export]
54macro_rules! from_str_json {
55 ($t:ty) => {
56 impl std::str::FromStr for $t {
57 type Err = serde_json::Error;
58
59 fn from_str(s: &str) -> Result<Self, Self::Err> {
60 serde_json::from_str(s)
61 }
62 }
63 };
64}
65
66impl<T: Serialize + Zeroize + JsonSchema> JsonSchema for MaybeSecret<T> {
67 fn schema_name() -> String {
68 "MaybeSecret".to_owned()
70 }
71
72 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
73 SchemaObject {
74 instance_type: Some(InstanceType::String.into()),
75 ..Default::default()
76 }
77 .into()
78 }
79
80 fn is_referenceable() -> bool {
81 true
82 }
83}
84
85impl<T: Display + Serialize + Zeroize> Display for MaybeSecret<T> {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match &*self {
88 MaybeSecret::Secret(s) => write!(f, "secrets://{}", s),
89 MaybeSecret::Plain(s) => {
90 write!(f, "{}", serde_json::to_string(&**s).map_err(|_| std::fmt::Error)?)
91 }
92 }
93 }
94}
95
96impl<T: FromStr + Zeroize> FromStr for MaybeSecret<T> {
97 type Err = <T as FromStr>::Err;
98
99 fn from_str(s: &str) -> Result<Self, Self::Err> {
100 if s.starts_with("secrets://") {
101 Ok(MaybeSecret::Secret(s[10..].to_string()))
102 } else {
103 Ok(MaybeSecret::Plain(Zeroizing::new(s.parse()?)))
104 }
105 }
106}
107
108impl<T: Serialize + Zeroize> Serialize for MaybeSecret<T> {
109 fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
110 match self {
111 MaybeSecret::Secret(s) => ser.serialize_str(&format!("secrets://{}", s)),
112 MaybeSecret::Plain(t) => (&*t).serialize(ser),
113 }
114 }
115}
116
117impl<'de, T: DeserializeOwned + FromStr + Zeroize> Deserialize<'de> for MaybeSecret<T> {
118 fn deserialize<D>(de: D) -> Result<Self, D::Error>
119 where
120 D: serde::Deserializer<'de>,
121 {
122 #[derive(Serialize, Deserialize)]
123 #[serde(untagged)]
124 enum Format<T> {
125 SecretOrString(String),
126 Plain(T),
127 }
128 match Format::<T>::deserialize(de)? {
129 Format::SecretOrString(s) => {
130 if s.starts_with("secrets://") {
131 Ok(MaybeSecret::Secret(s[10..].to_string()))
132 } else {
133 Ok(MaybeSecret::Plain(Zeroizing::new(
144 T::from_str(&s)
145 .map_err(|_| serde::de::Error::custom("could not FromStr"))?,
146 )))
147 }
148 }
149 Format::Plain(t) => Ok(MaybeSecret::Plain(Zeroizing::new(t))),
150 }
151 }
152}
153
154#[cfg(feature = "netidx")]
155impl<T: Zeroize + Pack> Pack for MaybeSecret<T> {
156 fn encoded_len(&self) -> usize {
157 const TAG_LEN: usize = 1;
158 let clen = match self {
159 MaybeSecret::Secret(s) => s.encoded_len(),
160 MaybeSecret::Plain(t) => t.encoded_len(),
161 };
162 TAG_LEN + clen
163 }
164
165 fn encode(&self, buf: &mut impl BufMut) -> Result<(), PackError> {
166 match self {
167 MaybeSecret::Secret(s) => {
168 buf.put_u8(0);
169 s.encode(buf)?;
170 }
171 MaybeSecret::Plain(t) => {
172 buf.put_u8(1);
173 t.encode(buf)?;
174 }
175 }
176 Ok(())
177 }
178
179 fn decode(buf: &mut impl Buf) -> Result<Self, PackError>
180 where
181 Self: Sized,
182 {
183 let tag = buf.get_u8();
184 match tag {
185 0 => Ok(MaybeSecret::Secret(String::decode(buf)?)),
186 1 => Ok(MaybeSecret::Plain(Zeroizing::new(T::decode(buf)?))),
187 _ => Err(PackError::UnknownTag),
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use zeroize::ZeroizeOnDrop;
196
197 #[test]
198 fn test_from_str() {
199 let x: MaybeSecret<u64> = "secrets://foo".parse().unwrap();
200 assert_eq!(x, MaybeSecret::secret("foo"));
201 let y: MaybeSecret<u64> = "42".parse().unwrap();
202 assert_eq!(y, MaybeSecret::plain(42u64));
203 let z: MaybeSecret<String> = "asdf".parse().unwrap();
204 assert_eq!(z, MaybeSecret::plain("asdf".to_string()));
205 }
206
207 #[test]
208 fn test_serde() {
209 let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
210 let y = serde_json::to_string(&x).unwrap();
211 let z = serde_json::from_str(&y).unwrap();
212 assert_eq!(x, z);
213 let x: MaybeSecret<u64> = MaybeSecret::plain(42);
214 let y = serde_json::to_string(&x).unwrap();
215 let z = serde_json::from_str(&y).unwrap();
216 assert_eq!(x, z);
217 let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
218 let y = serde_json::to_string(&x).unwrap();
219 let z = serde_json::from_str(&y).unwrap();
220 assert_eq!(x, z);
221 }
222
223 #[test]
224 fn test_serde_yaml() {
225 let x: MaybeSecret<u64> = MaybeSecret::secret("asdf");
226 let y = serde_yaml::to_string(&x).unwrap();
227 let z = serde_yaml::from_str(&y).unwrap();
228 assert_eq!(x, z);
229 let x: MaybeSecret<u64> = MaybeSecret::plain(42);
230 let y = serde_yaml::to_string(&x).unwrap();
231 let z = serde_yaml::from_str(&y).unwrap();
232 assert_eq!(x, z);
233 let x: MaybeSecret<String> = MaybeSecret::plain("hahaha".to_string());
234 let y = serde_yaml::to_string(&x).unwrap();
235 let z = serde_yaml::from_str(&y).unwrap();
236 assert_eq!(x, z);
237 }
238
239 #[test]
240 fn test_serde_complex() {
241 #[derive(
242 Debug, PartialEq, Eq, Serialize, Deserialize, Zeroize, ZeroizeOnDrop,
243 )]
244 struct Foo {
245 bar: u64,
246 baz: String,
247 }
248 from_str_json!(Foo);
249 let x: MaybeSecret<Foo> =
250 MaybeSecret::plain(Foo { bar: 42, baz: "asdf".to_string() });
251 let y = serde_json::to_string(&x).unwrap();
252 let z = serde_json::from_str(&y).unwrap();
253 assert_eq!(x, z);
254 let yy = serde_yaml::to_string(&x).unwrap();
255 let zz = serde_yaml::from_str(&yy).unwrap();
256 assert_eq!(x, zz);
257 let x: MaybeSecret<Foo> = MaybeSecret::secret("my_secret_key");
258 let y = serde_json::to_string(&x).unwrap();
259 let z = serde_json::from_str(&y).unwrap();
260 assert_eq!(x, z);
261 let yy = serde_yaml::to_string(&x).unwrap();
262 let zz = serde_yaml::from_str(&yy).unwrap();
263 assert_eq!(x, zz);
264 }
265}