use std::sync::Arc;
use sa_token_adapter::context::SaRequest;
use sa_token_adapter::utils::extract_bearer_or_value;
type LoginIdValidator = Arc<dyn Fn(&str) -> bool + Send + Sync>;
pub fn match_path(path: &str, pattern: &str) -> bool {
if pattern == "/**" {
return true;
}
if let Some(prefix) = pattern.strip_suffix("/**") {
return path.starts_with(prefix);
}
if let Some(suffix) = pattern.strip_prefix("*") {
return path.ends_with(suffix);
}
if let Some(prefix) = pattern.strip_suffix("/*") {
if !path.starts_with(prefix) {
return false;
}
let rest = &path[prefix.len()..];
if rest.is_empty() || rest == "/" {
return true;
}
let rest = rest.trim_start_matches('/');
return !rest.contains('/');
}
path == pattern
}
pub fn match_any(path: &str, patterns: &[&str]) -> bool {
patterns.iter().any(|p| match_path(path, p))
}
pub fn need_auth(path: &str, include: &[&str], exclude: &[&str]) -> bool {
match_any(path, include) && !match_any(path, exclude)
}
#[derive(Clone)]
pub struct PathAuthConfig {
include: Vec<String>,
exclude: Vec<String>,
validator: Option<LoginIdValidator>,
}
impl PathAuthConfig {
pub fn new() -> Self {
Self {
include: Vec::new(),
exclude: Vec::new(),
validator: None,
}
}
pub fn include(mut self, patterns: Vec<String>) -> Self {
self.include = patterns;
self
}
pub fn exclude(mut self, patterns: Vec<String>) -> Self {
self.exclude = patterns;
self
}
pub fn validator<F>(mut self, f: F) -> Self
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
self.validator = Some(Arc::new(f));
self
}
pub fn check(&self, path: &str) -> bool {
let inc: Vec<&str> = self.include.iter().map(|s| s.as_str()).collect();
let exc: Vec<&str> = self.exclude.iter().map(|s| s.as_str()).collect();
need_auth(path, &inc, &exc)
}
pub fn validate_login_id(&self, login_id: &str) -> bool {
self.validator.as_ref().is_none_or(|v| v(login_id))
}
}
impl Default for PathAuthConfig {
fn default() -> Self {
Self::new()
}
}
use crate::{SaTokenManager, TokenValue, SaTokenContext, token::TokenInfo};
pub struct AuthResult {
pub need_auth: bool,
pub token: Option<TokenValue>,
pub token_info: Option<TokenInfo>,
pub is_valid: bool,
}
impl AuthResult {
pub fn should_reject(&self) -> bool {
self.need_auth && (!self.is_valid || self.token.is_none())
}
pub fn login_id(&self) -> Option<&str> {
self.token_info.as_ref().map(|t| t.login_id.as_str())
}
}
pub async fn process_auth(
path: &str,
token_str: Option<String>,
config: &PathAuthConfig,
manager: &SaTokenManager,
) -> AuthResult {
let need_auth = config.check(path);
let token = token_str.map(TokenValue::new);
let (is_valid, token_info) = if let Some(ref t) = token {
let valid = manager.is_valid(t).await;
let info = if valid {
manager.get_token_info(t).await.ok()
} else {
None
};
(valid, info)
} else {
(false, None)
};
let is_valid = is_valid && if need_auth {
token_info.as_ref().is_some_and(|info| config.validate_login_id(&info.login_id))
} else {
true
};
AuthResult {
need_auth,
token,
token_info,
is_valid,
}
}
pub fn create_context(result: &AuthResult) -> SaTokenContext {
let mut ctx = SaTokenContext::new();
if let (Some(token), Some(info)) = (&result.token, &result.token_info) {
ctx.token = Some(token.clone());
ctx.token_info = Some(Arc::new(info.clone()));
ctx.login_id = Some(info.login_id.clone());
}
ctx
}
pub fn extract_token<R: SaRequest>(req: &R, token_name: &str) -> Option<String> {
if let Some(v) = req.get_header(token_name) {
let s = extract_bearer_or_value(&v);
if !s.is_empty() {
return Some(s);
}
}
if !token_name.eq_ignore_ascii_case("authorization")
&& let Some(v) = req.get_header("Authorization") {
let s = extract_bearer_or_value(&v);
if !s.is_empty() {
return Some(s);
}
}
if let Some(v) = req.get_cookie(token_name) {
let s = v.trim().to_string();
if !s.is_empty() {
return Some(s);
}
}
if let Some(v) = req.get_param(token_name) {
let s = v.trim().to_string();
if !s.is_empty() {
return Some(s);
}
}
None
}
pub struct AuthFlowResult {
pub auth: AuthResult,
pub login_id: Option<String>,
pub token: Option<TokenValue>,
pub context: SaTokenContext,
}
impl AuthFlowResult {
pub fn should_reject(&self) -> bool {
self.auth.should_reject()
}
pub async fn run<F, R>(self, fut: F) -> R
where
F: std::future::Future<Output = R>,
{
SaTokenContext::scope(self.context, fut).await
}
}
pub async fn run_auth_flow<R: SaRequest>(
req: &R,
manager: &SaTokenManager,
path_config: Option<&PathAuthConfig>,
) -> AuthFlowResult {
let token_name = manager.config.token_name.as_str();
let token_str = extract_token(req, token_name);
let path = req.get_path();
let (auth, ctx) = match path_config {
Some(cfg) => {
let auth = process_auth(path.as_str(), token_str.clone(), cfg, manager).await;
let ctx = create_context(&auth);
(auth, ctx)
}
None => {
let token = token_str.map(TokenValue::new);
let (is_valid, token_info) = if let Some(ref t) = token {
let valid = manager.is_valid(t).await;
let info = if valid {
manager.get_token_info(t).await.ok()
} else {
None
};
(valid, info)
} else {
(false, None)
};
let auth = AuthResult {
need_auth: false,
token: token.clone(),
token_info,
is_valid,
};
let ctx = create_context(&auth);
(auth, ctx)
}
};
let login_id = auth.login_id().map(str::to_string);
let token = auth.token.clone();
AuthFlowResult {
auth,
login_id,
token,
context: ctx,
}
}