Skip to main content

actix_web_csp/security/
hash.rs

1use crate::constants::{HASH_PREFIX_SHA256, HASH_PREFIX_SHA384, HASH_PREFIX_SHA512};
2use crate::core::source::Source;
3use crate::error::CspError;
4use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
5use ring::digest::{self, Context, SHA256, SHA384, SHA512};
6use smallvec::SmallVec;
7use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum HashAlgorithm {
11    Sha256,
12    Sha384,
13    Sha512,
14}
15
16impl HashAlgorithm {
17    #[inline(always)]
18    pub fn digest_algorithm(&self) -> &'static digest::Algorithm {
19        match self {
20            HashAlgorithm::Sha256 => &SHA256,
21            HashAlgorithm::Sha384 => &SHA384,
22            HashAlgorithm::Sha512 => &SHA512,
23        }
24    }
25
26    #[inline(always)]
27    pub const fn name(&self) -> &'static str {
28        match self {
29            HashAlgorithm::Sha256 => "sha256",
30            HashAlgorithm::Sha384 => "sha384",
31            HashAlgorithm::Sha512 => "sha512",
32        }
33    }
34
35    #[inline(always)]
36    pub const fn prefix(&self) -> &'static str {
37        match self {
38            HashAlgorithm::Sha256 => HASH_PREFIX_SHA256,
39            HashAlgorithm::Sha384 => HASH_PREFIX_SHA384,
40            HashAlgorithm::Sha512 => HASH_PREFIX_SHA512,
41        }
42    }
43
44    #[inline]
45    pub fn from_digest_algorithm(algo: &'static digest::Algorithm) -> Option<Self> {
46        if algo == &SHA256 {
47            Some(Self::Sha256)
48        } else if algo == &SHA384 {
49            Some(Self::Sha384)
50        } else if algo == &SHA512 {
51            Some(Self::Sha512)
52        } else {
53            None
54        }
55    }
56}
57
58impl fmt::Display for HashAlgorithm {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.write_str(self.name())
61    }
62}
63
64impl TryFrom<&str> for HashAlgorithm {
65    type Error = CspError;
66
67    fn try_from(s: &str) -> Result<Self, Self::Error> {
68        match s {
69            "sha256" => Ok(HashAlgorithm::Sha256),
70            "sha384" => Ok(HashAlgorithm::Sha384),
71            "sha512" => Ok(HashAlgorithm::Sha512),
72            _ => Err(CspError::InvalidHashAlgorithm(s.to_string())),
73        }
74    }
75}
76
77thread_local! {
78    static HASH_CONTEXTS: std::cell::RefCell<HashContextPool> = std::cell::RefCell::new(HashContextPool::new());
79}
80
81struct HashContextPool {
82    sha256_contexts: SmallVec<[Context; 4]>,
83    sha384_contexts: SmallVec<[Context; 4]>,
84    sha512_contexts: SmallVec<[Context; 4]>,
85}
86
87impl HashContextPool {
88    fn new() -> Self {
89        Self {
90            sha256_contexts: SmallVec::new(),
91            sha384_contexts: SmallVec::new(),
92            sha512_contexts: SmallVec::new(),
93        }
94    }
95
96    fn get_context(&mut self, algorithm: HashAlgorithm) -> Context {
97        match algorithm {
98            HashAlgorithm::Sha256 => self
99                .sha256_contexts
100                .pop()
101                .unwrap_or_else(|| Context::new(&SHA256)),
102            HashAlgorithm::Sha384 => self
103                .sha384_contexts
104                .pop()
105                .unwrap_or_else(|| Context::new(&SHA384)),
106            HashAlgorithm::Sha512 => self
107                .sha512_contexts
108                .pop()
109                .unwrap_or_else(|| Context::new(&SHA512)),
110        }
111    }
112
113    fn return_context(&mut self, _context: Context, algorithm: HashAlgorithm) {
114        match algorithm {
115            HashAlgorithm::Sha256 => {
116                if self.sha256_contexts.len() < 4 {
117                    let new_context = Context::new(&SHA256);
118                    self.sha256_contexts.push(new_context);
119                }
120            }
121            HashAlgorithm::Sha384 => {
122                if self.sha384_contexts.len() < 4 {
123                    let new_context = Context::new(&SHA384);
124                    self.sha384_contexts.push(new_context);
125                }
126            }
127            HashAlgorithm::Sha512 => {
128                if self.sha512_contexts.len() < 4 {
129                    let new_context = Context::new(&SHA512);
130                    self.sha512_contexts.push(new_context);
131                }
132            }
133        }
134    }
135}
136
137#[derive(Debug)]
138pub struct HashGenerator;
139
140impl HashGenerator {
141    #[inline]
142    pub fn generate(algorithm: HashAlgorithm, data: &[u8]) -> String {
143        if data.len() < 64 {
144            Self::generate_small(algorithm, data)
145        } else {
146            Self::generate_large(algorithm, data)
147        }
148    }
149
150    #[inline]
151    fn generate_small(algorithm: HashAlgorithm, data: &[u8]) -> String {
152        let digest = digest::digest(algorithm.digest_algorithm(), data);
153        BASE64.encode(digest.as_ref())
154    }
155
156    #[inline]
157    fn generate_large(algorithm: HashAlgorithm, data: &[u8]) -> String {
158        HASH_CONTEXTS.with(|pool| {
159            let mut pool = pool.borrow_mut();
160            let mut context = pool.get_context(algorithm);
161
162            const CHUNK_SIZE: usize = 16384;
163            if data.len() > CHUNK_SIZE {
164                for chunk in data.chunks(CHUNK_SIZE) {
165                    context.update(chunk);
166                }
167            } else {
168                context.update(data);
169            }
170
171            let digest = context.finish();
172            let result = BASE64.encode(digest.as_ref());
173            pool.return_context(Context::new(algorithm.digest_algorithm()), algorithm);
174            result
175        })
176    }
177
178    #[inline]
179    pub fn generate_source(algorithm: HashAlgorithm, data: &[u8]) -> Source {
180        let hash = Self::generate(algorithm, data);
181        Source::Hash {
182            algorithm,
183            value: hash.into(),
184        }
185    }
186
187    #[inline]
188    pub fn generate_multiple(requests: &[(HashAlgorithm, &[u8])]) -> Vec<String> {
189        let mut results = Vec::with_capacity(requests.len());
190
191        HASH_CONTEXTS.with(|pool| {
192            let mut pool = pool.borrow_mut();
193
194            for &(algorithm, data) in requests {
195                let mut context = pool.get_context(algorithm);
196                context.update(data);
197                let digest = context.finish();
198                results.push(BASE64.encode(digest.as_ref()));
199                pool.return_context(Context::new(algorithm.digest_algorithm()), algorithm);
200            }
201        });
202
203        results
204    }
205
206    #[inline]
207    pub fn verify_hash(algorithm: HashAlgorithm, data: &[u8], hash: &str) -> bool {
208        let calculated = Self::generate(algorithm, data);
209        crate::utils::fast_string_compare(&calculated, hash)
210    }
211
212    #[inline]
213    pub fn generate_with_nonce(algorithm: HashAlgorithm, data: &[u8], nonce: &str) -> String {
214        HASH_CONTEXTS.with(|pool| {
215            let mut pool = pool.borrow_mut();
216            let mut context = pool.get_context(algorithm);
217            context.update(data);
218            context.update(nonce.as_bytes());
219            let digest = context.finish();
220            let result = BASE64.encode(digest.as_ref());
221            pool.return_context(Context::new(algorithm.digest_algorithm()), algorithm);
222            result
223        })
224    }
225
226    #[inline]
227    pub fn batch_verify(requests: &[(HashAlgorithm, &[u8], &str)]) -> Vec<bool> {
228        if requests.is_empty() {
229            return Vec::new();
230        }
231
232        let mut results = Vec::with_capacity(requests.len());
233
234        let mut sha256_requests = Vec::new();
235        let mut sha384_requests = Vec::new();
236        let mut sha512_requests = Vec::new();
237
238        for (i, &(algorithm, data, expected_hash)) in requests.iter().enumerate() {
239            match algorithm {
240                HashAlgorithm::Sha256 => sha256_requests.push((i, data, expected_hash)),
241                HashAlgorithm::Sha384 => sha384_requests.push((i, data, expected_hash)),
242                HashAlgorithm::Sha512 => sha512_requests.push((i, data, expected_hash)),
243            }
244        }
245
246        results.resize(requests.len(), false);
247
248        HASH_CONTEXTS.with(|pool| {
249            let mut pool = pool.borrow_mut();
250
251            if !sha256_requests.is_empty() {
252                let mut context = pool.get_context(HashAlgorithm::Sha256);
253                for &(i, data, expected_hash) in &sha256_requests {
254                    context.update(data);
255                    let digest = context.finish();
256                    let calculated = BASE64.encode(digest.as_ref());
257                    results[i] = crate::utils::fast_string_compare(&calculated, expected_hash);
258
259                    context = Context::new(&SHA256);
260                }
261                pool.return_context(context, HashAlgorithm::Sha256);
262            }
263
264            if !sha384_requests.is_empty() {
265                let mut context = pool.get_context(HashAlgorithm::Sha384);
266                for &(i, data, expected_hash) in &sha384_requests {
267                    context.update(data);
268                    let digest = context.finish();
269                    let calculated = BASE64.encode(digest.as_ref());
270                    results[i] = crate::utils::fast_string_compare(&calculated, expected_hash);
271
272                    context = Context::new(&SHA384);
273                }
274                pool.return_context(context, HashAlgorithm::Sha384);
275            }
276
277            if !sha512_requests.is_empty() {
278                let mut context = pool.get_context(HashAlgorithm::Sha512);
279                for &(i, data, expected_hash) in &sha512_requests {
280                    context.update(data);
281                    let digest = context.finish();
282                    let calculated = BASE64.encode(digest.as_ref());
283                    results[i] = crate::utils::fast_string_compare(&calculated, expected_hash);
284
285                    context = Context::new(&SHA512);
286                }
287                pool.return_context(context, HashAlgorithm::Sha512);
288            }
289        });
290
291        results
292    }
293
294    #[inline]
295    pub fn generate_hash(&self, content: &str) -> Result<String, CspError> {
296        Ok(Self::generate(HashAlgorithm::Sha256, content.as_bytes()))
297    }
298}