1use anyhow::{Result, anyhow};
4use converge_pack::{
5 AgentEffect, Context, ContextKey, FactPayload, ProvenanceSource, Suggestor, TextPayload,
6};
7use polars::prelude::*;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10
11use crate::provenance::PRISM_PROVENANCE;
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(deny_unknown_fields)]
16pub struct FeatureVector {
17 pub data: Vec<f32>,
18 pub shape: [usize; 2],
19}
20
21impl FactPayload for FeatureVector {
22 const FAMILY: &'static str = "prism.feature-vector";
23 const VERSION: u16 = 1;
24}
25
26impl FeatureVector {
27 pub fn new(data: Vec<f32>, shape: [usize; 2]) -> Result<Self> {
28 let expected = shape
29 .first()
30 .and_then(|rows| shape.get(1).map(|cols| rows.saturating_mul(*cols)))
31 .unwrap_or(0);
32 if data.len() != expected {
33 return Err(anyhow!(
34 "feature data length {} does not match shape {:?}",
35 data.len(),
36 shape
37 ));
38 }
39 Ok(Self { data, shape })
40 }
41
42 pub fn row(data: Vec<f32>) -> Self {
43 let cols = data.len();
44 Self {
45 data,
46 shape: [1, cols],
47 }
48 }
49
50 pub fn rows(&self) -> usize {
51 self.shape[0]
52 }
53
54 pub fn cols(&self) -> usize {
55 self.shape[1]
56 }
57}
58
59#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct FeatureColumns {
61 pub left: String,
62 pub right: String,
63}
64
65#[derive(Clone, Debug)]
66pub struct FeatureAgent {
67 source_path: Option<PathBuf>,
68 columns: Option<FeatureColumns>,
69}
70
71impl FeatureAgent {
72 pub fn new(source_path: Option<PathBuf>) -> Self {
73 Self {
74 source_path,
75 columns: None,
76 }
77 }
78
79 pub fn with_columns(mut self, left: impl Into<String>, right: impl Into<String>) -> Self {
80 self.columns = Some(FeatureColumns {
81 left: left.into(),
82 right: right.into(),
83 });
84 self
85 }
86
87 fn compute_features(&self) -> Result<FeatureVector> {
89 let df = if let Some(path) = &self.source_path {
90 load_dataframe(path)?
91 } else {
92 df! [
93 "x1" => [1.0, 2.0, 3.0],
94 "x2" => [4.0, 5.0, 6.0],
95 "x3" => [7.0, 8.0, 9.0],
96 ]?
97 };
98 compute_features_from_df(&df, self.columns.as_ref())
99 }
100}
101
102#[async_trait::async_trait]
103impl Suggestor for FeatureAgent {
104 fn name(&self) -> &'static str {
105 "FeatureAgent (Polars)"
106 }
107
108 fn dependencies(&self) -> &[ContextKey] {
109 &[ContextKey::Seeds]
111 }
112
113 fn accepts(&self, ctx: &dyn Context) -> bool {
114 ctx.has(ContextKey::Seeds) && !ctx.has(ContextKey::Proposals)
116 }
117
118 fn provenance(&self) -> &'static str {
119 PRISM_PROVENANCE.as_str()
120 }
121
122 async fn execute(&self, _ctx: &dyn Context) -> AgentEffect {
123 let features = match self.compute_features() {
125 Ok(f) => f,
126 Err(e) => {
127 return AgentEffect::with_proposal(PRISM_PROVENANCE.proposed_fact(
128 ContextKey::Diagnostic,
129 "feature-agent-error",
130 TextPayload::new(e.to_string()),
131 ));
132 }
133 };
134
135 let proposal =
137 PRISM_PROVENANCE.proposed_fact(ContextKey::Proposals, "features-001", features);
138
139 AgentEffect::with_proposal(proposal)
147 }
148}
149
150fn compute_features_from_df(
151 df: &DataFrame,
152 columns: Option<&FeatureColumns>,
153) -> Result<FeatureVector> {
154 let (left, right) = if let Some(columns) = columns {
155 let left = df
156 .column(&columns.left)
157 .map_err(|_| anyhow!("missing column {}", columns.left))?;
158 let right = df
159 .column(&columns.right)
160 .map_err(|_| anyhow!("missing column {}", columns.right))?;
161 (left.clone(), right.clone())
162 } else {
163 let mut numeric = df
164 .get_columns()
165 .iter()
166 .filter(|series| is_numeric_dtype(series.dtype()))
167 .cloned()
168 .collect::<Vec<_>>();
169 if numeric.len() < 2 {
170 return Err(anyhow!("need at least two numeric columns"));
171 }
172 (numeric.remove(0), numeric.remove(0))
173 };
174
175 if left.is_empty() || right.is_empty() {
176 return Err(anyhow!("input data is empty"));
177 }
178
179 let left = left.cast(&DataType::Float32)?;
180 let right = right.cast(&DataType::Float32)?;
181
182 let left_val = left
183 .f32()?
184 .get(0)
185 .ok_or_else(|| anyhow!("missing left value"))?;
186 let right_val = right
187 .f32()?
188 .get(0)
189 .ok_or_else(|| anyhow!("missing right value"))?;
190
191 let interaction = left_val * right_val;
192 Ok(FeatureVector::row(vec![left_val, right_val, interaction]))
193}
194
195fn load_dataframe(path: &Path) -> Result<DataFrame> {
196 let extension = path
197 .extension()
198 .and_then(|ext| ext.to_str())
199 .unwrap_or("")
200 .to_ascii_lowercase();
201
202 let path_str = path
203 .to_str()
204 .ok_or_else(|| anyhow!("path is not valid utf-8: {}", path.display()))?;
205
206 match extension.as_str() {
207 "parquet" => {
208 let pl_path = PlPath::from_str(path_str);
209 Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
210 }
211 "csv" => Ok(CsvReadOptions::default()
212 .with_has_header(true)
213 .try_into_reader_with_file_path(Some(path.to_path_buf()))?
214 .finish()?),
215 _ => Err(anyhow!(
216 "unsupported data format for path {} (expected .csv or .parquet)",
217 path.display()
218 )),
219 }
220}
221
222fn is_numeric_dtype(dtype: &DataType) -> bool {
223 matches!(
224 dtype,
225 DataType::Int8
226 | DataType::Int16
227 | DataType::Int32
228 | DataType::Int64
229 | DataType::UInt8
230 | DataType::UInt16
231 | DataType::UInt32
232 | DataType::UInt64
233 | DataType::Float32
234 | DataType::Float64
235 )
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use proptest::prelude::*;
242 use std::collections::HashMap;
243 use std::fs;
244 use std::hint::black_box;
245 use std::time::Instant;
246 use std::time::{SystemTime, UNIX_EPOCH};
247
248 #[test]
249 fn feature_vector_validates_shape() {
250 let ok = FeatureVector::new(vec![1.0, 2.0], [1, 2]).unwrap();
251 assert_eq!(ok.rows(), 1);
252 assert_eq!(ok.cols(), 2);
253 assert!(FeatureVector::new(vec![1.0], [1, 2]).is_err());
254 }
255
256 #[test]
257 fn feature_vector_new_multi_row() {
258 let fv = FeatureVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]).unwrap();
259 assert_eq!(fv.rows(), 2);
260 assert_eq!(fv.cols(), 3);
261 assert_eq!(fv.data.len(), 6);
262 }
263
264 #[test]
265 fn feature_vector_new_rejects_mismatched_length() {
266 assert!(FeatureVector::new(vec![1.0, 2.0, 3.0], [2, 2]).is_err());
267 assert!(FeatureVector::new(vec![], [1, 1]).is_err());
268 assert!(FeatureVector::new(vec![1.0], [0, 1]).is_err());
269 }
270
271 #[test]
272 fn feature_vector_new_empty() {
273 let fv = FeatureVector::new(vec![], [0, 0]).unwrap();
274 assert_eq!(fv.rows(), 0);
275 assert_eq!(fv.cols(), 0);
276 assert!(fv.data.is_empty());
277 }
278
279 #[test]
280 fn feature_vector_new_zero_cols() {
281 let fv = FeatureVector::new(vec![], [5, 0]).unwrap();
282 assert_eq!(fv.rows(), 5);
283 assert_eq!(fv.cols(), 0);
284 }
285
286 #[test]
287 fn feature_vector_row_creates_single_row() {
288 let fv = FeatureVector::row(vec![10.0, 20.0, 30.0]);
289 assert_eq!(fv.rows(), 1);
290 assert_eq!(fv.cols(), 3);
291 assert_eq!(fv.data, vec![10.0, 20.0, 30.0]);
292 }
293
294 #[test]
295 fn feature_vector_row_empty() {
296 let fv = FeatureVector::row(vec![]);
297 assert_eq!(fv.rows(), 1);
298 assert_eq!(fv.cols(), 0);
299 assert!(fv.data.is_empty());
300 }
301
302 #[test]
303 fn feature_vector_row_single_element() {
304 let fv = FeatureVector::row(vec![42.0]);
305 assert_eq!(fv.rows(), 1);
306 assert_eq!(fv.cols(), 1);
307 assert_eq!(fv.data, vec![42.0]);
308 }
309
310 #[test]
311 fn feature_columns_construction() {
312 let fc = FeatureColumns {
313 left: "price".to_string(),
314 right: "quantity".to_string(),
315 };
316 assert_eq!(fc.left, "price");
317 assert_eq!(fc.right, "quantity");
318 }
319
320 #[test]
321 fn feature_columns_roundtrip_serde() {
322 let fc = FeatureColumns {
323 left: "a".to_string(),
324 right: "b".to_string(),
325 };
326 let json = serde_json::to_string(&fc).unwrap();
327 let deserialized: FeatureColumns = serde_json::from_str(&json).unwrap();
328 assert_eq!(deserialized.left, "a");
329 assert_eq!(deserialized.right, "b");
330 }
331
332 #[test]
333 fn feature_vector_roundtrip_serde() {
334 let fv = FeatureVector::new(vec![1.0, 2.0, 3.0, 4.0], [2, 2]).unwrap();
335 let json = serde_json::to_string(&fv).unwrap();
336 let deserialized: FeatureVector = serde_json::from_str(&json).unwrap();
337 assert_eq!(fv, deserialized);
338 }
339
340 #[test]
341 fn feature_agent_new_without_columns() {
342 let agent = FeatureAgent::new(None);
343 assert!(agent.source_path.is_none());
344 assert!(agent.columns.is_none());
345 }
346
347 #[test]
348 fn feature_agent_with_columns() {
349 let agent = FeatureAgent::new(None).with_columns("x", "y");
350 let cols = agent.columns.unwrap();
351 assert_eq!(cols.left, "x");
352 assert_eq!(cols.right, "y");
353 }
354
355 #[test]
356 fn feature_agent_with_source_path() {
357 let agent = FeatureAgent::new(Some(PathBuf::from("/tmp/data.csv")));
358 assert_eq!(agent.source_path.unwrap(), PathBuf::from("/tmp/data.csv"));
359 }
360
361 #[test]
362 fn is_numeric_dtype_covers_all_numeric_types() {
363 let numeric = [
364 DataType::Int8,
365 DataType::Int16,
366 DataType::Int32,
367 DataType::Int64,
368 DataType::UInt8,
369 DataType::UInt16,
370 DataType::UInt32,
371 DataType::UInt64,
372 DataType::Float32,
373 DataType::Float64,
374 ];
375 for dt in &numeric {
376 assert!(is_numeric_dtype(dt), "{dt:?} should be numeric");
377 }
378 }
379
380 #[test]
381 fn is_numeric_dtype_rejects_non_numeric() {
382 assert!(!is_numeric_dtype(&DataType::String));
383 assert!(!is_numeric_dtype(&DataType::Boolean));
384 assert!(!is_numeric_dtype(&DataType::Date));
385 }
386
387 #[test]
388 fn compute_features_rejects_empty_dataframe() {
389 let df = df![
390 "a" => Vec::<f32>::new(),
391 "b" => Vec::<f32>::new(),
392 ]
393 .unwrap();
394 let cols = FeatureColumns {
395 left: "a".into(),
396 right: "b".into(),
397 };
398 assert!(compute_features_from_df(&df, Some(&cols)).is_err());
399 }
400
401 #[test]
402 fn compute_features_rejects_missing_column() {
403 let df = df!["a" => [1.0f32]].unwrap();
404 let cols = FeatureColumns {
405 left: "a".into(),
406 right: "missing".into(),
407 };
408 assert!(compute_features_from_df(&df, Some(&cols)).is_err());
409 }
410
411 #[test]
412 fn compute_features_rejects_insufficient_numeric_columns() {
413 let df = df!["text" => ["a", "b"]].unwrap();
414 assert!(compute_features_from_df(&df, None).is_err());
415 }
416
417 proptest! {
418 #[test]
419 fn feature_vector_shape_invariant(
420 rows in 0usize..50,
421 cols in 0usize..50,
422 ) {
423 let len = rows.saturating_mul(cols);
424 let data = vec![0.0f32; len];
425 let fv = FeatureVector::new(data, [rows, cols]).unwrap();
426 prop_assert_eq!(fv.rows() * fv.cols(), fv.data.len());
427 }
428 }
429
430 #[test]
431 fn compute_features_from_df_uses_named_columns() {
432 let df = df![
433 "a" => [2.0f32, 3.0],
434 "b" => [4.0f32, 5.0],
435 ]
436 .unwrap();
437
438 let columns = FeatureColumns {
439 left: "a".into(),
440 right: "b".into(),
441 };
442 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
443 assert_eq!(features.data, vec![2.0, 4.0, 8.0]);
444 assert_eq!(features.shape, [1, 3]);
445 }
446
447 #[test]
448 fn compute_features_from_df_falls_back_to_first_numeric_columns() {
449 let df = df![
450 "text" => ["x", "y"],
451 "a" => [1.5f32, 2.5],
452 "b" => [3.0f32, 4.0],
453 ]
454 .unwrap();
455
456 let features = compute_features_from_df(&df, None).unwrap();
457 assert_eq!(features.data, vec![1.5, 3.0, 4.5]);
458 }
459
460 #[test]
461 fn compute_features_handles_large_dataset() {
462 let rows = 10_000;
463 let left: Vec<f32> = (0..rows).map(|i| i as f32).collect();
464 let right: Vec<f32> = (0..rows).map(|i| (i as f32) + 1.0).collect();
465 let df = df![
466 "left" => left,
467 "right" => right,
468 ]
469 .unwrap();
470
471 let columns = FeatureColumns {
472 left: "left".into(),
473 right: "right".into(),
474 };
475 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
476 assert_eq!(features.data, vec![0.0, 1.0, 0.0]);
477 }
478
479 #[test]
480 fn load_dataframe_reads_csv() {
481 let mut path = std::env::temp_dir();
482 let nanos = SystemTime::now()
483 .duration_since(UNIX_EPOCH)
484 .unwrap()
485 .as_nanos();
486 path.push(format!("prism_{nanos}.csv"));
487
488 let contents = "left,right\n2.0,4.0\n3.0,5.0\n";
489 fs::write(&path, contents).unwrap();
490
491 let df = load_dataframe(&path).unwrap();
492 assert_eq!(df.height(), 2);
493 assert_eq!(df.width(), 2);
494 }
495
496 proptest! {
497 #[test]
498 fn compute_features_matches_first_row(
499 left in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
500 right in proptest::collection::vec(prop::num::f32::NORMAL, 1..50),
501 ) {
502 let len = left.len().min(right.len());
503 let df = df![
504 "left" => left[..len].to_vec(),
505 "right" => right[..len].to_vec(),
506 ]
507 .unwrap();
508
509 let columns = FeatureColumns {
510 left: "left".into(),
511 right: "right".into(),
512 };
513 let features = compute_features_from_df(&df, Some(&columns)).unwrap();
514 let expected_left = left[0];
515 let expected_right = right[0];
516 prop_assert_eq!(features.data, vec![expected_left, expected_right, expected_left * expected_right]);
517 }
518 }
519
520 #[test]
521 fn polars_vectorized_dot_product_matches_naive() {
522 let rows = 50_000;
523 let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
524 let right: Vec<f32> = (0..rows).map(|i| ((i + 3) % 100) as f32).collect();
525 let df = df![
526 "left" => left.clone(),
527 "right" => right.clone(),
528 ]
529 .unwrap();
530
531 let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
532 let polars_sum = product
533 .as_materialized_series()
534 .cast(&DataType::Float64)
535 .unwrap()
536 .f64()
537 .unwrap()
538 .sum()
539 .unwrap_or(0.0);
540
541 let mut naive_sum = 0.0f64;
542 for (l, r) in left.iter().zip(right.iter()) {
543 naive_sum += (*l as f64) * (*r as f64);
544 }
545
546 assert!((polars_sum - naive_sum).abs() < 1e-6);
547 }
548
549 #[test]
550 fn polars_groupby_sum_matches_naive() {
551 let rows = 10_000;
552 let keys: Vec<&str> = (0..rows)
553 .map(|i| {
554 if i % 3 == 0 {
555 "alpha"
556 } else if i % 3 == 1 {
557 "beta"
558 } else {
559 "gamma"
560 }
561 })
562 .collect();
563 let values: Vec<f32> = (0..rows).map(|i| (i % 7) as f32).collect();
564 let df = df![
565 "key" => keys.clone(),
566 "value" => values.clone(),
567 ]
568 .unwrap();
569
570 let grouped = df
571 .lazy()
572 .group_by([col("key")])
573 .agg([col("value").sum().alias("value_sum")])
574 .collect()
575 .unwrap();
576 let keys_series = grouped.column("key").unwrap().str().unwrap();
577 let sums_series = grouped.column("value_sum").unwrap().f32().unwrap();
578
579 let mut naive = HashMap::<&str, f32>::new();
580 for (key, value) in keys.iter().zip(values.iter()) {
581 *naive.entry(*key).or_insert(0.0) += value;
582 }
583
584 for idx in 0..grouped.height() {
585 if let Some(key) = keys_series.get(idx) {
586 let polars_value = sums_series.get(idx).unwrap_or(0.0);
587 let naive_value = naive.get(key).copied().unwrap_or(0.0);
588 assert!((polars_value - naive_value).abs() < 1e-3);
589 }
590 }
591 }
592
593 #[test]
594 #[ignore]
595 fn polars_vectorized_dot_product_is_fast() {
596 let rows = 300_000;
597 let left: Vec<f32> = (0..rows).map(|i| (i % 100) as f32).collect();
598 let right: Vec<f32> = (0..rows).map(|i| ((i + 5) % 100) as f32).collect();
599
600 let df = df![
601 "left" => left.clone(),
602 "right" => right.clone(),
603 ]
604 .unwrap();
605
606 let polars_start = Instant::now();
607 let product = (df.column("left").unwrap() * df.column("right").unwrap()).unwrap();
608 let polars_sum = product
609 .as_materialized_series()
610 .f32()
611 .unwrap()
612 .sum()
613 .unwrap_or(0.0);
614 let polars_elapsed = polars_start.elapsed();
615 black_box(polars_sum);
616
617 let naive_start = Instant::now();
618 let mut naive_sum = 0.0f32;
619 for (l, r) in left.iter().zip(right.iter()) {
620 naive_sum += l * r;
621 }
622 let naive_elapsed = naive_start.elapsed();
623 black_box(naive_sum);
624
625 println!(
626 "polars dot product: {:?}, naive loop: {:?}",
627 polars_elapsed, naive_elapsed
628 );
629
630 assert!(polars_elapsed <= naive_elapsed * 20);
631 }
632
633 #[test]
634 #[ignore]
635 fn polars_groupby_is_fast() {
636 let rows = 200_000;
637 let keys: Vec<&str> = (0..rows)
638 .map(|i| {
639 if i % 4 == 0 {
640 "alpha"
641 } else if i % 4 == 1 {
642 "beta"
643 } else if i % 4 == 2 {
644 "gamma"
645 } else {
646 "delta"
647 }
648 })
649 .collect();
650 let values: Vec<f32> = (0..rows).map(|i| (i % 9) as f32).collect();
651 let df = df![
652 "key" => keys.clone(),
653 "value" => values.clone(),
654 ]
655 .unwrap();
656
657 let polars_start = Instant::now();
658 let grouped = df
659 .lazy()
660 .group_by([col("key")])
661 .agg([col("value").sum().alias("value_sum")])
662 .collect()
663 .unwrap();
664 let polars_elapsed = polars_start.elapsed();
665 black_box(grouped.height());
666
667 let naive_start = Instant::now();
668 let mut naive = HashMap::<&str, f32>::new();
669 for (key, value) in keys.iter().zip(values.iter()) {
670 *naive.entry(*key).or_insert(0.0) += value;
671 }
672 let naive_elapsed = naive_start.elapsed();
673 black_box(naive.len());
674
675 println!(
676 "polars groupby: {:?}, naive hashmap: {:?}",
677 polars_elapsed, naive_elapsed
678 );
679
680 assert!(polars_elapsed <= naive_elapsed * 20);
681 }
682}