Skip to main content

somatize_runtime/
stream.rs

1//! Streaming executor — processes data in chunks through fitted filters.
2//!
3//! Respects each filter's [`StreamMode`]: FixedState processes chunks
4//! independently, Evolving updates state per chunk with checkpoints,
5//! Barrier accumulates all chunks before processing.
6
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13/// A fitted filter with its learned state, ready for streaming.
14pub struct FittedFilter {
15    pub name: String,
16    pub filter: Arc<dyn Filter>,
17    pub state: Value,
18}
19
20/// Processes a stream of chunks through a sequence of fitted filters.
21///
22/// Respects each filter's StreamMode:
23/// - FixedState: each chunk processed independently, cacheable per chunk
24/// - Evolving: state mutates with each chunk, periodic checkpoints
25/// - Barrier: accumulates all chunks, processes as batch
26pub struct StreamExecutor {
27    filters: Vec<FittedFilter>,
28    cache: Option<Arc<dyn CacheStore>>,
29    /// Accumulated chunks for Barrier filters (keyed by filter index).
30    barrier_buffers: Vec<Vec<Value>>,
31    /// Evolving states (keyed by filter index, mutated on each chunk).
32    evolving_states: Vec<Option<Value>>,
33    /// Chunk counter for checkpoint scheduling.
34    chunk_count: usize,
35}
36
37impl StreamExecutor {
38    pub fn new(filters: Vec<FittedFilter>) -> Self {
39        let n = filters.len();
40        Self {
41            filters,
42            cache: None,
43            barrier_buffers: vec![Vec::new(); n],
44            evolving_states: vec![None; n],
45            chunk_count: 0,
46        }
47    }
48
49    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
50        self.cache = Some(cache);
51        self
52    }
53
54    /// Process a single chunk through the pipeline.
55    ///
56    /// Returns the output chunk, or None if a Barrier filter is still accumulating.
57    pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
58        let mut current = chunk;
59        self.chunk_count += 1;
60
61        let n = self.filters.len();
62        for i in 0..n {
63            let mode = self.filters[i].filter.meta().stream_mode;
64
65            match mode {
66                StreamMode::FixedState => {
67                    current = self.process_fixed_state(i, &current)?;
68                }
69                StreamMode::Evolving { checkpoint_every } => {
70                    current = self.process_evolving(i, &current, checkpoint_every)?;
71                }
72                StreamMode::Barrier => {
73                    self.barrier_buffers[i].push(current);
74                    return Ok(None);
75                }
76                _ => {
77                    current = self.process_fixed_state(i, &current)?;
78                }
79            }
80        }
81
82        Ok(Some(current))
83    }
84
85    /// Flush barrier filters and process remaining data as batch.
86    ///
87    /// Call this after the stream ends to materialize barrier outputs.
88    pub fn flush(&mut self) -> Result<Option<Value>> {
89        let mut current: Option<Value> = None;
90        let n = self.filters.len();
91
92        for i in 0..n {
93            let mode = self.filters[i].filter.meta().stream_mode;
94
95            if mode == StreamMode::Barrier && !self.barrier_buffers[i].is_empty() {
96                let materialized = self.materialize_buffer(i)?;
97                let result = self.filters[i]
98                    .filter
99                    .forward(&materialized, &self.filters[i].state)?;
100                self.barrier_buffers[i].clear();
101                current = Some(result);
102            } else if let Some(val) = current.take() {
103                let result = self.filters[i]
104                    .filter
105                    .forward(&val, &self.filters[i].state)?;
106                current = Some(result);
107            }
108        }
109
110        Ok(current)
111    }
112
113    /// Process multiple chunks and collect outputs.
114    pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
115        let mut outputs = Vec::new();
116
117        for chunk in chunks {
118            if let Some(output) = self.process_chunk(chunk)? {
119                outputs.push(output);
120            }
121        }
122
123        // Flush any barrier buffers
124        if let Some(flushed) = self.flush()? {
125            outputs.push(flushed);
126        }
127
128        Ok(outputs)
129    }
130
131    /// Number of chunks processed so far.
132    pub fn chunks_processed(&self) -> usize {
133        self.chunk_count
134    }
135
136    fn process_fixed_state(&self, filter_idx: usize, input: &Value) -> Result<Value> {
137        let fitted = &self.filters[filter_idx];
138
139        // Try cache
140        if let Some(cache) = &self.cache {
141            let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
142            let cache_key = CacheKey::for_output(
143                &fitted.filter.config_hash(),
144                &CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default()),
145                &chunk_hash,
146            );
147            if let Some(cached) = cache.get(&cache_key)? {
148                return Ok(cached);
149            }
150            let result = fitted.filter.forward(input, &fitted.state)?;
151            let _ = cache.put(&cache_key, &result);
152            return Ok(result);
153        }
154
155        fitted.filter.forward(input, &fitted.state)
156    }
157
158    fn process_evolving(
159        &mut self,
160        filter_idx: usize,
161        input: &Value,
162        checkpoint_every: usize,
163    ) -> Result<Value> {
164        let fitted = &self.filters[filter_idx];
165
166        // Use evolving state if available, else initial state
167        let state = self.evolving_states[filter_idx]
168            .as_ref()
169            .unwrap_or(&fitted.state);
170
171        let result = fitted.filter.forward(input, state)?;
172
173        // For evolving: the output becomes the new state for next chunk
174        // (simplified model: state = last output)
175        self.evolving_states[filter_idx] = Some(result.clone());
176
177        // Checkpoint
178        if checkpoint_every > 0
179            && self.chunk_count.is_multiple_of(checkpoint_every)
180            && let Some(cache) = &self.cache
181        {
182            let checkpoint_key = CacheKey::from_parts(&[
183                b"checkpoint",
184                fitted.name.as_bytes(),
185                &(self.chunk_count as u64).to_le_bytes(),
186            ]);
187            let _ = cache.put(&checkpoint_key, &result);
188        }
189
190        Ok(result)
191    }
192
193    fn materialize_buffer(&self, filter_idx: usize) -> Result<Value> {
194        let buffer = &self.barrier_buffers[filter_idx];
195        if buffer.is_empty() {
196            return Ok(Value::Empty);
197        }
198
199        // Concatenate tensor chunks along first dimension
200        let mut all_data = Vec::new();
201        let mut total_rows = 0;
202        let mut cols = 0;
203
204        for chunk in buffer {
205            match chunk {
206                Value::Tensor { values, shape } => {
207                    all_data.extend(values);
208                    if shape.len() == 1 {
209                        total_rows += shape[0];
210                        cols = 1;
211                    } else if shape.len() >= 2 {
212                        total_rows += shape[0];
213                        cols = shape[1];
214                    }
215                }
216                _ => {
217                    return Err(SomaError::Other(
218                        "barrier buffer contains non-tensor values".into(),
219                    ));
220                }
221            }
222        }
223
224        if cols <= 1 {
225            Ok(Value::tensor(all_data, vec![total_rows]))
226        } else {
227            Ok(Value::tensor(all_data, vec![total_rows, cols]))
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use somatize_core::cache::CacheKey;
236    use somatize_core::filter::{FilterKind, FilterMeta};
237
238    // ── Test filters ──
239
240    struct DoubleChunk;
241    impl Filter for DoubleChunk {
242        fn config_hash(&self) -> CacheKey {
243            CacheKey::from_parts(&[b"DoubleChunk"])
244        }
245        fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
246            Ok(Value::Empty)
247        }
248        fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
249            match x {
250                Value::Tensor { values, shape } => Ok(Value::tensor(
251                    values.iter().map(|v| v * 2.0).collect(),
252                    shape.clone(),
253                )),
254                _ => Ok(x.clone()),
255            }
256        }
257        fn meta(&self) -> FilterMeta {
258            FilterMeta {
259                name: "DoubleChunk".into(),
260                kind: FilterKind::Stateless,
261                cacheable: true,
262                differentiable: true,
263                stream_mode: StreamMode::FixedState,
264                distribution: somatize_core::filter::Distribution::Local,
265                input_schema: None,
266                output_schema: None,
267            }
268        }
269
270        fn as_any(&self) -> &dyn std::any::Any {
271            self
272        }
273    }
274
275    struct Accumulator;
276    impl Filter for Accumulator {
277        fn config_hash(&self) -> CacheKey {
278            CacheKey::from_parts(&[b"Accumulator"])
279        }
280        fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
281            Ok(Value::Empty)
282        }
283        fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
284            // For barrier: receives concatenated tensor, computes mean
285            match x {
286                Value::Tensor { values, shape: _ } => {
287                    let mean = values.iter().sum::<f64>() / values.len() as f64;
288                    Ok(Value::tensor(vec![mean], vec![1]))
289                }
290                _ => Ok(x.clone()),
291            }
292        }
293        fn meta(&self) -> FilterMeta {
294            FilterMeta {
295                name: "Accumulator".into(),
296                kind: FilterKind::Trainable,
297                cacheable: false,
298                differentiable: false,
299                stream_mode: StreamMode::Barrier,
300                distribution: somatize_core::filter::Distribution::Local,
301                input_schema: None,
302                output_schema: None,
303            }
304        }
305
306        fn as_any(&self) -> &dyn std::any::Any {
307            self
308        }
309    }
310
311    struct RunningSum;
312    impl Filter for RunningSum {
313        fn config_hash(&self) -> CacheKey {
314            CacheKey::from_parts(&[b"RunningSum"])
315        }
316        fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
317            Ok(Value::tensor(vec![0.0], vec![1]))
318        }
319        fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
320            let x_val = x.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
321            let s_val = state.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
322            Ok(Value::tensor(vec![x_val + s_val], vec![1]))
323        }
324        fn meta(&self) -> FilterMeta {
325            FilterMeta {
326                name: "RunningSum".into(),
327                kind: FilterKind::Trainable,
328                cacheable: false,
329                differentiable: false,
330                stream_mode: StreamMode::Evolving {
331                    checkpoint_every: 3,
332                },
333                distribution: somatize_core::filter::Distribution::Local,
334                input_schema: None,
335                output_schema: None,
336            }
337        }
338
339        fn as_any(&self) -> &dyn std::any::Any {
340            self
341        }
342    }
343
344    // ── Tests ──
345
346    #[test]
347    fn fixed_state_processes_each_chunk() {
348        let mut executor = StreamExecutor::new(vec![FittedFilter {
349            name: "double".into(),
350            filter: Arc::new(DoubleChunk),
351            state: Value::Empty,
352        }]);
353
354        let chunks = vec![
355            Value::tensor(vec![1.0, 2.0], vec![2]),
356            Value::tensor(vec![3.0, 4.0], vec![2]),
357            Value::tensor(vec![5.0], vec![1]),
358        ];
359
360        let outputs = executor.process_all(chunks).unwrap();
361        assert_eq!(outputs.len(), 3);
362
363        let (d0, _) = outputs[0].as_tensor().unwrap();
364        assert_eq!(d0, &[2.0, 4.0]);
365        let (d1, _) = outputs[1].as_tensor().unwrap();
366        assert_eq!(d1, &[6.0, 8.0]);
367        let (d2, _) = outputs[2].as_tensor().unwrap();
368        assert_eq!(d2, &[10.0]);
369    }
370
371    #[test]
372    fn barrier_accumulates_then_flushes() {
373        let mut executor = StreamExecutor::new(vec![FittedFilter {
374            name: "acc".into(),
375            filter: Arc::new(Accumulator),
376            state: Value::Empty,
377        }]);
378
379        // Process chunks: barrier should return None for each
380        assert!(
381            executor
382                .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
383                .unwrap()
384                .is_none()
385        );
386        assert!(
387            executor
388                .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
389                .unwrap()
390                .is_none()
391        );
392        assert!(
393            executor
394                .process_chunk(Value::tensor(vec![5.0, 6.0], vec![2]))
395                .unwrap()
396                .is_none()
397        );
398
399        // Flush: should materialize and compute mean of [1,2,3,4,5,6]
400        let result = executor.flush().unwrap().unwrap();
401        let (data, _) = result.as_tensor().unwrap();
402        assert!((data[0] - 3.5).abs() < 0.01); // mean of 1..6
403    }
404
405    #[test]
406    fn evolving_state_accumulates() {
407        let mut executor = StreamExecutor::new(vec![FittedFilter {
408            name: "sum".into(),
409            filter: Arc::new(RunningSum),
410            state: Value::tensor(vec![0.0], vec![1]), // initial sum = 0
411        }]);
412
413        let r1 = executor
414            .process_chunk(Value::tensor(vec![5.0], vec![1]))
415            .unwrap()
416            .unwrap();
417        assert_eq!(r1.as_tensor().unwrap().0, &[5.0]); // 0+5=5
418
419        let r2 = executor
420            .process_chunk(Value::tensor(vec![3.0], vec![1]))
421            .unwrap()
422            .unwrap();
423        assert_eq!(r2.as_tensor().unwrap().0, &[8.0]); // 5+3=8
424
425        let r3 = executor
426            .process_chunk(Value::tensor(vec![2.0], vec![1]))
427            .unwrap()
428            .unwrap();
429        assert_eq!(r3.as_tensor().unwrap().0, &[10.0]); // 8+2=10
430    }
431
432    #[test]
433    fn mixed_pipeline_fixed_then_barrier() {
434        let mut executor = StreamExecutor::new(vec![
435            FittedFilter {
436                name: "double".into(),
437                filter: Arc::new(DoubleChunk),
438                state: Value::Empty,
439            },
440            FittedFilter {
441                name: "acc".into(),
442                filter: Arc::new(Accumulator),
443                state: Value::Empty,
444            },
445        ]);
446
447        let chunks = vec![
448            Value::tensor(vec![1.0], vec![1]),
449            Value::tensor(vec![2.0], vec![1]),
450            Value::tensor(vec![3.0], vec![1]),
451        ];
452
453        let outputs = executor.process_all(chunks).unwrap();
454        // DoubleChunk doubles: [2,4,6]. Accumulator sees barrier after double.
455        // After flush: mean of [2,4,6] = 4.0
456        assert_eq!(outputs.len(), 1);
457        let (data, _) = outputs[0].as_tensor().unwrap();
458        assert!((data[0] - 4.0).abs() < 0.01);
459    }
460
461    #[test]
462    fn fixed_state_with_cache() {
463        let cache = Arc::new(crate::MemoryCache::default());
464        let mut executor = StreamExecutor::new(vec![FittedFilter {
465            name: "double".into(),
466            filter: Arc::new(DoubleChunk),
467            state: Value::Empty,
468        }])
469        .with_cache(cache.clone());
470
471        let chunk = Value::tensor(vec![7.0], vec![1]);
472
473        // First call: cache miss
474        let r1 = executor.process_chunk(chunk.clone()).unwrap().unwrap();
475        assert_eq!(r1.as_tensor().unwrap().0, &[14.0]);
476        assert!(!cache.is_empty()); // cached
477
478        // Second call with same chunk: cache hit
479        let r2 = executor.process_chunk(chunk).unwrap().unwrap();
480        assert_eq!(r2.as_tensor().unwrap().0, &[14.0]);
481    }
482
483    #[test]
484    fn chunks_processed_counter() {
485        let mut executor = StreamExecutor::new(vec![FittedFilter {
486            name: "double".into(),
487            filter: Arc::new(DoubleChunk),
488            state: Value::Empty,
489        }]);
490
491        assert_eq!(executor.chunks_processed(), 0);
492        executor
493            .process_chunk(Value::tensor(vec![1.0], vec![1]))
494            .unwrap();
495        assert_eq!(executor.chunks_processed(), 1);
496        executor
497            .process_chunk(Value::tensor(vec![2.0], vec![1]))
498            .unwrap();
499        assert_eq!(executor.chunks_processed(), 2);
500    }
501
502    #[test]
503    fn empty_stream() {
504        let mut executor = StreamExecutor::new(vec![FittedFilter {
505            name: "double".into(),
506            filter: Arc::new(DoubleChunk),
507            state: Value::Empty,
508        }]);
509
510        let outputs = executor.process_all(vec![]).unwrap();
511        assert!(outputs.is_empty());
512    }
513}