1use scirs2_core::ndarray::{Array1, Array2};
12
13#[derive(Debug)]
18pub struct BatchProcessor {
19 num_threads: usize,
21}
22
23impl Default for BatchProcessor {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl BatchProcessor {
30 pub fn new() -> Self {
32 Self { num_threads: 0 }
33 }
34
35 pub fn with_threads(num_threads: usize) -> Self {
37 Self { num_threads }
38 }
39
40 pub fn num_threads(&self) -> usize {
42 if self.num_threads == 0 {
43 num_cpus_hint()
44 } else {
45 self.num_threads
46 }
47 }
48
49 pub fn process_batch<F, T, R>(&self, inputs: &[T], f: F) -> Vec<R>
53 where
54 F: Fn(&T) -> R,
55 {
56 inputs.iter().map(f).collect()
59 }
60
61 pub fn process_layers_parallel<F, R>(&self, num_layers: usize, f: F) -> Vec<R>
66 where
67 F: Fn(usize) -> R,
68 {
69 (0..num_layers).map(f).collect()
72 }
73}
74
75pub fn parallel_matvec_batch(
79 matrices: &[Array2<f32>],
80 vectors: &[Array1<f32>],
81) -> Vec<Array1<f32>> {
82 matrices
85 .iter()
86 .zip(vectors.iter())
87 .map(|(m, v)| m.dot(v))
88 .collect()
89}
90
91pub fn parallel_map<F>(data: &mut [f32], f: F)
95where
96 F: Fn(f32) -> f32,
97{
98 data.iter_mut().for_each(|x| *x = f(*x));
101}
102
103pub fn parallel_sum(data: &[f32]) -> f32 {
107 data.iter().sum()
110}
111
112pub fn parallel_dot(a: &[f32], b: &[f32]) -> f32 {
117 crate::simd::dot_product(a, b)
119}
120
121fn num_cpus_hint() -> usize {
123 std::thread::available_parallelism()
126 .map(|p| p.get())
127 .unwrap_or(1)
128}
129
130#[derive(Debug, Clone)]
132pub struct ParallelConfig {
133 pub parallel_batch: bool,
135 pub parallel_heads: bool,
137 pub min_batch_size: usize,
139 pub min_vector_size: usize,
141}
142
143impl Default for ParallelConfig {
144 fn default() -> Self {
145 Self {
146 parallel_batch: true,
147 parallel_heads: true,
148 min_batch_size: 4,
149 min_vector_size: 4096,
150 }
151 }
152}
153
154impl ParallelConfig {
155 pub fn throughput() -> Self {
157 Self {
158 parallel_batch: true,
159 parallel_heads: true,
160 min_batch_size: 2,
161 min_vector_size: 2048,
162 }
163 }
164
165 pub fn latency() -> Self {
167 Self {
168 parallel_batch: false,
169 parallel_heads: false,
170 min_batch_size: 16,
171 min_vector_size: 8192,
172 }
173 }
174
175 pub fn should_parallel_batch(&self, batch_size: usize) -> bool {
177 self.parallel_batch && batch_size >= self.min_batch_size
178 }
179
180 pub fn should_parallel_heads(&self, num_heads: usize) -> bool {
182 self.parallel_heads && num_heads >= 2
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_batch_processor() {
192 let processor = BatchProcessor::new();
193 let inputs = vec![1, 2, 3, 4, 5];
194 let results = processor.process_batch(&inputs, |&x| x * 2);
195 assert_eq!(results, vec![2, 4, 6, 8, 10]);
196 }
197
198 #[test]
199 fn test_parallel_config() {
200 let config = ParallelConfig::default();
201 assert!(config.should_parallel_batch(4));
202 assert!(!config.should_parallel_batch(2));
203 }
204
205 #[test]
206 fn test_parallel_dot() {
207 let a: Vec<f32> = (0..100).map(|x| x as f32).collect();
208 let b: Vec<f32> = vec![1.0; 100];
209 let result = parallel_dot(&a, &b);
210 let expected: f32 = (0..100).map(|x| x as f32).sum();
211 assert!((result - expected).abs() < 1e-3);
212 }
213
214 #[test]
215 fn test_parallel_sum() {
216 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
217 let result = parallel_sum(&data);
218 let expected: f32 = (0..100).map(|x| x as f32).sum();
219 assert!((result - expected).abs() < 1e-5);
220 }
221
222 #[test]
223 fn test_parallel_matvec_batch() {
224 let m1 = Array2::eye(3);
225 let m2 = Array2::eye(3);
226 let v1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
227 let v2 = Array1::from_vec(vec![4.0, 5.0, 6.0]);
228
229 let results = parallel_matvec_batch(&[m1, m2], &[v1.clone(), v2.clone()]);
230
231 assert_eq!(results.len(), 2);
232 assert_eq!(results[0], v1);
233 assert_eq!(results[1], v2);
234 }
235}