use crate::{FxGraph, Node, TorshResult};
use petgraph::graph::NodeIndex;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use torsh_core::{device::DeviceType, error::TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedConfig {
pub world_size: usize,
pub rank: usize,
pub master_addr: String,
pub master_port: u16,
pub backend: CommunicationBackendType,
pub timeout: u64,
}
impl Default for DistributedConfig {
fn default() -> Self {
Self {
world_size: 1,
rank: 0,
master_addr: "localhost".to_string(),
master_port: 23456,
backend: CommunicationBackendType::Nccl,
timeout: 300,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CommunicationBackendType {
Nccl,
Gloo,
Mpi,
Tcp,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CollectiveOp {
AllReduce,
AllGather,
ReduceScatter,
Broadcast,
Send,
Recv,
Barrier,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ReduceOp {
Sum,
Product,
Min,
Max,
Average,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommOp {
pub op_type: CollectiveOp,
pub reduce_op: Option<ReduceOp>,
pub src_rank: Option<usize>,
pub dst_rank: Option<usize>,
pub tag: u32,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum DistributionStrategy {
DataParallel,
ModelParallel,
PipelineParallel,
HybridParallel,
}
#[derive(Debug, Clone)]
pub struct DeviceMapping {
pub node_to_device: HashMap<NodeIndex, usize>,
pub rank_to_device_type: HashMap<usize, DeviceType>,
pub comm_groups: Vec<Vec<usize>>,
}
#[derive(Debug, Clone)]
pub struct DistributedPartition {
pub nodes: HashSet<NodeIndex>,
pub external_inputs: HashMap<NodeIndex, usize>, pub external_outputs: HashMap<NodeIndex, Vec<usize>>, pub comm_ops: Vec<(NodeIndex, CommOp)>,
pub rank: usize,
}
#[derive(Debug, Clone)]
pub struct DistributedExecutionPlan {
pub partitions: HashMap<usize, DistributedPartition>,
pub execution_order: Vec<Vec<NodeIndex>>, pub comm_schedule: HashMap<usize, Vec<CommOp>>, pub device_mapping: DeviceMapping,
}
pub struct DistributedPartitioner {
config: DistributedConfig,
strategy: DistributionStrategy,
}
impl DistributedPartitioner {
pub fn new(config: DistributedConfig, strategy: DistributionStrategy) -> Self {
Self { config, strategy }
}
pub fn partition(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
match self.strategy {
DistributionStrategy::DataParallel => self.partition_data_parallel(graph),
DistributionStrategy::ModelParallel => self.partition_model_parallel(graph),
DistributionStrategy::PipelineParallel => self.partition_pipeline_parallel(graph),
DistributionStrategy::HybridParallel => self.partition_hybrid_parallel(graph),
}
}
fn partition_data_parallel(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
let mut partitions = HashMap::new();
let mut device_mapping = DeviceMapping {
node_to_device: HashMap::new(),
rank_to_device_type: HashMap::new(),
comm_groups: vec![],
};
for rank in 0..self.config.world_size {
let mut partition = DistributedPartition {
nodes: graph.nodes().map(|(idx, _)| idx).collect(),
external_inputs: HashMap::new(),
external_outputs: HashMap::new(),
comm_ops: vec![],
rank,
};
for (node_idx, node) in graph.nodes() {
match node {
Node::Call(op_name, _)
if op_name.contains("backward") || op_name.contains("grad") =>
{
partition.comm_ops.push((
node_idx,
CommOp {
op_type: CollectiveOp::AllReduce,
reduce_op: Some(ReduceOp::Sum),
src_rank: None,
dst_rank: None,
tag: node_idx.index() as u32,
},
));
}
_ => {}
}
device_mapping.node_to_device.insert(node_idx, rank);
}
device_mapping
.rank_to_device_type
.insert(rank, DeviceType::Cpu);
partitions.insert(rank, partition);
}
device_mapping
.comm_groups
.push((0..self.config.world_size).collect());
Ok(DistributedExecutionPlan {
partitions,
execution_order: self.compute_execution_order(graph)?,
comm_schedule: self.compute_comm_schedule(graph)?,
device_mapping,
})
}
fn partition_model_parallel(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
let nodes: Vec<_> = graph.nodes().collect();
let nodes_per_rank = (nodes.len() + self.config.world_size - 1) / self.config.world_size;
let mut partitions = HashMap::new();
let mut device_mapping = DeviceMapping {
node_to_device: HashMap::new(),
rank_to_device_type: HashMap::new(),
comm_groups: vec![],
};
for rank in 0..self.config.world_size {
let start_idx = rank * nodes_per_rank;
let end_idx = ((rank + 1) * nodes_per_rank).min(nodes.len());
let mut partition = DistributedPartition {
nodes: HashSet::new(),
external_inputs: HashMap::new(),
external_outputs: HashMap::new(),
comm_ops: vec![],
rank,
};
for i in start_idx..end_idx {
let (node_idx, _) = nodes[i];
partition.nodes.insert(node_idx);
device_mapping.node_to_device.insert(node_idx, rank);
}
for &node_idx in &partition.nodes {
let predecessors: Vec<_> = graph
.graph
.neighbors_directed(node_idx, petgraph::Direction::Incoming)
.collect();
for pred_idx in predecessors {
if let Some(&src_rank) = device_mapping.node_to_device.get(&pred_idx) {
if src_rank != rank {
partition.external_inputs.insert(node_idx, src_rank);
partition.comm_ops.push((
node_idx,
CommOp {
op_type: CollectiveOp::Recv,
reduce_op: None,
src_rank: Some(src_rank),
dst_rank: Some(rank),
tag: node_idx.index() as u32,
},
));
}
}
}
let successors: Vec<_> = graph
.graph
.neighbors_directed(node_idx, petgraph::Direction::Outgoing)
.collect();
let mut dst_ranks = vec![];
for succ_idx in successors {
if let Some(&dst_rank) = device_mapping.node_to_device.get(&succ_idx) {
if dst_rank != rank && !dst_ranks.contains(&dst_rank) {
dst_ranks.push(dst_rank);
}
}
}
if !dst_ranks.is_empty() {
partition
.external_outputs
.insert(node_idx, dst_ranks.clone());
for &dst_rank in &dst_ranks {
partition.comm_ops.push((
node_idx,
CommOp {
op_type: CollectiveOp::Send,
reduce_op: None,
src_rank: Some(rank),
dst_rank: Some(dst_rank),
tag: node_idx.index() as u32,
},
));
}
}
}
device_mapping
.rank_to_device_type
.insert(rank, DeviceType::Cpu);
partitions.insert(rank, partition);
}
device_mapping
.comm_groups
.push((0..self.config.world_size).collect());
Ok(DistributedExecutionPlan {
partitions,
execution_order: self.compute_execution_order(graph)?,
comm_schedule: self.compute_comm_schedule(graph)?,
device_mapping,
})
}
fn partition_pipeline_parallel(
&self,
graph: &FxGraph,
) -> TorshResult<DistributedExecutionPlan> {
let execution_order = self.compute_execution_order(graph)?;
let stages_per_rank =
(execution_order.len() + self.config.world_size - 1) / self.config.world_size;
let mut partitions = HashMap::new();
let mut device_mapping = DeviceMapping {
node_to_device: HashMap::new(),
rank_to_device_type: HashMap::new(),
comm_groups: vec![],
};
for rank in 0..self.config.world_size {
let start_stage = rank * stages_per_rank;
let end_stage = ((rank + 1) * stages_per_rank).min(execution_order.len());
let mut partition = DistributedPartition {
nodes: HashSet::new(),
external_inputs: HashMap::new(),
external_outputs: HashMap::new(),
comm_ops: vec![],
rank,
};
for stage_idx in start_stage..end_stage {
for &node_idx in &execution_order[stage_idx] {
partition.nodes.insert(node_idx);
device_mapping.node_to_device.insert(node_idx, rank);
}
}
if rank > 0 {
for &node_idx in &execution_order[start_stage] {
partition.external_inputs.insert(node_idx, rank - 1);
partition.comm_ops.push((
node_idx,
CommOp {
op_type: CollectiveOp::Recv,
reduce_op: None,
src_rank: Some(rank - 1),
dst_rank: Some(rank),
tag: (rank * 1000 + node_idx.index()) as u32,
},
));
}
}
if rank < self.config.world_size - 1 && end_stage < execution_order.len() {
for &node_idx in &execution_order[end_stage - 1] {
partition.external_outputs.insert(node_idx, vec![rank + 1]);
partition.comm_ops.push((
node_idx,
CommOp {
op_type: CollectiveOp::Send,
reduce_op: None,
src_rank: Some(rank),
dst_rank: Some(rank + 1),
tag: ((rank + 1) * 1000 + node_idx.index()) as u32,
},
));
}
}
device_mapping
.rank_to_device_type
.insert(rank, DeviceType::Cpu);
partitions.insert(rank, partition);
}
for rank in 0..self.config.world_size - 1 {
device_mapping.comm_groups.push(vec![rank, rank + 1]);
}
Ok(DistributedExecutionPlan {
partitions,
execution_order,
comm_schedule: self.compute_comm_schedule(graph)?,
device_mapping,
})
}
fn partition_hybrid_parallel(&self, graph: &FxGraph) -> TorshResult<DistributedExecutionPlan> {
if self.config.world_size <= 2 {
self.partition_data_parallel(graph)
} else {
let model_parallel_ranks = self.config.world_size / 2;
let mut base_plan = self.partition_model_parallel(graph)?;
let mut new_partitions = base_plan.partitions.clone();
for rank in model_parallel_ranks..self.config.world_size {
let base_rank = rank % model_parallel_ranks;
if let Some(base_partition) = base_plan.partitions.get(&base_rank) {
let mut new_partition = base_partition.clone();
new_partition.rank = rank;
for (node_idx, node) in graph.nodes() {
if new_partition.nodes.contains(&node_idx) {
if let Node::Call(op_name, _) = node {
if op_name.contains("backward") || op_name.contains("grad") {
new_partition.comm_ops.push((
node_idx,
CommOp {
op_type: CollectiveOp::AllReduce,
reduce_op: Some(ReduceOp::Sum),
src_rank: None,
dst_rank: None,
tag: (rank * 10000 + node_idx.index()) as u32,
},
));
}
}
}
}
new_partitions.insert(rank, new_partition);
}
}
base_plan.partitions = new_partitions;
Ok(base_plan)
}
}
fn compute_execution_order(&self, graph: &FxGraph) -> TorshResult<Vec<Vec<NodeIndex>>> {
use petgraph::algo::toposort;
let topo_order = toposort(&graph.graph, None)
.map_err(|_| TorshError::InvalidArgument("Graph contains cycles".to_string()))?;
let mut stages = vec![];
let mut current_stage = vec![];
let mut processed = HashSet::new();
for node_idx in topo_order {
let predecessors: Vec<_> = graph
.graph
.neighbors_directed(node_idx, petgraph::Direction::Incoming)
.collect();
let can_execute = predecessors.iter().all(|&pred| processed.contains(&pred));
if can_execute || predecessors.is_empty() {
current_stage.push(node_idx);
processed.insert(node_idx);
} else {
if !current_stage.is_empty() {
stages.push(current_stage);
current_stage = vec![];
}
current_stage.push(node_idx);
processed.insert(node_idx);
}
}
if !current_stage.is_empty() {
stages.push(current_stage);
}
Ok(stages)
}
fn compute_comm_schedule(&self, _graph: &FxGraph) -> TorshResult<HashMap<usize, Vec<CommOp>>> {
let mut schedule = HashMap::new();
for rank in 0..self.config.world_size {
schedule.insert(rank, vec![]);
}
Ok(schedule)
}
}
pub struct ProcessGroup {
config: DistributedConfig,
backend: Box<dyn CommunicationBackend + Send + Sync>,
}
pub trait CommunicationBackend {
fn init(&mut self, config: &DistributedConfig) -> TorshResult<()>;
fn finalize(&mut self) -> TorshResult<()>;
fn all_reduce(&self, tensor: &mut Tensor, op: ReduceOp) -> TorshResult<()>;
fn all_gather(&self, input: &Tensor, outputs: &mut [Tensor]) -> TorshResult<()>;
fn broadcast(&self, tensor: &mut Tensor, root: usize) -> TorshResult<()>;
fn send(&self, tensor: &Tensor, dst: usize, tag: u32) -> TorshResult<()>;
fn recv(&self, tensor: &mut Tensor, src: usize, tag: u32) -> TorshResult<()>;
fn barrier(&self) -> TorshResult<()>;
fn rank(&self) -> usize;
fn world_size(&self) -> usize;
}
pub struct TcpBackend {
rank: usize,
world_size: usize,
initialized: bool,
}
impl TcpBackend {
pub fn new() -> Self {
Self {
rank: 0,
world_size: 1,
initialized: false,
}
}
}
impl CommunicationBackend for TcpBackend {
fn init(&mut self, config: &DistributedConfig) -> TorshResult<()> {
self.rank = config.rank;
self.world_size = config.world_size;
self.initialized = true;
Ok(())
}
fn finalize(&mut self) -> TorshResult<()> {
self.initialized = false;
Ok(())
}
fn all_reduce(&self, _tensor: &mut Tensor, _op: ReduceOp) -> TorshResult<()> {
if !self.initialized {
return Err(TorshError::InvalidArgument(
"Backend not initialized".to_string(),
));
}
if self.world_size == 1 {
return Ok(());
}
Ok(())
}
fn all_gather(&self, _input: &Tensor, _outputs: &mut [Tensor]) -> TorshResult<()> {
if !self.initialized {
return Err(TorshError::InvalidArgument(
"Backend not initialized".to_string(),
));
}
Ok(())
}
fn broadcast(&self, _tensor: &mut Tensor, _root: usize) -> TorshResult<()> {
if !self.initialized {
return Err(TorshError::InvalidArgument(
"Backend not initialized".to_string(),
));
}
Ok(())
}
fn send(&self, _tensor: &Tensor, _dst: usize, _tag: u32) -> TorshResult<()> {
if !self.initialized {
return Err(TorshError::InvalidArgument(
"Backend not initialized".to_string(),
));
}
Ok(())
}
fn recv(&self, _tensor: &mut Tensor, _src: usize, _tag: u32) -> TorshResult<()> {
if !self.initialized {
return Err(TorshError::InvalidArgument(
"Backend not initialized".to_string(),
));
}
Ok(())
}
fn barrier(&self) -> TorshResult<()> {
if !self.initialized {
return Err(TorshError::InvalidArgument(
"Backend not initialized".to_string(),
));
}
Ok(())
}
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
}
impl ProcessGroup {
pub fn new(config: DistributedConfig) -> TorshResult<Self> {
let backend: Box<dyn CommunicationBackend + Send + Sync> = match config.backend {
CommunicationBackendType::Tcp => Box::new(TcpBackend::new()),
_ => {
return Err(TorshError::InvalidArgument(format!(
"Backend {:?} not implemented",
config.backend
)));
}
};
Ok(Self { config, backend })
}
pub fn init(&mut self) -> TorshResult<()> {
self.backend.init(&self.config)
}
pub fn finalize(&mut self) -> TorshResult<()> {
self.backend.finalize()
}
pub fn rank(&self) -> usize {
self.backend.rank()
}
pub fn world_size(&self) -> usize {
self.backend.world_size()
}
pub fn execute_collective(&self, op: &CommOp, tensor: &mut Tensor) -> TorshResult<()> {
match op.op_type {
CollectiveOp::AllReduce => {
let reduce_op = op.reduce_op.unwrap_or(ReduceOp::Sum);
self.backend.all_reduce(tensor, reduce_op)
}
CollectiveOp::Broadcast => {
let root = op.src_rank.unwrap_or(0);
self.backend.broadcast(tensor, root)
}
CollectiveOp::Send => {
let dst = op.dst_rank.ok_or_else(|| {
TorshError::InvalidArgument("Send operation requires dst_rank".to_string())
})?;
self.backend.send(tensor, dst, op.tag)
}
CollectiveOp::Recv => {
let src = op.src_rank.ok_or_else(|| {
TorshError::InvalidArgument("Recv operation requires src_rank".to_string())
})?;
self.backend.recv(tensor, src, op.tag)
}
CollectiveOp::Barrier => self.backend.barrier(),
_ => Err(TorshError::InvalidArgument(format!(
"Collective operation {:?} not implemented",
op.op_type
))),
}
}
}
pub struct DistributedExecutor {
config: DistributedConfig,
process_group: Arc<RwLock<ProcessGroup>>,
execution_plan: Option<DistributedExecutionPlan>,
}
impl DistributedExecutor {
pub fn new(config: DistributedConfig) -> TorshResult<Self> {
let process_group = ProcessGroup::new(config.clone())?;
Ok(Self {
config,
process_group: Arc::new(RwLock::new(process_group)),
execution_plan: None,
})
}
pub fn init(&mut self) -> TorshResult<()> {
let mut pg = self
.process_group
.write()
.map_err(|_| TorshError::InvalidArgument("Failed to acquire write lock".to_string()))?;
pg.init()
}
pub fn set_execution_plan(&mut self, plan: DistributedExecutionPlan) {
self.execution_plan = Some(plan);
}
pub fn execute(
&self,
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
) -> TorshResult<Vec<Tensor>> {
let plan = self
.execution_plan
.as_ref()
.ok_or_else(|| TorshError::InvalidArgument("No execution plan set".to_string()))?;
let partition = plan.partitions.get(&self.config.rank).ok_or_else(|| {
TorshError::InvalidArgument(format!("No partition for rank {}", self.config.rank))
})?;
self.execute_partition(graph, partition, inputs)
}
fn execute_partition(
&self,
graph: &FxGraph,
partition: &DistributedPartition,
inputs: HashMap<String, Tensor>,
) -> TorshResult<Vec<Tensor>> {
let mut interpreter = crate::interpreter::GraphInterpreter::new(DeviceType::Cpu);
let local_graph = self.create_local_graph(graph, partition)?;
let mut local_inputs = inputs;
for (&node_idx, &_src_rank) in &partition.external_inputs {
for (comm_node_idx, comm_op) in &partition.comm_ops {
if *comm_node_idx == node_idx && comm_op.op_type == CollectiveOp::Recv {
let placeholder = torsh_tensor::creation::zeros(&[1]);
let node_index = node_idx.index();
local_inputs.insert(format!("external_{node_index}"), placeholder?);
break;
}
}
}
let outputs = interpreter.run(&local_graph, local_inputs)?;
for (&node_idx, _dst_ranks) in &partition.external_outputs {
for (comm_node_idx, comm_op) in &partition.comm_ops {
if *comm_node_idx == node_idx && comm_op.op_type == CollectiveOp::Send {
break;
}
}
}
for (_node_idx, comm_op) in &partition.comm_ops {
match comm_op.op_type {
CollectiveOp::AllReduce | CollectiveOp::Broadcast | CollectiveOp::Barrier => {
let pg = self.process_group.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock".to_string())
})?;
if comm_op.op_type == CollectiveOp::Barrier {
let mut temp_tensor = torsh_tensor::creation::zeros(&[1])?;
pg.execute_collective(comm_op, &mut temp_tensor)?;
}
}
_ => {
}
}
}
Ok(outputs)
}
fn create_local_graph(
&self,
graph: &FxGraph,
_partition: &DistributedPartition,
) -> TorshResult<FxGraph> {
Ok(graph.clone())
}
pub fn finalize(&mut self) -> TorshResult<()> {
let mut pg = self
.process_group
.write()
.map_err(|_| TorshError::InvalidArgument("Failed to acquire write lock".to_string()))?;
pg.finalize()
}
}
pub fn init_distributed(config: DistributedConfig) -> TorshResult<DistributedExecutor> {
let mut executor = DistributedExecutor::new(config)?;
executor.init()?;
Ok(executor)
}
pub fn create_execution_plan(
graph: &FxGraph,
config: DistributedConfig,
strategy: DistributionStrategy,
) -> TorshResult<DistributedExecutionPlan> {
let partitioner = DistributedPartitioner::new(config, strategy);
partitioner.partition(graph)
}
pub fn execute_distributed(
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
config: DistributedConfig,
strategy: DistributionStrategy,
) -> TorshResult<Vec<Tensor>> {
let mut executor = init_distributed(config.clone())?;
let plan = create_execution_plan(graph, config, strategy)?;
executor.set_execution_plan(plan);
let outputs = executor.execute(graph, inputs)?;
executor.finalize()?;
Ok(outputs)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tracer::ModuleTracer;
use torsh_tensor::creation::ones;
#[test]
fn test_distributed_config() {
let config = DistributedConfig::default();
assert_eq!(config.world_size, 1);
assert_eq!(config.rank, 0);
assert_eq!(config.master_addr, "localhost");
}
#[test]
fn test_process_group_creation() {
let config = DistributedConfig::default();
let result = ProcessGroup::new(config);
match result {
Ok(_) => {
}
Err(_) => {
}
}
}
#[test]
fn test_distributed_partitioner_data_parallel() {
let config = DistributedConfig {
world_size: 2,
rank: 0,
..Default::default()
};
let partitioner = DistributedPartitioner::new(config, DistributionStrategy::DataParallel);
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let result = partitioner.partition(&graph);
assert!(result.is_ok());
let plan = result.unwrap();
assert_eq!(plan.partitions.len(), 2);
}
#[test]
fn test_distributed_partitioner_model_parallel() {
let config = DistributedConfig {
world_size: 2,
rank: 0,
..Default::default()
};
let partitioner = DistributedPartitioner::new(config, DistributionStrategy::ModelParallel);
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("linear", vec!["x".to_string()]);
tracer.add_call("relu", vec!["node_0".to_string()]);
tracer.add_output("node_1");
let graph = tracer.finalize();
let result = partitioner.partition(&graph);
assert!(result.is_ok());
let plan = result.unwrap();
assert_eq!(plan.partitions.len(), 2);
}
#[test]
fn test_distributed_executor_creation() {
let config = DistributedConfig::default();
let result = DistributedExecutor::new(config);
match result {
Ok(_) => {
}
Err(_) => {
}
}
}
#[test]
fn test_tcp_backend() {
let mut backend = TcpBackend::new();
let config = DistributedConfig::default();
assert!(backend.init(&config).is_ok());
assert_eq!(backend.rank(), 0);
assert_eq!(backend.world_size(), 1);
assert!(backend.finalize().is_ok());
}
#[test]
fn test_comm_op_serialization() {
let comm_op = CommOp {
op_type: CollectiveOp::AllReduce,
reduce_op: Some(ReduceOp::Sum),
src_rank: None,
dst_rank: None,
tag: 42,
};
let serialized = serde_json::to_string(&comm_op).unwrap();
let deserialized: CommOp = serde_json::from_str(&serialized).unwrap();
assert_eq!(comm_op.tag, deserialized.tag);
match (comm_op.op_type, deserialized.op_type) {
(CollectiveOp::AllReduce, CollectiveOp::AllReduce) => {}
_ => panic!("Serialization failed"),
}
}
#[test]
fn test_execution_plan_creation() {
let config = DistributedConfig {
world_size: 2,
rank: 0,
..Default::default()
};
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let result = create_execution_plan(&graph, config, DistributionStrategy::DataParallel);
assert!(result.is_ok());
}
#[test]
fn test_distributed_execution_single_rank() {
let config = DistributedConfig::default();
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let mut inputs = HashMap::new();
inputs.insert("x".to_string(), ones(&[2, 3]).unwrap());
let result =
execute_distributed(&graph, inputs, config, DistributionStrategy::DataParallel);
match result {
Ok(outputs) => {
assert!(!outputs.is_empty());
}
Err(_) => {
}
}
}
}