dusk_safe/
sponge.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4//
5// Copyright (c) DUSK NETWORK. All rights reserved.
6
7use alloc::vec::Vec;
8use zeroize::Zeroize;
9
10use crate::{tag_input, Call, Error};
11
12/// This trait defines the behavior of a sponge algorithm.
13///
14/// Note: The trait's specific implementation of addition enables usage within
15/// zero-knowledge circuits.
16pub trait Safe<T, const W: usize>
17where
18    T: Default + Copy + Zeroize,
19{
20    /// Apply one permutation to the state.
21    fn permute(&mut self, state: &mut [T; W]);
22
23    /// Create the tag by hashing the tag input to an element of type `T`.
24    ///
25    /// # Parameters
26    ///
27    /// - `input`: The domain-separator and IO-pattern encoded as a slice of
28    ///   bytes.
29    ///
30    /// # Returns
31    ///
32    /// A tag element as the hash of the input to a field element `T`.
33    fn tag(&mut self, input: &[u8]) -> T;
34
35    /// Add two values of type `T` and return the result.
36    ///
37    /// # Parameters
38    ///
39    /// - `right`: The right operand of type `T`.
40    /// - `left`: The left operand of type `T`.
41    ///
42    /// # Returns
43    ///
44    /// The result of the addition, of type `T`.
45    fn add(&mut self, right: &T, left: &T) -> T;
46
47    /// Create a state and initialize it with the tag and default values of `T`.
48    ///
49    /// # Parameters
50    ///
51    /// - `tag`: The initial tag value as computed by [`Self::tag`].
52    ///
53    /// # Returns
54    ///
55    /// An array of type `[T; W]` representing the initialized state.
56    fn initialized_state(tag: T) -> [T; W] {
57        let mut state = [T::default(); W];
58        state[0] = tag;
59        state
60    }
61}
62
63/// Struct that implements the Sponge API over field elements.
64///
65/// The capacity is fixed to one field element and the rate are `W - 1` field
66/// elements.
67#[derive(Debug, Clone, PartialEq)]
68pub struct Sponge<S, T, const W: usize>
69where
70    S: Safe<T, W>,
71    T: Default + Copy + Zeroize,
72{
73    state: [T; W],
74    pub(crate) safe: S,
75    pos_absorb: usize,
76    pos_squeeze: usize,
77    io_count: usize,
78    iopattern: Vec<Call>,
79    domain_sep: u64,
80    pub(crate) output: Vec<T>,
81}
82
83impl<S, T, const W: usize> Sponge<S, T, W>
84where
85    S: Safe<T, W>,
86    T: Default + Copy + Zeroize,
87{
88    /// The capacity of the sponge.
89    const CAPACITY: usize = 1;
90
91    /// The rate of the sponge.
92    const RATE: usize = W - Self::CAPACITY;
93
94    /// This initializes the sponge, setting the first element of the state to
95    /// the [`Safe::tag()`] and the other elements to the default value of
96    /// `T`. It’s done once in the lifetime of a sponge.
97    ///
98    /// # Parameters
99    ///
100    /// - `safe`: The sponge safe implementation.
101    /// - `iopattern`: The IO-pattern for the sponge.
102    /// - `domain_sep`: The domain separator to be used.
103    ///
104    /// # Returns
105    ///
106    /// A result containing the initialized Sponge on success, or an `Error` if
107    /// the IO-pattern is invalid.
108    pub fn start(
109        safe: S,
110        iopattern: impl Into<Vec<Call>>,
111        domain_sep: u64,
112    ) -> Result<Self, Error> {
113        // Compute the tag and initialize the state.
114        // Note: This will return an error if the IO-pattern is invalid.
115        let iopattern: Vec<Call> = iopattern.into();
116        let mut safe = safe;
117        let tag = safe.tag(&tag_input(&iopattern, domain_sep)?);
118        let state = S::initialized_state(tag);
119
120        Ok(Self {
121            state,
122            safe,
123            pos_absorb: 0,
124            pos_squeeze: 0,
125            io_count: 0,
126            iopattern,
127            domain_sep,
128            output: Vec::new(),
129        })
130    }
131
132    /// This marks the end of the sponge life, preventing any further operation.
133    /// In particular, the state is erased from memory.
134    ///
135    /// # Returns
136    ///
137    /// A result containing the output vector on success, or an `Error` if the
138    /// IO-pattern wasn't followed.
139    pub fn finish(mut self) -> Result<Vec<T>, Error> {
140        let ret = match self.io_count == self.iopattern.len() {
141            true => Ok(self.output.clone()),
142            false => Err(Error::IOPatternViolation),
143        };
144        // no matter the return, we erase the internal state of the sponge
145        self.zeroize();
146        ret
147    }
148
149    /// This absorbs `len` field elements from the input into the state with
150    /// interleaving calls to the permutation function. It also checks if the
151    /// call matches the IO-pattern.
152    ///
153    /// # Parameters
154    ///
155    /// - `len`: The number of field elements to absorb.
156    /// - `input`: The input slice of field elements.
157    ///
158    /// # Returns
159    ///
160    /// A result indicating success if the operation completes, or an `Error`
161    /// if the IO-pattern wasn't followed.
162    pub fn absorb(
163        &mut self,
164        len: usize,
165        input: impl AsRef<[T]>,
166    ) -> Result<(), Error> {
167        // Check that input yields enough elements
168        if input.as_ref().len() < len {
169            self.zeroize();
170            return Err(Error::TooFewInputElements);
171        }
172        // Check that the IO-pattern is followed
173        match self.iopattern.get(self.io_count) {
174            // only proceed if we expect a call to absorb with the correct
175            // length as per the IO-pattern
176            Some(Call::Absorb(call_len)) if *call_len == len => {}
177            Some(Call::Absorb(_)) => {
178                self.zeroize();
179                return Err(Error::IOPatternViolation);
180            }
181            _ => {
182                self.zeroize();
183                return Err(Error::IOPatternViolation);
184            }
185        }
186
187        // Absorb `len` elements into the state, calling [`permute`] when the
188        // absorb-position reached the rate.
189        for element in input.as_ref().iter().take(len) {
190            if self.pos_absorb == Self::RATE {
191                self.safe.permute(&mut self.state);
192
193                self.pos_absorb = 0;
194            }
195            // add the input to the state using `Safe::add`
196            let pos = self.pos_absorb + Self::CAPACITY;
197            let previous_value = self.state[pos];
198            let sum = self.safe.add(&previous_value, element);
199            self.state[pos] = sum;
200            self.pos_absorb += 1;
201        }
202
203        // Set squeeze position to rate to force a permutation at the next
204        // call to squeeze
205        self.pos_squeeze = Self::RATE;
206
207        // Increase the position for the IO-pattern
208        self.io_count += 1;
209
210        Ok(())
211    }
212
213    /// This extracts `len` field elements from the state with interleaving
214    /// calls to the permutation function. It also checks if the call matches
215    /// the IO-pattern.
216    ///
217    /// # Parameters
218    ///
219    /// - `len`: The number of field elements to squeeze.
220    ///
221    /// # Returns
222    ///
223    /// A result indicating success if the operation completes, or an `Error`
224    /// if the IO-pattern wasn't followed.
225    pub fn squeeze(&mut self, len: usize) -> Result<(), Error> {
226        // Check that the IO-pattern is followed
227        match self.iopattern.get(self.io_count) {
228            // only proceed if we expect a call to squeeze with the correct
229            // length as per the IO-pattern
230            Some(Call::Squeeze(call_len)) if *call_len == len => {}
231            Some(Call::Squeeze(_)) => {
232                self.zeroize();
233                return Err(Error::IOPatternViolation);
234            }
235            _ => {
236                self.zeroize();
237                return Err(Error::IOPatternViolation);
238            }
239        }
240
241        // Squeeze 'len` field elements from the state, calling [`permute`] when
242        // the squeeze-position reached the rate.
243        for _ in 0..len {
244            if self.pos_squeeze == Self::RATE {
245                self.safe.permute(&mut self.state);
246
247                self.pos_squeeze = 0;
248                self.pos_absorb = 0;
249            }
250            self.output
251                .push(self.state[self.pos_squeeze + Self::CAPACITY]);
252            self.pos_squeeze += 1;
253        }
254
255        // Increase the position for the IO-pattern
256        self.io_count += 1;
257
258        Ok(())
259    }
260}
261
262impl<S, T, const W: usize> Drop for Sponge<S, T, W>
263where
264    S: Safe<T, W>,
265    T: Default + Copy + Zeroize,
266{
267    fn drop(&mut self) {
268        self.zeroize();
269    }
270}
271
272impl<S, T, const W: usize> Zeroize for Sponge<S, T, W>
273where
274    S: Safe<T, W>,
275    T: Default + Copy + Zeroize,
276{
277    fn zeroize(&mut self) {
278        self.state.zeroize();
279        self.pos_absorb.zeroize();
280        self.pos_squeeze.zeroize();
281        self.output.zeroize();
282    }
283}