kizzasi_core/
parallel.rs

1//! Parallel computation utilities for multi-layer SSM processing
2//!
3//! Provides parallel execution strategies for:
4//! - Batch processing of multiple inputs
5//! - Parallel layer computation where data dependencies allow
6//! - Multi-threaded matrix operations
7//!
8//! Uses scirs2-core parallel abstractions (NOT rayon directly per KIZZASI_POLICY.md).
9//! Parallel features are enabled via scirs2-core's parallel feature.
10
11use scirs2_core::ndarray::{Array1, Array2};
12
13/// Batch processor for parallel input processing
14///
15/// When scirs2-core parallel features are available, uses multi-threaded processing.
16/// Falls back to sequential processing otherwise.
17#[derive(Debug)]
18pub struct BatchProcessor {
19    /// Number of worker threads (0 = auto-detect)
20    num_threads: usize,
21}
22
23impl Default for BatchProcessor {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl BatchProcessor {
30    /// Create a new batch processor with automatic thread count
31    pub fn new() -> Self {
32        Self { num_threads: 0 }
33    }
34
35    /// Create a batch processor with specific thread count
36    pub fn with_threads(num_threads: usize) -> Self {
37        Self { num_threads }
38    }
39
40    /// Get the number of threads
41    pub fn num_threads(&self) -> usize {
42        if self.num_threads == 0 {
43            num_cpus_hint()
44        } else {
45            self.num_threads
46        }
47    }
48
49    /// Process a batch of inputs
50    ///
51    /// Uses parallel processing via scirs2-core when available.
52    pub fn process_batch<F, T, R>(&self, inputs: &[T], f: F) -> Vec<R>
53    where
54        F: Fn(&T) -> R,
55    {
56        // Currently using sequential processing
57        // Will use scirs2_core::parallel when API is stabilized
58        inputs.iter().map(f).collect()
59    }
60
61    /// Process multiple layers
62    ///
63    /// For independent layer computations (e.g., different attention heads).
64    /// Sequential layer dependencies still require sequential processing.
65    pub fn process_layers_parallel<F, R>(&self, num_layers: usize, f: F) -> Vec<R>
66    where
67        F: Fn(usize) -> R,
68    {
69        // Currently using sequential processing
70        // Will use scirs2_core::parallel when API is stabilized
71        (0..num_layers).map(f).collect()
72    }
73}
74
75/// Matrix-vector multiplication for batched operations
76///
77/// Uses parallel processing via scirs2-core when available.
78pub fn parallel_matvec_batch(
79    matrices: &[Array2<f32>],
80    vectors: &[Array1<f32>],
81) -> Vec<Array1<f32>> {
82    // Currently using sequential processing
83    // Will use scirs2_core::parallel when API is stabilized
84    matrices
85        .iter()
86        .zip(vectors.iter())
87        .map(|(m, v)| m.dot(v))
88        .collect()
89}
90
91/// Element-wise operations on arrays
92///
93/// Uses parallel processing via scirs2-core when available.
94pub fn parallel_map<F>(data: &mut [f32], f: F)
95where
96    F: Fn(f32) -> f32,
97{
98    // Currently using sequential processing
99    // Will use scirs2_core::parallel when API is stabilized
100    data.iter_mut().for_each(|x| *x = f(*x));
101}
102
103/// Reduction (sum)
104///
105/// Uses parallel processing via scirs2-core when available.
106pub fn parallel_sum(data: &[f32]) -> f32 {
107    // Currently using sequential processing
108    // Will use scirs2_core::parallel when API is stabilized
109    data.iter().sum()
110}
111
112/// Dot product for large vectors
113///
114/// Uses SIMD-optimized version, and will use parallel processing via
115/// scirs2-core for very large vectors when API is stabilized.
116pub fn parallel_dot(a: &[f32], b: &[f32]) -> f32 {
117    // Use SIMD version (already optimized)
118    crate::simd::dot_product(a, b)
119}
120
121/// Hint for number of CPUs
122fn num_cpus_hint() -> usize {
123    // Will use scirs2_core::parallel::num_threads() when API is stabilized
124    // For now, use a reasonable default
125    std::thread::available_parallelism()
126        .map(|p| p.get())
127        .unwrap_or(1)
128}
129
130/// Configuration for parallel execution
131#[derive(Debug, Clone)]
132pub struct ParallelConfig {
133    /// Enable parallel batch processing
134    pub parallel_batch: bool,
135    /// Enable parallel layer computation (for independent heads)
136    pub parallel_heads: bool,
137    /// Minimum batch size to trigger parallel processing
138    pub min_batch_size: usize,
139    /// Minimum vector size for parallel operations
140    pub min_vector_size: usize,
141}
142
143impl Default for ParallelConfig {
144    fn default() -> Self {
145        Self {
146            parallel_batch: true,
147            parallel_heads: true,
148            min_batch_size: 4,
149            min_vector_size: 4096,
150        }
151    }
152}
153
154impl ParallelConfig {
155    /// Create configuration optimized for throughput
156    pub fn throughput() -> Self {
157        Self {
158            parallel_batch: true,
159            parallel_heads: true,
160            min_batch_size: 2,
161            min_vector_size: 2048,
162        }
163    }
164
165    /// Create configuration optimized for latency (less parallelism)
166    pub fn latency() -> Self {
167        Self {
168            parallel_batch: false,
169            parallel_heads: false,
170            min_batch_size: 16,
171            min_vector_size: 8192,
172        }
173    }
174
175    /// Should use parallel batch processing for this batch size?
176    pub fn should_parallel_batch(&self, batch_size: usize) -> bool {
177        self.parallel_batch && batch_size >= self.min_batch_size
178    }
179
180    /// Should use parallel heads for this number of heads?
181    pub fn should_parallel_heads(&self, num_heads: usize) -> bool {
182        self.parallel_heads && num_heads >= 2
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_batch_processor() {
192        let processor = BatchProcessor::new();
193        let inputs = vec![1, 2, 3, 4, 5];
194        let results = processor.process_batch(&inputs, |&x| x * 2);
195        assert_eq!(results, vec![2, 4, 6, 8, 10]);
196    }
197
198    #[test]
199    fn test_parallel_config() {
200        let config = ParallelConfig::default();
201        assert!(config.should_parallel_batch(4));
202        assert!(!config.should_parallel_batch(2));
203    }
204
205    #[test]
206    fn test_parallel_dot() {
207        let a: Vec<f32> = (0..100).map(|x| x as f32).collect();
208        let b: Vec<f32> = vec![1.0; 100];
209        let result = parallel_dot(&a, &b);
210        let expected: f32 = (0..100).map(|x| x as f32).sum();
211        assert!((result - expected).abs() < 1e-3);
212    }
213
214    #[test]
215    fn test_parallel_sum() {
216        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
217        let result = parallel_sum(&data);
218        let expected: f32 = (0..100).map(|x| x as f32).sum();
219        assert!((result - expected).abs() < 1e-5);
220    }
221
222    #[test]
223    fn test_parallel_matvec_batch() {
224        let m1 = Array2::eye(3);
225        let m2 = Array2::eye(3);
226        let v1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
227        let v2 = Array1::from_vec(vec![4.0, 5.0, 6.0]);
228
229        let results = parallel_matvec_batch(&[m1, m2], &[v1.clone(), v2.clone()]);
230
231        assert_eq!(results.len(), 2);
232        assert_eq!(results[0], v1);
233        assert_eq!(results[1], v2);
234    }
235}