axum_csrf_simple/
lib.rs

1use axum::{body::Body, extract::{OptionalFromRequestParts, Request}, http::{self, request::Parts, HeaderValue, Method, StatusCode, Uri}, middleware::Next, response::{IntoResponse, Response}, Json};
2use hex::encode;
3use hmac::{Hmac, Mac};
4use serde::{Deserialize, Serialize};
5use sha2::Sha256;
6use short_uuid::short;
7use std::{collections::HashMap, sync::Arc};
8use tokio::sync::RwLock;
9use tower_cookies::Cookie;
10use std::convert::Infallible;
11use lazy_static::lazy_static;
12use rand::{rng, Rng};
13use rand::distr::Alphanumeric;
14
15type HmacSha256 = Hmac<Sha256>;
16
17lazy_static! {
18    static ref KEY:Arc<RwLock<String>> = Arc::new(RwLock::new(generate_random_string(32)));
19    static ref SECURE_COOKIE: Arc<RwLock<bool>> = Arc::new(RwLock::new(false));
20}
21
22/// Generate a random string of size
23/// Possible keys are from alpha numeric, mixing upper/lower cases
24pub fn generate_random_string(length: usize) -> String {
25    rng()
26        .sample_iter(&Alphanumeric)
27        .take(length)
28        .map(char::from)
29        .collect()
30}
31
32/// Enable or disable secure cookie. Default is to disable so it works with HTTP and HTTPS.
33/// Enabling secure cookie will make it only works with HTTPS
34/// Default is disabled
35pub async fn set_csrf_secure_cookie_enable(secure: bool) {
36    let mut w = SECURE_COOKIE.write().await;
37    *w = secure;
38}
39
40/// Set the signing key for csrf token. 
41/// If not called, CSRF token will be signed by a random 32 char alphanumeric string.
42/// Recommend to set a key with at least 32 characters.
43/// Better to call before your server start. Otherwise some existing CSRF token will become invalid.
44pub async fn set_csrf_token_sign_key(key:&str) {
45    let mut w = KEY.write().await;
46    *w = key.into();
47}
48
49/// Sign a message with the previously set sign key.
50/// Used internally, but you could use it elsewhere too.
51/// Return hex encoded signed message
52pub async fn sign_message(message: &str) -> String {
53    // Create HMAC instance
54    let r = KEY.read().await;
55    let key = r.clone();
56    drop(r);
57    let mut mac = HmacSha256::new_from_slice(key.as_bytes())
58        .expect("HMAC can take key of any size");
59    mac.update(message.as_bytes());
60    let bytes = mac.finalize().into_bytes();
61    // Sign the message
62    return encode(bytes);
63}
64
65/// Generate a CSRF token in format of xxxx-yyyy
66/// xxxx is the short uuid generated using uuid-short.
67/// yyyy is the hmac signature of the uuid-short signed with the sign key set previously 
68/// (or default 32 char random key if not set)
69pub async fn generate_csrf_token() -> String {
70    let result = short!();
71    let token_raw = result.to_string();
72    return format!("{}-{}", token_raw, sign_message(&token_raw).await);
73}
74
75/// Given a CSRF key of xxx-yyy format, use the previously set sign key to validate the value.
76/// Return true if the token is valid and signature matches
77pub async fn validate_csrf_token(what:&str) -> bool {
78    let mut tokens = what.splitn(2, '-');
79    if let(Some(first), Some(second)) = (tokens.next(), tokens.next()) {
80        return sign_message(first).await == second;
81    }
82    return false;
83}
84
85/// This is a verification function for sign_message. You can give input text, and a signature.
86/// The code will in computed signature to match the signature and return true if signature maches.
87pub async fn verify_signature(message: &str, signature: &str) -> bool {
88    let computed_signature = sign_message(message).await;
89    computed_signature == signature
90}
91
92/// Represents a CSRFToken. You can use request extension to get it.
93/// If you enable CSRF protection, the extension will guarantee the CSRF token is either
94///    Freshly initialized and cookie is set
95///    Or cookie seen, cookie value is valid, so we reuse the cookie
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct CSRFToken{ 
98    pub token:String,
99    pub is_new:bool,
100}
101
102/// Represents a CSRFToken from x-csrf-token request header. If it is available, it will be availabe for you
103/// to use using the auto extractor. 
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct CSRFTokenFromRequest(String);
106
107impl<B> OptionalFromRequestParts<B> for CSRFTokenFromRequest 
108    where B:Send + Sync
109{
110    type Rejection = Infallible;
111    async fn from_request_parts(
112        parts: &mut Parts,
113        _state: &B,
114    ) -> Result<Option<Self>, Self::Rejection> {
115        let val = parts.headers.get("x-csrf-token").map(|x| x.to_str().ok()).flatten();
116        match val {
117            None => {
118                return Ok(None)
119            },
120            Some(inner) => {
121                let token = inner.to_string();
122                if token.len() > 0 && validate_csrf_token(&token).await {
123                    return Ok(Some(CSRFTokenFromRequest(token)));
124                } else {
125                    return Ok(None);
126                }
127            }
128        }
129    }
130}
131
132/// A handler for you to expose to client. You should expose to client using your router.
133pub async fn get_csrf_token(request:Request) -> Result<Json<CSRFToken>, impl IntoResponse> {
134    let token1:Option<&CSRFToken> = request.extensions().get();
135    match token1 {
136        None => {
137            return Err((StatusCode::INTERNAL_SERVER_ERROR, "This request did not enable csrf_protection middleware"));
138        },
139        Some(token) => {
140            return Ok(Json(token.clone()));
141        }
142    }
143}
144
145fn get_csrf_token_from_query(request:&Request) -> Option<String> {
146    let uri: &Uri = request.uri();
147
148    if let Some(query) = uri.query() {
149        let params: HashMap<String, String> = serde_urlencoded::from_str(query).unwrap_or_default();
150        
151        if let Some(value) = params.get("csrf_token") {
152            return Some(value.clone())
153        }
154    }
155    None
156}
157
158/// Middleware to protect CSRF
159/// ```rust,no_run
160/// use axum_csrf_simple as csrf;
161/// use axum::middleware;
162/// use axum::routing::get;
163/// use axum::Router;
164/// use std::net::SocketAddr;
165/// 
166/// #[tokio::main]
167/// async fn main() {
168///   csrf::set_csrf_token_sign_key("key").await;
169///   csrf::set_csrf_secure_cookie_enable(false).await;
170///   let app1 = Router::new()
171///      .route("/admin/endpoint1", get(handle1).post(handle1).put(handle1))
172///      .route("/api/csrf", get(csrf::get_csrf_token))
173///      .route_layer(middleware::from_fn(csrf::csrf_protect));
174///   let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
175///   axum_server::bind(addr)
176///     .serve(app1.into_make_service())
177///     .await
178///     .unwrap();
179/// }
180/// 
181/// async fn handle1() -> &'static str{
182///     "HELLO"
183/// }
184/// ```
185pub async fn csrf_protect(mut request: Request, next: Next) -> Result<Response, StatusCode> {
186    let csrf_token_from_cookie = request.headers().get(http::header::COOKIE)
187        .map(|x| x.to_str().ok()).flatten()
188        .map(|x| {
189            x.split(";").find(|y| y.starts_with("csrf_token="))
190        })
191        .flatten()
192        .map(|x| x.splitn(2, '='))
193        .map(|x| x.last())
194        .flatten()
195        .map(|x| x.trim())
196        .map(|x| x.to_string());
197    let need_new_token: bool;
198    let actual_token = match csrf_token_from_cookie {
199        Some(token) => {
200            let is_valid = validate_csrf_token(&token).await;
201            if !is_valid {
202                need_new_token = true;
203                println!("Found invalid token {}", token);
204                CSRFToken{ token: generate_csrf_token().await, is_new: true}
205            } else {
206                need_new_token = false;
207                CSRFToken{ token, is_new:false}
208            }
209        },
210        None => {
211            need_new_token = true;
212            CSRFToken{token: generate_csrf_token().await, is_new: true}
213        }
214    };
215    request.extensions_mut().insert(actual_token.clone());
216    let actual_token_str = actual_token.clone();
217    let csrf_token_from_cookie_str = actual_token.token;
218    let csrf_token_from_header_or_query = request.headers()
219        .get("x-csrf-token")
220        .map(|x| x.to_str()
221            .ok()
222            .map(|x| x.to_string())
223        ).flatten().or(get_csrf_token_from_query(&request));
224    //let csrf_token_from_query_string = request
225    let method = request.method();
226    let mut response = match method {
227        &Method::POST | &Method::PUT | &Method::DELETE | &Method::PATCH => {
228            let valid_csrf = match csrf_token_from_header_or_query {
229                Some(val) =>  {
230                    if val == csrf_token_from_cookie_str {
231                        true
232                    } else {
233                        false
234                    }
235                },
236                _ => false,
237            };
238            if valid_csrf {
239                next.run(request).await
240            } else {
241                let unauthorized_response:Response<Body> = Response::builder()
242                    .status(StatusCode::UNAUTHORIZED)
243                    .header("x-reject-reason", "invalid-csrf-token")
244                    .body("Unauthorized due to invalid csrf token.".into())
245                    .unwrap();
246                let response = unauthorized_response;
247                response
248            }
249        },
250        _ => {
251            next.run(request).await
252        }
253    };
254    if need_new_token {
255        // A new CSRF token is genreated, we need to set to cookie
256        let token = actual_token_str.token;
257        let secure = {
258            *SECURE_COOKIE.read().await
259        };
260        let csrf_cookie = Cookie::build(("csrf_token", token.clone()))
261            .http_only(true)
262            .secure(secure)
263            .same_site(cookie::SameSite::Strict)
264            .path("/");
265        let header_value = csrf_cookie.build().encoded().to_string();
266        response.headers_mut().append(http::header::SET_COOKIE, HeaderValue::from_str(&header_value).unwrap());
267    } 
268    return Ok(response);
269}