dusk_safe/sponge.rs
1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4//
5// Copyright (c) DUSK NETWORK. All rights reserved.
6
7use alloc::vec::Vec;
8use zeroize::Zeroize;
9
10use crate::{tag_input, Call, Error};
11
12/// This trait defines the behavior of a sponge algorithm.
13///
14/// Note: The trait's specific implementation of addition enables usage within
15/// zero-knowledge circuits.
16pub trait Safe<T, const W: usize>
17where
18 T: Default + Copy + Zeroize,
19{
20 /// Apply one permutation to the state.
21 fn permute(&mut self, state: &mut [T; W]);
22
23 /// Create the tag by hashing the tag input to an element of type `T`.
24 ///
25 /// # Parameters
26 ///
27 /// - `input`: The domain-separator and IO-pattern encoded as a slice of
28 /// bytes.
29 ///
30 /// # Returns
31 ///
32 /// A tag element as the hash of the input to a field element `T`.
33 fn tag(&mut self, input: &[u8]) -> T;
34
35 /// Add two values of type `T` and return the result.
36 ///
37 /// # Parameters
38 ///
39 /// - `right`: The right operand of type `T`.
40 /// - `left`: The left operand of type `T`.
41 ///
42 /// # Returns
43 ///
44 /// The result of the addition, of type `T`.
45 fn add(&mut self, right: &T, left: &T) -> T;
46
47 /// Create a state and initialize it with the tag and default values of `T`.
48 ///
49 /// # Parameters
50 ///
51 /// - `tag`: The initial tag value as computed by [`Self::tag`].
52 ///
53 /// # Returns
54 ///
55 /// An array of type `[T; W]` representing the initialized state.
56 fn initialized_state(tag: T) -> [T; W] {
57 let mut state = [T::default(); W];
58 state[0] = tag;
59 state
60 }
61}
62
63/// Struct that implements the Sponge API over field elements.
64///
65/// The capacity is fixed to one field element and the rate are `W - 1` field
66/// elements.
67#[derive(Debug, Clone, PartialEq)]
68pub struct Sponge<S, T, const W: usize>
69where
70 S: Safe<T, W>,
71 T: Default + Copy + Zeroize,
72{
73 state: [T; W],
74 pub(crate) safe: S,
75 pos_absorb: usize,
76 pos_squeeze: usize,
77 io_count: usize,
78 iopattern: Vec<Call>,
79 domain_sep: u64,
80 pub(crate) output: Vec<T>,
81}
82
83impl<S, T, const W: usize> Sponge<S, T, W>
84where
85 S: Safe<T, W>,
86 T: Default + Copy + Zeroize,
87{
88 /// The capacity of the sponge.
89 const CAPACITY: usize = 1;
90
91 /// The rate of the sponge.
92 const RATE: usize = W - Self::CAPACITY;
93
94 /// This initializes the sponge, setting the first element of the state to
95 /// the [`Safe::tag()`] and the other elements to the default value of
96 /// `T`. It’s done once in the lifetime of a sponge.
97 ///
98 /// # Parameters
99 ///
100 /// - `safe`: The sponge safe implementation.
101 /// - `iopattern`: The IO-pattern for the sponge.
102 /// - `domain_sep`: The domain separator to be used.
103 ///
104 /// # Returns
105 ///
106 /// A result containing the initialized Sponge on success, or an `Error` if
107 /// the IO-pattern is invalid.
108 pub fn start(
109 safe: S,
110 iopattern: impl Into<Vec<Call>>,
111 domain_sep: u64,
112 ) -> Result<Self, Error> {
113 // Compute the tag and initialize the state.
114 // Note: This will return an error if the IO-pattern is invalid.
115 let iopattern: Vec<Call> = iopattern.into();
116 let mut safe = safe;
117 let tag = safe.tag(&tag_input(&iopattern, domain_sep)?);
118 let state = S::initialized_state(tag);
119
120 Ok(Self {
121 state,
122 safe,
123 pos_absorb: 0,
124 pos_squeeze: 0,
125 io_count: 0,
126 iopattern,
127 domain_sep,
128 output: Vec::new(),
129 })
130 }
131
132 /// This marks the end of the sponge life, preventing any further operation.
133 /// In particular, the state is erased from memory.
134 ///
135 /// # Returns
136 ///
137 /// A result containing the output vector on success, or an `Error` if the
138 /// IO-pattern wasn't followed.
139 pub fn finish(mut self) -> Result<Vec<T>, Error> {
140 let ret = match self.io_count == self.iopattern.len() {
141 true => Ok(self.output.clone()),
142 false => Err(Error::IOPatternViolation),
143 };
144 // no matter the return, we erase the internal state of the sponge
145 self.zeroize();
146 ret
147 }
148
149 /// This absorbs `len` field elements from the input into the state with
150 /// interleaving calls to the permutation function. It also checks if the
151 /// call matches the IO-pattern.
152 ///
153 /// # Parameters
154 ///
155 /// - `len`: The number of field elements to absorb.
156 /// - `input`: The input slice of field elements.
157 ///
158 /// # Returns
159 ///
160 /// A result indicating success if the operation completes, or an `Error`
161 /// if the IO-pattern wasn't followed.
162 pub fn absorb(
163 &mut self,
164 len: usize,
165 input: impl AsRef<[T]>,
166 ) -> Result<(), Error> {
167 // Check that input yields enough elements
168 if input.as_ref().len() < len {
169 self.zeroize();
170 return Err(Error::TooFewInputElements);
171 }
172 // Check that the IO-pattern is followed
173 match self.iopattern.get(self.io_count) {
174 // only proceed if we expect a call to absorb with the correct
175 // length as per the IO-pattern
176 Some(Call::Absorb(call_len)) if *call_len == len => {}
177 Some(Call::Absorb(_)) => {
178 self.zeroize();
179 return Err(Error::IOPatternViolation);
180 }
181 _ => {
182 self.zeroize();
183 return Err(Error::IOPatternViolation);
184 }
185 }
186
187 // Absorb `len` elements into the state, calling [`permute`] when the
188 // absorb-position reached the rate.
189 for element in input.as_ref().iter().take(len) {
190 if self.pos_absorb == Self::RATE {
191 self.safe.permute(&mut self.state);
192
193 self.pos_absorb = 0;
194 }
195 // add the input to the state using `Safe::add`
196 let pos = self.pos_absorb + Self::CAPACITY;
197 let previous_value = self.state[pos];
198 let sum = self.safe.add(&previous_value, element);
199 self.state[pos] = sum;
200 self.pos_absorb += 1;
201 }
202
203 // Set squeeze position to rate to force a permutation at the next
204 // call to squeeze
205 self.pos_squeeze = Self::RATE;
206
207 // Increase the position for the IO-pattern
208 self.io_count += 1;
209
210 Ok(())
211 }
212
213 /// This extracts `len` field elements from the state with interleaving
214 /// calls to the permutation function. It also checks if the call matches
215 /// the IO-pattern.
216 ///
217 /// # Parameters
218 ///
219 /// - `len`: The number of field elements to squeeze.
220 ///
221 /// # Returns
222 ///
223 /// A result indicating success if the operation completes, or an `Error`
224 /// if the IO-pattern wasn't followed.
225 pub fn squeeze(&mut self, len: usize) -> Result<(), Error> {
226 // Check that the IO-pattern is followed
227 match self.iopattern.get(self.io_count) {
228 // only proceed if we expect a call to squeeze with the correct
229 // length as per the IO-pattern
230 Some(Call::Squeeze(call_len)) if *call_len == len => {}
231 Some(Call::Squeeze(_)) => {
232 self.zeroize();
233 return Err(Error::IOPatternViolation);
234 }
235 _ => {
236 self.zeroize();
237 return Err(Error::IOPatternViolation);
238 }
239 }
240
241 // Squeeze 'len` field elements from the state, calling [`permute`] when
242 // the squeeze-position reached the rate.
243 for _ in 0..len {
244 if self.pos_squeeze == Self::RATE {
245 self.safe.permute(&mut self.state);
246
247 self.pos_squeeze = 0;
248 self.pos_absorb = 0;
249 }
250 self.output
251 .push(self.state[self.pos_squeeze + Self::CAPACITY]);
252 self.pos_squeeze += 1;
253 }
254
255 // Increase the position for the IO-pattern
256 self.io_count += 1;
257
258 Ok(())
259 }
260}
261
262impl<S, T, const W: usize> Drop for Sponge<S, T, W>
263where
264 S: Safe<T, W>,
265 T: Default + Copy + Zeroize,
266{
267 fn drop(&mut self) {
268 self.zeroize();
269 }
270}
271
272impl<S, T, const W: usize> Zeroize for Sponge<S, T, W>
273where
274 S: Safe<T, W>,
275 T: Default + Copy + Zeroize,
276{
277 fn zeroize(&mut self) {
278 self.state.zeroize();
279 self.pos_absorb.zeroize();
280 self.pos_squeeze.zeroize();
281 self.output.zeroize();
282 }
283}