Skip to main content

converge_analytics/
training.rs

1// Copyright 2024-2026 Reflective Labs
2
3use anyhow::{Context as _, Result, anyhow};
4use converge_pack::{AgentEffect, Context, ContextKey, ProposalId, ProposedFact, Suggestor};
5use polars::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fs::{File, create_dir_all};
9use std::io::Write;
10use std::path::{Path, PathBuf};
11
12const DATASET_URL: &str = "https://huggingface.co/datasets/gvlassis/california_housing/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet";
13const TARGET_COLUMN: &str = "median_house_value";
14
15fn proposal(
16    provenance: &str,
17    key: ContextKey,
18    id: impl Into<String>,
19    content: impl Into<String>,
20) -> ProposedFact {
21    ProposedFact::new(key, ProposalId::new(id.into()), content, provenance)
22}
23
24#[derive(Debug, Serialize, Deserialize, Clone)]
25pub struct TrainingPlan {
26    pub iteration: usize,
27    pub max_rows: usize,
28    pub train_fraction: f64,
29    pub val_fraction: f64,
30    pub infer_fraction: f64,
31    pub quality_threshold: f64,
32}
33
34#[derive(Debug, Serialize, Deserialize, Clone)]
35pub struct DatasetSplit {
36    pub source_path: String,
37    pub train_path: String,
38    pub val_path: String,
39    pub infer_path: String,
40    pub total_rows: usize,
41    pub max_rows: usize,
42    pub train_rows: usize,
43    pub val_rows: usize,
44    pub infer_rows: usize,
45    pub iteration: usize,
46}
47
48#[derive(Debug, Serialize, Deserialize, Clone)]
49pub struct BaselineModel {
50    pub target_column: String,
51    pub mean: f64,
52}
53
54#[derive(Debug, Serialize, Deserialize, Clone)]
55pub struct ModelMetadata {
56    pub model_path: String,
57    pub target_column: String,
58    pub train_rows: usize,
59    pub baseline_mean: f64,
60    pub iteration: usize,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
64pub struct EvaluationReport {
65    pub model_path: String,
66    pub metric: String,
67    pub value: f64,
68    pub mean_abs_target: f64,
69    pub success_ratio: f64,
70    pub val_rows: usize,
71    pub iteration: usize,
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct InferenceSample {
76    pub model_path: String,
77    pub target_column: String,
78    pub rows: usize,
79    pub predictions: Vec<f64>,
80    pub actuals: Vec<f64>,
81    pub iteration: usize,
82}
83
84#[derive(Debug, Serialize, Deserialize, Clone)]
85pub struct DataQualityReport {
86    pub kind: String,
87    pub iteration: usize,
88    pub source_path: String,
89    pub rows_checked: usize,
90    pub missingness: HashMap<String, f64>,
91    pub numeric_means: HashMap<String, f64>,
92    pub outlier_counts: HashMap<String, usize>,
93    pub drift_score: Option<f64>,
94}
95
96#[derive(Debug, Serialize, Deserialize, Clone)]
97pub struct FeatureInteraction {
98    pub name: String,
99    pub left: String,
100    pub right: String,
101    pub op: String,
102}
103
104#[derive(Debug, Serialize, Deserialize, Clone)]
105pub struct FeatureSpec {
106    pub kind: String,
107    pub iteration: usize,
108    pub target_column: String,
109    pub numeric_features: Vec<String>,
110    pub categorical_features: Vec<String>,
111    pub normalization: String,
112    pub interactions: Vec<FeatureInteraction>,
113}
114
115#[derive(Debug, Serialize, Deserialize, Clone)]
116pub struct HyperparameterSearchPlan {
117    pub kind: String,
118    pub iteration: usize,
119    pub max_trials: usize,
120    pub early_stopping: bool,
121    pub params: HashMap<String, Vec<f64>>,
122}
123
124#[derive(Debug, Serialize, Deserialize, Clone)]
125pub struct HyperparameterSearchResult {
126    pub kind: String,
127    pub iteration: usize,
128    pub best_params: HashMap<String, f64>,
129    pub score: f64,
130}
131
132#[derive(Debug, Serialize, Deserialize, Clone)]
133pub struct ModelRegistryRecord {
134    pub kind: String,
135    pub iteration: usize,
136    pub model_path: String,
137    pub metrics: HashMap<String, f64>,
138    pub provenance: String,
139}
140
141#[derive(Debug, Serialize, Deserialize, Clone)]
142pub struct MonitoringReport {
143    pub kind: String,
144    pub iteration: usize,
145    pub metric: String,
146    pub value: f64,
147    pub baseline: f64,
148    pub status: String,
149}
150
151#[derive(Debug, Serialize, Deserialize, Clone)]
152pub struct DeploymentDecision {
153    pub kind: String,
154    pub iteration: usize,
155    pub action: String,
156    pub reason: String,
157    pub retrain: bool,
158}
159
160#[derive(Debug)]
161pub struct DatasetAgent {
162    data_dir: PathBuf,
163}
164
165impl DatasetAgent {
166    pub fn new(data_dir: PathBuf) -> Self {
167        Self { data_dir }
168    }
169
170    fn dataset_path(&self) -> PathBuf {
171        self.data_dir.join("california_housing_train.parquet")
172    }
173
174    fn split_paths(&self) -> (PathBuf, PathBuf, PathBuf) {
175        (
176            self.data_dir.join("train.parquet"),
177            self.data_dir.join("val.parquet"),
178            self.data_dir.join("infer.parquet"),
179        )
180    }
181}
182
183#[derive(Debug, Default)]
184pub struct DataValidationAgent;
185
186impl DataValidationAgent {
187    pub fn new() -> Self {
188        Self
189    }
190}
191
192#[async_trait::async_trait]
193impl Suggestor for DataValidationAgent {
194    fn name(&self) -> &str {
195        "DataValidationAgent"
196    }
197
198    fn dependencies(&self) -> &[ContextKey] {
199        &[ContextKey::Signals]
200    }
201
202    fn accepts(&self, ctx: &dyn Context) -> bool {
203        ctx.has(ContextKey::Signals)
204            && match read_latest_split_from_ctx(ctx) {
205                Ok(split) => !has_data_quality_for_iteration(ctx, split.iteration),
206                Err(_) => false,
207            }
208    }
209
210    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
211        let split = match read_latest_split_from_ctx(ctx) {
212            Ok(split) => split,
213            Err(err) => {
214                return AgentEffect::with_proposal(proposal(
215                    self.name(),
216                    ContextKey::Diagnostic,
217                    "data-validation-error",
218                    err.to_string(),
219                ));
220            }
221        };
222
223        let df = match load_dataframe(Path::new(&split.train_path)) {
224            Ok(df) => df,
225            Err(err) => {
226                return AgentEffect::with_proposal(proposal(
227                    self.name(),
228                    ContextKey::Diagnostic,
229                    "data-validation-error",
230                    err.to_string(),
231                ));
232            }
233        };
234
235        let rows = df.height();
236        let mut missingness = HashMap::new();
237        let mut numeric_means = HashMap::new();
238        let mut outlier_counts = HashMap::new();
239
240        for series in df.get_columns() {
241            let name = series.name().to_string();
242            let null_ratio = if rows > 0 {
243                series.null_count() as f64 / rows as f64
244            } else {
245                0.0
246            };
247            missingness.insert(name.clone(), null_ratio);
248
249            if is_numeric_dtype(series.dtype()) {
250                if let Ok((mean, _std, outliers)) =
251                    compute_numeric_stats(series.as_materialized_series())
252                {
253                    numeric_means.insert(name.clone(), mean);
254                    outlier_counts.insert(name, outliers);
255                }
256            }
257        }
258
259        let drift_score = drift_score_from_ctx(ctx, split.iteration, &numeric_means);
260
261        let report = DataQualityReport {
262            kind: "data_quality".to_string(),
263            iteration: split.iteration,
264            source_path: split.train_path.clone(),
265            rows_checked: rows,
266            missingness,
267            numeric_means,
268            outlier_counts,
269            drift_score,
270        };
271
272        let content = serde_json::to_string(&report).unwrap_or_default();
273        AgentEffect::with_proposal(proposal(
274            self.name(),
275            ContextKey::Signals,
276            format!("data-quality-{}", split.iteration),
277            content,
278        ))
279    }
280}
281
282#[derive(Debug, Default)]
283pub struct FeatureEngineeringAgent;
284
285impl FeatureEngineeringAgent {
286    pub fn new() -> Self {
287        Self
288    }
289}
290
291#[async_trait::async_trait]
292impl Suggestor for FeatureEngineeringAgent {
293    fn name(&self) -> &str {
294        "FeatureEngineeringAgent"
295    }
296
297    fn dependencies(&self) -> &[ContextKey] {
298        &[ContextKey::Signals]
299    }
300
301    fn accepts(&self, ctx: &dyn Context) -> bool {
302        ctx.has(ContextKey::Signals)
303            && match read_latest_split_from_ctx(ctx) {
304                Ok(split) => !has_feature_spec_for_iteration(ctx, split.iteration),
305                Err(_) => false,
306            }
307    }
308
309    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
310        let split = match read_latest_split_from_ctx(ctx) {
311            Ok(split) => split,
312            Err(err) => {
313                return AgentEffect::with_proposal(proposal(
314                    self.name(),
315                    ContextKey::Diagnostic,
316                    "feature-engineering-error",
317                    err.to_string(),
318                ));
319            }
320        };
321
322        let df = match load_dataframe(Path::new(&split.train_path)) {
323            Ok(df) => df,
324            Err(err) => {
325                return AgentEffect::with_proposal(proposal(
326                    self.name(),
327                    ContextKey::Diagnostic,
328                    "feature-engineering-error",
329                    err.to_string(),
330                ));
331            }
332        };
333
334        let (target_column, _) = match select_target_column(&df) {
335            Ok(value) => value,
336            Err(err) => {
337                return AgentEffect::with_proposal(proposal(
338                    self.name(),
339                    ContextKey::Diagnostic,
340                    "feature-engineering-error",
341                    err.to_string(),
342                ));
343            }
344        };
345
346        let (numeric_features, categorical_features) = split_feature_columns(&df, &target_column);
347
348        let mut interactions = Vec::new();
349        if numeric_features.len() >= 2 {
350            interactions.push(FeatureInteraction {
351                name: format!("{}_x_{}", numeric_features[0], numeric_features[1]),
352                left: numeric_features[0].clone(),
353                right: numeric_features[1].clone(),
354                op: "multiply".to_string(),
355            });
356        }
357
358        let spec = FeatureSpec {
359            kind: "feature_spec".to_string(),
360            iteration: split.iteration,
361            target_column,
362            numeric_features,
363            categorical_features,
364            normalization: "standardize".to_string(),
365            interactions,
366        };
367
368        let content = serde_json::to_string(&spec).unwrap_or_default();
369        AgentEffect::with_proposal(proposal(
370            self.name(),
371            ContextKey::Constraints,
372            format!("feature-spec-{}", split.iteration),
373            content,
374        ))
375    }
376}
377
378#[derive(Debug)]
379pub struct HyperparameterSearchAgent {
380    max_trials: usize,
381}
382
383impl HyperparameterSearchAgent {
384    pub fn new(max_trials: usize) -> Self {
385        Self { max_trials }
386    }
387}
388
389#[async_trait::async_trait]
390impl Suggestor for HyperparameterSearchAgent {
391    fn name(&self) -> &str {
392        "HyperparameterSearchAgent"
393    }
394
395    fn dependencies(&self) -> &[ContextKey] {
396        &[ContextKey::Constraints, ContextKey::Signals]
397    }
398
399    fn accepts(&self, ctx: &dyn Context) -> bool {
400        ctx.has(ContextKey::Signals)
401            && match read_latest_split_from_ctx(ctx) {
402                Ok(split) => !has_hyperparam_result_for_iteration(ctx, split.iteration),
403                Err(_) => false,
404            }
405    }
406
407    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
408        let split = match read_latest_split_from_ctx(ctx) {
409            Ok(split) => split,
410            Err(err) => {
411                return AgentEffect::with_proposal(proposal(
412                    self.name(),
413                    ContextKey::Diagnostic,
414                    "hyperparam-search-error",
415                    err.to_string(),
416                ));
417            }
418        };
419
420        let training_plan = read_latest_plan_from_ctx(ctx).unwrap_or(TrainingPlan {
421            iteration: split.iteration,
422            max_rows: split.max_rows,
423            train_fraction: 0.8,
424            val_fraction: 0.15,
425            infer_fraction: 0.05,
426            quality_threshold: 0.75,
427        });
428
429        let mut params = HashMap::new();
430        params.insert("learning_rate".to_string(), vec![0.001, 0.01, 0.1]);
431        params.insert("hidden_size".to_string(), vec![8.0, 16.0, 32.0]);
432
433        let plan = HyperparameterSearchPlan {
434            kind: "hyperparam_plan".to_string(),
435            iteration: split.iteration,
436            max_trials: self.max_trials,
437            early_stopping: true,
438            params,
439        };
440
441        let mut best_params = HashMap::new();
442        best_params.insert("learning_rate".to_string(), 0.01);
443        best_params.insert("hidden_size".to_string(), 16.0);
444        let score = (1.0 - training_plan.quality_threshold) * plan.max_trials as f64
445            / plan.iteration.max(1) as f64;
446        let result = HyperparameterSearchResult {
447            kind: "hyperparam_result".to_string(),
448            iteration: split.iteration,
449            best_params,
450            score,
451        };
452
453        let plan_content = serde_json::to_string(&plan).unwrap_or_default();
454        let result_content = serde_json::to_string(&result).unwrap_or_default();
455
456        let mut effect = AgentEffect::empty();
457        effect.proposals.push(proposal(
458            self.name(),
459            ContextKey::Constraints,
460            format!("hyperparam-plan-{}", split.iteration),
461            plan_content,
462        ));
463        effect.proposals.push(proposal(
464            self.name(),
465            ContextKey::Evaluations,
466            format!("hyperparam-result-{}", split.iteration),
467            result_content,
468        ));
469        effect
470    }
471}
472
473#[async_trait::async_trait]
474impl Suggestor for DatasetAgent {
475    fn name(&self) -> &str {
476        "DatasetAgent (HuggingFace)"
477    }
478
479    fn dependencies(&self) -> &[ContextKey] {
480        &[ContextKey::Seeds]
481    }
482
483    fn accepts(&self, ctx: &dyn Context) -> bool {
484        if !ctx.has(ContextKey::Seeds) {
485            return false;
486        }
487
488        let plan = read_latest_plan_from_ctx(ctx);
489        if let Some(plan) = plan {
490            return !has_split_for_iteration(ctx, plan.iteration);
491        }
492
493        !ctx.has(ContextKey::Signals)
494    }
495
496    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
497        if let Err(err) = create_dir_all(&self.data_dir) {
498            return AgentEffect::with_proposal(proposal(
499                self.name(),
500                ContextKey::Diagnostic,
501                "dataset-agent-error",
502                err.to_string(),
503            ));
504        }
505
506        let dataset_path = self.dataset_path();
507        if let Err(err) = download_dataset_if_missing(&dataset_path) {
508            return AgentEffect::with_proposal(proposal(
509                self.name(),
510                ContextKey::Diagnostic,
511                "dataset-agent-error",
512                err.to_string(),
513            ));
514        }
515
516        let df = match load_dataframe(&dataset_path) {
517            Ok(df) => df,
518            Err(err) => {
519                return AgentEffect::with_proposal(proposal(
520                    self.name(),
521                    ContextKey::Diagnostic,
522                    "dataset-agent-error",
523                    err.to_string(),
524                ));
525            }
526        };
527
528        let total_rows = df.height();
529        if total_rows < 10 {
530            return AgentEffect::with_proposal(proposal(
531                self.name(),
532                ContextKey::Diagnostic,
533                "dataset-agent-error",
534                "dataset too small for splitting",
535            ));
536        }
537
538        let plan = read_latest_plan_from_ctx(ctx).unwrap_or(TrainingPlan {
539            iteration: 1,
540            max_rows: total_rows,
541            train_fraction: 0.8,
542            val_fraction: 0.15,
543            infer_fraction: 0.05,
544            quality_threshold: 0.75,
545        });
546
547        let max_rows = plan.max_rows.min(total_rows).max(10);
548        let df = df.slice(0, max_rows);
549
550        let mut train_rows = ((max_rows as f64) * plan.train_fraction).floor() as usize;
551        let mut val_rows = ((max_rows as f64) * plan.val_fraction).floor() as usize;
552        let mut infer_rows = max_rows.saturating_sub(train_rows + val_rows);
553        if infer_rows == 0 {
554            if val_rows > 1 {
555                val_rows -= 1;
556            } else if train_rows > 1 {
557                train_rows -= 1;
558            }
559            infer_rows = max_rows.saturating_sub(train_rows + val_rows).max(1);
560        }
561
562        let (train_path, val_path, infer_path) = self.split_paths();
563        let train_df = df.slice(0, train_rows);
564        let val_df = df.slice(train_rows as i64, val_rows);
565        let infer_df = df.slice((train_rows + val_rows) as i64, infer_rows);
566
567        if let Err(err) = write_parquet(&train_df, &train_path)
568            .and_then(|_| write_parquet(&val_df, &val_path))
569            .and_then(|_| write_parquet(&infer_df, &infer_path))
570        {
571            return AgentEffect::with_proposal(proposal(
572                self.name(),
573                ContextKey::Diagnostic,
574                "dataset-agent-error",
575                err.to_string(),
576            ));
577        }
578
579        let split = DatasetSplit {
580            source_path: dataset_path.to_string_lossy().to_string(),
581            train_path: train_path.to_string_lossy().to_string(),
582            val_path: val_path.to_string_lossy().to_string(),
583            infer_path: infer_path.to_string_lossy().to_string(),
584            total_rows,
585            max_rows,
586            train_rows,
587            val_rows,
588            infer_rows,
589            iteration: plan.iteration,
590        };
591
592        let content = serde_json::to_string(&split).unwrap_or_default();
593        AgentEffect::with_proposal(proposal(
594            self.name(),
595            ContextKey::Signals,
596            format!("dataset-split-{}", plan.iteration),
597            content,
598        ))
599    }
600}
601
602#[derive(Debug)]
603pub struct ModelTrainingAgent {
604    model_dir: PathBuf,
605}
606
607impl ModelTrainingAgent {
608    pub fn new(model_dir: PathBuf) -> Self {
609        Self { model_dir }
610    }
611
612    fn model_path(&self) -> PathBuf {
613        self.model_dir.join("baseline_mean.json")
614    }
615}
616
617#[async_trait::async_trait]
618impl Suggestor for ModelTrainingAgent {
619    fn name(&self) -> &str {
620        "ModelTrainingAgent (Baseline)"
621    }
622
623    fn dependencies(&self) -> &[ContextKey] {
624        &[ContextKey::Signals]
625    }
626
627    fn accepts(&self, ctx: &dyn Context) -> bool {
628        if !ctx.has(ContextKey::Signals) {
629            return false;
630        }
631        let split = match read_latest_split_from_ctx(ctx) {
632            Ok(split) => split,
633            Err(_) => return false,
634        };
635        !has_model_for_iteration(ctx, split.iteration)
636    }
637
638    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
639        let split = match read_latest_split_from_ctx(ctx) {
640            Ok(split) => split,
641            Err(err) => {
642                return AgentEffect::with_proposal(proposal(
643                    self.name(),
644                    ContextKey::Diagnostic,
645                    "model-training-error",
646                    err.to_string(),
647                ));
648            }
649        };
650
651        if let Err(err) = create_dir_all(&self.model_dir) {
652            return AgentEffect::with_proposal(proposal(
653                self.name(),
654                ContextKey::Diagnostic,
655                "model-training-error",
656                err.to_string(),
657            ));
658        }
659
660        let raw_train_df = match load_dataframe(Path::new(&split.train_path)) {
661            Ok(df) => df,
662            Err(err) => {
663                return AgentEffect::with_proposal(proposal(
664                    self.name(),
665                    ContextKey::Diagnostic,
666                    "model-training-error",
667                    err.to_string(),
668                ));
669            }
670        };
671
672        // Apply FeatureSpec transformation if available
673        let train_df = match read_feature_spec_from_ctx(ctx, split.iteration) {
674            Some(spec) => match apply_feature_spec(&raw_train_df, &spec) {
675                Ok(df) => df,
676                Err(err) => {
677                    return AgentEffect::with_proposal(proposal(
678                        self.name(),
679                        ContextKey::Diagnostic,
680                        "model-training-error",
681                        format!("feature spec application failed: {}", err),
682                    ));
683                }
684            },
685            None => raw_train_df,
686        };
687
688        let (target_name, target) = match select_target_column(&train_df) {
689            Ok(value) => value,
690            Err(err) => {
691                return AgentEffect::with_proposal(proposal(
692                    self.name(),
693                    ContextKey::Diagnostic,
694                    "model-training-error",
695                    err.to_string(),
696                ));
697            }
698        };
699
700        let mean = match mean_of_series(&target) {
701            Ok(value) => value,
702            Err(err) => {
703                return AgentEffect::with_proposal(proposal(
704                    self.name(),
705                    ContextKey::Diagnostic,
706                    "model-training-error",
707                    err.to_string(),
708                ));
709            }
710        };
711
712        let model = BaselineModel {
713            target_column: target_name.clone(),
714            mean,
715        };
716
717        let model_path = self.model_path();
718        if let Err(err) = write_json(&model_path, &model) {
719            return AgentEffect::with_proposal(proposal(
720                self.name(),
721                ContextKey::Diagnostic,
722                "model-training-error",
723                err.to_string(),
724            ));
725        }
726
727        let meta = ModelMetadata {
728            model_path: model_path.to_string_lossy().to_string(),
729            target_column: target_name,
730            train_rows: split.train_rows,
731            baseline_mean: mean,
732            iteration: split.iteration,
733        };
734
735        let content = serde_json::to_string(&meta).unwrap_or_default();
736        AgentEffect::with_proposal(proposal(
737            self.name(),
738            ContextKey::Strategies,
739            format!("trained-model-{}", split.iteration),
740            content,
741        ))
742    }
743}
744
745#[derive(Debug, Default)]
746pub struct ModelEvaluationAgent;
747
748impl ModelEvaluationAgent {
749    pub fn new() -> Self {
750        Self
751    }
752}
753
754#[derive(Debug, Default)]
755pub struct ModelRegistryAgent;
756
757impl ModelRegistryAgent {
758    pub fn new() -> Self {
759        Self
760    }
761}
762
763#[async_trait::async_trait]
764impl Suggestor for ModelRegistryAgent {
765    fn name(&self) -> &str {
766        "ModelRegistryAgent"
767    }
768
769    fn dependencies(&self) -> &[ContextKey] {
770        &[ContextKey::Strategies, ContextKey::Evaluations]
771    }
772
773    fn accepts(&self, ctx: &dyn Context) -> bool {
774        ctx.has(ContextKey::Strategies)
775            && ctx.has(ContextKey::Evaluations)
776            && match read_latest_model_meta_from_ctx(ctx) {
777                Ok(meta) => !has_registry_record_for_iteration(ctx, meta.iteration),
778                Err(_) => false,
779            }
780    }
781
782    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
783        let meta = match read_latest_model_meta_from_ctx(ctx) {
784            Ok(meta) => meta,
785            Err(err) => {
786                return AgentEffect::with_proposal(proposal(
787                    self.name(),
788                    ContextKey::Diagnostic,
789                    "model-registry-error",
790                    err.to_string(),
791                ));
792            }
793        };
794
795        let report = latest_evaluation_report(ctx, meta.iteration);
796        let mut metrics = HashMap::new();
797        if let Some(report) = report {
798            metrics.insert(report.metric, report.value);
799            metrics.insert("success_ratio".to_string(), report.success_ratio);
800        }
801
802        let record = ModelRegistryRecord {
803            kind: "model_registry".to_string(),
804            iteration: meta.iteration,
805            model_path: meta.model_path,
806            metrics,
807            provenance: "training_flow".to_string(),
808        };
809
810        let content = serde_json::to_string(&record).unwrap_or_default();
811        AgentEffect::with_proposal(proposal(
812            self.name(),
813            ContextKey::Strategies,
814            format!("model-registry-{}", record.iteration),
815            content,
816        ))
817    }
818}
819
820#[derive(Debug, Default)]
821pub struct MonitoringAgent;
822
823impl MonitoringAgent {
824    pub fn new() -> Self {
825        Self
826    }
827}
828
829#[async_trait::async_trait]
830impl Suggestor for MonitoringAgent {
831    fn name(&self) -> &str {
832        "MonitoringAgent"
833    }
834
835    fn dependencies(&self) -> &[ContextKey] {
836        &[ContextKey::Evaluations]
837    }
838
839    fn accepts(&self, ctx: &dyn Context) -> bool {
840        ctx.has(ContextKey::Evaluations)
841            && match latest_evaluation_report(ctx, 0) {
842                Some(report) => !has_monitoring_report_for_iteration(ctx, report.iteration),
843                None => false,
844            }
845    }
846
847    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
848        let report = match latest_evaluation_report(ctx, 0) {
849            Some(report) => report,
850            None => return AgentEffect::empty(),
851        };
852
853        let status = if report.success_ratio >= 0.75 {
854            "healthy"
855        } else {
856            "needs_attention"
857        };
858
859        let monitoring = MonitoringReport {
860            kind: "monitoring".to_string(),
861            iteration: report.iteration,
862            metric: report.metric,
863            value: report.value,
864            baseline: report.mean_abs_target,
865            status: status.to_string(),
866        };
867
868        let content = serde_json::to_string(&monitoring).unwrap_or_default();
869        AgentEffect::with_proposal(proposal(
870            self.name(),
871            ContextKey::Evaluations,
872            format!("monitoring-{}", report.iteration),
873            content,
874        ))
875    }
876}
877
878#[derive(Debug, Default)]
879pub struct DeploymentAgent;
880
881impl DeploymentAgent {
882    pub fn new() -> Self {
883        Self
884    }
885}
886
887#[async_trait::async_trait]
888impl Suggestor for DeploymentAgent {
889    fn name(&self) -> &str {
890        "DeploymentAgent"
891    }
892
893    fn dependencies(&self) -> &[ContextKey] {
894        &[ContextKey::Evaluations, ContextKey::Strategies]
895    }
896
897    fn accepts(&self, ctx: &dyn Context) -> bool {
898        ctx.has(ContextKey::Evaluations)
899            && ctx.has(ContextKey::Strategies)
900            && match latest_evaluation_report(ctx, 0) {
901                Some(report) => !has_deployment_decision_for_iteration(ctx, report.iteration),
902                None => false,
903            }
904    }
905
906    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
907        let report = match latest_evaluation_report(ctx, 0) {
908            Some(report) => report,
909            None => return AgentEffect::empty(),
910        };
911
912        let quality_threshold = read_latest_plan_from_ctx(ctx)
913            .map(|plan| plan.quality_threshold)
914            .unwrap_or(0.75);
915
916        let (action, retrain, reason) = if report.success_ratio >= quality_threshold {
917            ("deploy", false, "meets quality threshold")
918        } else {
919            ("hold", true, "below quality threshold")
920        };
921
922        let decision = DeploymentDecision {
923            kind: "deployment_decision".to_string(),
924            iteration: report.iteration,
925            action: action.to_string(),
926            reason: reason.to_string(),
927            retrain,
928        };
929
930        let content = serde_json::to_string(&decision).unwrap_or_default();
931        AgentEffect::with_proposal(proposal(
932            self.name(),
933            ContextKey::Strategies,
934            format!("deployment-{}", report.iteration),
935            content,
936        ))
937    }
938}
939
940#[async_trait::async_trait]
941impl Suggestor for ModelEvaluationAgent {
942    fn name(&self) -> &str {
943        "ModelEvaluationAgent (MAE)"
944    }
945
946    fn dependencies(&self) -> &[ContextKey] {
947        &[ContextKey::Signals, ContextKey::Strategies]
948    }
949
950    fn accepts(&self, ctx: &dyn Context) -> bool {
951        ctx.has(ContextKey::Signals)
952            && ctx.has(ContextKey::Strategies)
953            && match read_latest_split_from_ctx(ctx) {
954                Ok(split) => !has_evaluation_for_iteration(ctx, split.iteration),
955                Err(_) => false,
956            }
957    }
958
959    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
960        let split = match read_latest_split_from_ctx(ctx) {
961            Ok(split) => split,
962            Err(err) => {
963                return AgentEffect::with_proposal(proposal(
964                    self.name(),
965                    ContextKey::Diagnostic,
966                    "model-eval-error",
967                    err.to_string(),
968                ));
969            }
970        };
971
972        let model = match read_model_from_ctx(ctx) {
973            Ok(model) => model,
974            Err(err) => {
975                return AgentEffect::with_proposal(proposal(
976                    self.name(),
977                    ContextKey::Diagnostic,
978                    "model-eval-error",
979                    err.to_string(),
980                ));
981            }
982        };
983
984        let raw_val_df = match load_dataframe(Path::new(&split.val_path)) {
985            Ok(df) => df,
986            Err(err) => {
987                return AgentEffect::with_proposal(proposal(
988                    self.name(),
989                    ContextKey::Diagnostic,
990                    "model-eval-error",
991                    err.to_string(),
992                ));
993            }
994        };
995
996        // Apply FeatureSpec transformation if available
997        let val_df = match read_feature_spec_from_ctx(ctx, split.iteration) {
998            Some(spec) => apply_feature_spec(&raw_val_df, &spec).unwrap_or(raw_val_df),
999            None => raw_val_df,
1000        };
1001
1002        let target = match get_numeric_series(&val_df, &model.target_column) {
1003            Ok(series) => series,
1004            Err(err) => {
1005                return AgentEffect::with_proposal(proposal(
1006                    self.name(),
1007                    ContextKey::Diagnostic,
1008                    "model-eval-error",
1009                    err.to_string(),
1010                ));
1011            }
1012        };
1013
1014        let mae = match mean_abs_error(&target, model.mean) {
1015            Ok(value) => value,
1016            Err(err) => {
1017                return AgentEffect::with_proposal(proposal(
1018                    self.name(),
1019                    ContextKey::Diagnostic,
1020                    "model-eval-error",
1021                    err.to_string(),
1022                ));
1023            }
1024        };
1025
1026        let mean_abs = match mean_abs_value(&target) {
1027            Ok(value) => value,
1028            Err(err) => {
1029                return AgentEffect::with_proposal(proposal(
1030                    self.name(),
1031                    ContextKey::Diagnostic,
1032                    "model-eval-error",
1033                    err.to_string(),
1034                ));
1035            }
1036        };
1037
1038        let success_ratio = if mean_abs > 0.0 {
1039            (1.0 - (mae / mean_abs)).clamp(0.0, 1.0)
1040        } else {
1041            0.0
1042        };
1043
1044        let report = EvaluationReport {
1045            model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
1046            metric: "mae".to_string(),
1047            value: mae,
1048            mean_abs_target: mean_abs,
1049            success_ratio,
1050            val_rows: split.val_rows,
1051            iteration: split.iteration,
1052        };
1053
1054        let content = serde_json::to_string(&report).unwrap_or_default();
1055        AgentEffect::with_proposal(proposal(
1056            self.name(),
1057            ContextKey::Evaluations,
1058            format!("model-eval-{}", split.iteration),
1059            content,
1060        ))
1061    }
1062}
1063
1064#[derive(Debug)]
1065pub struct SampleInferenceAgent {
1066    pub max_rows: usize,
1067}
1068
1069impl SampleInferenceAgent {
1070    pub fn new(max_rows: usize) -> Self {
1071        Self { max_rows }
1072    }
1073}
1074
1075#[async_trait::async_trait]
1076impl Suggestor for SampleInferenceAgent {
1077    fn name(&self) -> &str {
1078        "SampleInferenceAgent (Baseline)"
1079    }
1080
1081    fn dependencies(&self) -> &[ContextKey] {
1082        &[ContextKey::Signals, ContextKey::Strategies]
1083    }
1084
1085    fn accepts(&self, ctx: &dyn Context) -> bool {
1086        ctx.has(ContextKey::Signals)
1087            && ctx.has(ContextKey::Strategies)
1088            && match read_latest_split_from_ctx(ctx) {
1089                Ok(split) => !has_inference_for_iteration(ctx, split.iteration),
1090                Err(_) => false,
1091            }
1092    }
1093
1094    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
1095        let split = match read_latest_split_from_ctx(ctx) {
1096            Ok(split) => split,
1097            Err(err) => {
1098                return AgentEffect::with_proposal(proposal(
1099                    self.name(),
1100                    ContextKey::Diagnostic,
1101                    "model-infer-error",
1102                    err.to_string(),
1103                ));
1104            }
1105        };
1106
1107        let model = match read_model_from_ctx(ctx) {
1108            Ok(model) => model,
1109            Err(err) => {
1110                return AgentEffect::with_proposal(proposal(
1111                    self.name(),
1112                    ContextKey::Diagnostic,
1113                    "model-infer-error",
1114                    err.to_string(),
1115                ));
1116            }
1117        };
1118
1119        let infer_df = match load_dataframe(Path::new(&split.infer_path)) {
1120            Ok(df) => df,
1121            Err(err) => {
1122                return AgentEffect::with_proposal(proposal(
1123                    self.name(),
1124                    ContextKey::Diagnostic,
1125                    "model-infer-error",
1126                    err.to_string(),
1127                ));
1128            }
1129        };
1130
1131        let target = match get_numeric_series(&infer_df, &model.target_column) {
1132            Ok(series) => series,
1133            Err(err) => {
1134                return AgentEffect::with_proposal(proposal(
1135                    self.name(),
1136                    ContextKey::Diagnostic,
1137                    "model-infer-error",
1138                    err.to_string(),
1139                ));
1140            }
1141        };
1142
1143        let sample_rows = self.max_rows.min(infer_df.height().max(1));
1144        let actuals = match target.f64() {
1145            Ok(series) => series
1146                .into_no_null_iter()
1147                .take(sample_rows)
1148                .collect::<Vec<_>>(),
1149            Err(err) => {
1150                return AgentEffect::with_proposal(proposal(
1151                    self.name(),
1152                    ContextKey::Diagnostic,
1153                    "model-infer-error",
1154                    err.to_string(),
1155                ));
1156            }
1157        };
1158
1159        let predictions = vec![model.mean; actuals.len()];
1160        let sample = InferenceSample {
1161            model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
1162            target_column: model.target_column,
1163            rows: actuals.len(),
1164            predictions,
1165            actuals,
1166            iteration: split.iteration,
1167        };
1168
1169        let content = serde_json::to_string(&sample).unwrap_or_default();
1170        AgentEffect::with_proposal(proposal(
1171            self.name(),
1172            ContextKey::Hypotheses,
1173            format!("inference-sample-{}", split.iteration),
1174            content,
1175        ))
1176    }
1177}
1178
1179fn download_dataset_if_missing(path: &Path) -> Result<()> {
1180    if path.exists() {
1181        return Ok(());
1182    }
1183
1184    let response = reqwest::blocking::get(DATASET_URL)?;
1185    let content = response.bytes()?;
1186
1187    let mut file = File::create(path)?;
1188    file.write_all(&content)?;
1189
1190    Ok(())
1191}
1192
1193fn load_dataframe(path: &Path) -> Result<DataFrame> {
1194    let extension = path
1195        .extension()
1196        .and_then(|ext| ext.to_str())
1197        .unwrap_or("")
1198        .to_ascii_lowercase();
1199
1200    let path_str = path
1201        .to_str()
1202        .ok_or_else(|| anyhow!("path is not valid utf-8: {}", path.display()))?;
1203
1204    match extension.as_str() {
1205        "parquet" => {
1206            let pl_path = PlPath::from_str(path_str);
1207            Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
1208        }
1209        "csv" => Ok(CsvReadOptions::default()
1210            .with_has_header(true)
1211            .try_into_reader_with_file_path(Some(path.to_path_buf()))?
1212            .finish()?),
1213        _ => Err(anyhow!(
1214            "unsupported data format for path {} (expected .csv or .parquet)",
1215            path.display()
1216        )),
1217    }
1218}
1219
1220fn write_parquet(df: &DataFrame, path: &Path) -> Result<()> {
1221    let mut file = File::create(path)?;
1222    let mut owned = df.clone();
1223    ParquetWriter::new(&mut file).finish(&mut owned)?;
1224    Ok(())
1225}
1226
1227fn write_json<T: Serialize>(path: &Path, value: &T) -> Result<()> {
1228    let content = serde_json::to_string_pretty(value)?;
1229    let mut file = File::create(path)?;
1230    file.write_all(content.as_bytes())?;
1231    Ok(())
1232}
1233
1234fn read_latest_split_from_ctx(ctx: &dyn Context) -> Result<DatasetSplit> {
1235    let facts = ctx.get(ContextKey::Signals);
1236    let mut latest: Option<DatasetSplit> = None;
1237    for fact in facts {
1238        if let Ok(split) = serde_json::from_str::<DatasetSplit>(&fact.content) {
1239            let should_replace = match &latest {
1240                Some(current) => split.iteration > current.iteration,
1241                None => true,
1242            };
1243            if should_replace {
1244                latest = Some(split);
1245            }
1246        }
1247    }
1248    latest.ok_or_else(|| anyhow!("missing dataset split"))
1249}
1250
1251fn read_model_path_from_ctx(ctx: &dyn Context) -> Result<String> {
1252    let meta = read_latest_model_meta_from_ctx(ctx)?;
1253    Ok(meta.model_path)
1254}
1255
1256fn read_model_from_ctx(ctx: &dyn Context) -> Result<BaselineModel> {
1257    let model_path = read_model_path_from_ctx(ctx)?;
1258    let content = std::fs::read_to_string(model_path)?;
1259    let model = serde_json::from_str(&content)?;
1260    Ok(model)
1261}
1262
1263fn read_latest_model_meta_from_ctx(ctx: &dyn Context) -> Result<ModelMetadata> {
1264    let facts = ctx.get(ContextKey::Strategies);
1265    let mut latest: Option<ModelMetadata> = None;
1266    for fact in facts {
1267        if let Ok(meta) = serde_json::from_str::<ModelMetadata>(&fact.content) {
1268            let should_replace = match &latest {
1269                Some(current) => meta.iteration > current.iteration,
1270                None => true,
1271            };
1272            if should_replace {
1273                latest = Some(meta);
1274            }
1275        }
1276    }
1277    latest.ok_or_else(|| anyhow!("missing model metadata"))
1278}
1279
1280fn read_latest_plan_from_ctx(ctx: &dyn Context) -> Option<TrainingPlan> {
1281    let facts = ctx.get(ContextKey::Constraints);
1282    let mut latest: Option<TrainingPlan> = None;
1283    for fact in facts {
1284        if let Ok(plan) = serde_json::from_str::<TrainingPlan>(&fact.content) {
1285            let should_replace = match &latest {
1286                Some(current) => plan.iteration > current.iteration,
1287                None => true,
1288            };
1289            if should_replace {
1290                latest = Some(plan);
1291            }
1292        }
1293    }
1294    latest
1295}
1296
1297fn has_split_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1298    ctx.get(ContextKey::Signals).iter().any(|fact| {
1299        serde_json::from_str::<DatasetSplit>(&fact.content)
1300            .map(|split| split.iteration == iteration)
1301            .unwrap_or(false)
1302    })
1303}
1304
1305fn has_model_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1306    ctx.get(ContextKey::Strategies).iter().any(|fact| {
1307        serde_json::from_str::<ModelMetadata>(&fact.content)
1308            .map(|meta| meta.iteration == iteration)
1309            .unwrap_or(false)
1310    })
1311}
1312
1313fn has_evaluation_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1314    ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1315        serde_json::from_str::<EvaluationReport>(&fact.content)
1316            .map(|report| report.iteration == iteration)
1317            .unwrap_or(false)
1318    })
1319}
1320
1321fn has_inference_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1322    ctx.get(ContextKey::Hypotheses).iter().any(|fact| {
1323        serde_json::from_str::<InferenceSample>(&fact.content)
1324            .map(|sample| sample.iteration == iteration)
1325            .unwrap_or(false)
1326    })
1327}
1328
1329fn has_data_quality_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1330    ctx.get(ContextKey::Signals).iter().any(|fact| {
1331        serde_json::from_str::<DataQualityReport>(&fact.content)
1332            .map(|report| report.iteration == iteration)
1333            .unwrap_or(false)
1334    })
1335}
1336
1337fn has_feature_spec_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1338    ctx.get(ContextKey::Constraints).iter().any(|fact| {
1339        serde_json::from_str::<FeatureSpec>(&fact.content)
1340            .map(|spec| spec.iteration == iteration)
1341            .unwrap_or(false)
1342    })
1343}
1344
1345fn has_hyperparam_result_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1346    ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1347        serde_json::from_str::<HyperparameterSearchResult>(&fact.content)
1348            .map(|result| result.iteration == iteration)
1349            .unwrap_or(false)
1350    })
1351}
1352
1353fn has_registry_record_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1354    ctx.get(ContextKey::Strategies).iter().any(|fact| {
1355        serde_json::from_str::<ModelRegistryRecord>(&fact.content)
1356            .map(|record| record.iteration == iteration)
1357            .unwrap_or(false)
1358    })
1359}
1360
1361fn has_monitoring_report_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1362    ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1363        serde_json::from_str::<MonitoringReport>(&fact.content)
1364            .map(|report| report.iteration == iteration)
1365            .unwrap_or(false)
1366    })
1367}
1368
1369fn has_deployment_decision_for_iteration(ctx: &dyn Context, iteration: usize) -> bool {
1370    ctx.get(ContextKey::Strategies).iter().any(|fact| {
1371        serde_json::from_str::<DeploymentDecision>(&fact.content)
1372            .map(|decision| decision.iteration == iteration)
1373            .unwrap_or(false)
1374    })
1375}
1376
1377fn latest_evaluation_report(ctx: &dyn Context, iteration: usize) -> Option<EvaluationReport> {
1378    let mut latest: Option<EvaluationReport> = None;
1379    for fact in ctx.get(ContextKey::Evaluations) {
1380        if let Ok(report) = serde_json::from_str::<EvaluationReport>(&fact.content) {
1381            if iteration > 0 {
1382                if report.iteration == iteration {
1383                    return Some(report);
1384                }
1385            } else if latest
1386                .as_ref()
1387                .map(|current| report.iteration > current.iteration)
1388                .unwrap_or(true)
1389            {
1390                latest = Some(report);
1391            }
1392        }
1393    }
1394    if iteration > 0 { None } else { latest }
1395}
1396
1397fn latest_data_quality_before_iteration(
1398    ctx: &dyn Context,
1399    iteration: usize,
1400) -> Option<DataQualityReport> {
1401    let mut latest: Option<DataQualityReport> = None;
1402    for fact in ctx.get(ContextKey::Signals) {
1403        if let Ok(report) = serde_json::from_str::<DataQualityReport>(&fact.content) {
1404            if report.iteration < iteration
1405                && latest
1406                    .as_ref()
1407                    .map(|current| report.iteration > current.iteration)
1408                    .unwrap_or(true)
1409            {
1410                latest = Some(report);
1411            }
1412        }
1413    }
1414    latest
1415}
1416
1417fn drift_score_from_ctx(
1418    ctx: &dyn Context,
1419    iteration: usize,
1420    numeric_means: &HashMap<String, f64>,
1421) -> Option<f64> {
1422    let previous = latest_data_quality_before_iteration(ctx, iteration)?;
1423    let mut total_delta = 0.0;
1424    let mut count = 0usize;
1425    for (name, mean) in numeric_means {
1426        if let Some(prev_mean) = previous.numeric_means.get(name) {
1427            total_delta += (mean - prev_mean).abs();
1428            count += 1;
1429        }
1430    }
1431    if count == 0 {
1432        None
1433    } else {
1434        Some(total_delta / count as f64)
1435    }
1436}
1437
1438fn compute_numeric_stats(series: &Series) -> Result<(f64, f64, usize)> {
1439    let casted = series.cast(&DataType::Float64)?;
1440    let values: Vec<f64> = casted
1441        .f64()
1442        .context("numeric series not f64")?
1443        .into_no_null_iter()
1444        .collect();
1445    if values.is_empty() {
1446        return Err(anyhow!("no numeric values to compute stats"));
1447    }
1448
1449    let mut total = 0.0;
1450    for value in &values {
1451        total += *value;
1452    }
1453    let mean = total / values.len() as f64;
1454
1455    let mut variance_sum = 0.0;
1456    for value in &values {
1457        let diff = *value - mean;
1458        variance_sum += diff * diff;
1459    }
1460    let std = (variance_sum / values.len() as f64).sqrt();
1461
1462    let outliers = if std > 0.0 {
1463        values
1464            .iter()
1465            .filter(|value| (*value - mean).abs() > 3.0 * std)
1466            .count()
1467    } else {
1468        0
1469    };
1470
1471    Ok((mean, std, outliers))
1472}
1473
1474fn split_feature_columns(df: &DataFrame, target: &str) -> (Vec<String>, Vec<String>) {
1475    let mut numeric = Vec::new();
1476    let mut categorical = Vec::new();
1477    for series in df.get_columns() {
1478        let name = series.name();
1479        if name == target {
1480            continue;
1481        }
1482        if is_numeric_dtype(series.dtype()) {
1483            numeric.push(name.to_string());
1484        } else {
1485            categorical.push(name.to_string());
1486        }
1487    }
1488    (numeric, categorical)
1489}
1490
1491fn select_target_column(df: &DataFrame) -> Result<(String, Series)> {
1492    if let Ok(col) = df.column(TARGET_COLUMN) {
1493        return Ok((
1494            TARGET_COLUMN.to_string(),
1495            col.as_materialized_series().clone(),
1496        ));
1497    }
1498
1499    let mut numeric = df
1500        .get_columns()
1501        .iter()
1502        .filter(|series| is_numeric_dtype(series.dtype()))
1503        .cloned()
1504        .collect::<Vec<_>>();
1505
1506    let fallback = numeric
1507        .pop()
1508        .ok_or_else(|| anyhow!("no numeric columns available for target"))?;
1509    let series = fallback.as_materialized_series().clone();
1510    Ok((series.name().to_string(), series))
1511}
1512
1513fn get_numeric_series(df: &DataFrame, name: &str) -> Result<Series> {
1514    let series = df
1515        .column(name)
1516        .map_err(|_| anyhow!("missing target column {}", name))?
1517        .as_materialized_series();
1518    let casted = series.cast(&DataType::Float64)?;
1519    Ok(casted)
1520}
1521
1522fn mean_of_series(series: &Series) -> Result<f64> {
1523    let casted = series.cast(&DataType::Float64)?;
1524    let values = casted
1525        .f64()
1526        .context("target column not f64")?
1527        .into_no_null_iter();
1528    let mut total = 0.0;
1529    let mut count = 0usize;
1530    for value in values {
1531        total += value;
1532        count += 1;
1533    }
1534    if count == 0 {
1535        return Err(anyhow!("no values to compute mean"));
1536    }
1537    Ok(total / count as f64)
1538}
1539
1540fn mean_abs_error(target: &Series, mean: f64) -> Result<f64> {
1541    let casted = target.cast(&DataType::Float64)?;
1542    let values = casted
1543        .f64()
1544        .context("target column not f64")?
1545        .into_no_null_iter();
1546    let mut total = 0.0;
1547    let mut count = 0usize;
1548    for value in values {
1549        total += (value - mean).abs();
1550        count += 1;
1551    }
1552    if count == 0 {
1553        return Err(anyhow!("no values to evaluate"));
1554    }
1555    Ok(total / count as f64)
1556}
1557
1558fn mean_abs_value(target: &Series) -> Result<f64> {
1559    let casted = target.cast(&DataType::Float64)?;
1560    let values = casted
1561        .f64()
1562        .context("target column not f64")?
1563        .into_no_null_iter();
1564    let mut total = 0.0;
1565    let mut count = 0usize;
1566    for value in values {
1567        total += value.abs();
1568        count += 1;
1569    }
1570    if count == 0 {
1571        return Err(anyhow!("no values to evaluate"));
1572    }
1573    Ok(total / count as f64)
1574}
1575fn is_numeric_dtype(dtype: &DataType) -> bool {
1576    matches!(
1577        dtype,
1578        DataType::Int8
1579            | DataType::Int16
1580            | DataType::Int32
1581            | DataType::Int64
1582            | DataType::UInt8
1583            | DataType::UInt16
1584            | DataType::UInt32
1585            | DataType::UInt64
1586            | DataType::Float32
1587            | DataType::Float64
1588    )
1589}
1590
1591/// Read the latest FeatureSpec from context for a given iteration
1592fn read_feature_spec_from_ctx(ctx: &dyn Context, iteration: usize) -> Option<FeatureSpec> {
1593    ctx.get(ContextKey::Constraints).iter().find_map(|fact| {
1594        serde_json::from_str::<FeatureSpec>(&fact.content)
1595            .ok()
1596            .filter(|spec| spec.iteration == iteration)
1597    })
1598}
1599
1600/// Apply a FeatureSpec to a DataFrame, creating interaction features and normalizing
1601pub fn apply_feature_spec(df: &DataFrame, spec: &FeatureSpec) -> Result<DataFrame> {
1602    let mut result = df.clone();
1603
1604    // Apply feature interactions
1605    for interaction in &spec.interactions {
1606        let left_col = result
1607            .column(&interaction.left)
1608            .map_err(|_| anyhow!("missing column {} for interaction", interaction.left))?
1609            .cast(&DataType::Float64)?;
1610        let right_col = result
1611            .column(&interaction.right)
1612            .map_err(|_| anyhow!("missing column {} for interaction", interaction.right))?
1613            .cast(&DataType::Float64)?;
1614
1615        let left_vals = left_col.f64().context("left column not f64")?;
1616        let right_vals = right_col.f64().context("right column not f64")?;
1617
1618        let interaction_series = match interaction.op.as_str() {
1619            "multiply" => left_vals * right_vals,
1620            "add" => left_vals + right_vals,
1621            "subtract" => left_vals - right_vals,
1622            "divide" => {
1623                // Safe division: use map to handle division safely
1624                left_vals
1625                    .into_iter()
1626                    .zip(right_vals.into_iter())
1627                    .map(|(l, r)| match (l, r) {
1628                        (Some(lv), Some(rv)) if rv.abs() > 1e-10 => Some(lv / rv),
1629                        _ => None,
1630                    })
1631                    .collect::<Float64Chunked>()
1632            }
1633            _ => return Err(anyhow!("unsupported interaction op: {}", interaction.op)),
1634        };
1635
1636        let named_series = interaction_series.with_name(interaction.name.clone().into());
1637        result = result
1638            .hstack(&[named_series.into_series().into()])
1639            .context("failed to add interaction column")?;
1640    }
1641
1642    // Apply normalization to numeric features
1643    if spec.normalization == "standardize" {
1644        for col_name in &spec.numeric_features {
1645            if let Ok(col) = result.column(col_name) {
1646                let casted = col.cast(&DataType::Float64)?;
1647                let values = casted.f64().context("column not f64")?;
1648
1649                // Compute mean and std
1650                let (mean, std) = compute_mean_std(values)?;
1651
1652                if std > 0.0 {
1653                    // Standardize: (x - mean) / std
1654                    let standardized = (values - mean) / std;
1655                    let named = standardized.with_name(col_name.clone().into());
1656
1657                    // Replace the column
1658                    result = result.drop(col_name)?;
1659                    result = result
1660                        .hstack(&[named.into_series().into()])
1661                        .context("failed to replace standardized column")?;
1662                }
1663            }
1664        }
1665    }
1666
1667    Ok(result)
1668}
1669
1670fn compute_mean_std(values: &ChunkedArray<Float64Type>) -> Result<(f64, f64)> {
1671    let vals: Vec<f64> = values.into_no_null_iter().collect();
1672    if vals.is_empty() {
1673        return Err(anyhow!("no values for mean/std computation"));
1674    }
1675
1676    let mean = vals.iter().sum::<f64>() / vals.len() as f64;
1677    let variance = vals.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / vals.len() as f64;
1678    let std = variance.sqrt();
1679
1680    Ok((mean, std))
1681}
1682
1683#[cfg(test)]
1684mod tests {
1685    use super::*;
1686    use std::path::PathBuf;
1687
1688    #[test]
1689    fn proposal_helper_builds_correct_fact() {
1690        let p = proposal("my-agent", ContextKey::Diagnostic, "id-1", "content-1");
1691        assert_eq!(p.provenance, "my-agent");
1692        assert_eq!(p.key, ContextKey::Diagnostic);
1693        assert_eq!(p.id, "id-1");
1694        assert_eq!(p.content, "content-1");
1695    }
1696
1697    #[test]
1698    fn training_plan_serde_roundtrip() {
1699        let plan = TrainingPlan {
1700            iteration: 3,
1701            max_rows: 1000,
1702            train_fraction: 0.7,
1703            val_fraction: 0.2,
1704            infer_fraction: 0.1,
1705            quality_threshold: 0.8,
1706        };
1707        let json = serde_json::to_string(&plan).unwrap();
1708        let restored: TrainingPlan = serde_json::from_str(&json).unwrap();
1709        assert_eq!(restored.iteration, 3);
1710        assert_eq!(restored.max_rows, 1000);
1711        assert!((restored.train_fraction - 0.7).abs() < f64::EPSILON);
1712    }
1713
1714    #[test]
1715    fn dataset_split_serde_roundtrip() {
1716        let split = DatasetSplit {
1717            source_path: "/data/src.parquet".into(),
1718            train_path: "/data/train.parquet".into(),
1719            val_path: "/data/val.parquet".into(),
1720            infer_path: "/data/infer.parquet".into(),
1721            total_rows: 1000,
1722            max_rows: 800,
1723            train_rows: 640,
1724            val_rows: 120,
1725            infer_rows: 40,
1726            iteration: 1,
1727        };
1728        let json = serde_json::to_string(&split).unwrap();
1729        let restored: DatasetSplit = serde_json::from_str(&json).unwrap();
1730        assert_eq!(restored.total_rows, 1000);
1731        assert_eq!(
1732            restored.train_rows + restored.val_rows + restored.infer_rows,
1733            800
1734        );
1735    }
1736
1737    #[test]
1738    fn baseline_model_serde_roundtrip() {
1739        let model = BaselineModel {
1740            target_column: "price".into(),
1741            mean: 42.5,
1742        };
1743        let json = serde_json::to_string(&model).unwrap();
1744        let restored: BaselineModel = serde_json::from_str(&json).unwrap();
1745        assert_eq!(restored.target_column, "price");
1746        assert!((restored.mean - 42.5).abs() < f64::EPSILON);
1747    }
1748
1749    #[test]
1750    fn evaluation_report_success_ratio_bounds() {
1751        let report = EvaluationReport {
1752            model_path: "/model".into(),
1753            metric: "mae".into(),
1754            value: 10.0,
1755            mean_abs_target: 100.0,
1756            success_ratio: 0.9,
1757            val_rows: 50,
1758            iteration: 1,
1759        };
1760        assert!(report.success_ratio >= 0.0 && report.success_ratio <= 1.0);
1761    }
1762
1763    #[test]
1764    fn feature_interaction_construction() {
1765        let fi = FeatureInteraction {
1766            name: "a_x_b".into(),
1767            left: "a".into(),
1768            right: "b".into(),
1769            op: "multiply".into(),
1770        };
1771        assert_eq!(fi.name, "a_x_b");
1772        assert_eq!(fi.op, "multiply");
1773    }
1774
1775    #[test]
1776    fn feature_spec_serde_roundtrip() {
1777        let spec = FeatureSpec {
1778            kind: "feature_spec".into(),
1779            iteration: 2,
1780            target_column: "target".into(),
1781            numeric_features: vec!["a".into(), "b".into()],
1782            categorical_features: vec!["c".into()],
1783            normalization: "standardize".into(),
1784            interactions: vec![],
1785        };
1786        let json = serde_json::to_string(&spec).unwrap();
1787        let restored: FeatureSpec = serde_json::from_str(&json).unwrap();
1788        assert_eq!(restored.numeric_features.len(), 2);
1789        assert_eq!(restored.categorical_features.len(), 1);
1790    }
1791
1792    #[test]
1793    fn hyperparam_search_plan_construction() {
1794        let mut params = HashMap::new();
1795        params.insert("lr".to_string(), vec![0.001, 0.01]);
1796        let plan = HyperparameterSearchPlan {
1797            kind: "hyperparam_plan".into(),
1798            iteration: 1,
1799            max_trials: 10,
1800            early_stopping: true,
1801            params,
1802        };
1803        assert_eq!(plan.max_trials, 10);
1804        assert!(plan.early_stopping);
1805        assert_eq!(plan.params["lr"].len(), 2);
1806    }
1807
1808    #[test]
1809    fn hyperparam_search_result_serde_roundtrip() {
1810        let mut best = HashMap::new();
1811        best.insert("lr".to_string(), 0.01);
1812        let result = HyperparameterSearchResult {
1813            kind: "hyperparam_result".into(),
1814            iteration: 1,
1815            best_params: best,
1816            score: 0.85,
1817        };
1818        let json = serde_json::to_string(&result).unwrap();
1819        let restored: HyperparameterSearchResult = serde_json::from_str(&json).unwrap();
1820        assert!((restored.score - 0.85).abs() < f64::EPSILON);
1821        assert!((restored.best_params["lr"] - 0.01).abs() < f64::EPSILON);
1822    }
1823
1824    #[test]
1825    fn model_registry_record_construction() {
1826        let mut metrics = HashMap::new();
1827        metrics.insert("mae".to_string(), 5.0);
1828        let record = ModelRegistryRecord {
1829            kind: "model_registry".into(),
1830            iteration: 1,
1831            model_path: "/models/v1.json".into(),
1832            metrics,
1833            provenance: "test".into(),
1834        };
1835        assert_eq!(record.metrics["mae"], 5.0);
1836    }
1837
1838    #[test]
1839    fn monitoring_report_status_values() {
1840        let healthy = MonitoringReport {
1841            kind: "monitoring".into(),
1842            iteration: 1,
1843            metric: "mae".into(),
1844            value: 5.0,
1845            baseline: 100.0,
1846            status: "healthy".into(),
1847        };
1848        assert_eq!(healthy.status, "healthy");
1849
1850        let needs_attention = MonitoringReport {
1851            status: "needs_attention".into(),
1852            ..healthy.clone()
1853        };
1854        assert_eq!(needs_attention.status, "needs_attention");
1855    }
1856
1857    #[test]
1858    fn deployment_decision_deploy_vs_hold() {
1859        let deploy = DeploymentDecision {
1860            kind: "deployment_decision".into(),
1861            iteration: 1,
1862            action: "deploy".into(),
1863            reason: "meets threshold".into(),
1864            retrain: false,
1865        };
1866        assert!(!deploy.retrain);
1867
1868        let hold = DeploymentDecision {
1869            action: "hold".into(),
1870            retrain: true,
1871            ..deploy.clone()
1872        };
1873        assert!(hold.retrain);
1874        assert_eq!(hold.action, "hold");
1875    }
1876
1877    #[test]
1878    fn data_validation_agent_default() {
1879        let agent = DataValidationAgent::default();
1880        let agent2 = DataValidationAgent::new();
1881        assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1882    }
1883
1884    #[test]
1885    fn feature_engineering_agent_default() {
1886        let agent = FeatureEngineeringAgent::default();
1887        let agent2 = FeatureEngineeringAgent::new();
1888        assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1889    }
1890
1891    #[test]
1892    fn model_evaluation_agent_default() {
1893        let agent = ModelEvaluationAgent::default();
1894        let agent2 = ModelEvaluationAgent::new();
1895        assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1896    }
1897
1898    #[test]
1899    fn model_registry_agent_default() {
1900        let agent = ModelRegistryAgent::default();
1901        let agent2 = ModelRegistryAgent::new();
1902        assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1903    }
1904
1905    #[test]
1906    fn monitoring_agent_default() {
1907        let agent = MonitoringAgent::default();
1908        let agent2 = MonitoringAgent::new();
1909        assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1910    }
1911
1912    #[test]
1913    fn deployment_agent_default() {
1914        let agent = DeploymentAgent::default();
1915        let agent2 = DeploymentAgent::new();
1916        assert_eq!(format!("{:?}", agent), format!("{:?}", agent2));
1917    }
1918
1919    #[test]
1920    fn dataset_agent_paths() {
1921        let agent = DatasetAgent::new(PathBuf::from("/tmp/data"));
1922        assert_eq!(
1923            agent.dataset_path(),
1924            PathBuf::from("/tmp/data/california_housing_train.parquet")
1925        );
1926        let (train, val, infer) = agent.split_paths();
1927        assert_eq!(train, PathBuf::from("/tmp/data/train.parquet"));
1928        assert_eq!(val, PathBuf::from("/tmp/data/val.parquet"));
1929        assert_eq!(infer, PathBuf::from("/tmp/data/infer.parquet"));
1930    }
1931
1932    #[test]
1933    fn model_training_agent_model_path() {
1934        let agent = ModelTrainingAgent::new(PathBuf::from("/tmp/models"));
1935        assert_eq!(
1936            agent.model_path(),
1937            PathBuf::from("/tmp/models/baseline_mean.json")
1938        );
1939    }
1940
1941    #[test]
1942    fn sample_inference_agent_construction() {
1943        let agent = SampleInferenceAgent::new(100);
1944        assert_eq!(agent.max_rows, 100);
1945    }
1946
1947    #[test]
1948    fn hyperparameter_search_agent_construction() {
1949        let agent = HyperparameterSearchAgent::new(50);
1950        assert_eq!(agent.max_trials, 50);
1951    }
1952
1953    #[test]
1954    fn is_numeric_dtype_comprehensive() {
1955        assert!(is_numeric_dtype(&DataType::Float32));
1956        assert!(is_numeric_dtype(&DataType::Float64));
1957        assert!(is_numeric_dtype(&DataType::Int64));
1958        assert!(!is_numeric_dtype(&DataType::String));
1959        assert!(!is_numeric_dtype(&DataType::Boolean));
1960    }
1961
1962    #[test]
1963    fn split_feature_columns_separates_types() {
1964        let df = df! {
1965            "num1" => [1.0, 2.0],
1966            "num2" => [3i32, 4],
1967            "cat1" => ["a", "b"],
1968            "target" => [10.0, 20.0]
1969        }
1970        .unwrap();
1971        let (numeric, categorical) = split_feature_columns(&df, "target");
1972        assert!(numeric.contains(&"num1".to_string()));
1973        assert!(numeric.contains(&"num2".to_string()));
1974        assert!(categorical.contains(&"cat1".to_string()));
1975        assert!(!numeric.contains(&"target".to_string()));
1976        assert!(!categorical.contains(&"target".to_string()));
1977    }
1978
1979    #[test]
1980    fn select_target_column_prefers_named_target() {
1981        let df = df! {
1982            "x" => [1.0, 2.0],
1983            "median_house_value" => [100.0, 200.0]
1984        }
1985        .unwrap();
1986        let (name, _) = select_target_column(&df).unwrap();
1987        assert_eq!(name, "median_house_value");
1988    }
1989
1990    #[test]
1991    fn select_target_column_falls_back_to_last_numeric() {
1992        let df = df! {
1993            "a" => [1.0, 2.0],
1994            "b" => [3.0, 4.0]
1995        }
1996        .unwrap();
1997        let (name, _) = select_target_column(&df).unwrap();
1998        assert_eq!(name, "b");
1999    }
2000
2001    #[test]
2002    fn select_target_column_fails_with_no_numeric() {
2003        let df = df! {
2004            "text" => ["a", "b"]
2005        }
2006        .unwrap();
2007        assert!(select_target_column(&df).is_err());
2008    }
2009
2010    #[test]
2011    fn mean_of_series_computes_correctly() {
2012        let series = Series::new("v".into(), &[2.0f64, 4.0, 6.0]);
2013        let mean = mean_of_series(&series).unwrap();
2014        assert!((mean - 4.0).abs() < f64::EPSILON);
2015    }
2016
2017    #[test]
2018    fn mean_abs_error_computes_correctly() {
2019        let series = Series::new("v".into(), &[10.0f64, 20.0, 30.0]);
2020        let mae = mean_abs_error(&series, 20.0).unwrap();
2021        // |10-20| + |20-20| + |30-20| = 10+0+10 = 20, /3 = 6.666...
2022        assert!((mae - 20.0 / 3.0).abs() < 1e-10);
2023    }
2024
2025    #[test]
2026    fn mean_abs_value_computes_correctly() {
2027        let series = Series::new("v".into(), &[-5.0f64, 5.0, 10.0]);
2028        let mav = mean_abs_value(&series).unwrap();
2029        assert!((mav - 20.0 / 3.0).abs() < 1e-10);
2030    }
2031
2032    #[test]
2033    fn compute_numeric_stats_basic() {
2034        let series = Series::new("v".into(), &[2.0f64, 4.0, 6.0]);
2035        let casted = series.cast(&DataType::Float64).unwrap();
2036        let (mean, std, outliers) = compute_numeric_stats(&casted).unwrap();
2037        assert!((mean - 4.0).abs() < 1e-10);
2038        assert!(std > 0.0);
2039        assert_eq!(outliers, 0);
2040    }
2041
2042    #[test]
2043    fn compute_numeric_stats_empty_fails() {
2044        let series = Series::new("v".into(), Vec::<f64>::new());
2045        assert!(compute_numeric_stats(&series).is_err());
2046    }
2047
2048    #[test]
2049    fn compute_numeric_stats_constant_series() {
2050        let series = Series::new("v".into(), &[5.0f64, 5.0, 5.0]);
2051        let (mean, std, outliers) = compute_numeric_stats(&series).unwrap();
2052        assert!((mean - 5.0).abs() < f64::EPSILON);
2053        assert!((std - 0.0).abs() < f64::EPSILON);
2054        assert_eq!(outliers, 0);
2055    }
2056
2057    #[test]
2058    fn data_quality_report_serde_roundtrip() {
2059        let mut missingness = HashMap::new();
2060        missingness.insert("a".to_string(), 0.1);
2061        let report = DataQualityReport {
2062            kind: "data_quality".into(),
2063            iteration: 1,
2064            source_path: "/data/train.parquet".into(),
2065            rows_checked: 100,
2066            missingness,
2067            numeric_means: HashMap::new(),
2068            outlier_counts: HashMap::new(),
2069            drift_score: Some(0.05),
2070        };
2071        let json = serde_json::to_string(&report).unwrap();
2072        let restored: DataQualityReport = serde_json::from_str(&json).unwrap();
2073        assert_eq!(restored.rows_checked, 100);
2074        assert!((restored.drift_score.unwrap() - 0.05).abs() < f64::EPSILON);
2075    }
2076
2077    #[test]
2078    fn inference_sample_serde_roundtrip() {
2079        let sample = InferenceSample {
2080            model_path: "/model".into(),
2081            target_column: "target".into(),
2082            rows: 3,
2083            predictions: vec![1.0, 2.0, 3.0],
2084            actuals: vec![1.1, 2.1, 3.1],
2085            iteration: 1,
2086        };
2087        let json = serde_json::to_string(&sample).unwrap();
2088        let restored: InferenceSample = serde_json::from_str(&json).unwrap();
2089        assert_eq!(restored.predictions.len(), 3);
2090        assert_eq!(restored.actuals.len(), 3);
2091    }
2092
2093    #[test]
2094    fn apply_feature_spec_creates_interaction_column() {
2095        let df = df! {
2096            "a" => [1.0, 2.0, 3.0],
2097            "b" => [4.0, 5.0, 6.0],
2098            "target" => [10.0, 20.0, 30.0]
2099        }
2100        .unwrap();
2101
2102        let spec = FeatureSpec {
2103            kind: "feature_spec".to_string(),
2104            iteration: 1,
2105            target_column: "target".to_string(),
2106            numeric_features: vec!["a".to_string(), "b".to_string()],
2107            categorical_features: vec![],
2108            normalization: "none".to_string(),
2109            interactions: vec![FeatureInteraction {
2110                name: "a_x_b".to_string(),
2111                left: "a".to_string(),
2112                right: "b".to_string(),
2113                op: "multiply".to_string(),
2114            }],
2115        };
2116
2117        let result = apply_feature_spec(&df, &spec).unwrap();
2118
2119        // Check interaction column exists
2120        assert!(result.column("a_x_b").is_ok());
2121
2122        // Check values: 1*4=4, 2*5=10, 3*6=18
2123        let interaction = result.column("a_x_b").unwrap().f64().unwrap();
2124        let values: Vec<f64> = interaction.into_no_null_iter().collect();
2125        assert_eq!(values, vec![4.0, 10.0, 18.0]);
2126    }
2127
2128    #[test]
2129    fn apply_feature_spec_standardizes_numeric_features() {
2130        let df = df! {
2131            "a" => [1.0, 2.0, 3.0, 4.0, 5.0],
2132            "target" => [10.0, 20.0, 30.0, 40.0, 50.0]
2133        }
2134        .unwrap();
2135
2136        let spec = FeatureSpec {
2137            kind: "feature_spec".to_string(),
2138            iteration: 1,
2139            target_column: "target".to_string(),
2140            numeric_features: vec!["a".to_string()],
2141            categorical_features: vec![],
2142            normalization: "standardize".to_string(),
2143            interactions: vec![],
2144        };
2145
2146        let result = apply_feature_spec(&df, &spec).unwrap();
2147
2148        // Standardized values should have mean ~0 and std ~1
2149        let a_col = result.column("a").unwrap().f64().unwrap();
2150        let values: Vec<f64> = a_col.into_no_null_iter().collect();
2151
2152        let mean: f64 = values.iter().sum::<f64>() / values.len() as f64;
2153        assert!(mean.abs() < 1e-10, "mean should be ~0, got {}", mean);
2154
2155        let variance: f64 =
2156            values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
2157        let std = variance.sqrt();
2158        assert!((std - 1.0).abs() < 1e-10, "std should be ~1, got {}", std);
2159    }
2160
2161    #[test]
2162    fn apply_feature_spec_handles_add_operation() {
2163        let df = df! {
2164            "a" => [1.0, 2.0, 3.0],
2165            "b" => [10.0, 20.0, 30.0]
2166        }
2167        .unwrap();
2168
2169        let spec = FeatureSpec {
2170            kind: "feature_spec".to_string(),
2171            iteration: 1,
2172            target_column: "target".to_string(),
2173            numeric_features: vec![],
2174            categorical_features: vec![],
2175            normalization: "none".to_string(),
2176            interactions: vec![FeatureInteraction {
2177                name: "a_plus_b".to_string(),
2178                left: "a".to_string(),
2179                right: "b".to_string(),
2180                op: "add".to_string(),
2181            }],
2182        };
2183
2184        let result = apply_feature_spec(&df, &spec).unwrap();
2185        let col = result.column("a_plus_b").unwrap().f64().unwrap();
2186        let values: Vec<f64> = col.into_no_null_iter().collect();
2187        assert_eq!(values, vec![11.0, 22.0, 33.0]);
2188    }
2189}