use crate::{error::AutogradError, Float, NdArray, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommOp {
Broadcast,
Reduce,
AllReduce,
Gather,
Scatter,
AllToAll,
}
pub struct CommHandle {
pub op: CommOp,
completed: bool,
}
impl CommHandle {
pub fn new(op: CommOp) -> Self {
Self {
op,
completed: false,
}
}
pub fn wait(&mut self) -> Result<()> {
self.completed = true;
Ok(())
}
pub fn is_complete(&self) -> bool {
self.completed
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionStrategy {
None,
Quantize,
Sparsify,
Hybrid,
}
pub fn compress_gradient<T: Float>(
gradient: &NdArray<T>,
strategy: CompressionStrategy,
) -> Result<Vec<u8>> {
match strategy {
CompressionStrategy::None => {
let slice = gradient.as_slice().ok_or_else(|| {
AutogradError::compute_error("Gradient is not contiguous".to_string())
})?;
let bytes: Vec<u8> = slice
.iter()
.flat_map(|&x| {
let f: f64 = x.to_f64().unwrap_or(0.0);
f.to_le_bytes().to_vec()
})
.collect();
Ok(bytes)
}
_ => {
Err(AutogradError::not_implemented(format!(
"Compression strategy {:?} not implemented",
strategy
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_comm_handle() {
let mut handle = CommHandle::new(CommOp::AllReduce);
assert!(!handle.is_complete());
handle.wait().expect("Should wait");
assert!(handle.is_complete());
}
#[test]
fn test_comm_op_equality() {
assert_eq!(CommOp::Broadcast, CommOp::Broadcast);
assert_ne!(CommOp::Broadcast, CommOp::Reduce);
}
}