pub trait GradientSync: Send {
// Required method
fn sync_gradients(&self, gradients: &mut Array1<f32>) -> ModelResult<()>;
// Provided methods
fn is_distributed(&self) -> bool { ... }
fn num_workers(&self) -> usize { ... }
}Expand description
Trait for gradient synchronization strategies.
Implementations are responsible for aggregating gradients across workers (e.g., averaging in all-reduce) and writing the result back in-place.
Required Methods§
Sourcefn sync_gradients(&self, gradients: &mut Array1<f32>) -> ModelResult<()>
fn sync_gradients(&self, gradients: &mut Array1<f32>) -> ModelResult<()>
Synchronize (aggregate) gradients across all workers.
On return gradients holds the post-synchronization values.
Provided Methods§
Sourcefn is_distributed(&self) -> bool
fn is_distributed(&self) -> bool
Returns true if this sync implementation involves multiple workers.
Sourcefn num_workers(&self) -> usize
fn num_workers(&self) -> usize
Number of workers participating in synchronization.