1type HashFn = fn(&[u8]) -> [u8; 32];
22
23pub mod constants {
25 pub const HASH_LEN: usize = 32;
29
30 pub const MESSAGE_LEN: usize = HASH_LEN;
40
41 pub const CHAIN_LEN: usize = 16;
52
53 pub const LG_CHAIN_LEN: usize = {
55 CHAIN_LEN.ilog2() as usize
57 };
58
59 pub const NUM_MESSAGE_CHUNKS: usize = {
65 (8 * HASH_LEN).div_ceil(LG_CHAIN_LEN)
69 };
70
71 #[cfg(test)]
72 mod tests {
73 use super::*;
74
75 #[test]
76 fn test_num_message_chunks() {
77 assert_eq!(NUM_MESSAGE_CHUNKS, 64);
78 }
79 }
80
81 pub const NUM_CHECKSUM_CHUNKS: usize = {
88 ((NUM_MESSAGE_CHUNKS * (CHAIN_LEN - 1)).ilog2() as usize / LG_CHAIN_LEN) + 1
94 };
95
96 pub const NUM_SIGNATURE_CHUNKS: usize = NUM_MESSAGE_CHUNKS + NUM_CHECKSUM_CHUNKS;
97 pub const SIGNATURE_SIZE: usize = NUM_SIGNATURE_CHUNKS * HASH_LEN;
99 pub const PUBLIC_KEY_SIZE: usize = HASH_LEN * 2;
101 pub const PRF_INPUT_SIZE: usize = 1 + HASH_LEN + 2;
103}
104
105#[derive(Debug, Clone, Copy)]
109pub struct PublicKey {
110 pub public_seed: [u8; constants::HASH_LEN],
111 pub public_key_hash: [u8; constants::HASH_LEN],
112}
113
114impl PublicKey {
115 pub fn to_bytes(&self) -> [u8; constants::PUBLIC_KEY_SIZE] {
118 let mut result = [0u8; constants::PUBLIC_KEY_SIZE];
119 result[..constants::HASH_LEN].copy_from_slice(&self.public_seed);
120 result[constants::HASH_LEN..].copy_from_slice(&self.public_key_hash);
121 result
122 }
123
124 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
127 if bytes.len() != constants::PUBLIC_KEY_SIZE {
128 return None;
129 }
130 let mut public_seed = [0u8; constants::HASH_LEN];
131 let mut public_key_hash = [0u8; constants::HASH_LEN];
132
133 public_seed.copy_from_slice(&bytes[..constants::HASH_LEN]);
134 public_key_hash.copy_from_slice(&bytes[constants::HASH_LEN..]);
135
136 Some(PublicKey {
137 public_seed,
138 public_key_hash,
139 })
140 }
141}
142
143pub struct WOTSPlus {
144 hash_fn: HashFn,
145}
146
147impl WOTSPlus {
148 pub fn new(hash_fn: HashFn) -> Self {
150 Self { hash_fn }
151 }
152
153 fn prf(&self, seed: &[u8; constants::HASH_LEN], index: u16) -> [u8; constants::HASH_LEN] {
157 let mut input = [0u8; constants::PRF_INPUT_SIZE];
158 input[0] = 0x03; input[1..33].copy_from_slice(seed); input[33..].copy_from_slice(&index.to_be_bytes()); (self.hash_fn)(&input)
162 }
163
164 pub fn generate_randomization_elements(
167 &self,
168 public_seed: &[u8; constants::HASH_LEN]
169 ) -> Vec<[u8; constants::HASH_LEN]> {
170 let mut elements = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);
171 for i in 0..constants::NUM_SIGNATURE_CHUNKS {
172 elements.push(self.prf(public_seed, i as u16));
173 }
174 elements
175 }
176
177 fn xor(a: &[u8; constants::HASH_LEN], b: &[u8; constants::HASH_LEN]) -> [u8; constants::HASH_LEN] {
179 let mut result = [0u8; constants::HASH_LEN];
180 for i in 0..constants::HASH_LEN {
181 result[i] = a[i] ^ b[i];
182 }
183 result
184 }
185
186 fn chain(
191 &self,
192 prev_chain_out: &[u8; constants::HASH_LEN],
193 randomization_elements: &[[u8; constants::HASH_LEN]],
194 index: u16,
195 steps: u16,
196 ) -> [u8; constants::HASH_LEN] {
197 let mut chain_out = *prev_chain_out;
198 for i in 1..=steps {
199 let xored = Self::xor(&chain_out, &randomization_elements[(i + index) as usize]);
200 chain_out = (self.hash_fn)(&xored);
201 }
202 chain_out
203 }
204
205 fn compute_message_hash_chain_indexes(&self, message: &[u8]) -> Vec<u8> {
214 if message.len() != constants::MESSAGE_LEN {
215 panic!("Message length must be {} bytes", constants::MESSAGE_LEN);
216 }
217
218 let mut chain_segments_indexes = vec![0u8; constants::NUM_SIGNATURE_CHUNKS];
219 let mut idx = 0;
220
221 for byte in message {
223 chain_segments_indexes[idx] = byte >> 4;
224 chain_segments_indexes[idx + 1] = byte & 0x0f;
225 idx += 2;
226 }
227
228 let mut checksum: u32 = 0;
230 for &value in &chain_segments_indexes[..constants::NUM_MESSAGE_CHUNKS] {
231 checksum += constants::CHAIN_LEN as u32 - 1 - value as u32
232 }
233
234 for i in (0..constants::NUM_CHECKSUM_CHUNKS).rev() {
238 let shift = i * constants::LG_CHAIN_LEN as usize;
239 chain_segments_indexes[idx] = ((checksum >> shift) & (constants::CHAIN_LEN as u32 - 1)) as u8;
240 idx += 1;
241 }
242
243 chain_segments_indexes
244 }
245
246 pub fn get_public_key(&self, private_key: &[u8; constants::HASH_LEN]) -> PublicKey {
248 let public_seed = self.prf(private_key, 0);
249 self.get_public_key_with_public_seed(private_key, &public_seed)
250 }
251 pub fn get_public_key_with_public_seed(&self, private_key: &[u8; constants::HASH_LEN], public_seed: &[u8; constants::HASH_LEN]) -> PublicKey {
252 let randomization_elements = self.generate_randomization_elements(&public_seed);
253 let function_key = randomization_elements[0];
254
255 let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
256
257 for i in 0..constants::NUM_SIGNATURE_CHUNKS {
258 let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
259 to_hash[..constants::HASH_LEN].copy_from_slice(&function_key);
260 to_hash[constants::HASH_LEN..].copy_from_slice(&self.prf(private_key, (i + 1) as u16));
261
262 let secret_key_segment = (self.hash_fn)(&to_hash);
263 let segment = self.chain(
264 &secret_key_segment,
265 &randomization_elements,
266 0,
267 (constants::CHAIN_LEN - 1) as u16,
268 );
269
270 public_key_segments.extend_from_slice(&segment);
271 }
272
273 let public_key_hash = (self.hash_fn)(&public_key_segments);
274
275 PublicKey {
276 public_seed: *public_seed,
277 public_key_hash,
278 }
279 }
280
281
282 pub fn generate_key_pair(&self, private_seed: &[u8; constants::HASH_LEN]) -> (PublicKey, [u8; constants::HASH_LEN]) {
292 let private_key = (self.hash_fn)(private_seed);
293 let public_key = self.get_public_key(&private_key);
294 (public_key, private_key)
295 }
296
297 pub fn sign(&self, private_key: &[u8; constants::HASH_LEN], message: &[u8]) -> Vec<[u8; constants::HASH_LEN]> {
306 if message.len() != constants::MESSAGE_LEN {
307 panic!("Message length must be {} bytes", constants::MESSAGE_LEN);
308 }
309
310 let public_seed = self.prf(private_key, 0);
311 let randomization_elements = self.generate_randomization_elements(&public_seed);
312 let function_key = randomization_elements[0];
313
314 let chain_segments = self.compute_message_hash_chain_indexes(message);
315 let mut signature = Vec::with_capacity(constants::NUM_SIGNATURE_CHUNKS);
316
317 for (i, &chain_idx) in chain_segments.iter().enumerate() {
318 let mut to_hash = vec![0u8; constants::HASH_LEN * 2];
319 to_hash[..constants::HASH_LEN].copy_from_slice(&function_key);
320 to_hash[constants::HASH_LEN..].copy_from_slice(&self.prf(private_key, (i + 1) as u16));
321
322 let secret_key_segment = (self.hash_fn)(&to_hash);
323 let sig_segment = self.chain(
324 &secret_key_segment,
325 &randomization_elements,
326 0,
327 chain_idx as u16,
328 );
329 signature.push(sig_segment);
330 }
331
332 signature
333 }
334
335 pub fn verify(&self, public_key: &PublicKey, message: &[u8], signature: &Vec<[u8; constants::HASH_LEN]>) -> bool {
344
345 if message.len() != constants::MESSAGE_LEN {
346 return false;
347 }
348 if signature.len() != constants::NUM_SIGNATURE_CHUNKS {
349 return false;
350 }
351
352 let randomization_elements = self.generate_randomization_elements(&public_key.public_seed);
353
354 let chain_segments = self.compute_message_hash_chain_indexes(message);
355
356 let mut public_key_segments = Vec::with_capacity(constants::SIGNATURE_SIZE);
357
358 for (i, &chain_idx) in chain_segments.iter().enumerate() {
361 let num_iterations = (constants::CHAIN_LEN - 1 - chain_idx as usize) as u16;
362 let segment = self.chain(
363 &signature[i],
364 &randomization_elements,
365 chain_idx as u16,
366 num_iterations,
367 );
368
369 public_key_segments.extend_from_slice(&segment);
370 }
371
372 let computed_hash = (self.hash_fn)(&public_key_segments);
374
375 computed_hash == public_key.public_key_hash
377 }
378
379 pub fn verify_with_randomization_elements(
383 &self,
384 public_key_hash: &[u8; constants::HASH_LEN],
385 message: &[u8],
386 signature: &Vec<[u8; constants::HASH_LEN]>,
387 randomization_elements: &Vec<[u8; constants::HASH_LEN]>,
388 ) -> bool {
389 if message.len() != constants::MESSAGE_LEN {
390 return false;
391 }
392 if signature.len() != constants::NUM_SIGNATURE_CHUNKS {
393 return false;
394 }
395 if randomization_elements.len() != constants::NUM_SIGNATURE_CHUNKS {
396 return false;
397 }
398
399 let chain_segments = self.compute_message_hash_chain_indexes(message);
400 let mut public_key_segments = [0u8; constants::SIGNATURE_SIZE];
401
402 for (i, &chain_idx) in chain_segments.iter().enumerate() {
404 let num_iterations = (constants::CHAIN_LEN - 1 - chain_idx as usize) as u16;
405 let segment = self.chain(
406 &signature[i],
407 randomization_elements,
408 chain_idx as u16,
409 num_iterations,
410 );
411
412 let offset = i * constants::HASH_LEN;
413 public_key_segments[offset..offset + constants::HASH_LEN].copy_from_slice(&segment);
414 }
415
416 let computed_hash = (self.hash_fn)(&public_key_segments);
418 computed_hash == *public_key_hash
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 fn mock_hash(data: &[u8]) -> [u8; 32] {
428 let mut output = [0u8; 32];
429 for (i, &byte) in data.iter().enumerate().take(32) {
430 output[i] = byte;
431 }
432 output
433 }
434
435 #[test]
436 fn test_constants() {
437 assert_eq!(constants::HASH_LEN, 32);
438 assert_eq!(constants::MESSAGE_LEN, 32);
439 assert_eq!(constants::CHAIN_LEN, 16);
440 assert_eq!(constants::NUM_MESSAGE_CHUNKS, 64);
441 assert_eq!(constants::NUM_CHECKSUM_CHUNKS, 3);
442 assert_eq!(constants::NUM_SIGNATURE_CHUNKS, 67);
443 }
444
445 #[test]
446 fn test_key_generation_and_signing() {
447 let wots = WOTSPlus::new(mock_hash);
448 let private_seed = [1u8; 32];
449 let (public_key, private_key) = wots.generate_key_pair(&private_seed);
450
451 let message = [2u8; constants::MESSAGE_LEN];
452 let signature = wots.sign(&private_key, &message);
453
454 assert!(wots.verify(&public_key, &message, &signature));
455 }
456
457 #[test]
458 fn test_invalid_message_length() {
459 let wots = WOTSPlus::new(mock_hash);
460 let private_seed = [1u8; 32];
461 let (public_key, _) = wots.generate_key_pair(&private_seed);
462
463 let invalid_message = [2u8; constants::MESSAGE_LEN + 1];
464 let signature: Vec<[u8; 32]> = vec![[0u8; 32]; constants::NUM_SIGNATURE_CHUNKS];
465 assert!(!wots.verify(&public_key, &invalid_message, &signature));
466 }
467
468 #[test]
469 fn test_invalid_signature_length() {
470 let wots = WOTSPlus::new(mock_hash);
471 let private_seed = [1u8; 32];
472 let (public_key, _) = wots.generate_key_pair(&private_seed);
473
474 let message = [2u8; constants::MESSAGE_LEN];
475 let signature: Vec<[u8; 32]> = vec![[0u8; 32]; constants::NUM_SIGNATURE_CHUNKS];
476 assert!(!wots.verify(&public_key, &message, &signature));
477 }
478
479 #[test]
480 fn test_public_key_serialization() {
481 let public_key = PublicKey {
482 public_seed: [1u8; constants::HASH_LEN],
483 public_key_hash: [2u8; constants::HASH_LEN],
484 };
485
486 let bytes = public_key.to_bytes();
487 let recovered = PublicKey::from_bytes(&bytes).unwrap();
488
489 assert_eq!(recovered.public_seed, public_key.public_seed);
490 assert_eq!(recovered.public_key_hash, public_key.public_key_hash);
491 }
492
493 #[cfg(test)]
494 mod tests {
495 use super::*;
496
497 #[test]
498 fn test_num_message_chunks() {
499 assert_eq!(constants::NUM_MESSAGE_CHUNKS, 64);
500 }
501
502 #[test]
503 fn test_num_checksum_chunks() {
504 assert_eq!(constants::NUM_CHECKSUM_CHUNKS, 3);
505 }
506 }
507}