serde_encrypted_value/
deserializer.rs

1//  Copyright 2017 Palantir Technologies, Inc.
2//
3//  Licensed under the Apache License, Version 2.0 (the "License");
4//  you may not use this file except in compliance with the License.
5//  You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14
15use serde::de;
16use std::fmt;
17
18use crate::Key;
19
20/// A deserializer which automatically decrypts strings.
21///
22/// Encrypted strings should be formatted like `${enc:<base64 ciphertext here>}`.
23pub struct Deserializer<'a, D, T> {
24    deserializer: D,
25    key: Option<&'a Key<T>>,
26}
27
28impl<'a, 'de, D, T> Deserializer<'a, D, T>
29where
30    D: de::Deserializer<'de>,
31{
32    /// Creates a new `Deserializer` wrapping another deserializer and decrypting string values.
33    ///
34    /// If `key` is `None`, deserialization will fail if an encrypted string is encountered.
35    pub fn new(deserializer: D, key: Option<&'a Key<T>>) -> Deserializer<'a, D, T> {
36        Deserializer { deserializer, key }
37    }
38}
39
40macro_rules! forward_deserialize {
41    ($name:ident) => {forward_deserialize!($name, );};
42    ($name:ident, $($arg:tt => $ty:ty),*) => {
43        fn $name<V>(self, $($arg: $ty,)* visitor: V) -> Result<V::Value, D::Error>
44            where V: de::Visitor<'de>
45        {
46            let visitor = Visitor {
47                visitor,
48                key: self.key,
49            };
50            self.deserializer.$name($($arg,)* visitor)
51        }
52    }
53}
54
55impl<'de, D, T> de::Deserializer<'de> for Deserializer<'_, D, T>
56where
57    D: de::Deserializer<'de>,
58{
59    type Error = D::Error;
60
61    forward_deserialize!(deserialize_any);
62    forward_deserialize!(deserialize_bool);
63    forward_deserialize!(deserialize_u8);
64    forward_deserialize!(deserialize_u16);
65    forward_deserialize!(deserialize_u32);
66    forward_deserialize!(deserialize_u64);
67    forward_deserialize!(deserialize_i8);
68    forward_deserialize!(deserialize_i16);
69    forward_deserialize!(deserialize_i32);
70    forward_deserialize!(deserialize_i64);
71    forward_deserialize!(deserialize_f32);
72    forward_deserialize!(deserialize_f64);
73    forward_deserialize!(deserialize_char);
74    forward_deserialize!(deserialize_str);
75    forward_deserialize!(deserialize_string);
76    forward_deserialize!(deserialize_unit);
77    forward_deserialize!(deserialize_option);
78    forward_deserialize!(deserialize_seq);
79    forward_deserialize!(deserialize_bytes);
80    forward_deserialize!(deserialize_byte_buf);
81    forward_deserialize!(deserialize_map);
82    forward_deserialize!(deserialize_unit_struct, name => &'static str);
83    forward_deserialize!(deserialize_newtype_struct, name => &'static str);
84    forward_deserialize!(deserialize_tuple_struct, name => &'static str, len => usize);
85    forward_deserialize!(deserialize_struct,
86                         name => &'static str,
87                         fields => &'static [&'static str]);
88    forward_deserialize!(deserialize_identifier);
89    forward_deserialize!(deserialize_tuple, len => usize);
90    forward_deserialize!(deserialize_enum,
91                         name => &'static str,
92                         variants => &'static [&'static str]);
93    forward_deserialize!(deserialize_ignored_any);
94}
95
96struct Visitor<'a, V, T> {
97    visitor: V,
98    key: Option<&'a Key<T>>,
99}
100
101impl<V, T> Visitor<'_, V, T> {
102    fn expand_str<E>(&self, s: &str) -> Result<Option<String>, E>
103    where
104        E: de::Error,
105    {
106        if s.starts_with("${enc:") && s.ends_with('}') {
107            match self.key {
108                Some(key) => match key.decrypt(&s[6..s.len() - 1]) {
109                    Ok(s) => Ok(Some(s)),
110                    Err(e) => Err(E::custom(e.to_string())),
111                },
112                None => Err(E::custom("missing encryption key")),
113            }
114        } else {
115            Ok(None)
116        }
117    }
118}
119
120macro_rules! forward_visit {
121    ($name:ident, $ty:ty) => {
122        fn $name<E>(self, v: $ty) -> Result<V::Value, E>
123        where
124            E: de::Error,
125        {
126            self.visitor.$name(v)
127        }
128    };
129}
130
131impl<'de, V, T> de::Visitor<'de> for Visitor<'_, V, T>
132where
133    V: de::Visitor<'de>,
134{
135    type Value = V::Value;
136
137    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
138        self.visitor.expecting(formatter)
139    }
140
141    forward_visit!(visit_bool, bool);
142    forward_visit!(visit_i8, i8);
143    forward_visit!(visit_i16, i16);
144    forward_visit!(visit_i32, i32);
145    forward_visit!(visit_i64, i64);
146    forward_visit!(visit_u8, u8);
147    forward_visit!(visit_u16, u16);
148    forward_visit!(visit_u32, u32);
149    forward_visit!(visit_u64, u64);
150    forward_visit!(visit_f32, f32);
151    forward_visit!(visit_f64, f64);
152    forward_visit!(visit_char, char);
153    forward_visit!(visit_bytes, &[u8]);
154    forward_visit!(visit_byte_buf, Vec<u8>);
155
156    fn visit_str<E>(self, v: &str) -> Result<V::Value, E>
157    where
158        E: de::Error,
159    {
160        match self.expand_str(v)? {
161            Some(s) => self.visitor.visit_string(s),
162            None => self.visitor.visit_str(v),
163        }
164    }
165
166    fn visit_string<E>(self, v: String) -> Result<V::Value, E>
167    where
168        E: de::Error,
169    {
170        match self.expand_str(&v)? {
171            Some(s) => self.visitor.visit_string(s),
172            None => self.visitor.visit_string(v),
173        }
174    }
175
176    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<V::Value, E>
177    where
178        E: de::Error,
179    {
180        match self.expand_str(v)? {
181            Some(s) => self.visitor.visit_string(s),
182            None => self.visitor.visit_borrowed_str(v),
183        }
184    }
185
186    fn visit_unit<E>(self) -> Result<V::Value, E>
187    where
188        E: de::Error,
189    {
190        self.visitor.visit_unit()
191    }
192
193    fn visit_none<E>(self) -> Result<V::Value, E>
194    where
195        E: de::Error,
196    {
197        self.visitor.visit_none()
198    }
199
200    fn visit_some<D>(self, deserializer: D) -> Result<V::Value, D::Error>
201    where
202        D: de::Deserializer<'de>,
203    {
204        let deserializer = Deserializer::new(deserializer, self.key);
205        self.visitor.visit_some(deserializer)
206    }
207
208    fn visit_newtype_struct<D>(self, deserializer: D) -> Result<V::Value, D::Error>
209    where
210        D: de::Deserializer<'de>,
211    {
212        let deserializer = Deserializer::new(deserializer, self.key);
213        self.visitor.visit_newtype_struct(deserializer)
214    }
215
216    fn visit_seq<V2>(self, visitor: V2) -> Result<V::Value, V2::Error>
217    where
218        V2: de::SeqAccess<'de>,
219    {
220        let visitor = Visitor {
221            visitor,
222            key: self.key,
223        };
224        self.visitor.visit_seq(visitor)
225    }
226
227    fn visit_map<V2>(self, visitor: V2) -> Result<V::Value, V2::Error>
228    where
229        V2: de::MapAccess<'de>,
230    {
231        let visitor = Visitor {
232            visitor,
233            key: self.key,
234        };
235        self.visitor.visit_map(visitor)
236    }
237
238    fn visit_enum<V2>(self, visitor: V2) -> Result<V::Value, V2::Error>
239    where
240        V2: de::EnumAccess<'de>,
241    {
242        let visitor = Visitor {
243            visitor,
244            key: self.key,
245        };
246        self.visitor.visit_enum(visitor)
247    }
248}
249
250impl<'de, V, T> de::SeqAccess<'de> for Visitor<'_, V, T>
251where
252    V: de::SeqAccess<'de>,
253{
254    type Error = V::Error;
255
256    fn next_element_seed<S>(&mut self, seed: S) -> Result<Option<S::Value>, V::Error>
257    where
258        S: de::DeserializeSeed<'de>,
259    {
260        let seed = DeserializeSeed {
261            seed,
262            key: self.key,
263        };
264        self.visitor.next_element_seed(seed)
265    }
266
267    fn size_hint(&self) -> Option<usize> {
268        self.visitor.size_hint()
269    }
270}
271
272impl<'de, V, T> de::MapAccess<'de> for Visitor<'_, V, T>
273where
274    V: de::MapAccess<'de>,
275{
276    type Error = V::Error;
277
278    fn next_key_seed<S>(&mut self, seed: S) -> Result<Option<S::Value>, V::Error>
279    where
280        S: de::DeserializeSeed<'de>,
281    {
282        let seed = DeserializeSeed {
283            seed,
284            key: self.key,
285        };
286        self.visitor.next_key_seed(seed)
287    }
288
289    fn next_value_seed<S>(&mut self, seed: S) -> Result<S::Value, V::Error>
290    where
291        S: de::DeserializeSeed<'de>,
292    {
293        let seed = DeserializeSeed {
294            seed,
295            key: self.key,
296        };
297        self.visitor.next_value_seed(seed)
298    }
299
300    #[allow(clippy::type_complexity)]
301    fn next_entry_seed<K, V2>(
302        &mut self,
303        kseed: K,
304        vseed: V2,
305    ) -> Result<Option<(K::Value, V2::Value)>, V::Error>
306    where
307        K: de::DeserializeSeed<'de>,
308        V2: de::DeserializeSeed<'de>,
309    {
310        let kseed = DeserializeSeed {
311            seed: kseed,
312            key: self.key,
313        };
314        let vseed = DeserializeSeed {
315            seed: vseed,
316            key: self.key,
317        };
318        self.visitor.next_entry_seed(kseed, vseed)
319    }
320
321    fn size_hint(&self) -> Option<usize> {
322        self.visitor.size_hint()
323    }
324}
325
326impl<'a, 'de, V, T> de::EnumAccess<'de> for Visitor<'a, V, T>
327where
328    V: de::EnumAccess<'de>,
329{
330    type Error = V::Error;
331    type Variant = Visitor<'a, V::Variant, T>;
332
333    #[allow(clippy::type_complexity)]
334    fn variant_seed<S>(self, seed: S) -> Result<(S::Value, Visitor<'a, V::Variant, T>), V::Error>
335    where
336        S: de::DeserializeSeed<'de>,
337    {
338        let seed = DeserializeSeed {
339            seed,
340            key: self.key,
341        };
342        match self.visitor.variant_seed(seed) {
343            Ok((value, variant)) => {
344                let variant = Visitor {
345                    visitor: variant,
346                    key: self.key,
347                };
348                Ok((value, variant))
349            }
350            Err(e) => Err(e),
351        }
352    }
353}
354
355impl<'de, V, T> de::VariantAccess<'de> for Visitor<'_, V, T>
356where
357    V: de::VariantAccess<'de>,
358{
359    type Error = V::Error;
360
361    fn unit_variant(self) -> Result<(), V::Error> {
362        self.visitor.unit_variant()
363    }
364
365    fn newtype_variant_seed<S>(self, seed: S) -> Result<S::Value, V::Error>
366    where
367        S: de::DeserializeSeed<'de>,
368    {
369        let seed = DeserializeSeed {
370            seed,
371            key: self.key,
372        };
373        self.visitor.newtype_variant_seed(seed)
374    }
375
376    fn tuple_variant<V2>(self, len: usize, visitor: V2) -> Result<V2::Value, V::Error>
377    where
378        V2: de::Visitor<'de>,
379    {
380        let visitor = Visitor {
381            visitor,
382            key: self.key,
383        };
384        self.visitor.tuple_variant(len, visitor)
385    }
386
387    fn struct_variant<V2>(
388        self,
389        fields: &'static [&'static str],
390        visitor: V2,
391    ) -> Result<V2::Value, V::Error>
392    where
393        V2: de::Visitor<'de>,
394    {
395        let visitor = Visitor {
396            visitor,
397            key: self.key,
398        };
399        self.visitor.struct_variant(fields, visitor)
400    }
401}
402
403struct DeserializeSeed<'a, S, T> {
404    seed: S,
405    key: Option<&'a Key<T>>,
406}
407
408impl<'de, S, T> de::DeserializeSeed<'de> for DeserializeSeed<'_, S, T>
409where
410    S: de::DeserializeSeed<'de>,
411{
412    type Value = S::Value;
413
414    fn deserialize<D>(self, deserializer: D) -> Result<S::Value, D::Error>
415    where
416        D: de::Deserializer<'de>,
417    {
418        let deserializer = Deserializer::new(deserializer, self.key);
419        self.seed.deserialize(deserializer)
420    }
421}