boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! Bucketed distributed trainer with backward/allreduce overlap
//!
//! Uses `backward_with_hooks` to fire allreduce on gradient buckets
//! during the backward pass, overlapping communication with computation.

use std::collections::HashMap;
use std::sync::Arc;

use crate::distributed::bucket_manager::GradientBucketManager;
use crate::distributed::grad_sync::broadcast_params;
use crate::error::Result;
use crate::ops::FusedOptimizerOps;
use crate::trainer::SimpleTrainer;
use crate::trainer::config::{TrainingConfig, TrainingMetrics};
use numr::autograd::{BackwardHook, Var, backward_with_hooks};
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps, UnaryOps};
use numr::runtime::{Communicator, Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};

/// Distributed trainer that overlaps allreduce with the backward pass.
///
/// Unlike `DistributedTrainer` which allreduces all gradients after backward
/// completes, `BucketedTrainer` fires allreduce on each bucket of gradients
/// as soon as all gradients in that bucket are ready during backward.
/// This gives 30-40% throughput improvement from communication/computation
/// overlap (same technique as PyTorch DDP).
///
/// # Usage
///
/// ```ignore
/// // Provide param info in reverse-backward order for max overlap
/// let param_info: Vec<(TensorId, usize, DType)> = model.param_info_reversed();
///
/// let stream_handle = client.compute_stream_handle();
/// let mut trainer = BucketedTrainer::new(
///     config,
///     comm,
///     &param_info,
///     25 * 1024 * 1024, // 25 MiB buckets
///     stream_handle,    // enables compute-comm overlap on CUDA
/// )?;
///
/// trainer.broadcast_params(&params)?;
///
/// for micro_batch in local_data {
///     let loss = forward(micro_batch, &params);
///     if let Some(metrics) = trainer.backward_and_step(&loss, &client, &mut params)? {
///         println!("step {} loss={:.4}", metrics.step, metrics.loss);
///     }
/// }
/// ```
pub struct BucketedTrainer<R: Runtime<DType = DType>> {
    inner: SimpleTrainer<R>,
    bucket_manager: GradientBucketManager<R>,
    comm: Arc<dyn Communicator>,
}

impl<R: Runtime<DType = DType>> BucketedTrainer<R> {
    /// Create a new bucketed distributed trainer.
    ///
    /// # Arguments
    ///
    /// * `config` - Training configuration (LR, weight decay, accumulation, etc.)
    /// * `comm` - Communicator for allreduce operations
    /// * `param_info` - Parameter (id, numel, dtype) in reverse-backward order
    /// * `bucket_size_bytes` - Target bucket size in bytes (25 MiB is a good default)
    /// * `compute_stream_handle` - Optional compute stream handle from
    ///   `RuntimeClient::compute_stream_handle()`. Enables event-based
    ///   compute-communication overlap when the communicator supports it.
    pub fn new(
        config: TrainingConfig,
        comm: Arc<dyn Communicator>,
        param_info: &[(TensorId, usize, DType)],
        bucket_size_bytes: usize,
        compute_stream_handle: Option<u64>,
    ) -> Result<Self> {
        let inner = SimpleTrainer::new(config)?;
        let bucket_manager = GradientBucketManager::new(
            param_info,
            comm.clone(),
            bucket_size_bytes,
            compute_stream_handle,
        );

        Ok(Self {
            inner,
            bucket_manager,
            comm,
        })
    }

    /// Broadcast parameters from rank 0 to all other ranks.
    pub fn broadcast_params(&self, params: &HashMap<TensorId, Tensor<R>>) -> Result<()> {
        broadcast_params(self.comm.as_ref(), params, 0)
    }

