Skip to main content

somatize_runtime/
stream.rs

1//! Streaming executor — processes data in chunks through fitted filters.
2//!
3//! Respects each filter's [`StreamMode`]: FixedState processes chunks
4//! independently, Evolving updates state per chunk with checkpoints,
5//! Barrier accumulates all chunks before processing.
6
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13/// A fitted filter with its learned state, ready for streaming.
14pub struct FittedFilter {
15    pub name: String,
16    pub filter: Arc<dyn Filter>,
17    pub state: Value,
18}
19
20/// Processes a stream of chunks through a sequence of fitted filters.
21///
22/// Respects each filter's StreamMode:
23/// - FixedState: each chunk processed independently, cacheable per chunk
24/// - Evolving: state mutates with each chunk, periodic checkpoints
25/// - Barrier: accumulates all chunks, processes as batch
26pub struct StreamExecutor {
27    filters: Vec<FittedFilter>,
28    cache: Option<Arc<dyn CacheStore>>,
29    /// Accumulated chunks for Barrier filters (keyed by filter index).
30    barrier_buffers: Vec<Vec<Value>>,
31    /// Evolving states (keyed by filter index, mutated on each chunk).
32    evolving_states: Vec<Option<Value>>,
33    /// Chunk counter for checkpoint scheduling.
34    chunk_count: usize,
35}
36
37impl StreamExecutor {
38    pub fn new(filters: Vec<FittedFilter>) -> Self {
39        let n = filters.len();
40        Self {
41            filters,
42            cache: None,
43            barrier_buffers: vec![Vec::new(); n],
44            evolving_states: vec![None; n],
45            chunk_count: 0,
46        }
47    }
48
49    pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
50        self.cache = Some(cache);
51        self
52    }
53
54    /// Process a single chunk through the pipeline.
55    ///
56    /// Returns the output chunk, or None if a Barrier filter is still accumulating.
57    pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
58        let mut current = chunk;
59        self.chunk_count += 1;
60
61        let n = self.filters.len();
62        for i in 0..n {
63            let mode = self.filters[i].filter.meta().stream_mode;
64
65            match mode {
66                StreamMode::FixedState => {
67                    current = self.process_fixed_state(i, &current)?;
68                }
69                StreamMode::Evolving { checkpoint_every } => {
70                    current = self.process_evolving(i, &current, checkpoint_every)?;
71                }
72                StreamMode::Barrier => {
73                    self.barrier_buffers[i].push(current);
74                    return Ok(None);
75                }
76                _ => {
77                    current = self.process_fixed_state(i, &current)?;
78                }
79            }
80        }
81
82        Ok(Some(current))
83    }
84
85    /// Flush barrier filters and process remaining data as batch.
86    ///
87    /// Call this after the stream ends to materialize barrier outputs.
88    pub fn flush(&mut self) -> Result<Option<Value>> {
89        let mut current: Option<Value> = None;
90        let n = self.filters.len();
91
92        for i in 0..n {
93            let mode = self.filters[i].filter.meta().stream_mode;
94
95            if mode == StreamMode::Barrier && !self.barrier_buffers[i].is_empty() {
96                let materialized = self.materialize_buffer(i)?;
97                let result = self.filters[i]
98                    .filter
99                    .forward(&materialized, &self.filters[i].state)?;
100                self.barrier_buffers[i].clear();
101                current = Some(result);
102            } else if let Some(val) = current.take() {
103                let result = self.filters[i]
104                    .filter
105                    .forward(&val, &self.filters[i].state)?;
106                current = Some(result);
107            }
108        }
109
110        Ok(current)
111    }
112
113    /// Process multiple chunks and collect outputs.
114    pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
115        let mut outputs = Vec::new();
116
117        for chunk in chunks {
118            if let Some(output) = self.process_chunk(chunk)? {
119                outputs.push(output);
120            }
121        }
122
123        // Flush any barrier buffers
124        if let Some(flushed) = self.flush()? {
125            outputs.push(flushed);
126        }
127
128        Ok(outputs)
129    }
130
131    /// Number of chunks processed so far.
132    pub fn chunks_processed(&self) -> usize {
133        self.chunk_count
134    }
135
136    fn process_fixed_state(&self, filter_idx: usize, input: &Value) -> Result<Value> {
137        let fitted = &self.filters[filter_idx];
138
139        // Try cache
140        if let Some(cache) = &self.cache {
141            let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
142            let cache_key = CacheKey::for_output(
143                &fitted.filter.config_hash(),
144                &CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default()),
145                &chunk_hash,
146            );
147            if let Some(cached) = cache.get(&cache_key)? {
148                return Ok(cached);
149            }
150            let result = fitted.filter.forward(input, &fitted.state)?;
151            let _ = cache.put(&cache_key, &result);
152            return Ok(result);
153        }
154
155        fitted.filter.forward(input, &fitted.state)
156    }
157
158    fn process_evolving(
159        &mut self,
160        filter_idx: usize,
161        input: &Value,
162        checkpoint_every: usize,
163    ) -> Result<Value> {
164        let fitted = &self.filters[filter_idx];
165
166        // Use evolving state if available, else initial state
167        let state = self.evolving_states[filter_idx]
168            .as_ref()
169            .unwrap_or(&fitted.state);
170
171        let result = fitted.filter.forward(input, state)?;
172
173        // For evolving: the output becomes the new state for next chunk
174        // (simplified model: state = last output)
175        self.evolving_states[filter_idx] = Some(result.clone());
176
177        // Checkpoint
178        if checkpoint_every > 0
179            && self.chunk_count.is_multiple_of(checkpoint_every)
180            && let Some(cache) = &self.cache
181        {
182            let checkpoint_key = CacheKey::from_parts(&[
183                b"checkpoint",
184                fitted.name.as_bytes(),
185                &(self.chunk_count as u64).to_le_bytes(),
186            ]);
187            let _ = cache.put(&checkpoint_key, &result);
188        }
189
190        Ok(result)
191    }
192
193    fn materialize_buffer(&self, filter_idx: usize) -> Result<Value> {
194        let buffer = &self.barrier_buffers[filter_idx];
195        if buffer.is_empty() {
196            return Ok(Value::Empty);
197        }
198
199        // Concatenate tensor chunks along first dimension
200        let mut all_data = Vec::new();
201        let mut total_rows = 0;
202        let mut cols = 0;
203
204        for chunk in buffer {
205            match chunk {
206                Value::Tensor { values, shape } => {
207                    all_data.extend(values);
208                    if shape.len() == 1 {
209                        total_rows += shape[0];
210                        cols = 1;
211                    } else if shape.len() >= 2 {
212                        total_rows += shape[0];
213                        cols = shape[1];
214                    }
215                }
216                _ => {
217                    return Err(SomaError::Other(
218                        "barrier buffer contains non-tensor values".into(),
219                    ));
220                }
221            }
222        }
223
224        if cols <= 1 {
225            Ok(Value::tensor(all_data, vec![total_rows]))
226        } else {
227            Ok(Value::tensor(all_data, vec![total_rows, cols]))
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use somatize_core::cache::CacheKey;
236    use somatize_core::filter::{FilterKind, FilterMeta};
237
238    // ── Test filters ──
239
240    struct DoubleChunk;
241    impl Filter for DoubleChunk {
242        fn config_hash(&self) -> CacheKey {
243            CacheKey::from_parts(&[b"DoubleChunk"])
244        }
245        fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
246            Ok(Value::Empty)
247        }
248        fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
249            match x {
250                Value::Tensor { values, shape } => Ok(Value::tensor(
251                    values.iter().map(|v| v * 2.0).collect(),
252                    shape.clone(),
253                )),
254                _ => Ok(x.clone()),
255            }
256        }
257        fn meta(&self) -> FilterMeta {
258            FilterMeta {
259                name: "DoubleChunk".into(),
260                kind: FilterKind::Stateless,
261                cacheable: true,
262                differentiable: true,
263                stream_mode: StreamMode::FixedState,
264                distribution: somatize_core::filter::Distribution::Local,
265                input_schema: None,
266                output_schema: None,
267            }
268        }
269    }
270
271    struct Accumulator;
272    impl Filter for Accumulator {
273        fn config_hash(&self) -> CacheKey {
274            CacheKey::from_parts(&[b"Accumulator"])
275        }
276        fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
277            Ok(Value::Empty)
278        }
279        fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
280            // For barrier: receives concatenated tensor, computes mean
281            match x {
282                Value::Tensor { values, shape: _ } => {
283                    let mean = values.iter().sum::<f64>() / values.len() as f64;
284                    Ok(Value::tensor(vec![mean], vec![1]))
285                }
286                _ => Ok(x.clone()),
287            }
288        }
289        fn meta(&self) -> FilterMeta {
290            FilterMeta {
291                name: "Accumulator".into(),
292                kind: FilterKind::Trainable,
293                cacheable: false,
294                differentiable: false,
295                stream_mode: StreamMode::Barrier,
296                distribution: somatize_core::filter::Distribution::Local,
297                input_schema: None,
298                output_schema: None,
299            }
300        }
301    }
302
303    struct RunningSum;
304    impl Filter for RunningSum {
305        fn config_hash(&self) -> CacheKey {
306            CacheKey::from_parts(&[b"RunningSum"])
307        }
308        fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
309            Ok(Value::tensor(vec![0.0], vec![1]))
310        }
311        fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
312            let x_val = x.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
313            let s_val = state.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
314            Ok(Value::tensor(vec![x_val + s_val], vec![1]))
315        }
316        fn meta(&self) -> FilterMeta {
317            FilterMeta {
318                name: "RunningSum".into(),
319                kind: FilterKind::Trainable,
320                cacheable: false,
321                differentiable: false,
322                stream_mode: StreamMode::Evolving {
323                    checkpoint_every: 3,
324                },
325                distribution: somatize_core::filter::Distribution::Local,
326                input_schema: None,
327                output_schema: None,
328            }
329        }
330    }
331
332    // ── Tests ──
333
334    #[test]
335    fn fixed_state_processes_each_chunk() {
336        let mut executor = StreamExecutor::new(vec![FittedFilter {
337            name: "double".into(),
338            filter: Arc::new(DoubleChunk),
339            state: Value::Empty,
340        }]);
341
342        let chunks = vec![
343            Value::tensor(vec![1.0, 2.0], vec![2]),
344            Value::tensor(vec![3.0, 4.0], vec![2]),
345            Value::tensor(vec![5.0], vec![1]),
346        ];
347
348        let outputs = executor.process_all(chunks).unwrap();
349        assert_eq!(outputs.len(), 3);
350
351        let (d0, _) = outputs[0].as_tensor().unwrap();
352        assert_eq!(d0, &[2.0, 4.0]);
353        let (d1, _) = outputs[1].as_tensor().unwrap();
354        assert_eq!(d1, &[6.0, 8.0]);
355        let (d2, _) = outputs[2].as_tensor().unwrap();
356        assert_eq!(d2, &[10.0]);
357    }
358
359    #[test]
360    fn barrier_accumulates_then_flushes() {
361        let mut executor = StreamExecutor::new(vec![FittedFilter {
362            name: "acc".into(),
363            filter: Arc::new(Accumulator),
364            state: Value::Empty,
365        }]);
366
367        // Process chunks: barrier should return None for each
368        assert!(
369            executor
370                .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
371                .unwrap()
372                .is_none()
373        );
374        assert!(
375            executor
376                .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
377                .unwrap()
378                .is_none()
379        );
380        assert!(
381            executor
382                .process_chunk(Value::tensor(vec![5.0, 6.0], vec![2]))
383                .unwrap()
384                .is_none()
385        );
386
387        // Flush: should materialize and compute mean of [1,2,3,4,5,6]
388        let result = executor.flush().unwrap().unwrap();
389        let (data, _) = result.as_tensor().unwrap();
390        assert!((data[0] - 3.5).abs() < 0.01); // mean of 1..6
391    }
392
393    #[test]
394    fn evolving_state_accumulates() {
395        let mut executor = StreamExecutor::new(vec![FittedFilter {
396            name: "sum".into(),
397            filter: Arc::new(RunningSum),
398            state: Value::tensor(vec![0.0], vec![1]), // initial sum = 0
399        }]);
400
401        let r1 = executor
402            .process_chunk(Value::tensor(vec![5.0], vec![1]))
403            .unwrap()
404            .unwrap();
405        assert_eq!(r1.as_tensor().unwrap().0, &[5.0]); // 0+5=5
406
407        let r2 = executor
408            .process_chunk(Value::tensor(vec![3.0], vec![1]))
409            .unwrap()
410            .unwrap();
411        assert_eq!(r2.as_tensor().unwrap().0, &[8.0]); // 5+3=8
412
413        let r3 = executor
414            .process_chunk(Value::tensor(vec![2.0], vec![1]))
415            .unwrap()
416            .unwrap();
417        assert_eq!(r3.as_tensor().unwrap().0, &[10.0]); // 8+2=10
418    }
419
420    #[test]
421    fn mixed_pipeline_fixed_then_barrier() {
422        let mut executor = StreamExecutor::new(vec![
423            FittedFilter {
424                name: "double".into(),
425                filter: Arc::new(DoubleChunk),
426                state: Value::Empty,
427            },
428            FittedFilter {
429                name: "acc".into(),
430                filter: Arc::new(Accumulator),
431                state: Value::Empty,
432            },
433        ]);
434
435        let chunks = vec![
436            Value::tensor(vec![1.0], vec![1]),
437            Value::tensor(vec![2.0], vec![1]),
438            Value::tensor(vec![3.0], vec![1]),
439        ];
440
441        let outputs = executor.process_all(chunks).unwrap();
442        // DoubleChunk doubles: [2,4,6]. Accumulator sees barrier after double.
443        // After flush: mean of [2,4,6] = 4.0
444        assert_eq!(outputs.len(), 1);
445        let (data, _) = outputs[0].as_tensor().unwrap();
446        assert!((data[0] - 4.0).abs() < 0.01);
447    }
448
449    #[test]
450    fn fixed_state_with_cache() {
451        let cache = Arc::new(crate::MemoryCache::default());
452        let mut executor = StreamExecutor::new(vec![FittedFilter {
453            name: "double".into(),
454            filter: Arc::new(DoubleChunk),
455            state: Value::Empty,
456        }])
457        .with_cache(cache.clone());
458
459        let chunk = Value::tensor(vec![7.0], vec![1]);
460
461        // First call: cache miss
462        let r1 = executor.process_chunk(chunk.clone()).unwrap().unwrap();
463        assert_eq!(r1.as_tensor().unwrap().0, &[14.0]);
464        assert!(!cache.is_empty()); // cached
465
466        // Second call with same chunk: cache hit
467        let r2 = executor.process_chunk(chunk).unwrap().unwrap();
468        assert_eq!(r2.as_tensor().unwrap().0, &[14.0]);
469    }
470
471    #[test]
472    fn chunks_processed_counter() {
473        let mut executor = StreamExecutor::new(vec![FittedFilter {
474            name: "double".into(),
475            filter: Arc::new(DoubleChunk),
476            state: Value::Empty,
477        }]);
478
479        assert_eq!(executor.chunks_processed(), 0);
480        executor
481            .process_chunk(Value::tensor(vec![1.0], vec![1]))
482            .unwrap();
483        assert_eq!(executor.chunks_processed(), 1);
484        executor
485            .process_chunk(Value::tensor(vec![2.0], vec![1]))
486            .unwrap();
487        assert_eq!(executor.chunks_processed(), 2);
488    }
489
490    #[test]
491    fn empty_stream() {
492        let mut executor = StreamExecutor::new(vec![FittedFilter {
493            name: "double".into(),
494            filter: Arc::new(DoubleChunk),
495            state: Value::Empty,
496        }]);
497
498        let outputs = executor.process_all(vec![]).unwrap();
499        assert!(outputs.is_empty());
500    }
501}