tower_surf/
token.rs

1use std::sync::Arc;
2
3use base64::prelude::*;
4use hmac::{Hmac, Mac};
5use rand::prelude::*;
6use secstr::SecStr;
7use sha2::Sha256;
8use tower_cookies::{Cookie, Cookies};
9
10use crate::{error::Error, surf::Config};
11
12/// An extension providing a way to interact with a visitor's
13/// CSRF token.
14#[derive(Clone)]
15pub struct Token {
16    pub(crate) config: Arc<Config>,
17    pub(crate) cookies: Cookies,
18}
19
20impl Token {
21    pub(crate) fn create(&self) -> Result<(), Error> {
22        let identifier: i128 = thread_rng().gen();
23        let token = create_token(&self.config.secret, identifier.to_string())?;
24
25        let cookie = Cookie::build((self.config.cookie_name(), token))
26            .path("/")
27            .expires(self.config.expires)
28            .http_only(self.config.http_only)
29            .same_site(self.config.same_site)
30            .secure(self.config.secure)
31            .build();
32
33        self.cookies.add(cookie);
34
35        Ok(())
36    }
37
38    /// Updates the identifier used to sign the token. The value should only be valid for the
39    /// duration of the user's authenticated session and should be unique to that session.
40    ///
41    /// See: [OWASP's CSRF Prevention Cheat Sheet](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#employing-hmac-csrf-tokens).
42    ///
43    /// # Errors
44    ///
45    /// - [`Error::InvalidLength`]
46    pub fn set(&self, identifier: impl Into<String>) -> Result<(), Error> {
47        let token = create_token(&self.config.secret, identifier)?;
48
49        let cookie = Cookie::build((self.config.cookie_name(), token))
50            .path("/")
51            .expires(self.config.expires)
52            .http_only(self.config.http_only)
53            .same_site(self.config.same_site)
54            .secure(self.config.secure)
55            .build();
56
57        self.cookies.add(cookie);
58
59        Ok(())
60    }
61
62    /// Get the current visitor's token.
63    ///
64    /// # Errors
65    ///
66    /// - [`Error::NoCookie`]
67    pub fn get(&self) -> Result<String, Error> {
68        self.cookies
69            .get(&self.config.cookie_name())
70            .map(|cookie| cookie.value().to_owned())
71            .ok_or(Error::NoCookie)
72    }
73
74    /// Reset the token to an identifier generated by [Surf](`crate::Surf`).
75    pub fn reset(&self) {
76        let cookie = Cookie::build((self.config.cookie_name(), "")).build();
77
78        self.cookies.remove(cookie);
79    }
80}
81
82type HmacSha256 = Hmac<Sha256>;
83
84pub(crate) fn create_token(
85    secret: &SecStr,
86    identifier: impl Into<String>,
87) -> Result<String, Error> {
88    let random = BASE64_STANDARD.encode(get_random_value());
89    let message = format!("{}!{}", identifier.into(), random);
90    let result = sign_and_encode(secret, &message)?;
91    let token = format!("{}.{}", result, message);
92
93    Ok(token)
94}
95
96pub(crate) fn validate_token(secret: &SecStr, cookie: &str, token: &str) -> Result<bool, Error> {
97    let mut parts = token.splitn(2, '.');
98    let received_hmac = parts.next().unwrap_or("");
99
100    let message = parts.next().unwrap_or("");
101    let expected_hmac = sign_and_encode(secret, message)?;
102
103    Ok(received_hmac == expected_hmac && cookie == token)
104}
105
106#[cfg(not(test))]
107fn get_random_value() -> [u8; 64] {
108    let mut random = [0u8; 64];
109    thread_rng().fill(&mut random);
110
111    random
112}
113
114#[cfg(test)]
115fn get_random_value() -> [u8; 64] {
116    [42u8; 64]
117}
118
119fn sign_and_encode(secret: &SecStr, message: &str) -> Result<String, Error> {
120    let mut mac = HmacSha256::new_from_slice(secret.unsecure())?;
121    mac.update(message.as_bytes());
122    let result = BASE64_STANDARD.encode(mac.finalize().into_bytes());
123
124    Ok(result)
125}
126
127#[cfg(test)]
128mod tests {
129    use anyhow::Result;
130
131    use super::*;
132
133    #[test]
134    fn create_token() -> Result<()> {
135        let secret = SecStr::from("super-secret");
136        let token = super::create_token(&secret, "identifier")?;
137
138        let parts = token.splitn(2, '.').collect::<Vec<&str>>();
139        assert_eq!(parts.len(), 2);
140
141        let message = format!("{}!{}", "identifier", BASE64_STANDARD.encode([42u8; 64]));
142        assert_eq!(parts[1], message);
143
144        let signature = sign_and_encode(&secret, &message)?;
145        assert_eq!(parts[0], signature);
146
147        Ok(())
148    }
149}