use serde_json::Value;
pub const D11_DEFAULT_SCALING_FLOOR: f64 = 0.85;
pub const D11_DEFAULT_LOSS_TOLERANCE: f64 = 0.01;
#[derive(Debug, Clone, PartialEq)]
pub enum DdpScalingOutcome {
Ok {
efficiency: f64,
},
MissingTokensPerSec {
which: &'static str,
},
NonPositiveTokensPerSec {
which: &'static str,
got: f64,
},
InvalidWorldSize {
got: i64,
},
BelowThreshold {
world_size: i64,
t1: f64,
tn: f64,
efficiency: f64,
threshold: f64,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum DdpLossParityOutcome {
Ok {
rel_diff: f64,
},
MissingFinalLoss {
which: &'static str,
},
NonPositiveBaselineLoss {
got: f64,
},
Divergence {
l1: f64,
ln: f64,
rel_diff: f64,
tolerance: f64,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum DdpAllreduceOutcome {
Ok { steps: usize },
MissingDdpMetrics,
MissingBandwidthArray,
EmptyBandwidthArray,
NonPositiveBandwidth { step_index: usize, got: f64 },
}
pub fn classify_scaling_efficiency(
metrics_1gpu: &Value,
metrics_ngpu: &Value,
world_size: i64,
threshold: f64,
) -> DdpScalingOutcome {
if world_size <= 1 {
return DdpScalingOutcome::InvalidWorldSize { got: world_size };
}
let Some(t1) = metrics_1gpu.get("tokens_per_sec").and_then(Value::as_f64) else {
return DdpScalingOutcome::MissingTokensPerSec {
which: "metrics_1gpu",
};
};
let Some(tn) = metrics_ngpu.get("tokens_per_sec").and_then(Value::as_f64) else {
return DdpScalingOutcome::MissingTokensPerSec {
which: "metrics_ngpu",
};
};
if t1 <= 0.0 {
return DdpScalingOutcome::NonPositiveTokensPerSec {
which: "metrics_1gpu",
got: t1,
};
}
if tn <= 0.0 {
return DdpScalingOutcome::NonPositiveTokensPerSec {
which: "metrics_ngpu",
got: tn,
};
}
let efficiency = tn / (world_size as f64 * t1);
if efficiency < threshold {
return DdpScalingOutcome::BelowThreshold {
world_size,
t1,
tn,
efficiency,
threshold,
};
}
DdpScalingOutcome::Ok { efficiency }
}
pub fn classify_loss_parity(
metrics_1gpu: &Value,
metrics_ngpu: &Value,
tolerance: f64,
) -> DdpLossParityOutcome {
let Some(l1) = metrics_1gpu.get("final_loss").and_then(Value::as_f64) else {
return DdpLossParityOutcome::MissingFinalLoss {
which: "metrics_1gpu",
};
};
let Some(ln) = metrics_ngpu.get("final_loss").and_then(Value::as_f64) else {
return DdpLossParityOutcome::MissingFinalLoss {
which: "metrics_ngpu",
};
};
if l1 <= 0.0 {
return DdpLossParityOutcome::NonPositiveBaselineLoss { got: l1 };
}
let rel_diff = (ln - l1).abs() / l1;
if rel_diff > tolerance {
return DdpLossParityOutcome::Divergence {
l1,
ln,
rel_diff,
tolerance,
};
}
DdpLossParityOutcome::Ok { rel_diff }
}
pub fn classify_allreduce_bandwidth(metrics: &Value) -> DdpAllreduceOutcome {
let Some(ddp) = metrics.get("ddp_metrics") else {
return DdpAllreduceOutcome::MissingDdpMetrics;
};
let Some(arr) = ddp
.get("allreduce_bandwidth_gbps")
.and_then(Value::as_array)
else {
return DdpAllreduceOutcome::MissingBandwidthArray;
};
if arr.is_empty() {
return DdpAllreduceOutcome::EmptyBandwidthArray;
}
for (i, v) in arr.iter().enumerate() {
let bw = v.as_f64().unwrap_or(0.0);
if bw <= 0.0 {
return DdpAllreduceOutcome::NonPositiveBandwidth {
step_index: i,
got: bw,
};
}
}
DdpAllreduceOutcome::Ok { steps: arr.len() }
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn t1_body() -> Value {
json!({
"tokens_per_sec": 1000.0,
"final_loss": 2.5,
"ddp_metrics": {"allreduce_bandwidth_gbps": [120.0, 118.0, 121.0]}
})
}
fn tn_body_efficient() -> Value {
json!({
"tokens_per_sec": 3500.0,
"final_loss": 2.51,
"ddp_metrics": {"allreduce_bandwidth_gbps": [80.0, 82.0, 79.0]}
})
}
fn tn_body_inefficient() -> Value {
json!({
"tokens_per_sec": 2000.0,
"final_loss": 2.51,
"ddp_metrics": {"allreduce_bandwidth_gbps": [40.0]}
})
}
#[test]
fn scaling_ok_at_0_875() {
match classify_scaling_efficiency(&t1_body(), &tn_body_efficient(), 4, 0.85) {
DdpScalingOutcome::Ok { efficiency } => {
assert!((efficiency - 0.875).abs() < 1e-6, "got {efficiency}");
}
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn scaling_rejects_below_threshold() {
match classify_scaling_efficiency(&t1_body(), &tn_body_inefficient(), 4, 0.85) {
DdpScalingOutcome::BelowThreshold { efficiency, .. } => {
assert!((efficiency - 0.5).abs() < 1e-6, "got {efficiency}");
}
other => panic!("expected BelowThreshold, got {other:?}"),
}
}
#[test]
fn scaling_rejects_invalid_world_size() {
let out = classify_scaling_efficiency(&t1_body(), &tn_body_efficient(), 1, 0.85);
assert!(matches!(
out,
DdpScalingOutcome::InvalidWorldSize { got: 1 }
));
}
#[test]
fn scaling_rejects_missing_tokens_per_sec() {
let empty = json!({});
assert!(matches!(
classify_scaling_efficiency(&empty, &tn_body_efficient(), 4, 0.85),
DdpScalingOutcome::MissingTokensPerSec {
which: "metrics_1gpu"
}
));
}
#[test]
fn scaling_rejects_zero_tokens_per_sec() {
let zero = json!({"tokens_per_sec": 0.0});
assert!(matches!(
classify_scaling_efficiency(&t1_body(), &zero, 4, 0.85),
DdpScalingOutcome::NonPositiveTokensPerSec {
which: "metrics_ngpu",
..
}
));
}
#[test]
fn loss_parity_ok_within_tolerance() {
let out = classify_loss_parity(&t1_body(), &tn_body_efficient(), 0.01);
match out {
DdpLossParityOutcome::Ok { rel_diff } => {
assert!((rel_diff - 0.004).abs() < 1e-6, "got {rel_diff}");
}
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn loss_parity_rejects_divergence() {
let diverged = json!({"final_loss": 5.0, "tokens_per_sec": 1.0});
assert!(matches!(
classify_loss_parity(&t1_body(), &diverged, 0.01),
DdpLossParityOutcome::Divergence { .. }
));
}
#[test]
fn loss_parity_rejects_missing_final_loss() {
let empty = json!({});
assert!(matches!(
classify_loss_parity(&empty, &tn_body_efficient(), 0.01),
DdpLossParityOutcome::MissingFinalLoss {
which: "metrics_1gpu"
}
));
}
#[test]
fn allreduce_ok_on_positive_array() {
match classify_allreduce_bandwidth(&tn_body_efficient()) {
DdpAllreduceOutcome::Ok { steps } => assert_eq!(steps, 3),
other => panic!("expected Ok(3), got {other:?}"),
}
}
#[test]
fn allreduce_rejects_missing_ddp_metrics() {
let empty = json!({});
assert_eq!(
classify_allreduce_bandwidth(&empty),
DdpAllreduceOutcome::MissingDdpMetrics
);
}
#[test]
fn allreduce_rejects_empty_array() {
let body = json!({"ddp_metrics": {"allreduce_bandwidth_gbps": []}});
assert_eq!(
classify_allreduce_bandwidth(&body),
DdpAllreduceOutcome::EmptyBandwidthArray
);
}
#[test]
fn allreduce_rejects_zero_bandwidth() {
let body = json!({"ddp_metrics": {"allreduce_bandwidth_gbps": [120.0, 0.0, 121.0]}});
match classify_allreduce_bandwidth(&body) {
DdpAllreduceOutcome::NonPositiveBandwidth { step_index: 1, .. } => {}
other => panic!("expected NonPositiveBandwidth(step 1), got {other:?}"),
}
}
}