use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct MetricsCollector {
start_time: Instant,
messages_sent: AtomicU64,
messages_received: AtomicU64,
broadcasts_sent: AtomicU64,
errors: AtomicU64,
total_processing_time: AtomicU64,
connections: AtomicU64,
custom_metrics: Arc<RwLock<HashMap<String, f64>>>,
samples: Arc<RwLock<Vec<PerformanceSample>>>,
}
impl MetricsCollector {
pub fn new() -> Self {
Self {
start_time: Instant::now(),
messages_sent: AtomicU64::new(0),
messages_received: AtomicU64::new(0),
broadcasts_sent: AtomicU64::new(0),
errors: AtomicU64::new(0),
total_processing_time: AtomicU64::new(0),
connections: AtomicU64::new(0),
custom_metrics: Arc::new(RwLock::new(HashMap::new())),
samples: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn record_message_sent(&self) {
self.messages_sent.fetch_add(1, Ordering::Relaxed);
}
pub fn record_message_received(&self) {
self.messages_received.fetch_add(1, Ordering::Relaxed);
}
pub fn record_broadcast_sent(&self) {
self.broadcasts_sent.fetch_add(1, Ordering::Relaxed);
}
pub fn record_error(&self) {
self.errors.fetch_add(1, Ordering::Relaxed);
}
pub fn record_processing_time(&self, duration: Duration) {
self.total_processing_time.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
}
pub fn record_startup(&self) {
self.messages_sent.store(0, Ordering::Relaxed);
self.messages_received.store(0, Ordering::Relaxed);
self.broadcasts_sent.store(0, Ordering::Relaxed);
self.errors.store(0, Ordering::Relaxed);
self.total_processing_time.store(0, Ordering::Relaxed);
}
pub fn record_shutdown(&self) {
}
pub fn add_connection(&self) {
self.connections.fetch_add(1, Ordering::Relaxed);
}
pub fn remove_connection(&self) {
self.connections.fetch_sub(1, Ordering::Relaxed);
}
pub async fn set_custom_metric(&self, name: String, value: f64) {
let mut metrics = self.custom_metrics.write().await;
metrics.insert(name, value);
}
pub fn get_metrics(&self) -> HashMap<String, f64> {
let uptime = self.start_time.elapsed().as_secs_f64();
let messages_sent = self.messages_sent.load(Ordering::Relaxed) as f64;
let messages_received = self.messages_received.load(Ordering::Relaxed) as f64;
let broadcasts_sent = self.broadcasts_sent.load(Ordering::Relaxed) as f64;
let errors = self.errors.load(Ordering::Relaxed) as f64;
let total_processing_time = self.total_processing_time.load(Ordering::Relaxed) as f64 / 1_000_000_000.0; let connections = self.connections.load(Ordering::Relaxed) as f64;
let mut metrics = HashMap::new();
metrics.insert("uptime_seconds".to_string(), uptime);
metrics.insert("messages_sent".to_string(), messages_sent);
metrics.insert("messages_received".to_string(), messages_received);
metrics.insert("broadcasts_sent".to_string(), broadcasts_sent);
metrics.insert("errors".to_string(), errors);
metrics.insert("total_processing_time_seconds".to_string(), total_processing_time);
metrics.insert("connections".to_string(), connections);
if uptime > 0.0 {
metrics.insert("messages_per_second".to_string(), (messages_sent + messages_received) / uptime);
metrics.insert("error_rate".to_string(), errors / (messages_sent + messages_received + 1.0));
}
if messages_sent + messages_received > 0.0 {
metrics.insert("avg_processing_time_ms".to_string(),
(total_processing_time * 1000.0) / (messages_sent + messages_received));
}
metrics
}
pub fn get_uptime(&self) -> f64 {
self.start_time.elapsed().as_secs_f64()
}
pub fn get_messages_sent(&self) -> u64 {
self.messages_sent.load(Ordering::Relaxed)
}
pub fn get_messages_received(&self) -> u64 {
self.messages_received.load(Ordering::Relaxed)
}
pub fn get_errors(&self) -> u64 {
self.errors.load(Ordering::Relaxed)
}
pub fn get_connections(&self) -> u64 {
self.connections.load(Ordering::Relaxed)
}
pub async fn record_sample(&self, operation: String, duration: Duration, success: bool) {
let sample = PerformanceSample {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
operation,
duration_ms: duration.as_millis() as f64,
success,
};
let mut samples = self.samples.write().await;
samples.push(sample);
if samples.len() > 1000 {
let drain_count = samples.len() - 1000;
samples.drain(0..drain_count);
}
}
pub async fn get_performance_stats(&self) -> PerformanceStats {
let samples = self.samples.read().await;
let metrics = self.get_metrics();
if samples.is_empty() {
return PerformanceStats::default();
}
let successful_samples: Vec<_> = samples.iter().filter(|s| s.success).collect();
let failed_samples: Vec<_> = samples.iter().filter(|s| !s.success).collect();
let total_samples = samples.len() as f64;
let success_rate = successful_samples.len() as f64 / total_samples;
let durations: Vec<f64> = successful_samples.iter().map(|s| s.duration_ms).collect();
let avg_duration = if !durations.is_empty() {
durations.iter().sum::<f64>() / durations.len() as f64
} else {
0.0
};
let mut sorted_durations = durations.clone();
sorted_durations.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p95_duration = if !sorted_durations.is_empty() {
let index = (sorted_durations.len() as f64 * 0.95) as usize;
sorted_durations.get(index.min(sorted_durations.len() - 1)).copied().unwrap_or(0.0)
} else {
0.0
};
PerformanceStats {
total_samples: total_samples as u64,
successful_samples: successful_samples.len() as u64,
failed_samples: failed_samples.len() as u64,
success_rate,
avg_duration_ms: avg_duration,
p95_duration_ms: p95_duration,
messages_per_second: metrics.get("messages_per_second").copied().unwrap_or(0.0),
error_rate: metrics.get("error_rate").copied().unwrap_or(0.0),
uptime_seconds: metrics.get("uptime_seconds").copied().unwrap_or(0.0),
}
}
pub async fn reset(&self) {
self.messages_sent.store(0, Ordering::Relaxed);
self.messages_received.store(0, Ordering::Relaxed);
self.broadcasts_sent.store(0, Ordering::Relaxed);
self.errors.store(0, Ordering::Relaxed);
self.total_processing_time.store(0, Ordering::Relaxed);
self.connections.store(0, Ordering::Relaxed);
let mut custom_metrics = self.custom_metrics.write().await;
custom_metrics.clear();
let mut samples = self.samples.write().await;
samples.clear();
}
}
impl Default for MetricsCollector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceSample {
pub timestamp: u64,
pub operation: String,
pub duration_ms: f64,
pub success: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceStats {
pub total_samples: u64,
pub successful_samples: u64,
pub failed_samples: u64,
pub success_rate: f64,
pub avg_duration_ms: f64,
pub p95_duration_ms: f64,
pub messages_per_second: f64,
pub error_rate: f64,
pub uptime_seconds: f64,
}
impl Default for PerformanceStats {
fn default() -> Self {
Self {
total_samples: 0,
successful_samples: 0,
failed_samples: 0,
success_rate: 0.0,
avg_duration_ms: 0.0,
p95_duration_ms: 0.0,
messages_per_second: 0.0,
error_rate: 0.0,
uptime_seconds: 0.0,
}
}
}
pub struct MetricsExporter;
impl MetricsExporter {
pub fn to_json(stats: &PerformanceStats) -> serde_json::Result<String> {
serde_json::to_string_pretty(stats)
}
pub fn to_prometheus(metrics: &HashMap<String, f64>) -> String {
let mut output = String::new();
for (name, value) in metrics {
let sanitized_name = name.replace(".", "_").replace(" ", "_");
output.push_str(&format!(
"# TYPE odin_{} gauge\nodin_{} {}\n",
sanitized_name, sanitized_name, value
));
}
output
}
pub fn to_csv(samples: &[PerformanceSample]) -> String {
let mut output = String::from("timestamp,operation,duration_ms,success\n");
for sample in samples {
output.push_str(&format!(
"{},{},{},{}\n",
sample.timestamp, sample.operation, sample.duration_ms, sample.success
));
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration as TokioDuration};
#[tokio::test]
async fn test_metrics_collection() {
let collector = MetricsCollector::new();
collector.record_message_sent();
collector.record_message_received();
collector.record_error();
let metrics = collector.get_metrics();
assert_eq!(metrics.get("messages_sent"), Some(&1.0));
assert_eq!(metrics.get("messages_received"), Some(&1.0));
assert_eq!(metrics.get("errors"), Some(&1.0));
}
#[tokio::test]
async fn test_custom_metrics() {
let collector = MetricsCollector::new();
collector.set_custom_metric("custom_metric".to_string(), 42.0).await;
}
#[tokio::test]
async fn test_performance_samples() {
let collector = MetricsCollector::new();
collector.record_sample(
"test_operation".to_string(),
Duration::from_millis(100),
true,
).await;
collector.record_sample(
"test_operation".to_string(),
Duration::from_millis(200),
false,
).await;
let stats = collector.get_performance_stats().await;
assert_eq!(stats.total_samples, 2);
assert_eq!(stats.successful_samples, 1);
assert_eq!(stats.failed_samples, 1);
assert_eq!(stats.success_rate, 0.5);
}
#[tokio::test]
async fn test_metrics_reset() {
let collector = MetricsCollector::new();
collector.record_message_sent();
collector.record_message_received();
let metrics_before = collector.get_metrics();
assert_eq!(metrics_before.get("messages_sent"), Some(&1.0));
collector.reset().await;
let metrics_after = collector.get_metrics();
assert_eq!(metrics_after.get("messages_sent"), Some(&0.0));
}
#[test]
fn test_metrics_exporter() {
let mut metrics = HashMap::new();
metrics.insert("messages_sent".to_string(), 100.0);
metrics.insert("uptime_seconds".to_string(), 3600.0);
let prometheus = MetricsExporter::to_prometheus(&metrics);
assert!(prometheus.contains("odin_messages_sent"));
assert!(prometheus.contains("odin_uptime_seconds"));
let samples = vec![
PerformanceSample {
timestamp: 1234567890,
operation: "test".to_string(),
duration_ms: 100.0,
success: true,
}
];
let csv = MetricsExporter::to_csv(&samples);
assert!(csv.contains("timestamp,operation,duration_ms,success"));
assert!(csv.contains("1234567890,test,100,true"));
}
#[tokio::test]
async fn test_derived_metrics() {
let collector = MetricsCollector::new();
for _ in 0..10 {
collector.record_message_sent();
}
for _ in 0..5 {
collector.record_message_received();
}
collector.record_error();
sleep(TokioDuration::from_millis(10)).await;
let metrics = collector.get_metrics();
assert!(metrics.contains_key("messages_per_second"));
assert!(metrics.contains_key("error_rate"));
assert!(metrics.get("uptime_seconds").unwrap() > &0.0);
let error_rate = metrics.get("error_rate").unwrap();
assert!(*error_rate > 0.0 && *error_rate < 1.0); }
}