scram_rs/
scram_hashing.rs

1/*-
2 * Scram-rs - a SCRAM authentification authorization library
3 * 
4 * Copyright (C) 2021  Aleksandr Morozov
5 * Copyright (C) 2025 Aleksandr Morozov
6 * 
7 * The syslog-rs crate can be redistributed and/or modified
8 * under the terms of either of the following licenses:
9 *
10 *   1. the Mozilla Public License Version 2.0 (the “MPL”) OR
11 *
12 *   2. The MIT License (MIT)
13 *                     
14 *   3. EUROPEAN UNION PUBLIC LICENCE v. 1.2 EUPL © the European Union 2007, 2016
15 */
16
17
18#[cfg(feature = "std")]
19use std::num::NonZeroU32;
20#[cfg(not(feature = "std"))]
21use core::num::NonZeroU32;
22
23#[cfg(not(feature = "std"))]
24use alloc::vec::Vec;
25
26use crate::scram_error;
27use crate::ScramErrorCode;
28use crate::ScramServerError;
29
30use super::scram_error::ScramResult;
31
32/// A trait which implements different operations which required by the 
33/// SCRAM algorithm.
34pub trait ScramOps
35{
36    /// A function which performs the XOR of two arrays
37    #[cfg(not(feature = "xor_without_u128"))]
38    fn xor_arrays(a: &[u8], b: &[u8]) -> ScramResult<Vec<u8>>
39    {
40        if a.len() != b.len()
41        {
42            scram_error!(
43                ScramErrorCode::InternalError,
44                ScramServerError::OtherError,
45                "xor arrays size mismatch: a: '{}', b: '{}'", a.len(), b.len()
46            );
47        }
48
49        let mut ret = Vec::<u8>::with_capacity(a.len());
50        let len = a.len();
51        let mut i = 0;
52
53        while i < len
54        {
55            let leni = len-i;
56
57            if leni >= (u128::BITS/8) as usize && ((&a[i]) as *const _ as *const u128).is_aligned() == true
58            {
59                let a1 = u128::from_be_bytes(a[i..i+(u128::BITS/8) as usize].try_into().unwrap());
60                let b1 = u128::from_be_bytes(b[i..i+(u128::BITS/8) as usize].try_into().unwrap());
61
62                let c1 = a1 ^ b1;
63
64                ret.extend_from_slice(&c1.to_be_bytes());
65
66                i += (u128::BITS/8) as usize;
67            }
68            else if leni >= (u64::BITS/8) as usize
69            {
70                let a1 = u64::from_be_bytes(a[i..i+(u64::BITS/8) as usize].try_into().unwrap());
71                let b1 = u64::from_be_bytes(b[i..i+(u64::BITS/8) as usize].try_into().unwrap());
72
73                let c1 = a1 ^ b1;
74
75                ret.extend_from_slice(&c1.to_be_bytes());
76
77                i += (u64::BITS/8) as usize;
78            }
79            else if leni >= (u32::BITS/8) as usize
80            {
81                let a1 = u32::from_be_bytes(a[i..i+(u32::BITS/8) as usize].try_into().unwrap());
82                let b1 = u32::from_be_bytes(b[i..i+(u32::BITS/8) as usize].try_into().unwrap());
83
84                let c1 = a1 ^ b1;
85
86                ret.extend_from_slice(&c1.to_be_bytes());
87
88                i += (u32::BITS/8) as usize;
89            }
90            else if leni >= (u16::BITS/8) as usize 
91            {
92                let a1 = u16::from_be_bytes(a[i..i+(u16::BITS/8) as usize].try_into().unwrap());
93                let b1 = u16::from_be_bytes(b[i..i+(u16::BITS/8) as usize].try_into().unwrap());
94
95                let c1 = a1 ^ b1;
96
97                ret.extend_from_slice(&c1.to_be_bytes());
98
99                i += (u16::BITS/8) as usize;
100            }
101            else 
102            {
103                ret.push(a[i] ^ b[i]);
104
105                i += 1;
106            }   
107        }
108        
109        return Ok(ret);
110    }
111
112    #[cfg(feature = "xor_without_u128")]
113    fn xor_arrays(a: &[u8], b: &[u8]) -> ScramResult<Vec<u8>>
114    {
115        if a.len() != b.len()
116        {
117            scram_error!(
118                ScramErrorCode::InternalError,
119                ScramServerError::OtherError,
120                "xor arrays size mismatch: a: '{}', b: '{}'", a.len(), b.len()
121            );
122        }
123
124        let mut ret = Vec::<u8>::with_capacity(a.len());
125        let len = a.len();
126        let mut i = 0;
127
128        while i < len
129        {
130            let leni = len-i;
131
132            if leni >= (u64::BITS/8) as usize
133            {
134                let a1 = u64::from_be_bytes(a[i..i+(u64::BITS/8) as usize].try_into().unwrap());
135                let b1 = u64::from_be_bytes(b[i..i+(u64::BITS/8) as usize].try_into().unwrap());
136
137                let c1 = a1 ^ b1;
138
139                ret.extend_from_slice(&c1.to_be_bytes());
140
141                i += (u64::BITS/8) as usize;
142            }
143            else if leni >= (u32::BITS/8) as usize
144            {
145                let a1 = u32::from_be_bytes(a[i..i+(u32::BITS/8) as usize].try_into().unwrap());
146                let b1 = u32::from_be_bytes(b[i..i+(u32::BITS/8) as usize].try_into().unwrap());
147
148                let c1 = a1 ^ b1;
149
150                ret.extend_from_slice(&c1.to_be_bytes());
151
152                i += (u32::BITS/8) as usize;
153            }
154            else if leni >= (u16::BITS/8) as usize 
155            {
156                let a1 = u16::from_be_bytes(a[i..i+(u16::BITS/8) as usize].try_into().unwrap());
157                let b1 = u16::from_be_bytes(b[i..i+(u16::BITS/8) as usize].try_into().unwrap());
158
159                let c1 = a1 ^ b1;
160
161                ret.extend_from_slice(&c1.to_be_bytes());
162
163                i += (u16::BITS/8) as usize;
164            }
165            else 
166            {
167                ret.push(a[i] ^ b[i]);
168
169                i += 1;
170            }   
171        }
172        
173        return Ok(ret);
174    }
175}
176
177/// A reference function which works 100%. Can be used for testing against.
178#[inline]
179pub
180fn xor_arrays_reference(a: &[u8], b: &[u8]) -> ScramResult<Vec<u8>>
181{
182    if a.len() != b.len()
183    {
184        scram_error!(
185            ScramErrorCode::InternalError,
186            ScramServerError::OtherError,
187            "xor arrays size mismatch: a: '{}', b: '{}'", a.len(), b.len()
188        );
189    }
190
191    let ret = a.into_iter().zip(b).map(|(a, b)| a ^ b).collect::<Vec<u8>>();
192    return Ok(ret);
193}
194
195pub trait ScramHashing: ScramOps
196{
197    /// A function which hashes the data using the hash function.
198    fn hash(data: &[u8]) -> Vec<u8>;
199
200    /// A function which performs an HMAC using the hash function.
201    fn hmac(data: &[u8], key: &[u8]) -> ScramResult<Vec<u8>>;
202
203    /// A function which does PBKDF2 key derivation using the hash function.
204    fn derive(password: &[u8], salt: &[u8], iterations: NonZeroU32) -> ScramResult<Vec<u8>>;
205}
206
207/// All hashing code dedicated for SHA1. Both rust native and ring inplemetations.
208pub use super::scram_hashing_sha1::*;
209
210/// All hasing code dedicated for SHA256. Both rust native and ring inplemetations.
211pub use super::scram_hashing_sha2::*;
212
213/// All hashing code dedicated for SHA512. Both rust native and ring inplemetations.
214pub use super::scram_hashing_sha5::*;
215
216
217#[cfg(feature = "std")]
218#[cfg(test)]
219mod tests
220{
221    use std::time::Instant;
222
223    use crate::{xor_arrays_reference, ScramCommon, ScramOps};
224
225    fn perf_xor(a: &[u8], b: &[u8])
226    {
227        let s = Instant::now();
228
229        let res_ref = xor_arrays_reference(a, b).unwrap();
230
231        let e1 = s.elapsed();
232        
233        struct XorTest; impl ScramOps for XorTest {}
234        
235        let s = Instant::now();
236
237        let xor_res = XorTest::xor_arrays(a, b).unwrap();
238
239        let e2 = s.elapsed();
240
241        println!("ref: {:?}, fast: {:?}", e1, e2);
242
243
244        assert_eq!(xor_res.as_slice(), res_ref.as_slice());
245    }
246
247    #[test]
248    fn test_xor()
249    {
250        for _ in 0..20
251        {
252            let a = ScramCommon::sc_random(128).unwrap();
253            let b = ScramCommon::sc_random(128).unwrap();
254            perf_xor(&a, &b);
255        }
256    }
257
258    #[test]
259    fn test_xor2()
260    {
261        for _ in 0..20
262        {
263            let a = ScramCommon::sc_random(1).unwrap();
264            let b = ScramCommon::sc_random(1).unwrap();
265            perf_xor(&a, &b);
266        }
267    }
268}