1use anyhow::{Context as _, Result, anyhow};
4use converge_pack::{AgentEffect, Context, ContextKey, ProposalId, ProposedFact, Suggestor};
5use polars::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fs::{File, create_dir_all};
9use std::io::Write;
10use std::path::{Path, PathBuf};
11
12const DATASET_URL: &str = "https://huggingface.co/datasets/gvlassis/california_housing/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet";
13const TARGET_COLUMN: &str = "median_house_value";
14
15fn proposal(
16 provenance: &str,
17 key: ContextKey,
18 id: impl Into<String>,
19 content: impl Into<String>,
20) -> ProposedFact {
21 ProposedFact::new(key, ProposalId::new(id.into()), content, provenance)
22}
23
24#[derive(Debug, Serialize, Deserialize, Clone)]
25pub struct TrainingPlan {
26 pub iteration: usize,
27 pub max_rows: usize,
28 pub train_fraction: f64,
29 pub val_fraction: f64,
30 pub infer_fraction: f64,
31 pub quality_threshold: f64,
32}
33
34#[derive(Debug, Serialize, Deserialize, Clone)]
35pub struct DatasetSplit {
36 pub source_path: String,
37 pub train_path: String,
38 pub val_path: String,
39 pub infer_path: String,
40 pub total_rows: usize,
41 pub max_rows: usize,
42 pub train_rows: usize,
43 pub val_rows: usize,
44 pub infer_rows: usize,
45 pub iteration: usize,
46}
47
48#[derive(Debug, Serialize, Deserialize, Clone)]
49pub struct BaselineModel {
50 pub target_column: String,
51 pub mean: f64,
52}
53
54#[derive(Debug, Serialize, Deserialize, Clone)]
55pub struct ModelMetadata {
56 pub model_path: String,
57 pub target_column: String,
58 pub train_rows: usize,
59 pub baseline_mean: f64,
60 pub iteration: usize,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
64pub struct EvaluationReport {
65 pub model_path: String,
66 pub metric: String,
67 pub value: f64,
68 pub mean_abs_target: f64,
69 pub success_ratio: f64,
70 pub val_rows: usize,
71 pub iteration: usize,
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct InferenceSample {
76 pub model_path: String,
77 pub target_column: String,
78 pub rows: usize,
79 pub predictions: Vec<f64>,
80 pub actuals: Vec<f64>,
81 pub iteration: usize,
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
85pub struct DataQualityReport {
86 pub kind: String,
87 pub iteration: usize,
88 pub source_path: String,
89 pub rows_checked: usize,
90 pub missingness: HashMap<String, f64>,
91 pub numeric_means: HashMap<String, f64>,
92 pub outlier_counts: HashMap<String, usize>,
93 pub drift_score: Option<f64>,
94}
95
96#[derive(Debug, Serialize, Deserialize, Clone)]
97pub struct FeatureInteraction {
98 pub name: String,
99 pub left: String,
100 pub right: String,
101 pub op: String,
102}
103
104#[derive(Debug, Serialize, Deserialize, Clone)]
105pub struct FeatureSpec {
106 pub kind: String,
107 pub iteration: usize,
108 pub target_column: String,
109 pub numeric_features: Vec<String>,
110 pub categorical_features: Vec<String>,
111 pub normalization: String,
112 pub interactions: Vec<FeatureInteraction>,
113}
114
115#[derive(Debug, Serialize, Deserialize, Clone)]
116pub struct HyperparameterSearchPlan {
117 pub kind: String,
118 pub iteration: usize,
119 pub max_trials: usize,
120 pub early_stopping: bool,
121 pub params: HashMap<String, Vec<f64>>,
122}
123
124#[derive(Debug, Serialize, Deserialize, Clone)]
125pub struct HyperparameterSearchResult {
126 pub kind: String,
127 pub iteration: usize,
128 pub best_params: HashMap<String, f64>,
129 pub score: f64,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ModelRegistryRecord {
134 pub kind: String,
135 pub iteration: usize,
136 pub model_path: String,
137 pub metrics: HashMap<String, f64>,
138 pub provenance: String,
139}
140
141#[derive(Debug, Serialize, Deserialize, Clone)]
142pub struct MonitoringReport {
143 pub kind: String,
144 pub iteration: usize,
145 pub metric: String,
146 pub value: f64,
147 pub baseline: f64,
148 pub status: String,
149}
150
151#[derive(Debug, Serialize, Deserialize, Clone)]
152pub struct DeploymentDecision {
153 pub kind: String,
154 pub iteration: usize,
155 pub action: String,
156 pub reason: String,
157 pub retrain: bool,
158}
159
160#[derive(Debug)]
161pub struct DatasetAgent {
162 data_dir: PathBuf,
163}
164
165impl DatasetAgent {
166 pub fn new(data_dir: PathBuf) -> Self {
167 Self { data_dir }
168 }
169
170 fn dataset_path(&self) -> PathBuf {
171 self.data_dir.join("california_housing_train.parquet")
172 }
173
174 fn split_paths(&self) -> (PathBuf, PathBuf, PathBuf) {
175 (
176 self.data_dir.join("train.parquet"),
177 self.data_dir.join("val.parquet"),
178 self.data_dir.join("infer.parquet"),
179 )
180 }
181}
182
183#[derive(Debug, Default)]
184pub struct DataValidationAgent;
185
186impl DataValidationAgent {
187 pub fn new() -> Self {
188 Self
189 }
190}
191
192#[async_trait::async_trait]
193impl Suggestor for DataValidationAgent {
194 fn name(&self) -> &str {
195 "DataValidationAgent"
196 }
197
198 fn dependencies(&self) -> &[ContextKey] {
199 &[ContextKey::Signals]
200 }
201
202 fn accepts(&self, ctx: &dyn Context) -> bool {
203 ctx.has(ContextKey::Signals)
204 && match read_latest_split_from_ctx(ctx) {
205 Ok(split) => !has_data_quality_for_iteration(ctx, split.iteration),
206 Err(_) => false,
207 }
208 }
209
210 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
211 let split = match read_latest_split_from_ctx(ctx) {
212 Ok(split) => split,
213 Err(err) => {
214 return AgentEffect::with_proposal(proposal(
215 self.name(),
216 ContextKey::Diagnostic,
217 "data-validation-error",
218 err.to_string(),
219 ));
220 }
221 };
222
223 let df = match load_dataframe(Path::new(&split.train_path)) {
224 Ok(df) => df,
225 Err(err) => {
226 return AgentEffect::with_proposal(proposal(
227 self.name(),
228 ContextKey::Diagnostic,
229 "data-validation-error",
230 err.to_string(),
231 ));
232 }
233 };
234
235 let rows = df.height();
236 let mut missingness = HashMap::new();
237 let mut numeric_means = HashMap::new();
238 let mut outlier_counts = HashMap::new();
239
240 for series in df.get_columns() {
241 let name = series.name().to_string();
242 let null_ratio = if rows > 0 {
243 series.null_count() as f64 / rows as f64
244 } else {
245 0.0
246 };
247 missingness.insert(name.clone(), null_ratio);
248
249 if is_numeric_dtype(series.dtype()) {
250 if let Ok((mean, _std, outliers)) =
251 compute_numeric_stats(series.as_materialized_series())
252 {
253 numeric_means.insert(name.clone(), mean);
254 outlier_counts.insert(name, outliers);
255 }
256 }
257 }
258
259 let drift_score = drift_score_from_ctx(ctx, split.iteration, &numeric_means);
260
261 let report = DataQualityReport {
262 kind: "data_quality".to_string(),
263 iteration: split.iteration,
264 source_path: split.train_path.clone(),
265 rows_checked: rows,
266 missingness,
267 numeric_means,
268 outlier_counts,
269 drift_score,
270 };
271
272 let content = serde_json::to_string(&report).unwrap_or_default();
273 AgentEffect::with_proposal(proposal(
274 self.name(),
275 ContextKey::Signals,
276 format!("data-quality-{}", split.iteration),
277 content,
278 ))
279 }
280}
281
282#[derive(Debug, Default)]
283pub struct FeatureEngineeringAgent;
284
285impl FeatureEngineeringAgent {
286 pub fn new() -> Self {
287 Self
288 }
289}
290
291#[async_trait::async_trait]
292impl Suggestor for FeatureEngineeringAgent {
293 fn name(&self) -> &str {
294 "FeatureEngineeringAgent"
295 }
296
297 fn dependencies(&self) -> &[ContextKey] {
298 &[ContextKey::Signals]
299 }
300
301 fn accepts(&self, ctx: &dyn Context) -> bool {
302 ctx.has(ContextKey::Signals)
303 && match read_latest_split_from_ctx(ctx) {
304 Ok(split) => !has_feature_spec_for_iteration(ctx, split.iteration),
305 Err(_) => false,
306 }
307 }
308
309 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
310 let split = match read_latest_split_from_ctx(ctx) {
311 Ok(split) => split,
312 Err(err) => {
313 return AgentEffect::with_proposal(proposal(
314 self.name(),
315 ContextKey::Diagnostic,
316 "feature-engineering-error",
317 err.to_string(),
318 ));
319 }
320 };
321
322 let df = match load_dataframe(Path::new(&split.train_path)) {
323 Ok(df) => df,
324 Err(err) => {
325 return AgentEffect::with_proposal(proposal(
326 self.name(),
327 ContextKey::Diagnostic,
328 "feature-engineering-error",
329 err.to_string(),
330 ));
331 }
332 };
333
334 let (target_column, _) = match select_target_column(&df) {
335 Ok(value) => value,
336 Err(err) => {
337 return AgentEffect::with_proposal(proposal(
338 self.name(),
339 ContextKey::Diagnostic,
340 "feature-engineering-error",
341 err.to_string(),
342 ));
343 }
344 };
345
346 let (numeric_features, categorical_features) = split_feature_columns(&df, &target_column);
347
348 let mut interactions = Vec::new();
349 if numeric_features.len() >= 2 {
350 interactions.push(FeatureInteraction {
351 name: format!("{}_x_{}", numeric_features[0], numeric_features[1]),
352 left: numeric_features[0].clone(),
353 right: numeric_features[1].clone(),
354 op: "multiply".to_string(),
355 });
356 }
357
358 let spec = FeatureSpec {
359 kind: "feature_spec".to_string(),
360 iteration: split.iteration,
361 target_column,
362 numeric_features,
363 categorical_features,
364 normalization: "standardize".to_string(),
365 interactions,
366 };
367
368 let content = serde_json::to_string(&spec).unwrap_or_default();
369 AgentEffect::with_proposal(proposal(
370 self.name(),
371 ContextKey::Constraints,
372 format!("feature-spec-{}", split.iteration),
373 content,
374 ))
375 }
376}
377
378#[derive(Debug)]
379pub struct HyperparameterSearchAgent {
380 max_trials: usize,
381}
382
383impl HyperparameterSearchAgent {
384 pub fn new(max_trials: usize) -> Self {
385 Self { max_trials }
386 }
387}
388
389#[async_trait::async_trait]
390impl Suggestor for HyperparameterSearchAgent {
391 fn name(&self) -> &str {
392 "HyperparameterSearchAgent"
393 }
394
395 fn dependencies(&self) -> &[ContextKey] {
396 &[ContextKey::Constraints, ContextKey::Signals]
397 }
398
399 fn accepts(&self, ctx: &dyn Context) -> bool {
400 ctx.has(ContextKey::Signals)
401 && match read_latest_split_from_ctx(ctx) {
402 Ok(split) => !has_hyperparam_result_for_iteration(ctx, split.iteration),
403 Err(_) => false,
404 }
405 }
406
407 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
408 let split = match read_latest_split_from_ctx(ctx) {
409 Ok(split) => split,
410 Err(err) => {
411 return AgentEffect::with_proposal(proposal(
412 self.name(),
413 ContextKey::Diagnostic,
414 "hyperparam-search-error",
415 err.to_string(),
416 ));
417 }
418 };
419
420 let training_plan = read_latest_plan_from_ctx(ctx).unwrap_or(TrainingPlan {
421 iteration: split.iteration,
422 max_rows: split.max_rows,
423 train_fraction: 0.8,
424 val_fraction: 0.15,
425 infer_fraction: 0.05,
426 quality_threshold: 0.75,
427 });
428
429 let mut params = HashMap::new();
430 params.insert("learning_rate".to_string(), vec![0.001, 0.01, 0.1]);
431 params.insert("hidden_size".to_string(), vec![8.0, 16.0, 32.0]);
432
433 let plan = HyperparameterSearchPlan {
434 kind: "hyperparam_plan".to_string(),
435 iteration: split.iteration,
436 max_trials: self.max_trials,
437 early_stopping: true,
438 params,
439 };
440
441 let mut best_params = HashMap::new();
442 best_params.insert("learning_rate".to_string(), 0.01);
443 best_params.insert("hidden_size".to_string(), 16.0);
444 let score = (1.0 - training_plan.quality_threshold) * plan.max_trials as f64
445 / plan.iteration.max(1) as f64;
446 let result = HyperparameterSearchResult {
447 kind: "hyperparam_result".to_string(),
448 iteration: split.iteration,
449 best_params,
450 score,
451 };
452
453 let plan_content = serde_json::to_string(&plan).unwrap_or_default();
454 let result_content = serde_json::to_string(&result).unwrap_or_default();
455
456 let mut effect = AgentEffect::empty();
457 effect.proposals.push(proposal(
458 self.name(),
459 ContextKey::Constraints,
460 format!("hyperparam-plan-{}", split.iteration),
461 plan_content,
462 ));
463 effect.proposals.push(proposal(
464 self.name(),
465 ContextKey::Evaluations,
466 format!("hyperparam-result-{}", split.iteration),
467 result_content,
468 ));
469 effect
470 }
471}
472
473#[async_trait::async_trait]
474impl Suggestor for DatasetAgent {
475 fn name(&self) -> &str {
476 "DatasetAgent (HuggingFace)"
477 }
478
479 fn dependencies(&self) -> &[ContextKey] {
480 &[ContextKey::Seeds]
481 }
482
483 fn accepts(&self, ctx: &dyn Context) -> bool {
484 if !ctx.has(ContextKey::Seeds) {
485 return false;
486 }
487
488 let plan = read_latest_plan_from_ctx(ctx);
489 if let Some(plan) = plan {
490 return !has_split_for_iteration(ctx, plan.iteration);
491 }
492
493 !ctx.has(ContextKey::Signals)
494 }
495
496 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
497 if let Err(err) = create_dir_all(&self.data_dir) {
498 return AgentEffect::with_proposal(proposal(
499 self.name(),
500 ContextKey::Diagnostic,
501 "dataset-agent-error",
502 err.to_string(),
503 ));
504 }
505
506 let dataset_path = self.dataset_path();
507 if let Err(err) = download_dataset_if_missing(&dataset_path) {
508 return AgentEffect::with_proposal(proposal(
509 self.name(),
510 ContextKey::Diagnostic,
511 "dataset-agent-error",
512 err.to_string(),
513 ));
514 }
515
516 let df = match load_dataframe(&dataset_path) {
517 Ok(df) => df,
518 Err(err) => {
519 return AgentEffect::with_proposal(proposal(
520 self.name(),
521 ContextKey::Diagnostic,
522 "dataset-agent-error",
523 err.to_string(),
524 ));
525 }
526 };
527
528 let total_rows = df.height();
529 if total_rows < 10 {
530 return AgentEffect::with_proposal(proposal(
531 self.name(),
532 ContextKey::Diagnostic,
533 "dataset-agent-error",
534 "dataset too small for splitting",
535 ));
536 }
537
538 let plan = read_latest_plan_from_ctx(ctx).unwrap_or(TrainingPlan {
539 iteration: 1,
540 max_rows: total_rows,
541 train_fraction: 0.8,
542 val_fraction: 0.15,
543 infer_fraction: 0.05,
544 quality_threshold: 0.75,
545 });
546
547 let max_rows = plan.max_rows.min(total_rows).max(10);
548 let df = df.slice(0, max_rows);
549
550 let mut train_rows = ((max_rows as f64) * plan.train_fraction).floor() as usize;
551 let mut val_rows = ((max_rows as f64) * plan.val_fraction).floor() as usize;
552 let mut infer_rows = max_rows.saturating_sub(train_rows + val_rows);
553 if infer_rows == 0 {
554 if val_rows > 1 {
555 val_rows -= 1;
556 } else if train_rows > 1 {
557 train_rows -= 1;
558 }
559 infer_rows = max_rows.saturating_sub(train_rows + val_rows).max(1);
560 }
561
562 let (train_path, val_path, infer_path) = self.split_paths();
563 let train_df = df.slice(0, train_rows);
564 let val_df = df.slice(train_rows as i64, val_rows);
565 let infer_df = df.slice((train_rows + val_rows) as i64, infer_rows);
566
567 if let Err(err) = write_parquet(&train_df, &train_path)
568 .and_then(|_| write_parquet(&val_df, &val_path))
569 .and_then(|_| write_parquet(&infer_df, &infer_path))
570 {
571 return AgentEffect::with_proposal(proposal(
572 self.name(),
573 ContextKey::Diagnostic,
574 "dataset-agent-error",
575 err.to_string(),
576 ));
577 }
578
579 let split = DatasetSplit {
580 source_path: dataset_path.to_string_lossy().to_string(),
581 train_path: train_path.to_string_lossy().to_string(),
582 val_path: val_path.to_string_lossy().to_string(),
583 infer_path: infer_path.to_string_lossy().to_string(),
584 total_rows,
585 max_rows,
586 train_rows,
587 val_rows,
588 infer_rows,
589 iteration: plan.iteration,
590 };
591
592 let content = serde_json::to_string(&split).unwrap_or_default();
593 AgentEffect::with_proposal(proposal(
594 self.name(),
595 ContextKey::Signals,
596 format!("dataset-split-{}", plan.iteration),
597 content,
598 ))
599 }
600}
601
602#[derive(Debug)]
603pub struct ModelTrainingAgent {
604 model_dir: PathBuf,
605}
606
607impl ModelTrainingAgent {
608 pub fn new(model_dir: PathBuf) -> Self {
609 Self { model_dir }
610 }
611
612 fn model_path(&self) -> PathBuf {
613 self.model_dir.join("baseline_mean.json")
614 }
615}
616
617#[async_trait::async_trait]
618impl Suggestor for ModelTrainingAgent {
619 fn name(&self) -> &str {
620 "ModelTrainingAgent (Baseline)"
621 }
622
623 fn dependencies(&self) -> &[ContextKey] {
624 &[ContextKey::Signals]
625 }
626
627 fn accepts(&self, ctx: &dyn Context) -> bool {
628 if !ctx.has(ContextKey::Signals) {
629 return false;
630 }
631 let split = match read_latest_split_from_ctx(ctx) {
632 Ok(split) => split,
633 Err(_) => return false,
634 };
635 !has_model_for_iteration(ctx, split.iteration)
636 }
637
638 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
639 let split = match read_latest_split_from_ctx(ctx) {
640 Ok(split) => split,
641 Err(err) => {
642 return AgentEffect::with_proposal(proposal(
643 self.name(),
644 ContextKey::Diagnostic,
645 "model-training-error",
646 err.to_string(),
647 ));
648 }
649 };
650
651 if let Err(err) = create_dir_all(&self.model_dir) {
652 return AgentEffect::with_proposal(proposal(
653 self.name(),
654 ContextKey::Diagnostic,
655 "model-training-error",
656 err.to_string(),
657 ));
658 }
659
660 let raw_train_df = match load_dataframe(Path::new(&split.train_path)) {
661 Ok(df) => df,
662 Err(err) => {
663 return AgentEffect::with_proposal(proposal(
664 self.name(),
665 ContextKey::Diagnostic,
666 "model-training-error",
667 err.to_string(),
668 ));
669 }
670 };
671
672 let train_df = match read_feature_spec_from_ctx(ctx, split.iteration) {
674 Some(spec) => match apply_feature_spec(&raw_train_df, &spec) {
675 Ok(df) => df,
676 Err(err) => {
677 return AgentEffect::with_proposal(proposal(
678 self.name(),
679 ContextKey::Diagnostic,
680 "model-training-error",
681 format!("feature spec application failed: {}", err),
682 ));
683 }
684 },
685 None => raw_train_df,
686 };
687
688 let (target_name, target) = match select_target_column(&train_df) {
689 Ok(value) => value,
690 Err(err) => {
691 return AgentEffect::with_proposal(proposal(
692 self.name(),
693 ContextKey::Diagnostic,
694 "model-training-error",
695 err.to_string(),
696 ));
697 }
698 };
699
700 let mean = match mean_of_series(&target) {
701 Ok(value) => value,
702 Err(err) => {
703 return AgentEffect::with_proposal(proposal(
704 self.name(),
705 ContextKey::Diagnostic,
706 "model-training-error",
707 err.to_string(),
708 ));
709 }
710 };
711
712 let model = BaselineModel {
713 target_column: target_name.clone(),
714 mean,
715 };
716
717 let model_path = self.model_path();
718 if let Err(err) = write_json(&model_path, &model) {
719 return AgentEffect::with_proposal(proposal(
720 self.name(),
721 ContextKey::Diagnostic,
722 "model-training-error",
723 err.to_string(),
724 ));
725 }
726
727 let meta = ModelMetadata {
728 model_path: model_path.to_string_lossy().to_string(),
729 target_column: target_name,
730 train_rows: split.train_rows,
731 baseline_mean: mean,
732 iteration: split.iteration,
733 };
734
735 let content = serde_json::to_string(&meta).unwrap_or_default();
736 AgentEffect::with_proposal(proposal(
737 self.name(),
738 ContextKey::Strategies,
739 format!("trained-model-{}", split.iteration),
740 content,
741 ))
742 }
743}
744
745#[derive(Debug, Default)]
746pub struct ModelEvaluationAgent;
747
748impl ModelEvaluationAgent {
749 pub fn new() -> Self {
750 Self
751 }
752}
753
754#[derive(Debug, Default)]
755pub struct ModelRegistryAgent;
756
757impl ModelRegistryAgent {
758 pub fn new() -> Self {
759 Self
760 }
761}
762
763#[async_trait::async_trait]
764impl Suggestor for ModelRegistryAgent {
765 fn name(&self) -> &str {
766 "ModelRegistryAgent"
767 }
768
769 fn dependencies(&self) -> &[ContextKey] {
770 &[ContextKey::Strategies, ContextKey::Evaluations]
771 }
772
773 fn accepts(&self, ctx: &dyn Context) -> bool {
774 ctx.has(ContextKey::Strategies)
775 && ctx.has(ContextKey::Evaluations)
776 && match read_latest_model_meta_from_ctx(ctx) {
777 Ok(meta) => !has_registry_record_for_iteration(ctx, meta.iteration),
778 Err(_) => false,
779 }
780 }
781
782 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
783 let meta = match read_latest_model_meta_from_ctx(ctx) {
784 Ok(meta) => meta,
785 Err(err) => {
786 return AgentEffect::with_proposal(proposal(
787 self.name(),
788 ContextKey::Diagnostic,
789 "model-registry-error",
790 err.to_string(),
791 ));
792 }
793 };
794
795 let report = latest_evaluation_report(ctx, meta.iteration);
796 let mut metrics = HashMap::new();
797 if let Some(report) = report {
798 metrics.insert(report.metric, report.value);
799 metrics.insert("success_ratio".to_string(), report.success_ratio);
800 }
801
802 let record = ModelRegistryRecord {
803 kind: "model_registry".to_string(),
804 iteration: meta.iteration,
805 model_path: meta.model_path,
806 metrics,
807 provenance: "training_flow".to_string(),
808 };
809
810 let content = serde_json::to_string(&record).unwrap_or_default();
811 AgentEffect::with_proposal(proposal(
812 self.name(),
813 ContextKey::Strategies,
814 format!("model-registry-{}", record.iteration),
815 content,
816 ))
817 }
818}
819
820#[derive(Debug, Default)]
821pub struct MonitoringAgent;
822
823impl MonitoringAgent {
824 pub fn new() -> Self {
825 Self
826 }
827}
828
829#[async_trait::async_trait]
830impl Suggestor for MonitoringAgent {
831 fn name(&self) -> &str {
832 "MonitoringAgent"
833 }
834
835 fn dependencies(&self) -> &[ContextKey] {
836 &[ContextKey::Evaluations]
837 }
838
839 fn accepts(&self, ctx: &dyn Context) -> bool {
840 ctx.has(ContextKey::Evaluations)
841 && match latest_evaluation_report(ctx, 0) {
842 Some(report) => !has_monitoring_report_for_iteration(ctx, report.iteration),
843 None => false,
844 }
845 }
846
847 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
848 let report = match latest_evaluation_report(ctx, 0) {
849 Some(report) => report,
850 None => return AgentEffect::empty(),
851 };
852
853 let status = if report.success_ratio >= 0.75 {
854 "healthy"
855 } else {
856 "needs_attention"
857 };
858
859 let monitoring = MonitoringReport {
860 kind: "monitoring".to_string(),
861 iteration: report.iteration,
862 metric: report.metric,
863 value: report.value,
864 baseline: report.mean_abs_target,
865 status: status.to_string(),
866 };
867
868 let content = serde_json::to_string(&monitoring).unwrap_or_default();
869 AgentEffect::with_proposal(proposal(
870 self.name(),
871 ContextKey::Evaluations,
872 format!("monitoring-{}", report.iteration),
873 content,
874 ))
875 }
876}
877
878#[derive(Debug, Default)]
879pub struct DeploymentAgent;
880
881impl DeploymentAgent {
882 pub fn new() -> Self {
883 Self
884 }
885}
886
887#[async_trait::async_trait]
888impl Suggestor for DeploymentAgent {
889 fn name(&self) -> &str {
890 "DeploymentAgent"
891 }
892
893 fn dependencies(&self) -> &[ContextKey] {
894 &[ContextKey::Evaluations, ContextKey::Strategies]
895 }
896
897 fn accepts(&self, ctx: &dyn Context) -> bool {
898 ctx.has(ContextKey::Evaluations)
899 && ctx.has(ContextKey::Strategies)
900 && match latest_evaluation_report(ctx, 0) {
901 Some(report) => !has_deployment_decision_for_iteration(ctx, report.iteration),
902 None => false,
903 }
904 }
905
906 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
907 let report = match latest_evaluation_report(ctx, 0) {
908 Some(report) => report,
909 None => return AgentEffect::empty(),
910 };
911
912 let quality_threshold = read_latest_plan_from_ctx(ctx)
913 .map(|plan| plan.quality_threshold)
914 .unwrap_or(0.75);
915
916 let (action, retrain, reason) = if report.success_ratio >= quality_threshold {
917 ("deploy", false, "meets quality threshold")
918 } else {
919 ("hold", true, "below quality threshold")
920 };
921
922 let decision = DeploymentDecision {
923 kind: "deployment_decision".to_string(),
924 iteration: report.iteration,
925 action: action.to_string(),
926 reason: reason.to_string(),
927 retrain,
928 };
929
930 let content = serde_json::to_string(&decision).unwrap_or_default();
931 AgentEffect::with_proposal(proposal(
932 self.name(),
933 ContextKey::Strategies,
934 format!("deployment-{}", report.iteration),
935 content,
936 ))
937 }
938}
939
940#[async_trait::async_trait]
941impl Suggestor for ModelEvaluationAgent {
942 fn name(&self) -> &str {
943 "ModelEvaluationAgent (MAE)"
944 }
945
946 fn dependencies(&self) -> &[ContextKey] {
947 &[ContextKey::Signals, ContextKey::Strategies]
948 }
949
950 fn accepts(&self, ctx: &dyn Context) -> bool {
951 ctx.has(ContextKey::Signals)
952 && ctx.has(ContextKey::Strategies)
953 && match read_latest_split_from_ctx(ctx) {
954 Ok(split) => !has_evaluation_for_iteration(ctx, split.iteration),
955 Err(_) => false,
956 }
957 }
958
959 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
960 let split = match read_latest_split_from_ctx(ctx) {
961 Ok(split) => split,
962 Err(err) => {
963 return AgentEffect::with_proposal(proposal(
964 self.name(),
965 ContextKey::Diagnostic,
966 "model-eval-error",
967 err.to_string(),
968 ));
969 }
970 };
971
972 let model = match read_model_from_ctx(ctx) {
973 Ok(model) => model,
974 Err(err) => {
975 return AgentEffect::with_proposal(proposal(
976 self.name(),
977 ContextKey::Diagnostic,
978 "model-eval-error",
979 err.to_string(),
980 ));
981 }
982 };
983
984 let raw_val_df = match load_dataframe(Path::new(&split.val_path)) {
985 Ok(df) => df,
986 Err(err) => {
987 return AgentEffect::with_proposal(proposal(
988 self.name(),
989 ContextKey::Diagnostic,
990 "model-eval-error",
991 err.to_string(),
992 ));
993 }
994 };
995
996 let val_df = match read_feature_spec_from_ctx(ctx, split.iteration) {
998 Some(spec) => apply_feature_spec(&raw_val_df, &spec).unwrap_or(raw_val_df),
999 None => raw_val_df,
1000 };
1001
1002 let target = match get_numeric_series(&val_df, &model.target_column) {
1003 Ok(series) => series,
1004 Err(err) => {
1005 return AgentEffect::with_proposal(proposal(
1006 self.name(),
1007 ContextKey::Diagnostic,
1008 "model-eval-error",
1009 err.to_string(),
1010 ));
1011 }
1012 };
1013
1014 let mae = match mean_abs_error(&target, model.mean) {
1015 Ok(value) => value,
1016 Err(err) => {
1017 return AgentEffect::with_proposal(proposal(
1018 self.name(),
1019 ContextKey::Diagnostic,
1020 "model-eval-error",
1021 err.to_string(),
1022 ));
1023 }
1024 };
1025
1026 let mean_abs = match mean_abs_value(&target) {
1027 Ok(value) => value,
1028 Err(err) => {
1029 return AgentEffect::with_proposal(proposal(
1030 self.name(),
1031 ContextKey::Diagnostic,
1032 "model-eval-error",
1033 err.to_string(),
1034 ));
1035 }
1036 };
1037
1038 let success_ratio = if mean_abs > 0.0 {
1039 (1.0 - (mae / mean_abs)).clamp(0.0, 1.0)
1040 } else {
1041 0.0
1042 };
1043
1044 let report = EvaluationReport {
1045 model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
1046 metric: "mae".to_string(),
1047 value: mae,
1048 mean_abs_target: mean_abs,
1049 success_ratio,
1050 val_rows: split.val_rows,
1051 iteration: split.iteration,
1052 };
1053
1054 let content = serde_json::to_string(&report).unwrap_or_default();
1055 AgentEffect::with_proposal(proposal(
1056 self.name(),
1057 ContextKey::Evaluations,
1058 format!("model-eval-{}", split.iteration),
1059 content,
1060 ))
1061 }
1062}
1063
1064#[derive(Debug)]
1065pub struct SampleInferenceAgent {
1066 pub max_rows: usize,
1067}
1068
1069impl SampleInferenceAgent {
1070 pub fn new(max_rows: usize) -> Self {
1071 Self { max_rows }
1072 }
1073}
1074
1075#[async_trait::async_trait]
1076impl Suggestor for SampleInferenceAgent {
1077 fn name(&self) -> &str {
1078 "SampleInferenceAgent (Baseline)"
1079 }
1080
1081 fn dependencies(&self) -> &[ContextKey] {
1082 &[ContextKey::Signals, ContextKey::Strategies]
1083 }
1084
1085 fn accepts(&self, ctx: &dyn Context) -> bool {
1086 ctx.has(ContextKey::Signals)
1087 && ctx.has(ContextKey::Strategies)
1088 && match read_latest_split_from_ctx(ctx) {
1089 Ok(split) => !has_inference_for_iteration(ctx, split.iteration),
1090 Err(_) => false,
1091 }
1092 }
1093
1094 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
1095 let split = match read_latest_split_from_ctx(ctx) {
1096 Ok(split) => split,
1097 Err(err) => {
1098 return AgentEffect::with_proposal(proposal(
1099 self.name(),
1100 ContextKey::Diagnostic,
1101 "model-infer-error",
1102 err.to_string(),
1103 ));
1104 }
1105 };
1106
1107 let model = match read_model_from_ctx(ctx) {
1108 Ok(model) => model,
1109 Err(err) => {
1110 return AgentEffect::with_proposal(proposal(
1111 self.name(),
1112 ContextKey::Diagnostic,
1113 "model-infer-error",
1114 err.to_string(),
1115 ));
1116 }
1117 };
1118
1119 let infer_df = match load_dataframe(Path::new(&split.infer_path)) {
1120 Ok(df) => df,
1121 Err(err) => {
1122 return AgentEffect::with_proposal(proposal(
1123 self.name(),
1124 ContextKey::Diagnostic,
1125 "model-infer-error",
1126 err.to_string(),
1127 ));
1128 }
1129 };
1130
1131 let target = match get_numeric_series(&infer_df, &model.target_column) {
1132 Ok(series) => series,
1133 Err(err) => {
1134 return AgentEffect::with_proposal(proposal(
1135 self.name(),
1136 ContextKey::Diagnostic,
1137 "model-infer-error",
1138 err.to_string(),
1139 ));
1140 }
1141 };
1142
1143 let sample_rows = self.max_rows.min(infer_df.height().max(1));
1144 let actuals = match target.f64() {
1145 Ok(series) => series
1146 .into_no_null_iter()
1147 .take(sample_rows)
1148 .collect::<Vec<_>>(),
1149 Err(err) => {
1150 return AgentEffect::with_proposal(proposal(
1151 self.name(),
1152 ContextKey::Diagnostic,
1153 "model-infer-error",
1154 err.to_string(),
1155 ));
1156 }
1157 };
1158
1159 let predictions = vec![model.mean; actuals.len()];
1160 let sample = InferenceSample {
1161 model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
1162 target_column: model.target_column,
1163 rows: actuals.len(),
1164 predictions,
1165 actuals,
1166 iteration: split.iteration,
1167 };
1168
1169 let content = serde_json::to_string(&sample).unwrap_or_default();
1170 AgentEffect::with_proposal(proposal(
1171 self.name(),
1172 ContextKey::Hypotheses,
1173 format!("inference-sample-{}", split.iteration),
1174 content,
1175 ))
1176 }
1177}
1178
1179fn download_dataset_if_missing(path: &Path) -> Result<()> {
1180 if path.exists() {
1181 return Ok(());
1182 }
1183
1184 let response = reqwest::blocking::get(DATASET_URL)?;
1185 let content = response.bytes()?;
1186
1187 let mut file = File::create(path)?;
1188 file.write_all(&content)?;
1189
1190 Ok(())
1191}
1192
1193fn load_dataframe(path: &Path) -> Result<DataFrame> {
1194 let extension = path
1195 .extension()
1196 .and_then(|ext| ext.to_str())
1197 .unwrap_or("")
1198 .to_ascii_lowercase();
1199
1200 let path_str = path
1201 .to_str()
1202 .ok_or_else(|| anyhow!("path is not valid utf-8: {}", path.display()))?;
1203
1204 match extension.as_str() {
1205 "parquet" => {
1206 let pl_path = PlPath::from_str(path_str);
1207 Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
1208 }
1209 "csv" => Ok(CsvReadOptions::default()
1210 .with_has_header(true)
1211 .try_into_reader_with_file_path(Some(path.to_path_buf()))?
1212 .finish()?),
1213 _ => Err(anyhow!(
1214 "unsupported data format for path {} (expected .csv or .parquet)",
1215 path.display()
1216 )),
1217 }
1218}
1219
1220fn write_parquet(df: &DataFrame, path: &Path) -> Result<()> {
1221 let mut file = File::create(path)?;
1222 let mut owned = df.clone();
1223 ParquetWriter::new(&mut file).finish(&mut owned)?;
1224 Ok(())
1225}
1226
1227fn write_json<T: Serialize>(path: &Path, value: &T) -> Result<()> {
1228 let content = serde_json::to_string_pretty(value)?;
1229 let mut file = File::create(path)?;
1230 file.write_all(content.as_bytes())?;
1231 Ok(())
1232}
1233
1234fn read_latest_split_from_ctx(ctx: &dyn Context) -> Result<DatasetSplit> {
1235 let facts = ctx.get(ContextKey::Signals);
1236 let mut latest: Option<DatasetSplit> = None;
1237 for fact in facts {
1238 if let Ok(split) = serde_json::from_str::<DatasetSplit>(&fact.content) {
1239 let should_replace = match &latest {
1240 Some(current) => split.iteration > current.iteration,
1241 None => true,
1242 };
1243 if should_replace {
1244 latest = Some(split);
1245 }
1246 }
1247 }
1248 latest.ok_or_else(|| anyhow!("missing dataset split"))
1249}
1250
1251fn read_model_path_from_ctx(ctx: &dyn Context) -> Result<String> {
1252 let meta = read_latest_model_meta_from_ctx(ctx)?;
1253 Ok(meta.model_path)
1254}
1255
1256fn read_model_from_ctx(ctx: &dyn Context) -> Result<BaselineModel> {
1257 let model_path = read_model_path_from_ctx(ctx)?;
1258 let content = std::fs::read_to_string(model_path)?;
1259 let model = serde_json::from_str(&content)?;
1260 Ok(model)
1261}
1262
1263fn read_latest_model_meta_from_ctx(ctx: &dyn Context) -> Result<ModelMetadata> {
1264 let facts = ctx.get(ContextKey::Strategies);
1265 let mut latest: Option<ModelMetadata> = None;
1266 for fact in facts {
1267 if let Ok(meta) = serde_json::from_str::<ModelMetadata>(&fact.content) {
1268 let should_replace = match &latest {
1269 Some(current) => meta.iteration > current.iteration,
1270 None => true,
1271 };
1272 if should_replace {
1273 latest = Some(meta);
1274 }
1275 }
1276 }
1277 latest.ok_or_else(|| anyhow!("missing model metadata"))
1278}
1279
1280fn read_latest_plan_from_ctx(ctx: &dyn Context) -> Option<TrainingPlan> {
1281 let facts = ctx.get(ContextKey::Constraints);
1282 let mut latest: Option<TrainingPlan> = None;
1283 for fact in facts {
1284 if let Ok(plan) = serde_json::from_str::<TrainingPlan>(&fact.content) {
1285 let should_replace = match &latest {
1286 Some(current) => plan.iteration > current.iteration,
1287 None => true,
1288 };
1289 if should_replace {
1290 latest = Some(plan);
1291 }
1292 }
1293 }
1294 latest
1295}
1296
1297fn has_split_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1298 ctx.get(ContextKey::Signals).iter().any(|fact| {
1299 serde_json::from_str::<DatasetSplit>(&fact.content)
1300 .map(|split| split.iteration == iteration)
1301 .unwrap_or(false)
1302 })
1303}
1304
1305fn has_model_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1306 ctx.get(ContextKey::Strategies).iter().any(|fact| {
1307 serde_json::from_str::<ModelMetadata>(&fact.content)
1308 .map(|meta| meta.iteration == iteration)
1309 .unwrap_or(false)
1310 })
1311}
1312
1313fn has_evaluation_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1314 ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1315 serde_json::from_str::<EvaluationReport>(&fact.content)
1316 .map(|report| report.iteration == iteration)
1317 .unwrap_or(false)
1318 })
1319}
1320
1321fn has_inference_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1322 ctx.get(ContextKey::Hypotheses).iter().any(|fact| {
1323 serde_json::from_str::<InferenceSample>(&fact.content)
1324 .map(|sample| sample.iteration == iteration)
1325 .unwrap_or(false)
1326 })
1327}
1328
1329fn has_data_quality_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1330 ctx.get(ContextKey::Signals).iter().any(|fact| {
1331 serde_json::from_str::<DataQualityReport>(&fact.content)
1332 .map(|report| report.iteration == iteration)
1333 .unwrap_or(false)
1334 })
1335}
1336
1337fn has_feature_spec_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1338 ctx.get(ContextKey::Constraints).iter().any(|fact| {
1339 serde_json::from_str::<FeatureSpec>(&fact.content)
1340 .map(|spec| spec.iteration == iteration)
1341 .unwrap_or(false)
1342 })
1343}
1344
1345fn has_hyperparam_result_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1346 ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1347 serde_json::from_str::<HyperparameterSearchResult>(&fact.content)
1348 .map(|result| result.iteration == iteration)
1349 .unwrap_or(false)
1350 })
1351}
1352
1353fn has_registry_record_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1354 ctx.get(ContextKey::Strategies).iter().any(|fact| {
1355 serde_json::from_str::<ModelRegistryRecord>(&fact.content)
1356 .map(|record| record.iteration == iteration)
1357 .unwrap_or(false)
1358 })
1359}
1360
1361fn has_monitoring_report_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1362 ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1363 serde_json::from_str::<MonitoringReport>(&fact.content)
1364 .map(|report| report.iteration == iteration)
1365 .unwrap_or(false)
1366 })
1367}
1368
1369fn has_deployment_decision_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1370 ctx.get(ContextKey::Strategies).iter().any(|fact| {
1371 serde_json::from_str::<DeploymentDecision>(&fact.content)
1372 .map(|decision| decision.iteration == iteration)
1373 .unwrap_or(false)
1374 })
1375}
1376
1377fn latest_evaluation_report(ctx: &dyn Context, iteration: usize) -> Option<EvaluationReport> {
1378 let mut latest: Option<EvaluationReport> = None;
1379 for fact in ctx.get(ContextKey::Evaluations) {
1380 if let Ok(report) = serde_json::from_str::<EvaluationReport>(&fact.content) {
1381 if iteration > 0 {
1382 if report.iteration == iteration {
1383 return Some(report);
1384 }
1385 } else if latest
1386 .as_ref()
1387 .map(|current| report.iteration > current.iteration)
1388 .unwrap_or(true)
1389 {
1390 latest = Some(report);
1391 }
1392 }
1393 }
1394 if iteration > 0 { None } else { latest }
1395}
1396
1397fn latest_data_quality_before_iteration(
1398 ctx: &dyn Context,
1399 iteration: usize,
1400) -> Option<DataQualityReport> {
1401 let mut latest: Option<DataQualityReport> = None;
1402 for fact in ctx.get(ContextKey::Signals) {
1403 if let Ok(report) = serde_json::from_str::<DataQualityReport>(&fact.content) {
1404 if report.iteration < iteration
1405 && latest
1406 .as_ref()
1407 .map(|current| report.iteration > current.iteration)
1408 .unwrap_or(true)
1409 {
1410 latest = Some(report);
1411 }
1412 }
1413 }
1414 latest
1415}
1416
1417fn drift_score_from_ctx(
1418 ctx: &dyn Context,
1419 iteration: usize,
1420 numeric_means: &HashMap<String, f64>,
1421) -> Option<f64> {
1422 let previous = latest_data_quality_before_iteration(ctx, iteration)?;
1423 let mut total_delta = 0.0;
1424 let mut count = 0usize;
1425 for (name, mean) in numeric_means {
1426 if let Some(prev_mean) = previous.numeric_means.get(name) {
1427 total_delta += (mean - prev_mean).abs();
1428 count += 1;
1429 }
1430 }
1431 if count == 0 {
1432 None
1433 } else {
1434 Some(total_delta / count as f64)
1435 }
1436}
1437
1438fn compute_numeric_stats(series: &Series) -> Result<(f64, f64, usize)> {
1439 let casted = series.cast(&DataType::Float64)?;
1440 let values: Vec<f64> = casted
1441 .f64()
1442 .context("numeric series not f64")?
1443 .into_no_null_iter()
1444 .collect();
1445 if values.is_empty() {
1446 return Err(anyhow!("no numeric values to compute stats"));
1447 }
1448
1449 let mut total = 0.0;
1450 for value in &values {
1451 total += *value;
1452 }
1453 let mean = total / values.len() as f64;
1454
1455 let mut variance_sum = 0.0;
1456 for value in &values {
1457 let diff = *value - mean;
1458 variance_sum += diff * diff;
1459 }
1460 let std = (variance_sum / values.len() as f64).sqrt();
1461
1462 let outliers = if std > 0.0 {
1463 values
1464 .iter()
1465 .filter(|value| (*value - mean).abs() > 3.0 * std)
1466 .count()
1467 } else {
1468 0
1469 };
1470
1471 Ok((mean, std, outliers))
1472}
1473
1474fn split_feature_columns(df: &DataFrame, target: &str) -> (Vec<String>, Vec<String>) {
1475 let mut numeric = Vec::new();
1476 let mut categorical = Vec::new();
1477 for series in df.get_columns() {
1478 let name = series.name();
1479 if name == target {
1480 continue;
1481 }
1482 if is_numeric_dtype(series.dtype()) {
1483 numeric.push(name.to_string());
1484 } else {
1485 categorical.push(name.to_string());
1486 }
1487 }
1488 (numeric, categorical)
1489}
1490
1491fn select_target_column(df: &DataFrame) -> Result<(String, Series)> {
1492 if let Ok(col) = df.column(TARGET_COLUMN) {
1493 return Ok((
1494 TARGET_COLUMN.to_string(),
1495 col.as_materialized_series().clone(),
1496 ));
1497 }
1498
1499 let mut numeric = df
1500 .get_columns()
1501 .iter()
1502 .filter(|series| is_numeric_dtype(series.dtype()))
1503 .cloned()
1504 .collect::<Vec<_>>();
1505
1506 let fallback = numeric
1507 .pop()
1508 .ok_or_else(|| anyhow!("no numeric columns available for target"))?;
1509 let series = fallback.as_materialized_series().clone();
1510 Ok((series.name().to_string(), series))
1511}
1512
1513fn get_numeric_series(df: &DataFrame, name: &str) -> Result<Series> {
1514 let series = df
1515 .column(name)
1516 .map_err(|_| anyhow!("missing target column {}", name))?
1517 .as_materialized_series();
1518 let casted = series.cast(&DataType::Float64)?;
1519 Ok(casted)
1520}
1521
1522fn mean_of_series(series: &Series) -> Result<f64> {
1523 let casted = series.cast(&DataType::Float64)?;
1524 let values = casted
1525 .f64()
1526 .context("target column not f64")?
1527 .into_no_null_iter();
1528 let mut total = 0.0;
1529 let mut count = 0usize;
1530 for value in values {
1531 total += value;
1532 count += 1;
1533 }
1534 if count == 0 {
1535 return Err(anyhow!("no values to compute mean"));
1536 }
1537 Ok(total / count as f64)
1538}
1539
1540fn mean_abs_error(target: &Series, mean: f64) -> Result<f64> {
1541 let casted = target.cast(&DataType::Float64)?;
1542 let values = casted
1543 .f64()
1544 .context("target column not f64")?
1545 .into_no_null_iter();
1546 let mut total = 0.0;
1547 let mut count = 0usize;
1548 for value in values {
1549 total += (value - mean).abs();
1550 count += 1;
1551 }
1552 if count == 0 {
1553 return Err(anyhow!("no values to evaluate"));
1554 }
1555 Ok(total / count as f64)
1556}
1557
1558fn mean_abs_value(target: &Series) -> Result<f64> {
1559 let casted = target.cast(&DataType::Float64)?;
1560 let values = casted
1561 .f64()
1562 .context("target column not f64")?
1563 .into_no_null_iter();
1564 let mut total = 0.0;
1565 let mut count = 0usize;
1566 for value in values {
1567 total += value.abs();
1568 count += 1;
1569 }
1570 if count == 0 {
1571 return Err(anyhow!("no values to evaluate"));
1572 }
1573 Ok(total / count as f64)
1574}
1575fn is_numeric_dtype(dtype: &DataType) -> bool {
1576 matches!(
1577 dtype,
1578 DataType::Int8
1579 | DataType::Int16
1580 | DataType::Int32
1581 | DataType::Int64
1582 | DataType::UInt8
1583 | DataType::UInt16
1584 | DataType::UInt32
1585 | DataType::UInt64
1586 | DataType::Float32
1587 | DataType::Float64
1588 )
1589}
1590
1591fn read_feature_spec_from_ctx(ctx: &dyn Context, iteration: usize) -> Option<FeatureSpec> {
1593 ctx.get(ContextKey::Constraints).iter().find_map(|fact| {
1594 serde_json::from_str::<FeatureSpec>(&fact.content)
1595 .ok()
1596 .filter(|spec| spec.iteration == iteration)
1597 })
1598}
1599
1600pub fn apply_feature_spec(df: &DataFrame, spec: &FeatureSpec) -> Result<DataFrame> {
1602 let mut result = df.clone();
1603
1604 for interaction in &spec.interactions {
1606 let left_col = result
1607 .column(&interaction.left)
1608 .map_err(|_| anyhow!("missing column {} for interaction", interaction.left))?
1609 .cast(&DataType::Float64)?;
1610 let right_col = result
1611 .column(&interaction.right)
1612 .map_err(|_| anyhow!("missing column {} for interaction", interaction.right))?
1613 .cast(&DataType::Float64)?;
1614
1615 let left_vals = left_col.f64().context("left column not f64")?;
1616 let right_vals = right_col.f64().context("right column not f64")?;
1617
1618 let interaction_series = match interaction.op.as_str() {
1619 "multiply" => left_vals * right_vals,
1620 "add" => left_vals + right_vals,
1621 "subtract" => left_vals - right_vals,
1622 "divide" => {
1623 left_vals
1625 .into_iter()
1626 .zip(right_vals.into_iter())
1627 .map(|(l, r)| match (l, r) {
1628 (Some(lv), Some(rv)) if rv.abs() > 1e-10 => Some(lv / rv),
1629 _ => None,
1630 })
1631 .collect::<Float64Chunked>()
1632 }
1633 _ => return Err(anyhow!("unsupported interaction op: {}", interaction.op)),
1634 };
1635
1636 let named_series = interaction_series.with_name(interaction.name.clone().into());
1637 result = result
1638 .hstack(&[named_series.into_series().into()])
1639 .context("failed to add interaction column")?;
1640 }
1641
1642 if spec.normalization == "standardize" {
1644 for col_name in &spec.numeric_features {
1645 if let Ok(col) = result.column(col_name) {
1646 let casted = col.cast(&DataType::Float64)?;
1647 let values = casted.f64().context("column not f64")?;
1648
1649 let (mean, std) = compute_mean_std(values)?;
1651
1652 if std > 0.0 {
1653 let standardized = (values - mean) / std;
1655 let named = standardized.with_name(col_name.clone().into());
1656
1657 result = result.drop(col_name)?;
1659 result = result
1660 .hstack(&[named.into_series().into()])
1661 .context("failed to replace standardized column")?;
1662 }
1663 }
1664 }
1665 }
1666
1667 Ok(result)
1668}
1669
1670fn compute_mean_std(values: &ChunkedArray<Float64Type>) -> Result<(f64, f64)> {
1671 let vals: Vec<f64> = values.into_no_null_iter().collect();
1672 if vals.is_empty() {
1673 return Err(anyhow!("no values for mean/std computation"));
1674 }
1675
1676 let mean = vals.iter().sum::<f64>() / vals.len() as f64;
1677 let variance = vals.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / vals.len() as f64;
1678 let std = variance.sqrt();
1679
1680 Ok((mean, std))
1681}
1682
1683#[cfg(test)]
1684mod tests {
1685 use super::*;
1686 use std::path::PathBuf;
1687
1688 #[test]
1689 fn proposal_helper_builds_correct_fact() {
1690 let p = proposal("my-agent", ContextKey::Diagnostic, "id-1", "content-1");
1691 assert_eq!(p.provenance, "my-agent");
1692 assert_eq!(p.key, ContextKey::Diagnostic);
1693 assert_eq!(p.id, "id-1");
1694 assert_eq!(p.content, "content-1");
1695 }
1696
1697 #[test]
1698 fn training_plan_serde_roundtrip() {
1699 let plan = TrainingPlan {
1700 iteration: 3,
1701 max_rows: 1000,
1702 train_fraction: 0.7,
1703 val_fraction: 0.2,
1704 infer_fraction: 0.1,
1705 quality_threshold: 0.8,
1706 };
1707 let json = serde_json::to_string(&plan).unwrap();
1708 let restored: TrainingPlan = serde_json::from_str(&json).unwrap();
1709 assert_eq!(restored.iteration, 3);
1710 assert_eq!(restored.max_rows, 1000);
1711 assert!((restored.train_fraction - 0.7).abs() < f64::EPSILON);
1712 }
1713
1714 #[test]
1715 fn dataset_split_serde_roundtrip() {
1716 let split = DatasetSplit {
1717 source_path: "/data/src.parquet".into(),
1718 train_path: "/data/train.parquet".into(),
1719 val_path: "/data/val.parquet".into(),
1720 infer_path: "/data/infer.parquet".into(),
1721 total_rows: 1000,
1722 max_rows: 800,
1723 train_rows: 640,
1724 val_rows: 120,
1725 infer_rows: 40,
1726 iteration: 1,
1727 };
1728 let json = serde_json::to_string(&split).unwrap();
1729 let restored: DatasetSplit = serde_json::from_str(&json).unwrap();
1730 assert_eq!(restored.total_rows, 1000);
1731 assert_eq!(
1732 restored.train_rows + restored.val_rows + restored.infer_rows,
1733 800
1734 );
1735 }
1736
1737 #[test]
1738 fn baseline_model_serde_roundtrip() {
1739 let model = BaselineModel {
1740 target_column: "price".into(),
1741 mean: 42.5,
1742 };
1743 let json = serde_json::to_string(&model).unwrap();
1744 let restored: BaselineModel = serde_json::from_str(&json).unwrap();
1745 assert_eq!(restored.target_column, "price");
1746 assert!((restored.mean - 42.5).abs() < f64::EPSILON);
1747 }
1748
1749 #[test]
1750 fn evaluation_report_success_ratio_bounds() {
1751 let report = EvaluationReport {
1752 model_path: "/model".into(),
1753 metric: "mae".into(),
1754 value: 10.0,
1755 mean_abs_target: 100.0,
1756 success_ratio: 0.9,
1757 val_rows: 50,
1758 iteration: 1,
1759 };
1760 assert!(report.success_ratio >= 0.0 && report.success_ratio <= 1.0);
1761 }
1762
1763 #[test]
1764 fn feature_interaction_construction() {
1765 let fi = FeatureInteraction {
1766 name: "a_x_b".into(),
1767 left: "a".into(),
1768 right: "b".into(),
1769 op: "multiply".into(),
1770 };
1771 assert_eq!(fi.name, "a_x_b");
1772 assert_eq!(fi.op, "multiply");
1773 }
1774
1775 #[test]
1776 fn feature_spec_serde_roundtrip() {
1777 let spec = FeatureSpec {
1778 kind: "feature_spec".into(),
1779 iteration: 2,
1780 target_column: "target".into(),
1781 numeric_features: vec!["a".into(), "b".into()],
1782 categorical_features: vec!["c".into()],
1783 normalization: "standardize".into(),
1784 interactions: vec![],
1785 };
1786 let json = serde_json::to_string(&spec).unwrap();
1787 let restored: FeatureSpec = serde_json::from_str(&json).unwrap();
1788 assert_eq!(restored.numeric_features.len(), 2);
1789 assert_eq!(restored.categorical_features.len(), 1);
1790 }
1791
1792 #[test]
1793 fn hyperparam_search_plan_construction() {
1794 let mut params = HashMap::new();
1795 params.insert("lr".to_string(), vec![0.001, 0.01]);
1796 let plan = HyperparameterSearchPlan {
1797 kind: "hyperparam_plan".into(),
1798 iteration: 1,
1799 max_trials: 10,
1800 early_stopping: true,
1801 params,
1802 };
1803 assert_eq!(plan.max_trials, 10);
1804 assert!(plan.early_stopping);
1805 assert_eq!(plan.params["lr"].len(), 2);
1806 }
1807
1808 #[test]
1809 fn hyperparam_search_result_serde_roundtrip() {
1810 let mut best = HashMap::new();
1811 best.insert("lr".to_string(), 0.01);
1812 let result = HyperparameterSearchResult {
1813 kind: "hyperparam_result".into(),
1814 iteration: 1,
1815 best_params: best,
1816 score: 0.85,
1817 };
1818 let json = serde_json::to_string(&result).unwrap();
1819 let restored: HyperparameterSearchResult = serde_json::from_str(&json).unwrap();
1820 assert!((restored.score - 0.85).abs() < f64::EPSILON);
1821 assert!((restored.best_params["lr"] - 0.01).abs() < f64::EPSILON);
1822 }
1823
1824 #[test]
1825 fn model_registry_record_construction() {
1826 let mut metrics = HashMap::new();
1827 metrics.insert("mae".to_string(), 5.0);
1828 let record = ModelRegistryRecord {
1829 kind: "model_registry".into(),
1830 iteration: 1,
1831 model_path: "/models/v1.json".into(),
1832 metrics,
1833 provenance: "test".into(),
1834 };
1835 assert_eq!(record.metrics["mae"], 5.0);
1836 }
1837
1838 #[test]
1839 fn monitoring_report_status_values() {
1840 let healthy = MonitoringReport {
1841 kind: "monitoring".into(),
1842 iteration: 1,
1843 metric: "mae".into(),
1844 value: 5.0,
1845 baseline: 100.0,
1846 status: "healthy".into(),
1847 };
1848 assert_eq!(healthy.status, "healthy");
1849
1850 let needs_attention = MonitoringReport {
1851 status: "needs_attention".into(),
1852 ..healthy.clone()
1853 };
1854 assert_eq!(needs_attention.status, "needs_attention");
1855 }
1856
1857 #[test]
1858 fn deployment_decision_deploy_vs_hold() {
1859 let deploy = DeploymentDecision {
1860 kind: "deployment_decision".into(),
1861 iteration: 1,
1862 action: "deploy".into(),
1863 reason: "meets threshold".into(),
1864 retrain: false,
1865 };
1866 assert!(!deploy.retrain);
1867
1868 let hold = DeploymentDecision {
1869 action: "hold".into(),
1870 retrain: true,
1871 ..deploy.clone()
1872 };
1873 assert!(hold.retrain);
1874 assert_eq!(hold.action, "hold");
1875 }
1876
1877 #[test]
1878 fn data_validation_agent_default() {
1879 let agent = DataValidationAgent::default();
1880 let agent2 = DataValidationAgent::new();
1881 assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1882 }
1883
1884 #[test]
1885 fn feature_engineering_agent_default() {
1886 let agent = FeatureEngineeringAgent::default();
1887 let agent2 = FeatureEngineeringAgent::new();
1888 assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1889 }
1890
1891 #[test]
1892 fn model_evaluation_agent_default() {
1893 let agent = ModelEvaluationAgent::default();
1894 let agent2 = ModelEvaluationAgent::new();
1895 assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1896 }
1897
1898 #[test]
1899 fn model_registry_agent_default() {
1900 let agent = ModelRegistryAgent::default();
1901 let agent2 = ModelRegistryAgent::new();
1902 assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1903 }
1904
1905 #[test]
1906 fn monitoring_agent_default() {
1907 let agent = MonitoringAgent::default();
1908 let agent2 = MonitoringAgent::new();
1909 assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1910 }
1911
1912 #[test]
1913 fn deployment_agent_default() {
1914 let agent = DeploymentAgent::default();
1915 let agent2 = DeploymentAgent::new();
1916 assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1917 }
1918
1919 #[test]
1920 fn dataset_agent_paths() {
1921 let agent = DatasetAgent::new(PathBuf::from("/tmp/data"));
1922 assert_eq!(
1923 agent.dataset_path(),
1924 PathBuf::from("/tmp/data/california_housing_train.parquet")
1925 );
1926 let (train, val, infer) = agent.split_paths();
1927 assert_eq!(train, PathBuf::from("/tmp/data/train.parquet"));
1928 assert_eq!(val, PathBuf::from("/tmp/data/val.parquet"));
1929 assert_eq!(infer, PathBuf::from("/tmp/data/infer.parquet"));
1930 }
1931
1932 #[test]
1933 fn model_training_agent_model_path() {
1934 let agent = ModelTrainingAgent::new(PathBuf::from("/tmp/models"));
1935 assert_eq!(
1936 agent.model_path(),
1937 PathBuf::from("/tmp/models/baseline_mean.json")
1938 );
1939 }
1940
1941 #[test]
1942 fn sample_inference_agent_construction() {
1943 let agent = SampleInferenceAgent::new(100);
1944 assert_eq!(agent.max_rows, 100);
1945 }
1946
1947 #[test]
1948 fn hyperparameter_search_agent_construction() {
1949 let agent = HyperparameterSearchAgent::new(50);
1950 assert_eq!(agent.max_trials, 50);
1951 }
1952
1953 #[test]
1954 fn is_numeric_dtype_comprehensive() {
1955 assert!(is_numeric_dtype(&DataType::Float32));
1956 assert!(is_numeric_dtype(&DataType::Float64));
1957 assert!(is_numeric_dtype(&DataType::Int64));
1958 assert!(!is_numeric_dtype(&DataType::String));
1959 assert!(!is_numeric_dtype(&DataType::Boolean));
1960 }
1961
1962 #[test]
1963 fn split_feature_columns_separates_types() {
1964 let df = df! {
1965 "num1" => [1.0, 2.0],
1966 "num2" => [3i32, 4],
1967 "cat1" => ["a", "b"],
1968 "target" => [10.0, 20.0]
1969 }
1970 .unwrap();
1971 let (numeric, categorical) = split_feature_columns(&df, "target");
1972 assert!(numeric.contains(&"num1".to_string()));
1973 assert!(numeric.contains(&"num2".to_string()));
1974 assert!(categorical.contains(&"cat1".to_string()));
1975 assert!(!numeric.contains(&"target".to_string()));
1976 assert!(!categorical.contains(&"target".to_string()));
1977 }
1978
1979 #[test]
1980 fn select_target_column_prefers_named_target() {
1981 let df = df! {
1982 "x" => [1.0, 2.0],
1983 "median_house_value" => [100.0, 200.0]
1984 }
1985 .unwrap();
1986 let (name, _) = select_target_column(&df).unwrap();
1987 assert_eq!(name, "median_house_value");
1988 }
1989
1990 #[test]
1991 fn select_target_column_falls_back_to_last_numeric() {
1992 let df = df! {
1993 "a" => [1.0, 2.0],
1994 "b" => [3.0, 4.0]
1995 }
1996 .unwrap();
1997 let (name, _) = select_target_column(&df).unwrap();
1998 assert_eq!(name, "b");
1999 }
2000
2001 #[test]
2002 fn select_target_column_fails_with_no_numeric() {
2003 let df = df! {
2004 "text" => ["a", "b"]
2005 }
2006 .unwrap();
2007 assert!(select_target_column(&df).is_err());
2008 }
2009
2010 #[test]
2011 fn mean_of_series_computes_correctly() {
2012 let series = Series::new("v".into(), &[2.0f64, 4.0, 6.0]);
2013 let mean = mean_of_series(&series).unwrap();
2014 assert!((mean - 4.0).abs() < f64::EPSILON);
2015 }
2016
2017 #[test]
2018 fn mean_abs_error_computes_correctly() {
2019 let series = Series::new("v".into(), &[10.0f64, 20.0, 30.0]);
2020 let mae = mean_abs_error(&series, 20.0).unwrap();
2021 assert!((mae - 20.0 / 3.0).abs() < 1e-10);
2023 }
2024
2025 #[test]
2026 fn mean_abs_value_computes_correctly() {
2027 let series = Series::new("v".into(), &[-5.0f64, 5.0, 10.0]);
2028 let mav = mean_abs_value(&series).unwrap();
2029 assert!((mav - 20.0 / 3.0).abs() < 1e-10);
2030 }
2031
2032 #[test]
2033 fn compute_numeric_stats_basic() {
2034 let series = Series::new("v".into(), &[2.0f64, 4.0, 6.0]);
2035 let casted = series.cast(&DataType::Float64).unwrap();
2036 let (mean, std, outliers) = compute_numeric_stats(&casted).unwrap();
2037 assert!((mean - 4.0).abs() < 1e-10);
2038 assert!(std > 0.0);
2039 assert_eq!(outliers, 0);
2040 }
2041
2042 #[test]
2043 fn compute_numeric_stats_empty_fails() {
2044 let series = Series::new("v".into(), Vec::<f64>::new());
2045 assert!(compute_numeric_stats(&series).is_err());
2046 }
2047
2048 #[test]
2049 fn compute_numeric_stats_constant_series() {
2050 let series = Series::new("v".into(), &[5.0f64, 5.0, 5.0]);
2051 let (mean, std, outliers) = compute_numeric_stats(&series).unwrap();
2052 assert!((mean - 5.0).abs() < f64::EPSILON);
2053 assert!((std - 0.0).abs() < f64::EPSILON);
2054 assert_eq!(outliers, 0);
2055 }
2056
2057 #[test]
2058 fn data_quality_report_serde_roundtrip() {
2059 let mut missingness = HashMap::new();
2060 missingness.insert("a".to_string(), 0.1);
2061 let report = DataQualityReport {
2062 kind: "data_quality".into(),
2063 iteration: 1,
2064 source_path: "/data/train.parquet".into(),
2065 rows_checked: 100,
2066 missingness,
2067 numeric_means: HashMap::new(),
2068 outlier_counts: HashMap::new(),
2069 drift_score: Some(0.05),
2070 };
2071 let json = serde_json::to_string(&report).unwrap();
2072 let restored: DataQualityReport = serde_json::from_str(&json).unwrap();
2073 assert_eq!(restored.rows_checked, 100);
2074 assert!((restored.drift_score.unwrap() - 0.05).abs() < f64::EPSILON);
2075 }
2076
2077 #[test]
2078 fn inference_sample_serde_roundtrip() {
2079 let sample = InferenceSample {
2080 model_path: "/model".into(),
2081 target_column: "target".into(),
2082 rows: 3,
2083 predictions: vec![1.0, 2.0, 3.0],
2084 actuals: vec![1.1, 2.1, 3.1],
2085 iteration: 1,
2086 };
2087 let json = serde_json::to_string(&sample).unwrap();
2088 let restored: InferenceSample = serde_json::from_str(&json).unwrap();
2089 assert_eq!(restored.predictions.len(), 3);
2090 assert_eq!(restored.actuals.len(), 3);
2091 }
2092
2093 #[test]
2094 fn apply_feature_spec_creates_interaction_column() {
2095 let df = df! {
2096 "a" => [1.0, 2.0, 3.0],
2097 "b" => [4.0, 5.0, 6.0],
2098 "target" => [10.0, 20.0, 30.0]
2099 }
2100 .unwrap();
2101
2102 let spec = FeatureSpec {
2103 kind: "feature_spec".to_string(),
2104 iteration: 1,
2105 target_column: "target".to_string(),
2106 numeric_features: vec!["a".to_string(), "b".to_string()],
2107 categorical_features: vec![],
2108 normalization: "none".to_string(),
2109 interactions: vec![FeatureInteraction {
2110 name: "a_x_b".to_string(),
2111 left: "a".to_string(),
2112 right: "b".to_string(),
2113 op: "multiply".to_string(),
2114 }],
2115 };
2116
2117 let result = apply_feature_spec(&df, &spec).unwrap();
2118
2119 assert!(result.column("a_x_b").is_ok());
2121
2122 let interaction = result.column("a_x_b").unwrap().f64().unwrap();
2124 let values: Vec<f64> = interaction.into_no_null_iter().collect();
2125 assert_eq!(values, vec![4.0, 10.0, 18.0]);
2126 }
2127
2128 #[test]
2129 fn apply_feature_spec_standardizes_numeric_features() {
2130 let df = df! {
2131 "a" => [1.0, 2.0, 3.0, 4.0, 5.0],
2132 "target" => [10.0, 20.0, 30.0, 40.0, 50.0]
2133 }
2134 .unwrap();
2135
2136 let spec = FeatureSpec {
2137 kind: "feature_spec".to_string(),
2138 iteration: 1,
2139 target_column: "target".to_string(),
2140 numeric_features: vec!["a".to_string()],
2141 categorical_features: vec![],
2142 normalization: "standardize".to_string(),
2143 interactions: vec![],
2144 };
2145
2146 let result = apply_feature_spec(&df, &spec).unwrap();
2147
2148 let a_col = result.column("a").unwrap().f64().unwrap();
2150 let values: Vec<f64> = a_col.into_no_null_iter().collect();
2151
2152 let mean: f64 = values.iter().sum::<f64>() / values.len() as f64;
2153 assert!(mean.abs() < 1e-10, "mean should be ~0, got {}", mean);
2154
2155 let variance: f64 =
2156 values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
2157 let std = variance.sqrt();
2158 assert!((std - 1.0).abs() < 1e-10, "std should be ~1, got {}", std);
2159 }
2160
2161 #[test]
2162 fn apply_feature_spec_handles_add_operation() {
2163 let df = df! {
2164 "a" => [1.0, 2.0, 3.0],
2165 "b" => [10.0, 20.0, 30.0]
2166 }
2167 .unwrap();
2168
2169 let spec = FeatureSpec {
2170 kind: "feature_spec".to_string(),
2171 iteration: 1,
2172 target_column: "target".to_string(),
2173 numeric_features: vec![],
2174 categorical_features: vec![],
2175 normalization: "none".to_string(),
2176 interactions: vec![FeatureInteraction {
2177 name: "a_plus_b".to_string(),
2178 left: "a".to_string(),
2179 right: "b".to_string(),
2180 op: "add".to_string(),
2181 }],
2182 };
2183
2184 let result = apply_feature_spec(&df, &spec).unwrap();
2185 let col = result.column("a_plus_b").unwrap().f64().unwrap();
2186 let values: Vec<f64> = col.into_no_null_iter().collect();
2187 assert_eq!(values, vec![11.0, 22.0, 33.0]);
2188 }
2189}