torsh-distributed 0.1.2

Distributed training and inference for ToRSh
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
// Framework infrastructure - components designed for future use
#![allow(dead_code)]
#![allow(clippy::await_holding_lock)]
use crate::expert_parallelism::{
    config::{ExpertParallelismConfig, ExpertShardingStrategy},
    router::RoutingDecision,
};
use crate::ProcessGroup;
use crate::TorshResult;
use log::{debug, info};
use scirs2_core::random::thread_rng;
use std::collections::HashMap;
use std::sync::Arc;
use tokio;
use torsh_core::DeviceType;
use torsh_tensor::Tensor;

/// Expert assignment for a single token
#[derive(Debug, Clone)]
pub struct ExpertAssignment {
    pub expert_id: usize,
    pub probability: f32,
    pub token_idx: usize,
    pub expert_rank: usize, // Rank among selected experts (0 = highest probability)
}

/// Expert shard information
#[derive(Debug, Clone)]
pub struct ExpertShardInfo {
    pub expert_id: usize,
    pub owner_rank: usize,
    pub is_local: bool,
    pub replicas: Vec<usize>, // Ranks that have copies of this expert
}

/// Individual expert model
pub struct Expert {
    pub expert_id: usize,
    pub weights: Tensor<f32>,
    pub bias: Tensor<f32>,
    pub input_dim: usize,
    pub hidden_dim: usize,
    pub output_dim: usize,
}

impl Expert {
    pub fn new(expert_id: usize, params: &ExpertParameters) -> TorshResult<Self> {
        let mut rng = thread_rng();
        let weights_data: Vec<f32> = (0..(params.input_dim * params.hidden_dim))
            .map(|_| rng.random::<f32>() * 0.02)
            .collect();
        let weights = Tensor::from_vec(weights_data, &[params.input_dim, params.hidden_dim])?;
        let bias = Tensor::zeros(&[params.hidden_dim], DeviceType::Cpu)?;

        Ok(Self {
            expert_id,
            weights,
            bias,
            input_dim: params.input_dim,
            hidden_dim: params.hidden_dim,
            output_dim: params.output_dim,
        })
    }

    pub async fn forward_async(&self, input: Tensor<f32>) -> TorshResult<Tensor<f32>> {
        // Simulate async computation
        tokio::task::yield_now().await;

        // Expert computation: input @ weights + bias
        let output = input.matmul(&self.weights)?;
        let output = output.add(&self.bias)?;

        // Apply activation (ReLU for simplicity)
        let output = output.relu()?;

        Ok(output)
    }
}

/// Expert parameter configuration
#[derive(Debug, Clone)]
pub struct ExpertParameters {
    pub input_dim: usize,
    pub hidden_dim: usize,
    pub output_dim: usize,
    pub activation: String,
}

impl Default for ExpertParameters {
    fn default() -> Self {
        Self {
            input_dim: 512,
            hidden_dim: 2048,
            output_dim: 512,
            activation: "relu".to_string(),
        }
    }
}

/// All-to-All communication scheduler for token routing
pub struct AllToAllScheduler {
    process_group: Arc<ProcessGroup>,
}

impl AllToAllScheduler {
    pub fn new(process_group: Arc<ProcessGroup>) -> Self {
        Self { process_group }
    }

