elliptic_curve/hash2curve/hash2field/
expand_msg.rs1pub(super) mod xmd;
4pub(super) mod xof;
5
6use crate::{Error, Result};
7use digest::{Digest, ExtendableOutput, Update, XofReader};
8use generic_array::typenum::{IsLess, U256};
9use generic_array::{ArrayLength, GenericArray};
10
11const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
13const MAX_DST_LEN: usize = 255;
15
16pub trait ExpandMsg<'a> {
21 type Expander: Expander + Sized;
23
24 fn expand_message(
29 msgs: &[&[u8]],
30 dsts: &'a [&'a [u8]],
31 len_in_bytes: usize,
32 ) -> Result<Self::Expander>;
33}
34
35pub trait Expander {
37 fn fill_bytes(&mut self, okm: &mut [u8]);
39}
40
41pub(crate) enum Domain<'a, L>
47where
48 L: ArrayLength<u8> + IsLess<U256>,
49{
50 Hashed(GenericArray<u8, L>),
52 Array(&'a [&'a [u8]]),
54}
55
56impl<'a, L> Domain<'a, L>
57where
58 L: ArrayLength<u8> + IsLess<U256>,
59{
60 pub fn xof<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
61 where
62 X: Default + ExtendableOutput + Update,
63 {
64 if dsts.is_empty() {
65 Err(Error)
66 } else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
67 let mut data = GenericArray::<u8, L>::default();
68 let mut hash = X::default();
69 hash.update(OVERSIZE_DST_SALT);
70
71 for dst in dsts {
72 hash.update(dst);
73 }
74
75 hash.finalize_xof().read(&mut data);
76
77 Ok(Self::Hashed(data))
78 } else {
79 Ok(Self::Array(dsts))
80 }
81 }
82
83 pub fn xmd<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
84 where
85 X: Digest<OutputSize = L>,
86 {
87 if dsts.is_empty() {
88 Err(Error)
89 } else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
90 Ok(Self::Hashed({
91 let mut hash = X::new();
92 hash.update(OVERSIZE_DST_SALT);
93
94 for dst in dsts {
95 hash.update(dst);
96 }
97
98 hash.finalize()
99 }))
100 } else {
101 Ok(Self::Array(dsts))
102 }
103 }
104
105 pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
106 match self {
107 Self::Hashed(d) => hash.update(d),
108 Self::Array(d) => {
109 for d in d.iter() {
110 hash.update(d)
111 }
112 }
113 }
114 }
115
116 pub fn len(&self) -> u8 {
117 match self {
118 Self::Hashed(_) => L::to_u8(),
120 Self::Array(d) => {
122 u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
123 }
124 }
125 }
126
127 #[cfg(test)]
128 pub fn assert(&self, bytes: &[u8]) {
129 let data = match self {
130 Domain::Hashed(d) => d.to_vec(),
131 Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
132 };
133 assert_eq!(data, bytes);
134 }
135
136 #[cfg(test)]
137 pub fn assert_dst(&self, bytes: &[u8]) {
138 let data = match self {
139 Domain::Hashed(d) => d.to_vec(),
140 Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
141 };
142 assert_eq!(data, &bytes[..bytes.len() - 1]);
143 assert_eq!(self.len(), bytes[bytes.len() - 1]);
144 }
145}