use pyo3::prelude::*;
use pyo3::types::PyAny;
use crate::{tensor::PyTensor, error::PyResult};
#[pyclass(name = "ProcessGroup")]
pub struct PyProcessGroup {
rank: u32,
world_size: u32,
}
#[pymethods]
impl PyProcessGroup {
#[new]
fn new(rank: u32, world_size: u32) -> Self {
Self { rank, world_size }
}
#[getter]
fn rank(&self) -> u32 {
self.rank
}
#[getter]
fn world_size(&self) -> u32 {
self.world_size
}
fn all_reduce(&self, tensor: &PyTensor, op: Option<String>) -> PyResult<()> {
Ok(())
}
fn all_gather(&self, tensors: Vec<PyTensor>, tensor: &PyTensor) -> PyResult<()> {
Ok(())
}
fn broadcast(&self, tensor: &PyTensor, src: u32) -> PyResult<()> {
Ok(())
}
fn barrier(&self) -> PyResult<()> {
Ok(())
}
}
#[pyclass(name = "DistributedDataParallel")]
pub struct PyDDP {
module: Py<PyAny>,
process_group: Option<PyProcessGroup>,
}
#[pymethods]
impl PyDDP {
#[new]
fn new(
module: Py<PyAny>,
device_ids: Option<Vec<u32>>,
output_device: Option<u32>,
broadcast_buffers: Option<bool>,
process_group: Option<PyProcessGroup>,
bucket_cap_mb: Option<f32>,
find_unused_parameters: Option<bool>,
check_reduction: Option<bool>,
gradient_as_bucket_view: Option<bool>,
) -> Self {
Self {
module,
process_group,
}
}
fn forward(&self, inputs: Vec<PyTensor>) -> PyResult<PyTensor> {
Python::attach(|py| {
let forward_method = self.module.getattr(py, "forward")?;
let result = forward_method.call1(py, PyTuple::new(py, &inputs))?;
result.extract::<PyTensor>(py)
})
}
fn __call__(&self, inputs: Vec<PyTensor>) -> PyResult<PyTensor> {
self.forward(inputs)
}
fn parameters(&self) -> PyResult<Vec<PyTensor>> {
Python::attach(|py| {
let params_method = self.module.getattr(py, "parameters")?;
let result = params_method.call0(py)?;
result.extract::<Vec<PyTensor>>(py)
})
}
fn named_parameters(&self) -> PyResult<std::collections::HashMap<String, PyTensor>> {
Python::attach(|py| {
let named_params_method = self.module.getattr(py, "named_parameters")?;
let result = named_params_method.call0(py)?;
result.extract::<std::collections::HashMap<String, PyTensor>>(py)
})
}
fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
Python::attach(|py| {
let train_method = self.module.getattr(py, "train")?;
train_method.call1(py, (mode.unwrap_or(true),))?;
Ok(())
})
}
fn eval(&mut self) -> PyResult<()> {
self.train(Some(false))
}
}
pub fn register_distributed_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyProcessGroup>()?;
m.add_class::<PyDDP>()?;
#[pyfn(m)]
fn init_process_group(
backend: String,
init_method: Option<String>,
world_size: Option<u32>,
rank: Option<u32>,
store: Option<Py<PyAny>>,
timeout: Option<f64>,
group_name: Option<String>,
pg_options: Option<Py<PyAny>>,
) -> PyResult<PyProcessGroup> {
let rank = rank.unwrap_or(0);
let world_size = world_size.unwrap_or(1);
Ok(PyProcessGroup::new(rank, world_size))
}
#[pyfn(m)]
fn destroy_process_group(group: Option<PyProcessGroup>) -> PyResult<()> {
Ok(())
}
#[pyfn(m)]
fn get_rank(group: Option<PyProcessGroup>) -> u32 {
group.map(|g| g.rank()).unwrap_or(0)
}
#[pyfn(m)]
fn get_world_size(group: Option<PyProcessGroup>) -> u32 {
group.map(|g| g.world_size()).unwrap_or(1)
}
#[pyfn(m)]
fn is_initialized() -> bool {
false
}
#[pyfn(m)]
fn is_available() -> bool {
true
}
#[pyfn(m)]
fn barrier(group: Option<PyProcessGroup>) -> PyResult<()> {
if let Some(g) = group {
g.barrier()
} else {
Ok(())
}
}
#[pyfn(m)]
fn all_reduce(tensor: &PyTensor, op: Option<String>, group: Option<PyProcessGroup>) -> PyResult<()> {
if let Some(g) = group {
g.all_reduce(tensor, op)
} else {
Ok(())
}
}
#[pyfn(m)]
fn all_gather(tensor_list: Vec<PyTensor>, tensor: &PyTensor, group: Option<PyProcessGroup>) -> PyResult<()> {
if let Some(g) = group {
g.all_gather(tensor_list, tensor)
} else {
Ok(())
}
}
#[pyfn(m)]
fn broadcast(tensor: &PyTensor, src: u32, group: Option<PyProcessGroup>) -> PyResult<()> {
if let Some(g) = group {
g.broadcast(tensor, src)
} else {
Ok(())
}
}
#[pyfn(m)]
fn reduce(tensor: &PyTensor, dst: u32, op: Option<String>, group: Option<PyProcessGroup>) -> PyResult<()> {
Ok(())
}
#[pyfn(m)]
fn scatter(tensor: &PyTensor, scatter_list: Option<Vec<PyTensor>>, src: u32, group: Option<PyProcessGroup>) -> PyResult<()> {
Ok(())
}
#[pyfn(m)]
fn gather(tensor: &PyTensor, gather_list: Option<Vec<PyTensor>>, dst: u32, group: Option<PyProcessGroup>) -> PyResult<()> {
Ok(())
}
Ok(())
}