scram_rs/
scram_hashing.rs1#[cfg(feature = "std")]
19use std::num::NonZeroU32;
20#[cfg(not(feature = "std"))]
21use core::num::NonZeroU32;
22
23#[cfg(not(feature = "std"))]
24use alloc::vec::Vec;
25
26use crate::scram_error;
27use crate::ScramErrorCode;
28use crate::ScramServerError;
29
30use super::scram_error::ScramResult;
31
32pub trait ScramOps
35{
36 #[cfg(not(feature = "xor_without_u128"))]
38 fn xor_arrays(a: &[u8], b: &[u8]) -> ScramResult<Vec<u8>>
39 {
40 if a.len() != b.len()
41 {
42 scram_error!(
43 ScramErrorCode::InternalError,
44 ScramServerError::OtherError,
45 "xor arrays size mismatch: a: '{}', b: '{}'", a.len(), b.len()
46 );
47 }
48
49 let mut ret = Vec::<u8>::with_capacity(a.len());
50 let len = a.len();
51 let mut i = 0;
52
53 while i < len
54 {
55 let leni = len-i;
56
57 if leni >= (u128::BITS/8) as usize && ((&a[i]) as *const _ as *const u128).is_aligned() == true
58 {
59 let a1 = u128::from_be_bytes(a[i..i+(u128::BITS/8) as usize].try_into().unwrap());
60 let b1 = u128::from_be_bytes(b[i..i+(u128::BITS/8) as usize].try_into().unwrap());
61
62 let c1 = a1 ^ b1;
63
64 ret.extend_from_slice(&c1.to_be_bytes());
65
66 i += (u128::BITS/8) as usize;
67 }
68 else if leni >= (u64::BITS/8) as usize
69 {
70 let a1 = u64::from_be_bytes(a[i..i+(u64::BITS/8) as usize].try_into().unwrap());
71 let b1 = u64::from_be_bytes(b[i..i+(u64::BITS/8) as usize].try_into().unwrap());
72
73 let c1 = a1 ^ b1;
74
75 ret.extend_from_slice(&c1.to_be_bytes());
76
77 i += (u64::BITS/8) as usize;
78 }
79 else if leni >= (u32::BITS/8) as usize
80 {
81 let a1 = u32::from_be_bytes(a[i..i+(u32::BITS/8) as usize].try_into().unwrap());
82 let b1 = u32::from_be_bytes(b[i..i+(u32::BITS/8) as usize].try_into().unwrap());
83
84 let c1 = a1 ^ b1;
85
86 ret.extend_from_slice(&c1.to_be_bytes());
87
88 i += (u32::BITS/8) as usize;
89 }
90 else if leni >= (u16::BITS/8) as usize
91 {
92 let a1 = u16::from_be_bytes(a[i..i+(u16::BITS/8) as usize].try_into().unwrap());
93 let b1 = u16::from_be_bytes(b[i..i+(u16::BITS/8) as usize].try_into().unwrap());
94
95 let c1 = a1 ^ b1;
96
97 ret.extend_from_slice(&c1.to_be_bytes());
98
99 i += (u16::BITS/8) as usize;
100 }
101 else
102 {
103 ret.push(a[i] ^ b[i]);
104
105 i += 1;
106 }
107 }
108
109 return Ok(ret);
110 }
111
112 #[cfg(feature = "xor_without_u128")]
113 fn xor_arrays(a: &[u8], b: &[u8]) -> ScramResult<Vec<u8>>
114 {
115 if a.len() != b.len()
116 {
117 scram_error!(
118 ScramErrorCode::InternalError,
119 ScramServerError::OtherError,
120 "xor arrays size mismatch: a: '{}', b: '{}'", a.len(), b.len()
121 );
122 }
123
124 let mut ret = Vec::<u8>::with_capacity(a.len());
125 let len = a.len();
126 let mut i = 0;
127
128 while i < len
129 {
130 let leni = len-i;
131
132 if leni >= (u64::BITS/8) as usize
133 {
134 let a1 = u64::from_be_bytes(a[i..i+(u64::BITS/8) as usize].try_into().unwrap());
135 let b1 = u64::from_be_bytes(b[i..i+(u64::BITS/8) as usize].try_into().unwrap());
136
137 let c1 = a1 ^ b1;
138
139 ret.extend_from_slice(&c1.to_be_bytes());
140
141 i += (u64::BITS/8) as usize;
142 }
143 else if leni >= (u32::BITS/8) as usize
144 {
145 let a1 = u32::from_be_bytes(a[i..i+(u32::BITS/8) as usize].try_into().unwrap());
146 let b1 = u32::from_be_bytes(b[i..i+(u32::BITS/8) as usize].try_into().unwrap());
147
148 let c1 = a1 ^ b1;
149
150 ret.extend_from_slice(&c1.to_be_bytes());
151
152 i += (u32::BITS/8) as usize;
153 }
154 else if leni >= (u16::BITS/8) as usize
155 {
156 let a1 = u16::from_be_bytes(a[i..i+(u16::BITS/8) as usize].try_into().unwrap());
157 let b1 = u16::from_be_bytes(b[i..i+(u16::BITS/8) as usize].try_into().unwrap());
158
159 let c1 = a1 ^ b1;
160
161 ret.extend_from_slice(&c1.to_be_bytes());
162
163 i += (u16::BITS/8) as usize;
164 }
165 else
166 {
167 ret.push(a[i] ^ b[i]);
168
169 i += 1;
170 }
171 }
172
173 return Ok(ret);
174 }
175}
176
177#[inline]
179pub
180fn xor_arrays_reference(a: &[u8], b: &[u8]) -> ScramResult<Vec<u8>>
181{
182 if a.len() != b.len()
183 {
184 scram_error!(
185 ScramErrorCode::InternalError,
186 ScramServerError::OtherError,
187 "xor arrays size mismatch: a: '{}', b: '{}'", a.len(), b.len()
188 );
189 }
190
191 let ret = a.into_iter().zip(b).map(|(a, b)| a ^ b).collect::<Vec<u8>>();
192 return Ok(ret);
193}
194
195pub trait ScramHashing: ScramOps
196{
197 fn hash(data: &[u8]) -> Vec<u8>;
199
200 fn hmac(data: &[u8], key: &[u8]) -> ScramResult<Vec<u8>>;
202
203 fn derive(password: &[u8], salt: &[u8], iterations: NonZeroU32) -> ScramResult<Vec<u8>>;
205}
206
207pub use super::scram_hashing_sha1::*;
209
210pub use super::scram_hashing_sha2::*;
212
213pub use super::scram_hashing_sha5::*;
215
216
217#[cfg(feature = "std")]
218#[cfg(test)]
219mod tests
220{
221 use std::time::Instant;
222
223 use crate::{xor_arrays_reference, ScramCommon, ScramOps};
224
225 fn perf_xor(a: &[u8], b: &[u8])
226 {
227 let s = Instant::now();
228
229 let res_ref = xor_arrays_reference(a, b).unwrap();
230
231 let e1 = s.elapsed();
232
233 struct XorTest; impl ScramOps for XorTest {}
234
235 let s = Instant::now();
236
237 let xor_res = XorTest::xor_arrays(a, b).unwrap();
238
239 let e2 = s.elapsed();
240
241 println!("ref: {:?}, fast: {:?}", e1, e2);
242
243
244 assert_eq!(xor_res.as_slice(), res_ref.as_slice());
245 }
246
247 #[test]
248 fn test_xor()
249 {
250 for _ in 0..20
251 {
252 let a = ScramCommon::sc_random(128).unwrap();
253 let b = ScramCommon::sc_random(128).unwrap();
254 perf_xor(&a, &b);
255 }
256 }
257
258 #[test]
259 fn test_xor2()
260 {
261 for _ in 0..20
262 {
263 let a = ScramCommon::sc_random(1).unwrap();
264 let b = ScramCommon::sc_random(1).unwrap();
265 perf_xor(&a, &b);
266 }
267 }
268}