use crate::utils::error::{GatewayError, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
pub struct MetricsCollector {
prometheus_metrics: Arc<RwLock<PrometheusMetrics>>,
datadog_client: Option<DataDogClient>,
otel_exporter: Option<OtelExporter>,
custom_metrics: Arc<RwLock<HashMap<String, MetricValue>>>,
}
#[derive(Debug, Default)]
pub struct PrometheusMetrics {
pub request_total: HashMap<String, u64>,
pub request_duration: HashMap<String, Vec<f64>>,
pub error_total: HashMap<String, u64>,
pub token_usage: HashMap<String, u64>,
pub cost_total: HashMap<String, f64>,
pub provider_health: HashMap<String, f64>,
pub cache_hits: u64,
pub cache_misses: u64,
pub active_connections: u64,
pub queue_size: HashMap<String, u64>,
}
pub struct DataDogClient {
api_key: String,
base_url: String,
client: reqwest::Client,
default_tags: Vec<String>,
}
pub struct OtelExporter {
endpoint: String,
headers: HashMap<String, String>,
client: reqwest::Client,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MetricValue {
Counter(u64),
Gauge(f64),
Histogram(Vec<f64>),
Summary { sum: f64, count: u64 },
}
pub struct LogAggregator {
destinations: Vec<LogDestination>,
buffer: Arc<RwLock<Vec<LogEntry>>>,
flush_interval: Duration,
}
#[derive(Debug, Clone)]
pub enum LogDestination {
Elasticsearch {
url: String,
index: String,
auth: Option<String>,
},
Splunk {
url: String,
token: String,
index: String,
},
CloudWatch {
region: String,
log_group: String,
log_stream: String,
},
GCPLogging {
project_id: String,
log_name: String,
},
DatadogLogs {
api_key: String,
site: String,
},
Webhook {
url: String,
headers: HashMap<String, String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub timestamp: DateTime<Utc>,
pub level: LogLevel,
pub message: String,
pub request_id: Option<String>,
pub user_id: Option<String>,
pub provider: Option<String>,
pub model: Option<String>,
pub duration_ms: Option<u64>,
pub tokens: Option<TokenUsage>,
pub cost: Option<f64>,
pub error: Option<ErrorDetails>,
pub fields: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LogLevel {
Error,
Warn,
Info,
Debug,
Trace,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorDetails {
pub error_type: String,
pub error_message: String,
pub error_code: Option<String>,
pub stack_trace: Option<String>,
}
pub struct AlertManager {
channels: Vec<AlertChannel>,
rules: Vec<AlertRule>,
alert_states: Arc<RwLock<HashMap<String, AlertState>>>,
}
#[derive(Debug, Clone)]
pub enum AlertChannel {
Slack {
webhook_url: String,
channel: String,
username: String,
},
Email {
smtp_host: String,
smtp_port: u16,
username: String,
password: String,
from: String,
to: Vec<String>,
},
PagerDuty {
integration_key: String,
severity: String,
},
Discord {
webhook_url: String,
},
Teams {
webhook_url: String,
},
Webhook {
url: String,
headers: HashMap<String, String>,
},
}
#[derive(Debug, Clone)]
pub struct AlertRule {
pub id: String,
pub name: String,
pub metric: String,
pub condition: AlertCondition,
pub threshold: f64,
pub window: Duration,
pub severity: AlertSeverity,
pub channels: Vec<String>,
pub enabled: bool,
}
#[derive(Debug, Clone)]
pub enum AlertCondition {
GreaterThan,
LessThan,
Equal,
NotEqual,
GreaterThanOrEqual,
LessThanOrEqual,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AlertSeverity {
Critical,
High,
Medium,
Low,
Info,
}
#[derive(Debug, Clone)]
pub struct AlertState {
pub firing: bool,
pub fired_at: Option<DateTime<Utc>>,
pub last_notification: Option<DateTime<Utc>>,
pub notification_count: u32,
}
pub struct PerformanceTracer {
traces: Arc<RwLock<HashMap<String, TraceSpan>>>,
exporters: Vec<TraceExporter>,
}
#[derive(Debug, Clone)]
pub struct TraceSpan {
pub span_id: String,
pub parent_id: Option<String>,
pub trace_id: String,
pub operation: String,
pub start_time: Instant,
pub end_time: Option<Instant>,
pub tags: HashMap<String, String>,
pub logs: Vec<SpanLog>,
}
#[derive(Debug, Clone)]
pub struct SpanLog {
pub timestamp: Instant,
pub message: String,
pub fields: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub enum TraceExporter {
Jaeger {
endpoint: String,
service_name: String,
},
Zipkin {
endpoint: String,
service_name: String,
},
OpenTelemetry {
endpoint: String,
headers: HashMap<String, String>,
},
DataDogAPM {
api_key: String,
service_name: String,
},
}
impl MetricsCollector {
pub fn new() -> Self {
Self {
prometheus_metrics: Arc::new(RwLock::new(PrometheusMetrics::default())),
datadog_client: None,
otel_exporter: None,
custom_metrics: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_datadog(mut self, api_key: String, site: String) -> Self {
self.datadog_client = Some(DataDogClient {
api_key,
base_url: format!("https://api.{}", site),
client: reqwest::Client::new(),
default_tags: vec![
"service:litellm-gateway".to_string(),
"env:production".to_string(),
],
});
self
}
pub fn with_otel(mut self, endpoint: String, headers: HashMap<String, String>) -> Self {
self.otel_exporter = Some(OtelExporter {
endpoint,
headers,
client: reqwest::Client::new(),
});
self
}
pub async fn record_request(
&self,
provider: &str,
model: &str,
duration: Duration,
tokens: Option<TokenUsage>,
cost: Option<f64>,
success: bool,
) {
let mut metrics = self.prometheus_metrics.write().await;
let key = format!("{}:{}", provider, model);
*metrics.request_total.entry(key.clone()).or_insert(0) += 1;
metrics.request_duration.entry(key.clone())
.or_insert_with(Vec::new)
.push(duration.as_secs_f64());
if !success {
*metrics.error_total.entry(key.clone()).or_insert(0) += 1;
}
if let Some(token_usage) = tokens {
*metrics.token_usage.entry(format!("{}:prompt", key)).or_insert(0) += token_usage.prompt_tokens as u64;
*metrics.token_usage.entry(format!("{}:completion", key)).or_insert(0) += token_usage.completion_tokens as u64;
}
if let Some(request_cost) = cost {
*metrics.cost_total.entry(key).or_insert(0.0) += request_cost;
}
}
pub async fn record_cache_hit(&self, hit: bool) {
let mut metrics = self.prometheus_metrics.write().await;
if hit {
metrics.cache_hits += 1;
} else {
metrics.cache_misses += 1;
}
}
pub async fn update_provider_health(&self, provider: &str, health_score: f64) {
let mut metrics = self.prometheus_metrics.write().await;
metrics.provider_health.insert(provider.to_string(), health_score);
}
pub async fn export_prometheus(&self) -> String {
let metrics = self.prometheus_metrics.read().await;
let mut output = String::new();
output.push_str("# HELP litellm_requests_total Total number of requests\n");
output.push_str("# TYPE litellm_requests_total counter\n");
for (key, value) in &metrics.request_total {
let parts: Vec<&str> = key.split(':').collect();
if parts.len() == 2 {
output.push_str(&format!(
"litellm_requests_total{{provider=\"{}\",model=\"{}\"}} {}\n",
parts[0], parts[1], value
));
}
}
output.push_str("# HELP litellm_errors_total Total number of errors\n");
output.push_str("# TYPE litellm_errors_total counter\n");
for (key, value) in &metrics.error_total {
let parts: Vec<&str> = key.split(':').collect();
if parts.len() == 2 {
output.push_str(&format!(
"litellm_errors_total{{provider=\"{}\",model=\"{}\"}} {}\n",
parts[0], parts[1], value
));
}
}
output.push_str("# HELP litellm_cache_hits_total Total cache hits\n");
output.push_str("# TYPE litellm_cache_hits_total counter\n");
output.push_str(&format!("litellm_cache_hits_total {}\n", metrics.cache_hits));
output.push_str("# HELP litellm_cache_misses_total Total cache misses\n");
output.push_str("# TYPE litellm_cache_misses_total counter\n");
output.push_str(&format!("litellm_cache_misses_total {}\n", metrics.cache_misses));
output.push_str("# HELP litellm_provider_health Provider health score\n");
output.push_str("# TYPE litellm_provider_health gauge\n");
for (provider, health) in &metrics.provider_health {
output.push_str(&format!(
"litellm_provider_health{{provider=\"{}\"}} {}\n",
provider, health
));
}
output
}
pub async fn send_to_datadog(&self) -> Result<()> {
if let Some(client) = &self.datadog_client {
let metrics = self.prometheus_metrics.read().await;
debug!("Sending metrics to DataDog");
}
Ok(())
}
}
impl LogAggregator {
pub fn new() -> Self {
Self {
destinations: vec![],
buffer: Arc::new(RwLock::new(Vec::new())),
flush_interval: Duration::from_secs(10),
}
}
pub fn add_destination(mut self, destination: LogDestination) -> Self {
self.destinations.push(destination);
self
}
pub async fn log(&self, entry: LogEntry) {
let mut buffer = self.buffer.write().await;
buffer.push(entry);
if buffer.len() >= 100 {
self.flush_buffer().await;
}
}
async fn flush_buffer(&self) {
let mut buffer = self.buffer.write().await;
if buffer.is_empty() {
return;
}
let entries = buffer.drain(..).collect::<Vec<_>>();
drop(buffer);
for destination in &self.destinations {
if let Err(e) = self.send_to_destination(destination, &entries).await {
error!("Failed to send logs to destination: {}", e);
}
}
}
async fn send_to_destination(
&self,
destination: &LogDestination,
entries: &[LogEntry],
) -> Result<()> {
match destination {
LogDestination::Elasticsearch { url, index, auth } => {
debug!("Sending {} logs to Elasticsearch", entries.len());
}
LogDestination::Splunk { url, token, index } => {
debug!("Sending {} logs to Splunk", entries.len());
}
LogDestination::DatadogLogs { api_key, site } => {
debug!("Sending {} logs to Datadog", entries.len());
}
LogDestination::Webhook { url, headers } => {
let client = reqwest::Client::new();
let mut request = client.post(url).json(entries);
for (key, value) in headers {
request = request.header(key, value);
}
request.send().await
.map_err(|e| GatewayError::Network(e.to_string()))?;
}
_ => {
debug!("Sending {} logs to destination", entries.len());
}
}
Ok(())
}
pub async fn start_background_flush(&self) {
let aggregator = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(aggregator.flush_interval);
loop {
interval.tick().await;
aggregator.flush_buffer().await;
}
});
}
}
impl Clone for LogAggregator {
fn clone(&self) -> Self {
Self {
destinations: self.destinations.clone(),
buffer: self.buffer.clone(),
flush_interval: self.flush_interval,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_metrics_collection() {
let collector = MetricsCollector::new();
collector.record_request(
"openai",
"gpt-4",
Duration::from_millis(500),
Some(TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
}),
Some(0.01),
true,
).await;
let prometheus_output = collector.export_prometheus().await;
assert!(prometheus_output.contains("litellm_requests_total"));
assert!(prometheus_output.contains("provider=\"openai\""));
assert!(prometheus_output.contains("model=\"gpt-4\""));
}
#[tokio::test]
async fn test_log_aggregation() {
let aggregator = LogAggregator::new();
let entry = LogEntry {
timestamp: Utc::now(),
level: LogLevel::Info,
message: "Test log entry".to_string(),
request_id: Some("req-123".to_string()),
user_id: Some("user-456".to_string()),
provider: Some("openai".to_string()),
model: Some("gpt-4".to_string()),
duration_ms: Some(500),
tokens: None,
cost: None,
error: None,
fields: HashMap::new(),
};
aggregator.log(entry).await;
let buffer = aggregator.buffer.read().await;
assert_eq!(buffer.len(), 1);
assert_eq!(buffer[0].message, "Test log entry");
}
}