Skip to main content

converge_analytics/
engine.rs

1// Copyright (c) 2026 Aprio One AB
2// Author: Kenneth Pernyer, kenneth@pernyer.se
3
4use anyhow::{anyhow, Result};
5use converge_core::{Agent, AgentEffect, Context, ContextKey, Fact, ProposedFact};
6use polars::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10/// A fact content representing computed features.
11#[derive(Debug, Serialize, Deserialize, PartialEq)]
12pub struct FeatureVector {
13    pub data: Vec<f32>,
14    pub shape: [usize; 2],
15}
16
17impl FeatureVector {
18    pub fn new(data: Vec<f32>, shape: [usize; 2]) -> Result<Self> {
19        let expected = shape
20            .get(0)
21            .and_then(|rows| shape.get(1).map(|cols| rows.saturating_mul(*cols)))
22            .unwrap_or(0);
23        if data.len() != expected {
24            return Err(anyhow!(
25                "feature data length {} does not match shape {:?}",
26                data.len(),
27                shape
28            ));
29        }
30        Ok(Self { data, shape })
31    }
32
33    pub fn row(data: Vec<f32>) -> Self {
34        let cols = data.len();
35        Self {
36            data,
37            shape: [1, cols],
38        }
39    }
40
41    pub fn rows(&self) -> usize {
42        self.shape[0]
43    }
44
45    pub fn cols(&self) -> usize {
46        self.shape[1]
47    }
48}
49
50#[derive(Clone, Debug)]
51pub struct FeatureColumns {
52    pub left: String,
53    pub right: String,
54}
55
56#[derive(Clone)]
57pub struct FeatureAgent {
58    source_path: Option<PathBuf>,
59    columns: Option<FeatureColumns>,
60}
61
62impl FeatureAgent {
63    pub fn new(source_path: Option<PathBuf>) -> Self {
64        Self {
65            source_path,
66            columns: None,
67        }
68    }
69
70    pub fn with_columns(mut self, left: impl Into<String>, right: impl Into<String>) -> Self {
71        self.columns = Some(FeatureColumns {
72            left: left.into(),
73            right: right.into(),
74        });
75        self
76    }
77
78    /// Internal Polars logic to compute features
79    fn compute_features(&self) -> Result<FeatureVector> {
80        let df = if let Some(path) = &self.source_path {
81            load_dataframe(path)?
82        } else {
83            df! [
84                "x1" => [1.0, 2.0, 3.0],
85                "x2" => [4.0, 5.0, 6.0],
86                "x3" => [7.0, 8.0, 9.0],
87            ]?
88        };
89        compute_features_from_df(&df, self.columns.as_ref())
90    }
91}
92
93impl Agent for FeatureAgent {
94    fn name(&self) -> &str {
95        "FeatureAgent (Polars)"
96    }
97
98    fn dependencies(&self) -> &[ContextKey] {
99        // Depends on Seeds to know WHAT to process
100        &[ContextKey::Seeds]
101    }
102
103    fn accepts(&self, ctx: &Context) -> bool {
104        // Run if we have Seeds but haven't produced Proposals yet
105        ctx.has(ContextKey::Seeds) && !ctx.has(ContextKey::Proposals)
106    }
107
108    fn execute(&self, _ctx: &Context) -> AgentEffect {
109        // 1. Compute features using Polars
110        let features = match self.compute_features() {
111            Ok(f) => f,
112            Err(e) => {
113                return AgentEffect::with_fact(Fact::new(
114                    ContextKey::Diagnostic,
115                    "feature-agent-error",
116                    e.to_string(),
117                ))
118            }
119        };
120
121        // 2. Serialize to Fact content
122        let content = serde_json::to_string(&features).unwrap_or_default();
123
124        // 3. Propose the features
125        let proposal = ProposedFact {
126            key: ContextKey::Proposals,
127            id: "features-001".into(),
128            content,
129            confidence: 1.0, // Deterministic computation
130            provenance: "polars-engine".into(),
131        };
132
133        // Note: In a real agent, we might emit a Fact directly if trusted, or a ProposedFact.
134        // converge_core usually requires TryFrom implementation or specific flow.
135        // For simplicity, we assume we can emit effects.
136        // Wait, AgentEffect::with_proposal?
137        // Let's check AgentEffect definition.
138
139        // Use the constructor for single proposal
140        AgentEffect::with_proposal(proposal)
141    }
142}
143
144fn compute_features_from_df(df: &DataFrame, columns: Option<&FeatureColumns>) -> Result<FeatureVector> {
145    let (left, right) = if let Some(columns) = columns {
146        let left = df
147            .column(&columns.left)
148            .map_err(|_| anyhow!("missing column {}", columns.left))?;
149        let right = df
150            .column(&columns.right)
151            .map_err(|_| anyhow!("missing column {}", columns.right))?;
152        (left.clone(), right.clone())
153    } else {
154        let mut numeric = df
155            .get_columns()
156            .iter()
157            .filter(|series| is_numeric_dtype(series.dtype()))
158            .cloned()
159            .collect::<Vec<_>>();
160        if numeric.len() < 2 {
161            return Err(anyhow!("need at least two numeric columns"));
162        }
163        (numeric.remove(0), numeric.remove(0))
164    };
165
166    if left.len() == 0 || right.len() == 0 {
167        return Err(anyhow!("input data is empty"));
168    }
169
170    let left = left.cast(&DataType::Float32)?;
171    let right = right.cast(&DataType::Float32)?;
172
173    let left_val = left
174        .f32()?
175        .get(0)
176        .ok_or_else(|| anyhow!("missing left value"))?;
177    let right_val = right
178        .f32()?
179        .get(0)
180        .ok_or_else(|| anyhow!("missing right value"))?;
181
182    let interaction = left_val * right_val;
183    Ok(FeatureVector::row(vec![left_val, right_val, interaction]))
184}
185
186fn load_dataframe(path: &PathBuf) -> Result<DataFrame> {
187    let extension = path
188        .extension()
189        .and_then(|ext| ext.to_str())
190        .unwrap_or("")
191        .to_ascii_lowercase();
192
193    let path_str = path
194        .to_str()
195        .ok_or_else(|| anyhow!("path is not valid utf-8: {:?}", path))?;
196
197    match extension.as_str() {
198        "parquet" => {
199            let pl_path = PlPath::from_str(path_str);
200            Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
201        }
202        "csv" => Ok(
203            CsvReadOptions::default()
204                .with_has_header(true)
205                .try_into_reader_with_file_path(Some(path.to_path_buf()))?
206                .finish()?,
207        ),
208        _ => Err(anyhow!(
209            "unsupported data format for path {:?} (expected .csv or .parquet)",
210            path
211        )),
212    }
213}
214
215fn is_numeric_dtype(dtype: &DataType) -> bool {
216    matches!(
217        dtype,
218        DataType::Int8
219            | DataType::Int16
220            | DataType::Int32
221            | DataType::Int64
222            | DataType::UInt8
223            | DataType::UInt16
224            | DataType::UInt32
225            | DataType::UInt64
226            | DataType::Float32
227            | DataType::Float64
228    )
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use proptest::prelude::*;
235    use std::collections::HashMap;
236    use std::fs;
237    use std::hint::black_box;
238    use std::time::Instant;
239    use std::time::{SystemTime, UNIX_EPOCH};
240
241    #[test]
242    fn feature_vector_validates_shape() {
243        let ok = FeatureVector::new(vec![1.0, 2.0], [1, 2]).unwrap();
244        assert_eq!(ok.rows(), 1);
245        assert_eq!(ok.cols(), 2);
246        assert!(FeatureVector::new(vec![1.0], [1, 2]).is_err());
247    }
248
249    #[test]
250    fn compute_features_from_df_uses_named_columns() {
251        let df = df![
252            "a" => [2.0f32, 3.0],
253            "b" => [4.0f32, 5.0],
254        ]
255        .unwrap();
256
257        let columns = FeatureColumns {
258            left: "a".into(),
259            right: "b".into(),
260        };
261        let features = compute_features_from_df(&df, Some(&columns)).unwrap();
262        assert_eq!(features.data, vec![2.0, 4.0, 8.0]);
263        assert_eq!(features.shape, [1, 3]);
264    }
265
266    #[test]
267    fn compute_features_from_df_falls_back_to_first_numeric_columns() {
268        let df = df![
269            "text" => ["x", "y"],
270            "a" => [1.5f32, 2.5],
271            "b" => [3.0f32, 4.0],
272        ]
273        .unwrap();
274
275        let features = compute_features_from_df(&df, None).unwrap();
276        assert_eq!(features.data, vec![1.5, 3.0, 4.5]);
277    }
278
279    #[test]
280    fn compute_features_handles_large_dataset() {
281        let rows = 10_000;
282        let left: Vec<f32> = (0..rows).map(|i| i as f32).collect();
283        let right: Vec<f32> = (0..rows).map(|i| (i as f32) + 1.0).collect();
284        let df = df![
285            "left" => left,
286            "right" => right,
287        ]
288        .unwrap();
289
290        let columns = FeatureColumns {
291            left: "left".into(),
292            right: "right".into(),
293        };
294        let features = compute_features_from_df(&df, Some(&columns)).unwrap();
295        assert_eq!(features.data, vec![0.0, 1.0, 0.0]);
296    }
297
298    #[test]
299    fn load_dataframe_reads_csv() {
300        let mut path = std::env::temp_dir();
301        let nanos = SystemTime::now()
302            .duration_since(UNIX_EPOCH)
303            .unwrap()
304            .as_nanos();
305        path.push(format!("converge_analytics_{nanos}.csv"));
306
307        let contents = "left,right\n2.0,4.0\n3.0,5.0\n";
308        fs::write(&path, contents).unwrap();
309
310        let df = load_dataframe(&path).unwrap();
311        assert_eq!(df.height(), 2);
312        assert_eq!(df.width(), 2);
313    }
314
315    proptest! {
316        #[test]
317        fn compute_features_matches_first_row(
318            left in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
319            right in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
320        ) {
321            let len = left.len().min(right.len());
322            let df = df![
323                "left" => left[..len].to_vec(),
324                "right" => right[..len].to_vec(),
325            ]
326            .unwrap();
327
328            let columns = FeatureColumns {
329                left: "left".into(),
330                right: "right".into(),
331            };
332            let features = compute_features_from_df(&df, Some(&columns)).unwrap();
333            let expected_left = left[0];
334            let expected_right = right[0];
335            prop_assert_eq!(features.data, vec![expected_left, expected_right, expected_left * expected_right]);
336        }
337    }
338
339    #[test]
340    fn polars_vectorized_dot_product_matches_naive() {
341        let rows = 50_000;
342        let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
343        let right: Vec<f32> = (0..rows).map(|i| ((i + 3) % 100) as f32).collect();
344        let df = df![
345            "left" => left.clone(),
346            "right" => right.clone(),
347        ]
348        .unwrap();
349
350        let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
351        let polars_sum = product
352            .as_materialized_series()
353            .cast(&DataType::Float64)
354            .unwrap()
355            .f64()
356            .unwrap()
357            .sum()
358            .unwrap_or(0.0);
359
360        let mut naive_sum = 0.0f64;
361        for (l, r) in left.iter().zip(right.iter()) {
362            naive_sum += (*l as f64) * (*r as f64);
363        }
364
365        assert!((polars_sum - naive_sum).abs() < 1e-6);
366    }
367
368    #[test]
369    fn polars_groupby_sum_matches_naive() {
370        let rows = 10_000;
371        let keys: Vec<&str> = (0..rows)
372            .map(|i| if i % 3 == 0 { "alpha" } else if i % 3 == 1 { "beta" } else { "gamma" })
373            .collect();
374        let values: Vec<f32> = (0..rows).map(|i| (i % 7) as f32).collect();
375        let df = df![
376            "key" => keys.clone(),
377            "value" => values.clone(),
378        ]
379        .unwrap();
380
381        let grouped = df
382            .lazy()
383            .group_by([col("key")])
384            .agg([col("value").sum().alias("value_sum")])
385            .collect()
386            .unwrap();
387        let keys_series = grouped.column("key").unwrap().str().unwrap();
388        let sums_series = grouped.column("value_sum").unwrap().f32().unwrap();
389
390        let mut naive = HashMap::<&str, f32>::new();
391        for (key, value) in keys.iter().zip(values.iter()) {
392            *naive.entry(*key).or_insert(0.0) += value;
393        }
394
395        for idx in 0..grouped.height() {
396            if let Some(key) = keys_series.get(idx) {
397                let polars_value = sums_series.get(idx).unwrap_or(0.0);
398                let naive_value = naive.get(key).copied().unwrap_or(0.0);
399                assert!((polars_value - naive_value).abs() < 1e-3);
400            }
401        }
402    }
403
404    #[test]
405    #[ignore]
406    fn polars_vectorized_dot_product_is_fast() {
407        let rows = 300_000;
408        let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
409        let right: Vec<f32> = (0..rows).map(|i| ((i + 5) % 100) as f32).collect();
410
411        let df = df![
412            "left" => left.clone(),
413            "right" => right.clone(),
414        ]
415        .unwrap();
416
417        let polars_start = Instant::now();
418        let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
419        let polars_sum = product
420            .as_materialized_series()
421            .f32()
422            .unwrap()
423            .sum()
424            .unwrap_or(0.0);
425        let polars_elapsed = polars_start.elapsed();
426        black_box(polars_sum);
427
428        let naive_start = Instant::now();
429        let mut naive_sum = 0.0f32;
430        for (l, r) in left.iter().zip(right.iter()) {
431            naive_sum += l * r;
432        }
433        let naive_elapsed = naive_start.elapsed();
434        black_box(naive_sum);
435
436        println!(
437            "polars dot product: {:?}, naive loop: {:?}",
438            polars_elapsed, naive_elapsed
439        );
440
441        assert!(polars_elapsed <= naive_elapsed * 20);
442    }
443
444    #[test]
445    #[ignore]
446    fn polars_groupby_is_fast() {
447        let rows = 200_000;
448        let keys: Vec<&str> = (0..rows)
449            .map(|i| if i % 4 == 0 { "alpha" } else if i % 4 == 1 { "beta" } else if i % 4 == 2 { "gamma" } else { "delta" })
450            .collect();
451        let values: Vec<f32> = (0..rows).map(|i| (i % 9) as f32).collect();
452        let df = df![
453            "key" => keys.clone(),
454            "value" => values.clone(),
455        ]
456        .unwrap();
457
458        let polars_start = Instant::now();
459        let grouped = df
460            .lazy()
461            .group_by([col("key")])
462            .agg([col("value").sum().alias("value_sum")])
463            .collect()
464            .unwrap();
465        let polars_elapsed = polars_start.elapsed();
466        black_box(grouped.height());
467
468        let naive_start = Instant::now();
469        let mut naive = HashMap::<&str, f32>::new();
470        for (key, value) in keys.iter().zip(values.iter()) {
471            *naive.entry(*key).or_insert(0.0) += value;
472        }
473        let naive_elapsed = naive_start.elapsed();
474        black_box(naive.len());
475
476        println!(
477            "polars groupby: {:?}, naive hashmap: {:?}",
478            polars_elapsed, naive_elapsed
479        );
480
481        assert!(polars_elapsed <= naive_elapsed * 20);
482    }
483}