use std::path::{Path, PathBuf};
use serde_json::Value;
use super::ddp_metrics_classifier::{
classify_allreduce_bandwidth, classify_loss_parity, classify_scaling_efficiency,
DdpAllreduceOutcome, DdpLossParityOutcome, DdpScalingOutcome, D11_DEFAULT_LOSS_TOLERANCE,
D11_DEFAULT_SCALING_FLOOR,
};
use crate::error::{CliError, Result};
pub(crate) fn run(
metrics_1gpu_file: &Path,
metrics_ngpu_file: &Path,
world_size: i64,
scaling_floor: f64,
loss_tolerance: f64,
json: bool,
) -> Result<()> {
let body_1 = load_json(metrics_1gpu_file)?;
let body_n = load_json(metrics_ngpu_file)?;
let scaling = classify_scaling_efficiency(&body_1, &body_n, world_size, scaling_floor);
let parity = classify_loss_parity(&body_1, &body_n, loss_tolerance);
let allreduce = classify_allreduce_bandwidth(&body_n);
print_report(
metrics_1gpu_file,
metrics_ngpu_file,
&scaling,
&parity,
&allreduce,
json,
);
if !matches!(scaling, DdpScalingOutcome::Ok { .. }) {
return Err(CliError::ValidationFailed(format!(
"ddp-metrics-lint scaling-efficiency gate rejected: {scaling:?}"
)));
}
if !matches!(parity, DdpLossParityOutcome::Ok { .. }) {
return Err(CliError::ValidationFailed(format!(
"ddp-metrics-lint loss-parity gate rejected: {parity:?}"
)));
}
if !matches!(allreduce, DdpAllreduceOutcome::Ok { .. }) {
return Err(CliError::ValidationFailed(format!(
"ddp-metrics-lint allreduce-bandwidth gate rejected: {allreduce:?}"
)));
}
Ok(())
}
fn load_json(path: &Path) -> Result<Value> {
if !path.exists() {
return Err(CliError::FileNotFound(PathBuf::from(path)));
}
let body_text = std::fs::read_to_string(path)?;
serde_json::from_str(&body_text).map_err(|e| {
CliError::InvalidFormat(format!(
"apr ddp-metrics-lint: failed to parse JSON from {}: {e}",
path.display()
))
})
}
#[allow(clippy::too_many_arguments)]
fn print_report(
file1: &Path,
file_n: &Path,
scaling: &DdpScalingOutcome,
parity: &DdpLossParityOutcome,
allreduce: &DdpAllreduceOutcome,
json: bool,
) {
if json {
let obj = serde_json::json!({
"metrics_1gpu_file": file1.display().to_string(),
"metrics_ngpu_file": file_n.display().to_string(),
"scaling_efficiency": format!("{scaling:?}"),
"loss_parity": format!("{parity:?}"),
"allreduce_bandwidth": format!("{allreduce:?}"),
});
println!("{}", serde_json::to_string_pretty(&obj).unwrap_or_default());
return;
}
println!("ddp-metrics-lint report");
println!(" metrics_1gpu_file : {}", file1.display());
println!(" metrics_ngpu_file : {}", file_n.display());
println!(" scaling_efficiency : {scaling:?}");
println!(" loss_parity : {parity:?}");
println!(" allreduce_bandwidth : {allreduce:?}");
}
pub const DDP_METRICS_DEFAULT_SCALING_FLOOR: f64 = D11_DEFAULT_SCALING_FLOOR;
pub const DDP_METRICS_DEFAULT_LOSS_TOLERANCE: f64 = D11_DEFAULT_LOSS_TOLERANCE;