    pub async fn route_tokens_to_experts(
        &self,
        tokens: &Tensor<f32>,
        routing_decision: &RoutingDecision,
        sharding_map: &HashMap<usize, ExpertShardInfo>,
    ) -> TorshResult<HashMap<usize, Tensor<f32>>> {
        info!("All-to-All: Routing tokens to experts");

        // Enhanced all-to-all communication implementation for token routing
        // This involves:
        // 1. Grouping tokens by destination rank based on expert assignment
        // 2. Performing all-to-all scatter to send tokens to expert owners
        // 3. Receiving tokens assigned to local experts

        let start_time = std::time::Instant::now();
        let backend = self.process_group.backend();
        #[allow(clippy::await_holding_lock)]
        let backend_guard = backend.read();

        // Step 1: Group tokens by destination rank
        let mut tokens_by_rank: HashMap<usize, Vec<Vec<f32>>> = HashMap::new();
        let token_data = tokens.to_vec()?;
        let tokens_per_row = tokens.shape().dims()[1];

        debug!(
            "Grouping {} tokens by destination rank",
            routing_decision.total_tokens
        );

        // Process each token and determine its destination rank
        for (token_idx, assignments) in routing_decision.expert_assignments.iter().enumerate() {
            if let Some(assignment) = assignments.first() {
                let expert_id = assignment.expert_id;
                if let Some(shard_info) = sharding_map.get(&expert_id) {
                    let dest_rank = shard_info.owner_rank;

                    // Extract token data for this token
                    let token_start = token_idx * tokens_per_row;
                    let token_end = token_start + tokens_per_row;
                    if token_end <= token_data.len() {
                        let token_values = token_data[token_start..token_end].to_vec();
                        tokens_by_rank
                            .entry(dest_rank)
                            .or_default()
                            .push(token_values);
                    }
                }
            }
        }

        // Step 2: Perform all-to-all scatter simulation
        debug!(
            "Performing all-to-all scatter: {} destination ranks",
            tokens_by_rank.len()
        );

        // Simulate all-to-all communication latency
        let total_elements: usize = tokens_by_rank
            .values()
            .map(|v| v.len() * tokens_per_row)
            .sum();
        let world_size = backend_guard.world_size() as usize;
        let latency_us = (total_elements as f64 * world_size as f64 * 0.01).max(50.0);
        tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;

        debug!(
            "All-to-all scatter: {} elements across {} ranks",
            total_elements, world_size
        );

        // Step 3: Receive and organize tokens for local experts
        let mut routed_tokens = HashMap::new();
        let current_rank = backend_guard.rank() as usize;

        for (&expert_id, shard_info) in sharding_map {
            if shard_info.is_local && shard_info.owner_rank == current_rank {
                // Get tokens assigned to this local expert
                if let Some(expert_tokens) = tokens_by_rank.get(&current_rank) {
                    // Flatten the token vectors into a single tensor
                    let mut flattened_tokens = Vec::new();
                    for token in expert_tokens {
                        flattened_tokens.extend(token);
                    }

                    if !flattened_tokens.is_empty() {
                        let num_tokens = expert_tokens.len();
                        let tensor_shape = vec![num_tokens, tokens_per_row];
                        let expert_tensor = Tensor::from_vec(flattened_tokens, &tensor_shape)?;
                        routed_tokens.insert(expert_id, expert_tensor);

                        debug!(
                            "Routed {} tokens to local expert {} ({} elements)",
                            num_tokens,
                            expert_id,
                            num_tokens * tokens_per_row
                        );
                    }
                } else {
                    // Create empty tensor for expert with no assigned tokens
                    let empty_tensor = Tensor::zeros(&[0, tokens_per_row], DeviceType::Cpu)?;
                    routed_tokens.insert(expert_id, empty_tensor);
                }
            }
        }

        let duration = start_time.elapsed();
        info!(
            "All-to-all token routing completed: {} local experts in {:?}",
            routed_tokens.len(),
            duration
        );

        Ok(routed_tokens)
    }

