pub mod middleware;
pub mod upload;
use crate::api::Query;
use crate::request::{ApiClient, ApiResponse};
use axum::extract::State;
use axum::http::header;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub cors_origin: Option<String>,
pub rate_limit: u64,
pub rate_limit_window: u64,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 3000,
cors_origin: None,
rate_limit: 0,
rate_limit_window: 60,
}
}
}
impl ServerConfig {
pub fn from_env() -> Self {
Self {
host: std::env::var("NCM_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
port: std::env::var("NCM_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(3000),
cors_origin: std::env::var("CORS_ALLOW_ORIGIN").ok(),
rate_limit: std::env::var("RATE_LIMIT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(0),
rate_limit_window: std::env::var("RATE_LIMIT_WINDOW")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(60),
}
}
}
#[derive(Clone)]
pub struct AppState {
pub client: Arc<ApiClient>,
}
async fn extract_merged_query(
headers: &HeaderMap,
uri_query: Option<&str>,
body: axum::body::Bytes,
content_type: Option<&str>,
) -> Query {
let mut query = Query::new();
if let Some(qs) = uri_query {
if let Ok(params) = serde_urlencoded::from_str::<HashMap<String, String>>(qs) {
for (k, v) in params {
query.params.insert(k, v);
}
}
}
if !body.is_empty() {
let ct = content_type.unwrap_or("");
if ct.contains("application/json") {
if let Ok(map) = serde_json::from_slice::<HashMap<String, Value>>(&body) {
for (k, v) in map {
let s = match &v {
Value::String(s) => s.clone(),
_ => v.to_string(),
};
query.params.insert(k, s);
}
}
} else {
if let Ok(params) = serde_urlencoded::from_bytes::<HashMap<String, String>>(&body) {
for (k, v) in params {
query.params.insert(k, v);
}
}
}
}
if let Some(cookie_param) = query.params.remove("cookie") {
query.cookie = Some(cookie_param);
} else if let Some(cookie_header) = headers.get(header::COOKIE) {
if let Ok(c) = cookie_header.to_str() {
query.cookie = Some(c.to_string());
}
}
if let Some(real_ip) = query.params.remove("realIP") {
query.real_ip = Some(real_ip);
}
if let Some(proxy) = query.params.remove("proxy") {
query.proxy = Some(proxy);
}
query
}
fn build_success_response(api_resp: ApiResponse) -> Response {
let status = axum::http::StatusCode::from_u16(api_resp.status as u16)
.unwrap_or(axum::http::StatusCode::OK);
let mut response = (status, Json(api_resp.body)).into_response();
for cookie_str in &api_resp.cookie {
if let Ok(val) = header::HeaderValue::from_str(cookie_str) {
response.headers_mut().append(header::SET_COOKIE, val);
}
}
response
}
fn build_error_response(err: crate::error::NcmError) -> Response {
use crate::error::NcmError;
let (status, body) = match &err {
NcmError::AuthRequired(msg) => (
axum::http::StatusCode::UNAUTHORIZED,
json!({ "code": 301, "msg": msg }),
),
NcmError::InvalidParam(msg) => (
axum::http::StatusCode::BAD_REQUEST,
json!({ "code": 400, "msg": msg }),
),
NcmError::RateLimited(msg) => (
axum::http::StatusCode::TOO_MANY_REQUESTS,
json!({ "code": 503, "msg": msg }),
),
NcmError::Timeout(msg) => (
axum::http::StatusCode::GATEWAY_TIMEOUT,
json!({ "code": 504, "msg": msg }),
),
NcmError::Api { code, msg } => {
let http_status = axum::http::StatusCode::from_u16(*code as u16)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
(http_status, json!({ "code": code, "msg": msg }))
}
_ => (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
json!({ "code": 500, "msg": err.to_string() }),
),
};
(status, Json(body)).into_response()
}
async fn handle_api_request<F>(
state: &AppState,
headers: HeaderMap,
uri: &axum::http::Uri,
body: axum::body::Bytes,
api_fn: F,
) -> Response
where
F: for<'a> FnOnce(
&'a ApiClient,
&'a Query,
) -> Pin<
Box<dyn Future<Output = crate::error::Result<ApiResponse>> + Send + 'a>,
>,
{
let path = uri.path().to_string();
let start = std::time::Instant::now();
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok());
let query = extract_merged_query(&headers, uri.query(), body, content_type).await;
match api_fn(&state.client, &query).await {
Ok(resp) => {
tracing::info!("{} -> {} ({:.1?})", path, resp.status, start.elapsed());
build_success_response(resp)
}
Err(e) => {
tracing::warn!("{} -> ERROR: {} ({:.1?})", path, e, start.elapsed());
build_error_response(e)
}
}
}
macro_rules! api_routes {
($router:expr, $( $method:ident => $route:expr ),* $(,)?) => {{
let router = $router;
$(
let router = router.route(
$route,
get(|State(state): State<AppState>, headers: HeaderMap, uri: axum::http::Uri| async move {
handle_api_request(&state, headers, &uri, axum::body::Bytes::new(), |client, q| Box::pin(client.$method(q))).await
})
.post(|State(state): State<AppState>, headers: HeaderMap, uri: axum::http::Uri, body: axum::body::Bytes| async move {
handle_api_request(&state, headers, &uri, body, |client, q| Box::pin(client.$method(q))).await
}),
);
)*
router
}};
}
fn register_routes(router: Router<AppState>) -> Router<AppState> {
let router = { include!(concat!(env!("OUT_DIR"), "/api_routes_generated.rs")) };
let router = router
.route("/avatar/upload", post(upload::handle_avatar_upload))
.route("/voice/upload", post(upload::handle_voice_upload));
router
}
pub fn build_app(client: ApiClient) -> Router {
let state = AppState {
client: Arc::new(client),
};
let router = Router::new();
let router = register_routes(router);
let router = router.route(
"/",
get(|| async {
Json(json!({
"code": 200,
"msg": "NCM API Rust Server is running",
}))
}),
);
router.layer(middleware::cors_layer(None)).with_state(state)
}
pub fn build_app_with_config(client: ApiClient, config: &ServerConfig) -> Router {
let state = AppState {
client: Arc::new(client),
};
let router = Router::new();
let router = register_routes(router);
let router = router.route(
"/",
get(|| async {
Json(json!({
"code": 200,
"msg": "NCM API Rust Server is running",
}))
}),
);
let router = router
.layer(middleware::cors_layer(config.cors_origin.as_deref()))
.with_state(state);
if config.rate_limit > 0 {
let limiter = middleware::RateLimiter::new(config.rate_limit, config.rate_limit_window);
router.layer(axum::middleware::from_fn_with_state(
limiter,
middleware::rate_limit_middleware,
))
} else {
router
}
}
pub async fn start_server(config: ServerConfig) {
let client = ApiClient::new(None);
let app = build_app_with_config(client, &config);
let addr = format!("{}:{}", config.host, config.port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.expect("Failed to bind address");
tracing::info!("NCM API Server listening on http://{}", addr);
axum::serve(listener, app).await.expect("Server error");
}