1use std::collections::HashMap;
8use std::future::{Future, Ready, ready};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12
13use actix_web::cookie::{Cookie, SameSite};
14use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
15use actix_web::http::header;
16use actix_web::{HttpMessage, HttpRequest, HttpResponse, HttpResponseBuilder};
17use base64::Engine;
18use base64::engine::general_purpose::URL_SAFE;
19use chrono::{DateTime, Utc};
20use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
21use serde_json::Value;
22use tracing::warn;
23
24use crate::core::{Token, TokenStore};
25use crate::errors::JwtError;
26use crate::store::InMemoryRefreshTokenStore;
27
28#[derive(Debug, Clone)]
32pub struct JwtPayload(pub HashMap<String, Value>);
33
34#[derive(Debug, Clone)]
38pub struct JwtTokenString(pub String);
39
40#[derive(Debug, Clone)]
44pub struct JwtIdentity(pub Value);
45
46pub struct ActixJwtMiddleware {
116 pub realm: String,
118 pub identity_key: String,
120
121 pub signing_algorithm: String,
123 pub key: Vec<u8>,
125 pub key_func:
128 Option<Arc<dyn Fn(&jsonwebtoken::Header) -> Result<DecodingKey, JwtError> + Send + Sync>>,
129
130 pub timeout: Duration,
132 pub timeout_func: Option<Arc<dyn Fn(&Value) -> Duration + Send + Sync>>,
134 pub max_refresh: Duration,
136 pub time_func: Arc<dyn Fn() -> DateTime<Utc> + Send + Sync>,
138
139 pub authenticator:
141 Option<Arc<dyn Fn(&HttpRequest, &[u8]) -> Result<Value, JwtError> + Send + Sync>>,
142 pub authorizer: Arc<dyn Fn(&HttpRequest, &Value) -> bool + Send + Sync>,
144 pub payload_func: Option<Arc<dyn Fn(&Value) -> HashMap<String, Value> + Send + Sync>>,
146 pub identity_handler: Arc<dyn Fn(&HttpRequest) -> Option<Value> + Send + Sync>,
148
149 pub unauthorized: Arc<dyn Fn(&HttpRequest, u16, &str) -> HttpResponse + Send + Sync>,
151 pub login_response: Arc<dyn Fn(&HttpRequest, &Token) -> HttpResponse + Send + Sync>,
153 pub logout_response: Arc<dyn Fn(&HttpRequest) -> HttpResponse + Send + Sync>,
155 pub refresh_response: Arc<dyn Fn(&HttpRequest, &Token) -> HttpResponse + Send + Sync>,
157 pub http_status_message_func: Arc<dyn Fn(&HttpRequest, &JwtError) -> String + Send + Sync>,
159
160 pub token_lookup: String,
163 pub token_head_name: String,
165 pub exp_field: String,
167
168 pub priv_key_file: Option<String>,
170 pub priv_key_bytes: Option<Vec<u8>>,
172 pub pub_key_file: Option<String>,
174 pub pub_key_bytes: Option<Vec<u8>>,
176 pub private_key_passphrase: Option<String>,
178 encoding_key: Option<EncodingKey>,
179 decoding_key: Option<DecodingKey>,
180
181 pub send_cookie: bool,
183 pub cookie_max_age: Duration,
185 pub secure_cookie: bool,
187 pub cookie_http_only: bool,
189 pub cookie_domain: Option<String>,
191 pub cookie_name: String,
193 pub cookie_same_site: SameSite,
195 pub send_authorization: bool,
198
199 pub refresh_token_timeout: Duration,
201 pub refresh_token_cookie_name: String,
203 pub refresh_token_length: usize,
205 pub refresh_token_store: Arc<dyn TokenStore>,
207
208 pub skipper: Option<Arc<dyn Fn(&ServiceRequest) -> bool + Send + Sync>>,
211 pub before_func: Option<Arc<dyn Fn(&ServiceRequest) + Send + Sync>>,
213 pub success_handler: Option<Arc<dyn Fn(&HttpRequest) -> Result<(), JwtError> + Send + Sync>>,
215 pub error_handler:
218 Option<Arc<dyn Fn(&HttpRequest, JwtError) -> Option<JwtError> + Send + Sync>>,
219 pub continue_on_ignored_error: bool,
223}
224
225impl ActixJwtMiddleware {
226 pub fn new() -> Self {
233 Self {
234 realm: "actix jwt".to_string(),
235 identity_key: "identity".to_string(),
236
237 signing_algorithm: "HS256".to_string(),
238 key: Vec::new(),
239 key_func: None,
240
241 timeout: Duration::from_secs(3600), timeout_func: None,
243 max_refresh: Duration::ZERO,
244 time_func: Arc::new(Utc::now),
245
246 authenticator: None,
247 authorizer: Arc::new(|_req, _data| true),
248 payload_func: None,
249 identity_handler: Arc::new(|req| {
250 let ext = req.extensions();
251 let payload = ext.get::<JwtPayload>()?;
252 payload.0.get("identity").cloned()
253 }),
254
255 unauthorized: Arc::new(|_req, code, message| {
256 HttpResponse::build(
257 actix_web::http::StatusCode::from_u16(code)
258 .unwrap_or(actix_web::http::StatusCode::UNAUTHORIZED),
259 )
260 .json(serde_json::json!({
261 "code": code,
262 "message": message,
263 }))
264 }),
265 login_response: Arc::new(|_req, token| {
266 HttpResponse::Ok().json(Self::generate_token_response_static(token))
267 }),
268 logout_response: Arc::new(|_req| {
269 HttpResponse::Ok().json(serde_json::json!({ "code": 200 }))
270 }),
271 refresh_response: Arc::new(|_req, token| {
272 HttpResponse::Ok().json(Self::generate_token_response_static(token))
273 }),
274 http_status_message_func: Arc::new(|_req, err| err.to_string()),
275
276 token_lookup: "header:Authorization".to_string(),
277 token_head_name: "Bearer".to_string(),
278 exp_field: "exp".to_string(),
279
280 priv_key_file: None,
281 priv_key_bytes: None,
282 pub_key_file: None,
283 pub_key_bytes: None,
284 private_key_passphrase: None,
285 encoding_key: None,
286 decoding_key: None,
287
288 send_cookie: false,
289 cookie_max_age: Duration::from_secs(3600),
290 secure_cookie: false,
291 cookie_http_only: false,
292 cookie_domain: None,
293 cookie_name: "jwt".to_string(),
294 cookie_same_site: SameSite::Lax,
295 send_authorization: false,
296
297 refresh_token_timeout: Duration::from_secs(30 * 24 * 3600), refresh_token_cookie_name: "refresh_token".to_string(),
299 refresh_token_length: 32,
300 refresh_token_store: Arc::new(InMemoryRefreshTokenStore::new()),
301
302 skipper: None,
303 before_func: None,
304 success_handler: None,
305 error_handler: None,
306 continue_on_ignored_error: false,
307 }
308 }
309
310 pub fn init(&mut self) -> Result<(), JwtError> {
321 if self.token_lookup.is_empty() {
322 self.token_lookup = "header:Authorization".to_string();
323 }
324
325 if self.signing_algorithm.is_empty() {
326 self.signing_algorithm = "HS256".to_string();
327 }
328
329 if self.timeout == Duration::ZERO {
330 self.timeout = Duration::from_secs(3600);
331 }
332
333 let token_head = self.token_head_name.trim().to_string();
334 self.token_head_name = if token_head.is_empty() {
335 "Bearer".to_string()
336 } else {
337 token_head
338 };
339
340 if self.realm.is_empty() {
341 self.realm = "actix jwt".to_string();
342 }
343
344 if self.cookie_max_age == Duration::ZERO {
345 self.cookie_max_age = self.timeout;
346 }
347
348 if self.cookie_name.is_empty() {
349 self.cookie_name = "jwt".to_string();
350 }
351
352 if self.refresh_token_cookie_name.is_empty() {
353 self.refresh_token_cookie_name = "refresh_token".to_string();
354 }
355
356 if self.exp_field.is_empty() {
357 self.exp_field = "exp".to_string();
358 }
359
360 if self.identity_key.is_empty() {
361 self.identity_key = "identity".to_string();
362 }
363
364 if self.refresh_token_timeout == Duration::ZERO {
365 self.refresh_token_timeout = Duration::from_secs(30 * 24 * 3600);
366 }
367
368 if self.refresh_token_length == 0 {
369 self.refresh_token_length = 32;
370 }
371
372 if self.key_func.is_some() {
374 return Ok(());
375 }
376
377 if self.using_public_key_algo() {
378 return self.read_keys();
379 }
380
381 if self.key.is_empty() {
382 return Err(JwtError::MissingSecretKey);
383 }
384
385 self.encoding_key = Some(EncodingKey::from_secret(&self.key));
386 self.decoding_key = Some(DecodingKey::from_secret(&self.key));
387
388 Ok(())
389 }
390
391 pub fn using_public_key_algo(&self) -> bool {
393 matches!(self.signing_algorithm.as_str(), "RS256" | "RS384" | "RS512")
394 }
395
396 fn algorithm(&self) -> Result<Algorithm, JwtError> {
398 match self.signing_algorithm.as_str() {
399 "HS256" => Ok(Algorithm::HS256),
400 "HS384" => Ok(Algorithm::HS384),
401 "HS512" => Ok(Algorithm::HS512),
402 "RS256" => Ok(Algorithm::RS256),
403 "RS384" => Ok(Algorithm::RS384),
404 "RS512" => Ok(Algorithm::RS512),
405 _ => Err(JwtError::InvalidSigningAlgorithm),
406 }
407 }
408
409 fn read_keys(&mut self) -> Result<(), JwtError> {
410 self.load_private_key()?;
411 self.load_public_key()?;
412 Ok(())
413 }
414
415 fn load_private_key(&mut self) -> Result<(), JwtError> {
416 let key_data = if let Some(ref path) = self.priv_key_file {
417 std::fs::read(path).map_err(|e| {
418 warn!("Failed to read private key file {}: {}", path, e);
419 JwtError::NoPrivKeyFile
420 })?
421 } else if let Some(ref bytes) = self.priv_key_bytes {
422 bytes.clone()
423 } else {
424 return Err(JwtError::NoPrivKeyFile);
425 };
426
427 if let Some(ref passphrase) = self.private_key_passphrase {
428 let pem_str = std::str::from_utf8(&key_data).map_err(|_| JwtError::InvalidPrivKey)?;
430 let doc = pkcs8::EncryptedPrivateKeyInfo::try_from(pem_str.as_bytes())
431 .map_err(|_| JwtError::InvalidPrivKey)?;
432
433 let decrypted = doc
434 .decrypt(passphrase.as_bytes())
435 .map_err(|_| JwtError::InvalidPrivKey)?;
436
437 let der_bytes = decrypted.as_bytes();
438
439 let pem = pem::encode(&pem::Pem::new("PRIVATE KEY", der_bytes.to_vec()));
441 self.encoding_key = Some(
442 EncodingKey::from_rsa_pem(pem.as_bytes()).map_err(|_| JwtError::InvalidPrivKey)?,
443 );
444 } else {
445 self.encoding_key =
446 Some(EncodingKey::from_rsa_pem(&key_data).map_err(|_| JwtError::InvalidPrivKey)?);
447 }
448
449 Ok(())
450 }
451
452 fn load_public_key(&mut self) -> Result<(), JwtError> {
453 let key_data = if let Some(ref path) = self.pub_key_file {
454 std::fs::read(path).map_err(|e| {
455 warn!("Failed to read public key file {}: {}", path, e);
456 JwtError::NoPubKeyFile
457 })?
458 } else if let Some(ref bytes) = self.pub_key_bytes {
459 bytes.clone()
460 } else {
461 return Err(JwtError::NoPubKeyFile);
462 };
463
464 self.decoding_key =
465 Some(DecodingKey::from_rsa_pem(&key_data).map_err(|_| JwtError::InvalidPubKey)?);
466
467 Ok(())
468 }
469
470 pub fn generate_access_token(&self, data: &Value) -> Result<(String, DateTime<Utc>), JwtError> {
472 let alg = self.algorithm()?;
473
474 let mut claims = serde_json::Map::new();
475
476 let framework_claims: &[&str] = &["exp", "orig_iat"];
478
479 if let Some(ref pf) = self.payload_func {
480 for (k, v) in pf(data) {
481 if !framework_claims.contains(&k.as_str()) {
482 claims.insert(k, v);
483 }
484 }
485 }
486
487 let now = (self.time_func)();
488 let timeout = self
489 .timeout_func
490 .as_ref()
491 .map(|f| f(data))
492 .unwrap_or(self.timeout);
493 let expire = now
494 + chrono::Duration::from_std(timeout)
495 .unwrap_or_else(|_| chrono::Duration::seconds(3600));
496
497 claims.insert(
498 self.exp_field.clone(),
499 Value::Number(expire.timestamp().into()),
500 );
501 claims.insert(
502 "orig_iat".to_string(),
503 Value::Number(now.timestamp().into()),
504 );
505
506 let header = Header::new(alg);
507 let claims_value = Value::Object(claims);
508
509 let encoding_key = self
510 .encoding_key
511 .as_ref()
512 .ok_or(JwtError::MissingSecretKey)?;
513
514 let token_string = jsonwebtoken::encode(&header, &claims_value, encoding_key)
515 .map_err(|_| JwtError::FailedTokenCreation)?;
516
517 Ok((token_string, expire))
518 }
519
520 pub fn generate_refresh_token(&self) -> Result<String, JwtError> {
522 use rand::RngCore;
523 let mut buf = vec![0u8; self.refresh_token_length];
524 rand::thread_rng()
525 .try_fill_bytes(&mut buf)
526 .map_err(|e| JwtError::Internal(format!("RNG failure: {e}")))?;
527 Ok(URL_SAFE.encode(&buf))
528 }
529
530 async fn store_refresh_token(&self, token: &str, user_data: &Value) -> Result<(), JwtError> {
532 let expiry = (self.time_func)()
533 + chrono::Duration::from_std(self.refresh_token_timeout)
534 .unwrap_or_else(|_| chrono::Duration::days(30));
535 self.refresh_token_store
536 .set(token, user_data.clone(), expiry)
537 .await
538 }
539
540 async fn validate_refresh_token(&self, token: &str) -> Result<Value, JwtError> {
542 self.refresh_token_store
543 .get(token)
544 .await
545 .map_err(|e| match e {
546 JwtError::RefreshTokenNotFound => JwtError::InvalidRefreshToken,
547 other => other,
548 })
549 }
550
551 async fn revoke_refresh_token(&self, token: &str) -> Result<(), JwtError> {
553 self.refresh_token_store.delete(token).await
554 }
555
556 pub async fn token_generator(&self, data: &Value) -> Result<Token, JwtError> {
560 let (access_token, expire) = self.generate_access_token(data)?;
561 let refresh_token = self.generate_refresh_token()?;
562
563 self.store_refresh_token(&refresh_token, data).await?;
564
565 let now = (self.time_func)();
566 Ok(Token {
567 access_token,
568 token_type: "Bearer".to_string(),
569 refresh_token: Some(refresh_token),
570 expires_at: expire.timestamp(),
571 created_at: now.timestamp(),
572 })
573 }
574
575 pub async fn token_generator_with_revocation(
577 &self,
578 data: &Value,
579 old_refresh_token: &str,
580 ) -> Result<Token, JwtError> {
581 let token_pair = self.token_generator(data).await?;
582
583 if let Err(e) = self.revoke_refresh_token(old_refresh_token).await {
585 if !matches!(e, JwtError::RefreshTokenNotFound) {
586 return Err(e);
587 }
588 }
589
590 Ok(token_pair)
591 }
592
593 pub fn parse_token_from_request(
595 &self,
596 req: &HttpRequest,
597 ) -> Result<TokenData<Value>, JwtError> {
598 let token_str = self.extract_token_string(req)?;
599
600 req.extensions_mut()
602 .insert(JwtTokenString(token_str.clone()));
603
604 self.parse_token_string(&token_str)
605 }
606
607 pub fn parse_token_string(&self, token: &str) -> Result<TokenData<Value>, JwtError> {
609 let alg = self.algorithm()?;
610
611 if let Some(ref kf) = self.key_func {
612 let header = jsonwebtoken::decode_header(token)
614 .map_err(|e| JwtError::TokenParsing(e.to_string()))?;
615 let dk = kf(&header)?;
616 let mut validation = Validation::new(alg);
617 validation.validate_exp = true;
618 validation.validate_aud = false;
619 validation.required_spec_claims.clear();
620 return jsonwebtoken::decode::<Value>(token, &dk, &validation)
621 .map_err(|e| JwtError::TokenParsing(e.to_string()));
622 }
623
624 let decoding_key = self
625 .decoding_key
626 .as_ref()
627 .ok_or(JwtError::MissingSecretKey)?;
628
629 let mut validation = Validation::new(alg);
630 validation.validate_exp = true;
631 validation.validate_aud = false;
632 validation.required_spec_claims.clear();
633
634 jsonwebtoken::decode::<Value>(token, decoding_key, &validation)
635 .map_err(|e| JwtError::TokenParsing(e.to_string()))
636 }
637
638 fn extract_token_string(&self, req: &HttpRequest) -> Result<String, JwtError> {
640 let methods: Vec<&str> = self.token_lookup.split(',').collect();
641 let mut last_err: Option<JwtError> = None;
642
643 for method in methods {
644 let parts: Vec<&str> = method.trim().splitn(2, ':').collect();
645 if parts.len() != 2 {
646 continue;
647 }
648 let source = parts[0].trim();
649 let name = parts[1].trim();
650
651 let result = match source {
652 "header" => self.jwt_from_header(req, name),
653 "query" => self.jwt_from_query(req, name),
654 "cookie" => self.jwt_from_cookie(req, name),
655 "param" => self.jwt_from_param(req, name),
656 "form" => self.jwt_from_form(req, name),
657 _ => continue,
658 };
659
660 match result {
661 Ok(t) if !t.is_empty() => return Ok(t),
662 Ok(_) => {}
663 Err(e) => {
664 last_err = Some(e);
665 }
666 }
667 }
668
669 Err(last_err.unwrap_or(JwtError::TokenExtraction(
670 "no token found in request".to_string(),
671 )))
672 }
673
674 fn jwt_from_header(&self, req: &HttpRequest, key: &str) -> Result<String, JwtError> {
675 let auth_header = req
676 .headers()
677 .get(key)
678 .and_then(|v| v.to_str().ok())
679 .unwrap_or("");
680
681 if auth_header.is_empty() {
682 return Err(JwtError::EmptyAuthHeader);
683 }
684
685 let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
686 if parts.len() != 2 || parts[0] != self.token_head_name {
687 return Err(JwtError::InvalidAuthHeader);
688 }
689
690 Ok(parts[1].to_string())
691 }
692
693 fn jwt_from_query(&self, req: &HttpRequest, key: &str) -> Result<String, JwtError> {
694 let qs = req.query_string();
695 for pair in qs.split('&') {
697 let mut kv = pair.splitn(2, '=');
698 if let (Some(k), Some(v)) = (kv.next(), kv.next()) {
699 if k == key && !v.is_empty() {
700 return Ok(v.to_string());
701 }
702 }
703 }
704 Err(JwtError::EmptyQueryToken)
705 }
706
707 fn jwt_from_cookie(&self, req: &HttpRequest, key: &str) -> Result<String, JwtError> {
708 req.cookie(key)
709 .map(|c| c.value().to_string())
710 .filter(|v| !v.is_empty())
711 .ok_or(JwtError::EmptyCookieToken)
712 }
713
714 fn jwt_from_param(&self, req: &HttpRequest, key: &str) -> Result<String, JwtError> {
715 let val = req.match_info().get(key).unwrap_or("");
716 if val.is_empty() {
717 return Err(JwtError::EmptyParamToken);
718 }
719 Ok(val.to_string())
720 }
721
722 fn jwt_from_form(&self, _req: &HttpRequest, _key: &str) -> Result<String, JwtError> {
723 Err(JwtError::EmptyParamToken)
727 }
728
729 fn get_claims_from_jwt(&self, req: &HttpRequest) -> Result<HashMap<String, Value>, JwtError> {
731 let token_data = self.parse_token_from_request(req)?;
732
733 let claims_map = match token_data.claims {
738 Value::Object(map) => map.into_iter().collect(),
739 _ => HashMap::new(),
740 };
741
742 Ok(claims_map)
743 }
744
745 fn middleware_impl(&self, req: &HttpRequest) -> Result<(), JwtError> {
747 let claims = self
748 .get_claims_from_jwt(req)
749 .map_err(|e| JwtError::TokenParsing(e.to_string()))?;
750
751 if !claims.contains_key("exp") {
753 return Err(JwtError::TokenExtraction(
754 JwtError::MissingExpField.to_string(),
755 ));
756 }
757
758 req.extensions_mut().insert(JwtPayload(claims));
759
760 let identity = (self.identity_handler)(req);
761 if let Some(ref id) = identity {
762 req.extensions_mut().insert(JwtIdentity(id.clone()));
763 }
764
765 let auth_data = identity.unwrap_or(Value::Null);
766 if !(self.authorizer)(req, &auth_data) {
767 return Err(JwtError::Forbidden);
768 }
769
770 Ok(())
771 }
772
773 fn unauthorized_response(&self, req: &HttpRequest, code: u16, message: &str) -> HttpResponse {
775 let mut resp = (self.unauthorized)(req, code, message);
776 resp.headers_mut().insert(
777 header::WWW_AUTHENTICATE,
778 format!("Bearer realm=\"{}\"", self.realm).parse().unwrap(),
779 );
780 resp
781 }
782
783 fn handle_middleware_error(&self, req: &HttpRequest, err: &JwtError) -> HttpResponse {
785 match err {
786 JwtError::Forbidden => {
787 let msg = (self.http_status_message_func)(req, &JwtError::Forbidden);
788 self.unauthorized_response(req, 403, &msg)
789 }
790 JwtError::TokenParsing(inner) => self.handle_token_error(req, inner),
791 JwtError::TokenExtraction(inner) => {
792 let msg = inner.clone();
793 self.unauthorized_response(req, 400, &msg)
794 }
795 other => {
796 let msg = (self.http_status_message_func)(req, other);
797 self.unauthorized_response(req, 401, &msg)
798 }
799 }
800 }
801
802 fn handle_token_error(&self, req: &HttpRequest, detail: &str) -> HttpResponse {
803 let lower = detail.to_lowercase();
804 if lower.contains("expired") {
805 let msg = (self.http_status_message_func)(req, &JwtError::ExpiredToken);
806 self.unauthorized_response(req, 401, &msg)
807 } else if lower.contains("exp") && lower.contains("invalid") {
808 let msg = (self.http_status_message_func)(req, &JwtError::WrongFormatOfExp);
809 self.unauthorized_response(req, 400, &msg)
810 } else if lower.contains("exp") && lower.contains("required") {
811 let msg = (self.http_status_message_func)(req, &JwtError::MissingExpField);
812 self.unauthorized_response(req, 400, &msg)
813 } else {
814 let err = JwtError::TokenParsing(detail.to_string());
815 let msg = (self.http_status_message_func)(req, &err);
816 self.unauthorized_response(req, 401, &msg)
817 }
818 }
819
820 pub fn set_cookie(builder: &mut HttpResponseBuilder, config: &CookieConfig, value: &str) {
822 let mut cookie = Cookie::build(config.name.clone(), value.to_string())
823 .path("/")
824 .max_age(actix_web::cookie::time::Duration::seconds(
825 config.max_age.as_secs() as i64,
826 ))
827 .secure(config.secure)
828 .http_only(config.http_only)
829 .same_site(config.same_site)
830 .finish();
831
832 if let Some(ref domain) = config.domain {
833 cookie.set_domain(domain.clone());
834 }
835
836 builder.cookie(cookie);
837 }
838
839 pub fn access_cookie_config(&self) -> CookieConfig {
841 CookieConfig {
842 name: self.cookie_name.clone(),
843 max_age: self.cookie_max_age,
844 secure: self.secure_cookie,
845 http_only: self.cookie_http_only,
846 domain: self.cookie_domain.clone(),
847 same_site: self.cookie_same_site,
848 }
849 }
850
851 pub fn refresh_cookie_config(&self) -> CookieConfig {
853 CookieConfig {
854 name: self.refresh_token_cookie_name.clone(),
855 max_age: self.refresh_token_timeout,
856 secure: true, http_only: true, domain: self.cookie_domain.clone(),
859 same_site: self.cookie_same_site,
860 }
861 }
862
863 fn append_cookie(
868 headers: &mut actix_web::http::header::HeaderMap,
869 config: &CookieConfig,
870 value: &str,
871 ) {
872 let mut cookie = Cookie::build(config.name.clone(), value.to_string())
873 .path("/")
874 .max_age(actix_web::cookie::time::Duration::seconds(
875 config.max_age.as_secs() as i64,
876 ))
877 .secure(config.secure)
878 .http_only(config.http_only)
879 .same_site(config.same_site)
880 .finish();
881
882 if let Some(ref domain) = config.domain {
883 cookie.set_domain(domain.clone());
884 }
885
886 headers.append(header::SET_COOKIE, cookie.to_string().parse().unwrap());
887 }
888
889 fn append_delete_cookie(
892 headers: &mut actix_web::http::header::HeaderMap,
893 config: &CookieConfig,
894 ) {
895 let mut cookie = Cookie::build(config.name.clone(), "")
896 .path("/")
897 .max_age(actix_web::cookie::time::Duration::seconds(-1))
898 .secure(config.secure)
899 .http_only(config.http_only)
900 .same_site(config.same_site)
901 .finish();
902
903 if let Some(ref domain) = config.domain {
904 cookie.set_domain(domain.clone());
905 }
906
907 headers.append(header::SET_COOKIE, cookie.to_string().parse().unwrap());
908 }
909
910 pub fn delete_cookie(builder: &mut HttpResponseBuilder, config: &CookieConfig) {
912 let mut cookie = Cookie::build(config.name.clone(), "")
913 .path("/")
914 .max_age(actix_web::cookie::time::Duration::seconds(-1))
915 .secure(config.secure)
916 .http_only(config.http_only)
917 .same_site(config.same_site)
918 .finish();
919
920 if let Some(ref domain) = config.domain {
921 cookie.set_domain(domain.clone());
922 }
923
924 builder.cookie(cookie);
925 }
926
927 pub async fn login_handler(&self, req: &HttpRequest, body: &[u8]) -> HttpResponse {
940 let authenticator = match self.authenticator {
941 Some(ref auth) => auth,
942 None => {
943 let msg = (self.http_status_message_func)(req, &JwtError::MissingAuthenticator);
944 return self.unauthorized_response(req, 500, &msg);
945 }
946 };
947
948 let data = match authenticator(req, body) {
949 Ok(d) => d,
950 Err(e) => {
951 let msg = (self.http_status_message_func)(req, &e);
952 return self.unauthorized_response(req, 401, &msg);
953 }
954 };
955
956 let token_pair = match self.token_generator(&data).await {
957 Ok(t) => t,
958 Err(_) => {
959 let msg = (self.http_status_message_func)(req, &JwtError::FailedTokenCreation);
960 return self.unauthorized_response(req, 500, &msg);
961 }
962 };
963
964 let mut resp = (self.login_response)(req, &token_pair);
965
966 if self.send_cookie {
967 Self::append_cookie(
968 resp.headers_mut(),
969 &self.access_cookie_config(),
970 &token_pair.access_token,
971 );
972 if let Some(ref rt) = token_pair.refresh_token {
973 Self::append_cookie(resp.headers_mut(), &self.refresh_cookie_config(), rt);
974 }
975 }
976
977 resp
978 }
979
980 pub fn extract_refresh_token(&self, req: &HttpRequest, body: &[u8]) -> Option<String> {
982 if let Some(cookie) = req.cookie(&self.refresh_token_cookie_name) {
984 let val = cookie.value().to_string();
985 if !val.is_empty() {
986 return Some(val);
987 }
988 }
989
990 let content_type = req
992 .headers()
993 .get(header::CONTENT_TYPE)
994 .and_then(|v| v.to_str().ok())
995 .unwrap_or("");
996
997 if content_type.contains("application/x-www-form-urlencoded")
998 || content_type.contains("multipart/form-data")
999 {
1000 let body_str = std::str::from_utf8(body).unwrap_or("");
1002 for pair in body_str.split('&') {
1003 let mut kv = pair.splitn(2, '=');
1004 if let (Some(k), Some(v)) = (kv.next(), kv.next()) {
1005 if k == "refresh_token" && !v.is_empty() {
1006 return Some(v.to_string());
1007 }
1008 }
1009 }
1010 } else if content_type.contains("application/json") {
1011 #[derive(serde::Deserialize)]
1012 struct RefreshBody {
1013 refresh_token: Option<String>,
1014 }
1015 if let Ok(parsed) = serde_json::from_slice::<RefreshBody>(body) {
1016 if let Some(rt) = parsed.refresh_token {
1017 if !rt.is_empty() {
1018 return Some(rt);
1019 }
1020 }
1021 }
1022 }
1023
1024 None
1025 }
1026
1027 pub async fn logout_handler(&self, req: &HttpRequest, body: &[u8]) -> HttpResponse {
1029 if let Ok(claims) = self.get_claims_from_jwt(req) {
1031 req.extensions_mut().insert(JwtPayload(claims));
1032 let identity = (self.identity_handler)(req);
1033 if let Some(ref id) = identity {
1034 req.extensions_mut().insert(JwtIdentity(id.clone()));
1035 }
1036 }
1037
1038 if let Some(ref rt) = self.extract_refresh_token(req, body) {
1040 if let Err(e) = self.revoke_refresh_token(rt).await {
1041 warn!("Failed to revoke refresh token on logout: {}", e);
1042 }
1043 }
1044
1045 let mut resp = (self.logout_response)(req);
1046
1047 if self.send_cookie {
1048 Self::append_delete_cookie(resp.headers_mut(), &self.access_cookie_config());
1049 Self::append_delete_cookie(resp.headers_mut(), &self.refresh_cookie_config());
1050 }
1051
1052 resp
1053 }
1054
1055 pub async fn refresh_handler(&self, req: &HttpRequest, body: &[u8]) -> HttpResponse {
1058 let refresh_token = match self.extract_refresh_token(req, body) {
1059 Some(rt) => rt,
1060 None => {
1061 let msg = (self.http_status_message_func)(req, &JwtError::MissingRefreshToken);
1062 return self.unauthorized_response(req, 400, &msg);
1063 }
1064 };
1065
1066 let user_data = match self.validate_refresh_token(&refresh_token).await {
1067 Ok(d) => d,
1068 Err(e) => {
1069 let msg = (self.http_status_message_func)(req, &e);
1070 return self.unauthorized_response(req, 401, &msg);
1071 }
1072 };
1073
1074 let token_pair = match self
1075 .token_generator_with_revocation(&user_data, &refresh_token)
1076 .await
1077 {
1078 Ok(t) => t,
1079 Err(e) => {
1080 let msg = (self.http_status_message_func)(req, &e);
1081 return self.unauthorized_response(req, 500, &msg);
1082 }
1083 };
1084
1085 let mut resp = (self.refresh_response)(req, &token_pair);
1086
1087 if self.send_cookie {
1088 Self::append_cookie(
1089 resp.headers_mut(),
1090 &self.access_cookie_config(),
1091 &token_pair.access_token,
1092 );
1093 if let Some(ref rt) = token_pair.refresh_token {
1094 Self::append_cookie(resp.headers_mut(), &self.refresh_cookie_config(), rt);
1095 }
1096 }
1097
1098 resp
1099 }
1100
1101 pub fn generate_token_response(token: &Token) -> serde_json::Map<String, Value> {
1105 let mut map = serde_json::Map::new();
1106 map.insert(
1107 "access_token".into(),
1108 Value::String(token.access_token.clone()),
1109 );
1110 map.insert("token_type".into(), Value::String(token.token_type.clone()));
1111 map.insert(
1112 "expires_in".into(),
1113 Value::Number(token.expires_in().into()),
1114 );
1115
1116 if let Some(ref rt) = token.refresh_token {
1117 map.insert("refresh_token".into(), Value::String(rt.clone()));
1118 }
1119
1120 map
1121 }
1122
1123 fn generate_token_response_static(token: &Token) -> Value {
1126 let map = Self::generate_token_response(token);
1127 Value::Object(map)
1128 }
1129
1130 pub fn middleware(self: &Arc<Self>) -> JwtAuth {
1134 JwtAuth {
1135 inner: self.clone(),
1136 }
1137 }
1138}
1139
1140impl Default for ActixJwtMiddleware {
1141 fn default() -> Self {
1142 Self::new()
1143 }
1144}
1145
1146pub struct CookieConfig {
1152 pub name: String,
1154 pub max_age: Duration,
1156 pub secure: bool,
1158 pub http_only: bool,
1160 pub domain: Option<String>,
1162 pub same_site: SameSite,
1164}
1165
1166pub fn extract_claims(req: &HttpRequest) -> HashMap<String, Value> {
1183 req.extensions()
1184 .get::<JwtPayload>()
1185 .map(|p| p.0.clone())
1186 .unwrap_or_default()
1187}
1188
1189pub fn get_token(req: &HttpRequest) -> Option<String> {
1193 req.extensions()
1194 .get::<JwtTokenString>()
1195 .map(|t| t.0.clone())
1196}
1197
1198pub fn get_identity(req: &HttpRequest) -> Option<Value> {
1203 req.extensions().get::<JwtIdentity>().map(|i| i.0.clone())
1204}
1205
1206pub struct JwtAuth {
1212 inner: Arc<ActixJwtMiddleware>,
1213}
1214
1215impl<S, B> Transform<S, ServiceRequest> for JwtAuth
1216where
1217 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
1218 B: 'static,
1219{
1220 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
1221 type Error = actix_web::Error;
1222 type Transform = JwtAuthMiddleware<S>;
1223 type InitError = ();
1224 type Future = Ready<Result<Self::Transform, Self::InitError>>;
1225
1226 fn new_transform(&self, service: S) -> Self::Future {
1227 ready(Ok(JwtAuthMiddleware {
1228 service: Arc::new(service),
1229 inner: self.inner.clone(),
1230 }))
1231 }
1232}
1233
1234pub struct JwtAuthMiddleware<S> {
1236 service: Arc<S>,
1237 inner: Arc<ActixJwtMiddleware>,
1238}
1239
1240impl<S, B> Service<ServiceRequest> for JwtAuthMiddleware<S>
1241where
1242 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
1243 B: 'static,
1244{
1245 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
1246 type Error = actix_web::Error;
1247 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
1248
1249 fn poll_ready(
1250 &self,
1251 ctx: &mut std::task::Context<'_>,
1252 ) -> std::task::Poll<Result<(), Self::Error>> {
1253 self.service.poll_ready(ctx)
1254 }
1255
1256 fn call(&self, req: ServiceRequest) -> Self::Future {
1257 let mw = self.inner.clone();
1258 let service = self.service.clone();
1259
1260 Box::pin(async move {
1261 if let Some(ref skipper) = mw.skipper {
1263 if skipper(&req) {
1264 let res = service.call(req).await?;
1265 return Ok(res.map_into_left_body());
1266 }
1267 }
1268
1269 if let Some(ref bf) = mw.before_func {
1271 bf(&req);
1272 }
1273
1274 let mw_result = mw.middleware_impl(req.request());
1281
1282 if let Err(err) = mw_result {
1283 if let Some(ref eh) = mw.error_handler {
1285 let maybe_err = eh(req.request(), err);
1286 if maybe_err.is_none() && mw.continue_on_ignored_error {
1287 let res = service.call(req).await?;
1291 return Ok(res.map_into_left_body());
1292 }
1293 if let Some(e) = maybe_err {
1294 let resp = mw.handle_middleware_error(req.request(), &e);
1295 return Ok(req.into_response(resp).map_into_right_body());
1296 }
1297 return Ok(req
1299 .into_response(HttpResponse::Ok().finish())
1300 .map_into_right_body());
1301 }
1302
1303 let resp = mw.handle_middleware_error(req.request(), &err);
1305 return Ok(req.into_response(resp).map_into_right_body());
1306 }
1307
1308 if let Some(ref sh) = mw.success_handler {
1310 if let Err(e) = sh(req.request()) {
1311 let resp = mw.handle_middleware_error(req.request(), &e);
1312 return Ok(req.into_response(resp).map_into_right_body());
1313 }
1314 }
1315
1316 let send_auth = if mw.send_authorization {
1318 let ext = req.extensions();
1319 ext.get::<JwtTokenString>()
1320 .map(|t| format!("{} {}", mw.token_head_name, t.0))
1321 } else {
1322 None
1323 };
1324
1325 let mut res = service.call(req).await?;
1327
1328 if let Some(val) = send_auth {
1329 res.headers_mut()
1330 .insert(header::AUTHORIZATION, val.parse().unwrap());
1331 }
1332
1333 Ok(res.map_into_left_body())
1334 })
1335 }
1336}