use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use crate::error::ServerlessError;
pub struct MetricsCollector {
invocation_count: AtomicU64,
total_duration_ms: AtomicU64,
}
impl MetricsCollector {
#[must_use]
pub fn new() -> Self {
Self {
invocation_count: AtomicU64::new(0),
total_duration_ms: AtomicU64::new(0),
}
}
pub fn record_invocation(&self, function_name: &str, duration_ms: u64, success: bool) {
if cfg!(feature = "telemetry") {
self.invocation_count.fetch_add(1, Ordering::SeqCst);
self.total_duration_ms
.fetch_add(duration_ms, Ordering::SeqCst);
tracing::info!(
"Function invocation: {} - Success: {}, Duration: {}ms",
function_name,
success,
duration_ms
);
}
}
#[must_use]
pub fn avg_duration_ms(&self) -> f64 {
let count = self.invocation_count.load(Ordering::SeqCst);
if count == 0 {
return 0.0;
}
let total = self.total_duration_ms.load(Ordering::SeqCst);
total as f64 / count as f64
}
#[must_use]
pub fn invocation_count(&self) -> u64 {
self.invocation_count.load(Ordering::SeqCst)
}
}
impl Default for MetricsCollector {
fn default() -> Self {
Self::new()
}
}
pub struct TelemetryContext {
start_time: Instant,
function_name: String,
trace_id: String,
}
impl TelemetryContext {
#[must_use]
pub fn new(function_name: &str) -> Self {
Self {
start_time: Instant::now(),
function_name: function_name.to_string(),
trace_id: uuid::Uuid::new_v4().to_string(),
}
}
pub fn record_completion(&self, success: bool, error: Option<&ServerlessError>) {
if cfg!(feature = "telemetry") {
let duration_ms = self.start_time.elapsed().as_millis() as u64;
if let Some(err) = error {
tracing::error!(
trace_id = %self.trace_id,
function = %self.function_name,
duration_ms = duration_ms,
error = %err,
"Function execution failed"
);
} else {
tracing::info!(
trace_id = %self.trace_id,
function = %self.function_name,
duration_ms = duration_ms,
"Function execution completed"
);
}
get_metrics_collector().record_invocation(&self.function_name, duration_ms, success);
}
}
}
static METRICS_COLLECTOR: OnceLock<MetricsCollector> = OnceLock::new();
#[must_use]
pub fn get_metrics_collector() -> &'static MetricsCollector {
METRICS_COLLECTOR.get_or_init(MetricsCollector::new)
}
pub fn init_telemetry() {
if cfg!(feature = "telemetry") {
tracing::info!("Telemetry system initialized");
}
}
pub async fn telemetry_middleware<F, R>(
function_name: &str,
operation: F,
) -> Result<R, ServerlessError>
where
F: std::future::Future<Output = Result<R, ServerlessError>>,
{
if cfg!(feature = "telemetry") {
let ctx = TelemetryContext::new(function_name);
match operation.await {
Ok(result) => {
ctx.record_completion(true, None);
Ok(result)
}
Err(err) => {
ctx.record_completion(false, Some(&err));
Err(err)
}
}
} else {
operation.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metrics_collector_initial_state() {
let collector = MetricsCollector::new();
assert_eq!(collector.invocation_count(), 0);
assert_eq!(collector.avg_duration_ms(), 0.0);
}
#[test]
fn test_metrics_collector_records_invocation() {
let collector = MetricsCollector::new();
collector.record_invocation("test_func", 100, true);
collector.record_invocation("test_func", 200, true);
let count = collector.invocation_count();
let avg_duration = collector.avg_duration_ms();
if cfg!(feature = "telemetry") {
assert_eq!(count, 2);
assert_eq!(avg_duration, 150.0);
} else {
assert_eq!(count, 0);
assert_eq!(avg_duration, 0.0);
}
}
}