use crate::MyError;
use dashmap::DashMap;
use rocket::{
Orbit, Rocket, Route,
fairing::{Fairing, Info, Kind},
get,
http::Method,
routes,
serde::json::Json,
};
use serde::Serialize;
use std::sync::{
Arc, OnceLock,
atomic::{AtomicU64, Ordering},
};
use tracing::{error, info};
#[derive(Debug, Eq, Hash, PartialEq)]
struct RouteAttributes {
method: Method,
path: String,
mime: String,
rank: isize,
}
impl From<&Route> for RouteAttributes {
fn from(route: &Route) -> RouteAttributes {
let mime = if let Some(z_format) = route.format.as_ref() {
z_format.to_string()
} else {
"N/A".to_owned()
};
RouteAttributes {
method: route.method,
path: route.uri.origin.path().to_string(),
mime,
rank: route.rank,
}
}
}
#[derive(Debug)]
struct RouteStats {
count: AtomicU64,
min: AtomicU64,
avg: AtomicU64,
max: AtomicU64,
}
impl Default for RouteStats {
fn default() -> Self {
Self {
count: Default::default(),
min: AtomicU64::new(u64::MAX),
avg: Default::default(),
max: Default::default(),
}
}
}
static ENDPOINTS: OnceLock<Arc<DashMap<RouteAttributes, RouteStats>>> = OnceLock::new();
fn endpoints() -> Arc<DashMap<RouteAttributes, RouteStats>> {
ENDPOINTS.get_or_init(|| Arc::new(DashMap::new())).clone()
}
pub(crate) struct StatsFairing;
#[rocket::async_trait]
impl Fairing for StatsFairing {
fn info(&self) -> Info {
Info {
name: "Routes Statistics",
kind: Kind::Liftoff | Kind::Shutdown,
}
}
async fn on_liftoff(&self, r: &Rocket<Orbit>) {
for route in r.routes() {
let key = RouteAttributes::from(route);
endpoints().insert(key, RouteStats::default());
}
}
async fn on_shutdown(&self, _: &Rocket<Orbit>) {
let stats = endpoints();
let (total_count, total_avg): (u64, u64) = stats
.iter()
.filter(|e| e.count.load(Ordering::Relaxed) > 0)
.fold((0, 0), |(sum_count, sum_avg), e| {
(
sum_count + e.count.load(Ordering::Relaxed),
sum_avg + e.avg.load(Ordering::Relaxed),
)
});
let average_duration = total_avg.checked_div(total_count).unwrap_or(0);
info!("LaRS stats\n{:?}", stats);
info!(
"*** Total calls = {}; Average duration = {} ns",
total_count, average_duration
);
}
}
pub(crate) fn update_stats(route: &Route, duration: u64) {
let key = RouteAttributes::from(route);
let tmp = endpoints();
let tmp = tmp.get_mut(&key);
match tmp {
Some(endpoint) => {
endpoint.min.fetch_min(duration, Ordering::Relaxed);
endpoint.max.fetch_max(duration, Ordering::Relaxed);
let old_count = endpoint.count.fetch_add(1, Ordering::Relaxed);
let old_avg = endpoint.avg.fetch_add(0, Ordering::Relaxed);
let new_avg = (old_count * old_avg + duration) / (old_count + 1);
endpoint.avg.store(new_avg, Ordering::Relaxed);
}
_ => error!("Failed finding stats for {}", route),
}
}
#[doc(hidden)]
pub fn routes() -> Vec<rocket::Route> {
routes![stats]
}
#[derive(Debug, Serialize)]
struct StatsRecord {
method: String,
path: String,
mime: String,
rank: isize,
count: u64,
min: u64,
avg: u64,
max: u64,
}
#[get("/")]
async fn stats() -> Result<Json<Vec<StatsRecord>>, MyError> {
let result = endpoints()
.iter()
.filter(|x| x.count.load(Ordering::Relaxed) > 0)
.map(|x| {
let (k, v) = x.pair();
StatsRecord {
method: k.method.to_string(),
path: k.path.clone(),
mime: k.mime.clone(),
rank: k.rank,
count: v.count.load(Ordering::Relaxed),
min: v.min.load(Ordering::Relaxed),
avg: v.avg.load(Ordering::Relaxed),
max: v.max.load(Ordering::Relaxed),
}
})
.collect();
Ok(Json(result))
}