use axum::body::Body;
use axum::extract::{Request, State};
use axum::http::{StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use kanade_shared::secrets;
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use std::env;
use std::sync::OnceLock;
use tracing::{error, warn};
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
Viewer,
Operator,
Admin,
}
impl Role {
pub fn as_str(self) -> &'static str {
match self {
Role::Viewer => "viewer",
Role::Operator => "operator",
Role::Admin => "admin",
}
}
pub fn parse(s: &str) -> Option<Role> {
match s {
"viewer" => Some(Role::Viewer),
"operator" => Some(Role::Operator),
"admin" => Some(Role::Admin),
_ => None,
}
}
pub fn allows(self, required: Role) -> bool {
self >= required
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Claims {
pub sub: String,
pub exp: i64,
#[serde(default)]
pub aud: Option<String>,
#[serde(default)]
pub roles: Vec<String>,
}
impl Claims {
pub fn role(&self) -> Role {
self.roles
.iter()
.filter_map(|r| Role::parse(r))
.max()
.unwrap_or(Role::Viewer)
}
fn service(sub: &str) -> Self {
Claims {
sub: sub.to_string(),
exp: 4_102_444_800, aud: Some(EXPECTED_AUDIENCE.to_string()),
roles: vec![Role::Admin.as_str().to_string()],
}
}
}
const ENV_DISABLE: &str = "KANADE_AUTH_DISABLE";
const ENV_STATIC_TOKEN: &str = "KANADE_AUTH_STATIC_TOKEN";
const ENV_SECRET: &str = "KANADE_JWT_SECRET";
const REG_SUBKEY: &str = r"SOFTWARE\kanade\backend";
const REG_STATIC_TOKEN: &str = "StaticToken";
const REG_JWT_SECRET: &str = "JwtSecret";
pub const EXPECTED_AUDIENCE: &str = "kanade";
fn resolve_static_token() -> Option<&'static str> {
static CACHE: OnceLock<Option<String>> = OnceLock::new();
CACHE
.get_or_init(|| {
if let Some(t) = secrets::read_hklm_value(REG_SUBKEY, REG_STATIC_TOKEN) {
return Some(t);
}
match env::var(ENV_STATIC_TOKEN) {
Ok(t) if !t.is_empty() => Some(t),
_ => None,
}
})
.as_deref()
}
fn resolve_jwt_secret() -> Option<String> {
if let Some(s) = secrets::read_hklm_value(REG_SUBKEY, REG_JWT_SECRET) {
return Some(s);
}
match env::var(ENV_SECRET) {
Ok(s) if !s.is_empty() => Some(s),
_ => None,
}
}
pub fn signing_secret() -> &'static str {
static CACHE: OnceLock<String> = OnceLock::new();
CACHE.get_or_init(|| {
resolve_jwt_secret().unwrap_or_else(|| {
warn!(
"no JwtSecret registry value and no $KANADE_JWT_SECRET — using a hard-coded dev fallback (NEVER in production)"
);
"dev-secret-please-override".to_string()
})
})
}
async fn lookup_user(
pool: &SqlitePool,
username: &str,
) -> Result<Option<(Role, bool)>, sqlx::Error> {
let row =
sqlx::query_as::<_, (String, i64)>("SELECT role, disabled FROM users WHERE username = ?")
.bind(username)
.fetch_optional(pool)
.await?;
Ok(row.and_then(|(role, disabled)| Role::parse(&role).map(|r| (r, disabled != 0))))
}
pub async fn verify(
State(pool): State<SqlitePool>,
req: Request,
next: Next,
) -> Result<Response, Response> {
if env::var(ENV_DISABLE).is_ok() {
let mut req = req;
req.extensions_mut()
.insert(Claims::service("auth-disabled"));
return Ok(next.run(req).await);
}
let path = req.uri().path();
if !path.starts_with("/api/") {
return Ok(next.run(req).await);
}
if path == "/api/auth/login" {
return Ok(next.run(req).await);
}
let token = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "))
.map(str::trim)
.filter(|t| !t.is_empty());
let Some(token) = token else {
return Err(unauth("missing bearer token"));
};
if let Some(expected) = resolve_static_token()
&& constant_time_eq(token.as_bytes(), expected.as_bytes())
{
let mut req = req;
req.extensions_mut()
.insert(Claims::service("service-token"));
return Ok(next.run(req).await);
}
let secret = signing_secret();
let key = DecodingKey::from_secret(secret.as_bytes());
let mut validation = Validation::new(Algorithm::HS256);
validation.set_audience(&[EXPECTED_AUDIENCE]);
let claims = match decode::<Claims>(token, &key, &validation) {
Ok(data) => data.claims,
Err(e) => {
warn!(error = %e, path, "JWT verify failed");
return Err(unauth(&format!("invalid token: {e}")));
}
};
match lookup_user(&pool, &claims.sub).await {
Ok(Some((role, disabled))) => {
if disabled {
return Err(unauth("account disabled"));
}
let mut claims = claims;
claims.roles = vec![role.as_str().to_string()];
let mut req = req;
req.extensions_mut().insert(claims);
Ok(next.run(req).await)
}
Ok(None) => Err(unauth("unknown account")),
Err(e) => {
error!(error = %e, sub = %claims.sub, "user lookup failed");
Err(unauth("auth backend unavailable"))
}
}
}
fn gate(req: &Request, required: Role) -> Option<Response> {
let Some(claims) = req.extensions().get::<Claims>().cloned() else {
return Some(forbidden("no authenticated identity"));
};
if claims.role().allows(required) {
None
} else {
Some(forbidden(&format!(
"{} role required (caller is {})",
required.as_str(),
claims.role().as_str()
)))
}
}
pub async fn require_operator(req: Request, next: Next) -> Result<Response, Response> {
if let Some(rejection) = gate(&req, Role::Operator) {
return Err(rejection);
}
Ok(next.run(req).await)
}
pub async fn require_admin(req: Request, next: Next) -> Result<Response, Response> {
if let Some(rejection) = gate(&req, Role::Admin) {
return Err(rejection);
}
Ok(next.run(req).await)
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn unauth(msg: &str) -> Response {
(StatusCode::UNAUTHORIZED, Body::from(msg.to_owned())).into_response()
}
fn forbidden(msg: &str) -> Response {
(StatusCode::FORBIDDEN, Body::from(msg.to_owned())).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn role_hierarchy() {
assert!(Role::Admin.allows(Role::Operator));
assert!(Role::Admin.allows(Role::Viewer));
assert!(Role::Operator.allows(Role::Viewer));
assert!(!Role::Operator.allows(Role::Admin));
assert!(!Role::Viewer.allows(Role::Operator));
assert!(Role::Viewer.allows(Role::Viewer));
}
#[test]
fn role_roundtrip() {
for r in [Role::Viewer, Role::Operator, Role::Admin] {
assert_eq!(Role::parse(r.as_str()), Some(r));
}
assert_eq!(Role::parse("root"), None);
}
#[test]
fn claims_role_picks_highest() {
let c = Claims {
sub: "x".into(),
exp: 0,
aud: None,
roles: vec!["viewer".into(), "admin".into()],
};
assert_eq!(c.role(), Role::Admin);
let none = Claims {
sub: "x".into(),
exp: 0,
aud: None,
roles: vec![],
};
assert_eq!(none.role(), Role::Viewer);
}
}