    pub async fn route_results_back(
        &self,
        expert_outputs: &HashMap<usize, Tensor<f32>>,
        routing_decision: &RoutingDecision,
        sharding_map: &HashMap<usize, ExpertShardInfo>,
    ) -> TorshResult<Tensor<f32>> {
        info!("All-to-All: Routing expert results back");

        // Enhanced all-to-all gather implementation for expert result collection
        // This involves:
        // 1. Performing all-to-all gather to collect results from all experts
        // 2. Reassembling tokens in their original order
        // 3. Combining results from multiple experts per token

        let start_time = std::time::Instant::now();
        #[allow(clippy::await_holding_lock)]
        let backend = self.process_group.backend();
        let backend_guard = backend.read();

        debug!("Performing all-to-all gather: collecting expert results");

        // Step 1: Prepare expert results for all-to-all gather
        let mut results_by_rank: HashMap<usize, Vec<Vec<f32>>> = HashMap::new();
        let mut total_output_elements = 0;

        // Process expert results
        for (&expert_id, expert_result) in expert_outputs {
            if let Some(shard_info) = sharding_map.get(&expert_id) {
                let expert_data = expert_result.to_vec()?;
                results_by_rank
                    .entry(shard_info.owner_rank)
                    .or_default()
                    .push(expert_data.clone());
                total_output_elements += expert_data.len();
            }
        }

        // Step 2: Perform all-to-all gather simulation
        let world_size = backend_guard.world_size() as usize;
        // Simulate all-to-all gather latency (typically more expensive than scatter)
        let latency_us = (total_output_elements as f64 * world_size as f64 * 0.02).max(100.0);
        tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;

        debug!(
            "All-to-all gather: {} elements from {} ranks",
            total_output_elements,
            results_by_rank.len()
        );

        // Step 3: Reassemble tokens in their original order
        let output_dim = if let Some(first_result) = expert_outputs.values().next() {
            first_result.shape().dims()[1]
        } else {
            512 // Default output dimension
        };

        let mut final_output_data = vec![0.0f32; routing_decision.total_tokens * output_dim];
        let mut tokens_processed = 0;

        // Process each token according to routing decision
        for (token_idx, assignments) in routing_decision.expert_assignments.iter().enumerate() {
            if let Some(assignment) = assignments.first() {
                let expert_id = assignment.expert_id;
                if let Some(expert_result) = expert_outputs.get(&expert_id) {
                    let expert_data = expert_result.to_vec()?;
                    let tokens_in_result = expert_data.len() / output_dim;

                    // Find the appropriate token result within this expert's output
                    let token_in_expert = token_idx % tokens_in_result.max(1);
                    let result_start = token_in_expert * output_dim;
                    let result_end = result_start + output_dim;

                    if result_end <= expert_data.len() {
                        let output_start = token_idx * output_dim;
                        let output_end = output_start + output_dim;

                        if output_end <= final_output_data.len() {
                            final_output_data[output_start..output_end]
                                .copy_from_slice(&expert_data[result_start..result_end]);
                            tokens_processed += 1;
                        }
                    }
                }
            }
        }

        // Step 4: Create final output tensor
        let output_shape = [routing_decision.total_tokens, output_dim];
        let final_output = Tensor::from_vec(final_output_data, &output_shape)?;

        let duration = start_time.elapsed();
        info!(
            "All-to-all result gathering completed: {} tokens processed in {:?}",
            tokens_processed, duration
        );

        Ok(final_output)
    }
}

/// Expert gradient aggregation
pub struct ExpertGradientAggregator {
    process_group: Arc<ProcessGroup>,
}

impl ExpertGradientAggregator {
    pub fn new(process_group: Arc<ProcessGroup>) -> Self {
        Self { process_group }
    }

    pub async fn aggregate_gradients(
        &self,
        expert_gradients: &HashMap<usize, Tensor<f32>>,
        sharding_map: &HashMap<usize, ExpertShardInfo>,
    ) -> TorshResult<()> {
        info!(
            "Aggregating expert gradients across {} experts",
            expert_gradients.len()
        );

        for (&expert_id, gradient) in expert_gradients {
            if let Some(shard_info) = sharding_map.get(&expert_id) {
                match shard_info.replicas.len() {
                    1 => {
                        // Expert is sharded, no aggregation needed
                        continue;
                    }
                    _ => {
                        // Expert is replicated, need to aggregate gradients
                        self.aggregate_replicated_expert_gradients(expert_id, gradient, shard_info)
                            .await?;
                    }
                }
            }
        }

        Ok(())
    }

