elliptic_curve/hash2curve/hash2field/
expand_msg.rs

1//! `expand_message` interface `for hash_to_field`.
2
3pub(super) mod xmd;
4pub(super) mod xof;
5
6use crate::{Error, Result};
7use digest::{Digest, ExtendableOutput, Update, XofReader};
8use generic_array::typenum::{IsLess, U256};
9use generic_array::{ArrayLength, GenericArray};
10
11/// Salt when the DST is too long
12const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
13/// Maximum domain separation tag length
14const MAX_DST_LEN: usize = 255;
15
16/// Trait for types implementing expand_message interface for `hash_to_field`.
17///
18/// # Errors
19/// See implementors of [`ExpandMsg`] for errors.
20pub trait ExpandMsg<'a> {
21    /// Type holding data for the [`Expander`].
22    type Expander: Expander + Sized;
23
24    /// Expands `msg` to the required number of bytes.
25    ///
26    /// Returns an expander that can be used to call `read` until enough
27    /// bytes have been consumed
28    fn expand_message(
29        msgs: &[&[u8]],
30        dsts: &'a [&'a [u8]],
31        len_in_bytes: usize,
32    ) -> Result<Self::Expander>;
33}
34
35/// Expander that, call `read` until enough bytes have been consumed.
36pub trait Expander {
37    /// Fill the array with the expanded bytes
38    fn fill_bytes(&mut self, okm: &mut [u8]);
39}
40
41/// The domain separation tag
42///
43/// Implements [section 5.4.3 of `draft-irtf-cfrg-hash-to-curve-13`][dst].
44///
45/// [dst]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-13#section-5.4.3
46pub(crate) enum Domain<'a, L>
47where
48    L: ArrayLength<u8> + IsLess<U256>,
49{
50    /// > 255
51    Hashed(GenericArray<u8, L>),
52    /// <= 255
53    Array(&'a [&'a [u8]]),
54}
55
56impl<'a, L> Domain<'a, L>
57where
58    L: ArrayLength<u8> + IsLess<U256>,
59{
60    pub fn xof<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
61    where
62        X: Default + ExtendableOutput + Update,
63    {
64        if dsts.is_empty() {
65            Err(Error)
66        } else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
67            let mut data = GenericArray::<u8, L>::default();
68            let mut hash = X::default();
69            hash.update(OVERSIZE_DST_SALT);
70
71            for dst in dsts {
72                hash.update(dst);
73            }
74
75            hash.finalize_xof().read(&mut data);
76
77            Ok(Self::Hashed(data))
78        } else {
79            Ok(Self::Array(dsts))
80        }
81    }
82
83    pub fn xmd<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
84    where
85        X: Digest<OutputSize = L>,
86    {
87        if dsts.is_empty() {
88            Err(Error)
89        } else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
90            Ok(Self::Hashed({
91                let mut hash = X::new();
92                hash.update(OVERSIZE_DST_SALT);
93
94                for dst in dsts {
95                    hash.update(dst);
96                }
97
98                hash.finalize()
99            }))
100        } else {
101            Ok(Self::Array(dsts))
102        }
103    }
104
105    pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
106        match self {
107            Self::Hashed(d) => hash.update(d),
108            Self::Array(d) => {
109                for d in d.iter() {
110                    hash.update(d)
111                }
112            }
113        }
114    }
115
116    pub fn len(&self) -> u8 {
117        match self {
118            // Can't overflow because it's enforced on a type level.
119            Self::Hashed(_) => L::to_u8(),
120            // Can't overflow because it's checked on creation.
121            Self::Array(d) => {
122                u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
123            }
124        }
125    }
126
127    #[cfg(test)]
128    pub fn assert(&self, bytes: &[u8]) {
129        let data = match self {
130            Domain::Hashed(d) => d.to_vec(),
131            Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
132        };
133        assert_eq!(data, bytes);
134    }
135
136    #[cfg(test)]
137    pub fn assert_dst(&self, bytes: &[u8]) {
138        let data = match self {
139            Domain::Hashed(d) => d.to_vec(),
140            Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
141        };
142        assert_eq!(data, &bytes[..bytes.len() - 1]);
143        assert_eq!(self.len(), bytes[bytes.len() - 1]);
144    }
145}