1use core::fmt;
10use core::marker::PhantomData;
11
12use block_buffer::BlockSizes;
13use digest::array::ArraySize;
14use digest::block_api::{
15 AlgorithmName,
16 Block,
17 BlockSizeUser,
18 Buffer,
19 BufferKindUser,
20 Eager,
21 ExtendableOutputCore,
22 FixedOutputCore,
23 OutputSizeUser,
24 Reset,
25 UpdateCore,
26 XofReaderCore,
27};
28use digest::common::hazmat::{
29 DeserializeStateError,
30 SerializableState,
31 SerializedState,
32};
33use digest::typenum::{
34 IsLessOrEqual,
35 True,
36 U0,
37 U200,
38};
39use digest::{
40 HashMarker,
41 Output,
42};
43
44pub use crate::cshake::{
45 CShake128Core,
46 CShake256Core,
47};
48use crate::{
49 DEFAULT_ROUND_COUNT,
50 PLEN,
51};
52
53pub const KECCAK_DIGEST_PAD: u8 = 0x01;
56
57#[derive(Clone)]
59pub struct SpongeHasherCore<
60 Rate,
61 OutputSize,
62 const PAD: u8,
63 const ROUNDS: usize = DEFAULT_ROUND_COUNT,
64> where
65 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
66 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
67{
68 state: [u64; PLEN],
69 _pd: PhantomData<(Rate, OutputSize)>,
70}
71
72impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> HashMarker
73 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
74where
75 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
76 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
77{
78}
79
80impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> BlockSizeUser
81 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
82where
83 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
84 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
85{
86 type BlockSize = Rate;
87}
88
89impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> BufferKindUser
90 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
91where
92 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
93 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
94{
95 type BufferKind = Eager;
96}
97
98impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> OutputSizeUser
99 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
100where
101 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
102 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
103{
104 type OutputSize = OutputSize;
105}
106
107impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> UpdateCore
108 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
109where
110 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
111 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
112{
113 #[inline]
114 fn update_blocks(&mut self, blocks: &[Block<Self>]) {
115 for block in blocks {
116 xor_block(&mut self.state, block);
117 lib_q_keccak::p1600(&mut self.state, ROUNDS);
118 }
119 }
120}
121
122impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> FixedOutputCore
123 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
124where
125 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
126 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
127{
128 #[inline]
129 fn finalize_fixed_core(&mut self, buffer: &mut Buffer<Self>, out: &mut Output<Self>) {
130 let pos = buffer.get_pos();
131 let mut block = buffer.pad_with_zeros();
132 block[pos] = PAD;
133 let n = block.len();
134 block[n - 1] |= 0x80;
135
136 xor_block(&mut self.state, &block);
137 lib_q_keccak::p1600(&mut self.state, ROUNDS);
138
139 for (o, s) in out.chunks_mut(8).zip(self.state.iter()) {
140 o.copy_from_slice(&s.to_le_bytes()[..o.len()]);
141 }
142 }
143}
144
145impl<Rate, const PAD: u8, const ROUNDS: usize> ExtendableOutputCore
146 for SpongeHasherCore<Rate, U0, PAD, ROUNDS>
147where
148 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
149{
150 type ReaderCore = SpongeReaderCore<Rate, ROUNDS>;
151
152 #[inline]
153 fn finalize_xof_core(&mut self, buffer: &mut Buffer<Self>) -> Self::ReaderCore {
154 let pos = buffer.get_pos();
155 let mut block = buffer.pad_with_zeros();
156 block[pos] = PAD;
157 let n = block.len();
158 block[n - 1] |= 0x80;
159
160 xor_block(&mut self.state, &block);
161 lib_q_keccak::p1600(&mut self.state, ROUNDS);
162
163 SpongeReaderCore::new(&self.state)
164 }
165}
166
167impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Default
168 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
169where
170 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
171 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
172{
173 #[inline]
174 fn default() -> Self {
175 Self {
176 state: Default::default(),
177 _pd: PhantomData,
178 }
179 }
180}
181
182impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Reset
183 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
184where
185 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
186 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
187{
188 #[inline]
189 fn reset(&mut self) {
190 *self = Default::default();
191 }
192}
193
194impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> AlgorithmName
195 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
196where
197 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
198 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
199{
200 fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.write_str("SpongeHasherCore")
203 }
204}
205
206impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> fmt::Debug
207 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
208where
209 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
210 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.write_str("SpongeHasherCore { ... }")
214 }
215}
216
217impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> Drop
218 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
219where
220 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
221 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
222{
223 fn drop(&mut self) {
224 #[cfg(feature = "zeroize")]
225 {
226 use digest::zeroize::Zeroize;
227 self.state.zeroize();
228 }
229 }
230}
231
232#[cfg_attr(docsrs, doc(cfg(feature = "zeroize")))]
233#[cfg(feature = "zeroize")]
234impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> digest::zeroize::ZeroizeOnDrop
235 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
236where
237 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
238 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
239{
240}
241
242impl<Rate, OutputSize, const PAD: u8, const ROUNDS: usize> SerializableState
243 for SpongeHasherCore<Rate, OutputSize, PAD, ROUNDS>
244where
245 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
246 OutputSize: ArraySize + IsLessOrEqual<U200, Output = True>,
247{
248 type SerializedStateSize = U200;
249
250 fn serialize(&self) -> SerializedState<Self> {
251 let mut serialized_state = SerializedState::<Self>::default();
252 let chunks = serialized_state.chunks_exact_mut(8);
253 for (val, chunk) in self.state.iter().zip(chunks) {
254 chunk.copy_from_slice(&val.to_le_bytes());
255 }
256
257 serialized_state
258 }
259
260 fn deserialize(
261 serialized_state: &SerializedState<Self>,
262 ) -> Result<Self, DeserializeStateError> {
263 let bytes: &[u8] = serialized_state.as_ref();
264 let (lanes, remainder) = bytes.as_chunks::<8>();
265 if !remainder.is_empty() || lanes.len() != PLEN {
266 return Err(DeserializeStateError);
267 }
268 let state = core::array::from_fn(|i| u64::from_le_bytes(lanes[i]));
269
270 Ok(Self {
271 state,
272 _pd: PhantomData,
273 })
274 }
275}
276
277#[derive(Clone)]
279pub struct SpongeReaderCore<Rate, const ROUNDS: usize = DEFAULT_ROUND_COUNT>
280where
281 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
282{
283 state: [u64; PLEN],
284 _pd: PhantomData<Rate>,
285}
286
287impl<Rate, const ROUNDS: usize> SpongeReaderCore<Rate, ROUNDS>
288where
289 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
290{
291 pub(crate) fn new(state: &[u64; PLEN]) -> Self {
292 Self {
293 state: *state,
294 _pd: PhantomData,
295 }
296 }
297}
298
299impl<Rate, const ROUNDS: usize> BlockSizeUser for SpongeReaderCore<Rate, ROUNDS>
300where
301 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
302{
303 type BlockSize = Rate;
304}
305
306impl<Rate, const ROUNDS: usize> XofReaderCore for SpongeReaderCore<Rate, ROUNDS>
307where
308 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
309{
310 #[inline]
311 fn read_block(&mut self) -> Block<Self> {
312 let mut block = Block::<Self>::default();
313 for (src, dst) in self.state.iter().zip(block.chunks_mut(8)) {
314 dst.copy_from_slice(&src.to_le_bytes()[..dst.len()]);
315 }
316 lib_q_keccak::p1600(&mut self.state, ROUNDS);
317 block
318 }
319}
320
321impl<Rate, const ROUNDS: usize> Drop for SpongeReaderCore<Rate, ROUNDS>
322where
323 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
324{
325 fn drop(&mut self) {
326 #[cfg(feature = "zeroize")]
327 {
328 use digest::zeroize::Zeroize;
329 self.state.zeroize();
330 }
331 }
332}
333
334impl<Rate, const ROUNDS: usize> fmt::Debug for SpongeReaderCore<Rate, ROUNDS>
335where
336 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>,
337{
338 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339 f.write_str("SpongeReaderCore { ... }")
340 }
341}
342
343#[cfg_attr(docsrs, doc(cfg(feature = "zeroize")))]
344#[cfg(feature = "zeroize")]
345impl<Rate, const ROUNDS: usize> digest::zeroize::ZeroizeOnDrop for SpongeReaderCore<Rate, ROUNDS> where
346 Rate: BlockSizes + IsLessOrEqual<U200, Output = True>
347{
348}
349
350pub(crate) fn xor_block(state: &mut [u64; PLEN], block: &[u8]) {
351 assert!(block.len() < 8 * PLEN);
352
353 let mut chunks = block.chunks_exact(8);
354 for (s, chunk) in state.iter_mut().zip(&mut chunks) {
355 *s ^= u64::from_le_bytes(
356 *chunk
357 .first_chunk::<8>()
358 .expect("8 bytes from chunk_exact(8)"),
359 );
360 }
361
362 let rem = chunks.remainder();
363 if !rem.is_empty() {
364 let mut buf = [0u8; 8];
365 buf[..rem.len()].copy_from_slice(rem);
366 let n = block.len() / 8;
367 state[n] ^= u64::from_le_bytes(buf);
368 }
369}