pub struct ElChe {
world_size: usize,
anchor: usize,
batch_counts: Vec<usize>,
ms_per_batch: Vec<f64>,
calibrated: bool,
overhead_target: f64,
min_anchor: usize,
max_anchor: usize,
max_batch_diff: Option<usize>,
}
impl ElChe {
pub fn new(world_size: usize, anchor: usize) -> Self {
assert!(world_size >= 2, "El Che requires at least 2 devices");
assert!(anchor >= 1, "anchor must be >= 1");
ElChe {
world_size,
anchor,
batch_counts: vec![anchor; world_size],
ms_per_batch: vec![0.0; world_size],
calibrated: false,
overhead_target: 0.10,
min_anchor: anchor,
max_anchor: 200,
max_batch_diff: None,
}
}
pub fn with_overhead_target(mut self, target: f64) -> Self {
self.overhead_target = target.clamp(0.01, 0.50);
self
}
pub fn with_max_anchor(mut self, max: usize) -> Self {
self.max_anchor = max.max(1);
if self.min_anchor > self.max_anchor {
self.min_anchor = self.max_anchor;
self.anchor = self.anchor.clamp(self.min_anchor, self.max_anchor);
}
self
}
pub fn with_max_batch_diff(mut self, max: usize) -> Self {
self.max_batch_diff = Some(max);
self
}
pub fn max_batch_diff(&self) -> Option<usize> {
self.max_batch_diff
}
pub fn with_speed_ratio(mut self, slow_rank: usize, ratio: f64) -> Self {
assert!(
slow_rank < self.world_size,
"slow_rank ({slow_rank}) out of bounds for world_size ({})",
self.world_size,
);
let ratio = ratio.max(1.0);
for rank in 0..self.world_size {
if rank == slow_rank {
self.batch_counts[rank] = self.anchor;
} else {
self.batch_counts[rank] =
(self.anchor as f64 * ratio).round().max(1.0) as usize;
}
}
self
}
pub fn batches(&self, rank: usize) -> usize {
self.batch_counts[rank]
}
pub fn batch_counts(&self) -> &[usize] {
&self.batch_counts
}
pub fn total_batches(&self) -> usize {
self.batch_counts.iter().sum()
}
pub fn anchor(&self) -> usize {
self.anchor
}
pub fn anchor_wall_ms(&self) -> f64 {
if !self.calibrated {
return 0.0;
}
let slow_ms = self.ms_per_batch.iter().copied().fold(0.0_f64, f64::max);
self.anchor as f64 * slow_ms
}
pub fn nudge_anchor_down(&mut self, factor: f64) {
let new = (self.anchor as f64 * factor.clamp(0.1, 1.0)).ceil() as usize;
self.anchor = new.max(1).min(self.anchor);
let slow_ms = self.ms_per_batch
.iter()
.copied()
.fold(0.0_f64, f64::max);
if slow_ms > 0.0 {
self.recompute_batch_counts(slow_ms);
}
}
pub fn is_calibrated(&self) -> bool {
self.calibrated
}
pub fn has_speed_hint(&self) -> bool {
self.batch_counts.windows(2).any(|w| w[0] != w[1])
}
pub fn ms_per_batch(&self) -> &[f64] {
&self.ms_per_batch
}
pub fn report_timing(&mut self, wall_ms: &[f64], actual_batches: &[usize], sync_ms: f64) {
assert_eq!(
wall_ms.len(),
self.world_size,
"wall_ms length must match world_size",
);
for (rank, &wall) in wall_ms.iter().enumerate() {
let n = actual_batches.get(rank).copied().unwrap_or(0);
if n > 0 && wall > 0.0 {
let new_ms = wall / n as f64;
self.ms_per_batch[rank] = if self.calibrated && self.ms_per_batch[rank] > 0.0 {
let error = (new_ms - self.ms_per_batch[rank]).abs()
/ self.ms_per_batch[rank];
let alpha = error.clamp(0.1, 0.8);
alpha * new_ms + (1.0 - alpha) * self.ms_per_batch[rank]
} else {
new_ms
};
}
}
let slow_ms = self
.ms_per_batch
.iter()
.copied()
.fold(0.0_f64, f64::max);
if slow_ms <= 0.0 {
return; }
let compute_ms = wall_ms
.iter()
.copied()
.fold(0.0_f64, f64::max);
if compute_ms > 0.0 && sync_ms > 0.0 {
let overhead = sync_ms / compute_ms;
if overhead > self.overhead_target {
let scale = overhead / self.overhead_target;
let new_anchor =
(self.anchor as f64 * scale).ceil() as usize;
self.anchor =
new_anchor.clamp(self.min_anchor, self.max_anchor);
} else if overhead < self.overhead_target * 0.5
&& self.anchor > self.min_anchor {
self.anchor -= 1;
}
}
self.recompute_batch_counts(slow_ms);
self.calibrated = true;
}
pub fn clamp_total(&self, max_total: usize) -> Vec<usize> {
let current_total = self.total_batches();
if current_total <= max_total {
return self.batch_counts.clone();
}
let scale = max_total as f64 / current_total as f64;
let mut clamped: Vec<usize> = self
.batch_counts
.iter()
.map(|&n| (n as f64 * scale).floor().max(1.0) as usize)
.collect();
let sum: usize = clamped.iter().sum();
let mut remainder = max_total.saturating_sub(sum);
for c in &mut clamped {
if remainder == 0 {
break;
}
*c += 1;
remainder -= 1;
}
clamped
}
fn recompute_batch_counts(&mut self, slow_ms: f64) {
for rank in 0..self.world_size {
let ms = self.ms_per_batch[rank];
let target = if ms <= 0.0 || (ms - slow_ms).abs() < 1e-6 {
self.anchor
} else {
let ratio = slow_ms / ms;
(self.anchor as f64 * ratio).round().max(1.0) as usize
};
let current = self.batch_counts[rank];
let diff = (target as f64 - current as f64).abs();
if diff > current as f64 * 0.10 || !self.calibrated {
let clamped = match self.max_batch_diff {
Some(max) if self.calibrated => {
if target > current {
current.saturating_add(max).min(target)
} else {
current.saturating_sub(max).max(target).max(1)
}
}
_ => target,
};
self.batch_counts[rank] = clamped;
}
}
}
}