rustorch 0.6.29

Production-ready PyTorch-compatible deep learning library in Rust with special mathematical functions (gamma, Bessel, error functions), statistical distributions, Fourier transforms (FFT/RFFT), matrix decomposition (SVD/QR/LU/eigenvalue), automatic differentiation, neural networks, computer vision transforms, complete GPU acceleration (CUDA/Metal/OpenCL), SIMD optimizations, parallel processing, WebAssembly browser support, comprehensive distributed learning support, and performance validation
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
//! Distributed training support for RusTorch
//! RusTorchの分散学習サポート
//!
//! This module provides comprehensive distributed training capabilities including:
//! - Data parallel training across multiple GPUs
//! - Model parallel training for large models
//! - Multi-machine cluster support
//! - Communication backends (NCCL, Gloo, MPI)
//! - Gradient synchronization and aggregation
//!
//! このモジュールは包括的な分散学習機能を提供します:
//! - 複数GPU間でのデータ並列学習
//! - 大規模モデル向けのモデル並列学習
//! - 複数マシンクラスターサポート
//! - 通信バックエンド(NCCL、Gloo、MPI)
//! - 勾配同期と集約

use crate::autograd::Variable;
use crate::error::RusTorchResult;
use crate::gpu::DeviceType;
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

// Removed problematic DistributedFloat type alias - use DistributedScalar trait instead

/// Type alias for distributed-compatible float types
/// 分散互換フロート型の型エイリアス  
pub trait DistributedScalar:
    Float + Send + Sync + 'static + std::fmt::Debug + ndarray::ScalarOperand + num_traits::FromPrimitive
{
}

// Implement DistributedScalar for standard float types
impl DistributedScalar for f32 {}
impl DistributedScalar for f64 {}

/// Common trait for distributed data parallel implementations
/// 分散データ並列実装の共通トレイト
pub trait DistributedDataParallelTrait<T: DistributedScalar> {
    /// Get device IDs for this DDP instance
    /// このDDPインスタンスのデバイスIDを取得
    fn device_ids(&self) -> &[usize];

    /// Perform distributed forward pass
    /// 分散フォワードパスを実行
    fn distributed_forward(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>>;

    /// Synchronize gradients across processes
    /// プロセス間での勾配同期
    fn sync_gradients(&self) -> RusTorchResult<()>;
}

// Re-export multi-GPU validation components
pub use multi_gpu_validation::{
    BenchmarkResults, GpuDeviceInfo, MemoryUsage, MultiGpuValidator, ValidationMetrics,
};

/// Distributed backend types
/// 分散バックエンドタイプ
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistributedBackend {
    /// NVIDIA Collective Communications Library
    /// NVIDIA集合通信ライブラリ
    NCCL,
    /// Facebook's collective communications library
    /// Facebookの集合通信ライブラリ
    Gloo,
    /// Message Passing Interface
    /// メッセージパッシングインターフェース
    MPI,
    /// Custom TCP backend
    /// カスタムTCPバックエンド
    TCP,
}

/// Process group for distributed training
/// 分散学習用プロセスグループ
#[derive(Debug, Clone)]
pub struct ProcessGroup {
    /// Rank of current process
    /// 現在のプロセスのランク
    pub rank: usize,
    /// Total number of processes
    /// プロセス総数
    pub world_size: usize,
    /// Backend used for communication
    /// 通信に使用するバックエンド
    pub backend: DistributedBackend,
    /// Master address for coordination
    /// 調整用マスターアドレス
    pub master_addr: String,
    /// Master port for coordination
    /// 調整用マスターポート
    pub master_port: u16,
}

impl ProcessGroup {
    /// Create a new process group
    /// 新しいプロセスグループを作成
    pub fn new(
        rank: usize,
        world_size: usize,
        backend: DistributedBackend,
        master_addr: String,
        master_port: u16,
    ) -> Self {
        Self {
            rank,
            world_size,
            backend,
            master_addr,
            master_port,
        }
    }

    /// Initialize the process group
    /// プロセスグループを初期化
    pub fn init(&self) -> crate::error::RusTorchResult<()> {
        match self.backend {
            DistributedBackend::NCCL => self.init_nccl(),
            DistributedBackend::Gloo => self.init_gloo(),
            DistributedBackend::MPI => self.init_mpi(),
            DistributedBackend::TCP => self.init_tcp(),
        }
    }

