1use axum::{body::Body, extract::{OptionalFromRequestParts, Request}, http::{self, request::Parts, HeaderValue, Method, StatusCode, Uri}, middleware::Next, response::{IntoResponse, Response}, Json};
2use hex::encode;
3use hmac::{Hmac, Mac};
4use serde::{Deserialize, Serialize};
5use sha2::Sha256;
6use short_uuid::short;
7use std::{collections::HashMap, sync::Arc};
8use tokio::sync::RwLock;
9use tower_cookies::Cookie;
10use std::convert::Infallible;
11use lazy_static::lazy_static;
12use rand::{rng, Rng};
13use rand::distr::Alphanumeric;
14
15type HmacSha256 = Hmac<Sha256>;
16
17lazy_static! {
18 static ref KEY:Arc<RwLock<String>> = Arc::new(RwLock::new(generate_random_string(32)));
19 static ref SECURE_COOKIE: Arc<RwLock<bool>> = Arc::new(RwLock::new(false));
20}
21
22pub fn generate_random_string(length: usize) -> String {
25 rng()
26 .sample_iter(&Alphanumeric)
27 .take(length)
28 .map(char::from)
29 .collect()
30}
31
32pub async fn set_csrf_secure_cookie_enable(secure: bool) {
36 let mut w = SECURE_COOKIE.write().await;
37 *w = secure;
38}
39
40pub async fn set_csrf_token_sign_key(key:&str) {
45 let mut w = KEY.write().await;
46 *w = key.into();
47}
48
49pub async fn sign_message(message: &str) -> String {
53 let r = KEY.read().await;
55 let key = r.clone();
56 drop(r);
57 let mut mac = HmacSha256::new_from_slice(key.as_bytes())
58 .expect("HMAC can take key of any size");
59 mac.update(message.as_bytes());
60 let bytes = mac.finalize().into_bytes();
61 return encode(bytes);
63}
64
65pub async fn generate_csrf_token() -> String {
70 let result = short!();
71 let token_raw = result.to_string();
72 return format!("{}-{}", token_raw, sign_message(&token_raw).await);
73}
74
75pub async fn validate_csrf_token(what:&str) -> bool {
78 let mut tokens = what.splitn(2, '-');
79 if let(Some(first), Some(second)) = (tokens.next(), tokens.next()) {
80 return sign_message(first).await == second;
81 }
82 return false;
83}
84
85pub async fn verify_signature(message: &str, signature: &str) -> bool {
88 let computed_signature = sign_message(message).await;
89 computed_signature == signature
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct CSRFToken{
98 pub token:String,
99 pub is_new:bool,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct CSRFTokenFromRequest(String);
106
107impl<B> OptionalFromRequestParts<B> for CSRFTokenFromRequest
108 where B:Send + Sync
109{
110 type Rejection = Infallible;
111 async fn from_request_parts(
112 parts: &mut Parts,
113 _state: &B,
114 ) -> Result<Option<Self>, Self::Rejection> {
115 let val = parts.headers.get("x-csrf-token").map(|x| x.to_str().ok()).flatten();
116 match val {
117 None => {
118 return Ok(None)
119 },
120 Some(inner) => {
121 let token = inner.to_string();
122 if token.len() > 0 && validate_csrf_token(&token).await {
123 return Ok(Some(CSRFTokenFromRequest(token)));
124 } else {
125 return Ok(None);
126 }
127 }
128 }
129 }
130}
131
132pub async fn get_csrf_token(request:Request) -> Result<Json<CSRFToken>, impl IntoResponse> {
134 let token1:Option<&CSRFToken> = request.extensions().get();
135 match token1 {
136 None => {
137 return Err((StatusCode::INTERNAL_SERVER_ERROR, "This request did not enable csrf_protection middleware"));
138 },
139 Some(token) => {
140 return Ok(Json(token.clone()));
141 }
142 }
143}
144
145fn get_csrf_token_from_query(request:&Request) -> Option<String> {
146 let uri: &Uri = request.uri();
147
148 if let Some(query) = uri.query() {
149 let params: HashMap<String, String> = serde_urlencoded::from_str(query).unwrap_or_default();
150
151 if let Some(value) = params.get("csrf_token") {
152 return Some(value.clone())
153 }
154 }
155 None
156}
157
158pub async fn csrf_protect(mut request: Request, next: Next) -> Result<Response, StatusCode> {
186 let csrf_token_from_cookie = request.headers().get(http::header::COOKIE)
187 .map(|x| x.to_str().ok()).flatten()
188 .map(|x| {
189 x.split(";").find(|y| y.starts_with("csrf_token="))
190 })
191 .flatten()
192 .map(|x| x.splitn(2, '='))
193 .map(|x| x.last())
194 .flatten()
195 .map(|x| x.trim())
196 .map(|x| x.to_string());
197 let need_new_token: bool;
198 let actual_token = match csrf_token_from_cookie {
199 Some(token) => {
200 let is_valid = validate_csrf_token(&token).await;
201 if !is_valid {
202 need_new_token = true;
203 println!("Found invalid token {}", token);
204 CSRFToken{ token: generate_csrf_token().await, is_new: true}
205 } else {
206 need_new_token = false;
207 CSRFToken{ token, is_new:false}
208 }
209 },
210 None => {
211 need_new_token = true;
212 CSRFToken{token: generate_csrf_token().await, is_new: true}
213 }
214 };
215 request.extensions_mut().insert(actual_token.clone());
216 let actual_token_str = actual_token.clone();
217 let csrf_token_from_cookie_str = actual_token.token;
218 let csrf_token_from_header_or_query = request.headers()
219 .get("x-csrf-token")
220 .map(|x| x.to_str()
221 .ok()
222 .map(|x| x.to_string())
223 ).flatten().or(get_csrf_token_from_query(&request));
224 let method = request.method();
226 let mut response = match method {
227 &Method::POST | &Method::PUT | &Method::DELETE | &Method::PATCH => {
228 let valid_csrf = match csrf_token_from_header_or_query {
229 Some(val) => {
230 if val == csrf_token_from_cookie_str {
231 true
232 } else {
233 false
234 }
235 },
236 _ => false,
237 };
238 if valid_csrf {
239 next.run(request).await
240 } else {
241 let unauthorized_response:Response<Body> = Response::builder()
242 .status(StatusCode::UNAUTHORIZED)
243 .header("x-reject-reason", "invalid-csrf-token")
244 .body("Unauthorized due to invalid csrf token.".into())
245 .unwrap();
246 let response = unauthorized_response;
247 response
248 }
249 },
250 _ => {
251 next.run(request).await
252 }
253 };
254 if need_new_token {
255 let token = actual_token_str.token;
257 let secure = {
258 *SECURE_COOKIE.read().await
259 };
260 let csrf_cookie = Cookie::build(("csrf_token", token.clone()))
261 .http_only(true)
262 .secure(secure)
263 .same_site(cookie::SameSite::Strict)
264 .path("/");
265 let header_value = csrf_cookie.build().encoded().to_string();
266 response.headers_mut().append(http::header::SET_COOKIE, HeaderValue::from_str(&header_value).unwrap());
267 }
268 return Ok(response);
269}