hashsigs_rs/
lib.rs

1// Copyright (C) 2024 quip.network
2//
3// This program is free software: you can redistribute it and/or modify
4// it under the terms of the GNU Affero General Public License as published by
5// the Free Software Foundation, either version 3 of the License, or
6// (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11// GNU Affero General Public License for more details.
12//
13// You should have received a copy of the GNU Affero General Public License
14// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15//
16// SPDX-License-Identifier: AGPL-3.0-or-later
17
18//! WOTS+ (Winternitz One-Time Signature Plus) implementation
19
20/// Hash function type for WOTS+
21type HashFn = fn(&[u8]) -> [u8; 32];
22
23/// Constants from the WOTS+ implementation
24pub mod constants {
25    /// HashLen: The WOTS+ `n` security parameter which is the size 
26    /// of the hash function output in bytes.
27    /// This is 32 for keccak256 (256 / 8 = 32)
28    pub const HASH_LEN: usize = 32;
29
30    /// MessageLen: The WOTS+ `m` parameter which is the size 
31    /// of the message to be signed in bytes 
32    /// (and also the size of our hash function)
33    ///
34    /// This is 32 for keccak256 (256 / 8 = 32)
35    ///
36    /// Note that this is not the message length itself as, like 
37    /// with most signatures, we hash the message and then compute
38    /// the signature on the hash of the message.
39    pub const MESSAGE_LEN: usize = HASH_LEN;
40
41    /// ChainLen: The WOTS+ `w`(internitz) parameter. 
42    /// This corresponds to the number of hash chains for each public
43    /// key segment and the base-w representation of the message
44    /// and checksum.
45    /// 
46    /// A larger value means a smaller signature size but a longer
47    /// computation time.
48    /// 
49    /// For XMSS (rfc8391) this value is limited to 4 or 16 because
50    /// they simplify the algorithm and offer the best trade-offs.
51    pub const CHAIN_LEN: usize = 16;
52
53    /// lg(ChainLen) so we don't calculate it (lg(16) == 4)
54    pub const LG_CHAIN_LEN: usize = {
55        // Using const fn ilog2 to calculate log2(CHAIN_LEN) at compile time
56        CHAIN_LEN.ilog2() as usize
57    };
58
59    /// NumMessageChunks: the `len_1` parameter which is the number of
60    /// message chunks. This is 
61    /// ceil(8n / lg(w)) -> ceil(8 * HASH_LEN / lg(CHAIN_LEN))
62    /// or ceil(32*8 / lg(16)) -> 256 / 4 = 64
63    /// Python:  math.ceil(32*8 / math.log(16,2))
64    pub const NUM_MESSAGE_CHUNKS: usize = {
65        // Since HASH_LEN = 32, CHAIN_LEN = 16 (2^4), we know:
66        // 32*8 = 256, log2(16) = 4
67        // 256/4 = 64
68        (8 * HASH_LEN).div_ceil(LG_CHAIN_LEN)
69    };
70
71    #[cfg(test)]
72    mod tests {
73        use super::*;
74
75        #[test]
76        fn test_num_message_chunks() {
77            assert_eq!(NUM_MESSAGE_CHUNKS, 64);
78        }
79    }
80
81    /// NumChecksumChunks: the `len_2` parameter which is the number of
82    /// checksum chunks. This is
83    /// floor(lg(len_1 * (w - 1)) / lg(w)) + 1
84    /// -> floor(lg(NUM_MESSAGE_CHUNKS * (CHAIN_LEN - 1)) / lg(CHAIN_LEN)) + 1
85    /// -> floor(lg(64 * 15) / lg(16)) + 1 = 3
86    /// Python: math.floor(math.log(64 * 15, 2) / math.log(16, 2)) + 1
87    pub const NUM_CHECKSUM_CHUNKS: usize = {
88        // Since NUM_MESSAGE_CHUNKS = 64, CHAIN_LEN = 16:
89        // 64 * 15 = 960
90        // log2(960) ≈ 9.907
91        // log2(16) = 4
92        // floor(9.907 / 4) + 1 = floor(2.477) + 1 = 3
93        ((NUM_MESSAGE_CHUNKS * (CHAIN_LEN - 1)).ilog2() as usize / LG_CHAIN_LEN) + 1
94    };
95
96    pub const NUM_SIGNATURE_CHUNKS: usize = NUM_MESSAGE_CHUNKS + NUM_CHECKSUM_CHUNKS;
97    /// Size of signature in bytes
98    pub const SIGNATURE_SIZE: usize = NUM_SIGNATURE_CHUNKS * HASH_LEN;
99    /// Size of public key in bytes
100    pub const PUBLIC_KEY_SIZE: usize = HASH_LEN * 2;
101    /// PRF input size (prefix + seed + index)
102    pub const PRF_INPUT_SIZE: usize = 1 + HASH_LEN + 2;
103}
104
105/// PublicKey consists of two parts:
106/// 1. The public seed used to generate randomization elements
107/// 2. The hash of all public key segments concatenated together
108#[derive(Debug, Clone, Copy)]
109pub struct PublicKey {
110    pub public_seed: [u8; constants::HASH_LEN],
111    pub public_key_hash: [u8; constants::HASH_LEN],
112}
113
114impl PublicKey {
115    /// Convert the public key to bytes
116    /// Returns a byte array of size PUBLIC_KEY_SIZE containing the public seed followed by the public key hash
117    pub fn to_bytes(&self) -> [u8; constants::PUBLIC_KEY_SIZE] {
118        let mut result = [0u8; constants::PUBLIC_KEY_SIZE];
119        result[..constants::HASH_LEN].copy_from_slice(&self.public_seed);
120        result[constants::HASH_LEN..].copy_from_slice(&self.public_key_hash);
121        result
122    }
123
124    /// Create a PublicKey from bytes
125    /// Returns None if the input is not of the correct length
126    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
127        if bytes.len() != constants::PUBLIC_KEY_SIZE {
128            return None;
129        }
130        let mut public_seed = [0u8; constants::HASH_LEN];
131        let mut public_key_hash = [0u8; constants::HASH_LEN];
132        
133        public_seed.copy_from_slice(&bytes[..constants::HASH_LEN]);
134        public_key_hash.copy_from_slice(&bytes[constants::HASH_LEN..]);
135        
136        Some(PublicKey {
137            public_seed,
138            public_key_hash,
139        })
140    }
141}
142
143pub struct WOTSPlus {
144    hash_fn: HashFn,
145}
146
147impl WOTSPlus {
148    /// Create a new WOTS+ instance with the specified hash function
149    pub fn new(hash_fn: HashFn) -> Self {
150        Self { hash_fn }
151    }
152
153    /// Generate randomization elements from seed and index
154    /// Similar to XMSS RFC 8391 section 5.1
155    /// Uses a prefix byte (0x03) to domain separate the PRF
156    fn prf(&self, seed: &[u8; constants::HASH_LEN], index: u16) -> [u8; constants::HASH_LEN] {
157        let mut input = [0u8; constants::PRF_INPUT_SIZE];
158        input[0] = 0x03; // prefix to domain separate
159        input[1..33].copy_from_slice(seed); // the seed input
160        input[33..].copy_from_slice(&index.to_be_bytes()); // the index/position
161        (self.hash_fn)(&input)
162    }
163
164    /// Generate randomization elements from public seed
165    /// These elements are used in the chain function to randomize each hash
166    pub fn generate_randomization_elements(
167        &self,
168        public_seed: &[u8; constants::HASH_LEN]
169    ) -> Vec<[u8; constants::HASH_LEN]> {
170        let mut elements = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);
171        for i in 0..constants::NUM_SIGNATURE_CHUNKS {
172            elements.push(self.prf(public_seed, i as u16));
173        }
174        elements
175    }
176
177    /// XOR two 32-byte arrays
178    fn xor(a: &[u8; constants::HASH_LEN], b: &[u8; constants::HASH_LEN]) -> [u8; constants::HASH_LEN] {
179        let mut result = [0u8; constants::HASH_LEN];
180        for i in 0..constants::HASH_LEN {
181            result[i] = a[i] ^ b[i];
182        }
183        result
184    }
185
186    /// Chain function (c_k^i function)
187    /// This is the core of WOTS+, implementing the hash chain with randomization
188    /// The chain function takes the previous chain output, XORs it with a randomization element,
189    /// and then hashes the result. This is repeated 'steps' times.
190    fn chain(
191        &self,
192        prev_chain_out: &[u8; constants::HASH_LEN],
193        randomization_elements: &[[u8; constants::HASH_LEN]],
194        index: u16,
195        steps: u16,
196    ) -> [u8; constants::HASH_LEN] {
197        let mut chain_out = *prev_chain_out;
198        for i in 1..=steps {
199            let xored = Self::xor(&chain_out, &randomization_elements[(i + index) as usize]);
200            chain_out = (self.hash_fn)(&xored);
201        }
202        chain_out
203    }
204
205    /// Compute message hash chain indexes
206    /// This function performs two main tasks:
207    /// 1. Convert the message to base-w representation (or base of CHAIN_LEN representation)
208    /// 2. Compute and append the checksum in base-w representation
209    /// 
210    /// These numbers are used to index into each hash chain which is rooted at a secret key segment
211    /// and produces a public key segment at the end of the chain. Verification of a signature means
212    /// using these indexes into each hash chain to recompute the corresponding public key segment.
213    fn compute_message_hash_chain_indexes(&self, message: &[u8]) -> Vec<u8> {
214        if message.len() != constants::MESSAGE_LEN {
215            panic!("Message length must be {} bytes", constants::MESSAGE_LEN);
216        }
217
218        let mut chain_segments_indexes = vec![0u8; constants::NUM_SIGNATURE_CHUNKS];
219        let mut idx = 0;
220        
221        // Convert message to base-w representation
222        for byte in message {
223            chain_segments_indexes[idx] = byte >> 4;
224            chain_segments_indexes[idx + 1] = byte & 0x0f;
225            idx += 2;
226        }
227
228        // Compute checksum
229        let mut checksum: u32 = 0;
230        for &value in &chain_segments_indexes[..constants::NUM_MESSAGE_CHUNKS] {
231            checksum += constants::CHAIN_LEN as u32 - 1 - value as u32
232        }
233
234        // Convert checksum to base-w and append
235        // This is left-shifting the checksum to ensure proper alignment when
236        // converting to base-w representation
237        for i in (0..constants::NUM_CHECKSUM_CHUNKS).rev() {
238            let shift = i * constants::LG_CHAIN_LEN as usize;
239            chain_segments_indexes[idx] = ((checksum >> shift) & (constants::CHAIN_LEN as u32 - 1)) as u8;
240            idx += 1;
241        }
242
243        chain_segments_indexes
244    }
245
246    /// Generate public key from a private key
247    pub fn get_public_key(&self, private_key: &[u8; constants::HASH_LEN]) -> PublicKey {
248        let public_seed = self.prf(private_key, 0);
249        self.get_public_key_with_public_seed(private_key, &public_seed)
250    }
251    pub fn get_public_key_with_public_seed(&self, private_key: &[u8; constants::HASH_LEN], public_seed: &[u8; constants::HASH_LEN]) -> PublicKey {
252        let randomization_elements = self.generate_randomization_elements(&public_seed);
253        let function_key = randomization_elements[0];
254
255        let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
256
257        for i in 0..constants::NUM_SIGNATURE_CHUNKS {
258            let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
259            to_hash[..constants::HASH_LEN].copy_from_slice(&function_key);
260            to_hash[constants::HASH_LEN..].copy_from_slice(&self.prf(private_key, (i + 1) as u16));
261            
262            let secret_key_segment = (self.hash_fn)(&to_hash);
263            let segment = self.chain(
264                &secret_key_segment,
265                &randomization_elements,
266                0,
267                (constants::CHAIN_LEN - 1) as u16,
268            );
269            
270            public_key_segments.extend_from_slice(&segment);
271        }
272
273        let public_key_hash = (self.hash_fn)(&public_key_segments);
274        
275        PublicKey {
276            public_seed: *public_seed,
277            public_key_hash,
278        }
279    }
280
281
282    /// Generate a WOTS+ key pair
283    /// The process works as follows:
284    /// 1. Generate private key from seed
285    /// 2. Generate public seed from private key
286    /// 3. Generate randomization elements from public seed
287    /// 4. For each signature chunk:
288    ///    a. Generate a secret key segment
289    ///    b. Run the chain function to the end to get the public key segment
290    /// 5. Hash all public key segments together to get the final public key
291    pub fn generate_key_pair(&self, private_seed: &[u8; constants::HASH_LEN]) -> (PublicKey, [u8; constants::HASH_LEN]) {
292        let private_key = (self.hash_fn)(private_seed);
293        let public_key = self.get_public_key(&private_key);
294        (public_key, private_key)
295    }
296
297    /// Sign a message with a WOTS+ private key
298    /// The process works as follows:
299    /// 1. Generate public seed from private key
300    /// 2. Generate randomization elements from public seed
301    /// 3. Convert message to chain indexes (including checksum)
302    /// 4. For each chain index:
303    ///    a. Generate the secret key segment
304    ///    b. Run the chain function to the index position
305    pub fn sign(&self, private_key: &[u8; constants::HASH_LEN], message: &[u8]) -> Vec<[u8; constants::HASH_LEN]> {
306        if message.len() != constants::MESSAGE_LEN {
307            panic!("Message length must be {} bytes", constants::MESSAGE_LEN);
308        }
309
310        let public_seed = self.prf(private_key, 0);
311        let randomization_elements = self.generate_randomization_elements(&public_seed);
312        let function_key = randomization_elements[0];
313        
314        let chain_segments = self.compute_message_hash_chain_indexes(message);
315        let mut signature = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);
316
317        for (i, &chain_idx) in chain_segments.iter().enumerate() {
318            let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
319            to_hash[..constants::HASH_LEN].copy_from_slice(&function_key);
320            to_hash[constants::HASH_LEN..].copy_from_slice(&self.prf(private_key, (i + 1) as u16));
321            
322            let secret_key_segment = (self.hash_fn)(&to_hash);
323            let sig_segment = self.chain(
324                &secret_key_segment,
325                &randomization_elements,
326                0,
327                chain_idx as u16,
328            );
329            signature.push(sig_segment);
330        }
331
332        signature
333    }
334
335    /// Verify a WOTS+ signature
336    /// The verification process works as follows:
337    /// 1. The first part of the publicKey is a public seed used to regenerate the randomization elements
338    /// 2. The second part of the publicKey is the hash of the NumMessageChunks + NumChecksumChunks public key segments
339    /// 3. Convert the Message to "base-w" representation (or base of ChainLen representation)
340    /// 4. Compute and add the checksum
341    /// 5. Run the chain function on each segment to reproduce each public key segment
342    /// 6. Hash all public key segments together to recreate the original public key
343    pub fn verify(&self, public_key: &PublicKey, message: &[u8], signature: &Vec<[u8; constants::HASH_LEN]>) -> bool {
344        
345        if message.len() != constants::MESSAGE_LEN {
346            return false;
347        }
348        if signature.len() != constants::NUM_SIGNATURE_CHUNKS {
349            return false;
350        }
351
352        let randomization_elements = self.generate_randomization_elements(&public_key.public_seed);
353        
354        let chain_segments = self.compute_message_hash_chain_indexes(message);
355        
356        let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
357
358        // Compute each public key segment. These are done by taking the signature, which is prevChainOut at chainIdx,
359        // and completing the hash chain via the chain function to recompute the public key segment.
360        for (i, &chain_idx) in chain_segments.iter().enumerate() {
361            let num_iterations = (constants::CHAIN_LEN - 1 - chain_idx as usize) as u16;
362            let segment = self.chain(
363                &signature[i],
364                &randomization_elements,
365                chain_idx as u16,
366                num_iterations,
367            );
368            
369            public_key_segments.extend_from_slice(&segment);
370        }
371
372        // Hash all public key segments together to recreate the original public key
373        let computed_hash = (self.hash_fn)(&public_key_segments);
374        
375        // Compare computed hash with stored public key hash
376        computed_hash == public_key.public_key_hash
377    }
378
379    /// Verify a WOTS+ signature using pre-computed randomization elements
380    /// This is an optimization that allows reusing the randomization elements
381    /// when verifying multiple signatures with the same public seed
382    pub fn verify_with_randomization_elements(
383        &self,
384        public_key_hash: &[u8; constants::HASH_LEN],
385        message: &[u8],
386        signature: &Vec<[u8; constants::HASH_LEN]>,
387        randomization_elements: &Vec<[u8; constants::HASH_LEN]>,
388    ) -> bool {
389        if message.len() != constants::MESSAGE_LEN {
390            return false;
391        }
392        if signature.len() != constants::NUM_SIGNATURE_CHUNKS {
393            return false;
394        }
395        if randomization_elements.len() != constants::NUM_SIGNATURE_CHUNKS {
396            return false;
397        }
398
399        let chain_segments = self.compute_message_hash_chain_indexes(message);
400        let mut public_key_segments = [0u8; constants::SIGNATURE_SIZE];
401        
402        // Compute each public key segment using the pre-computed randomization elements
403        for (i, &chain_idx) in chain_segments.iter().enumerate() {
404            let num_iterations = (constants::CHAIN_LEN - 1 - chain_idx as usize) as u16;
405            let segment = self.chain(
406                &signature[i],
407                randomization_elements,
408                chain_idx as u16,
409                num_iterations,
410            );
411            
412            let offset = i * constants::HASH_LEN;
413            public_key_segments[offset..offset + constants::HASH_LEN].copy_from_slice(&segment);
414        }
415
416        // Hash all public key segments together and compare with the provided hash
417        let computed_hash = (self.hash_fn)(&public_key_segments);
418        computed_hash == *public_key_hash
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    // Mock hash function for testing
427    fn mock_hash(data: &[u8]) -> [u8; 32] {
428        let mut output = [0u8; 32];
429        for (i, &byte) in data.iter().enumerate().take(32) {
430            output[i] = byte;
431        }
432        output
433    }
434
435    #[test]
436    fn test_constants() {
437        assert_eq!(constants::HASH_LEN, 32);
438        assert_eq!(constants::MESSAGE_LEN, 32);
439        assert_eq!(constants::CHAIN_LEN, 16);
440        assert_eq!(constants::NUM_MESSAGE_CHUNKS, 64);
441        assert_eq!(constants::NUM_CHECKSUM_CHUNKS, 3);
442        assert_eq!(constants::NUM_SIGNATURE_CHUNKS, 67);
443    }
444
445    #[test]
446    fn test_key_generation_and_signing() {
447        let wots = WOTSPlus::new(mock_hash);
448        let private_seed = [1u8; 32];
449        let (public_key, private_key) = wots.generate_key_pair(&private_seed);
450        
451        let message = [2u8; constants::MESSAGE_LEN];
452        let signature = wots.sign(&private_key, &message);
453        
454        assert!(wots.verify(&public_key, &message, &signature));
455    }
456
457    #[test]
458    fn test_invalid_message_length() {
459        let wots = WOTSPlus::new(mock_hash);
460        let private_seed = [1u8; 32];
461        let (public_key, _) = wots.generate_key_pair(&private_seed);
462        
463        let invalid_message = [2u8; constants::MESSAGE_LEN + 1];
464        let signature: Vec<[u8; 32]> = vec![[0u8; 32]; constants::NUM_SIGNATURE_CHUNKS];
465        assert!(!wots.verify(&public_key, &invalid_message, &signature));
466    }
467
468    #[test]
469    fn test_invalid_signature_length() {
470        let wots = WOTSPlus::new(mock_hash);
471        let private_seed = [1u8; 32];
472        let (public_key, _) = wots.generate_key_pair(&private_seed);
473        
474        let message = [2u8; constants::MESSAGE_LEN];
475        let signature: Vec<[u8; 32]> = vec![[0u8; 32]; constants::NUM_SIGNATURE_CHUNKS];
476        assert!(!wots.verify(&public_key, &message, &signature));
477    }
478
479    #[test]
480    fn test_public_key_serialization() {
481        let public_key = PublicKey {
482            public_seed: [1u8; constants::HASH_LEN],
483            public_key_hash: [2u8; constants::HASH_LEN],
484        };
485        
486        let bytes = public_key.to_bytes();
487        let recovered = PublicKey::from_bytes(&bytes).unwrap();
488        
489        assert_eq!(recovered.public_seed, public_key.public_seed);
490        assert_eq!(recovered.public_key_hash, public_key.public_key_hash);
491    }
492
493    #[cfg(test)]
494    mod tests {
495        use super::*;
496
497        #[test]
498        fn test_num_message_chunks() {
499            assert_eq!(constants::NUM_MESSAGE_CHUNKS, 64);
500        }
501
502        #[test]
503        fn test_num_checksum_chunks() {
504            assert_eq!(constants::NUM_CHECKSUM_CHUNKS, 3);
505        }
506    }
507}