use std::str::FromStr;
use axum::Json;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Row, SqlitePool};
use tracing::warn;
const DEFAULT_WINDOW_SECS: i64 = 60 * 60;
const DEFAULT_STEP_SECS: i64 = 5 * 60;
const MAX_BUCKETS: i64 = 10_000;
const DEFAULT_TOP_WINDOW_SECS: i64 = 5 * 60;
const DEFAULT_TOP_LIMIT: i64 = 5;
const MAX_TOP_LIMIT: i64 = 50;
const ACTIVE_INVESTIGATION_WINDOW_SECS: i64 = 5 * 60;
#[derive(Clone, Copy, Debug)]
enum Metric {
CpuPct,
MemUsedBytes,
DiskReadBytesPerSec,
DiskWrittenBytesPerSec,
NetRxBytesPerSec,
NetTxBytesPerSec,
}
impl Metric {
fn column(self) -> &'static str {
match self {
Self::CpuPct => "cpu_pct",
Self::MemUsedBytes => "mem_used_bytes",
Self::DiskReadBytesPerSec => "disk_read_bytes_per_sec",
Self::DiskWrittenBytesPerSec => "disk_written_bytes_per_sec",
Self::NetRxBytesPerSec => "net_rx_bytes_per_sec",
Self::NetTxBytesPerSec => "net_tx_bytes_per_sec",
}
}
}
impl FromStr for Metric {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"cpu_pct" | "cpu" => Ok(Self::CpuPct),
"mem_used_bytes" | "mem" | "memory" => Ok(Self::MemUsedBytes),
"disk_read_bytes_per_sec" | "disk_read" => Ok(Self::DiskReadBytesPerSec),
"disk_written_bytes_per_sec" | "disk_written" => Ok(Self::DiskWrittenBytesPerSec),
"net_rx_bytes_per_sec" | "net_rx" => Ok(Self::NetRxBytesPerSec),
"net_tx_bytes_per_sec" | "net_tx" => Ok(Self::NetTxBytesPerSec),
_ => Err(()),
}
}
}
#[derive(Clone, Copy, Debug)]
enum Aggregate {
Avg,
Max,
}
impl Aggregate {
fn sql(self) -> &'static str {
match self {
Self::Avg => "AVG",
Self::Max => "MAX",
}
}
}
impl FromStr for Aggregate {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"avg" | "mean" => Ok(Self::Avg),
"max" => Ok(Self::Max),
_ => Err(()),
}
}
}
#[derive(Deserialize)]
pub struct FleetPerfQuery {
metric: Option<String>,
agg: Option<String>,
from: Option<DateTime<Utc>>,
to: Option<DateTime<Utc>>,
step: Option<String>,
}
#[derive(Serialize)]
pub struct FleetPerfPoint {
pub at: DateTime<Utc>,
pub value: Option<f64>,
}
#[derive(Serialize)]
pub struct FleetPerfResponse {
pub metric: String,
pub agg: String,
pub from: DateTime<Utc>,
pub to: DateTime<Utc>,
pub step_seconds: i64,
pub points: Vec<FleetPerfPoint>,
}
pub async fn fleet(
State(pool): State<SqlitePool>,
Query(q): Query<FleetPerfQuery>,
) -> Result<Json<FleetPerfResponse>, StatusCode> {
let metric = Metric::from_str(q.metric.as_deref().unwrap_or("cpu_pct"))
.map_err(|_| StatusCode::BAD_REQUEST)?;
let agg = Aggregate::from_str(q.agg.as_deref().unwrap_or("avg"))
.map_err(|_| StatusCode::BAD_REQUEST)?;
let to = q.to.unwrap_or_else(Utc::now);
let from = q
.from
.unwrap_or_else(|| to - Duration::seconds(DEFAULT_WINDOW_SECS));
let step_secs = match q.step.as_deref() {
None => DEFAULT_STEP_SECS,
Some(raw) => match humantime::parse_duration(raw) {
Ok(d) => i64::try_from(d.as_secs()).unwrap_or(DEFAULT_STEP_SECS),
Err(_) => return Err(StatusCode::BAD_REQUEST),
},
};
if step_secs <= 0 || from >= to {
return Err(StatusCode::BAD_REQUEST);
}
if (to - from).num_seconds() / step_secs > MAX_BUCKETS {
return Err(StatusCode::BAD_REQUEST);
}
let sql = format!(
"SELECT
(CAST(strftime('%s', at) AS INTEGER) / ?) * ? AS bucket_unix,
{agg}({metric}) AS value
FROM host_perf_samples
WHERE at >= ? AND at < ?
GROUP BY bucket_unix
ORDER BY bucket_unix ASC",
agg = agg.sql(),
metric = metric.column(),
);
let rows = sqlx::query(sqlx::AssertSqlSafe(sql))
.bind(step_secs)
.bind(step_secs)
.bind(from)
.bind(to)
.fetch_all(&pool)
.await
.map_err(|e| {
warn!(error = %e, metric = ?metric, "fleet perf query");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let points = rows
.into_iter()
.map(|r| {
let bucket: i64 = r.try_get("bucket_unix").unwrap_or(0);
FleetPerfPoint {
at: DateTime::<Utc>::from_timestamp(bucket, 0).unwrap_or(from),
value: r.try_get("value").ok(),
}
})
.collect();
Ok(Json(FleetPerfResponse {
metric: metric.column().to_string(),
agg: match agg {
Aggregate::Avg => "avg".into(),
Aggregate::Max => "max".into(),
},
from,
to,
step_seconds: step_secs,
points,
}))
}
#[derive(Deserialize)]
pub struct TopPerfQuery {
metric: Option<String>,
window: Option<String>,
limit: Option<i64>,
}
#[derive(Serialize)]
pub struct TopPerfRow {
pub pc_id: String,
pub hostname: Option<String>,
pub value: f64,
}
#[derive(Serialize)]
pub struct TopPerfResponse {
pub metric: String,
pub window_seconds: i64,
pub rows: Vec<TopPerfRow>,
}
pub async fn top(
State(pool): State<SqlitePool>,
Query(q): Query<TopPerfQuery>,
) -> Result<Json<TopPerfResponse>, StatusCode> {
let metric = Metric::from_str(q.metric.as_deref().unwrap_or("cpu_pct"))
.map_err(|_| StatusCode::BAD_REQUEST)?;
let window_secs = match q.window.as_deref() {
None => DEFAULT_TOP_WINDOW_SECS,
Some(raw) => match humantime::parse_duration(raw) {
Ok(d) => i64::try_from(d.as_secs()).unwrap_or(DEFAULT_TOP_WINDOW_SECS),
Err(_) => return Err(StatusCode::BAD_REQUEST),
},
};
if window_secs <= 0 {
return Err(StatusCode::BAD_REQUEST);
}
let limit = q.limit.unwrap_or(DEFAULT_TOP_LIMIT).clamp(1, MAX_TOP_LIMIT);
let from = Utc::now() - Duration::seconds(window_secs);
let sql = format!(
"SELECT h.pc_id,
a.hostname AS hostname,
AVG(h.{metric}) AS value
FROM host_perf_samples h
LEFT JOIN agents a ON a.pc_id = h.pc_id
WHERE h.at > ?
AND h.{metric} IS NOT NULL
GROUP BY h.pc_id
ORDER BY value DESC NULLS LAST
LIMIT ?",
metric = metric.column(),
);
let rows = sqlx::query(sqlx::AssertSqlSafe(sql))
.bind(from)
.bind(limit)
.fetch_all(&pool)
.await
.map_err(|e| {
warn!(error = %e, metric = ?metric, "top perf query");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let rows = rows
.into_iter()
.map(|r| TopPerfRow {
pc_id: r.try_get("pc_id").unwrap_or_default(),
hostname: r.try_get("hostname").ok(),
value: r.try_get("value").unwrap_or(0.0),
})
.collect();
Ok(Json(TopPerfResponse {
metric: metric.column().to_string(),
window_seconds: window_secs,
rows,
}))
}
#[derive(Serialize)]
pub struct ActiveInvestigation {
pub pc_id: String,
pub hostname: Option<String>,
pub latest_at: DateTime<Utc>,
}
#[derive(Serialize)]
pub struct ActiveInvestigationsResponse {
pub window_seconds: i64,
pub rows: Vec<ActiveInvestigation>,
}
pub async fn active_investigations(
State(pool): State<SqlitePool>,
) -> Result<Json<ActiveInvestigationsResponse>, StatusCode> {
let from = Utc::now() - Duration::seconds(ACTIVE_INVESTIGATION_WINDOW_SECS);
let rows = sqlx::query(
"SELECT p.pc_id, a.hostname AS hostname, MAX(p.at) AS latest_at
FROM process_perf_samples p
LEFT JOIN agents a ON a.pc_id = p.pc_id
WHERE p.at > ?
GROUP BY p.pc_id
ORDER BY latest_at DESC",
)
.bind(from)
.fetch_all(&pool)
.await
.map_err(|e| {
warn!(error = %e, "active_investigations query");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let rows = rows
.into_iter()
.filter_map(|r| {
let pc_id: String = r.try_get("pc_id").ok()?;
let latest_at: DateTime<Utc> = r.try_get("latest_at").ok()?;
Some(ActiveInvestigation {
pc_id,
hostname: r.try_get("hostname").ok(),
latest_at,
})
})
.collect();
Ok(Json(ActiveInvestigationsResponse {
window_seconds: ACTIVE_INVESTIGATION_WINDOW_SECS,
rows,
}))
}