kizzasi_core/
scan.rs

1//! Parallel Scan Algorithms for SSMs
2//!
3//! Implements efficient parallel scan (prefix sum) operations for:
4//! - **Linear-time sequential scan**: O(N) for inference
5//! - **Logarithmic-depth parallel scan**: O(log N) parallel time for training
6//! - **Associative scan for SSMs**: Enables parallel state computation
7//!
8//! # Theory
9//!
10//! For SSM recurrence: h_t = A * h_{t-1} + B * x_t
11//!
12//! We can express this as an associative binary operation:
13//! ```text
14//! (A₁, B₁) ⊗ (A₂, B₂) = (A₂ ∘ A₁, A₂ ∘ B₁ + B₂)
15//! ```
16//!
17//! Where ∘ is element-wise multiplication for diagonal A (S4D).
18//!
19//! This allows computing all states in parallel via:
20//! ```text
21//! h_t = SCAN(⊗, [(A₁,B₁), (A₂,B₂), ..., (Aₙ,Bₙ)])
22//! ```
23//!
24//! # Complexity
25//!
26//! - Sequential: O(N) time, O(1) extra space
27//! - Parallel (work-efficient): O(N) work, O(log N) depth
28//! - Memory: O(N) for intermediate results
29
30use crate::error::{CoreError, CoreResult};
31use crate::parallel::ParallelConfig;
32use scirs2_core::ndarray::{Array1, Array2, Array3};
33
34/// Associative binary operation for parallel scan
35pub trait AssociativeOp<T>: Send + Sync {
36    /// Apply the associative operation: a ⊗ b
37    fn combine(&self, a: &T, b: &T) -> T;
38
39    /// Identity element (if it exists)
40    fn identity(&self) -> Option<T>;
41}
42
43/// Generic parallel scan (prefix sum) using work-efficient algorithm
44///
45/// Computes: [x₀, x₀⊗x₁, x₀⊗x₁⊗x₂, ..., x₀⊗x₁⊗...⊗xₙ]
46///
47/// Uses Blelloch's work-efficient parallel scan algorithm:
48/// - Up-sweep (reduce) phase: O(log N) depth
49/// - Down-sweep (propagate) phase: O(log N) depth
50/// - Total work: O(N)
51pub fn parallel_scan<T, Op>(data: &[T], op: &Op, parallel: bool) -> Vec<T>
52where
53    T: Clone + Send + Sync,
54    Op: AssociativeOp<T>,
55{
56    if data.is_empty() {
57        return Vec::new();
58    }
59
60    if data.len() == 1 {
61        return vec![data[0].clone()];
62    }
63
64    if !parallel || data.len() < 64 {
65        // Sequential scan for small inputs
66        return sequential_scan(data, op);
67    }
68
69    // Parallel work-efficient scan
70    parallel_scan_impl(data, op)
71}
72
73/// Sequential inclusive scan
74fn sequential_scan<T, Op>(data: &[T], op: &Op) -> Vec<T>
75where
76    T: Clone,
77    Op: AssociativeOp<T>,
78{
79    let mut result = Vec::with_capacity(data.len());
80    result.push(data[0].clone());
81
82    for i in 1..data.len() {
83        let combined = op.combine(&result[i - 1], &data[i]);
84        result.push(combined);
85    }
86
87    result
88}
89
90/// Work-efficient parallel scan implementation
91///
92/// NOTE: Currently uses sequential implementation.
93/// Parallel version will be implemented when scirs2-core parallel API is stable.
94fn parallel_scan_impl<T, Op>(data: &[T], op: &Op) -> Vec<T>
95where
96    T: Clone + Send + Sync,
97    Op: AssociativeOp<T>,
98{
99    // For now, use sequential scan
100    // TODO: Implement true parallel Blelloch scan when scirs2-core API is ready
101    sequential_scan(data, op)
102}
103
104/// SSM Scan Element: (A_bar, B_bar) for diagonal SSM
105#[derive(Clone, Debug)]
106pub struct SSMElement {
107    /// Discretized A (diagonal): exp(Δ * λ)
108    pub a_bar: Array1<f32>,
109    /// Discretized B: B̄ = (exp(Δλ) - 1)/λ * B
110    pub b_bar: Array1<f32>,
111}
112
113/// Associative operation for SSM elements
114pub struct SSMScanOp;
115
116impl AssociativeOp<SSMElement> for SSMScanOp {
117    fn combine(&self, left: &SSMElement, right: &SSMElement) -> SSMElement {
118        // (A₁, B₁) ⊗ (A₂, B₂) = (A₂ ∘ A₁, A₂ ∘ B₁ + B₂)
119        let a_combined = &right.a_bar * &left.a_bar;
120        let b_combined = &right.a_bar * &left.b_bar + &right.b_bar;
121
122        SSMElement {
123            a_bar: a_combined,
124            b_bar: b_combined,
125        }
126    }
127
128    fn identity(&self) -> Option<SSMElement> {
129        None // No identity for SSM scan
130    }
131}
132
133/// Parallel SSM scan for efficient state computation
134///
135/// Given sequences of (A_bar, B_bar) elements, computes all hidden states in parallel.
136///
137/// # Arguments
138/// * `a_bars` - Discretized A matrices (seq_len, state_dim)
139/// * `b_bars` - Discretized B vectors (seq_len, state_dim)
140/// * `c` - Output projection vector (state_dim,)
141/// * `parallel_config` - Parallelization configuration
142///
143/// # Returns
144/// Output sequence (seq_len, state_dim) where each position contains the cumulative state
145pub fn parallel_ssm_scan(
146    a_bars: &Array2<f32>,
147    b_bars: &Array2<f32>,
148    c: &Array1<f32>,
149    parallel_config: &ParallelConfig,
150) -> CoreResult<Array2<f32>> {
151    let (seq_len, state_dim) = a_bars.dim();
152
153    if b_bars.dim() != (seq_len, state_dim) {
154        return Err(CoreError::DimensionMismatch {
155            expected: seq_len,
156            got: b_bars.nrows(),
157        });
158    }
159
160    if c.len() != state_dim {
161        return Err(CoreError::DimensionMismatch {
162            expected: state_dim,
163            got: c.len(),
164        });
165    }
166
167    // Create SSM elements
168    let elements: Vec<SSMElement> = (0..seq_len)
169        .map(|t| SSMElement {
170            a_bar: a_bars.row(t).to_owned(),
171            b_bar: b_bars.row(t).to_owned(),
172        })
173        .collect();
174
175    // Perform parallel scan
176    let op = SSMScanOp;
177    let scanned = parallel_scan(&elements, &op, parallel_config.parallel_batch);
178
179    // Extract B_bar components (which contain the cumulative states)
180    let mut states = Array2::zeros((seq_len, state_dim));
181    for (t, elem) in scanned.iter().enumerate() {
182        states.row_mut(t).assign(&elem.b_bar);
183    }
184
185    Ok(states)
186}
187
188/// Parallel SSM forward pass for batch of sequences
189///
190/// Computes outputs for multiple sequences in parallel using associative scan.
191///
192/// # Arguments
193/// * `a_bars` - Discretized A matrices (batch_size, seq_len, state_dim)
194/// * `b_bars` - Discretized B vectors (batch_size, seq_len, state_dim)
195/// * `c` - Output projection (state_dim,)
196/// * `d` - Skip connection
197///
198/// # Returns
199/// Outputs (batch_size, seq_len)
200pub fn parallel_ssm_batch(
201    a_bars: &Array3<f32>,
202    b_bars: &Array3<f32>,
203    c: &Array1<f32>,
204    d: f32,
205    parallel_config: &ParallelConfig,
206) -> CoreResult<Array2<f32>> {
207    let (batch_size, seq_len, state_dim) = a_bars.dim();
208
209    if b_bars.dim() != (batch_size, seq_len, state_dim) {
210        return Err(CoreError::InvalidConfig(
211            "b_bars shape mismatch".to_string(),
212        ));
213    }
214
215    // Process each batch item
216    // TODO: Use scirs2-core parallel when API is stable
217    let outputs: Vec<Array1<f32>> = (0..batch_size)
218        .map(|b| {
219            // Get this batch's A and B
220            let a_batch = a_bars.slice(s![b, .., ..]).to_owned();
221            let b_batch = b_bars.slice(s![b, .., ..]).to_owned();
222
223            // Perform scan for this sequence
224            let states = parallel_ssm_scan(&a_batch, &b_batch, c, parallel_config).unwrap();
225
226            // Compute outputs: y_t = C · h_t + D · x_t
227            // (for simplicity, assuming D*x is already included in b_bar)
228            let mut output = Array1::zeros(seq_len);
229            for t in 0..seq_len {
230                let h_t = states.row(t);
231                output[t] = c.dot(&h_t) + d;
232            }
233
234            output
235        })
236        .collect();
237
238    // Stack into output array
239    let mut result = Array2::zeros((batch_size, seq_len));
240    for (b, output) in outputs.iter().enumerate() {
241        result.row_mut(b).assign(output);
242    }
243
244    Ok(result)
245}
246
247/// Segmented parallel scan for variable-length sequences
248///
249/// Performs parallel scan where sequences are separated by segment boundaries.
250/// This is useful for processing multiple variable-length sequences in one batch.
251///
252/// # Arguments
253/// * `data` - Packed data elements
254/// * `segment_ids` - Segment ID for each element (resets scan at boundaries)
255/// * `op` - Associative operation
256///
257/// # Returns
258/// Scanned result with scan reset at segment boundaries
259pub fn segmented_scan<T, Op>(data: &[T], segment_ids: &[usize], op: &Op, parallel: bool) -> Vec<T>
260where
261    T: Clone + Send + Sync,
262    Op: AssociativeOp<T>,
263{
264    if data.len() != segment_ids.len() {
265        panic!("data and segment_ids must have same length");
266    }
267
268    if !parallel {
269        return segmented_scan_sequential(data, segment_ids, op);
270    }
271
272    // For parallel version, we need more sophisticated handling
273    // For now, fall back to sequential for correctness
274    segmented_scan_sequential(data, segment_ids, op)
275}
276
277fn segmented_scan_sequential<T, Op>(data: &[T], segment_ids: &[usize], op: &Op) -> Vec<T>
278where
279    T: Clone,
280    Op: AssociativeOp<T>,
281{
282    if data.is_empty() {
283        return Vec::new();
284    }
285
286    let mut result = Vec::with_capacity(data.len());
287    result.push(data[0].clone());
288
289    for i in 1..data.len() {
290        if segment_ids[i] != segment_ids[i - 1] {
291            // New segment - reset scan
292            result.push(data[i].clone());
293        } else {
294            // Same segment - continue scan
295            let combined = op.combine(&result[i - 1], &data[i]);
296            result.push(combined);
297        }
298    }
299
300    result
301}
302
303// Re-export for convenience
304use scirs2_core::ndarray::s;
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    // Simple addition operation for testing
311    struct AddOp;
312
313    impl AssociativeOp<f32> for AddOp {
314        fn combine(&self, a: &f32, b: &f32) -> f32 {
315            a + b
316        }
317
318        fn identity(&self) -> Option<f32> {
319            Some(0.0)
320        }
321    }
322
323    #[test]
324    fn test_sequential_scan() {
325        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
326        let op = AddOp;
327
328        let result = sequential_scan(&data, &op);
329        assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0, 15.0]);
330    }
331
332    #[test]
333    fn test_parallel_scan() {
334        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
335        let op = AddOp;
336
337        let result = parallel_scan(&data, &op, true);
338        assert_eq!(result, vec![1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0]);
339    }
340
341    #[test]
342    fn test_ssm_scan_op() {
343        let elem1 = SSMElement {
344            a_bar: Array1::from_vec(vec![0.9, 0.8]),
345            b_bar: Array1::from_vec(vec![0.1, 0.2]),
346        };
347
348        let elem2 = SSMElement {
349            a_bar: Array1::from_vec(vec![0.9, 0.8]),
350            b_bar: Array1::from_vec(vec![0.1, 0.2]),
351        };
352
353        let op = SSMScanOp;
354        let result = op.combine(&elem1, &elem2);
355
356        // (A₂ ∘ A₁, A₂ ∘ B₁ + B₂)
357        assert!((result.a_bar[0] - 0.81).abs() < 1e-6); // 0.9 * 0.9
358        assert!((result.a_bar[1] - 0.64).abs() < 1e-6); // 0.8 * 0.8
359        assert!((result.b_bar[0] - 0.19).abs() < 1e-6); // 0.9 * 0.1 + 0.1
360        assert!((result.b_bar[1] - 0.36).abs() < 1e-6); // 0.8 * 0.2 + 0.2
361    }
362
363    #[test]
364    fn test_parallel_ssm_scan() {
365        let seq_len = 4;
366        let state_dim = 2;
367
368        let a_bars = Array2::from_shape_vec(
369            (seq_len, state_dim),
370            vec![0.9, 0.8, 0.9, 0.8, 0.9, 0.8, 0.9, 0.8],
371        )
372        .unwrap();
373
374        let b_bars = Array2::from_shape_vec(
375            (seq_len, state_dim),
376            vec![0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
377        )
378        .unwrap();
379
380        let c = Array1::from_vec(vec![1.0, 1.0]);
381
382        let config = ParallelConfig::latency(); // Use sequential for determinism
383
384        let states = parallel_ssm_scan(&a_bars, &b_bars, &c, &config).unwrap();
385        assert_eq!(states.dim(), (seq_len, state_dim));
386
387        // Check that states are non-zero
388        assert!(states.iter().any(|&x| x != 0.0));
389    }
390
391    #[test]
392    fn test_segmented_scan() {
393        let data = vec![1.0, 2.0, 3.0, 1.0, 2.0];
394        let segments = vec![0, 0, 0, 1, 1]; // Two segments
395        let op = AddOp;
396
397        let result = segmented_scan(&data, &segments, &op, false);
398
399        // First segment: [1, 3, 6]
400        // Second segment: [1, 3]
401        assert_eq!(result[0], 1.0);
402        assert_eq!(result[1], 3.0);
403        assert_eq!(result[2], 6.0);
404        assert_eq!(result[3], 1.0); // Reset
405        assert_eq!(result[4], 3.0);
406    }
407
408    #[test]
409    fn test_empty_scan() {
410        let data: Vec<f32> = vec![];
411        let op = AddOp;
412
413        let result = parallel_scan(&data, &op, false);
414        assert_eq!(result.len(), 0);
415    }
416
417    #[test]
418    fn test_single_element_scan() {
419        let data = vec![42.0];
420        let op = AddOp;
421
422        let result = parallel_scan(&data, &op, true);
423        assert_eq!(result, vec![42.0]);
424    }
425}