#![cfg(feature = "portal")]
use axum::{
extract::{Path, Query, Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, PgPool, Row};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::time::interval;
use uuid::Uuid;
use crate::portal::auth_db::PortalState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecord {
pub api_key_id: Uuid,
pub endpoint: String,
pub method: String,
pub status_code: u16,
pub response_time_ms: i32,
pub request_size_bytes: Option<i32>,
pub response_size_bytes: Option<i32>,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
pub api_key_id: Uuid,
pub total_requests: i64,
pub avg_response_time_ms: f64,
pub total_bytes: i64,
pub success_rate: f64,
pub error_rate: f64,
pub top_endpoints: Vec<EndpointStats>,
pub period_start: DateTime<Utc>,
pub period_end: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct EndpointStats {
pub endpoint: String,
pub count: i64,
pub avg_response_time_ms: f64,
pub success_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailySummary {
pub date: String,
pub requests: i64,
pub avg_response_time_ms: f64,
pub bytes: i64,
pub success_rate: f64,
}
#[derive(Debug, Clone)]
pub struct TrackerConfig {
pub batch_size: usize,
pub flush_interval_secs: u64,
pub max_buffer_capacity: usize,
}
impl Default for TrackerConfig {
fn default() -> Self {
Self {
batch_size: 100,
flush_interval_secs: 10,
max_buffer_capacity: 10_000,
}
}
}
#[derive(Clone)]
pub struct UsageTracker {
pool: PgPool,
buffer: Arc<RwLock<Vec<UsageRecord>>>,
config: Arc<TrackerConfig>,
}
impl UsageTracker {
pub fn new(pool: PgPool) -> Self {
Self {
pool,
buffer: Arc::new(RwLock::new(Vec::new())),
config: Arc::new(TrackerConfig::default()),
}
}
pub fn with_batch_size(mut self, size: usize) -> Self {
Arc::make_mut(&mut self.config).batch_size = size;
self
}
pub fn with_flush_interval(mut self, flush_interval: std::time::Duration) -> Self {
Arc::make_mut(&mut self.config).flush_interval_secs = flush_interval.as_secs();
self
}
pub fn with_max_capacity(mut self, capacity: usize) -> Self {
Arc::make_mut(&mut self.config).max_buffer_capacity = capacity;
self
}
pub fn start(self) -> Self {
let tracker = self.clone();
tokio::spawn(async move {
tracker.run_flush_loop().await;
});
self
}
pub async fn track(&self, record: UsageRecord) {
let mut buffer = self.buffer.write().await;
if buffer.len() >= self.config.max_buffer_capacity {
tracing::warn!(
"Usage tracker buffer at capacity ({}), dropping record",
self.config.max_buffer_capacity
);
return;
}
buffer.push(record);
if buffer.len() >= self.config.batch_size {
drop(buffer); self.flush().await;
}
}
async fn run_flush_loop(&self) {
let mut ticker = interval(std::time::Duration::from_secs(
self.config.flush_interval_secs,
));
loop {
ticker.tick().await;
self.flush().await;
}
}
async fn flush(&self) {
let records = {
let mut buffer = self.buffer.write().await;
std::mem::take(&mut *buffer)
};
if records.is_empty() {
return;
}
let count = records.len();
if let Err(e) = self.insert_batch(records).await {
tracing::error!("Failed to flush {} usage records: {}", count, e);
} else {
tracing::debug!("Flushed {} usage records to database", count);
}
}
async fn insert_batch(&self, records: Vec<UsageRecord>) -> Result<(), sqlx::Error> {
if records.is_empty() {
return Ok(());
}
let mut query_builder = sqlx::QueryBuilder::new(
r#"
INSERT INTO api_key_usage (
api_key_id, endpoint, method, status_code,
response_time_ms, request_size_bytes, response_size_bytes,
ip_address, user_agent, created_at
)
"#,
);
query_builder.push_values(records, |mut b, record| {
b.push_bind(record.api_key_id)
.push_bind(record.endpoint)
.push_bind(record.method)
.push_bind(record.status_code as i16)
.push_bind(record.response_time_ms)
.push_bind(record.request_size_bytes)
.push_bind(record.response_size_bytes)
.push_bind(record.ip_address.map(|ip| ip.to_string()))
.push_bind(record.user_agent)
.push_bind(record.timestamp);
});
let query = query_builder.build();
query.execute(&self.pool).await?;
Ok(())
}
pub async fn flush_now(&self) -> Result<(), sqlx::Error> {
self.flush().await;
Ok(())
}
}
pub async fn track_api_usage(
State(tracker): State<UsageTracker>,
req: Request,
next: Next,
) -> Response {
let start = std::time::Instant::now();
let method = req.method().to_string();
let endpoint = req.uri().path().to_string();
let api_key_id = req.extensions().get::<Uuid>().copied();
let ip_address = extract_ip_address(&req);
let user_agent = req
.headers()
.get("user-agent")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let request_size = calculate_request_size(&req);
let response = next.run(req).await;
let duration = start.elapsed();
let status_code = response.status().as_u16();
let response_size = calculate_response_size(&response);
if let Some(key_id) = api_key_id {
let record = UsageRecord {
api_key_id: key_id,
endpoint,
method,
status_code,
response_time_ms: duration.as_millis() as i32,
request_size_bytes: Some(request_size),
response_size_bytes: Some(response_size),
ip_address,
user_agent,
timestamp: Utc::now(),
};
tokio::spawn(async move {
tracker.track(record).await;
});
}
response
}
fn extract_ip_address(req: &Request) -> Option<IpAddr> {
if let Some(xff) = req.headers().get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
if let Some(first_ip) = xff_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse() {
return Some(ip);
}
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(ip_str) = real_ip.to_str() {
if let Ok(ip) = ip_str.parse() {
return Some(ip);
}
}
}
None
}
fn calculate_request_size(req: &Request) -> i32 {
let headers_size: usize = req
.headers()
.iter()
.map(|(k, v)| k.as_str().len() + v.len())
.sum();
headers_size as i32
}
fn calculate_response_size(res: &Response) -> i32 {
let headers_size: usize = res
.headers()
.iter()
.map(|(k, v)| k.as_str().len() + v.len())
.sum();
headers_size as i32
}
pub struct UsageRepository<'a> {
pool: &'a PgPool,
}
impl<'a> UsageRepository<'a> {
pub fn new(pool: &'a PgPool) -> Self {
Self { pool }
}
pub async fn get_stats(
&self,
api_key_id: Uuid,
from: DateTime<Utc>,
to: DateTime<Utc>,
) -> Result<UsageStats, sqlx::Error> {
let row = sqlx::query(
r#"
SELECT
COUNT(*)::BIGINT as total_requests,
COALESCE(AVG(response_time_ms), 0)::FLOAT8 as avg_response_time,
COALESCE(SUM(COALESCE(request_size_bytes, 0) + COALESCE(response_size_bytes, 0)), 0)::BIGINT as total_bytes,
COUNT(*) FILTER (WHERE status_code < 400)::BIGINT as success_count,
COUNT(*) FILTER (WHERE status_code >= 400)::BIGINT as error_count
FROM api_key_usage
WHERE api_key_id = $1
AND created_at >= $2
AND created_at < $3
"#,
)
.bind(api_key_id)
.bind(from)
.bind(to)
.fetch_one(self.pool)
.await?;
let total_requests: i64 = row.try_get("total_requests").unwrap_or(0);
let avg_response_time: f64 = row.try_get("avg_response_time").unwrap_or(0.0);
let total_bytes: i64 = row.try_get("total_bytes").unwrap_or(0);
let success_count: i64 = row.try_get("success_count").unwrap_or(0);
let error_count: i64 = row.try_get("error_count").unwrap_or(0);
let endpoint_rows = sqlx::query(
r#"
SELECT
endpoint,
COUNT(*)::BIGINT as count,
COALESCE(AVG(response_time_ms), 0)::FLOAT8 as avg_response_time,
CASE WHEN COUNT(*) > 0
THEN (COUNT(*) FILTER (WHERE status_code < 400))::FLOAT8 / COUNT(*)::FLOAT8
ELSE 0
END as success_rate
FROM api_key_usage
WHERE api_key_id = $1
AND created_at >= $2
AND created_at < $3
GROUP BY endpoint
ORDER BY count DESC
LIMIT 10
"#,
)
.bind(api_key_id)
.bind(from)
.bind(to)
.fetch_all(self.pool)
.await?;
let top_endpoints: Vec<EndpointStats> = endpoint_rows
.into_iter()
.map(|row| EndpointStats {
endpoint: row.try_get("endpoint").unwrap_or_default(),
count: row.try_get("count").unwrap_or(0),
avg_response_time_ms: row.try_get("avg_response_time").unwrap_or(0.0),
success_rate: row.try_get("success_rate").unwrap_or(0.0),
})
.collect();
let total = total_requests as f64;
let success = success_count as f64;
let error = error_count as f64;
Ok(UsageStats {
api_key_id,
total_requests,
avg_response_time_ms: avg_response_time,
total_bytes,
success_rate: if total > 0.0 { success / total } else { 0.0 },
error_rate: if total > 0.0 { error / total } else { 0.0 },
top_endpoints,
period_start: from,
period_end: to,
})
}
pub async fn get_usage_by_endpoint(
&self,
api_key_id: Uuid,
from: DateTime<Utc>,
to: DateTime<Utc>,
) -> Result<Vec<EndpointStats>, sqlx::Error> {
let rows = sqlx::query(
r#"
SELECT
endpoint,
COUNT(*)::BIGINT as count,
COALESCE(AVG(response_time_ms), 0)::FLOAT8 as avg_response_time,
CASE WHEN COUNT(*) > 0
THEN (COUNT(*) FILTER (WHERE status_code < 400))::FLOAT8 / COUNT(*)::FLOAT8
ELSE 0
END as success_rate
FROM api_key_usage
WHERE api_key_id = $1
AND created_at >= $2
AND created_at < $3
GROUP BY endpoint
ORDER BY count DESC
"#,
)
.bind(api_key_id)
.bind(from)
.bind(to)
.fetch_all(self.pool)
.await?;
Ok(rows
.into_iter()
.map(|row| EndpointStats {
endpoint: row.try_get("endpoint").unwrap_or_default(),
count: row.try_get("count").unwrap_or(0),
avg_response_time_ms: row.try_get("avg_response_time").unwrap_or(0.0),
success_rate: row.try_get("success_rate").unwrap_or(0.0),
})
.collect())
}
pub async fn get_daily_summary(
&self,
api_key_id: Uuid,
days: u32,
) -> Result<Vec<DailySummary>, sqlx::Error> {
let from = Utc::now() - Duration::days(days as i64);
let rows = sqlx::query(
r#"
SELECT
DATE(created_at)::TEXT as date,
COUNT(*)::BIGINT as requests,
COALESCE(AVG(response_time_ms), 0)::FLOAT8 as avg_response_time,
COALESCE(SUM(COALESCE(request_size_bytes, 0) + COALESCE(response_size_bytes, 0)), 0)::BIGINT as bytes,
CASE WHEN COUNT(*) > 0
THEN (COUNT(*) FILTER (WHERE status_code < 400))::FLOAT8 / COUNT(*)::FLOAT8
ELSE 0
END as success_rate
FROM api_key_usage
WHERE api_key_id = $1
AND created_at >= $2
GROUP BY DATE(created_at)
ORDER BY date DESC
"#,
)
.bind(api_key_id)
.bind(from)
.fetch_all(self.pool)
.await?;
Ok(rows
.into_iter()
.map(|row| DailySummary {
date: row.try_get("date").unwrap_or_default(),
requests: row.try_get("requests").unwrap_or(0),
avg_response_time_ms: row.try_get("avg_response_time").unwrap_or(0.0),
bytes: row.try_get("bytes").unwrap_or(0),
success_rate: row.try_get("success_rate").unwrap_or(0.0),
})
.collect())
}
}
#[derive(Debug, Deserialize)]
pub struct UsageStatsQuery {
pub from: Option<DateTime<Utc>>,
pub to: Option<DateTime<Utc>>,
#[serde(default = "default_days")]
pub days: u32,
}
fn default_days() -> u32 {
7
}
pub async fn get_usage_stats(
State(state): State<PortalState>,
Path(api_key_id): Path<Uuid>,
Query(params): Query<UsageStatsQuery>,
) -> impl IntoResponse {
let to = params.to.unwrap_or_else(Utc::now);
let from = params
.from
.unwrap_or_else(|| to - Duration::days(params.days as i64));
let repo = UsageRepository::new(state.db.pool());
match repo.get_stats(api_key_id, from, to).await {
Ok(stats) => (StatusCode::OK, Json(stats)).into_response(),
Err(e) => {
tracing::error!("Failed to get usage stats: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to retrieve usage statistics"
})),
)
.into_response()
}
}
}
pub async fn get_daily_usage(
State(state): State<PortalState>,
Path(api_key_id): Path<Uuid>,
Query(params): Query<UsageStatsQuery>,
) -> impl IntoResponse {
let repo = UsageRepository::new(state.db.pool());
match repo.get_daily_summary(api_key_id, params.days).await {
Ok(summary) => (StatusCode::OK, Json(summary)).into_response(),
Err(e) => {
tracing::error!("Failed to get daily usage: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to retrieve daily usage"
})),
)
.into_response()
}
}
}
pub async fn get_endpoint_usage(
State(state): State<PortalState>,
Path(api_key_id): Path<Uuid>,
Query(params): Query<UsageStatsQuery>,
) -> impl IntoResponse {
let to = params.to.unwrap_or_else(Utc::now);
let from = params
.from
.unwrap_or_else(|| to - Duration::days(params.days as i64));
let repo = UsageRepository::new(state.db.pool());
match repo.get_usage_by_endpoint(api_key_id, from, to).await {
Ok(endpoints) => (StatusCode::OK, Json(endpoints)).into_response(),
Err(e) => {
tracing::error!("Failed to get endpoint usage: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to retrieve endpoint usage"
})),
)
.into_response()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tracker_config_defaults() {
let config = TrackerConfig::default();
assert_eq!(config.batch_size, 100);
assert_eq!(config.flush_interval_secs, 10);
assert_eq!(config.max_buffer_capacity, 10_000);
}
#[test]
fn test_usage_record_creation() {
let record = UsageRecord {
api_key_id: Uuid::new_v4(),
endpoint: "/api/v1/test".to_string(),
method: "GET".to_string(),
status_code: 200,
response_time_ms: 150,
request_size_bytes: Some(1024),
response_size_bytes: Some(2048),
ip_address: Some("127.0.0.1".parse().unwrap()),
user_agent: Some("test-agent".to_string()),
timestamp: Utc::now(),
};
assert_eq!(record.method, "GET");
assert_eq!(record.status_code, 200);
}
}