Skip to main content

prism/
engine.rs

1// Copyright 2024-2026 Reflective Labs
2
3use anyhow::{Result, anyhow};
4use converge_pack::{
5    AgentEffect, Context, ContextKey, FactPayload, ProvenanceSource, Suggestor, TextPayload,
6};
7use polars::prelude::*;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10
11use crate::provenance::PRISM_PROVENANCE;
12
13/// Typed payload representing computed features.
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(deny_unknown_fields)]
16pub struct FeatureVector {
17    pub data: Vec<f32>,
18    pub shape: [usize; 2],
19}
20
21impl FactPayload for FeatureVector {
22    const FAMILY: &'static str = "prism.feature-vector";
23    const VERSION: u16 = 1;
24}
25
26impl FeatureVector {
27    pub fn new(data: Vec<f32>, shape: [usize; 2]) -> Result<Self> {
28        let expected = shape
29            .first()
30            .and_then(|rows| shape.get(1).map(|cols| rows.saturating_mul(*cols)))
31            .unwrap_or(0);
32        if data.len() != expected {
33            return Err(anyhow!(
34                "feature data length {} does not match shape {:?}",
35                data.len(),
36                shape
37            ));
38        }
39        Ok(Self { data, shape })
40    }
41
42    pub fn row(data: Vec<f32>) -> Self {
43        let cols = data.len();
44        Self {
45            data,
46            shape: [1, cols],
47        }
48    }
49
50    pub fn rows(&self) -> usize {
51        self.shape[0]
52    }
53
54    pub fn cols(&self) -> usize {
55        self.shape[1]
56    }
57}
58
59#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct FeatureColumns {
61    pub left: String,
62    pub right: String,
63}
64
65#[derive(Clone, Debug)]
66pub struct FeatureAgent {
67    source_path: Option<PathBuf>,
68    columns: Option<FeatureColumns>,
69}
70
71impl FeatureAgent {
72    pub fn new(source_path: Option<PathBuf>) -> Self {
73        Self {
74            source_path,
75            columns: None,
76        }
77    }
78
79    pub fn with_columns(mut self, left: impl Into<String>, right: impl Into<String>) -> Self {
80        self.columns = Some(FeatureColumns {
81            left: left.into(),
82            right: right.into(),
83        });
84        self
85    }
86
87    /// Internal Polars logic to compute features
88    fn compute_features(&self) -> Result<FeatureVector> {
89        let df = if let Some(path) = &self.source_path {
90            load_dataframe(path)?
91        } else {
92            df! [
93                "x1" => [1.0, 2.0, 3.0],
94                "x2" => [4.0, 5.0, 6.0],
95                "x3" => [7.0, 8.0, 9.0],
96            ]?
97        };
98        compute_features_from_df(&df, self.columns.as_ref())
99    }
100}
101
102#[async_trait::async_trait]
103impl Suggestor for FeatureAgent {
104    fn name(&self) -> &'static str {
105        "FeatureAgent (Polars)"
106    }
107
108    fn dependencies(&self) -> &[ContextKey] {
109        // Depends on Seeds to know WHAT to process
110        &[ContextKey::Seeds]
111    }
112
113    fn accepts(&self, ctx: &dyn Context) -> bool {
114        // Run if we have Seeds but haven't produced Proposals yet
115        ctx.has(ContextKey::Seeds) && !ctx.has(ContextKey::Proposals)
116    }
117
118    fn provenance(&self) -> &'static str {
119        PRISM_PROVENANCE.as_str()
120    }
121
122    async fn execute(&self, _ctx: &dyn Context) -> AgentEffect {
123        // 1. Compute features using Polars
124        let features = match self.compute_features() {
125            Ok(f) => f,
126            Err(e) => {
127                return AgentEffect::with_proposal(PRISM_PROVENANCE.proposed_fact(
128                    ContextKey::Diagnostic,
129                    "feature-agent-error",
130                    TextPayload::new(e.to_string()),
131                ));
132            }
133        };
134
135        // 2. Propose the features
136        let proposal =
137            PRISM_PROVENANCE.proposed_fact(ContextKey::Proposals, "features-001", features);
138
139        // Note: In a real agent, we might emit a Fact directly if trusted, or a ProposedFact.
140        // converge_core usually requires TryFrom implementation or specific flow.
141        // For simplicity, we assume we can emit effects.
142        // Wait, AgentEffect::with_proposal?
143        // Let's check AgentEffect definition.
144
145        // Use the constructor for single proposal
146        AgentEffect::with_proposal(proposal)
147    }
148}
149
150fn compute_features_from_df(
151    df: &DataFrame,
152    columns: Option<&FeatureColumns>,
153) -> Result<FeatureVector> {
154    let (left, right) = if let Some(columns) = columns {
155        let left = df
156            .column(&columns.left)
157            .map_err(|_| anyhow!("missing column {}", columns.left))?;
158        let right = df
159            .column(&columns.right)
160            .map_err(|_| anyhow!("missing column {}", columns.right))?;
161        (left.clone(), right.clone())
162    } else {
163        let mut numeric = df
164            .get_columns()
165            .iter()
166            .filter(|series| is_numeric_dtype(series.dtype()))
167            .cloned()
168            .collect::<Vec<_>>();
169        if numeric.len() < 2 {
170            return Err(anyhow!("need at least two numeric columns"));
171        }
172        (numeric.remove(0), numeric.remove(0))
173    };
174
175    if left.is_empty() || right.is_empty() {
176        return Err(anyhow!("input data is empty"));
177    }
178
179    let left = left.cast(&DataType::Float32)?;
180    let right = right.cast(&DataType::Float32)?;
181
182    let left_val = left
183        .f32()?
184        .get(0)
185        .ok_or_else(|| anyhow!("missing left value"))?;
186    let right_val = right
187        .f32()?
188        .get(0)
189        .ok_or_else(|| anyhow!("missing right value"))?;
190
191    let interaction = left_val * right_val;
192    Ok(FeatureVector::row(vec![left_val, right_val, interaction]))
193}
194
195fn load_dataframe(path: &Path) -> Result<DataFrame> {
196    let extension = path
197        .extension()
198        .and_then(|ext| ext.to_str())
199        .unwrap_or("")
200        .to_ascii_lowercase();
201
202    let path_str = path
203        .to_str()
204        .ok_or_else(|| anyhow!("path is not valid utf-8: {}", path.display()))?;
205
206    match extension.as_str() {
207        "parquet" => {
208            let pl_path = PlPath::from_str(path_str);
209            Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
210        }
211        "csv" => Ok(CsvReadOptions::default()
212            .with_has_header(true)
213            .try_into_reader_with_file_path(Some(path.to_path_buf()))?
214            .finish()?),
215        _ => Err(anyhow!(
216            "unsupported data format for path {} (expected .csv or .parquet)",
217            path.display()
218        )),
219    }
220}
221
222fn is_numeric_dtype(dtype: &DataType) -> bool {
223    matches!(
224        dtype,
225        DataType::Int8
226            | DataType::Int16
227            | DataType::Int32
228            | DataType::Int64
229            | DataType::UInt8
230            | DataType::UInt16
231            | DataType::UInt32
232            | DataType::UInt64
233            | DataType::Float32
234            | DataType::Float64
235    )
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use proptest::prelude::*;
242    use std::collections::HashMap;
243    use std::fs;
244    use std::hint::black_box;
245    use std::time::Instant;
246    use std::time::{SystemTime, UNIX_EPOCH};
247
248    #[test]
249    fn feature_vector_validates_shape() {
250        let ok = FeatureVector::new(vec![1.0, 2.0], [1, 2]).unwrap();
251        assert_eq!(ok.rows(), 1);
252        assert_eq!(ok.cols(), 2);
253        assert!(FeatureVector::new(vec![1.0], [1, 2]).is_err());
254    }
255
256    #[test]
257    fn feature_vector_new_multi_row() {
258        let fv = FeatureVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]).unwrap();
259        assert_eq!(fv.rows(), 2);
260        assert_eq!(fv.cols(), 3);
261        assert_eq!(fv.data.len(), 6);
262    }
263
264    #[test]
265    fn feature_vector_new_rejects_mismatched_length() {
266        assert!(FeatureVector::new(vec![1.0, 2.0, 3.0], [2, 2]).is_err());
267        assert!(FeatureVector::new(vec![], [1, 1]).is_err());
268        assert!(FeatureVector::new(vec![1.0], [0, 1]).is_err());
269    }
270
271    #[test]
272    fn feature_vector_new_empty() {
273        let fv = FeatureVector::new(vec![], [0, 0]).unwrap();
274        assert_eq!(fv.rows(), 0);
275        assert_eq!(fv.cols(), 0);
276        assert!(fv.data.is_empty());
277    }
278
279    #[test]
280    fn feature_vector_new_zero_cols() {
281        let fv = FeatureVector::new(vec![], [5, 0]).unwrap();
282        assert_eq!(fv.rows(), 5);
283        assert_eq!(fv.cols(), 0);
284    }
285
286    #[test]
287    fn feature_vector_row_creates_single_row() {
288        let fv = FeatureVector::row(vec![10.0, 20.0, 30.0]);
289        assert_eq!(fv.rows(), 1);
290        assert_eq!(fv.cols(), 3);
291        assert_eq!(fv.data, vec![10.0, 20.0, 30.0]);
292    }
293
294    #[test]
295    fn feature_vector_row_empty() {
296        let fv = FeatureVector::row(vec![]);
297        assert_eq!(fv.rows(), 1);
298        assert_eq!(fv.cols(), 0);
299        assert!(fv.data.is_empty());
300    }
301
302    #[test]
303    fn feature_vector_row_single_element() {
304        let fv = FeatureVector::row(vec![42.0]);
305        assert_eq!(fv.rows(), 1);
306        assert_eq!(fv.cols(), 1);
307        assert_eq!(fv.data, vec![42.0]);
308    }
309
310    #[test]
311    fn feature_columns_construction() {
312        let fc = FeatureColumns {
313            left: "price".to_string(),
314            right: "quantity".to_string(),
315        };
316        assert_eq!(fc.left, "price");
317        assert_eq!(fc.right, "quantity");
318    }
319
320    #[test]
321    fn feature_columns_roundtrip_serde() {
322        let fc = FeatureColumns {
323            left: "a".to_string(),
324            right: "b".to_string(),
325        };
326        let json = serde_json::to_string(&fc).unwrap();
327        let deserialized: FeatureColumns = serde_json::from_str(&json).unwrap();
328        assert_eq!(deserialized.left, "a");
329        assert_eq!(deserialized.right, "b");
330    }
331
332    #[test]
333    fn feature_vector_roundtrip_serde() {
334        let fv = FeatureVector::new(vec![1.0, 2.0, 3.0, 4.0], [2, 2]).unwrap();
335        let json = serde_json::to_string(&fv).unwrap();
336        let deserialized: FeatureVector = serde_json::from_str(&json).unwrap();
337        assert_eq!(fv, deserialized);
338    }
339
340    #[test]
341    fn feature_agent_new_without_columns() {
342        let agent = FeatureAgent::new(None);
343        assert!(agent.source_path.is_none());
344        assert!(agent.columns.is_none());
345    }
346
347    #[test]
348    fn feature_agent_with_columns() {
349        let agent = FeatureAgent::new(None).with_columns("x", "y");
350        let cols = agent.columns.unwrap();
351        assert_eq!(cols.left, "x");
352        assert_eq!(cols.right, "y");
353    }
354
355    #[test]
356    fn feature_agent_with_source_path() {
357        let agent = FeatureAgent::new(Some(PathBuf::from("/tmp/data.csv")));
358        assert_eq!(agent.source_path.unwrap(), PathBuf::from("/tmp/data.csv"));
359    }
360
361    #[test]
362    fn is_numeric_dtype_covers_all_numeric_types() {
363        let numeric = [
364            DataType::Int8,
365            DataType::Int16,
366            DataType::Int32,
367            DataType::Int64,
368            DataType::UInt8,
369            DataType::UInt16,
370            DataType::UInt32,
371            DataType::UInt64,
372            DataType::Float32,
373            DataType::Float64,
374        ];
375        for dt in &numeric {
376            assert!(is_numeric_dtype(dt), "{dt:?} should be numeric");
377        }
378    }
379
380    #[test]
381    fn is_numeric_dtype_rejects_non_numeric() {
382        assert!(!is_numeric_dtype(&DataType::String));
383        assert!(!is_numeric_dtype(&DataType::Boolean));
384        assert!(!is_numeric_dtype(&DataType::Date));
385    }
386
387    #[test]
388    fn compute_features_rejects_empty_dataframe() {
389        let df = df![
390            "a" => Vec::<f32>::new(),
391            "b" => Vec::<f32>::new(),
392        ]
393        .unwrap();
394        let cols = FeatureColumns {
395            left: "a".into(),
396            right: "b".into(),
397        };
398        assert!(compute_features_from_df(&df, Some(&cols)).is_err());
399    }
400
401    #[test]
402    fn compute_features_rejects_missing_column() {
403        let df = df!["a" => [1.0f32]].unwrap();
404        let cols = FeatureColumns {
405            left: "a".into(),
406            right: "missing".into(),
407        };
408        assert!(compute_features_from_df(&df, Some(&cols)).is_err());
409    }
410
411    #[test]
412    fn compute_features_rejects_insufficient_numeric_columns() {
413        let df = df!["text" => ["a", "b"]].unwrap();
414        assert!(compute_features_from_df(&df, None).is_err());
415    }
416
417    proptest! {
418        #[test]
419        fn feature_vector_shape_invariant(
420            rows in 0usize..50,
421            cols in 0usize..50,
422        ) {
423            let len = rows.saturating_mul(cols);
424            let data = vec![0.0f32; len];
425            let fv = FeatureVector::new(data, [rows, cols]).unwrap();
426            prop_assert_eq!(fv.rows() * fv.cols(), fv.data.len());
427        }
428    }
429
430    #[test]
431    fn compute_features_from_df_uses_named_columns() {
432        let df = df![
433            "a" => [2.0f32, 3.0],
434            "b" => [4.0f32, 5.0],
435        ]
436        .unwrap();
437
438        let columns = FeatureColumns {
439            left: "a".into(),
440            right: "b".into(),
441        };
442        let features = compute_features_from_df(&df, Some(&columns)).unwrap();
443        assert_eq!(features.data, vec![2.0, 4.0, 8.0]);
444        assert_eq!(features.shape, [1, 3]);
445    }
446
447    #[test]
448    fn compute_features_from_df_falls_back_to_first_numeric_columns() {
449        let df = df![
450            "text" => ["x", "y"],
451            "a" => [1.5f32, 2.5],
452            "b" => [3.0f32, 4.0],
453        ]
454        .unwrap();
455
456        let features = compute_features_from_df(&df, None).unwrap();
457        assert_eq!(features.data, vec![1.5, 3.0, 4.5]);
458    }
459
460    #[test]
461    fn compute_features_handles_large_dataset() {
462        let rows = 10_000;
463        let left: Vec<f32> = (0..rows).map(|i| i as f32).collect();
464        let right: Vec<f32> = (0..rows).map(|i| (i as f32) + 1.0).collect();
465        let df = df![
466            "left" => left,
467            "right" => right,
468        ]
469        .unwrap();
470
471        let columns = FeatureColumns {
472            left: "left".into(),
473            right: "right".into(),
474        };
475        let features = compute_features_from_df(&df, Some(&columns)).unwrap();
476        assert_eq!(features.data, vec![0.0, 1.0, 0.0]);
477    }
478
479    #[test]
480    fn load_dataframe_reads_csv() {
481        let mut path = std::env::temp_dir();
482        let nanos = SystemTime::now()
483            .duration_since(UNIX_EPOCH)
484            .unwrap()
485            .as_nanos();
486        path.push(format!("prism_{nanos}.csv"));
487
488        let contents = "left,right\n2.0,4.0\n3.0,5.0\n";
489        fs::write(&path, contents).unwrap();
490
491        let df = load_dataframe(&path).unwrap();
492        assert_eq!(df.height(), 2);
493        assert_eq!(df.width(), 2);
494    }
495
496    proptest! {
497        #[test]
498        fn compute_features_matches_first_row(
499            left in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
500            right in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
501        ) {
502            let len = left.len().min(right.len());
503            let df = df![
504                "left" => left[..len].to_vec(),
505                "right" => right[..len].to_vec(),
506            ]
507            .unwrap();
508
509            let columns = FeatureColumns {
510                left: "left".into(),
511                right: "right".into(),
512            };
513            let features = compute_features_from_df(&df, Some(&columns)).unwrap();
514            let expected_left = left[0];
515            let expected_right = right[0];
516            prop_assert_eq!(features.data, vec![expected_left, expected_right, expected_left * expected_right]);
517        }
518    }
519
520    #[test]
521    fn polars_vectorized_dot_product_matches_naive() {
522        let rows = 50_000;
523        let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
524        let right: Vec<f32> = (0..rows).map(|i| ((i + 3) % 100) as f32).collect();
525        let df = df![
526            "left" => left.clone(),
527            "right" => right.clone(),
528        ]
529        .unwrap();
530
531        let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
532        let polars_sum = product
533            .as_materialized_series()
534            .cast(&DataType::Float64)
535            .unwrap()
536            .f64()
537            .unwrap()
538            .sum()
539            .unwrap_or(0.0);
540
541        let mut naive_sum = 0.0f64;
542        for (l, r) in left.iter().zip(right.iter()) {
543            naive_sum += (*l as f64) * (*r as f64);
544        }
545
546        assert!((polars_sum - naive_sum).abs() < 1e-6);
547    }
548
549    #[test]
550    fn polars_groupby_sum_matches_naive() {
551        let rows = 10_000;
552        let keys: Vec<&str> = (0..rows)
553            .map(|i| {
554                if i % 3 == 0 {
555                    "alpha"
556                } else if i % 3 == 1 {
557                    "beta"
558                } else {
559                    "gamma"
560                }
561            })
562            .collect();
563        let values: Vec<f32> = (0..rows).map(|i| (i % 7) as f32).collect();
564        let df = df![
565            "key" => keys.clone(),
566            "value" => values.clone(),
567        ]
568        .unwrap();
569
570        let grouped = df
571            .lazy()
572            .group_by([col("key")])
573            .agg([col("value").sum().alias("value_sum")])
574            .collect()
575            .unwrap();
576        let keys_series = grouped.column("key").unwrap().str().unwrap();
577        let sums_series = grouped.column("value_sum").unwrap().f32().unwrap();
578
579        let mut naive = HashMap::<&str, f32>::new();
580        for (key, value) in keys.iter().zip(values.iter()) {
581            *naive.entry(*key).or_insert(0.0) += value;
582        }
583
584        for idx in 0..grouped.height() {
585            if let Some(key) = keys_series.get(idx) {
586                let polars_value = sums_series.get(idx).unwrap_or(0.0);
587                let naive_value = naive.get(key).copied().unwrap_or(0.0);
588                assert!((polars_value - naive_value).abs() < 1e-3);
589            }
590        }
591    }
592
593    #[test]
594    #[ignore]
595    fn polars_vectorized_dot_product_is_fast() {
596        let rows = 300_000;
597        let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
598        let right: Vec<f32> = (0..rows).map(|i| ((i + 5) % 100) as f32).collect();
599
600        let df = df![
601            "left" => left.clone(),
602            "right" => right.clone(),
603        ]
604        .unwrap();
605
606        let polars_start = Instant::now();
607        let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
608        let polars_sum = product
609            .as_materialized_series()
610            .f32()
611            .unwrap()
612            .sum()
613            .unwrap_or(0.0);
614        let polars_elapsed = polars_start.elapsed();
615        black_box(polars_sum);
616
617        let naive_start = Instant::now();
618        let mut naive_sum = 0.0f32;
619        for (l, r) in left.iter().zip(right.iter()) {
620            naive_sum += l * r;
621        }
622        let naive_elapsed = naive_start.elapsed();
623        black_box(naive_sum);
624
625        println!(
626            "polars dot product: {:?}, naive loop: {:?}",
627            polars_elapsed, naive_elapsed
628        );
629
630        assert!(polars_elapsed <= naive_elapsed * 20);
631    }
632
633    #[test]
634    #[ignore]
635    fn polars_groupby_is_fast() {
636        let rows = 200_000;
637        let keys: Vec<&str> = (0..rows)
638            .map(|i| {
639                if i % 4 == 0 {
640                    "alpha"
641                } else if i % 4 == 1 {
642                    "beta"
643                } else if i % 4 == 2 {
644                    "gamma"
645                } else {
646                    "delta"
647                }
648            })
649            .collect();
650        let values: Vec<f32> = (0..rows).map(|i| (i % 9) as f32).collect();
651        let df = df![
652            "key" => keys.clone(),
653            "value" => values.clone(),
654        ]
655        .unwrap();
656
657        let polars_start = Instant::now();
658        let grouped = df
659            .lazy()
660            .group_by([col("key")])
661            .agg([col("value").sum().alias("value_sum")])
662            .collect()
663            .unwrap();
664        let polars_elapsed = polars_start.elapsed();
665        black_box(grouped.height());
666
667        let naive_start = Instant::now();
668        let mut naive = HashMap::<&str, f32>::new();
669        for (key, value) in keys.iter().zip(values.iter()) {
670            *naive.entry(*key).or_insert(0.0) += value;
671        }
672        let naive_elapsed = naive_start.elapsed();
673        black_box(naive.len());
674
675        println!(
676            "polars groupby: {:?}, naive hashmap: {:?}",
677            polars_elapsed, naive_elapsed
678        );
679
680        assert!(polars_elapsed <= naive_elapsed * 20);
681    }
682}