1use anyhow::{Result, anyhow};
4use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
5use polars::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::path::{Path, PathBuf};
8
9#[derive(Debug, Serialize, Deserialize, PartialEq)]
11pub struct FeatureVector {
12 pub data: Vec<f32>,
13 pub shape: [usize; 2],
14}
15
16impl FeatureVector {
17 pub fn new(data: Vec<f32>, shape: [usize; 2]) -> Result<Self> {
18 let expected = shape
19 .first()
20 .and_then(|rows| shape.get(1).map(|cols| rows.saturating_mul(*cols)))
21 .unwrap_or(0);
22 if data.len() != expected {
23 return Err(anyhow!(
24 "feature data length {} does not match shape {:?}",
25 data.len(),
26 shape
27 ));
28 }
29 Ok(Self { data, shape })
30 }
31
32 pub fn row(data: Vec<f32>) -> Self {
33 let cols = data.len();
34 Self {
35 data,
36 shape: [1, cols],
37 }
38 }
39
40 pub fn rows(&self) -> usize {
41 self.shape[0]
42 }
43
44 pub fn cols(&self) -> usize {
45 self.shape[1]
46 }
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
50pub struct FeatureColumns {
51 pub left: String,
52 pub right: String,
53}
54
55#[derive(Clone, Debug)]
56pub struct FeatureAgent {
57 source_path: Option<PathBuf>,
58 columns: Option<FeatureColumns>,
59}
60
61impl FeatureAgent {
62 pub fn new(source_path: Option<PathBuf>) -> Self {
63 Self {
64 source_path,
65 columns: None,
66 }
67 }
68
69 pub fn with_columns(mut self, left: impl Into<String>, right: impl Into<String>) -> Self {
70 self.columns = Some(FeatureColumns {
71 left: left.into(),
72 right: right.into(),
73 });
74 self
75 }
76
77 fn compute_features(&self) -> Result<FeatureVector> {
79 let df = if let Some(path) = &self.source_path {
80 load_dataframe(path)?
81 } else {
82 df! [
83 "x1" => [1.0, 2.0, 3.0],
84 "x2" => [4.0, 5.0, 6.0],
85 "x3" => [7.0, 8.0, 9.0],
86 ]?
87 };
88 compute_features_from_df(&df, self.columns.as_ref())
89 }
90}
91
92#[async_trait::async_trait]
93impl Suggestor for FeatureAgent {
94 fn name(&self) -> &str {
95 "FeatureAgent (Polars)"
96 }
97
98 fn dependencies(&self) -> &[ContextKey] {
99 &[ContextKey::Seeds]
101 }
102
103 fn accepts(&self, ctx: &dyn Context) -> bool {
104 ctx.has(ContextKey::Seeds) && !ctx.has(ContextKey::Proposals)
106 }
107
108 async fn execute(&self, _ctx: &dyn Context) -> AgentEffect {
109 let features = match self.compute_features() {
111 Ok(f) => f,
112 Err(e) => {
113 return AgentEffect::with_proposal(ProposedFact::new(
114 ContextKey::Diagnostic,
115 "feature-agent-error",
116 e.to_string(),
117 self.name(),
118 ));
119 }
120 };
121
122 let content = serde_json::to_string(&features).unwrap_or_default();
124
125 let proposal = ProposedFact::new(
127 ContextKey::Proposals,
128 "features-001",
129 content,
130 "polars-engine",
131 );
132
133 AgentEffect::with_proposal(proposal)
141 }
142}
143
144fn compute_features_from_df(
145 df: &DataFrame,
146 columns: Option<&FeatureColumns>,
147) -> Result<FeatureVector> {
148 let (left, right) = if let Some(columns) = columns {
149 let left = df
150 .column(&columns.left)
151 .map_err(|_| anyhow!("missing column {}", columns.left))?;
152 let right = df
153 .column(&columns.right)
154 .map_err(|_| anyhow!("missing column {}", columns.right))?;
155 (left.clone(), right.clone())
156 } else {
157 let mut numeric = df
158 .get_columns()
159 .iter()
160 .filter(|series| is_numeric_dtype(series.dtype()))
161 .cloned()
162 .collect::<Vec<_>>();
163 if numeric.len() < 2 {
164 return Err(anyhow!("need at least two numeric columns"));
165 }
166 (numeric.remove(0), numeric.remove(0))
167 };
168
169 if left.len() == 0 || right.len() == 0 {
170 return Err(anyhow!("input data is empty"));
171 }
172
173 let left = left.cast(&DataType::Float32)?;
174 let right = right.cast(&DataType::Float32)?;
175
176 let left_val = left
177 .f32()?
178 .get(0)
179 .ok_or_else(|| anyhow!("missing left value"))?;
180 let right_val = right
181 .f32()?
182 .get(0)
183 .ok_or_else(|| anyhow!("missing right value"))?;
184
185 let interaction = left_val * right_val;
186 Ok(FeatureVector::row(vec![left_val, right_val, interaction]))
187}
188
189fn load_dataframe(path: &Path) -> Result<DataFrame> {
190 let extension = path
191 .extension()
192 .and_then(|ext| ext.to_str())
193 .unwrap_or("")
194 .to_ascii_lowercase();
195
196 let path_str = path
197 .to_str()
198 .ok_or_else(|| anyhow!("path is not valid utf-8: {}", path.display()))?;
199
200 match extension.as_str() {
201 "parquet" => {
202 let pl_path = PlPath::from_str(path_str);
203 Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
204 }
205 "csv" => Ok(CsvReadOptions::default()
206 .with_has_header(true)
207 .try_into_reader_with_file_path(Some(path.to_path_buf()))?
208 .finish()?),
209 _ => Err(anyhow!(
210 "unsupported data format for path {} (expected .csv or .parquet)",
211 path.display()
212 )),
213 }
214}
215
216fn is_numeric_dtype(dtype: &DataType) -> bool {
217 matches!(
218 dtype,
219 DataType::Int8
220 | DataType::Int16
221 | DataType::Int32
222 | DataType::Int64
223 | DataType::UInt8
224 | DataType::UInt16
225 | DataType::UInt32
226 | DataType::UInt64
227 | DataType::Float32
228 | DataType::Float64
229 )
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use proptest::prelude::*;
236 use std::collections::HashMap;
237 use std::fs;
238 use std::hint::black_box;
239 use std::time::Instant;
240 use std::time::{SystemTime, UNIX_EPOCH};
241
242 #[test]
243 fn feature_vector_validates_shape() {
244 let ok = FeatureVector::new(vec![1.0, 2.0], [1, 2]).unwrap();
245 assert_eq!(ok.rows(), 1);
246 assert_eq!(ok.cols(), 2);
247 assert!(FeatureVector::new(vec![1.0], [1, 2]).is_err());
248 }
249
250 #[test]
251 fn feature_vector_new_multi_row() {
252 let fv = FeatureVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]).unwrap();
253 assert_eq!(fv.rows(), 2);
254 assert_eq!(fv.cols(), 3);
255 assert_eq!(fv.data.len(), 6);
256 }
257
258 #[test]
259 fn feature_vector_new_rejects_mismatched_length() {
260 assert!(FeatureVector::new(vec![1.0, 2.0, 3.0], [2, 2]).is_err());
261 assert!(FeatureVector::new(vec![], [1, 1]).is_err());
262 assert!(FeatureVector::new(vec![1.0], [0, 1]).is_err());
263 }
264
265 #[test]
266 fn feature_vector_new_empty() {
267 let fv = FeatureVector::new(vec![], [0, 0]).unwrap();
268 assert_eq!(fv.rows(), 0);
269 assert_eq!(fv.cols(), 0);
270 assert!(fv.data.is_empty());
271 }
272
273 #[test]
274 fn feature_vector_new_zero_cols() {
275 let fv = FeatureVector::new(vec![], [5, 0]).unwrap();
276 assert_eq!(fv.rows(), 5);
277 assert_eq!(fv.cols(), 0);
278 }
279
280 #[test]
281 fn feature_vector_row_creates_single_row() {
282 let fv = FeatureVector::row(vec![10.0, 20.0, 30.0]);
283 assert_eq!(fv.rows(), 1);
284 assert_eq!(fv.cols(), 3);
285 assert_eq!(fv.data, vec![10.0, 20.0, 30.0]);
286 }
287
288 #[test]
289 fn feature_vector_row_empty() {
290 let fv = FeatureVector::row(vec![]);
291 assert_eq!(fv.rows(), 1);
292 assert_eq!(fv.cols(), 0);
293 assert!(fv.data.is_empty());
294 }
295
296 #[test]
297 fn feature_vector_row_single_element() {
298 let fv = FeatureVector::row(vec![42.0]);
299 assert_eq!(fv.rows(), 1);
300 assert_eq!(fv.cols(), 1);
301 assert_eq!(fv.data, vec![42.0]);
302 }
303
304 #[test]
305 fn feature_columns_construction() {
306 let fc = FeatureColumns {
307 left: "price".to_string(),
308 right: "quantity".to_string(),
309 };
310 assert_eq!(fc.left, "price");
311 assert_eq!(fc.right, "quantity");
312 }
313
314 #[test]
315 fn feature_columns_roundtrip_serde() {
316 let fc = FeatureColumns {
317 left: "a".to_string(),
318 right: "b".to_string(),
319 };
320 let json = serde_json::to_string(&fc).unwrap();
321 let deserialized: FeatureColumns = serde_json::from_str(&json).unwrap();
322 assert_eq!(deserialized.left, "a");
323 assert_eq!(deserialized.right, "b");
324 }
325
326 #[test]
327 fn feature_vector_roundtrip_serde() {
328 let fv = FeatureVector::new(vec![1.0, 2.0, 3.0, 4.0], [2, 2]).unwrap();
329 let json = serde_json::to_string(&fv).unwrap();
330 let deserialized: FeatureVector = serde_json::from_str(&json).unwrap();
331 assert_eq!(fv, deserialized);
332 }
333
334 #[test]
335 fn feature_agent_new_without_columns() {
336 let agent = FeatureAgent::new(None);
337 assert!(agent.source_path.is_none());
338 assert!(agent.columns.is_none());
339 }
340
341 #[test]
342 fn feature_agent_with_columns() {
343 let agent = FeatureAgent::new(None).with_columns("x", "y");
344 let cols = agent.columns.unwrap();
345 assert_eq!(cols.left, "x");
346 assert_eq!(cols.right, "y");
347 }
348
349 #[test]
350 fn feature_agent_with_source_path() {
351 let agent = FeatureAgent::new(Some(PathBuf::from("/tmp/data.csv")));
352 assert_eq!(agent.source_path.unwrap(), PathBuf::from("/tmp/data.csv"));
353 }
354
355 #[test]
356 fn is_numeric_dtype_covers_all_numeric_types() {
357 let numeric = [
358 DataType::Int8,
359 DataType::Int16,
360 DataType::Int32,
361 DataType::Int64,
362 DataType::UInt8,
363 DataType::UInt16,
364 DataType::UInt32,
365 DataType::UInt64,
366 DataType::Float32,
367 DataType::Float64,
368 ];
369 for dt in &numeric {
370 assert!(is_numeric_dtype(dt), "{dt:?} should be numeric");
371 }
372 }
373
374 #[test]
375 fn is_numeric_dtype_rejects_non_numeric() {
376 assert!(!is_numeric_dtype(&DataType::String));
377 assert!(!is_numeric_dtype(&DataType::Boolean));
378 assert!(!is_numeric_dtype(&DataType::Date));
379 }
380
381 #[test]
382 fn compute_features_rejects_empty_dataframe() {
383 let df = df![
384 "a" => Vec::<f32>::new(),
385 "b" => Vec::<f32>::new(),
386 ]
387 .unwrap();
388 let cols = FeatureColumns {
389 left: "a".into(),
390 right: "b".into(),
391 };
392 assert!(compute_features_from_df(&df, Some(&cols)).is_err());
393 }
394
395 #[test]
396 fn compute_features_rejects_missing_column() {
397 let df = df!["a" => [1.0f32]].unwrap();
398 let cols = FeatureColumns {
399 left: "a".into(),
400 right: "missing".into(),
401 };
402 assert!(compute_features_from_df(&df, Some(&cols)).is_err());
403 }
404
405 #[test]
406 fn compute_features_rejects_insufficient_numeric_columns() {
407 let df = df!["text" => ["a", "b"]].unwrap();
408 assert!(compute_features_from_df(&df, None).is_err());
409 }
410
411 proptest! {
412 #[test]
413 fn feature_vector_shape_invariant(
414 rows in 0usize..50,
415 cols in 0usize..50,
416 ) {
417 let len = rows.saturating_mul(cols);
418 let data = vec![0.0f32; len];
419 let fv = FeatureVector::new(data, [rows, cols]).unwrap();
420 prop_assert_eq!(fv.rows() * fv.cols(), fv.data.len());
421 }
422 }
423
424 #[test]
425 fn compute_features_from_df_uses_named_columns() {
426 let df = df![
427 "a" => [2.0f32, 3.0],
428 "b" => [4.0f32, 5.0],
429 ]
430 .unwrap();
431
432 let columns = FeatureColumns {
433 left: "a".into(),
434 right: "b".into(),
435 };
436 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
437 assert_eq!(features.data, vec![2.0, 4.0, 8.0]);
438 assert_eq!(features.shape, [1, 3]);
439 }
440
441 #[test]
442 fn compute_features_from_df_falls_back_to_first_numeric_columns() {
443 let df = df![
444 "text" => ["x", "y"],
445 "a" => [1.5f32, 2.5],
446 "b" => [3.0f32, 4.0],
447 ]
448 .unwrap();
449
450 let features = compute_features_from_df(&df, None).unwrap();
451 assert_eq!(features.data, vec![1.5, 3.0, 4.5]);
452 }
453
454 #[test]
455 fn compute_features_handles_large_dataset() {
456 let rows = 10_000;
457 let left: Vec<f32> = (0..rows).map(|i| i as f32).collect();
458 let right: Vec<f32> = (0..rows).map(|i| (i as f32) + 1.0).collect();
459 let df = df![
460 "left" => left,
461 "right" => right,
462 ]
463 .unwrap();
464
465 let columns = FeatureColumns {
466 left: "left".into(),
467 right: "right".into(),
468 };
469 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
470 assert_eq!(features.data, vec![0.0, 1.0, 0.0]);
471 }
472
473 #[test]
474 fn load_dataframe_reads_csv() {
475 let mut path = std::env::temp_dir();
476 let nanos = SystemTime::now()
477 .duration_since(UNIX_EPOCH)
478 .unwrap()
479 .as_nanos();
480 path.push(format!("converge_analytics_{nanos}.csv"));
481
482 let contents = "left,right\n2.0,4.0\n3.0,5.0\n";
483 fs::write(&path, contents).unwrap();
484
485 let df = load_dataframe(&path).unwrap();
486 assert_eq!(df.height(), 2);
487 assert_eq!(df.width(), 2);
488 }
489
490 proptest! {
491 #[test]
492 fn compute_features_matches_first_row(
493 left in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
494 right in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
495 ) {
496 let len = left.len().min(right.len());
497 let df = df![
498 "left" => left[..len].to_vec(),
499 "right" => right[..len].to_vec(),
500 ]
501 .unwrap();
502
503 let columns = FeatureColumns {
504 left: "left".into(),
505 right: "right".into(),
506 };
507 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
508 let expected_left = left[0];
509 let expected_right = right[0];
510 prop_assert_eq!(features.data, vec![expected_left, expected_right, expected_left * expected_right]);
511 }
512 }
513
514 #[test]
515 fn polars_vectorized_dot_product_matches_naive() {
516 let rows = 50_000;
517 let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
518 let right: Vec<f32> = (0..rows).map(|i| ((i + 3) % 100) as f32).collect();
519 let df = df![
520 "left" => left.clone(),
521 "right" => right.clone(),
522 ]
523 .unwrap();
524
525 let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
526 let polars_sum = product
527 .as_materialized_series()
528 .cast(&DataType::Float64)
529 .unwrap()
530 .f64()
531 .unwrap()
532 .sum()
533 .unwrap_or(0.0);
534
535 let mut naive_sum = 0.0f64;
536 for (l, r) in left.iter().zip(right.iter()) {
537 naive_sum += (*l as f64) * (*r as f64);
538 }
539
540 assert!((polars_sum - naive_sum).abs() < 1e-6);
541 }
542
543 #[test]
544 fn polars_groupby_sum_matches_naive() {
545 let rows = 10_000;
546 let keys: Vec<&str> = (0..rows)
547 .map(|i| {
548 if i % 3 == 0 {
549 "alpha"
550 } else if i % 3 == 1 {
551 "beta"
552 } else {
553 "gamma"
554 }
555 })
556 .collect();
557 let values: Vec<f32> = (0..rows).map(|i| (i % 7) as f32).collect();
558 let df = df![
559 "key" => keys.clone(),
560 "value" => values.clone(),
561 ]
562 .unwrap();
563
564 let grouped = df
565 .lazy()
566 .group_by([col("key")])
567 .agg([col("value").sum().alias("value_sum")])
568 .collect()
569 .unwrap();
570 let keys_series = grouped.column("key").unwrap().str().unwrap();
571 let sums_series = grouped.column("value_sum").unwrap().f32().unwrap();
572
573 let mut naive = HashMap::<&str, f32>::new();
574 for (key, value) in keys.iter().zip(values.iter()) {
575 *naive.entry(*key).or_insert(0.0) += value;
576 }
577
578 for idx in 0..grouped.height() {
579 if let Some(key) = keys_series.get(idx) {
580 let polars_value = sums_series.get(idx).unwrap_or(0.0);
581 let naive_value = naive.get(key).copied().unwrap_or(0.0);
582 assert!((polars_value - naive_value).abs() < 1e-3);
583 }
584 }
585 }
586
587 #[test]
588 #[ignore]
589 fn polars_vectorized_dot_product_is_fast() {
590 let rows = 300_000;
591 let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
592 let right: Vec<f32> = (0..rows).map(|i| ((i + 5) % 100) as f32).collect();
593
594 let df = df![
595 "left" => left.clone(),
596 "right" => right.clone(),
597 ]
598 .unwrap();
599
600 let polars_start = Instant::now();
601 let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
602 let polars_sum = product
603 .as_materialized_series()
604 .f32()
605 .unwrap()
606 .sum()
607 .unwrap_or(0.0);
608 let polars_elapsed = polars_start.elapsed();
609 black_box(polars_sum);
610
611 let naive_start = Instant::now();
612 let mut naive_sum = 0.0f32;
613 for (l, r) in left.iter().zip(right.iter()) {
614 naive_sum += l * r;
615 }
616 let naive_elapsed = naive_start.elapsed();
617 black_box(naive_sum);
618
619 println!(
620 "polars dot product: {:?}, naive loop: {:?}",
621 polars_elapsed, naive_elapsed
622 );
623
624 assert!(polars_elapsed <= naive_elapsed * 20);
625 }
626
627 #[test]
628 #[ignore]
629 fn polars_groupby_is_fast() {
630 let rows = 200_000;
631 let keys: Vec<&str> = (0..rows)
632 .map(|i| {
633 if i % 4 == 0 {
634 "alpha"
635 } else if i % 4 == 1 {
636 "beta"
637 } else if i % 4 == 2 {
638 "gamma"
639 } else {
640 "delta"
641 }
642 })
643 .collect();
644 let values: Vec<f32> = (0..rows).map(|i| (i % 9) as f32).collect();
645 let df = df![
646 "key" => keys.clone(),
647 "value" => values.clone(),
648 ]
649 .unwrap();
650
651 let polars_start = Instant::now();
652 let grouped = df
653 .lazy()
654 .group_by([col("key")])
655 .agg([col("value").sum().alias("value_sum")])
656 .collect()
657 .unwrap();
658 let polars_elapsed = polars_start.elapsed();
659 black_box(grouped.height());
660
661 let naive_start = Instant::now();
662 let mut naive = HashMap::<&str, f32>::new();
663 for (key, value) in keys.iter().zip(values.iter()) {
664 *naive.entry(*key).or_insert(0.0) += value;
665 }
666 let naive_elapsed = naive_start.elapsed();
667 black_box(naive.len());
668
669 println!(
670 "polars groupby: {:?}, naive hashmap: {:?}",
671 polars_elapsed, naive_elapsed
672 );
673
674 assert!(polars_elapsed <= naive_elapsed * 20);
675 }
676}