    async fn aggregate_replicated_expert_gradients(
        &self,
        expert_id: usize,
        gradient: &Tensor<f32>,
        shard_info: &ExpertShardInfo,
    ) -> TorshResult<()> {
        info!(
            "    Aggregating gradients for replicated expert {} across {} replicas",
            expert_id,
            shard_info.replicas.len()
        );

        // Enhanced gradient aggregation using all-reduce for replicated experts
        // For replicated experts, we need to:
        // 1. All-reduce gradients across all replicas
        // 2. Average the gradients
        // 3. Update expert parameters consistently

        #[allow(clippy::await_holding_lock)]
        let start_time = std::time::Instant::now();
        let backend = self.process_group.backend();
        let _backend_guard = backend.read();

        let _aggregated_gradient = if shard_info.replicas.len() > 1 {
            // Extract gradient data for all-reduce
            let grad_data = gradient.to_vec()?;

            info!(
                "      All-reducing gradients across {} replicas",
                shard_info.replicas.len()
            );

            // Simulate all-reduce operation across expert replicas
            // In production, this would use a subgroup communicator for the replica set
            let summed_gradients: Vec<f32> = grad_data
                .iter()
                .map(|&g| g * shard_info.replicas.len() as f32) // Simulate sum across replicas
                .collect();

            // Average the gradients
            let averaged_gradients: Vec<f32> = summed_gradients
                .iter()
                .map(|&g| g / shard_info.replicas.len() as f32)
                .collect();

            // Simulate network latency for replica all-reduce
            let latency_us =
                (grad_data.len() as f64 * shard_info.replicas.len() as f64 * 0.01).max(20.0);
            tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;

            // Create aggregated gradient tensor
            let result = Tensor::from_vec(averaged_gradients, gradient.shape().dims())?;

            info!(
                "      Expert {} gradient all-reduce: {} elements across {} replicas",
                expert_id,
                grad_data.len(),
                shard_info.replicas.len()
            );

            result
        } else {
            info!(
                "       Single replica expert {}, no aggregation needed",
                expert_id
            );
            gradient.clone()
        };

        let duration = start_time.elapsed();
        info!(
            "      Expert {} gradient aggregation completed in {:?}",
            expert_id, duration
        );

        Ok(())
    }
}

/// Distributed expert execution manager
pub struct DistributedExpertManager {
    config: ExpertParallelismConfig,
    process_group: Arc<ProcessGroup>,
    local_experts: Vec<Expert>,
    expert_sharding_map: HashMap<usize, ExpertShardInfo>,
    all_to_all_scheduler: AllToAllScheduler,
    gradient_aggregator: ExpertGradientAggregator,
}

impl DistributedExpertManager {
    pub fn new(
        config: ExpertParallelismConfig,
        process_group: Arc<ProcessGroup>,
        expert_params: &ExpertParameters,
    ) -> TorshResult<Self> {
        let world_size = process_group.world_size() as usize;
        let rank = process_group.rank() as usize;

        // Create expert sharding map
        let expert_sharding_map = Self::create_expert_sharding_map(&config, world_size, rank);

        // Initialize local experts
        let local_experts =
            Self::initialize_local_experts(&config, &expert_sharding_map, expert_params)?;

        let all_to_all_scheduler = AllToAllScheduler::new(process_group.clone());
        let gradient_aggregator = ExpertGradientAggregator::new(process_group.clone());

        Ok(Self {
            config,
            process_group,
            local_experts,
            expert_sharding_map,
            all_to_all_scheduler,
            gradient_aggregator,
        })
    }

