training_flow/
training_flow.rs1use anyhow::Result;
5use converge_analytics::training::{
6 DataValidationAgent, DatasetAgent, DatasetSplit, DeploymentAgent, EvaluationReport,
7 FeatureEngineeringAgent, HyperparameterSearchAgent, ModelEvaluationAgent, ModelRegistryAgent,
8 ModelTrainingAgent, MonitoringAgent, SampleInferenceAgent, TrainingPlan,
9};
10use converge_core::{Agent, Context, ContextKey, Fact};
11use serde_json;
12use std::path::PathBuf;
13
14fn main() -> Result<()> {
15 println!("Initializing Converge Analytics training flow...");
16
17 let mut ctx = Context::new();
19 ctx.add_fact(Fact::new(
20 ContextKey::Seeds,
21 "job-1",
22 "Train a baseline model on California Housing",
23 ))?;
24
25 let data_dir = PathBuf::from("data");
27 let model_dir = PathBuf::from("models");
28
29 let dataset_agent = DatasetAgent::new(data_dir);
30 let validator = DataValidationAgent::new();
31 let featurizer = FeatureEngineeringAgent::new();
32 let hyperparam = HyperparameterSearchAgent::new(12);
33 let trainer = ModelTrainingAgent::new(model_dir);
34 let evaluator = ModelEvaluationAgent::new();
35 let registry = ModelRegistryAgent::new();
36 let monitor = MonitoringAgent::new();
37 let inference = SampleInferenceAgent::new(5);
38 let deployment = DeploymentAgent::new();
39
40 let agents: Vec<Box<dyn Agent>> = vec![
41 Box::new(dataset_agent),
42 Box::new(validator),
43 Box::new(featurizer),
44 Box::new(hyperparam),
45 Box::new(trainer),
46 Box::new(evaluator),
47 Box::new(registry),
48 Box::new(monitor),
49 Box::new(inference),
50 Box::new(deployment),
51 ];
52
53 println!("Starting execution loop...");
55 let mut iteration = 1usize;
56 let mut max_rows = 500usize;
57 let quality_threshold = 0.75;
58 let max_iterations = 5;
59 let mut last_iteration = iteration;
60
61 while iteration <= max_iterations {
62 let plan = TrainingPlan {
63 iteration,
64 max_rows,
65 train_fraction: 0.8,
66 val_fraction: 0.15,
67 infer_fraction: 0.05,
68 quality_threshold,
69 };
70 let plan_content = serde_json::to_string(&plan)?;
71 ctx.add_fact(Fact::new(
72 ContextKey::Constraints,
73 format!("training-plan-{}", iteration),
74 plan_content,
75 ))?;
76
77 println!(
78 "\n--- Iteration {} (max_rows={}) ---",
79 iteration, max_rows
80 );
81
82 for cycle in 1..=8 {
83 println!("Cycle {}", cycle);
84 let mut changes = 0;
85
86 for agent in &agents {
87 if agent.accepts(&ctx) {
88 println!("Agent {} is active", agent.name());
89 let effect = agent.execute(&ctx);
90
91 if !effect.is_empty() {
92 for fact in effect.facts {
93 if ctx.add_fact(fact).unwrap_or(false) {
94 changes += 1;
95 }
96 }
97 }
98 }
99 }
100
101 if changes == 0 {
102 break;
103 }
104 }
105
106 let report = match latest_evaluation_for_iteration(&ctx, iteration) {
107 Some(report) => report,
108 None => {
109 println!("No evaluation report for iteration {}", iteration);
110 break;
111 }
112 };
113
114 println!(
115 "Validation MAE: {:.4} | success_ratio: {:.3} | val_rows: {}",
116 report.value, report.success_ratio, report.val_rows
117 );
118
119 if report.success_ratio >= quality_threshold {
120 println!("Quality threshold met. Converged.");
121 last_iteration = iteration;
122 break;
123 }
124
125 let total_rows = latest_total_rows(&ctx).unwrap_or(max_rows);
126 if max_rows >= total_rows {
127 println!("No more data available to improve quality.");
128 break;
129 }
130
131 max_rows = (max_rows * 2).min(total_rows);
132 println!("Quality below threshold, expanding training rows to {}", max_rows);
133 last_iteration = iteration;
134 iteration += 1;
135 }
136
137 if let Some(eval) = latest_evaluation_for_iteration(&ctx, last_iteration) {
138 println!("\nFinal evaluation: {:?}", eval);
139 }
140
141 if let Some(hypo) = ctx.get(ContextKey::Hypotheses).last() {
142 println!("\nInference sample: {}", hypo.content);
143 }
144
145 Ok(())
146}
147
148fn latest_evaluation_for_iteration(
149 ctx: &Context,
150 iteration: usize,
151) -> Option<EvaluationReport> {
152 ctx.get(ContextKey::Evaluations)
153 .iter()
154 .filter_map(|fact| serde_json::from_str::<EvaluationReport>(&fact.content).ok())
155 .find(|report| report.iteration == iteration)
156}
157
158fn latest_total_rows(ctx: &Context) -> Option<usize> {
159 ctx.get(ContextKey::Signals).iter().find_map(|fact| {
160 serde_json::from_str::<DatasetSplit>(&fact.content)
161 .ok()
162 .map(|split| split.total_rows)
163 })
164}