    fn init_nccl(&self) -> crate::error::RusTorchResult<()> {
        // NCCL initialization implementation
        // NCCL初期化実装
        #[cfg(feature = "nccl")]
        {
            // Initialize NCCL communicator
            // NCCL通信器を初期化
            Ok(())
        }
        #[cfg(not(feature = "nccl"))]
        {
            Err(crate::error::RusTorchError::distributed(
                "NCCL not compiled",
            ))
        }
    }

    fn init_gloo(&self) -> crate::error::RusTorchResult<()> {
        // Gloo initialization implementation
        // Gloo初期化実装
        Ok(())
    }

    fn init_mpi(&self) -> crate::error::RusTorchResult<()> {
        // MPI not supported - use TCP or NCCL instead
        Err(crate::error::RusTorchError::distributed(
            "MPI not supported - use TCP or NCCL",
        ))
    }

    fn init_tcp(&self) -> crate::error::RusTorchResult<()> {
        // TCP backend initialization
        // TCPバックエンド初期化
        Ok(())
    }
}

// DistributedError enum removed - now using unified RusTorchError system
// DistributedErrorエナム削除 - 統一RusTorchErrorシステムを使用

// Result type unified in error module - no need for local alias

/// Communication operations for distributed training
/// 分散学習用通信操作
pub trait DistributedOps<T: Float> {
    /// All-reduce operation across all processes
    /// 全プロセス間でのall-reduce操作
    fn all_reduce(&self, tensor: &mut Tensor<T>, op: ReduceOp) -> RusTorchResult<()>;

    /// All-gather operation across all processes
    /// 全プロセス間でのall-gather操作
    fn all_gather(&self, tensor: &Tensor<T>) -> RusTorchResult<Vec<Tensor<T>>>;

    /// Broadcast operation from root process
    /// ルートプロセスからのブロードキャスト操作
    fn broadcast(&self, tensor: &mut Tensor<T>, root: usize) -> RusTorchResult<()>;

    /// Reduce operation to root process
    /// ルートプロセスへのreduce操作
    fn reduce(&self, tensor: &mut Tensor<T>, root: usize, op: ReduceOp) -> RusTorchResult<()>;

    /// Scatter operation from root process
    /// ルートプロセスからのscatter操作
    fn scatter(&self, tensors: &[Tensor<T>], root: usize) -> RusTorchResult<Tensor<T>>;

    /// Gather operation to root process
    /// ルートプロセスへのgather操作
    fn gather(&self, tensor: &Tensor<T>, root: usize) -> RusTorchResult<Vec<Tensor<T>>>;
}

/// Reduction operations for collective communications
/// 集合通信用リダクション操作
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
    /// Sum reduction
    /// 合計リダクション
    Sum,
    /// Product reduction
    /// 積リダクション
    Product,
    /// Minimum reduction
    /// 最小値リダクション
    Min,
    /// Maximum reduction
    /// 最大値リダクション
    Max,
    /// Average reduction
    /// 平均リダクション
    Average,
}

/// Global distributed state
/// グローバル分散状態
static mut DISTRIBUTED_STATE: Option<Arc<Mutex<DistributedState>>> = None;
static DISTRIBUTED_INIT: std::sync::Once = std::sync::Once::new();

/// Distributed state management
/// 分散状態管理
#[derive(Debug)]
pub struct DistributedState {
    /// Current process group
    /// 現在のプロセスグループ
    pub process_group: Option<ProcessGroup>,
    /// Available devices for distributed training
    /// 分散学習で利用可能なデバイス
    pub devices: Vec<DeviceType>,
    /// Device mapping for each rank
    /// 各ランクのデバイスマッピング
    pub device_map: HashMap<usize, Vec<DeviceType>>,
}

impl Default for DistributedState {
    fn default() -> Self {
        Self::new()
    }
}

impl DistributedState {
    /// Create new distributed state
    /// 新しい分散状態を作成
    pub fn new() -> Self {
        Self {
            process_group: None,
            devices: Vec::new(),
            device_map: HashMap::new(),
        }
    }

    /// Set process group
    /// プロセスグループを設定
    pub fn set_process_group(&mut self, pg: ProcessGroup) {
        self.process_group = Some(pg);
    }

    /// Get current rank
    /// 現在のランクを取得
    pub fn rank(&self) -> Option<usize> {
        self.process_group.as_ref().map(|pg| pg.rank)
    }

    /// Get world size
    /// ワールドサイズを取得
    pub fn world_size(&self) -> Option<usize> {
        self.process_group.as_ref().map(|pg| pg.world_size)
    }

