1use anyhow::{anyhow, Result};
5use converge_core::{Agent, AgentEffect, Context, ContextKey, Fact, ProposedFact};
6use polars::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10#[derive(Debug, Serialize, Deserialize, PartialEq)]
12pub struct FeatureVector {
13 pub data: Vec<f32>,
14 pub shape: [usize; 2],
15}
16
17impl FeatureVector {
18 pub fn new(data: Vec<f32>, shape: [usize; 2]) -> Result<Self> {
19 let expected = shape
20 .get(0)
21 .and_then(|rows| shape.get(1).map(|cols| rows.saturating_mul(*cols)))
22 .unwrap_or(0);
23 if data.len() != expected {
24 return Err(anyhow!(
25 "feature data length {} does not match shape {:?}",
26 data.len(),
27 shape
28 ));
29 }
30 Ok(Self { data, shape })
31 }
32
33 pub fn row(data: Vec<f32>) -> Self {
34 let cols = data.len();
35 Self {
36 data,
37 shape: [1, cols],
38 }
39 }
40
41 pub fn rows(&self) -> usize {
42 self.shape[0]
43 }
44
45 pub fn cols(&self) -> usize {
46 self.shape[1]
47 }
48}
49
50#[derive(Clone, Debug)]
51pub struct FeatureColumns {
52 pub left: String,
53 pub right: String,
54}
55
56#[derive(Clone)]
57pub struct FeatureAgent {
58 source_path: Option<PathBuf>,
59 columns: Option<FeatureColumns>,
60}
61
62impl FeatureAgent {
63 pub fn new(source_path: Option<PathBuf>) -> Self {
64 Self {
65 source_path,
66 columns: None,
67 }
68 }
69
70 pub fn with_columns(mut self, left: impl Into<String>, right: impl Into<String>) -> Self {
71 self.columns = Some(FeatureColumns {
72 left: left.into(),
73 right: right.into(),
74 });
75 self
76 }
77
78 fn compute_features(&self) -> Result<FeatureVector> {
80 let df = if let Some(path) = &self.source_path {
81 load_dataframe(path)?
82 } else {
83 df! [
84 "x1" => [1.0, 2.0, 3.0],
85 "x2" => [4.0, 5.0, 6.0],
86 "x3" => [7.0, 8.0, 9.0],
87 ]?
88 };
89 compute_features_from_df(&df, self.columns.as_ref())
90 }
91}
92
93impl Agent for FeatureAgent {
94 fn name(&self) -> &str {
95 "FeatureAgent (Polars)"
96 }
97
98 fn dependencies(&self) -> &[ContextKey] {
99 &[ContextKey::Seeds]
101 }
102
103 fn accepts(&self, ctx: &Context) -> bool {
104 ctx.has(ContextKey::Seeds) && !ctx.has(ContextKey::Proposals)
106 }
107
108 fn execute(&self, _ctx: &Context) -> AgentEffect {
109 let features = match self.compute_features() {
111 Ok(f) => f,
112 Err(e) => {
113 return AgentEffect::with_fact(Fact::new(
114 ContextKey::Diagnostic,
115 "feature-agent-error",
116 e.to_string(),
117 ))
118 }
119 };
120
121 let content = serde_json::to_string(&features).unwrap_or_default();
123
124 let proposal = ProposedFact {
126 key: ContextKey::Proposals,
127 id: "features-001".into(),
128 content,
129 confidence: 1.0, provenance: "polars-engine".into(),
131 };
132
133 AgentEffect::with_proposal(proposal)
141 }
142}
143
144fn compute_features_from_df(df: &DataFrame, columns: Option<&FeatureColumns>) -> Result<FeatureVector> {
145 let (left, right) = if let Some(columns) = columns {
146 let left = df
147 .column(&columns.left)
148 .map_err(|_| anyhow!("missing column {}", columns.left))?;
149 let right = df
150 .column(&columns.right)
151 .map_err(|_| anyhow!("missing column {}", columns.right))?;
152 (left.clone(), right.clone())
153 } else {
154 let mut numeric = df
155 .get_columns()
156 .iter()
157 .filter(|series| is_numeric_dtype(series.dtype()))
158 .cloned()
159 .collect::<Vec<_>>();
160 if numeric.len() < 2 {
161 return Err(anyhow!("need at least two numeric columns"));
162 }
163 (numeric.remove(0), numeric.remove(0))
164 };
165
166 if left.len() == 0 || right.len() == 0 {
167 return Err(anyhow!("input data is empty"));
168 }
169
170 let left = left.cast(&DataType::Float32)?;
171 let right = right.cast(&DataType::Float32)?;
172
173 let left_val = left
174 .f32()?
175 .get(0)
176 .ok_or_else(|| anyhow!("missing left value"))?;
177 let right_val = right
178 .f32()?
179 .get(0)
180 .ok_or_else(|| anyhow!("missing right value"))?;
181
182 let interaction = left_val * right_val;
183 Ok(FeatureVector::row(vec![left_val, right_val, interaction]))
184}
185
186fn load_dataframe(path: &PathBuf) -> Result<DataFrame> {
187 let extension = path
188 .extension()
189 .and_then(|ext| ext.to_str())
190 .unwrap_or("")
191 .to_ascii_lowercase();
192
193 let path_str = path
194 .to_str()
195 .ok_or_else(|| anyhow!("path is not valid utf-8: {:?}", path))?;
196
197 match extension.as_str() {
198 "parquet" => {
199 let pl_path = PlPath::from_str(path_str);
200 Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
201 }
202 "csv" => Ok(
203 CsvReadOptions::default()
204 .with_has_header(true)
205 .try_into_reader_with_file_path(Some(path.to_path_buf()))?
206 .finish()?,
207 ),
208 _ => Err(anyhow!(
209 "unsupported data format for path {:?} (expected .csv or .parquet)",
210 path
211 )),
212 }
213}
214
215fn is_numeric_dtype(dtype: &DataType) -> bool {
216 matches!(
217 dtype,
218 DataType::Int8
219 | DataType::Int16
220 | DataType::Int32
221 | DataType::Int64
222 | DataType::UInt8
223 | DataType::UInt16
224 | DataType::UInt32
225 | DataType::UInt64
226 | DataType::Float32
227 | DataType::Float64
228 )
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use proptest::prelude::*;
235 use std::collections::HashMap;
236 use std::fs;
237 use std::hint::black_box;
238 use std::time::Instant;
239 use std::time::{SystemTime, UNIX_EPOCH};
240
241 #[test]
242 fn feature_vector_validates_shape() {
243 let ok = FeatureVector::new(vec![1.0, 2.0], [1, 2]).unwrap();
244 assert_eq!(ok.rows(), 1);
245 assert_eq!(ok.cols(), 2);
246 assert!(FeatureVector::new(vec![1.0], [1, 2]).is_err());
247 }
248
249 #[test]
250 fn compute_features_from_df_uses_named_columns() {
251 let df = df![
252 "a" => [2.0f32, 3.0],
253 "b" => [4.0f32, 5.0],
254 ]
255 .unwrap();
256
257 let columns = FeatureColumns {
258 left: "a".into(),
259 right: "b".into(),
260 };
261 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
262 assert_eq!(features.data, vec![2.0, 4.0, 8.0]);
263 assert_eq!(features.shape, [1, 3]);
264 }
265
266 #[test]
267 fn compute_features_from_df_falls_back_to_first_numeric_columns() {
268 let df = df![
269 "text" => ["x", "y"],
270 "a" => [1.5f32, 2.5],
271 "b" => [3.0f32, 4.0],
272 ]
273 .unwrap();
274
275 let features = compute_features_from_df(&df, None).unwrap();
276 assert_eq!(features.data, vec![1.5, 3.0, 4.5]);
277 }
278
279 #[test]
280 fn compute_features_handles_large_dataset() {
281 let rows = 10_000;
282 let left: Vec<f32> = (0..rows).map(|i| i as f32).collect();
283 let right: Vec<f32> = (0..rows).map(|i| (i as f32) + 1.0).collect();
284 let df = df![
285 "left" => left,
286 "right" => right,
287 ]
288 .unwrap();
289
290 let columns = FeatureColumns {
291 left: "left".into(),
292 right: "right".into(),
293 };
294 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
295 assert_eq!(features.data, vec![0.0, 1.0, 0.0]);
296 }
297
298 #[test]
299 fn load_dataframe_reads_csv() {
300 let mut path = std::env::temp_dir();
301 let nanos = SystemTime::now()
302 .duration_since(UNIX_EPOCH)
303 .unwrap()
304 .as_nanos();
305 path.push(format!("converge_analytics_{nanos}.csv"));
306
307 let contents = "left,right\n2.0,4.0\n3.0,5.0\n";
308 fs::write(&path, contents).unwrap();
309
310 let df = load_dataframe(&path).unwrap();
311 assert_eq!(df.height(), 2);
312 assert_eq!(df.width(), 2);
313 }
314
315 proptest! {
316 #[test]
317 fn compute_features_matches_first_row(
318 left in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
319 right in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
320 ) {
321 let len = left.len().min(right.len());
322 let df = df![
323 "left" => left[..len].to_vec(),
324 "right" => right[..len].to_vec(),
325 ]
326 .unwrap();
327
328 let columns = FeatureColumns {
329 left: "left".into(),
330 right: "right".into(),
331 };
332 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
333 let expected_left = left[0];
334 let expected_right = right[0];
335 prop_assert_eq!(features.data, vec![expected_left, expected_right, expected_left * expected_right]);
336 }
337 }
338
339 #[test]
340 fn polars_vectorized_dot_product_matches_naive() {
341 let rows = 50_000;
342 let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
343 let right: Vec<f32> = (0..rows).map(|i| ((i + 3) % 100) as f32).collect();
344 let df = df![
345 "left" => left.clone(),
346 "right" => right.clone(),
347 ]
348 .unwrap();
349
350 let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
351 let polars_sum = product
352 .as_materialized_series()
353 .cast(&DataType::Float64)
354 .unwrap()
355 .f64()
356 .unwrap()
357 .sum()
358 .unwrap_or(0.0);
359
360 let mut naive_sum = 0.0f64;
361 for (l, r) in left.iter().zip(right.iter()) {
362 naive_sum += (*l as f64) * (*r as f64);
363 }
364
365 assert!((polars_sum - naive_sum).abs() < 1e-6);
366 }
367
368 #[test]
369 fn polars_groupby_sum_matches_naive() {
370 let rows = 10_000;
371 let keys: Vec<&str> = (0..rows)
372 .map(|i| if i % 3 == 0 { "alpha" } else if i % 3 == 1 { "beta" } else { "gamma" })
373 .collect();
374 let values: Vec<f32> = (0..rows).map(|i| (i % 7) as f32).collect();
375 let df = df![
376 "key" => keys.clone(),
377 "value" => values.clone(),
378 ]
379 .unwrap();
380
381 let grouped = df
382 .lazy()
383 .group_by([col("key")])
384 .agg([col("value").sum().alias("value_sum")])
385 .collect()
386 .unwrap();
387 let keys_series = grouped.column("key").unwrap().str().unwrap();
388 let sums_series = grouped.column("value_sum").unwrap().f32().unwrap();
389
390 let mut naive = HashMap::<&str, f32>::new();
391 for (key, value) in keys.iter().zip(values.iter()) {
392 *naive.entry(*key).or_insert(0.0) += value;
393 }
394
395 for idx in 0..grouped.height() {
396 if let Some(key) = keys_series.get(idx) {
397 let polars_value = sums_series.get(idx).unwrap_or(0.0);
398 let naive_value = naive.get(key).copied().unwrap_or(0.0);
399 assert!((polars_value - naive_value).abs() < 1e-3);
400 }
401 }
402 }
403
404 #[test]
405 #[ignore]
406 fn polars_vectorized_dot_product_is_fast() {
407 let rows = 300_000;
408 let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
409 let right: Vec<f32> = (0..rows).map(|i| ((i + 5) % 100) as f32).collect();
410
411 let df = df![
412 "left" => left.clone(),
413 "right" => right.clone(),
414 ]
415 .unwrap();
416
417 let polars_start = Instant::now();
418 let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
419 let polars_sum = product
420 .as_materialized_series()
421 .f32()
422 .unwrap()
423 .sum()
424 .unwrap_or(0.0);
425 let polars_elapsed = polars_start.elapsed();
426 black_box(polars_sum);
427
428 let naive_start = Instant::now();
429 let mut naive_sum = 0.0f32;
430 for (l, r) in left.iter().zip(right.iter()) {
431 naive_sum += l * r;
432 }
433 let naive_elapsed = naive_start.elapsed();
434 black_box(naive_sum);
435
436 println!(
437 "polars dot product: {:?}, naive loop: {:?}",
438 polars_elapsed, naive_elapsed
439 );
440
441 assert!(polars_elapsed <= naive_elapsed * 20);
442 }
443
444 #[test]
445 #[ignore]
446 fn polars_groupby_is_fast() {
447 let rows = 200_000;
448 let keys: Vec<&str> = (0..rows)
449 .map(|i| if i % 4 == 0 { "alpha" } else if i % 4 == 1 { "beta" } else if i % 4 == 2 { "gamma" } else { "delta" })
450 .collect();
451 let values: Vec<f32> = (0..rows).map(|i| (i % 9) as f32).collect();
452 let df = df![
453 "key" => keys.clone(),
454 "value" => values.clone(),
455 ]
456 .unwrap();
457
458 let polars_start = Instant::now();
459 let grouped = df
460 .lazy()
461 .group_by([col("key")])
462 .agg([col("value").sum().alias("value_sum")])
463 .collect()
464 .unwrap();
465 let polars_elapsed = polars_start.elapsed();
466 black_box(grouped.height());
467
468 let naive_start = Instant::now();
469 let mut naive = HashMap::<&str, f32>::new();
470 for (key, value) in keys.iter().zip(values.iter()) {
471 *naive.entry(*key).or_insert(0.0) += value;
472 }
473 let naive_elapsed = naive_start.elapsed();
474 black_box(naive.len());
475
476 println!(
477 "polars groupby: {:?}, naive hashmap: {:?}",
478 polars_elapsed, naive_elapsed
479 );
480
481 assert!(polars_elapsed <= naive_elapsed * 20);
482 }
483}