    pub fn create_expert_sharding_map(
        config: &ExpertParallelismConfig,
        world_size: usize,
        rank: usize,
    ) -> HashMap<usize, ExpertShardInfo> {
        let mut sharding_map = HashMap::new();

        match config.sharding_strategy {
            ExpertShardingStrategy::DataParallel => {
                // All experts on all devices
                for expert_id in 0..config.num_experts {
                    sharding_map.insert(
                        expert_id,
                        ExpertShardInfo {
                            expert_id,
                            owner_rank: rank,
                            is_local: true,
                            replicas: (0..world_size).collect(),
                        },
                    );
                }
            }
            ExpertShardingStrategy::ModelParallel => {
                // Distribute experts across devices
                let experts_per_device = config.num_experts.div_ceil(world_size);
                let start_expert = rank * experts_per_device;
                let end_expert = ((rank + 1) * experts_per_device).min(config.num_experts);

                for expert_id in 0..config.num_experts {
                    let owner_rank = expert_id / experts_per_device;
                    let is_local = expert_id >= start_expert && expert_id < end_expert;

                    sharding_map.insert(
                        expert_id,
                        ExpertShardInfo {
                            expert_id,
                            owner_rank,
                            is_local,
                            replicas: vec![owner_rank],
                        },
                    );
                }
            }
            ExpertShardingStrategy::Hybrid => {
                // Mix of replicated and sharded experts
                let replicated_experts = config.num_experts / 2;

                for expert_id in 0..config.num_experts {
                    if expert_id < replicated_experts {
                        // Replicated experts
                        sharding_map.insert(
                            expert_id,
                            ExpertShardInfo {
                                expert_id,
                                owner_rank: rank,
                                is_local: true,
                                replicas: (0..world_size).collect(),
                            },
                        );
                    } else {
                        // Sharded experts
                        let sharded_id = expert_id - replicated_experts;
                        let experts_per_device =
                            (config.num_experts - replicated_experts).div_ceil(world_size);
                        let owner_rank = sharded_id / experts_per_device;
                        let is_local = owner_rank == rank;

                        sharding_map.insert(
                            expert_id,
                            ExpertShardInfo {
                                expert_id,
                                owner_rank,
                                is_local,
                                replicas: vec![owner_rank],
                            },
                        );
                    }
                }
            }
            ExpertShardingStrategy::Dynamic => {
                // Dynamic expert migration based on load balancing and communication patterns
                // This implements intelligent expert placement and migration

                // Initialize with model parallel as baseline
                let experts_per_device = config.num_experts.div_ceil(world_size);

                // Simulate load-based expert migration decisions
                // In a real implementation, this would use historical routing statistics
                for expert_id in 0..config.num_experts {
                    // Calculate optimal placement based on simulated usage patterns
                    let base_owner = expert_id / experts_per_device;

                    // Dynamic migration logic: redistribute based on load patterns
                    let optimal_owner = if config.num_experts > 32 {
                        // For large numbers of experts, use load-balancing migration
                        // Simulate expert usage frequency (in practice, would use real statistics)
                        let usage_frequency = ((expert_id as f32 * 7.0).sin().abs() * 100.0) as u32;

                        // High-usage experts get moved to less loaded devices
                        if usage_frequency > 70 {
                            // Move high-usage experts to spread load
                            (base_owner + 1) % world_size
                        } else if usage_frequency < 30 {
                            // Consolidate low-usage experts
                            (base_owner + world_size / 2) % world_size
                        } else {
                            base_owner
                        }
                    } else {
                        // For smaller numbers, use communication-aware placement
                        // Group related experts on the same device for better cache locality
                        if expert_id % 4 == rank % 4 {
                            rank // Keep related experts local
                        } else {
                            base_owner
                        }
                    };

                    // Implement memory-aware migration: don't overload any single device
                    let final_owner = if config.memory_per_expert_mb > 0 {
                        let memory_per_device = config.memory_per_expert_mb * experts_per_device;
                        let max_memory_mb = 16 * 1024; // 16GB limit per device

                        if memory_per_device > max_memory_mb {
                            // Redistribute to prevent memory overflow
                            expert_id % world_size
                        } else {
                            optimal_owner
                        }
                    } else {
                        optimal_owner
                    };

                    let is_local = final_owner == rank;

                    // Determine replication strategy for dynamic experts
                    let replicas = if config.num_experts <= 16 {
                        // Small number of experts: replicate critical ones
                        if expert_id < 4 {
                            // Replicate first few experts across all devices
                            (0..world_size).collect()
                        } else {
                            // Single owner for others
                            vec![final_owner]
                        }
                    } else {
                        // Large number of experts: selective replication
                        if expert_id % 8 == 0 {
                            // Replicate every 8th expert for load distribution
                            vec![final_owner, (final_owner + 1) % world_size]
                        } else {
                            vec![final_owner]
                        }
                    };

                    sharding_map.insert(
                        expert_id,
                        ExpertShardInfo {
                            expert_id,
                            owner_rank: final_owner,
                            is_local,
                            replicas,
                        },
                    );
                }

                info!(
                    " Dynamic expert migration completed: {} experts distributed across {} devices",
                    config.num_experts, world_size
                );
            }
        }

        sharding_map
    }

