Skip to main content

converge_analytics/
training.rs

1// Copyright (c) 2026 Aprio One AB
2// Author: Kenneth Pernyer, kenneth@pernyer.se
3
4use anyhow::{anyhow, Context as _, Result};
5use converge_core::{Agent, AgentEffect, Context, ContextKey, Fact};
6use polars::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{create_dir_all, File};
10use std::io::Write;
11use std::path::{Path, PathBuf};
12
13const DATASET_URL: &str = "https://huggingface.co/datasets/gvlassis/california_housing/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet";
14const TARGET_COLUMN: &str = "median_house_value";
15
16#[derive(Debug, Serialize, Deserialize, Clone)]
17pub struct TrainingPlan {
18    pub iteration: usize,
19    pub max_rows: usize,
20    pub train_fraction: f64,
21    pub val_fraction: f64,
22    pub infer_fraction: f64,
23    pub quality_threshold: f64,
24}
25
26#[derive(Debug, Serialize, Deserialize)]
27pub struct DatasetSplit {
28    pub source_path: String,
29    pub train_path: String,
30    pub val_path: String,
31    pub infer_path: String,
32    pub total_rows: usize,
33    pub max_rows: usize,
34    pub train_rows: usize,
35    pub val_rows: usize,
36    pub infer_rows: usize,
37    pub iteration: usize,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41pub struct BaselineModel {
42    pub target_column: String,
43    pub mean: f64,
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47pub struct ModelMetadata {
48    pub model_path: String,
49    pub target_column: String,
50    pub train_rows: usize,
51    pub baseline_mean: f64,
52    pub iteration: usize,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct EvaluationReport {
57    pub model_path: String,
58    pub metric: String,
59    pub value: f64,
60    pub mean_abs_target: f64,
61    pub success_ratio: f64,
62    pub val_rows: usize,
63    pub iteration: usize,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67pub struct InferenceSample {
68    pub model_path: String,
69    pub target_column: String,
70    pub rows: usize,
71    pub predictions: Vec<f64>,
72    pub actuals: Vec<f64>,
73    pub iteration: usize,
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77pub struct DataQualityReport {
78    pub kind: String,
79    pub iteration: usize,
80    pub source_path: String,
81    pub rows_checked: usize,
82    pub missingness: HashMap<String, f64>,
83    pub numeric_means: HashMap<String, f64>,
84    pub outlier_counts: HashMap<String, usize>,
85    pub drift_score: Option<f64>,
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89pub struct FeatureInteraction {
90    pub name: String,
91    pub left: String,
92    pub right: String,
93    pub op: String,
94}
95
96#[derive(Debug, Serialize, Deserialize)]
97pub struct FeatureSpec {
98    pub kind: String,
99    pub iteration: usize,
100    pub target_column: String,
101    pub numeric_features: Vec<String>,
102    pub categorical_features: Vec<String>,
103    pub normalization: String,
104    pub interactions: Vec<FeatureInteraction>,
105}
106
107#[derive(Debug, Serialize, Deserialize)]
108pub struct HyperparameterSearchPlan {
109    pub kind: String,
110    pub iteration: usize,
111    pub max_trials: usize,
112    pub early_stopping: bool,
113    pub params: HashMap<String, Vec<f64>>,
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct HyperparameterSearchResult {
118    pub kind: String,
119    pub iteration: usize,
120    pub best_params: HashMap<String, f64>,
121    pub score: f64,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125pub struct ModelRegistryRecord {
126    pub kind: String,
127    pub iteration: usize,
128    pub model_path: String,
129    pub metrics: HashMap<String, f64>,
130    pub provenance: String,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134pub struct MonitoringReport {
135    pub kind: String,
136    pub iteration: usize,
137    pub metric: String,
138    pub value: f64,
139    pub baseline: f64,
140    pub status: String,
141}
142
143#[derive(Debug, Serialize, Deserialize)]
144pub struct DeploymentDecision {
145    pub kind: String,
146    pub iteration: usize,
147    pub action: String,
148    pub reason: String,
149    pub retrain: bool,
150}
151
152pub struct DatasetAgent {
153    data_dir: PathBuf,
154}
155
156impl DatasetAgent {
157    pub fn new(data_dir: PathBuf) -> Self {
158        Self { data_dir }
159    }
160
161    fn dataset_path(&self) -> PathBuf {
162        self.data_dir.join("california_housing_train.parquet")
163    }
164
165    fn split_paths(&self) -> (PathBuf, PathBuf, PathBuf) {
166        (
167            self.data_dir.join("train.parquet"),
168            self.data_dir.join("val.parquet"),
169            self.data_dir.join("infer.parquet"),
170        )
171    }
172}
173
174pub struct DataValidationAgent;
175
176impl DataValidationAgent {
177    pub fn new() -> Self {
178        Self
179    }
180}
181
182impl Agent for DataValidationAgent {
183    fn name(&self) -> &str {
184        "DataValidationAgent"
185    }
186
187    fn dependencies(&self) -> &[ContextKey] {
188        &[ContextKey::Signals]
189    }
190
191    fn accepts(&self, ctx: &Context) -> bool {
192        ctx.has(ContextKey::Signals)
193            && match read_latest_split_from_ctx(ctx) {
194                Ok(split) => !has_data_quality_for_iteration(ctx, split.iteration),
195                Err(_) => false,
196            }
197    }
198
199    fn execute(&self, ctx: &Context) -> AgentEffect {
200        let split = match read_latest_split_from_ctx(ctx) {
201            Ok(split) => split,
202            Err(err) => {
203                return AgentEffect::with_fact(Fact::new(
204                    ContextKey::Diagnostic,
205                    "data-validation-error",
206                    err.to_string(),
207                ))
208            }
209        };
210
211        let df = match load_dataframe(Path::new(&split.train_path)) {
212            Ok(df) => df,
213            Err(err) => {
214                return AgentEffect::with_fact(Fact::new(
215                    ContextKey::Diagnostic,
216                    "data-validation-error",
217                    err.to_string(),
218                ))
219            }
220        };
221
222        let rows = df.height();
223        let mut missingness = HashMap::new();
224        let mut numeric_means = HashMap::new();
225        let mut outlier_counts = HashMap::new();
226
227        for series in df.get_columns() {
228            let name = series.name().to_string();
229            let null_ratio = if rows > 0 {
230                series.null_count() as f64 / rows as f64
231            } else {
232                0.0
233            };
234            missingness.insert(name.clone(), null_ratio);
235
236            if is_numeric_dtype(series.dtype()) {
237                if let Ok((mean, _std, outliers)) =
238                    compute_numeric_stats(series.as_materialized_series())
239                {
240                    numeric_means.insert(name.clone(), mean);
241                    outlier_counts.insert(name, outliers);
242                }
243            }
244        }
245
246        let drift_score = drift_score_from_ctx(ctx, split.iteration, &numeric_means);
247
248        let report = DataQualityReport {
249            kind: "data_quality".to_string(),
250            iteration: split.iteration,
251            source_path: split.train_path.clone(),
252            rows_checked: rows,
253            missingness,
254            numeric_means,
255            outlier_counts,
256            drift_score,
257        };
258
259        let content = serde_json::to_string(&report).unwrap_or_default();
260        AgentEffect::with_fact(Fact::new(
261            ContextKey::Signals,
262            format!("data-quality-{}", split.iteration),
263            content,
264        ))
265    }
266}
267
268pub struct FeatureEngineeringAgent;
269
270impl FeatureEngineeringAgent {
271    pub fn new() -> Self {
272        Self
273    }
274}
275
276impl Agent for FeatureEngineeringAgent {
277    fn name(&self) -> &str {
278        "FeatureEngineeringAgent"
279    }
280
281    fn dependencies(&self) -> &[ContextKey] {
282        &[ContextKey::Signals]
283    }
284
285    fn accepts(&self, ctx: &Context) -> bool {
286        ctx.has(ContextKey::Signals)
287            && match read_latest_split_from_ctx(ctx) {
288                Ok(split) => !has_feature_spec_for_iteration(ctx, split.iteration),
289                Err(_) => false,
290            }
291    }
292
293    fn execute(&self, ctx: &Context) -> AgentEffect {
294        let split = match read_latest_split_from_ctx(ctx) {
295            Ok(split) => split,
296            Err(err) => {
297                return AgentEffect::with_fact(Fact::new(
298                    ContextKey::Diagnostic,
299                    "feature-engineering-error",
300                    err.to_string(),
301                ))
302            }
303        };
304
305        let df = match load_dataframe(Path::new(&split.train_path)) {
306            Ok(df) => df,
307            Err(err) => {
308                return AgentEffect::with_fact(Fact::new(
309                    ContextKey::Diagnostic,
310                    "feature-engineering-error",
311                    err.to_string(),
312                ))
313            }
314        };
315
316        let (target_column, _) = match select_target_column(&df) {
317            Ok(value) => value,
318            Err(err) => {
319                return AgentEffect::with_fact(Fact::new(
320                    ContextKey::Diagnostic,
321                    "feature-engineering-error",
322                    err.to_string(),
323                ))
324            }
325        };
326
327        let (numeric_features, categorical_features) =
328            split_feature_columns(&df, &target_column);
329
330        let mut interactions = Vec::new();
331        if numeric_features.len() >= 2 {
332            interactions.push(FeatureInteraction {
333                name: format!("{}_x_{}", numeric_features[0], numeric_features[1]),
334                left: numeric_features[0].clone(),
335                right: numeric_features[1].clone(),
336                op: "multiply".to_string(),
337            });
338        }
339
340        let spec = FeatureSpec {
341            kind: "feature_spec".to_string(),
342            iteration: split.iteration,
343            target_column,
344            numeric_features,
345            categorical_features,
346            normalization: "standardize".to_string(),
347            interactions,
348        };
349
350        let content = serde_json::to_string(&spec).unwrap_or_default();
351        AgentEffect::with_fact(Fact::new(
352            ContextKey::Constraints,
353            format!("feature-spec-{}", split.iteration),
354            content,
355        ))
356    }
357}
358
359pub struct HyperparameterSearchAgent {
360    max_trials: usize,
361}
362
363impl HyperparameterSearchAgent {
364    pub fn new(max_trials: usize) -> Self {
365        Self { max_trials }
366    }
367}
368
369impl Agent for HyperparameterSearchAgent {
370    fn name(&self) -> &str {
371        "HyperparameterSearchAgent"
372    }
373
374    fn dependencies(&self) -> &[ContextKey] {
375        &[ContextKey::Constraints, ContextKey::Signals]
376    }
377
378    fn accepts(&self, ctx: &Context) -> bool {
379        ctx.has(ContextKey::Signals)
380            && match read_latest_split_from_ctx(ctx) {
381                Ok(split) => !has_hyperparam_result_for_iteration(ctx, split.iteration),
382                Err(_) => false,
383            }
384    }
385
386    fn execute(&self, ctx: &Context) -> AgentEffect {
387        let split = match read_latest_split_from_ctx(ctx) {
388            Ok(split) => split,
389            Err(err) => {
390                return AgentEffect::with_fact(Fact::new(
391                    ContextKey::Diagnostic,
392                    "hyperparam-search-error",
393                    err.to_string(),
394                ))
395            }
396        };
397
398        let training_plan = read_latest_plan_from_ctx(ctx).unwrap_or_else(|| TrainingPlan {
399            iteration: split.iteration,
400            max_rows: split.max_rows,
401            train_fraction: 0.8,
402            val_fraction: 0.15,
403            infer_fraction: 0.05,
404            quality_threshold: 0.75,
405        });
406
407        let mut params = HashMap::new();
408        params.insert("learning_rate".to_string(), vec![0.001, 0.01, 0.1]);
409        params.insert("hidden_size".to_string(), vec![8.0, 16.0, 32.0]);
410
411        let plan = HyperparameterSearchPlan {
412            kind: "hyperparam_plan".to_string(),
413            iteration: split.iteration,
414            max_trials: self.max_trials,
415            early_stopping: true,
416            params,
417        };
418
419        let mut best_params = HashMap::new();
420        best_params.insert("learning_rate".to_string(), 0.01);
421        best_params.insert("hidden_size".to_string(), 16.0);
422        let score = (1.0 - training_plan.quality_threshold)
423            * plan.max_trials as f64
424            / plan.iteration.max(1) as f64;
425        let result = HyperparameterSearchResult {
426            kind: "hyperparam_result".to_string(),
427            iteration: split.iteration,
428            best_params,
429            score,
430        };
431
432        let plan_content = serde_json::to_string(&plan).unwrap_or_default();
433        let result_content = serde_json::to_string(&result).unwrap_or_default();
434
435        let mut effect = AgentEffect::empty();
436        effect.facts.push(Fact::new(
437            ContextKey::Constraints,
438            format!("hyperparam-plan-{}", split.iteration),
439            plan_content,
440        ));
441        effect.facts.push(Fact::new(
442            ContextKey::Evaluations,
443            format!("hyperparam-result-{}", split.iteration),
444            result_content,
445        ));
446        effect
447    }
448}
449
450impl Agent for DatasetAgent {
451    fn name(&self) -> &str {
452        "DatasetAgent (HuggingFace)"
453    }
454
455    fn dependencies(&self) -> &[ContextKey] {
456        &[ContextKey::Seeds]
457    }
458
459    fn accepts(&self, ctx: &Context) -> bool {
460        if !ctx.has(ContextKey::Seeds) {
461            return false;
462        }
463
464        let plan = read_latest_plan_from_ctx(ctx);
465        if let Some(plan) = plan {
466            return !has_split_for_iteration(ctx, plan.iteration);
467        }
468
469        !ctx.has(ContextKey::Signals)
470    }
471
472    fn execute(&self, ctx: &Context) -> AgentEffect {
473        if let Err(err) = create_dir_all(&self.data_dir) {
474            return AgentEffect::with_fact(Fact::new(
475                ContextKey::Diagnostic,
476                "dataset-agent-error",
477                err.to_string(),
478            ));
479        }
480
481        let dataset_path = self.dataset_path();
482        if let Err(err) = download_dataset_if_missing(&dataset_path) {
483            return AgentEffect::with_fact(Fact::new(
484                ContextKey::Diagnostic,
485                "dataset-agent-error",
486                err.to_string(),
487            ));
488        }
489
490        let df = match load_dataframe(&dataset_path) {
491            Ok(df) => df,
492            Err(err) => {
493                return AgentEffect::with_fact(Fact::new(
494                    ContextKey::Diagnostic,
495                    "dataset-agent-error",
496                    err.to_string(),
497                ))
498            }
499        };
500
501        let total_rows = df.height();
502        if total_rows < 10 {
503            return AgentEffect::with_fact(Fact::new(
504                ContextKey::Diagnostic,
505                "dataset-agent-error",
506                "dataset too small for splitting",
507            ));
508        }
509
510        let plan = read_latest_plan_from_ctx(ctx).unwrap_or_else(|| TrainingPlan {
511            iteration: 1,
512            max_rows: total_rows,
513            train_fraction: 0.8,
514            val_fraction: 0.15,
515            infer_fraction: 0.05,
516            quality_threshold: 0.75,
517        });
518
519        let max_rows = plan.max_rows.min(total_rows).max(10);
520        let df = df.slice(0, max_rows);
521
522        let mut train_rows = ((max_rows as f64) * plan.train_fraction).floor() as usize;
523        let mut val_rows = ((max_rows as f64) * plan.val_fraction).floor() as usize;
524        let mut infer_rows = max_rows.saturating_sub(train_rows + val_rows);
525        if infer_rows == 0 {
526            if val_rows > 1 {
527                val_rows -= 1;
528            } else if train_rows > 1 {
529                train_rows -= 1;
530            }
531            infer_rows = max_rows.saturating_sub(train_rows + val_rows).max(1);
532        }
533
534        let (train_path, val_path, infer_path) = self.split_paths();
535        let train_df = df.slice(0, train_rows);
536        let val_df = df.slice(train_rows as i64, val_rows);
537        let infer_df = df.slice((train_rows + val_rows) as i64, infer_rows);
538
539        if let Err(err) = write_parquet(&train_df, &train_path)
540            .and_then(|_| write_parquet(&val_df, &val_path))
541            .and_then(|_| write_parquet(&infer_df, &infer_path))
542        {
543            return AgentEffect::with_fact(Fact::new(
544                ContextKey::Diagnostic,
545                "dataset-agent-error",
546                err.to_string(),
547            ));
548        }
549
550        let split = DatasetSplit {
551            source_path: dataset_path.to_string_lossy().to_string(),
552            train_path: train_path.to_string_lossy().to_string(),
553            val_path: val_path.to_string_lossy().to_string(),
554            infer_path: infer_path.to_string_lossy().to_string(),
555            total_rows,
556            max_rows,
557            train_rows,
558            val_rows,
559            infer_rows,
560            iteration: plan.iteration,
561        };
562
563        let content = serde_json::to_string(&split).unwrap_or_default();
564        AgentEffect::with_fact(Fact::new(
565            ContextKey::Signals,
566            format!("dataset-split-{}", plan.iteration),
567            content,
568        ))
569    }
570}
571
572pub struct ModelTrainingAgent {
573    model_dir: PathBuf,
574}
575
576impl ModelTrainingAgent {
577    pub fn new(model_dir: PathBuf) -> Self {
578        Self { model_dir }
579    }
580
581    fn model_path(&self) -> PathBuf {
582        self.model_dir.join("baseline_mean.json")
583    }
584}
585
586impl Agent for ModelTrainingAgent {
587    fn name(&self) -> &str {
588        "ModelTrainingAgent (Baseline)"
589    }
590
591    fn dependencies(&self) -> &[ContextKey] {
592        &[ContextKey::Signals]
593    }
594
595    fn accepts(&self, ctx: &Context) -> bool {
596        if !ctx.has(ContextKey::Signals) {
597            return false;
598        }
599        let split = match read_latest_split_from_ctx(ctx) {
600            Ok(split) => split,
601            Err(_) => return false,
602        };
603        !has_model_for_iteration(ctx, split.iteration)
604    }
605
606    fn execute(&self, ctx: &Context) -> AgentEffect {
607        let split = match read_latest_split_from_ctx(ctx) {
608            Ok(split) => split,
609            Err(err) => {
610                return AgentEffect::with_fact(Fact::new(
611                    ContextKey::Diagnostic,
612                    "model-training-error",
613                    err.to_string(),
614                ))
615            }
616        };
617
618        if let Err(err) = create_dir_all(&self.model_dir) {
619            return AgentEffect::with_fact(Fact::new(
620                ContextKey::Diagnostic,
621                "model-training-error",
622                err.to_string(),
623            ));
624        }
625
626        let train_df = match load_dataframe(Path::new(&split.train_path)) {
627            Ok(df) => df,
628            Err(err) => {
629                return AgentEffect::with_fact(Fact::new(
630                    ContextKey::Diagnostic,
631                    "model-training-error",
632                    err.to_string(),
633                ))
634            }
635        };
636
637        let (target_name, target) = match select_target_column(&train_df) {
638            Ok(value) => value,
639            Err(err) => {
640                return AgentEffect::with_fact(Fact::new(
641                    ContextKey::Diagnostic,
642                    "model-training-error",
643                    err.to_string(),
644                ))
645            }
646        };
647
648        let mean = match mean_of_series(&target) {
649            Ok(value) => value,
650            Err(err) => {
651                return AgentEffect::with_fact(Fact::new(
652                    ContextKey::Diagnostic,
653                    "model-training-error",
654                    err.to_string(),
655                ))
656            }
657        };
658
659        let model = BaselineModel {
660            target_column: target_name.clone(),
661            mean,
662        };
663
664        let model_path = self.model_path();
665        if let Err(err) = write_json(&model_path, &model) {
666            return AgentEffect::with_fact(Fact::new(
667                ContextKey::Diagnostic,
668                "model-training-error",
669                err.to_string(),
670            ));
671        }
672
673        let meta = ModelMetadata {
674            model_path: model_path.to_string_lossy().to_string(),
675            target_column: target_name,
676            train_rows: split.train_rows,
677            baseline_mean: mean,
678            iteration: split.iteration,
679        };
680
681        let content = serde_json::to_string(&meta).unwrap_or_default();
682        AgentEffect::with_fact(Fact::new(
683            ContextKey::Strategies,
684            format!("trained-model-{}", split.iteration),
685            content,
686        ))
687    }
688}
689
690pub struct ModelEvaluationAgent;
691
692impl ModelEvaluationAgent {
693    pub fn new() -> Self {
694        Self
695    }
696}
697
698pub struct ModelRegistryAgent;
699
700impl ModelRegistryAgent {
701    pub fn new() -> Self {
702        Self
703    }
704}
705
706impl Agent for ModelRegistryAgent {
707    fn name(&self) -> &str {
708        "ModelRegistryAgent"
709    }
710
711    fn dependencies(&self) -> &[ContextKey] {
712        &[ContextKey::Strategies, ContextKey::Evaluations]
713    }
714
715    fn accepts(&self, ctx: &Context) -> bool {
716        ctx.has(ContextKey::Strategies)
717            && ctx.has(ContextKey::Evaluations)
718            && match read_latest_model_meta_from_ctx(ctx) {
719                Ok(meta) => !has_registry_record_for_iteration(ctx, meta.iteration),
720                Err(_) => false,
721            }
722    }
723
724    fn execute(&self, ctx: &Context) -> AgentEffect {
725        let meta = match read_latest_model_meta_from_ctx(ctx) {
726            Ok(meta) => meta,
727            Err(err) => {
728                return AgentEffect::with_fact(Fact::new(
729                    ContextKey::Diagnostic,
730                    "model-registry-error",
731                    err.to_string(),
732                ))
733            }
734        };
735
736        let report = latest_evaluation_report(ctx, meta.iteration);
737        let mut metrics = HashMap::new();
738        if let Some(report) = report {
739            metrics.insert(report.metric, report.value);
740            metrics.insert("success_ratio".to_string(), report.success_ratio);
741        }
742
743        let record = ModelRegistryRecord {
744            kind: "model_registry".to_string(),
745            iteration: meta.iteration,
746            model_path: meta.model_path,
747            metrics,
748            provenance: "training_flow".to_string(),
749        };
750
751        let content = serde_json::to_string(&record).unwrap_or_default();
752        AgentEffect::with_fact(Fact::new(
753            ContextKey::Strategies,
754            format!("model-registry-{}", record.iteration),
755            content,
756        ))
757    }
758}
759
760pub struct MonitoringAgent;
761
762impl MonitoringAgent {
763    pub fn new() -> Self {
764        Self
765    }
766}
767
768impl Agent for MonitoringAgent {
769    fn name(&self) -> &str {
770        "MonitoringAgent"
771    }
772
773    fn dependencies(&self) -> &[ContextKey] {
774        &[ContextKey::Evaluations]
775    }
776
777    fn accepts(&self, ctx: &Context) -> bool {
778        ctx.has(ContextKey::Evaluations)
779            && match latest_evaluation_report(ctx, 0) {
780                Some(report) => !has_monitoring_report_for_iteration(ctx, report.iteration),
781                None => false,
782            }
783    }
784
785    fn execute(&self, ctx: &Context) -> AgentEffect {
786        let report = match latest_evaluation_report(ctx, 0) {
787            Some(report) => report,
788            None => return AgentEffect::empty(),
789        };
790
791        let status = if report.success_ratio >= 0.75 {
792            "healthy"
793        } else {
794            "needs_attention"
795        };
796
797        let monitoring = MonitoringReport {
798            kind: "monitoring".to_string(),
799            iteration: report.iteration,
800            metric: report.metric,
801            value: report.value,
802            baseline: report.mean_abs_target,
803            status: status.to_string(),
804        };
805
806        let content = serde_json::to_string(&monitoring).unwrap_or_default();
807        AgentEffect::with_fact(Fact::new(
808            ContextKey::Evaluations,
809            format!("monitoring-{}", report.iteration),
810            content,
811        ))
812    }
813}
814
815pub struct DeploymentAgent;
816
817impl DeploymentAgent {
818    pub fn new() -> Self {
819        Self
820    }
821}
822
823impl Agent for DeploymentAgent {
824    fn name(&self) -> &str {
825        "DeploymentAgent"
826    }
827
828    fn dependencies(&self) -> &[ContextKey] {
829        &[ContextKey::Evaluations, ContextKey::Strategies]
830    }
831
832    fn accepts(&self, ctx: &Context) -> bool {
833        ctx.has(ContextKey::Evaluations)
834            && ctx.has(ContextKey::Strategies)
835            && match latest_evaluation_report(ctx, 0) {
836                Some(report) => !has_deployment_decision_for_iteration(ctx, report.iteration),
837                None => false,
838            }
839    }
840
841    fn execute(&self, ctx: &Context) -> AgentEffect {
842        let report = match latest_evaluation_report(ctx, 0) {
843            Some(report) => report,
844            None => return AgentEffect::empty(),
845        };
846
847        let quality_threshold = read_latest_plan_from_ctx(ctx)
848            .map(|plan| plan.quality_threshold)
849            .unwrap_or(0.75);
850
851        let (action, retrain, reason) = if report.success_ratio >= quality_threshold {
852            ("deploy", false, "meets quality threshold")
853        } else {
854            ("hold", true, "below quality threshold")
855        };
856
857        let decision = DeploymentDecision {
858            kind: "deployment_decision".to_string(),
859            iteration: report.iteration,
860            action: action.to_string(),
861            reason: reason.to_string(),
862            retrain,
863        };
864
865        let content = serde_json::to_string(&decision).unwrap_or_default();
866        AgentEffect::with_fact(Fact::new(
867            ContextKey::Strategies,
868            format!("deployment-{}", report.iteration),
869            content,
870        ))
871    }
872}
873
874impl Agent for ModelEvaluationAgent {
875    fn name(&self) -> &str {
876        "ModelEvaluationAgent (MAE)"
877    }
878
879    fn dependencies(&self) -> &[ContextKey] {
880        &[ContextKey::Signals, ContextKey::Strategies]
881    }
882
883    fn accepts(&self, ctx: &Context) -> bool {
884        ctx.has(ContextKey::Signals)
885            && ctx.has(ContextKey::Strategies)
886            && match read_latest_split_from_ctx(ctx) {
887                Ok(split) => !has_evaluation_for_iteration(ctx, split.iteration),
888                Err(_) => false,
889            }
890    }
891
892    fn execute(&self, ctx: &Context) -> AgentEffect {
893        let split = match read_latest_split_from_ctx(ctx) {
894            Ok(split) => split,
895            Err(err) => {
896                return AgentEffect::with_fact(Fact::new(
897                    ContextKey::Diagnostic,
898                    "model-eval-error",
899                    err.to_string(),
900                ))
901            }
902        };
903
904        let model = match read_model_from_ctx(ctx) {
905            Ok(model) => model,
906            Err(err) => {
907                return AgentEffect::with_fact(Fact::new(
908                    ContextKey::Diagnostic,
909                    "model-eval-error",
910                    err.to_string(),
911                ))
912            }
913        };
914
915        let val_df = match load_dataframe(Path::new(&split.val_path)) {
916            Ok(df) => df,
917            Err(err) => {
918                return AgentEffect::with_fact(Fact::new(
919                    ContextKey::Diagnostic,
920                    "model-eval-error",
921                    err.to_string(),
922                ))
923            }
924        };
925
926        let target = match get_numeric_series(&val_df, &model.target_column) {
927            Ok(series) => series,
928            Err(err) => {
929                return AgentEffect::with_fact(Fact::new(
930                    ContextKey::Diagnostic,
931                    "model-eval-error",
932                    err.to_string(),
933                ))
934            }
935        };
936
937        let mae = match mean_abs_error(&target, model.mean) {
938            Ok(value) => value,
939            Err(err) => {
940                return AgentEffect::with_fact(Fact::new(
941                    ContextKey::Diagnostic,
942                    "model-eval-error",
943                    err.to_string(),
944                ))
945            }
946        };
947
948        let mean_abs = match mean_abs_value(&target) {
949            Ok(value) => value,
950            Err(err) => {
951                return AgentEffect::with_fact(Fact::new(
952                    ContextKey::Diagnostic,
953                    "model-eval-error",
954                    err.to_string(),
955                ))
956            }
957        };
958
959        let success_ratio = if mean_abs > 0.0 {
960            (1.0 - (mae / mean_abs)).clamp(0.0, 1.0)
961        } else {
962            0.0
963        };
964
965        let report = EvaluationReport {
966            model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
967            metric: "mae".to_string(),
968            value: mae,
969            mean_abs_target: mean_abs,
970            success_ratio,
971            val_rows: split.val_rows,
972            iteration: split.iteration,
973        };
974
975        let content = serde_json::to_string(&report).unwrap_or_default();
976        AgentEffect::with_fact(Fact::new(
977            ContextKey::Evaluations,
978            format!("model-eval-{}", split.iteration),
979            content,
980        ))
981    }
982}
983
984pub struct SampleInferenceAgent {
985    pub max_rows: usize,
986}
987
988impl SampleInferenceAgent {
989    pub fn new(max_rows: usize) -> Self {
990        Self { max_rows }
991    }
992}
993
994impl Agent for SampleInferenceAgent {
995    fn name(&self) -> &str {
996        "SampleInferenceAgent (Baseline)"
997    }
998
999    fn dependencies(&self) -> &[ContextKey] {
1000        &[ContextKey::Signals, ContextKey::Strategies]
1001    }
1002
1003    fn accepts(&self, ctx: &Context) -> bool {
1004        ctx.has(ContextKey::Signals)
1005            && ctx.has(ContextKey::Strategies)
1006            && match read_latest_split_from_ctx(ctx) {
1007                Ok(split) => !has_inference_for_iteration(ctx, split.iteration),
1008                Err(_) => false,
1009            }
1010    }
1011
1012    fn execute(&self, ctx: &Context) -> AgentEffect {
1013        let split = match read_latest_split_from_ctx(ctx) {
1014            Ok(split) => split,
1015            Err(err) => {
1016                return AgentEffect::with_fact(Fact::new(
1017                    ContextKey::Diagnostic,
1018                    "model-infer-error",
1019                    err.to_string(),
1020                ))
1021            }
1022        };
1023
1024        let model = match read_model_from_ctx(ctx) {
1025            Ok(model) => model,
1026            Err(err) => {
1027                return AgentEffect::with_fact(Fact::new(
1028                    ContextKey::Diagnostic,
1029                    "model-infer-error",
1030                    err.to_string(),
1031                ))
1032            }
1033        };
1034
1035        let infer_df = match load_dataframe(Path::new(&split.infer_path)) {
1036            Ok(df) => df,
1037            Err(err) => {
1038                return AgentEffect::with_fact(Fact::new(
1039                    ContextKey::Diagnostic,
1040                    "model-infer-error",
1041                    err.to_string(),
1042                ))
1043            }
1044        };
1045
1046        let target = match get_numeric_series(&infer_df, &model.target_column) {
1047            Ok(series) => series,
1048            Err(err) => {
1049                return AgentEffect::with_fact(Fact::new(
1050                    ContextKey::Diagnostic,
1051                    "model-infer-error",
1052                    err.to_string(),
1053                ))
1054            }
1055        };
1056
1057        let sample_rows = self.max_rows.min(infer_df.height().max(1));
1058        let actuals = match target.f64() {
1059            Ok(series) => series
1060                .into_no_null_iter()
1061                .take(sample_rows)
1062                .collect::<Vec<_>>(),
1063            Err(err) => {
1064                return AgentEffect::with_fact(Fact::new(
1065                    ContextKey::Diagnostic,
1066                    "model-infer-error",
1067                    err.to_string(),
1068                ))
1069            }
1070        };
1071
1072        let predictions = vec![model.mean; actuals.len()];
1073        let sample = InferenceSample {
1074            model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
1075            target_column: model.target_column,
1076            rows: actuals.len(),
1077            predictions,
1078            actuals,
1079            iteration: split.iteration,
1080        };
1081
1082        let content = serde_json::to_string(&sample).unwrap_or_default();
1083        AgentEffect::with_fact(Fact::new(
1084            ContextKey::Hypotheses,
1085            format!("inference-sample-{}", split.iteration),
1086            content,
1087        ))
1088    }
1089}
1090
1091fn download_dataset_if_missing(path: &Path) -> Result<()> {
1092    if path.exists() {
1093        return Ok(());
1094    }
1095
1096    let response = reqwest::blocking::get(DATASET_URL)?;
1097    let content = response.bytes()?;
1098
1099    let mut file = File::create(path)?;
1100    file.write_all(&content)?;
1101
1102    Ok(())
1103}
1104
1105fn load_dataframe(path: &Path) -> Result<DataFrame> {
1106    let extension = path
1107        .extension()
1108        .and_then(|ext| ext.to_str())
1109        .unwrap_or("")
1110        .to_ascii_lowercase();
1111
1112    let path_str = path
1113        .to_str()
1114        .ok_or_else(|| anyhow!("path is not valid utf-8: {:?}", path))?;
1115
1116    match extension.as_str() {
1117        "parquet" => {
1118            let pl_path = PlPath::from_str(path_str);
1119            Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
1120        }
1121        "csv" => Ok(
1122            CsvReadOptions::default()
1123                .with_has_header(true)
1124                .try_into_reader_with_file_path(Some(path.to_path_buf()))?
1125                .finish()?,
1126        ),
1127        _ => Err(anyhow!(
1128            "unsupported data format for path {:?} (expected .csv or .parquet)",
1129            path
1130        )),
1131    }
1132}
1133
1134fn write_parquet(df: &DataFrame, path: &Path) -> Result<()> {
1135    let mut file = File::create(path)?;
1136    let mut owned = df.clone();
1137    ParquetWriter::new(&mut file).finish(&mut owned)?;
1138    Ok(())
1139}
1140
1141fn write_json<T: Serialize>(path: &Path, value: &T) -> Result<()> {
1142    let content = serde_json::to_string_pretty(value)?;
1143    let mut file = File::create(path)?;
1144    file.write_all(content.as_bytes())?;
1145    Ok(())
1146}
1147
1148fn read_latest_split_from_ctx(ctx: &Context) -> Result<DatasetSplit> {
1149    let facts = ctx.get(ContextKey::Signals);
1150    let mut latest: Option<DatasetSplit> = None;
1151    for fact in facts {
1152        if let Ok(split) = serde_json::from_str::<DatasetSplit>(&fact.content) {
1153            let should_replace = match &latest {
1154                Some(current) => split.iteration > current.iteration,
1155                None => true,
1156            };
1157            if should_replace {
1158                latest = Some(split);
1159            }
1160        }
1161    }
1162    latest.ok_or_else(|| anyhow!("missing dataset split"))
1163}
1164
1165fn read_model_path_from_ctx(ctx: &Context) -> Result<String> {
1166    let meta = read_latest_model_meta_from_ctx(ctx)?;
1167    Ok(meta.model_path)
1168}
1169
1170fn read_model_from_ctx(ctx: &Context) -> Result<BaselineModel> {
1171    let model_path = read_model_path_from_ctx(ctx)?;
1172    let content = std::fs::read_to_string(model_path)?;
1173    let model = serde_json::from_str(&content)?;
1174    Ok(model)
1175}
1176
1177fn read_latest_model_meta_from_ctx(ctx: &Context) -> Result<ModelMetadata> {
1178    let facts = ctx.get(ContextKey::Strategies);
1179    let mut latest: Option<ModelMetadata> = None;
1180    for fact in facts {
1181        if let Ok(meta) = serde_json::from_str::<ModelMetadata>(&fact.content) {
1182            let should_replace = match &latest {
1183                Some(current) => meta.iteration > current.iteration,
1184                None => true,
1185            };
1186            if should_replace {
1187                latest = Some(meta);
1188            }
1189        }
1190    }
1191    latest.ok_or_else(|| anyhow!("missing model metadata"))
1192}
1193
1194fn read_latest_plan_from_ctx(ctx: &Context) -> Option<TrainingPlan> {
1195    let facts = ctx.get(ContextKey::Constraints);
1196    let mut latest: Option<TrainingPlan> = None;
1197    for fact in facts {
1198        if let Ok(plan) = serde_json::from_str::<TrainingPlan>(&fact.content) {
1199            let should_replace = match &latest {
1200                Some(current) => plan.iteration > current.iteration,
1201                None => true,
1202            };
1203            if should_replace {
1204                latest = Some(plan);
1205            }
1206        }
1207    }
1208    latest
1209}
1210
1211fn has_split_for_iteration(ctx: &Context, iteration: usize) -> bool {
1212    ctx.get(ContextKey::Signals).iter().any(|fact| {
1213        serde_json::from_str::<DatasetSplit>(&fact.content)
1214            .map(|split| split.iteration == iteration)
1215            .unwrap_or(false)
1216    })
1217}
1218
1219fn has_model_for_iteration(ctx: &Context, iteration: usize) -> bool {
1220    ctx.get(ContextKey::Strategies).iter().any(|fact| {
1221        serde_json::from_str::<ModelMetadata>(&fact.content)
1222            .map(|meta| meta.iteration == iteration)
1223            .unwrap_or(false)
1224    })
1225}
1226
1227fn has_evaluation_for_iteration(ctx: &Context, iteration: usize) -> bool {
1228    ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1229        serde_json::from_str::<EvaluationReport>(&fact.content)
1230            .map(|report| report.iteration == iteration)
1231            .unwrap_or(false)
1232    })
1233}
1234
1235fn has_inference_for_iteration(ctx: &Context, iteration: usize) -> bool {
1236    ctx.get(ContextKey::Hypotheses).iter().any(|fact| {
1237        serde_json::from_str::<InferenceSample>(&fact.content)
1238            .map(|sample| sample.iteration == iteration)
1239            .unwrap_or(false)
1240    })
1241}
1242
1243fn has_data_quality_for_iteration(ctx: &Context, iteration: usize) -> bool {
1244    ctx.get(ContextKey::Signals).iter().any(|fact| {
1245        serde_json::from_str::<DataQualityReport>(&fact.content)
1246            .map(|report| report.iteration == iteration)
1247            .unwrap_or(false)
1248    })
1249}
1250
1251fn has_feature_spec_for_iteration(ctx: &Context, iteration: usize) -> bool {
1252    ctx.get(ContextKey::Constraints).iter().any(|fact| {
1253        serde_json::from_str::<FeatureSpec>(&fact.content)
1254            .map(|spec| spec.iteration == iteration)
1255            .unwrap_or(false)
1256    })
1257}
1258
1259fn has_hyperparam_result_for_iteration(ctx: &Context, iteration: usize) -> bool {
1260    ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1261        serde_json::from_str::<HyperparameterSearchResult>(&fact.content)
1262            .map(|result| result.iteration == iteration)
1263            .unwrap_or(false)
1264    })
1265}
1266
1267fn has_registry_record_for_iteration(ctx: &Context, iteration: usize) -> bool {
1268    ctx.get(ContextKey::Strategies).iter().any(|fact| {
1269        serde_json::from_str::<ModelRegistryRecord>(&fact.content)
1270            .map(|record| record.iteration == iteration)
1271            .unwrap_or(false)
1272    })
1273}
1274
1275fn has_monitoring_report_for_iteration(ctx: &Context, iteration: usize) -> bool {
1276    ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1277        serde_json::from_str::<MonitoringReport>(&fact.content)
1278            .map(|report| report.iteration == iteration)
1279            .unwrap_or(false)
1280    })
1281}
1282
1283fn has_deployment_decision_for_iteration(ctx: &Context, iteration: usize) -> bool {
1284    ctx.get(ContextKey::Strategies).iter().any(|fact| {
1285        serde_json::from_str::<DeploymentDecision>(&fact.content)
1286            .map(|decision| decision.iteration == iteration)
1287            .unwrap_or(false)
1288    })
1289}
1290
1291fn latest_evaluation_report(ctx: &Context, iteration: usize) -> Option<EvaluationReport> {
1292    let mut latest: Option<EvaluationReport> = None;
1293    for fact in ctx.get(ContextKey::Evaluations) {
1294        if let Ok(report) = serde_json::from_str::<EvaluationReport>(&fact.content) {
1295            if iteration > 0 {
1296                if report.iteration == iteration {
1297                    return Some(report);
1298                }
1299            } else if latest
1300                .as_ref()
1301                .map(|current| report.iteration > current.iteration)
1302                .unwrap_or(true)
1303            {
1304                latest = Some(report);
1305            }
1306        }
1307    }
1308    if iteration > 0 {
1309        None
1310    } else {
1311        latest
1312    }
1313}
1314
1315fn latest_data_quality_before_iteration(
1316    ctx: &Context,
1317    iteration: usize,
1318) -> Option<DataQualityReport> {
1319    let mut latest: Option<DataQualityReport> = None;
1320    for fact in ctx.get(ContextKey::Signals) {
1321        if let Ok(report) = serde_json::from_str::<DataQualityReport>(&fact.content) {
1322            if report.iteration < iteration
1323                && latest
1324                    .as_ref()
1325                    .map(|current| report.iteration > current.iteration)
1326                    .unwrap_or(true)
1327            {
1328                latest = Some(report);
1329            }
1330        }
1331    }
1332    latest
1333}
1334
1335fn drift_score_from_ctx(
1336    ctx: &Context,
1337    iteration: usize,
1338    numeric_means: &HashMap<String, f64>,
1339) -> Option<f64> {
1340    let previous = latest_data_quality_before_iteration(ctx, iteration)?;
1341    let mut total_delta = 0.0;
1342    let mut count = 0usize;
1343    for (name, mean) in numeric_means {
1344        if let Some(prev_mean) = previous.numeric_means.get(name) {
1345            total_delta += (mean - prev_mean).abs();
1346            count += 1;
1347        }
1348    }
1349    if count == 0 {
1350        None
1351    } else {
1352        Some(total_delta / count as f64)
1353    }
1354}
1355
1356fn compute_numeric_stats(series: &Series) -> Result<(f64, f64, usize)> {
1357    let casted = series.cast(&DataType::Float64)?;
1358    let values: Vec<f64> = casted
1359        .f64()
1360        .context("numeric series not f64")?
1361        .into_no_null_iter()
1362        .collect();
1363    if values.is_empty() {
1364        return Err(anyhow!("no numeric values to compute stats"));
1365    }
1366
1367    let mut total = 0.0;
1368    for value in &values {
1369        total += *value;
1370    }
1371    let mean = total / values.len() as f64;
1372
1373    let mut variance_sum = 0.0;
1374    for value in &values {
1375        let diff = *value - mean;
1376        variance_sum += diff * diff;
1377    }
1378    let std = (variance_sum / values.len() as f64).sqrt();
1379
1380    let outliers = if std > 0.0 {
1381        values
1382            .iter()
1383            .filter(|value| (*value - mean).abs() > 3.0 * std)
1384            .count()
1385    } else {
1386        0
1387    };
1388
1389    Ok((mean, std, outliers))
1390}
1391
1392fn split_feature_columns(df: &DataFrame, target: &str) -> (Vec<String>, Vec<String>) {
1393    let mut numeric = Vec::new();
1394    let mut categorical = Vec::new();
1395    for series in df.get_columns() {
1396        let name = series.name();
1397        if name == target {
1398            continue;
1399        }
1400        if is_numeric_dtype(series.dtype()) {
1401            numeric.push(name.to_string());
1402        } else {
1403            categorical.push(name.to_string());
1404        }
1405    }
1406    (numeric, categorical)
1407}
1408
1409fn select_target_column(df: &DataFrame) -> Result<(String, Series)> {
1410    if let Ok(col) = df.column(TARGET_COLUMN) {
1411        return Ok((
1412            TARGET_COLUMN.to_string(),
1413            col.as_materialized_series().clone(),
1414        ));
1415    }
1416
1417    let mut numeric = df
1418        .get_columns()
1419        .iter()
1420        .filter(|series| is_numeric_dtype(series.dtype()))
1421        .cloned()
1422        .collect::<Vec<_>>();
1423
1424    let fallback = numeric
1425        .pop()
1426        .ok_or_else(|| anyhow!("no numeric columns available for target"))?;
1427    let series = fallback.as_materialized_series().clone();
1428    Ok((series.name().to_string(), series))
1429}
1430
1431fn get_numeric_series(df: &DataFrame, name: &str) -> Result<Series> {
1432    let series = df
1433        .column(name)
1434        .map_err(|_| anyhow!("missing target column {}", name))?
1435        .as_materialized_series();
1436    let casted = series.cast(&DataType::Float64)?;
1437    Ok(casted)
1438}
1439
1440fn mean_of_series(series: &Series) -> Result<f64> {
1441    let casted = series.cast(&DataType::Float64)?;
1442    let values = casted
1443        .f64()
1444        .context("target column not f64")?
1445        .into_no_null_iter();
1446    let mut total = 0.0;
1447    let mut count = 0usize;
1448    for value in values {
1449        total += value;
1450        count += 1;
1451    }
1452    if count == 0 {
1453        return Err(anyhow!("no values to compute mean"));
1454    }
1455    Ok(total / count as f64)
1456}
1457
1458fn mean_abs_error(target: &Series, mean: f64) -> Result<f64> {
1459    let casted = target.cast(&DataType::Float64)?;
1460    let values = casted
1461        .f64()
1462        .context("target column not f64")?
1463        .into_no_null_iter();
1464    let mut total = 0.0;
1465    let mut count = 0usize;
1466    for value in values {
1467        total += (value - mean).abs();
1468        count += 1;
1469    }
1470    if count == 0 {
1471        return Err(anyhow!("no values to evaluate"));
1472    }
1473    Ok(total / count as f64)
1474}
1475
1476fn mean_abs_value(target: &Series) -> Result<f64> {
1477    let casted = target.cast(&DataType::Float64)?;
1478    let values = casted
1479        .f64()
1480        .context("target column not f64")?
1481        .into_no_null_iter();
1482    let mut total = 0.0;
1483    let mut count = 0usize;
1484    for value in values {
1485        total += value.abs();
1486        count += 1;
1487    }
1488    if count == 0 {
1489        return Err(anyhow!("no values to evaluate"));
1490    }
1491    Ok(total / count as f64)
1492}
1493fn is_numeric_dtype(dtype: &DataType) -> bool {
1494    matches!(
1495        dtype,
1496        DataType::Int8
1497            | DataType::Int16
1498            | DataType::Int32
1499            | DataType::Int64
1500            | DataType::UInt8
1501            | DataType::UInt16
1502            | DataType::UInt32
1503            | DataType::UInt64
1504            | DataType::Float32
1505            | DataType::Float64
1506    )
1507}