use crate::error::Result;
use crate::res;
use axum::RequestExt;
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use axum_extra::TypedHeader;
use derive_more::{Deref, From, Into};
use futures::TryFutureExt;
use headers::Authorization;
use headers::authorization::Bearer;
use jiff::{SignedDuration, Zoned};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
use nil_core::player::PlayerId;
use nil_core::ruler::Ruler;
use nil_server_database::sql_types::player_id::db_PlayerId as DbPlayerId;
use nil_server_types::auth::Token;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::LazyLock;
use tokio::task::spawn_blocking;
static JWT_SECRET: LazyLock<Box<str>> = LazyLock::new(|| {
env::var("NIL_JWT_SECRET")
.map(String::into_boxed_str)
.unwrap_or_else(|_| Box::from("CALL-OF-NIL"))
});
pub async fn authorization(mut request: Request, next: Next) -> Response {
let Ok(token) = request
.extract_parts::<TypedHeader<Authorization<Bearer>>>()
.map_ok(|header| Token::new(header.token()))
.await
else {
tracing::warn!("Missing authorization header");
return res!(UNAUTHORIZED);
};
match decode_jwt(token).await {
Ok(data) => {
request
.extensions_mut()
.insert(CurrentPlayer(data.claims.sub));
next.run(request).await
}
Err(err) => {
tracing::warn!("Failed to decode token: {err}");
res!(UNAUTHORIZED)
}
}
}
#[derive(Serialize, Deserialize)]
pub(crate) struct Claims {
pub sub: PlayerId,
pub exp: usize,
pub iat: usize,
}
pub(crate) async fn encode_jwt(player: PlayerId) -> Result<Token> {
let token = spawn_blocking(move || {
let now = Zoned::now();
let iat = now.timestamp().as_second().try_into()?;
let exp = now
.saturating_add(SignedDuration::from_hours(24 * 7))
.timestamp()
.as_second()
.try_into()?;
let token = encode(
&Header::default(),
&Claims { sub: player, iat, exp },
&EncodingKey::from_secret(JWT_SECRET.as_bytes()),
)?;
Ok::<_, anyhow::Error>(Token::new(token))
})
.await??;
Ok(token)
}
pub(crate) async fn decode_jwt(token: Token) -> Result<TokenData<Claims>> {
let claims = spawn_blocking(move || {
decode(
&token,
&DecodingKey::from_secret(JWT_SECRET.as_bytes()),
&Validation::default(),
)
.map_err(Into::<anyhow::Error>::into)
})
.await??;
Ok(claims)
}
#[derive(Clone, Debug, Deref, From, Into, PartialEq, Eq)]
pub struct CurrentPlayer(pub(crate) PlayerId);
impl From<CurrentPlayer> for Ruler {
fn from(player: CurrentPlayer) -> Self {
Ruler::Player { id: player.0 }
}
}
impl From<CurrentPlayer> for DbPlayerId {
fn from(player: CurrentPlayer) -> Self {
DbPlayerId::from(player.0)
}
}
impl PartialEq<PlayerId> for CurrentPlayer {
fn eq(&self, other: &PlayerId) -> bool {
self.0.eq(other)
}
}