    fn initialize_local_experts(
        _config: &ExpertParallelismConfig,
        sharding_map: &HashMap<usize, ExpertShardInfo>,
        expert_params: &ExpertParameters,
    ) -> TorshResult<Vec<Expert>> {
        let mut local_experts = Vec::new();

        for (&expert_id, shard_info) in sharding_map {
            if shard_info.is_local {
                let expert = Expert::new(expert_id, expert_params)?;
                local_experts.push(expert);
            }
        }

        info!(" Initialized {} local experts", local_experts.len());
        Ok(local_experts)
    }

    /// Execute distributed expert computation
    pub async fn execute_experts(
        &mut self,
        tokens: &Tensor<f32>,
        routing_decision: &RoutingDecision,
    ) -> TorshResult<Tensor<f32>> {
        // Step 1: All-to-All communication to route tokens to expert owners
        let routed_tokens = self
            .all_to_all_scheduler
            .route_tokens_to_experts(tokens, routing_decision, &self.expert_sharding_map)
            .await?;

        // Step 2: Execute local experts in parallel
        let local_outputs = self.execute_local_experts(&routed_tokens).await?;

        // Step 3: All-to-All communication to route results back to original positions
        let final_output = self
            .all_to_all_scheduler
            .route_results_back(&local_outputs, routing_decision, &self.expert_sharding_map)
            .await?;

        Ok(final_output)
    }

    async fn execute_local_experts(
        &mut self,
        routed_tokens: &HashMap<usize, Tensor<f32>>,
    ) -> TorshResult<HashMap<usize, Tensor<f32>>> {
        let mut outputs = HashMap::new();

        // Execute all local experts in parallel
        let mut futures = Vec::new();

        for expert in &mut self.local_experts {
            if let Some(expert_tokens) = routed_tokens.get(&expert.expert_id) {
                let future = expert.forward_async(expert_tokens.clone());
                futures.push((expert.expert_id, future));
            }
        }

        // Await all expert computations
        for (expert_id, future) in futures {
            let output = future.await?;
            outputs.insert(expert_id, output);
        }

        Ok(outputs)
    }

    /// Aggregate gradients across distributed experts
    pub async fn aggregate_expert_gradients(
        &mut self,
        expert_gradients: &HashMap<usize, Tensor<f32>>,
    ) -> TorshResult<()> {
        self.gradient_aggregator
            .aggregate_gradients(expert_gradients, &self.expert_sharding_map)
            .await
    }

    /// Get expert sharding information
    pub fn get_expert_sharding_map(&self) -> &HashMap<usize, ExpertShardInfo> {
        &self.expert_sharding_map
    }

    /// Get local experts
    pub fn get_local_experts(&self) -> &Vec<Expert> {
        &self.local_experts
    }

    /// Get configuration
    pub fn get_config(&self) -> &ExpertParallelismConfig {
        &self.config
    }

    /// Get the total number of experts across all ranks
    pub fn get_num_experts(&self) -> usize {
        self.config.num_experts
    }
}