train_station/lib.rs
1//! # Train Station
2//!
3//! Maximum performance, zero-dependency Rust machine learning library.
4//!
5//! Train Station is designed as a **zero-dependency, maximum performance** Rust machine learning library
6//! optimized for raw computational speed, zero-cost abstractions, and minimal memory overhead. This makes
7//! it positioned nicely for embedded applications or edge deployments (or if you just want easy compilation across platforms and static linking).
8//! The library provides high-performance tensors with SIMD optimization, automatic differentiation, and comprehensive
9//! mathematical operations suitable for production ML workloads. CPU bound today, GPU capability tomorrow (foundation for CUDA support in place).
10//!
11//! <mark>**Note**: CUDA support is not yet implemented, but the foundation is in place. Device support is in place but not yet thoroughly tested or supported.
12//! Effectively, the libtrary is currently CPU-only until CUDA support is implemented. Feel free to contribute!
13//!
14//! The plan is to rapidly add functionality and operation support in early stages of development as the library matures.
15//! CUDA support will then follow.</mark>
16//!
17//!
18//! # Design Philosophy
19//!
20//! - <mark>**Zero Dependencies**: Standard library only - no external crates required or utilized inside of Train Station</mark>
21//! - <mark>**Iterator Integration**: Implemented as a trait, allowing leveraging of Rust's iterator system
22//! while maintaining Train Station's functionality (gradtrack, etc)</mark>
23//! - **Raw Performance**: Direct memory management with unsafe optimizations justified by benchmarks
24//! - **Zero-Cost Abstractions**: Compile-time optimization, enum dispatch, no virtual calls
25//! - **Memory Safety**: RAII patterns with justified unsafe usage and comprehensive validation
26//! - **Simplicity**: Minimal redundancy, direct implementations, clear API design
27//! - **Thread Safety**: All public APIs are Send + Sync for concurrent usage
28//!
29//! # Core Features
30//!
31//! - **High-Performance Tensors**: SIMD-optimized multi-dimensional arrays with AVX2 support
32//! - **Automatic Differentiation (GradTrack)**: Zero-overhead gradient tracking with computation graph optimization
33//! - **Mathematical Operations**: Complete suite of tensor operations with broadcasting support
34//! (cuurently add add, sub, mul, div operations tested with broadcasting. Future TODO to ensure all operations are tested with broadcasting)
35//! - **Device Management**: Unified CPU/CUDA device abstraction with thread-safe context switching
36//! - **Serialization Framework**: Binary and JSON serialization for model checkpointing
37//! (very minimal framework, feel free to use serde_json or bincode for more complex use cases)
38//! - **Optimizer Implementations**: Adam optimizer with SIMD-optimized parameter updates
39//! - **Memory Management**: Thread-safe memory pool with global allocator and statistics
40//!
41//! # Organization
42//!
43//! The library is organized into specialized modules for maximum performance and maintainability:
44//!
45//! - **`tensor`**: Core tensor system with operations, transformations, and indexing
46//! - **`gradtrack`**: Gradient tracking system with computation graph management
47//! - **`device`**: Device management for CPU and CUDA operations
48//! - **`optimizers`**: Optimization algorithms (Adam) with parameter management
49//! - **`serialization`**: Binary and JSON serialization framework
50//! - **`cuda`**: CUDA FFI for GPU acceleration (feature-gated)
51//!
52//! # Performance Characteristics
53//!
54//! - **Memory Overhead**: ~64 bytes per tensor (excluding data)
55//! - **SIMD Alignment**: 32-byte alignment for AVX2 operations
56//! - **Zero-Cost Operators**: Mathematical expressions with no runtime overhead
57//! - **Thread Safety**: Lock-free operations with atomic ID generation
58//! - **Memory Pool**: Thread-safe global allocator with statistics tracking
59//! - **Gradient Tracking**: Zero-overhead when disabled, optimized when enabled
60//!
61//! # Examples
62//!
63//! ## Basic Tensor Operations
64//!
65//! ```rust
66//! use train_station::{Tensor, Device};
67//!
68//! // Create tensors with different configurations
69//! let tensor = Tensor::new(vec![2, 3, 4]);
70//! let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();
71//! let device_tensor = Tensor::zeros_on_device(vec![100, 100], Device::cpu());
72//!
73//! // Access tensor properties
74//! assert_eq!(tensor.size(), 24);
75//! assert_eq!(tensor.shape().dims, vec![2, 3, 4]);
76//! assert!(tensor.is_contiguous());
77//! assert!(tensor.is_simd_aligned());
78//! ```
79//!
80//! ## Mathematical Operations with Operator Overloading
81//!
82//! ```rust
83//! use train_station::Tensor;
84//!
85//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
86//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
87//!
88//! // Tensor operations with operators (each operation consumes the tensors)
89//! let result1 = a + b; // Tensor addition
90//!
91//! let a2 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
92//! let b2 = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
93//! let result2 = a2 * b2; // Element-wise multiplication
94//!
95//! let a3 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
96//! let b3 = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
97//! let result3 = a3 - b3; // Tensor subtraction
98//!
99//! let a4 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
100//! let b4 = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
101//! let result4 = a4 / b4; // Element-wise division
102//!
103//! // Scalar operations
104//! let a5 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
105//! let result5 = a5 + 5.0; // Tensor + scalar
106//!
107//! let a6 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
108//! let result6 = 5.0 + a6; // Scalar + tensor
109//!
110//! let a7 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
111//! let result7 = a7 * 3.0; // Tensor * scalar
112//!
113//! let a8 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
114//! let result8 = 3.0 * a8; // Scalar * tensor
115//!
116//! // Compound expressions
117//! let a9 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
118//! let b9 = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
119//! let result9 = (a9 + b9) * 2.0 - 1.0; // Complex mathematical expressions
120//!
121//! // Assignment operators
122//! let a10 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
123//! let b10 = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
124//! let mut c = a10.clone();
125//! c += b10; // In-place addition
126//! c *= 2.0; // In-place scalar multiplication
127//!
128//! // Negation
129//! let a11 = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
130//! let result11 = -a11; // Negate all elements
131//! ```
132//!
133//! ## Automatic Differentiation
134//!
135//! ```rust
136//! use train_station::{NoGradTrack, Tensor};
137//!
138//! // Enable gradient tracking
139//! let a = Tensor::ones(vec![1000, 1000]).with_requires_grad();
140//! let b = Tensor::zeros(vec![1000, 1000]);
141//! let mut result = &a + &b + 5.0;
142//!
143//! // Compute gradients
144//! result.backward(None);
145//!
146//! // Access gradients
147//! if let Some(grad) = a.grad() {
148//! println!("Gradient shape: {:?}", grad.shape().dims);
149//! }
150//!
151//! // Disable gradients for inference
152//! {
153//! let _guard = NoGradTrack::new();
154//! let inference_result = &a + &b; // No gradients tracked
155//! }
156//! ```
157//!
158//! ## Device Management
159//!
160//! ```rust
161//! use train_station::{Device, with_device, set_default_device, Tensor};
162//!
163//! // Basic device usage
164//! let cpu_device = Device::cpu();
165//! let tensor = Tensor::new_on_device(vec![2, 3], cpu_device);
166//!
167//! // Context management (similar to PyTorch)
168//! with_device(Device::cpu(), || {
169//! let tensor = Tensor::new(vec![3, 4]); // Uses context device
170//! // ... operations
171//! }); // Device automatically restored
172//!
173//! // CUDA usage (when feature enabled)
174//! #[cfg(feature = "cuda")]
175//! {
176//! if train_station::cuda_is_available() {
177//! let cuda_device = Device::cuda(0);
178//! let gpu_tensor = Tensor::new_on_device(vec![1000, 1000], cuda_device);
179//! }
180//! }
181//! ```
182//!
183//! ## Optimization with Adam
184//!
185//! ```rust
186//! use train_station::{Tensor};
187//! use train_station::optimizers::{Adam, Optimizer};
188//!
189//! // Create parameters
190//! let mut param1 = Tensor::randn(vec![100, 100], None).with_requires_grad();
191//! let mut param2 = Tensor::randn(vec![100, 100], None).with_requires_grad();
192//!
193//! // Create optimizer
194//! let mut optimizer = Adam::with_learning_rate(0.001);
195//! optimizer.add_parameter(¶m1);
196//! optimizer.add_parameter(¶m2);
197//!
198//! // Training loop
199//! for epoch in 0..100 {
200//! // Forward pass
201//! let mut loss = param1.matmul(¶m2).sum();
202//!
203//! // Backward pass
204//! loss.backward(None);
205//!
206//! // Optimization step
207//! optimizer.step(&mut [&mut param1, &mut param2]);
208//! optimizer.zero_grad(&mut [&mut param1, &mut param2]);
209//! }
210//! ```
211//!
212//! ## Serialization
213//!
214//! ```rust
215//! use train_station::Tensor;
216//! use train_station::serialization::StructSerializable;
217//!
218//! let tensor = Tensor::new(vec![2, 3]);
219//!
220//! // Save in JSON format (human-readable)
221//! tensor.save_json("model.json").unwrap();
222//!
223//! // Save in binary format (efficient)
224//! tensor.save_binary("model.bin").unwrap();
225//!
226//! // Load from file
227//! let loaded_tensor = Tensor::load_json("model.json").unwrap();
228//! ```
229//!
230//! # Thread Safety
231//!
232//! All public APIs in Train Station are designed to be thread-safe:
233//!
234//! - **Tensor Operations**: All tensor operations are Send + Sync
235//! - **Device Management**: Thread-safe device context switching with automatic restoration
236//! - **Gradient Tracking**: Thread-local computation graph storage
237//! - **Memory Management**: Thread-safe global memory pool with atomic operations
238//! - **Optimizers**: Thread-safe parameter updates with exclusive access patterns
239//! - **Serialization**: Thread-safe file operations with proper error handling
240//!
241//! # Memory Safety
242//!
243//! Train Station prioritizes memory safety while maintaining maximum performance:
244//!
245//! - **RAII Patterns**: Automatic resource cleanup through Drop implementations
246//! - **Justified Unsafe Code**: All unsafe operations validated against LibTorch reference
247//! - **Comprehensive Validation**: Mathematical equivalence proven for all operations
248//! - **Memory Pool**: Thread-safe allocation with statistics and error detection
249//! - **Zero-Copy Views**: Efficient tensor views with shared memory management
250//!
251//! # Feature Flags
252//!
253//! - **`cuda`**: Enables CUDA GPU acceleration support (only foundational, a big future TODO)
254//!
255//! # Performance Benchmarks
256//!
257//! Train Station is designed to achieve maximum performance:
258//!
259//! - **Tensor Operations**: SIMD-optimized with AVX2 support for x86_64
260//! - **Memory Allocation**: Thread-safe pool allocator with minimal overhead
261//! - **Gradient Computation**: Zero-overhead tracking with optimized accumulation
262//! - **Mathematical Expressions**: Zero-cost operator overloading
263//! - **Serialization**: Optimized binary format for production deployment
264//!
265//! # Design Principles
266//!
267//! - **Performance First**: Every design decision optimized for speed
268//! - **Zero Dependencies**: Only standard library dependencies
269//! - **Memory Safety**: RAII patterns with justified unsafe usage
270//! - **Thread Safety**: All public APIs Send + Sync
271//! - **Simplicity**: Minimal redundancy, direct implementations
272//! - **Future Proof**: Foundation for advanced ML operations
273//! - **Natural API**: Operator overloading for intuitive mathematical expressions
274//! - **Comprehensive Testing**: 100% coverage with mathematical validation
275
276#[cfg(feature = "cuda")]
277pub(crate) mod cuda;
278pub(crate) mod device;
279pub(crate) mod gradtrack;
280pub mod optimizers;
281pub mod serialization;
282pub mod tensor;
283
284pub use device::{
285 cuda_device_count, cuda_is_available, current_device, get_default_device, set_default_device,
286 with_device, Device, DeviceType,
287};
288pub use gradtrack::{
289 clear_gradients, is_grad_enabled, set_grad_enabled, with_no_grad, NoGradTrack,
290};
291pub use tensor::Tensor;