1use anyhow::{anyhow, Context as _, Result};
5use converge_core::{Agent, AgentEffect, Context, ContextKey, Fact};
6use polars::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{create_dir_all, File};
10use std::io::Write;
11use std::path::{Path, PathBuf};
12
13const DATASET_URL: &str = "https://huggingface.co/datasets/gvlassis/california_housing/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet";
14const TARGET_COLUMN: &str = "median_house_value";
15
16#[derive(Debug, Serialize, Deserialize, Clone)]
17pub struct TrainingPlan {
18 pub iteration: usize,
19 pub max_rows: usize,
20 pub train_fraction: f64,
21 pub val_fraction: f64,
22 pub infer_fraction: f64,
23 pub quality_threshold: f64,
24}
25
26#[derive(Debug, Serialize, Deserialize)]
27pub struct DatasetSplit {
28 pub source_path: String,
29 pub train_path: String,
30 pub val_path: String,
31 pub infer_path: String,
32 pub total_rows: usize,
33 pub max_rows: usize,
34 pub train_rows: usize,
35 pub val_rows: usize,
36 pub infer_rows: usize,
37 pub iteration: usize,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41pub struct BaselineModel {
42 pub target_column: String,
43 pub mean: f64,
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47pub struct ModelMetadata {
48 pub model_path: String,
49 pub target_column: String,
50 pub train_rows: usize,
51 pub baseline_mean: f64,
52 pub iteration: usize,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct EvaluationReport {
57 pub model_path: String,
58 pub metric: String,
59 pub value: f64,
60 pub mean_abs_target: f64,
61 pub success_ratio: f64,
62 pub val_rows: usize,
63 pub iteration: usize,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67pub struct InferenceSample {
68 pub model_path: String,
69 pub target_column: String,
70 pub rows: usize,
71 pub predictions: Vec<f64>,
72 pub actuals: Vec<f64>,
73 pub iteration: usize,
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77pub struct DataQualityReport {
78 pub kind: String,
79 pub iteration: usize,
80 pub source_path: String,
81 pub rows_checked: usize,
82 pub missingness: HashMap<String, f64>,
83 pub numeric_means: HashMap<String, f64>,
84 pub outlier_counts: HashMap<String, usize>,
85 pub drift_score: Option<f64>,
86}
87
88#[derive(Debug, Serialize, Deserialize)]
89pub struct FeatureInteraction {
90 pub name: String,
91 pub left: String,
92 pub right: String,
93 pub op: String,
94}
95
96#[derive(Debug, Serialize, Deserialize)]
97pub struct FeatureSpec {
98 pub kind: String,
99 pub iteration: usize,
100 pub target_column: String,
101 pub numeric_features: Vec<String>,
102 pub categorical_features: Vec<String>,
103 pub normalization: String,
104 pub interactions: Vec<FeatureInteraction>,
105}
106
107#[derive(Debug, Serialize, Deserialize)]
108pub struct HyperparameterSearchPlan {
109 pub kind: String,
110 pub iteration: usize,
111 pub max_trials: usize,
112 pub early_stopping: bool,
113 pub params: HashMap<String, Vec<f64>>,
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct HyperparameterSearchResult {
118 pub kind: String,
119 pub iteration: usize,
120 pub best_params: HashMap<String, f64>,
121 pub score: f64,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125pub struct ModelRegistryRecord {
126 pub kind: String,
127 pub iteration: usize,
128 pub model_path: String,
129 pub metrics: HashMap<String, f64>,
130 pub provenance: String,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134pub struct MonitoringReport {
135 pub kind: String,
136 pub iteration: usize,
137 pub metric: String,
138 pub value: f64,
139 pub baseline: f64,
140 pub status: String,
141}
142
143#[derive(Debug, Serialize, Deserialize)]
144pub struct DeploymentDecision {
145 pub kind: String,
146 pub iteration: usize,
147 pub action: String,
148 pub reason: String,
149 pub retrain: bool,
150}
151
152pub struct DatasetAgent {
153 data_dir: PathBuf,
154}
155
156impl DatasetAgent {
157 pub fn new(data_dir: PathBuf) -> Self {
158 Self { data_dir }
159 }
160
161 fn dataset_path(&self) -> PathBuf {
162 self.data_dir.join("california_housing_train.parquet")
163 }
164
165 fn split_paths(&self) -> (PathBuf, PathBuf, PathBuf) {
166 (
167 self.data_dir.join("train.parquet"),
168 self.data_dir.join("val.parquet"),
169 self.data_dir.join("infer.parquet"),
170 )
171 }
172}
173
174pub struct DataValidationAgent;
175
176impl DataValidationAgent {
177 pub fn new() -> Self {
178 Self
179 }
180}
181
182impl Agent for DataValidationAgent {
183 fn name(&self) -> &str {
184 "DataValidationAgent"
185 }
186
187 fn dependencies(&self) -> &[ContextKey] {
188 &[ContextKey::Signals]
189 }
190
191 fn accepts(&self, ctx: &Context) -> bool {
192 ctx.has(ContextKey::Signals)
193 && match read_latest_split_from_ctx(ctx) {
194 Ok(split) => !has_data_quality_for_iteration(ctx, split.iteration),
195 Err(_) => false,
196 }
197 }
198
199 fn execute(&self, ctx: &Context) -> AgentEffect {
200 let split = match read_latest_split_from_ctx(ctx) {
201 Ok(split) => split,
202 Err(err) => {
203 return AgentEffect::with_fact(Fact::new(
204 ContextKey::Diagnostic,
205 "data-validation-error",
206 err.to_string(),
207 ))
208 }
209 };
210
211 let df = match load_dataframe(Path::new(&split.train_path)) {
212 Ok(df) => df,
213 Err(err) => {
214 return AgentEffect::with_fact(Fact::new(
215 ContextKey::Diagnostic,
216 "data-validation-error",
217 err.to_string(),
218 ))
219 }
220 };
221
222 let rows = df.height();
223 let mut missingness = HashMap::new();
224 let mut numeric_means = HashMap::new();
225 let mut outlier_counts = HashMap::new();
226
227 for series in df.get_columns() {
228 let name = series.name().to_string();
229 let null_ratio = if rows > 0 {
230 series.null_count() as f64 / rows as f64
231 } else {
232 0.0
233 };
234 missingness.insert(name.clone(), null_ratio);
235
236 if is_numeric_dtype(series.dtype()) {
237 if let Ok((mean, _std, outliers)) =
238 compute_numeric_stats(series.as_materialized_series())
239 {
240 numeric_means.insert(name.clone(), mean);
241 outlier_counts.insert(name, outliers);
242 }
243 }
244 }
245
246 let drift_score = drift_score_from_ctx(ctx, split.iteration, &numeric_means);
247
248 let report = DataQualityReport {
249 kind: "data_quality".to_string(),
250 iteration: split.iteration,
251 source_path: split.train_path.clone(),
252 rows_checked: rows,
253 missingness,
254 numeric_means,
255 outlier_counts,
256 drift_score,
257 };
258
259 let content = serde_json::to_string(&report).unwrap_or_default();
260 AgentEffect::with_fact(Fact::new(
261 ContextKey::Signals,
262 format!("data-quality-{}", split.iteration),
263 content,
264 ))
265 }
266}
267
268pub struct FeatureEngineeringAgent;
269
270impl FeatureEngineeringAgent {
271 pub fn new() -> Self {
272 Self
273 }
274}
275
276impl Agent for FeatureEngineeringAgent {
277 fn name(&self) -> &str {
278 "FeatureEngineeringAgent"
279 }
280
281 fn dependencies(&self) -> &[ContextKey] {
282 &[ContextKey::Signals]
283 }
284
285 fn accepts(&self, ctx: &Context) -> bool {
286 ctx.has(ContextKey::Signals)
287 && match read_latest_split_from_ctx(ctx) {
288 Ok(split) => !has_feature_spec_for_iteration(ctx, split.iteration),
289 Err(_) => false,
290 }
291 }
292
293 fn execute(&self, ctx: &Context) -> AgentEffect {
294 let split = match read_latest_split_from_ctx(ctx) {
295 Ok(split) => split,
296 Err(err) => {
297 return AgentEffect::with_fact(Fact::new(
298 ContextKey::Diagnostic,
299 "feature-engineering-error",
300 err.to_string(),
301 ))
302 }
303 };
304
305 let df = match load_dataframe(Path::new(&split.train_path)) {
306 Ok(df) => df,
307 Err(err) => {
308 return AgentEffect::with_fact(Fact::new(
309 ContextKey::Diagnostic,
310 "feature-engineering-error",
311 err.to_string(),
312 ))
313 }
314 };
315
316 let (target_column, _) = match select_target_column(&df) {
317 Ok(value) => value,
318 Err(err) => {
319 return AgentEffect::with_fact(Fact::new(
320 ContextKey::Diagnostic,
321 "feature-engineering-error",
322 err.to_string(),
323 ))
324 }
325 };
326
327 let (numeric_features, categorical_features) =
328 split_feature_columns(&df, &target_column);
329
330 let mut interactions = Vec::new();
331 if numeric_features.len() >= 2 {
332 interactions.push(FeatureInteraction {
333 name: format!("{}_x_{}", numeric_features[0], numeric_features[1]),
334 left: numeric_features[0].clone(),
335 right: numeric_features[1].clone(),
336 op: "multiply".to_string(),
337 });
338 }
339
340 let spec = FeatureSpec {
341 kind: "feature_spec".to_string(),
342 iteration: split.iteration,
343 target_column,
344 numeric_features,
345 categorical_features,
346 normalization: "standardize".to_string(),
347 interactions,
348 };
349
350 let content = serde_json::to_string(&spec).unwrap_or_default();
351 AgentEffect::with_fact(Fact::new(
352 ContextKey::Constraints,
353 format!("feature-spec-{}", split.iteration),
354 content,
355 ))
356 }
357}
358
359pub struct HyperparameterSearchAgent {
360 max_trials: usize,
361}
362
363impl HyperparameterSearchAgent {
364 pub fn new(max_trials: usize) -> Self {
365 Self { max_trials }
366 }
367}
368
369impl Agent for HyperparameterSearchAgent {
370 fn name(&self) -> &str {
371 "HyperparameterSearchAgent"
372 }
373
374 fn dependencies(&self) -> &[ContextKey] {
375 &[ContextKey::Constraints, ContextKey::Signals]
376 }
377
378 fn accepts(&self, ctx: &Context) -> bool {
379 ctx.has(ContextKey::Signals)
380 && match read_latest_split_from_ctx(ctx) {
381 Ok(split) => !has_hyperparam_result_for_iteration(ctx, split.iteration),
382 Err(_) => false,
383 }
384 }
385
386 fn execute(&self, ctx: &Context) -> AgentEffect {
387 let split = match read_latest_split_from_ctx(ctx) {
388 Ok(split) => split,
389 Err(err) => {
390 return AgentEffect::with_fact(Fact::new(
391 ContextKey::Diagnostic,
392 "hyperparam-search-error",
393 err.to_string(),
394 ))
395 }
396 };
397
398 let training_plan = read_latest_plan_from_ctx(ctx).unwrap_or_else(|| TrainingPlan {
399 iteration: split.iteration,
400 max_rows: split.max_rows,
401 train_fraction: 0.8,
402 val_fraction: 0.15,
403 infer_fraction: 0.05,
404 quality_threshold: 0.75,
405 });
406
407 let mut params = HashMap::new();
408 params.insert("learning_rate".to_string(), vec![0.001, 0.01, 0.1]);
409 params.insert("hidden_size".to_string(), vec![8.0, 16.0, 32.0]);
410
411 let plan = HyperparameterSearchPlan {
412 kind: "hyperparam_plan".to_string(),
413 iteration: split.iteration,
414 max_trials: self.max_trials,
415 early_stopping: true,
416 params,
417 };
418
419 let mut best_params = HashMap::new();
420 best_params.insert("learning_rate".to_string(), 0.01);
421 best_params.insert("hidden_size".to_string(), 16.0);
422 let score = (1.0 - training_plan.quality_threshold)
423 * plan.max_trials as f64
424 / plan.iteration.max(1) as f64;
425 let result = HyperparameterSearchResult {
426 kind: "hyperparam_result".to_string(),
427 iteration: split.iteration,
428 best_params,
429 score,
430 };
431
432 let plan_content = serde_json::to_string(&plan).unwrap_or_default();
433 let result_content = serde_json::to_string(&result).unwrap_or_default();
434
435 let mut effect = AgentEffect::empty();
436 effect.facts.push(Fact::new(
437 ContextKey::Constraints,
438 format!("hyperparam-plan-{}", split.iteration),
439 plan_content,
440 ));
441 effect.facts.push(Fact::new(
442 ContextKey::Evaluations,
443 format!("hyperparam-result-{}", split.iteration),
444 result_content,
445 ));
446 effect
447 }
448}
449
450impl Agent for DatasetAgent {
451 fn name(&self) -> &str {
452 "DatasetAgent (HuggingFace)"
453 }
454
455 fn dependencies(&self) -> &[ContextKey] {
456 &[ContextKey::Seeds]
457 }
458
459 fn accepts(&self, ctx: &Context) -> bool {
460 if !ctx.has(ContextKey::Seeds) {
461 return false;
462 }
463
464 let plan = read_latest_plan_from_ctx(ctx);
465 if let Some(plan) = plan {
466 return !has_split_for_iteration(ctx, plan.iteration);
467 }
468
469 !ctx.has(ContextKey::Signals)
470 }
471
472 fn execute(&self, ctx: &Context) -> AgentEffect {
473 if let Err(err) = create_dir_all(&self.data_dir) {
474 return AgentEffect::with_fact(Fact::new(
475 ContextKey::Diagnostic,
476 "dataset-agent-error",
477 err.to_string(),
478 ));
479 }
480
481 let dataset_path = self.dataset_path();
482 if let Err(err) = download_dataset_if_missing(&dataset_path) {
483 return AgentEffect::with_fact(Fact::new(
484 ContextKey::Diagnostic,
485 "dataset-agent-error",
486 err.to_string(),
487 ));
488 }
489
490 let df = match load_dataframe(&dataset_path) {
491 Ok(df) => df,
492 Err(err) => {
493 return AgentEffect::with_fact(Fact::new(
494 ContextKey::Diagnostic,
495 "dataset-agent-error",
496 err.to_string(),
497 ))
498 }
499 };
500
501 let total_rows = df.height();
502 if total_rows < 10 {
503 return AgentEffect::with_fact(Fact::new(
504 ContextKey::Diagnostic,
505 "dataset-agent-error",
506 "dataset too small for splitting",
507 ));
508 }
509
510 let plan = read_latest_plan_from_ctx(ctx).unwrap_or_else(|| TrainingPlan {
511 iteration: 1,
512 max_rows: total_rows,
513 train_fraction: 0.8,
514 val_fraction: 0.15,
515 infer_fraction: 0.05,
516 quality_threshold: 0.75,
517 });
518
519 let max_rows = plan.max_rows.min(total_rows).max(10);
520 let df = df.slice(0, max_rows);
521
522 let mut train_rows = ((max_rows as f64) * plan.train_fraction).floor() as usize;
523 let mut val_rows = ((max_rows as f64) * plan.val_fraction).floor() as usize;
524 let mut infer_rows = max_rows.saturating_sub(train_rows + val_rows);
525 if infer_rows == 0 {
526 if val_rows > 1 {
527 val_rows -= 1;
528 } else if train_rows > 1 {
529 train_rows -= 1;
530 }
531 infer_rows = max_rows.saturating_sub(train_rows + val_rows).max(1);
532 }
533
534 let (train_path, val_path, infer_path) = self.split_paths();
535 let train_df = df.slice(0, train_rows);
536 let val_df = df.slice(train_rows as i64, val_rows);
537 let infer_df = df.slice((train_rows + val_rows) as i64, infer_rows);
538
539 if let Err(err) = write_parquet(&train_df, &train_path)
540 .and_then(|_| write_parquet(&val_df, &val_path))
541 .and_then(|_| write_parquet(&infer_df, &infer_path))
542 {
543 return AgentEffect::with_fact(Fact::new(
544 ContextKey::Diagnostic,
545 "dataset-agent-error",
546 err.to_string(),
547 ));
548 }
549
550 let split = DatasetSplit {
551 source_path: dataset_path.to_string_lossy().to_string(),
552 train_path: train_path.to_string_lossy().to_string(),
553 val_path: val_path.to_string_lossy().to_string(),
554 infer_path: infer_path.to_string_lossy().to_string(),
555 total_rows,
556 max_rows,
557 train_rows,
558 val_rows,
559 infer_rows,
560 iteration: plan.iteration,
561 };
562
563 let content = serde_json::to_string(&split).unwrap_or_default();
564 AgentEffect::with_fact(Fact::new(
565 ContextKey::Signals,
566 format!("dataset-split-{}", plan.iteration),
567 content,
568 ))
569 }
570}
571
572pub struct ModelTrainingAgent {
573 model_dir: PathBuf,
574}
575
576impl ModelTrainingAgent {
577 pub fn new(model_dir: PathBuf) -> Self {
578 Self { model_dir }
579 }
580
581 fn model_path(&self) -> PathBuf {
582 self.model_dir.join("baseline_mean.json")
583 }
584}
585
586impl Agent for ModelTrainingAgent {
587 fn name(&self) -> &str {
588 "ModelTrainingAgent (Baseline)"
589 }
590
591 fn dependencies(&self) -> &[ContextKey] {
592 &[ContextKey::Signals]
593 }
594
595 fn accepts(&self, ctx: &Context) -> bool {
596 if !ctx.has(ContextKey::Signals) {
597 return false;
598 }
599 let split = match read_latest_split_from_ctx(ctx) {
600 Ok(split) => split,
601 Err(_) => return false,
602 };
603 !has_model_for_iteration(ctx, split.iteration)
604 }
605
606 fn execute(&self, ctx: &Context) -> AgentEffect {
607 let split = match read_latest_split_from_ctx(ctx) {
608 Ok(split) => split,
609 Err(err) => {
610 return AgentEffect::with_fact(Fact::new(
611 ContextKey::Diagnostic,
612 "model-training-error",
613 err.to_string(),
614 ))
615 }
616 };
617
618 if let Err(err) = create_dir_all(&self.model_dir) {
619 return AgentEffect::with_fact(Fact::new(
620 ContextKey::Diagnostic,
621 "model-training-error",
622 err.to_string(),
623 ));
624 }
625
626 let train_df = match load_dataframe(Path::new(&split.train_path)) {
627 Ok(df) => df,
628 Err(err) => {
629 return AgentEffect::with_fact(Fact::new(
630 ContextKey::Diagnostic,
631 "model-training-error",
632 err.to_string(),
633 ))
634 }
635 };
636
637 let (target_name, target) = match select_target_column(&train_df) {
638 Ok(value) => value,
639 Err(err) => {
640 return AgentEffect::with_fact(Fact::new(
641 ContextKey::Diagnostic,
642 "model-training-error",
643 err.to_string(),
644 ))
645 }
646 };
647
648 let mean = match mean_of_series(&target) {
649 Ok(value) => value,
650 Err(err) => {
651 return AgentEffect::with_fact(Fact::new(
652 ContextKey::Diagnostic,
653 "model-training-error",
654 err.to_string(),
655 ))
656 }
657 };
658
659 let model = BaselineModel {
660 target_column: target_name.clone(),
661 mean,
662 };
663
664 let model_path = self.model_path();
665 if let Err(err) = write_json(&model_path, &model) {
666 return AgentEffect::with_fact(Fact::new(
667 ContextKey::Diagnostic,
668 "model-training-error",
669 err.to_string(),
670 ));
671 }
672
673 let meta = ModelMetadata {
674 model_path: model_path.to_string_lossy().to_string(),
675 target_column: target_name,
676 train_rows: split.train_rows,
677 baseline_mean: mean,
678 iteration: split.iteration,
679 };
680
681 let content = serde_json::to_string(&meta).unwrap_or_default();
682 AgentEffect::with_fact(Fact::new(
683 ContextKey::Strategies,
684 format!("trained-model-{}", split.iteration),
685 content,
686 ))
687 }
688}
689
690pub struct ModelEvaluationAgent;
691
692impl ModelEvaluationAgent {
693 pub fn new() -> Self {
694 Self
695 }
696}
697
698pub struct ModelRegistryAgent;
699
700impl ModelRegistryAgent {
701 pub fn new() -> Self {
702 Self
703 }
704}
705
706impl Agent for ModelRegistryAgent {
707 fn name(&self) -> &str {
708 "ModelRegistryAgent"
709 }
710
711 fn dependencies(&self) -> &[ContextKey] {
712 &[ContextKey::Strategies, ContextKey::Evaluations]
713 }
714
715 fn accepts(&self, ctx: &Context) -> bool {
716 ctx.has(ContextKey::Strategies)
717 && ctx.has(ContextKey::Evaluations)
718 && match read_latest_model_meta_from_ctx(ctx) {
719 Ok(meta) => !has_registry_record_for_iteration(ctx, meta.iteration),
720 Err(_) => false,
721 }
722 }
723
724 fn execute(&self, ctx: &Context) -> AgentEffect {
725 let meta = match read_latest_model_meta_from_ctx(ctx) {
726 Ok(meta) => meta,
727 Err(err) => {
728 return AgentEffect::with_fact(Fact::new(
729 ContextKey::Diagnostic,
730 "model-registry-error",
731 err.to_string(),
732 ))
733 }
734 };
735
736 let report = latest_evaluation_report(ctx, meta.iteration);
737 let mut metrics = HashMap::new();
738 if let Some(report) = report {
739 metrics.insert(report.metric, report.value);
740 metrics.insert("success_ratio".to_string(), report.success_ratio);
741 }
742
743 let record = ModelRegistryRecord {
744 kind: "model_registry".to_string(),
745 iteration: meta.iteration,
746 model_path: meta.model_path,
747 metrics,
748 provenance: "training_flow".to_string(),
749 };
750
751 let content = serde_json::to_string(&record).unwrap_or_default();
752 AgentEffect::with_fact(Fact::new(
753 ContextKey::Strategies,
754 format!("model-registry-{}", record.iteration),
755 content,
756 ))
757 }
758}
759
760pub struct MonitoringAgent;
761
762impl MonitoringAgent {
763 pub fn new() -> Self {
764 Self
765 }
766}
767
768impl Agent for MonitoringAgent {
769 fn name(&self) -> &str {
770 "MonitoringAgent"
771 }
772
773 fn dependencies(&self) -> &[ContextKey] {
774 &[ContextKey::Evaluations]
775 }
776
777 fn accepts(&self, ctx: &Context) -> bool {
778 ctx.has(ContextKey::Evaluations)
779 && match latest_evaluation_report(ctx, 0) {
780 Some(report) => !has_monitoring_report_for_iteration(ctx, report.iteration),
781 None => false,
782 }
783 }
784
785 fn execute(&self, ctx: &Context) -> AgentEffect {
786 let report = match latest_evaluation_report(ctx, 0) {
787 Some(report) => report,
788 None => return AgentEffect::empty(),
789 };
790
791 let status = if report.success_ratio >= 0.75 {
792 "healthy"
793 } else {
794 "needs_attention"
795 };
796
797 let monitoring = MonitoringReport {
798 kind: "monitoring".to_string(),
799 iteration: report.iteration,
800 metric: report.metric,
801 value: report.value,
802 baseline: report.mean_abs_target,
803 status: status.to_string(),
804 };
805
806 let content = serde_json::to_string(&monitoring).unwrap_or_default();
807 AgentEffect::with_fact(Fact::new(
808 ContextKey::Evaluations,
809 format!("monitoring-{}", report.iteration),
810 content,
811 ))
812 }
813}
814
815pub struct DeploymentAgent;
816
817impl DeploymentAgent {
818 pub fn new() -> Self {
819 Self
820 }
821}
822
823impl Agent for DeploymentAgent {
824 fn name(&self) -> &str {
825 "DeploymentAgent"
826 }
827
828 fn dependencies(&self) -> &[ContextKey] {
829 &[ContextKey::Evaluations, ContextKey::Strategies]
830 }
831
832 fn accepts(&self, ctx: &Context) -> bool {
833 ctx.has(ContextKey::Evaluations)
834 && ctx.has(ContextKey::Strategies)
835 && match latest_evaluation_report(ctx, 0) {
836 Some(report) => !has_deployment_decision_for_iteration(ctx, report.iteration),
837 None => false,
838 }
839 }
840
841 fn execute(&self, ctx: &Context) -> AgentEffect {
842 let report = match latest_evaluation_report(ctx, 0) {
843 Some(report) => report,
844 None => return AgentEffect::empty(),
845 };
846
847 let quality_threshold = read_latest_plan_from_ctx(ctx)
848 .map(|plan| plan.quality_threshold)
849 .unwrap_or(0.75);
850
851 let (action, retrain, reason) = if report.success_ratio >= quality_threshold {
852 ("deploy", false, "meets quality threshold")
853 } else {
854 ("hold", true, "below quality threshold")
855 };
856
857 let decision = DeploymentDecision {
858 kind: "deployment_decision".to_string(),
859 iteration: report.iteration,
860 action: action.to_string(),
861 reason: reason.to_string(),
862 retrain,
863 };
864
865 let content = serde_json::to_string(&decision).unwrap_or_default();
866 AgentEffect::with_fact(Fact::new(
867 ContextKey::Strategies,
868 format!("deployment-{}", report.iteration),
869 content,
870 ))
871 }
872}
873
874impl Agent for ModelEvaluationAgent {
875 fn name(&self) -> &str {
876 "ModelEvaluationAgent (MAE)"
877 }
878
879 fn dependencies(&self) -> &[ContextKey] {
880 &[ContextKey::Signals, ContextKey::Strategies]
881 }
882
883 fn accepts(&self, ctx: &Context) -> bool {
884 ctx.has(ContextKey::Signals)
885 && ctx.has(ContextKey::Strategies)
886 && match read_latest_split_from_ctx(ctx) {
887 Ok(split) => !has_evaluation_for_iteration(ctx, split.iteration),
888 Err(_) => false,
889 }
890 }
891
892 fn execute(&self, ctx: &Context) -> AgentEffect {
893 let split = match read_latest_split_from_ctx(ctx) {
894 Ok(split) => split,
895 Err(err) => {
896 return AgentEffect::with_fact(Fact::new(
897 ContextKey::Diagnostic,
898 "model-eval-error",
899 err.to_string(),
900 ))
901 }
902 };
903
904 let model = match read_model_from_ctx(ctx) {
905 Ok(model) => model,
906 Err(err) => {
907 return AgentEffect::with_fact(Fact::new(
908 ContextKey::Diagnostic,
909 "model-eval-error",
910 err.to_string(),
911 ))
912 }
913 };
914
915 let val_df = match load_dataframe(Path::new(&split.val_path)) {
916 Ok(df) => df,
917 Err(err) => {
918 return AgentEffect::with_fact(Fact::new(
919 ContextKey::Diagnostic,
920 "model-eval-error",
921 err.to_string(),
922 ))
923 }
924 };
925
926 let target = match get_numeric_series(&val_df, &model.target_column) {
927 Ok(series) => series,
928 Err(err) => {
929 return AgentEffect::with_fact(Fact::new(
930 ContextKey::Diagnostic,
931 "model-eval-error",
932 err.to_string(),
933 ))
934 }
935 };
936
937 let mae = match mean_abs_error(&target, model.mean) {
938 Ok(value) => value,
939 Err(err) => {
940 return AgentEffect::with_fact(Fact::new(
941 ContextKey::Diagnostic,
942 "model-eval-error",
943 err.to_string(),
944 ))
945 }
946 };
947
948 let mean_abs = match mean_abs_value(&target) {
949 Ok(value) => value,
950 Err(err) => {
951 return AgentEffect::with_fact(Fact::new(
952 ContextKey::Diagnostic,
953 "model-eval-error",
954 err.to_string(),
955 ))
956 }
957 };
958
959 let success_ratio = if mean_abs > 0.0 {
960 (1.0 - (mae / mean_abs)).clamp(0.0, 1.0)
961 } else {
962 0.0
963 };
964
965 let report = EvaluationReport {
966 model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
967 metric: "mae".to_string(),
968 value: mae,
969 mean_abs_target: mean_abs,
970 success_ratio,
971 val_rows: split.val_rows,
972 iteration: split.iteration,
973 };
974
975 let content = serde_json::to_string(&report).unwrap_or_default();
976 AgentEffect::with_fact(Fact::new(
977 ContextKey::Evaluations,
978 format!("model-eval-{}", split.iteration),
979 content,
980 ))
981 }
982}
983
984pub struct SampleInferenceAgent {
985 pub max_rows: usize,
986}
987
988impl SampleInferenceAgent {
989 pub fn new(max_rows: usize) -> Self {
990 Self { max_rows }
991 }
992}
993
994impl Agent for SampleInferenceAgent {
995 fn name(&self) -> &str {
996 "SampleInferenceAgent (Baseline)"
997 }
998
999 fn dependencies(&self) -> &[ContextKey] {
1000 &[ContextKey::Signals, ContextKey::Strategies]
1001 }
1002
1003 fn accepts(&self, ctx: &Context) -> bool {
1004 ctx.has(ContextKey::Signals)
1005 && ctx.has(ContextKey::Strategies)
1006 && match read_latest_split_from_ctx(ctx) {
1007 Ok(split) => !has_inference_for_iteration(ctx, split.iteration),
1008 Err(_) => false,
1009 }
1010 }
1011
1012 fn execute(&self, ctx: &Context) -> AgentEffect {
1013 let split = match read_latest_split_from_ctx(ctx) {
1014 Ok(split) => split,
1015 Err(err) => {
1016 return AgentEffect::with_fact(Fact::new(
1017 ContextKey::Diagnostic,
1018 "model-infer-error",
1019 err.to_string(),
1020 ))
1021 }
1022 };
1023
1024 let model = match read_model_from_ctx(ctx) {
1025 Ok(model) => model,
1026 Err(err) => {
1027 return AgentEffect::with_fact(Fact::new(
1028 ContextKey::Diagnostic,
1029 "model-infer-error",
1030 err.to_string(),
1031 ))
1032 }
1033 };
1034
1035 let infer_df = match load_dataframe(Path::new(&split.infer_path)) {
1036 Ok(df) => df,
1037 Err(err) => {
1038 return AgentEffect::with_fact(Fact::new(
1039 ContextKey::Diagnostic,
1040 "model-infer-error",
1041 err.to_string(),
1042 ))
1043 }
1044 };
1045
1046 let target = match get_numeric_series(&infer_df, &model.target_column) {
1047 Ok(series) => series,
1048 Err(err) => {
1049 return AgentEffect::with_fact(Fact::new(
1050 ContextKey::Diagnostic,
1051 "model-infer-error",
1052 err.to_string(),
1053 ))
1054 }
1055 };
1056
1057 let sample_rows = self.max_rows.min(infer_df.height().max(1));
1058 let actuals = match target.f64() {
1059 Ok(series) => series
1060 .into_no_null_iter()
1061 .take(sample_rows)
1062 .collect::<Vec<_>>(),
1063 Err(err) => {
1064 return AgentEffect::with_fact(Fact::new(
1065 ContextKey::Diagnostic,
1066 "model-infer-error",
1067 err.to_string(),
1068 ))
1069 }
1070 };
1071
1072 let predictions = vec![model.mean; actuals.len()];
1073 let sample = InferenceSample {
1074 model_path: read_model_path_from_ctx(ctx).unwrap_or_default(),
1075 target_column: model.target_column,
1076 rows: actuals.len(),
1077 predictions,
1078 actuals,
1079 iteration: split.iteration,
1080 };
1081
1082 let content = serde_json::to_string(&sample).unwrap_or_default();
1083 AgentEffect::with_fact(Fact::new(
1084 ContextKey::Hypotheses,
1085 format!("inference-sample-{}", split.iteration),
1086 content,
1087 ))
1088 }
1089}
1090
1091fn download_dataset_if_missing(path: &Path) -> Result<()> {
1092 if path.exists() {
1093 return Ok(());
1094 }
1095
1096 let response = reqwest::blocking::get(DATASET_URL)?;
1097 let content = response.bytes()?;
1098
1099 let mut file = File::create(path)?;
1100 file.write_all(&content)?;
1101
1102 Ok(())
1103}
1104
1105fn load_dataframe(path: &Path) -> Result<DataFrame> {
1106 let extension = path
1107 .extension()
1108 .and_then(|ext| ext.to_str())
1109 .unwrap_or("")
1110 .to_ascii_lowercase();
1111
1112 let path_str = path
1113 .to_str()
1114 .ok_or_else(|| anyhow!("path is not valid utf-8: {:?}", path))?;
1115
1116 match extension.as_str() {
1117 "parquet" => {
1118 let pl_path = PlPath::from_str(path_str);
1119 Ok(LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?)
1120 }
1121 "csv" => Ok(
1122 CsvReadOptions::default()
1123 .with_has_header(true)
1124 .try_into_reader_with_file_path(Some(path.to_path_buf()))?
1125 .finish()?,
1126 ),
1127 _ => Err(anyhow!(
1128 "unsupported data format for path {:?} (expected .csv or .parquet)",
1129 path
1130 )),
1131 }
1132}
1133
1134fn write_parquet(df: &DataFrame, path: &Path) -> Result<()> {
1135 let mut file = File::create(path)?;
1136 let mut owned = df.clone();
1137 ParquetWriter::new(&mut file).finish(&mut owned)?;
1138 Ok(())
1139}
1140
1141fn write_json<T: Serialize>(path: &Path, value: &T) -> Result<()> {
1142 let content = serde_json::to_string_pretty(value)?;
1143 let mut file = File::create(path)?;
1144 file.write_all(content.as_bytes())?;
1145 Ok(())
1146}
1147
1148fn read_latest_split_from_ctx(ctx: &Context) -> Result<DatasetSplit> {
1149 let facts = ctx.get(ContextKey::Signals);
1150 let mut latest: Option<DatasetSplit> = None;
1151 for fact in facts {
1152 if let Ok(split) = serde_json::from_str::<DatasetSplit>(&fact.content) {
1153 let should_replace = match &latest {
1154 Some(current) => split.iteration > current.iteration,
1155 None => true,
1156 };
1157 if should_replace {
1158 latest = Some(split);
1159 }
1160 }
1161 }
1162 latest.ok_or_else(|| anyhow!("missing dataset split"))
1163}
1164
1165fn read_model_path_from_ctx(ctx: &Context) -> Result<String> {
1166 let meta = read_latest_model_meta_from_ctx(ctx)?;
1167 Ok(meta.model_path)
1168}
1169
1170fn read_model_from_ctx(ctx: &Context) -> Result<BaselineModel> {
1171 let model_path = read_model_path_from_ctx(ctx)?;
1172 let content = std::fs::read_to_string(model_path)?;
1173 let model = serde_json::from_str(&content)?;
1174 Ok(model)
1175}
1176
1177fn read_latest_model_meta_from_ctx(ctx: &Context) -> Result<ModelMetadata> {
1178 let facts = ctx.get(ContextKey::Strategies);
1179 let mut latest: Option<ModelMetadata> = None;
1180 for fact in facts {
1181 if let Ok(meta) = serde_json::from_str::<ModelMetadata>(&fact.content) {
1182 let should_replace = match &latest {
1183 Some(current) => meta.iteration > current.iteration,
1184 None => true,
1185 };
1186 if should_replace {
1187 latest = Some(meta);
1188 }
1189 }
1190 }
1191 latest.ok_or_else(|| anyhow!("missing model metadata"))
1192}
1193
1194fn read_latest_plan_from_ctx(ctx: &Context) -> Option<TrainingPlan> {
1195 let facts = ctx.get(ContextKey::Constraints);
1196 let mut latest: Option<TrainingPlan> = None;
1197 for fact in facts {
1198 if let Ok(plan) = serde_json::from_str::<TrainingPlan>(&fact.content) {
1199 let should_replace = match &latest {
1200 Some(current) => plan.iteration > current.iteration,
1201 None => true,
1202 };
1203 if should_replace {
1204 latest = Some(plan);
1205 }
1206 }
1207 }
1208 latest
1209}
1210
1211fn has_split_for_iteration(ctx: &Context, iteration: usize) -> bool {
1212 ctx.get(ContextKey::Signals).iter().any(|fact| {
1213 serde_json::from_str::<DatasetSplit>(&fact.content)
1214 .map(|split| split.iteration == iteration)
1215 .unwrap_or(false)
1216 })
1217}
1218
1219fn has_model_for_iteration(ctx: &Context, iteration: usize) -> bool {
1220 ctx.get(ContextKey::Strategies).iter().any(|fact| {
1221 serde_json::from_str::<ModelMetadata>(&fact.content)
1222 .map(|meta| meta.iteration == iteration)
1223 .unwrap_or(false)
1224 })
1225}
1226
1227fn has_evaluation_for_iteration(ctx: &Context, iteration: usize) -> bool {
1228 ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1229 serde_json::from_str::<EvaluationReport>(&fact.content)
1230 .map(|report| report.iteration == iteration)
1231 .unwrap_or(false)
1232 })
1233}
1234
1235fn has_inference_for_iteration(ctx: &Context, iteration: usize) -> bool {
1236 ctx.get(ContextKey::Hypotheses).iter().any(|fact| {
1237 serde_json::from_str::<InferenceSample>(&fact.content)
1238 .map(|sample| sample.iteration == iteration)
1239 .unwrap_or(false)
1240 })
1241}
1242
1243fn has_data_quality_for_iteration(ctx: &Context, iteration: usize) -> bool {
1244 ctx.get(ContextKey::Signals).iter().any(|fact| {
1245 serde_json::from_str::<DataQualityReport>(&fact.content)
1246 .map(|report| report.iteration == iteration)
1247 .unwrap_or(false)
1248 })
1249}
1250
1251fn has_feature_spec_for_iteration(ctx: &Context, iteration: usize) -> bool {
1252 ctx.get(ContextKey::Constraints).iter().any(|fact| {
1253 serde_json::from_str::<FeatureSpec>(&fact.content)
1254 .map(|spec| spec.iteration == iteration)
1255 .unwrap_or(false)
1256 })
1257}
1258
1259fn has_hyperparam_result_for_iteration(ctx: &Context, iteration: usize) -> bool {
1260 ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1261 serde_json::from_str::<HyperparameterSearchResult>(&fact.content)
1262 .map(|result| result.iteration == iteration)
1263 .unwrap_or(false)
1264 })
1265}
1266
1267fn has_registry_record_for_iteration(ctx: &Context, iteration: usize) -> bool {
1268 ctx.get(ContextKey::Strategies).iter().any(|fact| {
1269 serde_json::from_str::<ModelRegistryRecord>(&fact.content)
1270 .map(|record| record.iteration == iteration)
1271 .unwrap_or(false)
1272 })
1273}
1274
1275fn has_monitoring_report_for_iteration(ctx: &Context, iteration: usize) -> bool {
1276 ctx.get(ContextKey::Evaluations).iter().any(|fact| {
1277 serde_json::from_str::<MonitoringReport>(&fact.content)
1278 .map(|report| report.iteration == iteration)
1279 .unwrap_or(false)
1280 })
1281}
1282
1283fn has_deployment_decision_for_iteration(ctx: &Context, iteration: usize) -> bool {
1284 ctx.get(ContextKey::Strategies).iter().any(|fact| {
1285 serde_json::from_str::<DeploymentDecision>(&fact.content)
1286 .map(|decision| decision.iteration == iteration)
1287 .unwrap_or(false)
1288 })
1289}
1290
1291fn latest_evaluation_report(ctx: &Context, iteration: usize) -> Option<EvaluationReport> {
1292 let mut latest: Option<EvaluationReport> = None;
1293 for fact in ctx.get(ContextKey::Evaluations) {
1294 if let Ok(report) = serde_json::from_str::<EvaluationReport>(&fact.content) {
1295 if iteration > 0 {
1296 if report.iteration == iteration {
1297 return Some(report);
1298 }
1299 } else if latest
1300 .as_ref()
1301 .map(|current| report.iteration > current.iteration)
1302 .unwrap_or(true)
1303 {
1304 latest = Some(report);
1305 }
1306 }
1307 }
1308 if iteration > 0 {
1309 None
1310 } else {
1311 latest
1312 }
1313}
1314
1315fn latest_data_quality_before_iteration(
1316 ctx: &Context,
1317 iteration: usize,
1318) -> Option<DataQualityReport> {
1319 let mut latest: Option<DataQualityReport> = None;
1320 for fact in ctx.get(ContextKey::Signals) {
1321 if let Ok(report) = serde_json::from_str::<DataQualityReport>(&fact.content) {
1322 if report.iteration < iteration
1323 && latest
1324 .as_ref()
1325 .map(|current| report.iteration > current.iteration)
1326 .unwrap_or(true)
1327 {
1328 latest = Some(report);
1329 }
1330 }
1331 }
1332 latest
1333}
1334
1335fn drift_score_from_ctx(
1336 ctx: &Context,
1337 iteration: usize,
1338 numeric_means: &HashMap<String, f64>,
1339) -> Option<f64> {
1340 let previous = latest_data_quality_before_iteration(ctx, iteration)?;
1341 let mut total_delta = 0.0;
1342 let mut count = 0usize;
1343 for (name, mean) in numeric_means {
1344 if let Some(prev_mean) = previous.numeric_means.get(name) {
1345 total_delta += (mean - prev_mean).abs();
1346 count += 1;
1347 }
1348 }
1349 if count == 0 {
1350 None
1351 } else {
1352 Some(total_delta / count as f64)
1353 }
1354}
1355
1356fn compute_numeric_stats(series: &Series) -> Result<(f64, f64, usize)> {
1357 let casted = series.cast(&DataType::Float64)?;
1358 let values: Vec<f64> = casted
1359 .f64()
1360 .context("numeric series not f64")?
1361 .into_no_null_iter()
1362 .collect();
1363 if values.is_empty() {
1364 return Err(anyhow!("no numeric values to compute stats"));
1365 }
1366
1367 let mut total = 0.0;
1368 for value in &values {
1369 total += *value;
1370 }
1371 let mean = total / values.len() as f64;
1372
1373 let mut variance_sum = 0.0;
1374 for value in &values {
1375 let diff = *value - mean;
1376 variance_sum += diff * diff;
1377 }
1378 let std = (variance_sum / values.len() as f64).sqrt();
1379
1380 let outliers = if std > 0.0 {
1381 values
1382 .iter()
1383 .filter(|value| (*value - mean).abs() > 3.0 * std)
1384 .count()
1385 } else {
1386 0
1387 };
1388
1389 Ok((mean, std, outliers))
1390}
1391
1392fn split_feature_columns(df: &DataFrame, target: &str) -> (Vec<String>, Vec<String>) {
1393 let mut numeric = Vec::new();
1394 let mut categorical = Vec::new();
1395 for series in df.get_columns() {
1396 let name = series.name();
1397 if name == target {
1398 continue;
1399 }
1400 if is_numeric_dtype(series.dtype()) {
1401 numeric.push(name.to_string());
1402 } else {
1403 categorical.push(name.to_string());
1404 }
1405 }
1406 (numeric, categorical)
1407}
1408
1409fn select_target_column(df: &DataFrame) -> Result<(String, Series)> {
1410 if let Ok(col) = df.column(TARGET_COLUMN) {
1411 return Ok((
1412 TARGET_COLUMN.to_string(),
1413 col.as_materialized_series().clone(),
1414 ));
1415 }
1416
1417 let mut numeric = df
1418 .get_columns()
1419 .iter()
1420 .filter(|series| is_numeric_dtype(series.dtype()))
1421 .cloned()
1422 .collect::<Vec<_>>();
1423
1424 let fallback = numeric
1425 .pop()
1426 .ok_or_else(|| anyhow!("no numeric columns available for target"))?;
1427 let series = fallback.as_materialized_series().clone();
1428 Ok((series.name().to_string(), series))
1429}
1430
1431fn get_numeric_series(df: &DataFrame, name: &str) -> Result<Series> {
1432 let series = df
1433 .column(name)
1434 .map_err(|_| anyhow!("missing target column {}", name))?
1435 .as_materialized_series();
1436 let casted = series.cast(&DataType::Float64)?;
1437 Ok(casted)
1438}
1439
1440fn mean_of_series(series: &Series) -> Result<f64> {
1441 let casted = series.cast(&DataType::Float64)?;
1442 let values = casted
1443 .f64()
1444 .context("target column not f64")?
1445 .into_no_null_iter();
1446 let mut total = 0.0;
1447 let mut count = 0usize;
1448 for value in values {
1449 total += value;
1450 count += 1;
1451 }
1452 if count == 0 {
1453 return Err(anyhow!("no values to compute mean"));
1454 }
1455 Ok(total / count as f64)
1456}
1457
1458fn mean_abs_error(target: &Series, mean: f64) -> Result<f64> {
1459 let casted = target.cast(&DataType::Float64)?;
1460 let values = casted
1461 .f64()
1462 .context("target column not f64")?
1463 .into_no_null_iter();
1464 let mut total = 0.0;
1465 let mut count = 0usize;
1466 for value in values {
1467 total += (value - mean).abs();
1468 count += 1;
1469 }
1470 if count == 0 {
1471 return Err(anyhow!("no values to evaluate"));
1472 }
1473 Ok(total / count as f64)
1474}
1475
1476fn mean_abs_value(target: &Series) -> Result<f64> {
1477 let casted = target.cast(&DataType::Float64)?;
1478 let values = casted
1479 .f64()
1480 .context("target column not f64")?
1481 .into_no_null_iter();
1482 let mut total = 0.0;
1483 let mut count = 0usize;
1484 for value in values {
1485 total += value.abs();
1486 count += 1;
1487 }
1488 if count == 0 {
1489 return Err(anyhow!("no values to evaluate"));
1490 }
1491 Ok(total / count as f64)
1492}
1493fn is_numeric_dtype(dtype: &DataType) -> bool {
1494 matches!(
1495 dtype,
1496 DataType::Int8
1497 | DataType::Int16
1498 | DataType::Int32
1499 | DataType::Int64
1500 | DataType::UInt8
1501 | DataType::UInt16
1502 | DataType::UInt32
1503 | DataType::UInt64
1504 | DataType::Float32
1505 | DataType::Float64
1506 )
1507}