pub(crate) mod api;
pub(crate) mod client;
mod ops;
pub(crate) mod server;
pub use api::*;
pub use ops::*;
use serde::{Deserialize, Serialize};
use crate::tensor::FloatTensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct DistributedParamId(u64);
impl From<u64> for DistributedParamId {
fn from(value: u64) -> Self {
Self(value)
}
}
#[derive(Debug, Clone)]
pub struct DistributedParams {
pub param_id: DistributedParamId,
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
pub enum ReduceOperation {
Sum,
Mean,
}
#[derive(Clone)]
pub struct DistributedConfig {
pub all_reduce_op: ReduceOperation,
}
#[derive(new, Clone)]
pub struct CollectiveTensor<B: DistributedBackend> {
handle: FloatTensor<B>,
}
impl<B: DistributedBackend> CollectiveTensor<B> {
pub fn resolve(self) -> FloatTensor<B> {
B::sync_collective(&B::float_device(&self.handle));
self.handle
}
pub unsafe fn assume_resolved(self) -> FloatTensor<B> {
self.handle
}
}