    /// Run backward pass with overlapped allreduce, then optimizer step.
    ///
    /// This is the main training step method. It:
    /// 1. Runs backward with hooks that fire allreduce on ready buckets
    /// 2. Waits for all allreduce ops to complete
    /// 3. Unflattens averaged gradients back into the grad store
    /// 4. Delegates to SimpleTrainer for accumulation, clipping, and optimizer step
    pub fn backward_and_step<C>(
        &mut self,
        loss: &Var<R>,
        client: &C,
        params: &mut HashMap<TensorId, Tensor<R>>,
    ) -> Result<Option<TrainingMetrics>>
    where
        C: RuntimeClient<R>
            + TensorOps<R>
            + BinaryOps<R>
            + UnaryOps<R>
            + ScalarOps<R>
            + ReduceOps<R>
            + FusedOptimizerOps<R>,
    {
        // Extract loss value before backward
        let loss_value =
            loss.tensor()
                .item::<f32>()
                .map_err(|e| crate::error::Error::DistributedError {
                    reason: format!("failed to extract scalar loss: {e}"),
                })? as f64;

        // Reset bucket state for this backward pass
        self.bucket_manager.reset();

        // Run backward with hooks — allreduce fires during backward
        //
        // SAFETY: We use a raw pointer to bucket_manager to work around a split-borrow
        // limitation: backward_with_hooks() takes `&mut hook` while the hook itself
        // needs `&mut self.bucket_manager`. The compiler cannot verify that these are
        // disjoint borrows of `self`. This is safe because:
        //
        // 1. backward_with_hooks() borrows the hook (`&mut H`) only for the duration
        //    of the call and does NOT store it or spawn threads with it.
        // 2. The backward pass is single-threaded — numr's backward_with_hooks() does
        //    not invoke the hook concurrently. BucketedTrainer is !Sync (no concurrent
        //    &self access) and the hook is scoped to this function.
        // 3. `manager_ptr` points to `self.bucket_manager` which is a field of
        //    `BucketedTrainer` and outlives the hook (both live on this stack frame).
        // 4. The hook is the sole accessor of bucket_manager during backward — no
        //    other code path touches `self.bucket_manager` until backward returns.
        //
        // If backward_with_hooks() ever becomes multi-threaded or stores the hook,
        // this must be replaced with RefCell or an interior-mutability wrapper.
        let manager_ptr = &mut self.bucket_manager as *mut GradientBucketManager<R>;
        let mut hook = BucketHook::<R, C> {
            manager: manager_ptr,
            client,
        };
        let mut grads = backward_with_hooks(loss, client, &mut hook)?;

        // Wait for all allreduce ops and unflatten into the grad store
        self.bucket_manager.wait_and_unflatten(client, &mut grads)?;

        // Delegate to SimpleTrainer for accumulation + clipping + optimizer
        self.inner.step(client, params, grads, loss_value)
    }

    /// Current global step count.
    pub fn global_step(&self) -> u64 {
        self.inner.global_step()
    }

    /// Reference to the underlying communicator.
    pub fn communicator(&self) -> &dyn Communicator {
        self.comm.as_ref()
    }

    /// Number of gradient buckets.
    pub fn num_buckets(&self) -> usize {
        self.bucket_manager.num_buckets()
    }
}

/// Hook that bridges backward notifications into the bucket manager.
struct BucketHook<'a, R: Runtime, C: RuntimeClient<R> + TensorOps<R>> {
    manager: *mut GradientBucketManager<R>,
    client: &'a C,
}

// SAFETY: BucketHook is only used within a single backward_with_hooks() call,
// which is single-threaded and synchronous. The raw pointer targets a field of
// BucketedTrainer that outlives the hook. No concurrent access is possible:
// - BucketedTrainer is !Sync (has Arc<dyn Communicator> but no shared &self path)
// - backward_with_hooks() does not spawn threads or send the hook across threads
// - R: Send is required because *mut GradientBucketManager<R> contains R-typed data
unsafe impl<R: Runtime + Send, C: RuntimeClient<R> + TensorOps<R>> Send for BucketHook<'_, R, C> {}

