hbs_lms/lm_ots/
parameters.rs

1use core::marker::PhantomData;
2
3use tinyvec::ArrayVec;
4
5use crate::constants::get_hash_chain_count;
6use crate::{
7    constants::{FastVerifyCached, MAX_HASH_SIZE},
8    hasher::HashChain,
9    util::coef::coef,
10};
11
12use crate::util::coef::coef_helper;
13
14/// Specifies the used Winternitz parameter.
15#[derive(Clone, Copy, PartialEq, Eq)]
16pub enum LmotsAlgorithm {
17    LmotsReserved = 0,
18    LmotsW1 = 1,
19    LmotsW2 = 2,
20    LmotsW4 = 3,
21    LmotsW8 = 4,
22}
23
24impl Default for LmotsAlgorithm {
25    fn default() -> Self {
26        LmotsAlgorithm::LmotsReserved
27    }
28}
29
30impl From<u32> for LmotsAlgorithm {
31    fn from(_type: u32) -> Self {
32        match _type {
33            1 => LmotsAlgorithm::LmotsW1,
34            2 => LmotsAlgorithm::LmotsW2,
35            3 => LmotsAlgorithm::LmotsW4,
36            4 => LmotsAlgorithm::LmotsW8,
37            _ => LmotsAlgorithm::LmotsReserved,
38        }
39    }
40}
41
42impl LmotsAlgorithm {
43    pub fn construct_default_parameter<H: HashChain>() -> LmotsParameter<H> {
44        LmotsAlgorithm::LmotsW1.construct_parameter().unwrap()
45    }
46
47    pub fn construct_parameter<H: HashChain>(&self) -> Option<LmotsParameter<H>> {
48        match *self {
49            LmotsAlgorithm::LmotsReserved => None,
50            LmotsAlgorithm::LmotsW1 => Some(LmotsParameter::new(
51                1,
52                1,
53                get_hash_chain_count(1, H::OUTPUT_SIZE as usize) as u16,
54                7,
55            )),
56            LmotsAlgorithm::LmotsW2 => Some(LmotsParameter::new(
57                2,
58                2,
59                get_hash_chain_count(2, H::OUTPUT_SIZE as usize) as u16,
60                6,
61            )),
62            LmotsAlgorithm::LmotsW4 => Some(LmotsParameter::new(
63                3,
64                4,
65                get_hash_chain_count(4, H::OUTPUT_SIZE as usize) as u16,
66                4,
67            )),
68            LmotsAlgorithm::LmotsW8 => Some(LmotsParameter::new(
69                4,
70                8,
71                get_hash_chain_count(8, H::OUTPUT_SIZE as usize) as u16,
72                0,
73            )),
74        }
75    }
76
77    pub fn get_from_type<H: HashChain>(_type: u32) -> Option<LmotsParameter<H>> {
78        match _type {
79            1 => LmotsAlgorithm::LmotsW1.construct_parameter(),
80            2 => LmotsAlgorithm::LmotsW2.construct_parameter(),
81            3 => LmotsAlgorithm::LmotsW4.construct_parameter(),
82            4 => LmotsAlgorithm::LmotsW8.construct_parameter(),
83            _ => None,
84        }
85    }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct LmotsParameter<H: HashChain> {
90    type_id: u32,
91    winternitz: u8,
92    hash_chain_count: u16,
93    checksum_left_shift: u8,
94    phantom_data: PhantomData<H>,
95}
96
97// Manually implement Copy because HashChain trait does not.
98// However, it does not make a difference, because we don't hold a instance for HashChain.
99impl<H: HashChain> Copy for LmotsParameter<H> {}
100
101impl<H: HashChain> LmotsParameter<H> {
102    const HASH_FUNCTION_OUTPUT_SIZE: u16 = H::OUTPUT_SIZE;
103
104    pub fn new(
105        type_id: u32,
106        winternitz: u8,
107        hash_chain_count: u16,
108        checksum_left_shift: u8,
109    ) -> Self {
110        Self {
111            type_id,
112            winternitz,
113            hash_chain_count,
114            checksum_left_shift,
115            phantom_data: PhantomData,
116        }
117    }
118
119    pub fn get_type_id(&self) -> u32 {
120        self.type_id
121    }
122
123    pub fn get_winternitz(&self) -> u8 {
124        self.winternitz
125    }
126
127    pub fn get_hash_chain_count(&self) -> u16 {
128        self.hash_chain_count
129    }
130
131    pub fn get_checksum_left_shift(&self) -> u8 {
132        self.checksum_left_shift
133    }
134
135    pub fn get_hash_function_output_size(&self) -> usize {
136        Self::HASH_FUNCTION_OUTPUT_SIZE as usize
137    }
138
139    pub fn fast_verify_eval_init(&self) -> FastVerifyCached {
140        let max = (Self::HASH_FUNCTION_OUTPUT_SIZE * 8) / self.get_winternitz() as u16;
141
142        let max_word_size = (1 << self.get_winternitz()) - 1;
143        let sum = max * max_word_size;
144
145        let mut coef = ArrayVec::new();
146        for i in 0..self.get_hash_chain_count() {
147            coef.push(coef_helper(i, self.get_winternitz()));
148        }
149
150        (max, sum, coef)
151    }
152
153    pub fn fast_verify_eval(
154        &self,
155        byte_string: &[u8],
156        fast_verify_cached: &FastVerifyCached,
157    ) -> u16 {
158        let (max, sum, coef) = fast_verify_cached;
159        let mut total_hash_chain_iterations = 0;
160
161        for i in 0..*max {
162            let (index, shift, mask) = coef[i as usize];
163            let hash_chain_length = ((byte_string[index] as u64 >> shift) & mask) as u16;
164            total_hash_chain_iterations += hash_chain_length;
165        }
166
167        let mut checksum = *sum - total_hash_chain_iterations;
168        checksum <<= self.get_checksum_left_shift();
169        let checksum = [(checksum >> 8 & 0xff) as u8, (checksum & 0xff) as u8];
170
171        for i in *max..self.get_hash_chain_count() {
172            let (index, shift, mask) = coef[i as usize];
173            let hash_chain_length = ((checksum[index - 32] as u64 >> shift) & mask) as u16;
174            total_hash_chain_iterations += hash_chain_length;
175        }
176
177        total_hash_chain_iterations
178    }
179
180    fn checksum(&self, byte_string: &[u8]) -> u16 {
181        let mut sum = 0_u16;
182
183        let max = (Self::HASH_FUNCTION_OUTPUT_SIZE * 8) / self.get_winternitz() as u16;
184
185        let max_word_size: u64 = (1 << self.get_winternitz()) - 1;
186
187        for i in 0..max {
188            sum += (max_word_size - coef(byte_string, i, self.get_winternitz())) as u16;
189        }
190
191        sum << self.get_checksum_left_shift()
192    }
193
194    pub fn append_checksum_to(&self, byte_string: &[u8]) -> ArrayVec<[u8; MAX_HASH_SIZE + 2]> {
195        let mut result = ArrayVec::new();
196
197        let checksum = self.checksum(byte_string);
198
199        result.extend_from_slice(byte_string);
200
201        result.extend_from_slice(&[(checksum >> 8 & 0xff) as u8]);
202        result.extend_from_slice(&[(checksum & 0xff) as u8]);
203
204        result
205    }
206
207    pub fn get_hasher(&self) -> H {
208        H::default()
209    }
210}
211
212impl<H: HashChain> Default for LmotsParameter<H> {
213    fn default() -> Self {
214        LmotsAlgorithm::LmotsW1.construct_parameter().unwrap()
215    }
216}