#![deny(warnings)]
use std::time::Duration;
use axum_bootstrap::TlsParam;
use http::init_http_client;
use clap::Parser;
type DynError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Parser)]
#[command(author, version=None, about, long_about = None)]
pub struct Param {
#[arg(long, short, value_name = "port", default_value = "4000")]
port: u16,
#[arg(long, value_name = "reqwest client的代理", default_value = "")]
http_proxy: String,
#[arg(long, value_name = "CERT", default_value = "cert.pem")]
cert: String,
#[arg(long, value_name = "KEY", default_value = "privkey.pem")]
key: String,
#[arg(short, long, help = "if enable, server will listen on https")]
tls: bool,
}
pub(crate) static PARAM: std::sync::LazyLock<Param> = std::sync::LazyLock::new(Param::parse);
const CARGO_CRATE_NAME: &str = env!("CARGO_CRATE_NAME");
#[tokio::main]
pub async fn main() -> Result<(), DynError> {
axum_bootstrap::init_log::tracing::init(CARGO_CRATE_NAME)?;
log::info!("init http client...");
let client = init_http_client(&PARAM.http_proxy).await?;
#[cfg(feature = "mysql")]
{
log::info!("connecting to mysql...");
let pool: sqlx::Pool<sqlx::MySql> = sqlx_mysql::MySqlPoolOptions::new()
.max_connections(20)
.acquire_timeout(std::time::Duration::from_secs(10))
.connect_with(
sqlx_mysql::MySqlConnectOptions::new()
.host("127.0.0.1")
.username("root")
.password("xxxxxx")
.database("test")
.ssl_mode(sqlx_mysql::MySqlSslMode::Required)
.timezone(Some(String::from("+08:00"))),
)
.await?;
use axum_bootstrap::generate_shutdown_receiver;
let server = axum_bootstrap::new_server(PARAM.port, handler::build_router(handler::AppState { client, pool }), register_shutdown_receiver());
let server = server.with_timeout(Duration::from_secs(120)).with_tls_param(match PARAM.tls {
true => Some(TlsParam {
tls: true,
cert: PARAM.cert.to_string(),
key: PARAM.key.to_string(),
}),
false => None,
});
server.run().await?;
}
#[cfg(not(feature = "mysql"))]
{
use axum_bootstrap::generate_shutdown_receiver;
let server = axum_bootstrap::new_server(PARAM.port, handler::build_router(handler::AppState { client }), generate_shutdown_receiver());
let server = server.with_timeout(Duration::from_secs(120)).with_tls_param(match PARAM.tls {
true => Some(TlsParam {
tls: true,
cert: PARAM.cert.to_string(),
key: PARAM.key.to_string(),
}),
false => None,
});
server.run().await?;
}
Ok(())
}
mod handler {
#![allow(unused)]
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use axum::{
Json, Router,
extract::{ConnectInfo, MatchedPath, Request, State},
http::{self, HeaderValue},
routing::get,
};
use axum_macros::debug_handler;
use chrono::NaiveDateTime;
use hyper::{HeaderMap, StatusCode};
use log::info;
use prometheus_client::encoding::text::encode;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt};
use axum_bootstrap::{error::AppError, util::json::StupidValue};
use tokio::time::sleep;
use tower_http::{compression::CompressionLayer, cors::CorsLayer, timeout::TimeoutLayer, trace::TraceLayer};
use crate::metrics::{HttpReqLabel, METRIC};
pub(crate) struct AppState {
#[cfg(feature = "mysql")]
pub(crate) pool: sqlx::MySqlPool,
pub(crate) client: reqwest::Client,
}
pub(crate) fn build_router(app_state: AppState) -> Router {
Router::new()
.route("/", get(|ConnectInfo(addr): ConnectInfo<SocketAddr>| async move { (StatusCode::OK, format!("{addr}")) }))
.route(
"/time",
get(|| async {
sleep(Duration::from_secs(20)).await;
(StatusCode::OK, "OK")
}),
)
.route("/metrics", get(metrics_handler))
.route("/error", get(error_func))
.route("/data", get(data_handler).post(data_handler))
.layer((
TraceLayer::new_for_http() .make_span_with(make_span)
.on_failure(()),
CorsLayer::permissive(),
TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(30)),
CompressionLayer::new(),
))
.with_state(Arc::new(app_state))
}
fn make_span(req: &http::Request<axum::body::Body>) -> tracing::Span {
let method = req.method();
let path = req.uri().path();
let matched_path = req.extensions().get::<MatchedPath>().map(|matched_path| matched_path.as_str());
tracing::debug_span!("recv request", %method, %path, matched_path)
}
pub(crate) async fn metrics_handler() -> Result<(StatusCode, String), AppError> {
let mut buffer = String::new();
if let Err(e) = encode(&mut buffer, &METRIC.prom_registry) {
log::error!("Failed to encode metrics: {e:?}");
return Err(AppError::new(io::Error::other(e)));
}
Ok((StatusCode::OK, buffer))
}
pub(crate) async fn error_func() -> Result<(StatusCode, String), AppError> {
Err(AppError::new(io::Error::other("MOCK error")))
}
#[debug_handler]
pub(crate) async fn data_handler(
State(state): State<Arc<AppState>>, req: Json<DataRequest>,
) -> (StatusCode, HeaderMap, Json<Response<Vec<Data>>>) {
METRIC.req_count.get_or_create(&HttpReqLabel { path: "test".to_string() }).inc();
info!("req: {req:?}");
#[cfg(not(feature = "mysql"))]
return (StatusCode::INTERNAL_SERVER_ERROR, some_headers(), Json(Response::error("mysql not enabled".to_string())));
#[cfg(feature = "mysql")]
{
use std::borrow::Borrow;
let pool = state.pool.borrow();
match sqlx::query!(r"select now() as now_local, now() as now_naive, now() as now_utc;")
.fetch_one(pool)
.await
{
Ok(row) => (
StatusCode::OK,
some_headers(),
Json(Response::success(vec![Data {
now_local: row.now_local,
now_naive: row.now_naive,
now_utc: row.now_utc,
}])),
),
Err(e) => {
log::warn!("query now failed: {:?}", e);
(StatusCode::INTERNAL_SERVER_ERROR, some_headers(), Json(Response::error(format!("query now failed: {:?}", e))))
}
}
}
}
#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct DataRequest {
#[serde(rename = "startTime", with = "axum_bootstrap::util::json::my_date_format_option", default)]
pub(crate) start_time: Option<NaiveDateTime>,
#[serde(rename = "endTime", with = "axum_bootstrap::util::json::my_date_format_option", default)]
pub(crate) end_time: Option<NaiveDateTime>,
#[serde(rename = "distinctCode", default)]
pub(crate) distinct_code: StupidValue<bool>,
}
#[derive(serde::Serialize, Debug, FromRow)]
pub(crate) struct Data {
#[serde(with = "axum_bootstrap::util::json::my_date_format")]
pub(crate) now_local: NaiveDateTime,
#[serde(with = "axum_bootstrap::util::json::my_date_format")]
pub(crate) now_naive: NaiveDateTime,
#[serde(with = "axum_bootstrap::util::json::my_date_format")]
pub(crate) now_utc: NaiveDateTime,
}
#[derive(Serialize)]
pub(crate) struct Response<T: Serialize> {
code: i32,
message: String,
data: Option<T>,
}
impl<T: Serialize> Response<T> {
fn success(data: T) -> Response<T> {
Response {
code: 200,
message: "success".to_string(),
data: Some(data),
}
}
fn error(msg: String) -> Response<T> {
Response {
code: 500,
message: msg,
data: None,
}
}
}
pub fn some_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
headers
}
}
mod metrics {
use std::sync::LazyLock;
use prometheus_client::{
encoding::EncodeLabelSet,
metrics::{counter::Counter, family::Family},
registry::Registry,
};
pub(crate) static METRIC: LazyLock<Metrics> = LazyLock::new(|| {
let mut prom_registry = Registry::default();
let req_count = Family::<HttpReqLabel, Counter>::default();
prom_registry.register("req_count", "help", req_count.clone());
Metrics { prom_registry, req_count }
});
pub(crate) struct Metrics {
pub(crate) prom_registry: Registry,
pub(crate) req_count: Family<HttpReqLabel, Counter>,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub(crate) struct HttpReqLabel {
pub(crate) path: String,
}
}
mod http {
use reqwest::Client;
use crate::DynError;
pub async fn init_http_client(http_proxy: &str) -> Result<Client, DynError> {
let client_builder = Client::builder().pool_max_idle_per_host(20);
if http_proxy.is_empty() {
Ok(client_builder.build()?)
} else {
Ok(client_builder.proxy(reqwest::Proxy::all(http_proxy)?).build()?)
}
}
}