use crate::advanced_monitoring::{AdvancedMetrics, AdvancedMonitor};
use crate::{TorshDistributedError, TorshResult};
use std::collections::HashMap;
use std::fmt::Write as FmtWrite;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct PrometheusConfig {
pub port: u16,
pub path: String,
pub namespace: String,
pub labels: HashMap<String, String>,
pub enable_histograms: bool,
pub histogram_buckets: Vec<f64>,
}
impl Default for PrometheusConfig {
fn default() -> Self {
Self {
port: 9090,
path: "/metrics".to_string(),
namespace: "torsh".to_string(),
labels: HashMap::new(),
enable_histograms: true,
histogram_buckets: vec![
0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0,
],
}
}
}
impl PrometheusConfig {
pub fn builder() -> PrometheusConfigBuilder {
PrometheusConfigBuilder::default()
}
}
#[derive(Default)]
pub struct PrometheusConfigBuilder {
port: Option<u16>,
path: Option<String>,
namespace: Option<String>,
labels: HashMap<String, String>,
enable_histograms: Option<bool>,
histogram_buckets: Option<Vec<f64>>,
}
impl PrometheusConfigBuilder {
pub fn port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = Some(namespace.into());
self
}
pub fn label(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.labels.insert(key.into(), value.into());
self
}
pub fn enable_histograms(mut self, enable: bool) -> Self {
self.enable_histograms = Some(enable);
self
}
pub fn histogram_buckets(mut self, buckets: Vec<f64>) -> Self {
self.histogram_buckets = Some(buckets);
self
}
pub fn build(self) -> PrometheusConfig {
let default = PrometheusConfig::default();
PrometheusConfig {
port: self.port.unwrap_or(default.port),
path: self.path.unwrap_or(default.path),
namespace: self.namespace.unwrap_or(default.namespace),
labels: self.labels,
enable_histograms: self.enable_histograms.unwrap_or(default.enable_histograms),
histogram_buckets: self.histogram_buckets.unwrap_or(default.histogram_buckets),
}
}
}
pub struct PrometheusExporter {
monitor: Arc<AdvancedMonitor>,
config: PrometheusConfig,
histogram_data: Arc<RwLock<HistogramData>>,
}
#[derive(Default)]
struct HistogramData {
compute_forward_buckets: Vec<(f64, u64)>,
compute_backward_buckets: Vec<(f64, u64)>,
#[allow(dead_code)]
communication_allreduce_buckets: Vec<(f64, u64)>,
#[allow(dead_code)]
communication_broadcast_buckets: Vec<(f64, u64)>,
}
impl PrometheusExporter {
pub fn new(monitor: Arc<AdvancedMonitor>, config: PrometheusConfig) -> TorshResult<Self> {
Ok(Self {
monitor,
config,
histogram_data: Arc::new(RwLock::new(HistogramData::default())),
})
}
pub async fn start(&self) -> TorshResult<()> {
let port = self.config.port;
let path = self.config.path.clone();
let path_for_log = path.clone();
let exporter = self.clone_for_handler();
tokio::spawn(async move {
if let Err(e) = exporter.run_server(port, &path).await {
log::error!("Prometheus exporter server error: {}", e);
}
});
log::info!(
"Prometheus exporter started on port {} at {}",
port,
path_for_log
);
Ok(())
}
fn clone_for_handler(&self) -> Self {
Self {
monitor: Arc::clone(&self.monitor),
config: self.config.clone(),
histogram_data: Arc::clone(&self.histogram_data),
}
}
async fn run_server(&self, port: u16, path: &str) -> TorshResult<()> {
use std::net::SocketAddr;
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let path = path.to_string();
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
TorshDistributedError::io_error(format!("Failed to bind to {}: {}", addr, e))
})?;
log::info!("Prometheus metrics available at http://{}{}", addr, path);
loop {
match listener.accept().await {
Ok((stream, _)) => {
let exporter = self.clone_for_handler();
let path = path.clone();
tokio::spawn(async move {
if let Err(e) = exporter.handle_connection(stream, &path).await {
log::warn!("Error handling metrics request: {}", e);
}
});
}
Err(e) => {
log::error!("Error accepting connection: {}", e);
}
}
}
}
async fn handle_connection(
&self,
mut stream: tokio::net::TcpStream,
expected_path: &str,
) -> TorshResult<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buffer = [0u8; 1024];
let n = stream.read(&mut buffer).await.map_err(|e| {
TorshDistributedError::io_error(format!("Failed to read request: {}", e))
})?;
let request = String::from_utf8_lossy(&buffer[..n]);
if let Some(first_line) = request.lines().next() {
let parts: Vec<&str> = first_line.split_whitespace().collect();
if parts.len() >= 2 && parts[0] == "GET" {
let requested_path = parts[1];
if requested_path == expected_path {
let metrics = self.export_metrics().await?;
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/plain; version=0.0.4\r\nContent-Length: {}\r\n\r\n{}",
metrics.len(),
metrics
);
stream.write_all(response.as_bytes()).await.map_err(|e| {
TorshDistributedError::io_error(format!("Failed to write response: {}", e))
})?;
return Ok(());
}
}
}
let not_found = "HTTP/1.1 404 Not Found\r\nContent-Length: 9\r\n\r\nNot Found";
stream.write_all(not_found.as_bytes()).await.map_err(|e| {
TorshDistributedError::io_error(format!("Failed to write 404 response: {}", e))
})?;
Ok(())
}
pub async fn export_metrics(&self) -> TorshResult<String> {
let mut output = String::with_capacity(8192);
let metrics = self.monitor.get_latest_metrics().await?;
let namespace = &self.config.namespace;
let format_labels = |rank: u32| -> String {
let mut labels = vec![format!("rank=\"{}\"", rank)];
for (key, value) in &self.config.labels {
labels.push(format!("{}=\"{}\"", key, value));
}
labels.join(",")
};
writeln!(
output,
"# HELP {}_compute_forward_time_ms Forward pass computation time in milliseconds",
namespace
)
.expect("writeln to String should not fail");
writeln!(output, "# TYPE {}_compute_forward_time_ms gauge", namespace)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_compute_forward_time_ms{{{}}} {}",
namespace,
format_labels(*rank),
metric.compute.forward_time_ms
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_compute_backward_time_ms Backward pass computation time in milliseconds",
namespace
)
.expect("writeln to String should not fail");
writeln!(
output,
"# TYPE {}_compute_backward_time_ms gauge",
namespace
)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_compute_backward_time_ms{{{}}} {}",
namespace,
format_labels(*rank),
metric.compute.backward_time_ms
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_communication_allreduce_time_ms All-reduce operation time in milliseconds",
namespace
)
.expect("writeln to String should not fail");
writeln!(
output,
"# TYPE {}_communication_allreduce_time_ms gauge",
namespace
)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_communication_allreduce_time_ms{{{}}} {}",
namespace,
format_labels(*rank),
metric.communication.all_reduce_time_ms
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_communication_broadcast_time_ms Broadcast operation time in milliseconds",
namespace
)
.expect("writeln to String should not fail");
writeln!(
output,
"# TYPE {}_communication_broadcast_time_ms gauge",
namespace
)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_communication_broadcast_time_ms{{{}}} {}",
namespace,
format_labels(*rank),
metric.communication.broadcast_time_ms
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_memory_gpu_used_mb GPU memory used in megabytes",
namespace
)
.expect("writeln to String should not fail");
writeln!(output, "# TYPE {}_memory_gpu_used_mb gauge", namespace)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_memory_gpu_used_mb{{{}}} {}",
namespace,
format_labels(*rank),
metric.memory.gpu_memory_used_mb
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_memory_peak_mb Peak memory usage in megabytes",
namespace
)
.expect("writeln to String should not fail");
writeln!(output, "# TYPE {}_memory_peak_mb gauge", namespace)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_memory_peak_mb{{{}}} {}",
namespace,
format_labels(*rank),
metric.memory.peak_memory_mb
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_io_data_load_time_ms Data loading time in milliseconds",
namespace
)
.expect("writeln to String should not fail");
writeln!(output, "# TYPE {}_io_data_load_time_ms gauge", namespace)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_io_data_load_time_ms{{{}}} {}",
namespace,
format_labels(*rank),
metric.io.data_load_time_ms
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_io_disk_read_mbps Disk read throughput in MB/s",
namespace
)
.expect("writeln to String should not fail");
writeln!(output, "# TYPE {}_io_disk_read_mbps gauge", namespace)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_io_disk_read_mbps{{{}}} {}",
namespace,
format_labels(*rank),
metric.io.disk_read_mbps
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_io_disk_write_mbps Disk write throughput in MB/s",
namespace
)
.expect("writeln to String should not fail");
writeln!(output, "# TYPE {}_io_disk_write_mbps gauge", namespace)
.expect("writeln to String should not fail");
for (rank, metric) in &metrics {
writeln!(
output,
"{}_io_disk_write_mbps{{{}}} {}",
namespace,
format_labels(*rank),
metric.io.disk_write_mbps
)
.expect("writeln to String should not fail");
}
if self.config.enable_histograms {
self.export_histograms(&mut output, &metrics, namespace)
.await?;
}
Ok(output)
}
async fn export_histograms(
&self,
output: &mut String,
metrics: &HashMap<u32, AdvancedMetrics>,
namespace: &str,
) -> TorshResult<()> {
self.update_histogram_data(metrics).await;
let histogram_data = self.histogram_data.read().await;
let _buckets = &self.config.histogram_buckets;
writeln!(
output,
"# HELP {}_compute_forward_time_histogram_ms Forward pass time distribution",
namespace
)
.expect("writeln to String should not fail");
writeln!(
output,
"# TYPE {}_compute_forward_time_histogram_ms histogram",
namespace
)
.expect("writeln to String should not fail");
for (bucket, count) in &histogram_data.compute_forward_buckets {
writeln!(
output,
"{}_compute_forward_time_histogram_ms_bucket{{le=\"{}\"}} {}",
namespace, bucket, count
)
.expect("writeln to String should not fail");
}
writeln!(
output,
"# HELP {}_compute_backward_time_histogram_ms Backward pass time distribution",
namespace
)
.expect("writeln to String should not fail");
writeln!(
output,
"# TYPE {}_compute_backward_time_histogram_ms histogram",
namespace
)
.expect("writeln to String should not fail");
for (bucket, count) in &histogram_data.compute_backward_buckets {
writeln!(
output,
"{}_compute_backward_time_histogram_ms_bucket{{le=\"{}\"}} {}",
namespace, bucket, count
)
.expect("writeln to String should not fail");
}
Ok(())
}
async fn update_histogram_data(&self, metrics: &HashMap<u32, AdvancedMetrics>) {
let mut histogram_data = self.histogram_data.write().await;
let buckets = &self.config.histogram_buckets;
histogram_data.compute_forward_buckets.clear();
histogram_data.compute_backward_buckets.clear();
for &boundary in buckets {
histogram_data.compute_forward_buckets.push((boundary, 0));
histogram_data.compute_backward_buckets.push((boundary, 0));
}
for metric in metrics.values() {
for (boundary, count) in &mut histogram_data.compute_forward_buckets {
if metric.compute.forward_time_ms <= *boundary {
*count += 1;
}
}
for (boundary, count) in &mut histogram_data.compute_backward_buckets {
if metric.compute.backward_time_ms <= *boundary {
*count += 1;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::advanced_monitoring::{
AdvancedMetrics, CommunicationMetrics, ComputeMetrics, IoMetrics, MemoryMetrics,
};
use crate::backend::BackendType;
use crate::init_process_group;
#[tokio::test]
async fn test_prometheus_config_builder() {
let config = PrometheusConfig::builder()
.port(9091)
.path("/custom_metrics")
.namespace("test")
.label("env", "dev")
.label("cluster", "test-cluster")
.enable_histograms(false)
.build();
assert_eq!(config.port, 9091);
assert_eq!(config.path, "/custom_metrics");
assert_eq!(config.namespace, "test");
assert_eq!(config.labels.get("env"), Some(&"dev".to_string()));
assert_eq!(
config.labels.get("cluster"),
Some(&"test-cluster".to_string())
);
assert!(!config.enable_histograms);
}
#[tokio::test]
async fn test_prometheus_exporter_creation() {
let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
.await
.expect("writeln to String should not fail");
let monitor = Arc::new(AdvancedMonitor::new(Arc::new(pg)));
let config = PrometheusConfig::default();
let exporter = PrometheusExporter::new(monitor, config);
assert!(exporter.is_ok());
}
#[tokio::test]
async fn test_metrics_export_format() {
let pg = init_process_group(BackendType::Gloo, 0, 2, "127.0.0.1", 29500)
.await
.expect("writeln to String should not fail");
let monitor = Arc::new(AdvancedMonitor::new(Arc::new(pg)));
let test_metrics = AdvancedMetrics {
timestamp: std::time::Duration::from_secs(0),
compute: ComputeMetrics {
forward_time_ms: 10.5,
backward_time_ms: 15.2,
optimizer_time_ms: 2.3,
gpu_utilization: 85.0,
cpu_utilization: 60.0,
tensor_core_utilization: 75.0,
gflops: 100.5,
},
communication: CommunicationMetrics {
all_reduce_time_ms: 8.7,
broadcast_time_ms: 3.2,
all_gather_time_ms: 1.5,
bandwidth_mbps: 1024.0,
comm_comp_ratio: 0.3,
num_operations: 100,
avg_message_size: 10240,
},
memory: MemoryMetrics {
gpu_memory_used_mb: 512.0,
gpu_memory_total_mb: 1024.0,
cpu_memory_used_mb: 2048.0,
memory_bandwidth_gbps: 10.0,
num_allocations: 50,
peak_memory_mb: 768.0,
},
io: IoMetrics {
data_load_time_ms: 20.0,
disk_read_mbps: 100.0,
disk_write_mbps: 50.0,
preprocessing_time_ms: 5.0,
},
custom: HashMap::new(),
};
monitor
.record_metrics(test_metrics)
.expect("writeln to String should not fail");
let config = PrometheusConfig::builder()
.namespace("test")
.enable_histograms(false)
.build();
let exporter =
PrometheusExporter::new(monitor, config).expect("writeln to String should not fail");
let output = exporter
.export_metrics()
.await
.expect("writeln to String should not fail");
assert!(output.contains("# HELP test_compute_forward_time_ms"));
assert!(output.contains("# TYPE test_compute_forward_time_ms gauge"));
assert!(output.contains("test_compute_forward_time_ms{rank=\"0\"} 10.5"));
assert!(output.contains("test_compute_backward_time_ms{rank=\"0\"} 15.2"));
assert!(output.contains("test_communication_allreduce_time_ms{rank=\"0\"} 8.7"));
assert!(output.contains("test_memory_gpu_used_mb{rank=\"0\"} 512"));
}
#[tokio::test]
async fn test_custom_labels() {
let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
.await
.expect("writeln to String should not fail");
let monitor = Arc::new(AdvancedMonitor::new(Arc::new(pg)));
let test_metrics = AdvancedMetrics {
timestamp: std::time::Duration::from_secs(0),
compute: ComputeMetrics {
forward_time_ms: 10.0,
backward_time_ms: 15.0,
optimizer_time_ms: 2.0,
gpu_utilization: 85.0,
cpu_utilization: 60.0,
tensor_core_utilization: 75.0,
gflops: 100.0,
},
communication: CommunicationMetrics {
all_reduce_time_ms: 8.0,
broadcast_time_ms: 3.0,
all_gather_time_ms: 1.0,
bandwidth_mbps: 1024.0,
comm_comp_ratio: 0.3,
num_operations: 100,
avg_message_size: 10240,
},
memory: MemoryMetrics {
gpu_memory_used_mb: 512.0,
gpu_memory_total_mb: 1024.0,
cpu_memory_used_mb: 2048.0,
memory_bandwidth_gbps: 10.0,
num_allocations: 50,
peak_memory_mb: 768.0,
},
io: IoMetrics {
data_load_time_ms: 20.0,
disk_read_mbps: 100.0,
disk_write_mbps: 50.0,
preprocessing_time_ms: 5.0,
},
custom: HashMap::new(),
};
monitor
.record_metrics(test_metrics)
.expect("writeln to String should not fail");
let config = PrometheusConfig::builder()
.label("environment", "production")
.label("cluster", "gpu-cluster-1")
.enable_histograms(false)
.build();
let exporter =
PrometheusExporter::new(monitor, config).expect("writeln to String should not fail");
let output = exporter
.export_metrics()
.await
.expect("writeln to String should not fail");
assert!(output.contains("environment=\"production\""));
assert!(output.contains("cluster=\"gpu-cluster-1\""));
}
}