use rocket::{Data, Request, Response};
use rocket::fairing::{Fairing, Info, Kind};
use rocket::http::{ContentType, Status};
use sa_token_core::error::messages;
use sa_token_plugin_rocket_core::run_auth_flow;
use serde_json::json;
use crate::adapter::RocketCapturedRequest;
use crate::SaTokenState;
pub struct SaTokenFairing {
state: SaTokenState,
}
impl SaTokenFairing {
pub fn new(state: SaTokenState) -> Self {
Self { state }
}
}
#[rocket::async_trait]
impl Fairing for SaTokenFairing {
fn info(&self) -> Info {
Info {
name: "SaToken Authentication",
kind: Kind::Request,
}
}
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
let adapter = RocketCapturedRequest::capture(
request,
self.state.manager.config.token_name.as_str(),
);
let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
if let Some(ref t) = flow.token {
request.local_cache(|| Some(t.clone()));
}
if let Some(ref id) = flow.login_id {
request.local_cache(|| Some(id.clone()));
}
}
}
pub struct SaCheckLoginFairing {
state: SaTokenState,
}
impl SaCheckLoginFairing {
pub fn new(state: SaTokenState) -> Self {
Self { state }
}
}
#[rocket::async_trait]
impl Fairing for SaCheckLoginFairing {
fn info(&self) -> Info {
Info {
name: "SaToken Check Login",
kind: Kind::Request | Kind::Response,
}
}
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
let adapter = RocketCapturedRequest::capture(
request,
self.state.manager.config.token_name.as_str(),
);
let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
if flow.login_id.is_some() {
if let Some(ref t) = flow.token {
request.local_cache(|| Some(t.clone()));
}
if let Some(ref id) = flow.login_id {
request.local_cache(|| Some(id.clone()));
}
return;
}
request.local_cache(|| Some("unauthorized"));
}
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
if request.local_cache(|| None::<&str>).is_some()
&& *request.local_cache(|| None::<&str>) == Some("unauthorized") {
response.set_status(Status::Unauthorized);
response.set_sized_body(
None,
std::io::Cursor::new(
json!({
"code": 401,
"message": messages::AUTH_ERROR
})
.to_string(),
),
);
}
}
}
pub struct SaCheckPermissionFairing {
#[allow(dead_code)]
state: SaTokenState,
permission: String,
}
impl SaCheckPermissionFairing {
pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
Self {
state,
permission: permission.into(),
}
}
}
#[rocket::async_trait]
impl Fairing for SaCheckPermissionFairing {
fn info(&self) -> Info {
Info {
name: "SaToken Check Permission",
kind: Kind::Request | Kind::Response,
}
}
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
if sa_token_core::StpUtil::has_permission(&login_id, &self.permission).await {
return;
}
}
request.local_cache(|| Some("forbidden"));
}
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
if request.local_cache(|| None::<&str>).is_some()
&& *request.local_cache(|| None::<&str>) == Some("forbidden") {
response.set_status(Status::Forbidden);
response.set_header(ContentType::JSON);
response.set_sized_body(
None,
std::io::Cursor::new(
json!({
"code": 403,
"message": messages::PERMISSION_REQUIRED
})
.to_string(),
),
);
}
}
}
pub struct SaCheckRoleFairing {
#[allow(dead_code)]
state: SaTokenState,
role: String,
}
impl SaCheckRoleFairing {
pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
Self {
state,
role: role.into(),
}
}
}
#[rocket::async_trait]
impl Fairing for SaCheckRoleFairing {
fn info(&self) -> Info {
Info {
name: "SaToken Check Role",
kind: Kind::Request | Kind::Response,
}
}
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
if sa_token_core::StpUtil::has_role(&login_id, &self.role).await {
return;
}
}
request.local_cache(|| Some("forbidden_role"));
}
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
if request.local_cache(|| None::<&str>).is_some()
&& *request.local_cache(|| None::<&str>) == Some("forbidden_role") {
response.set_status(Status::Forbidden);
response.set_header(ContentType::JSON);
response.set_sized_body(
None,
std::io::Cursor::new(
json!({
"code": 403,
"message": messages::ROLE_REQUIRED
})
.to_string(),
),
);
}
}
}