impl<R: Runtime<DType = DType>, C: RuntimeClient<R> + TensorOps<R>> BackwardHook<R>
    for BucketHook<'_, R, C>
{
    fn on_leaf_grad_ready(&mut self, id: TensorId, grad: &Tensor<R>) {
        // SAFETY: single-threaded backward pass, manager pointer is valid (see
        // safety comment in backward_and_step).
        let manager = unsafe { &mut *self.manager };

        // BackwardHook::on_leaf_grad_ready cannot return errors, so failures here
        // will surface later in wait_and_unflatten (missing bucket data). In practice
        // mark_grad_ready only fails on dtype mismatch or missing grads — both of
        // which indicate a bug in bucket construction, not a transient error.
        if let Err(_e) = manager.mark_grad_ready(id, grad, self.client) {
            debug_assert!(false, "mark_grad_ready failed for {id:?}: {_e}");
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::distributed::bucket_manager::GradientBucketManager;
    use crate::test_utils::cpu_setup;
    use crate::trainer::config::TrainingConfig;
    use numr::autograd::{backward, var_mul, var_sum};
    use numr::runtime::NoOpCommunicator;
    use numr::runtime::cpu::CpuRuntime;

    #[test]
    fn test_bucketed_trainer_creation() {
        let comm = Arc::new(NoOpCommunicator);
        let config = TrainingConfig {
            learning_rate: 1e-3,
            weight_decay: 0.01,
            grad_accum_steps: 1,
            max_grad_norm: Some(1.0),
        };
        let id = TensorId::new();
        let params = vec![(id, 100, DType::F32)];
        let trainer =
            BucketedTrainer::<CpuRuntime>::new(config, comm, &params, 25 * 1024 * 1024, None)
                .unwrap();
        assert_eq!(trainer.global_step(), 0);
        assert_eq!(trainer.num_buckets(), 1);
    }

    #[test]
    fn test_bucketed_backward_produces_same_grads_as_regular() {
        let (client, device) = cpu_setup();

        // Create a parameter
        let w_id = TensorId::new();
        let w_tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);

        // Regular backward: loss = sum(w * w), grad = 2w
        let w_var = Var::with_id(w_tensor.clone(), w_id, true);
        let sq = var_mul(&w_var, &w_var, &client).unwrap();
        let loss = var_sum(&sq, &[0], false, &client).unwrap();
        let regular_grads = backward(&loss, &client).unwrap();
        let regular_grad: Vec<f32> = regular_grads
            .get(w_id)
            .expect("regular grad should exist")
            .to_vec();

        // Bucketed backward (NoOp comm = world_size 1)
        let w_var2 = Var::with_id(w_tensor.clone(), w_id, true);
        let sq2 = var_mul(&w_var2, &w_var2, &client).unwrap();
        let loss2 = var_sum(&sq2, &[0], false, &client).unwrap();

        let comm = Arc::new(NoOpCommunicator);
        let param_info = vec![(w_id, 3, DType::F32)];
        let mut mgr =
            GradientBucketManager::<CpuRuntime>::new(&param_info, comm, 25 * 1024 * 1024, None);

        let manager_ptr = &mut mgr as *mut GradientBucketManager<CpuRuntime>;
        let mut hook = BucketHook {
            manager: manager_ptr,
            client: &client,
        };
        let mut grads = backward_with_hooks(&loss2, &client, &mut hook).unwrap();
        mgr.wait_and_unflatten(&client, &mut grads).unwrap();

        let bucketed_grad: Vec<f32> = grads
            .get(w_id)
            .expect("bucketed grad should exist")
            .to_vec();

        for (a, b) in regular_grad.iter().zip(bucketed_grad.iter()) {
            assert!(
                (a - b).abs() < 1e-6,
                "Grad mismatch: regular={a}, bucketed={b}"
            );
        }
    }

    #[test]
    fn test_bucketed_trainer_send_sync() {
        fn assert_send_sync<T: Send + Sync>() {}
        assert_send_sync::<BucketedTrainer<CpuRuntime>>();
    }
}