1use crate::error::{CoreError, CoreResult};
31use crate::parallel::ParallelConfig;
32use scirs2_core::ndarray::{Array1, Array2, Array3};
33
34pub trait AssociativeOp<T>: Send + Sync {
36 fn combine(&self, a: &T, b: &T) -> T;
38
39 fn identity(&self) -> Option<T>;
41}
42
43pub fn parallel_scan<T, Op>(data: &[T], op: &Op, parallel: bool) -> Vec<T>
52where
53 T: Clone + Send + Sync,
54 Op: AssociativeOp<T>,
55{
56 if data.is_empty() {
57 return Vec::new();
58 }
59
60 if data.len() == 1 {
61 return vec![data[0].clone()];
62 }
63
64 if !parallel || data.len() < 64 {
65 return sequential_scan(data, op);
67 }
68
69 parallel_scan_impl(data, op)
71}
72
73fn sequential_scan<T, Op>(data: &[T], op: &Op) -> Vec<T>
75where
76 T: Clone,
77 Op: AssociativeOp<T>,
78{
79 let mut result = Vec::with_capacity(data.len());
80 result.push(data[0].clone());
81
82 for i in 1..data.len() {
83 let combined = op.combine(&result[i - 1], &data[i]);
84 result.push(combined);
85 }
86
87 result
88}
89
90fn parallel_scan_impl<T, Op>(data: &[T], op: &Op) -> Vec<T>
95where
96 T: Clone + Send + Sync,
97 Op: AssociativeOp<T>,
98{
99 sequential_scan(data, op)
102}
103
104#[derive(Clone, Debug)]
106pub struct SSMElement {
107 pub a_bar: Array1<f32>,
109 pub b_bar: Array1<f32>,
111}
112
113pub struct SSMScanOp;
115
116impl AssociativeOp<SSMElement> for SSMScanOp {
117 fn combine(&self, left: &SSMElement, right: &SSMElement) -> SSMElement {
118 let a_combined = &right.a_bar * &left.a_bar;
120 let b_combined = &right.a_bar * &left.b_bar + &right.b_bar;
121
122 SSMElement {
123 a_bar: a_combined,
124 b_bar: b_combined,
125 }
126 }
127
128 fn identity(&self) -> Option<SSMElement> {
129 None }
131}
132
133pub fn parallel_ssm_scan(
146 a_bars: &Array2<f32>,
147 b_bars: &Array2<f32>,
148 c: &Array1<f32>,
149 parallel_config: &ParallelConfig,
150) -> CoreResult<Array2<f32>> {
151 let (seq_len, state_dim) = a_bars.dim();
152
153 if b_bars.dim() != (seq_len, state_dim) {
154 return Err(CoreError::DimensionMismatch {
155 expected: seq_len,
156 got: b_bars.nrows(),
157 });
158 }
159
160 if c.len() != state_dim {
161 return Err(CoreError::DimensionMismatch {
162 expected: state_dim,
163 got: c.len(),
164 });
165 }
166
167 let elements: Vec<SSMElement> = (0..seq_len)
169 .map(|t| SSMElement {
170 a_bar: a_bars.row(t).to_owned(),
171 b_bar: b_bars.row(t).to_owned(),
172 })
173 .collect();
174
175 let op = SSMScanOp;
177 let scanned = parallel_scan(&elements, &op, parallel_config.parallel_batch);
178
179 let mut states = Array2::zeros((seq_len, state_dim));
181 for (t, elem) in scanned.iter().enumerate() {
182 states.row_mut(t).assign(&elem.b_bar);
183 }
184
185 Ok(states)
186}
187
188pub fn parallel_ssm_batch(
201 a_bars: &Array3<f32>,
202 b_bars: &Array3<f32>,
203 c: &Array1<f32>,
204 d: f32,
205 parallel_config: &ParallelConfig,
206) -> CoreResult<Array2<f32>> {
207 let (batch_size, seq_len, state_dim) = a_bars.dim();
208
209 if b_bars.dim() != (batch_size, seq_len, state_dim) {
210 return Err(CoreError::InvalidConfig(
211 "b_bars shape mismatch".to_string(),
212 ));
213 }
214
215 let outputs: Vec<Array1<f32>> = (0..batch_size)
218 .map(|b| {
219 let a_batch = a_bars.slice(s![b, .., ..]).to_owned();
221 let b_batch = b_bars.slice(s![b, .., ..]).to_owned();
222
223 let states = parallel_ssm_scan(&a_batch, &b_batch, c, parallel_config).unwrap();
225
226 let mut output = Array1::zeros(seq_len);
229 for t in 0..seq_len {
230 let h_t = states.row(t);
231 output[t] = c.dot(&h_t) + d;
232 }
233
234 output
235 })
236 .collect();
237
238 let mut result = Array2::zeros((batch_size, seq_len));
240 for (b, output) in outputs.iter().enumerate() {
241 result.row_mut(b).assign(output);
242 }
243
244 Ok(result)
245}
246
247pub fn segmented_scan<T, Op>(data: &[T], segment_ids: &[usize], op: &Op, parallel: bool) -> Vec<T>
260where
261 T: Clone + Send + Sync,
262 Op: AssociativeOp<T>,
263{
264 if data.len() != segment_ids.len() {
265 panic!("data and segment_ids must have same length");
266 }
267
268 if !parallel {
269 return segmented_scan_sequential(data, segment_ids, op);
270 }
271
272 segmented_scan_sequential(data, segment_ids, op)
275}
276
277fn segmented_scan_sequential<T, Op>(data: &[T], segment_ids: &[usize], op: &Op) -> Vec<T>
278where
279 T: Clone,
280 Op: AssociativeOp<T>,
281{
282 if data.is_empty() {
283 return Vec::new();
284 }
285
286 let mut result = Vec::with_capacity(data.len());
287 result.push(data[0].clone());
288
289 for i in 1..data.len() {
290 if segment_ids[i] != segment_ids[i - 1] {
291 result.push(data[i].clone());
293 } else {
294 let combined = op.combine(&result[i - 1], &data[i]);
296 result.push(combined);
297 }
298 }
299
300 result
301}
302
303use scirs2_core::ndarray::s;
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 struct AddOp;
312
313 impl AssociativeOp<f32> for AddOp {
314 fn combine(&self, a: &f32, b: &f32) -> f32 {
315 a + b
316 }
317
318 fn identity(&self) -> Option<f32> {
319 Some(0.0)
320 }
321 }
322
323 #[test]
324 fn test_sequential_scan() {
325 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
326 let op = AddOp;
327
328 let result = sequential_scan(&data, &op);
329 assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0, 15.0]);
330 }
331
332 #[test]
333 fn test_parallel_scan() {
334 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
335 let op = AddOp;
336
337 let result = parallel_scan(&data, &op, true);
338 assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0]);
339 }
340
341 #[test]
342 fn test_ssm_scan_op() {
343 let elem1 = SSMElement {
344 a_bar: Array1::from_vec(vec![0.9, 0.8]),
345 b_bar: Array1::from_vec(vec![0.1, 0.2]),
346 };
347
348 let elem2 = SSMElement {
349 a_bar: Array1::from_vec(vec![0.9, 0.8]),
350 b_bar: Array1::from_vec(vec![0.1, 0.2]),
351 };
352
353 let op = SSMScanOp;
354 let result = op.combine(&elem1, &elem2);
355
356 assert!((result.a_bar[0] - 0.81).abs() < 1e-6); assert!((result.a_bar[1] - 0.64).abs() < 1e-6); assert!((result.b_bar[0] - 0.19).abs() < 1e-6); assert!((result.b_bar[1] - 0.36).abs() < 1e-6); }
362
363 #[test]
364 fn test_parallel_ssm_scan() {
365 let seq_len = 4;
366 let state_dim = 2;
367
368 let a_bars = Array2::from_shape_vec(
369 (seq_len, state_dim),
370 vec![0.9, 0.8, 0.9, 0.8, 0.9, 0.8, 0.9, 0.8],
371 )
372 .unwrap();
373
374 let b_bars = Array2::from_shape_vec(
375 (seq_len, state_dim),
376 vec![0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
377 )
378 .unwrap();
379
380 let c = Array1::from_vec(vec![1.0, 1.0]);
381
382 let config = ParallelConfig::latency(); let states = parallel_ssm_scan(&a_bars, &b_bars, &c, &config).unwrap();
385 assert_eq!(states.dim(), (seq_len, state_dim));
386
387 assert!(states.iter().any(|&x| x != 0.0));
389 }
390
391 #[test]
392 fn test_segmented_scan() {
393 let data = vec![1.0, 2.0, 3.0, 1.0, 2.0];
394 let segments = vec![0, 0, 0, 1, 1]; let op = AddOp;
396
397 let result = segmented_scan(&data, &segments, &op, false);
398
399 assert_eq!(result[0], 1.0);
402 assert_eq!(result[1], 3.0);
403 assert_eq!(result[2], 6.0);
404 assert_eq!(result[3], 1.0); assert_eq!(result[4], 3.0);
406 }
407
408 #[test]
409 fn test_empty_scan() {
410 let data: Vec<f32> = vec![];
411 let op = AddOp;
412
413 let result = parallel_scan(&data, &op, false);
414 assert_eq!(result.len(), 0);
415 }
416
417 #[test]
418 fn test_single_element_scan() {
419 let data = vec![42.0];
420 let op = AddOp;
421
422 let result = parallel_scan(&data, &op, true);
423 assert_eq!(result, vec![42.0]);
424 }
425}