#![deny(warnings)]
use std::{sync::Arc, time::Duration};
use axum::{
Router, middleware,
routing::{get, post},
};
use axum_bootstrap::{
TlsParam,
jwt::{ClaimsPayload, JwtConfig, jwt_auth_middleware},
};
use clap::Parser;
use hyper::StatusCode;
use tower_http::services::ServeDir;
use crate::handler::{get_current_user, login_handler, logout_handler};
type DynError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value = "8080")]
port: u16,
#[arg(short, long)]
username: String,
#[arg(long)]
password: String,
#[arg(long, default_value = "your-secret-key-change-in-production")]
jwt_secret: String,
#[arg(long)]
cert: Option<String>,
#[arg(long)]
key: Option<String>,
}
pub(crate) static PARAM: std::sync::LazyLock<Args> = std::sync::LazyLock::new(Args::parse);
const CARGO_CRATE_NAME: &str = env!("CARGO_CRATE_NAME");
#[derive(Clone)]
pub struct AppState {
pub jwt_config: JwtConfig,
pub username: String,
pub password_hash: String,
}
#[tokio::main]
pub async fn main() -> Result<(), DynError> {
axum_bootstrap::init_log::tracing::init(CARGO_CRATE_NAME)?;
let password_hash = bcrypt::hash(&PARAM.password, bcrypt::DEFAULT_COST)?;
let jwt_config = JwtConfig::new(&PARAM.jwt_secret);
let state = Arc::new(AppState {
jwt_config: jwt_config.clone(),
username: PARAM.username.clone(),
password_hash,
});
let protected_routes = Router::new()
.route("/api/me", get(get_current_user))
.layer(middleware::from_fn_with_state(Arc::new(jwt_config.clone()), jwt_auth_middleware::<ClaimsPayload>));
let app = Router::new()
.route("/api/login", post(login_handler))
.route("/api/logout", post(logout_handler))
.route("/health", get(|| async { (StatusCode::OK, "OK") }))
.merge(protected_routes)
.fallback_service(ServeDir::new("static")) .layer((
tower_http::trace::TraceLayer::new_for_http()
.make_span_with(|req: &axum::extract::Request| {
let method = req.method();
let path = req.uri().path();
tracing::info_span!("request", %method, %path)
})
.on_failure(()),
tower_http::cors::CorsLayer::permissive(),
tower_http::timeout::TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(30)),
tower_http::compression::CompressionLayer::new()
.gzip(true)
.br(true)
.deflate(true)
.zstd(true),
))
.with_state(state);
use axum_bootstrap::generate_shutdown_receiver;
let server = axum_bootstrap::new_server(PARAM.port, app, generate_shutdown_receiver());
let server = server
.with_timeout(Duration::from_secs(120))
.with_tls_param(match (PARAM.cert.as_ref(), PARAM.key.as_ref()) {
(Some(cert), Some(key)) => Some(TlsParam {
tls: true,
cert: cert.to_string(),
key: key.to_string(),
}),
_ => None,
});
server.run().await?;
Ok(())
}
mod handler {
use std::sync::Arc;
use axum::{Json, extract::State};
use axum_bootstrap::jwt::{Claims, ClaimsPayload, LOGOUT_COOKIE};
use axum_extra::extract::CookieJar;
use axum_macros::debug_handler;
use hyper::StatusCode;
use log::error;
use serde::{Deserialize, Serialize};
use crate::AppState;
#[derive(Serialize)]
pub struct UserInfo {
username: String,
}
#[debug_handler]
pub async fn get_current_user(Claims { payload, .. }: Claims) -> Result<Json<UserInfo>, StatusCode> {
Ok(Json(UserInfo { username: payload.username }))
}
#[derive(Deserialize, Debug)]
pub struct LoginRequest {
username: String,
password: String,
}
#[derive(Serialize)]
pub struct LoginResponse {
success: bool,
message: String,
}
pub async fn login_handler(
State(state): State<Arc<AppState>>, Json(req): Json<LoginRequest>,
) -> Result<(StatusCode, CookieJar, Json<LoginResponse>), StatusCode> {
if req.username != state.username {
return Ok((
StatusCode::UNAUTHORIZED,
CookieJar::new(),
Json(LoginResponse {
success: false,
message: "用户名或密码错误".to_string(),
}),
));
}
let password_valid = bcrypt::verify(&req.password, &state.password_hash).map_err(|e| {
error!("密码验证失败: {:?}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
if !password_valid {
return Ok((
StatusCode::UNAUTHORIZED,
CookieJar::new(),
Json(LoginResponse {
success: false,
message: "用户名或密码错误".to_string(),
}),
));
}
let cookie = Claims::new(ClaimsPayload { username: req.username })
.to_cookie(&state.jwt_config)
.map_err(|e| {
error!("生成JWT token失败: {:?}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let jar = CookieJar::new().add(cookie);
Ok((
StatusCode::OK,
jar,
Json(LoginResponse {
success: true,
message: "登录成功".to_string(),
}),
))
}
pub async fn logout_handler() -> Result<(StatusCode, CookieJar, Json<LoginResponse>), StatusCode> {
let logout_cookie = LOGOUT_COOKIE;
let jar = CookieJar::new().add(logout_cookie.clone());
Ok((
StatusCode::OK,
jar,
Json(LoginResponse {
success: true,
message: "已退出登录".to_string(),
}),
))
}
}