diem_crypto/
validatable.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! This module provides the `Validate` trait and `Validatable` type in order to aid in deferred
5//! validation.
6
7use anyhow::Result;
8use once_cell::sync::OnceCell;
9use serde::{Deserialize, Serialize};
10use std::{convert::TryFrom, hash::Hash};
11
12/// The `Validate` trait is used in tandem with the `Validatable` type in order to provide deferred
13/// validation for types.
14///
15/// ## Trait Contract
16///
17/// Any type `V` which implement this trait must adhere to the following contract:
18///
19/// * `V` and `V::Unvalidated` are byte-for-byte equivalent.
20/// * `V` and `V::Unvalidated` have equivalent `Hash` implementations.
21/// * `V` and `V::Unvalidated` must have equivalent `Serialize` and `Deserialize` implementation.
22///   This means that `V` and `V:Unvalidated` have equivalent serialized formats and that you can
23///   deserialize a `V::Unvalidated` from a `V` that was previously serialized.
24pub trait Validate: Sized {
25    /// The unvalidated form of some type `V`
26    type Unvalidated;
27
28    /// Attempt to validate a `V::Unvalidated` and returning a validated `V` on success
29    fn validate(unvalidated: &Self::Unvalidated) -> Result<Self>;
30
31    /// Return the unvalidated form of type `V`
32    fn to_unvalidated(&self) -> Self::Unvalidated;
33}
34
35/// Used in connection with the `Validate` trait to be able to represent types which can benefit
36/// from deferred validation as a performance optimization.
37#[derive(Clone, Debug)]
38pub struct Validatable<V: Validate> {
39    unvalidated: V::Unvalidated,
40    maybe_valid: OnceCell<V>,
41}
42
43impl<V: Validate> Validatable<V> {
44    /// Create a new `Validatable` from a valid type
45    pub fn new_valid(valid: V) -> Self {
46        let unvalidated = valid.to_unvalidated();
47
48        let maybe_valid = OnceCell::new();
49        maybe_valid.set(valid).unwrap_or_else(|_| unreachable!());
50
51        Self {
52            unvalidated,
53            maybe_valid,
54        }
55    }
56
57    /// Create a new `Validatable` from an unvalidated type
58    pub fn new_unvalidated(unvalidated: V::Unvalidated) -> Self {
59        Self {
60            unvalidated,
61            maybe_valid: OnceCell::new(),
62        }
63    }
64
65    /// Return a reference to the unvalidated form `V::Unvalidated`
66    pub fn unvalidated(&self) -> &V::Unvalidated {
67        &self.unvalidated
68    }
69
70    /// Try to validate the unvalidated form returning `Some(&V)` on success and `None` on failure.
71    pub fn valid(&self) -> Option<&V> {
72        self.validate().ok()
73    }
74
75    // TODO maybe optimize to only try once and keep track when we fail
76    /// Attempt to validate `V::Unvalidated` and return a reference to a valid `V`
77    pub fn validate(&self) -> Result<&V> {
78        self.maybe_valid
79            .get_or_try_init(|| V::validate(&self.unvalidated))
80    }
81}
82
83impl<V> Serialize for Validatable<V>
84where
85    V: Validate + Serialize,
86    V::Unvalidated: Serialize,
87{
88    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
89    where
90        S: serde::Serializer,
91    {
92        self.unvalidated.serialize(serializer)
93    }
94}
95
96impl<'de, V> Deserialize<'de> for Validatable<V>
97where
98    V: Validate,
99    V::Unvalidated: Deserialize<'de>,
100{
101    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102    where
103        D: serde::Deserializer<'de>,
104    {
105        let unvalidated = <V::Unvalidated>::deserialize(deserializer)?;
106        Ok(Self::new_unvalidated(unvalidated))
107    }
108}
109
110impl<V> PartialEq for Validatable<V>
111where
112    V: Validate,
113    V::Unvalidated: PartialEq,
114{
115    fn eq(&self, other: &Self) -> bool {
116        self.unvalidated == other.unvalidated
117    }
118}
119
120impl<V> Eq for Validatable<V>
121where
122    V: Validate,
123    V::Unvalidated: Eq,
124{
125}
126
127impl<V> Hash for Validatable<V>
128where
129    V: Validate,
130    V::Unvalidated: Hash,
131{
132    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
133        self.unvalidated.hash(state);
134    }
135}
136
137//
138// Implement for Ed25519
139//
140
141use crate::ed25519::{Ed25519PublicKey, ED25519_PUBLIC_KEY_LENGTH};
142
143/// An unvalidated `Ed25519PublicKey`
144#[derive(Debug, Clone, Eq)]
145pub struct UnvalidatedEd25519PublicKey([u8; ED25519_PUBLIC_KEY_LENGTH]);
146
147impl UnvalidatedEd25519PublicKey {
148    /// Return key as bytes
149    pub fn to_bytes(&self) -> [u8; ED25519_PUBLIC_KEY_LENGTH] {
150        self.0
151    }
152}
153
154impl Serialize for UnvalidatedEd25519PublicKey {
155    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
156    where
157        S: serde::Serializer,
158    {
159        if serializer.is_human_readable() {
160            let encoded = ::hex::encode(&self.0);
161            serializer.serialize_str(&encoded)
162        } else {
163            // See comment in deserialize_key.
164            serializer.serialize_newtype_struct(
165                "Ed25519PublicKey",
166                serde_bytes::Bytes::new(self.0.as_ref()),
167            )
168        }
169    }
170}
171
172impl<'de> Deserialize<'de> for UnvalidatedEd25519PublicKey {
173    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174    where
175        D: serde::Deserializer<'de>,
176    {
177        use serde::de::Error;
178
179        if deserializer.is_human_readable() {
180            let encoded_key = <String>::deserialize(deserializer)?;
181            let bytes_out = ::hex::decode(encoded_key).map_err(D::Error::custom)?;
182            <[u8; ED25519_PUBLIC_KEY_LENGTH]>::try_from(bytes_out.as_ref())
183                .map(UnvalidatedEd25519PublicKey)
184                .map_err(D::Error::custom)
185        } else {
186            // In order to preserve the Serde data model and help analysis tools,
187            // make sure to wrap our value in a container with the same name
188            // as the original type.
189            #[derive(Deserialize)]
190            #[serde(rename = "Ed25519PublicKey")]
191            struct Value<'a>(&'a [u8]);
192
193            let value = Value::deserialize(deserializer)?;
194            <[u8; ED25519_PUBLIC_KEY_LENGTH]>::try_from(value.0)
195                .map(UnvalidatedEd25519PublicKey)
196                .map_err(D::Error::custom)
197        }
198    }
199}
200
201impl Hash for UnvalidatedEd25519PublicKey {
202    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
203        state.write(&self.0)
204    }
205}
206
207impl PartialEq for UnvalidatedEd25519PublicKey {
208    fn eq(&self, other: &Self) -> bool {
209        self.0 == other.0
210    }
211}
212
213impl Validate for Ed25519PublicKey {
214    type Unvalidated = UnvalidatedEd25519PublicKey;
215
216    fn validate(unvalidated: &Self::Unvalidated) -> Result<Self> {
217        Self::try_from(unvalidated.0.as_ref()).map_err(Into::into)
218    }
219
220    fn to_unvalidated(&self) -> Self::Unvalidated {
221        UnvalidatedEd25519PublicKey(self.to_bytes())
222    }
223}
224
225#[cfg(test)]
226mod test {
227    use crate::{
228        ed25519::{Ed25519PrivateKey, Ed25519PublicKey},
229        test_utils::uniform_keypair_strategy,
230        validatable::{UnvalidatedEd25519PublicKey, Validate},
231    };
232    use proptest::prelude::*;
233    use std::{
234        collections::hash_map::DefaultHasher,
235        hash::{Hash, Hasher},
236    };
237
238    proptest! {
239        #[test]
240        fn unvalidated_ed25519_public_key_equivalence(
241            keypair in uniform_keypair_strategy::<Ed25519PrivateKey, Ed25519PublicKey>()
242        ) {
243            let valid = keypair.public_key;
244            let unvalidated = valid.to_unvalidated();
245
246            prop_assert_eq!(&unvalidated, &UnvalidatedEd25519PublicKey(valid.to_bytes()));
247            prop_assert_eq!(&valid, &Ed25519PublicKey::validate(&unvalidated).unwrap());
248
249            // Ensure Serialize and Deserialize are implemented the same
250
251            // BCS - A non-human-readable format
252            {
253                let serialized_valid = bcs::to_bytes(&valid).unwrap();
254                let serialized_unvalidated = bcs::to_bytes(&unvalidated).unwrap();
255                prop_assert_eq!(&serialized_valid, &serialized_unvalidated);
256
257                let deserialized_valid_from_unvalidated: Ed25519PublicKey = bcs::from_bytes(&serialized_unvalidated).unwrap();
258                let deserialized_unvalidated_from_valid: UnvalidatedEd25519PublicKey = bcs::from_bytes(&serialized_valid).unwrap();
259
260                prop_assert_eq!(&valid, &deserialized_valid_from_unvalidated);
261                prop_assert_eq!(&unvalidated, &deserialized_unvalidated_from_valid);
262            }
263
264            // JSON A human-readable format
265            {
266                let serialized_valid = serde_json::to_string(&valid).unwrap();
267                let serialized_unvalidated = serde_json::to_string(&unvalidated).unwrap();
268                prop_assert_eq!(&serialized_valid, &serialized_unvalidated);
269
270                let deserialized_valid_from_unvalidated: Ed25519PublicKey = serde_json::from_str(&serialized_unvalidated).unwrap();
271                let deserialized_unvalidated_from_valid: UnvalidatedEd25519PublicKey = serde_json::from_str(&serialized_valid).unwrap();
272
273                prop_assert_eq!(&valid, &deserialized_valid_from_unvalidated);
274                prop_assert_eq!(&unvalidated, &deserialized_unvalidated_from_valid);
275            }
276
277
278            // Ensure Hash is implemented the same
279            let valid_hash = {
280                let mut hasher = DefaultHasher::new();
281                valid.hash(&mut hasher);
282                hasher.finish()
283            };
284
285            let unvalidated_hash = {
286                let mut hasher = DefaultHasher::new();
287                unvalidated.hash(&mut hasher);
288                hasher.finish()
289            };
290
291            prop_assert_eq!(valid_hash, unvalidated_hash);
292        }
293    }
294}