1use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::{Result, SomaError};
9use somatize_core::filter::{Filter, StreamMode};
10use somatize_core::value::Value;
11use std::sync::Arc;
12
13pub struct FittedFilter {
15 pub name: String,
16 pub filter: Arc<dyn Filter>,
17 pub state: Value,
18}
19
20pub struct StreamExecutor {
27 filters: Vec<FittedFilter>,
28 cache: Option<Arc<dyn CacheStore>>,
29 barrier_buffers: Vec<Vec<Value>>,
31 evolving_states: Vec<Option<Value>>,
33 chunk_count: usize,
35}
36
37impl StreamExecutor {
38 pub fn new(filters: Vec<FittedFilter>) -> Self {
39 let n = filters.len();
40 Self {
41 filters,
42 cache: None,
43 barrier_buffers: vec![Vec::new(); n],
44 evolving_states: vec![None; n],
45 chunk_count: 0,
46 }
47 }
48
49 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
50 self.cache = Some(cache);
51 self
52 }
53
54 pub fn process_chunk(&mut self, chunk: Value) -> Result<Option<Value>> {
58 let mut current = chunk;
59 self.chunk_count += 1;
60
61 let n = self.filters.len();
62 for i in 0..n {
63 let mode = self.filters[i].filter.meta().stream_mode;
64
65 match mode {
66 StreamMode::FixedState => {
67 current = self.process_fixed_state(i, ¤t)?;
68 }
69 StreamMode::Evolving { checkpoint_every } => {
70 current = self.process_evolving(i, ¤t, checkpoint_every)?;
71 }
72 StreamMode::Barrier => {
73 self.barrier_buffers[i].push(current);
74 return Ok(None);
75 }
76 _ => {
77 current = self.process_fixed_state(i, ¤t)?;
78 }
79 }
80 }
81
82 Ok(Some(current))
83 }
84
85 pub fn flush(&mut self) -> Result<Option<Value>> {
89 let mut current: Option<Value> = None;
90 let n = self.filters.len();
91
92 for i in 0..n {
93 let mode = self.filters[i].filter.meta().stream_mode;
94
95 if mode == StreamMode::Barrier && !self.barrier_buffers[i].is_empty() {
96 let materialized = self.materialize_buffer(i)?;
97 let result = self.filters[i]
98 .filter
99 .forward(&materialized, &self.filters[i].state)?;
100 self.barrier_buffers[i].clear();
101 current = Some(result);
102 } else if let Some(val) = current.take() {
103 let result = self.filters[i]
104 .filter
105 .forward(&val, &self.filters[i].state)?;
106 current = Some(result);
107 }
108 }
109
110 Ok(current)
111 }
112
113 pub fn process_all(&mut self, chunks: Vec<Value>) -> Result<Vec<Value>> {
115 let mut outputs = Vec::new();
116
117 for chunk in chunks {
118 if let Some(output) = self.process_chunk(chunk)? {
119 outputs.push(output);
120 }
121 }
122
123 if let Some(flushed) = self.flush()? {
125 outputs.push(flushed);
126 }
127
128 Ok(outputs)
129 }
130
131 pub fn chunks_processed(&self) -> usize {
133 self.chunk_count
134 }
135
136 fn process_fixed_state(&self, filter_idx: usize, input: &Value) -> Result<Value> {
137 let fitted = &self.filters[filter_idx];
138
139 if let Some(cache) = &self.cache {
141 let chunk_hash = CacheKey::hash_data(&serde_json::to_vec(input).unwrap_or_default());
142 let cache_key = CacheKey::for_output(
143 &fitted.filter.config_hash(),
144 &CacheKey::hash_data(&serde_json::to_vec(&fitted.state).unwrap_or_default()),
145 &chunk_hash,
146 );
147 if let Some(cached) = cache.get(&cache_key)? {
148 return Ok(cached);
149 }
150 let result = fitted.filter.forward(input, &fitted.state)?;
151 let _ = cache.put(&cache_key, &result);
152 return Ok(result);
153 }
154
155 fitted.filter.forward(input, &fitted.state)
156 }
157
158 fn process_evolving(
159 &mut self,
160 filter_idx: usize,
161 input: &Value,
162 checkpoint_every: usize,
163 ) -> Result<Value> {
164 let fitted = &self.filters[filter_idx];
165
166 let state = self.evolving_states[filter_idx]
168 .as_ref()
169 .unwrap_or(&fitted.state);
170
171 let result = fitted.filter.forward(input, state)?;
172
173 self.evolving_states[filter_idx] = Some(result.clone());
176
177 if checkpoint_every > 0
179 && self.chunk_count.is_multiple_of(checkpoint_every)
180 && let Some(cache) = &self.cache
181 {
182 let checkpoint_key = CacheKey::from_parts(&[
183 b"checkpoint",
184 fitted.name.as_bytes(),
185 &(self.chunk_count as u64).to_le_bytes(),
186 ]);
187 let _ = cache.put(&checkpoint_key, &result);
188 }
189
190 Ok(result)
191 }
192
193 fn materialize_buffer(&self, filter_idx: usize) -> Result<Value> {
194 let buffer = &self.barrier_buffers[filter_idx];
195 if buffer.is_empty() {
196 return Ok(Value::Empty);
197 }
198
199 let mut all_data = Vec::new();
201 let mut total_rows = 0;
202 let mut cols = 0;
203
204 for chunk in buffer {
205 match chunk {
206 Value::Tensor { values, shape } => {
207 all_data.extend(values);
208 if shape.len() == 1 {
209 total_rows += shape[0];
210 cols = 1;
211 } else if shape.len() >= 2 {
212 total_rows += shape[0];
213 cols = shape[1];
214 }
215 }
216 _ => {
217 return Err(SomaError::Other(
218 "barrier buffer contains non-tensor values".into(),
219 ));
220 }
221 }
222 }
223
224 if cols <= 1 {
225 Ok(Value::tensor(all_data, vec![total_rows]))
226 } else {
227 Ok(Value::tensor(all_data, vec![total_rows, cols]))
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use somatize_core::cache::CacheKey;
236 use somatize_core::filter::{FilterKind, FilterMeta};
237
238 struct DoubleChunk;
241 impl Filter for DoubleChunk {
242 fn config_hash(&self) -> CacheKey {
243 CacheKey::from_parts(&[b"DoubleChunk"])
244 }
245 fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
246 Ok(Value::Empty)
247 }
248 fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
249 match x {
250 Value::Tensor { values, shape } => Ok(Value::tensor(
251 values.iter().map(|v| v * 2.0).collect(),
252 shape.clone(),
253 )),
254 _ => Ok(x.clone()),
255 }
256 }
257 fn meta(&self) -> FilterMeta {
258 FilterMeta {
259 name: "DoubleChunk".into(),
260 kind: FilterKind::Stateless,
261 cacheable: true,
262 differentiable: true,
263 stream_mode: StreamMode::FixedState,
264 distribution: somatize_core::filter::Distribution::Local,
265 input_schema: None,
266 output_schema: None,
267 }
268 }
269
270 fn as_any(&self) -> &dyn std::any::Any {
271 self
272 }
273 }
274
275 struct Accumulator;
276 impl Filter for Accumulator {
277 fn config_hash(&self) -> CacheKey {
278 CacheKey::from_parts(&[b"Accumulator"])
279 }
280 fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
281 Ok(Value::Empty)
282 }
283 fn forward(&self, x: &Value, _: &Value) -> Result<Value> {
284 match x {
286 Value::Tensor { values, shape: _ } => {
287 let mean = values.iter().sum::<f64>() / values.len() as f64;
288 Ok(Value::tensor(vec![mean], vec![1]))
289 }
290 _ => Ok(x.clone()),
291 }
292 }
293 fn meta(&self) -> FilterMeta {
294 FilterMeta {
295 name: "Accumulator".into(),
296 kind: FilterKind::Trainable,
297 cacheable: false,
298 differentiable: false,
299 stream_mode: StreamMode::Barrier,
300 distribution: somatize_core::filter::Distribution::Local,
301 input_schema: None,
302 output_schema: None,
303 }
304 }
305
306 fn as_any(&self) -> &dyn std::any::Any {
307 self
308 }
309 }
310
311 struct RunningSum;
312 impl Filter for RunningSum {
313 fn config_hash(&self) -> CacheKey {
314 CacheKey::from_parts(&[b"RunningSum"])
315 }
316 fn fit(&self, _: &Value, _: Option<&Value>) -> Result<Value> {
317 Ok(Value::tensor(vec![0.0], vec![1]))
318 }
319 fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
320 let x_val = x.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
321 let s_val = state.as_tensor().map(|(d, _)| d[0]).unwrap_or(0.0);
322 Ok(Value::tensor(vec![x_val + s_val], vec![1]))
323 }
324 fn meta(&self) -> FilterMeta {
325 FilterMeta {
326 name: "RunningSum".into(),
327 kind: FilterKind::Trainable,
328 cacheable: false,
329 differentiable: false,
330 stream_mode: StreamMode::Evolving {
331 checkpoint_every: 3,
332 },
333 distribution: somatize_core::filter::Distribution::Local,
334 input_schema: None,
335 output_schema: None,
336 }
337 }
338
339 fn as_any(&self) -> &dyn std::any::Any {
340 self
341 }
342 }
343
344 #[test]
347 fn fixed_state_processes_each_chunk() {
348 let mut executor = StreamExecutor::new(vec![FittedFilter {
349 name: "double".into(),
350 filter: Arc::new(DoubleChunk),
351 state: Value::Empty,
352 }]);
353
354 let chunks = vec![
355 Value::tensor(vec![1.0, 2.0], vec![2]),
356 Value::tensor(vec![3.0, 4.0], vec![2]),
357 Value::tensor(vec![5.0], vec![1]),
358 ];
359
360 let outputs = executor.process_all(chunks).unwrap();
361 assert_eq!(outputs.len(), 3);
362
363 let (d0, _) = outputs[0].as_tensor().unwrap();
364 assert_eq!(d0, &[2.0, 4.0]);
365 let (d1, _) = outputs[1].as_tensor().unwrap();
366 assert_eq!(d1, &[6.0, 8.0]);
367 let (d2, _) = outputs[2].as_tensor().unwrap();
368 assert_eq!(d2, &[10.0]);
369 }
370
371 #[test]
372 fn barrier_accumulates_then_flushes() {
373 let mut executor = StreamExecutor::new(vec![FittedFilter {
374 name: "acc".into(),
375 filter: Arc::new(Accumulator),
376 state: Value::Empty,
377 }]);
378
379 assert!(
381 executor
382 .process_chunk(Value::tensor(vec![1.0, 2.0], vec![2]))
383 .unwrap()
384 .is_none()
385 );
386 assert!(
387 executor
388 .process_chunk(Value::tensor(vec![3.0, 4.0], vec![2]))
389 .unwrap()
390 .is_none()
391 );
392 assert!(
393 executor
394 .process_chunk(Value::tensor(vec![5.0, 6.0], vec![2]))
395 .unwrap()
396 .is_none()
397 );
398
399 let result = executor.flush().unwrap().unwrap();
401 let (data, _) = result.as_tensor().unwrap();
402 assert!((data[0] - 3.5).abs() < 0.01); }
404
405 #[test]
406 fn evolving_state_accumulates() {
407 let mut executor = StreamExecutor::new(vec![FittedFilter {
408 name: "sum".into(),
409 filter: Arc::new(RunningSum),
410 state: Value::tensor(vec![0.0], vec![1]), }]);
412
413 let r1 = executor
414 .process_chunk(Value::tensor(vec![5.0], vec![1]))
415 .unwrap()
416 .unwrap();
417 assert_eq!(r1.as_tensor().unwrap().0, &[5.0]); let r2 = executor
420 .process_chunk(Value::tensor(vec![3.0], vec![1]))
421 .unwrap()
422 .unwrap();
423 assert_eq!(r2.as_tensor().unwrap().0, &[8.0]); let r3 = executor
426 .process_chunk(Value::tensor(vec![2.0], vec![1]))
427 .unwrap()
428 .unwrap();
429 assert_eq!(r3.as_tensor().unwrap().0, &[10.0]); }
431
432 #[test]
433 fn mixed_pipeline_fixed_then_barrier() {
434 let mut executor = StreamExecutor::new(vec![
435 FittedFilter {
436 name: "double".into(),
437 filter: Arc::new(DoubleChunk),
438 state: Value::Empty,
439 },
440 FittedFilter {
441 name: "acc".into(),
442 filter: Arc::new(Accumulator),
443 state: Value::Empty,
444 },
445 ]);
446
447 let chunks = vec![
448 Value::tensor(vec![1.0], vec![1]),
449 Value::tensor(vec![2.0], vec![1]),
450 Value::tensor(vec![3.0], vec![1]),
451 ];
452
453 let outputs = executor.process_all(chunks).unwrap();
454 assert_eq!(outputs.len(), 1);
457 let (data, _) = outputs[0].as_tensor().unwrap();
458 assert!((data[0] - 4.0).abs() < 0.01);
459 }
460
461 #[test]
462 fn fixed_state_with_cache() {
463 let cache = Arc::new(crate::MemoryCache::default());
464 let mut executor = StreamExecutor::new(vec![FittedFilter {
465 name: "double".into(),
466 filter: Arc::new(DoubleChunk),
467 state: Value::Empty,
468 }])
469 .with_cache(cache.clone());
470
471 let chunk = Value::tensor(vec![7.0], vec![1]);
472
473 let r1 = executor.process_chunk(chunk.clone()).unwrap().unwrap();
475 assert_eq!(r1.as_tensor().unwrap().0, &[14.0]);
476 assert!(!cache.is_empty()); let r2 = executor.process_chunk(chunk).unwrap().unwrap();
480 assert_eq!(r2.as_tensor().unwrap().0, &[14.0]);
481 }
482
483 #[test]
484 fn chunks_processed_counter() {
485 let mut executor = StreamExecutor::new(vec![FittedFilter {
486 name: "double".into(),
487 filter: Arc::new(DoubleChunk),
488 state: Value::Empty,
489 }]);
490
491 assert_eq!(executor.chunks_processed(), 0);
492 executor
493 .process_chunk(Value::tensor(vec![1.0], vec![1]))
494 .unwrap();
495 assert_eq!(executor.chunks_processed(), 1);
496 executor
497 .process_chunk(Value::tensor(vec![2.0], vec![1]))
498 .unwrap();
499 assert_eq!(executor.chunks_processed(), 2);
500 }
501
502 #[test]
503 fn empty_stream() {
504 let mut executor = StreamExecutor::new(vec![FittedFilter {
505 name: "double".into(),
506 filter: Arc::new(DoubleChunk),
507 state: Value::Empty,
508 }]);
509
510 let outputs = executor.process_all(vec![]).unwrap();
511 assert!(outputs.is_empty());
512 }
513}