    /// Check if distributed training is initialized
    /// 分散学習が初期化されているかチェック
    pub fn is_initialized(&self) -> bool {
        self.process_group.is_some()
    }
}

/// Get global distributed state
/// グローバル分散状態を取得
pub fn get_distributed_state() -> &'static Arc<Mutex<DistributedState>> {
    unsafe {
        DISTRIBUTED_INIT.call_once(|| {
            DISTRIBUTED_STATE = Some(Arc::new(Mutex::new(DistributedState::new())));
        });
        #[allow(static_mut_refs)]
        DISTRIBUTED_STATE.as_ref().unwrap()
    }
}

/// Initialize distributed training
/// 分散学習を初期化
pub fn init_distributed(
    backend: DistributedBackend,
    rank: usize,
    world_size: usize,
    master_addr: String,
    master_port: u16,
) -> RusTorchResult<()> {
    let process_group = ProcessGroup::new(rank, world_size, backend, master_addr, master_port);
    process_group.init()?;

    let state = get_distributed_state();
    let mut state_guard = state.lock().unwrap();
    state_guard.set_process_group(process_group);

    Ok(())
}

/// Check if distributed training is available
/// 分散学習が利用可能かチェック
pub fn is_available() -> bool {
    let state = get_distributed_state();
    let state_guard = state.lock().unwrap();
    state_guard.is_initialized()
}

/// Get current rank in distributed training
/// 分散学習での現在のランクを取得
pub fn get_rank() -> Option<usize> {
    let state = get_distributed_state();
    let state_guard = state.lock().unwrap();
    state_guard.rank()
}

/// Get world size in distributed training
/// 分散学習でのワールドサイズを取得
pub fn get_world_size() -> Option<usize> {
    let state = get_distributed_state();
    let state_guard = state.lock().unwrap();
    state_guard.world_size()
}

/// Finalize distributed training
/// 分散学習を終了
pub fn finalize() -> RusTorchResult<()> {
    let state = get_distributed_state();
    let mut state_guard = state.lock().unwrap();
    state_guard.process_group = None;
    state_guard.devices.clear();
    state_guard.device_map.clear();

    Ok(())
}

/// Data parallel training module
/// データ並列学習モジュール
pub mod api;
pub mod async_gradient;
pub mod backends;
pub mod cluster;
pub mod common;
pub mod data_parallel;
pub mod ddp;
pub mod model_parallel;
pub mod multi_gpu_validation;
pub mod nccl_integration;
pub mod optimizer;
pub mod performance;
pub mod simple_ddp;

// Re-export core distributed functionality
pub use api::*;

// Re-export DDP implementations with shared trait
pub use ddp::{wrap_module, DistributedDataParallel};
pub use simple_ddp::{wrap_simple, SimpleDistributedDataParallel};

// Re-export async gradient synchronization
pub use async_gradient::{AsyncConfig, AsyncGradientSynchronizer, Priority};

// Traits are already public - no need for re-export

#[cfg(feature = "nccl")]
pub use nccl_integration::{NCCLBackendOptimized, NCCLOps, NCCLOptimizations, NCCLProfiler};

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_process_group_creation() {
        let pg = ProcessGroup::new(
            0,
            4,
            DistributedBackend::TCP,
            "localhost".to_string(),
            12345,
        );

        assert_eq!(pg.rank, 0);
        assert_eq!(pg.world_size, 4);
        assert_eq!(pg.backend, DistributedBackend::TCP);
        assert_eq!(pg.master_addr, "localhost");
        assert_eq!(pg.master_port, 12345);
    }

    #[test]
    fn test_distributed_state() {
        let mut state = DistributedState::new();
        assert!(!state.is_initialized());

        let pg = ProcessGroup::new(
            1,
            2,
            DistributedBackend::Gloo,
            "127.0.0.1".to_string(),
            29500,
        );

        state.set_process_group(pg);
        assert!(state.is_initialized());
        assert_eq!(state.rank(), Some(1));
        assert_eq!(state.world_size(), Some(2));
    }

    #[test]
    fn test_reduce_op_variants() {
        let ops = [
            ReduceOp::Sum,
            ReduceOp::Product,
            ReduceOp::Min,
            ReduceOp::Max,
            ReduceOp::Average,
        ];

        for op in &ops {
            assert!(matches!(
                op,
                ReduceOp::Sum
                    | ReduceOp::Product
                    | ReduceOp::Min
                    | ReduceOp::Max
                    | ReduceOp::Average
            ));
        }
    }
}