hbs_lms/lm_ots/
parameters.rs1use core::marker::PhantomData;
2
3use tinyvec::ArrayVec;
4
5use crate::constants::get_hash_chain_count;
6use crate::{
7 constants::{FastVerifyCached, MAX_HASH_SIZE},
8 hasher::HashChain,
9 util::coef::coef,
10};
11
12use crate::util::coef::coef_helper;
13
14#[derive(Clone, Copy, PartialEq, Eq)]
16pub enum LmotsAlgorithm {
17 LmotsReserved = 0,
18 LmotsW1 = 1,
19 LmotsW2 = 2,
20 LmotsW4 = 3,
21 LmotsW8 = 4,
22}
23
24impl Default for LmotsAlgorithm {
25 fn default() -> Self {
26 LmotsAlgorithm::LmotsReserved
27 }
28}
29
30impl From<u32> for LmotsAlgorithm {
31 fn from(_type: u32) -> Self {
32 match _type {
33 1 => LmotsAlgorithm::LmotsW1,
34 2 => LmotsAlgorithm::LmotsW2,
35 3 => LmotsAlgorithm::LmotsW4,
36 4 => LmotsAlgorithm::LmotsW8,
37 _ => LmotsAlgorithm::LmotsReserved,
38 }
39 }
40}
41
42impl LmotsAlgorithm {
43 pub fn construct_default_parameter<H: HashChain>() -> LmotsParameter<H> {
44 LmotsAlgorithm::LmotsW1.construct_parameter().unwrap()
45 }
46
47 pub fn construct_parameter<H: HashChain>(&self) -> Option<LmotsParameter<H>> {
48 match *self {
49 LmotsAlgorithm::LmotsReserved => None,
50 LmotsAlgorithm::LmotsW1 => Some(LmotsParameter::new(
51 1,
52 1,
53 get_hash_chain_count(1, H::OUTPUT_SIZE as usize) as u16,
54 7,
55 )),
56 LmotsAlgorithm::LmotsW2 => Some(LmotsParameter::new(
57 2,
58 2,
59 get_hash_chain_count(2, H::OUTPUT_SIZE as usize) as u16,
60 6,
61 )),
62 LmotsAlgorithm::LmotsW4 => Some(LmotsParameter::new(
63 3,
64 4,
65 get_hash_chain_count(4, H::OUTPUT_SIZE as usize) as u16,
66 4,
67 )),
68 LmotsAlgorithm::LmotsW8 => Some(LmotsParameter::new(
69 4,
70 8,
71 get_hash_chain_count(8, H::OUTPUT_SIZE as usize) as u16,
72 0,
73 )),
74 }
75 }
76
77 pub fn get_from_type<H: HashChain>(_type: u32) -> Option<LmotsParameter<H>> {
78 match _type {
79 1 => LmotsAlgorithm::LmotsW1.construct_parameter(),
80 2 => LmotsAlgorithm::LmotsW2.construct_parameter(),
81 3 => LmotsAlgorithm::LmotsW4.construct_parameter(),
82 4 => LmotsAlgorithm::LmotsW8.construct_parameter(),
83 _ => None,
84 }
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct LmotsParameter<H: HashChain> {
90 type_id: u32,
91 winternitz: u8,
92 hash_chain_count: u16,
93 checksum_left_shift: u8,
94 phantom_data: PhantomData<H>,
95}
96
97impl<H: HashChain> Copy for LmotsParameter<H> {}
100
101impl<H: HashChain> LmotsParameter<H> {
102 const HASH_FUNCTION_OUTPUT_SIZE: u16 = H::OUTPUT_SIZE;
103
104 pub fn new(
105 type_id: u32,
106 winternitz: u8,
107 hash_chain_count: u16,
108 checksum_left_shift: u8,
109 ) -> Self {
110 Self {
111 type_id,
112 winternitz,
113 hash_chain_count,
114 checksum_left_shift,
115 phantom_data: PhantomData,
116 }
117 }
118
119 pub fn get_type_id(&self) -> u32 {
120 self.type_id
121 }
122
123 pub fn get_winternitz(&self) -> u8 {
124 self.winternitz
125 }
126
127 pub fn get_hash_chain_count(&self) -> u16 {
128 self.hash_chain_count
129 }
130
131 pub fn get_checksum_left_shift(&self) -> u8 {
132 self.checksum_left_shift
133 }
134
135 pub fn get_hash_function_output_size(&self) -> usize {
136 Self::HASH_FUNCTION_OUTPUT_SIZE as usize
137 }
138
139 pub fn fast_verify_eval_init(&self) -> FastVerifyCached {
140 let max = (Self::HASH_FUNCTION_OUTPUT_SIZE * 8) / self.get_winternitz() as u16;
141
142 let max_word_size = (1 << self.get_winternitz()) - 1;
143 let sum = max * max_word_size;
144
145 let mut coef = ArrayVec::new();
146 for i in 0..self.get_hash_chain_count() {
147 coef.push(coef_helper(i, self.get_winternitz()));
148 }
149
150 (max, sum, coef)
151 }
152
153 pub fn fast_verify_eval(
154 &self,
155 byte_string: &[u8],
156 fast_verify_cached: &FastVerifyCached,
157 ) -> u16 {
158 let (max, sum, coef) = fast_verify_cached;
159 let mut total_hash_chain_iterations = 0;
160
161 for i in 0..*max {
162 let (index, shift, mask) = coef[i as usize];
163 let hash_chain_length = ((byte_string[index] as u64 >> shift) & mask) as u16;
164 total_hash_chain_iterations += hash_chain_length;
165 }
166
167 let mut checksum = *sum - total_hash_chain_iterations;
168 checksum <<= self.get_checksum_left_shift();
169 let checksum = [(checksum >> 8 & 0xff) as u8, (checksum & 0xff) as u8];
170
171 for i in *max..self.get_hash_chain_count() {
172 let (index, shift, mask) = coef[i as usize];
173 let hash_chain_length = ((checksum[index - 32] as u64 >> shift) & mask) as u16;
174 total_hash_chain_iterations += hash_chain_length;
175 }
176
177 total_hash_chain_iterations
178 }
179
180 fn checksum(&self, byte_string: &[u8]) -> u16 {
181 let mut sum = 0_u16;
182
183 let max = (Self::HASH_FUNCTION_OUTPUT_SIZE * 8) / self.get_winternitz() as u16;
184
185 let max_word_size: u64 = (1 << self.get_winternitz()) - 1;
186
187 for i in 0..max {
188 sum += (max_word_size - coef(byte_string, i, self.get_winternitz())) as u16;
189 }
190
191 sum << self.get_checksum_left_shift()
192 }
193
194 pub fn append_checksum_to(&self, byte_string: &[u8]) -> ArrayVec<[u8; MAX_HASH_SIZE + 2]> {
195 let mut result = ArrayVec::new();
196
197 let checksum = self.checksum(byte_string);
198
199 result.extend_from_slice(byte_string);
200
201 result.extend_from_slice(&[(checksum >> 8 & 0xff) as u8]);
202 result.extend_from_slice(&[(checksum & 0xff) as u8]);
203
204 result
205 }
206
207 pub fn get_hasher(&self) -> H {
208 H::default()
209 }
210}
211
212impl<H: HashChain> Default for LmotsParameter<H> {
213 fn default() -> Self {
214 LmotsAlgorithm::LmotsW1.construct_parameter().unwrap()
215 }
216}