use std::sync::OnceLock;
use std::time::Duration;
static METRICS: OnceLock<Box<dyn MetricsRecorder>> = OnceLock::new();
pub trait MetricsRecorder: Send + Sync + 'static {
fn increment_counter(&self, _name: &str, _value: u64, _labels: &[(&str, &str)]) {}
fn set_gauge(&self, _name: &str, _value: f64, _labels: &[(&str, &str)]) {}
fn record_histogram(&self, _name: &str, _value: f64, _labels: &[(&str, &str)]) {}
fn record_timing(&self, name: &str, duration: Duration, labels: &[(&str, &str)]) {
self.record_histogram(name, duration.as_secs_f64(), labels);
}
}
pub fn register_metrics(recorder: Box<dyn MetricsRecorder>) {
let _ = METRICS.set(recorder);
}
pub fn get_metrics() -> Option<&'static dyn MetricsRecorder> {
METRICS.get().map(|b| b.as_ref())
}
pub mod names {
pub const SESSIONS_STARTED: &str = "ssm_sessions_started_total";
pub const SESSIONS_TERMINATED: &str = "ssm_sessions_terminated_total";
pub const SESSION_ERRORS: &str = "ssm_session_errors_total";
pub const ACTIVE_SESSIONS: &str = "ssm_active_sessions";
pub const MESSAGES_SENT: &str = "ssm_messages_sent_total";
pub const MESSAGES_RECEIVED: &str = "ssm_messages_received_total";
pub const BYTES_SENT: &str = "ssm_bytes_sent_total";
pub const BYTES_RECEIVED: &str = "ssm_bytes_received_total";
pub const RETRANSMISSIONS: &str = "ssm_message_retransmissions_total";
pub const ACKS_RECEIVED: &str = "ssm_acks_received_total";
pub const RTT_SECONDS: &str = "ssm_rtt_seconds";
pub const CONNECTION_DURATION: &str = "ssm_connection_duration_seconds";
pub const HANDSHAKE_DURATION: &str = "ssm_handshake_duration_seconds";
pub const RECONNECTIONS: &str = "ssm_websocket_reconnections_total";
pub const CONNECTION_HEALTH: &str = "ssm_connection_health";
pub const PACKET_LOSS_PERCENT: &str = "ssm_packet_loss_percent";
pub const RTT_JITTER_SECONDS: &str = "ssm_rtt_jitter_seconds";
}
#[inline]
#[allow(dead_code)]
pub(crate) fn counter(name: &str, value: u64, labels: &[(&str, &str)]) {
if let Some(m) = get_metrics() {
m.increment_counter(name, value, labels);
}
}
#[inline]
#[allow(dead_code)]
pub(crate) fn gauge(name: &str, value: f64, labels: &[(&str, &str)]) {
if let Some(m) = get_metrics() {
m.set_gauge(name, value, labels);
}
}
#[inline]
#[allow(dead_code)]
pub(crate) fn histogram(name: &str, value: f64, labels: &[(&str, &str)]) {
if let Some(m) = get_metrics() {
m.record_histogram(name, value, labels);
}
}
#[inline]
#[allow(dead_code)]
pub(crate) fn timing(name: &str, duration: Duration, labels: &[(&str, &str)]) {
if let Some(m) = get_metrics() {
m.record_timing(name, duration, labels);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
struct TestRecorder {
counter_calls: Arc<AtomicU64>,
}
impl MetricsRecorder for TestRecorder {
fn increment_counter(&self, _name: &str, value: u64, _labels: &[(&str, &str)]) {
self.counter_calls.fetch_add(value, Ordering::Relaxed);
}
}
#[test]
fn test_metrics_no_recorder() {
counter("test", 1, &[]);
gauge("test", 1.0, &[]);
histogram("test", 1.0, &[]);
}
#[test]
fn test_metrics_with_recorder() {
let counter_calls = Arc::new(AtomicU64::new(0));
let recorder = TestRecorder {
counter_calls: counter_calls.clone(),
};
register_metrics(Box::new(recorder));
counter("test_metric", 5, &[("label", "value")]);
let calls = counter_calls.load(Ordering::Relaxed);
assert!(
calls >= 5,
"Expected at least 5 counter increments, got {}",
calls
);
}
#[test]
fn test_metric_names() {
assert!(names::SESSIONS_STARTED.starts_with("ssm_"));
assert!(names::SESSIONS_STARTED.ends_with("_total"));
assert!(names::RTT_SECONDS.ends_with("_seconds"));
}
}