1use anyhow::{Context as _, Result, anyhow};
4use converge_pack::{AgentEffect, Context, ContextKey, Provenance, ProvenanceSource, Suggestor};
5use polars::prelude::*;
6
7use crate::provenance::CRUCIBLE_PROVENANCE;
8
9use super::io::{
10 compute_mean_std, compute_numeric_stats, is_numeric_dtype, load_dataframe,
11 select_target_column, split_feature_columns,
12};
13use super::types::{
14 DataQualityReport, FeatureInteraction, FeatureSpec, HyperparameterSearchPlan,
15 HyperparameterSearchResult, TrainingPlan, diagnostic, drift_score_from_ctx,
16 has_data_quality_for_iteration, has_feature_spec_for_iteration,
17 has_hyperparam_result_for_iteration, proposal, read_latest_plan_from_ctx,
18 read_latest_split_from_ctx,
19};
20
21use std::collections::HashMap;
22
23#[derive(Debug, Default)]
24pub struct DataValidationAgent;
25
26impl DataValidationAgent {
27 pub fn new() -> Self {
28 Self
29 }
30}
31
32#[async_trait::async_trait]
33impl Suggestor for DataValidationAgent {
34 fn name(&self) -> &'static str {
35 "DataValidationAgent"
36 }
37
38 fn dependencies(&self) -> &[ContextKey] {
39 &[ContextKey::Signals]
40 }
41
42 fn accepts(&self, ctx: &dyn Context) -> bool {
43 ctx.has(ContextKey::Signals)
44 && match read_latest_split_from_ctx(ctx) {
45 Ok(split) => !has_data_quality_for_iteration(ctx, split.iteration),
46 Err(_) => false,
47 }
48 }
49
50 fn provenance(&self) -> Provenance {
51 Provenance::from(CRUCIBLE_PROVENANCE.as_str())
52 }
53
54 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
55 let split = match read_latest_split_from_ctx(ctx) {
56 Ok(split) => split,
57 Err(err) => {
58 return AgentEffect::with_proposal(diagnostic(
59 self.name(),
60 ContextKey::Diagnostic,
61 "data-validation-error",
62 err.to_string(),
63 ));
64 }
65 };
66
67 let df = match load_dataframe(&split.train_path) {
68 Ok(df) => df,
69 Err(err) => {
70 return AgentEffect::with_proposal(diagnostic(
71 self.name(),
72 ContextKey::Diagnostic,
73 "data-validation-error",
74 err.to_string(),
75 ));
76 }
77 };
78
79 let rows = df.height();
80 let mut missingness = HashMap::new();
81 let mut numeric_means = HashMap::new();
82 let mut outlier_counts = HashMap::new();
83
84 for series in df.get_columns() {
85 let name = series.name().to_string();
86 let null_ratio = if rows > 0 {
87 series.null_count() as f64 / rows as f64
88 } else {
89 0.0
90 };
91 missingness.insert(name.clone(), null_ratio);
92
93 if is_numeric_dtype(series.dtype())
94 && let Ok((mean, _std, outliers)) =
95 compute_numeric_stats(series.as_materialized_series())
96 {
97 numeric_means.insert(name.clone(), mean);
98 outlier_counts.insert(name, outliers);
99 }
100 }
101
102 let drift_score = drift_score_from_ctx(ctx, split.iteration, &numeric_means);
103
104 let report = DataQualityReport {
105 kind: "data_quality".to_string(),
106 iteration: split.iteration,
107 source_path: split.train_path.clone(),
108 rows_checked: rows,
109 missingness,
110 numeric_means,
111 outlier_counts,
112 drift_score,
113 };
114
115 AgentEffect::with_proposal(proposal(
116 self.name(),
117 ContextKey::Signals,
118 format!("data-quality-{}", split.iteration),
119 report,
120 ))
121 }
122}
123
124#[derive(Debug, Default)]
125pub struct FeatureEngineeringAgent;
126
127impl FeatureEngineeringAgent {
128 pub fn new() -> Self {
129 Self
130 }
131}
132
133#[async_trait::async_trait]
134impl Suggestor for FeatureEngineeringAgent {
135 fn name(&self) -> &'static str {
136 "FeatureEngineeringAgent"
137 }
138
139 fn dependencies(&self) -> &[ContextKey] {
140 &[ContextKey::Signals]
141 }
142
143 fn accepts(&self, ctx: &dyn Context) -> bool {
144 ctx.has(ContextKey::Signals)
145 && match read_latest_split_from_ctx(ctx) {
146 Ok(split) => !has_feature_spec_for_iteration(ctx, split.iteration),
147 Err(_) => false,
148 }
149 }
150
151 fn provenance(&self) -> Provenance {
152 Provenance::from(CRUCIBLE_PROVENANCE.as_str())
153 }
154
155 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
156 let split = match read_latest_split_from_ctx(ctx) {
157 Ok(split) => split,
158 Err(err) => {
159 return AgentEffect::with_proposal(diagnostic(
160 self.name(),
161 ContextKey::Diagnostic,
162 "feature-engineering-error",
163 err.to_string(),
164 ));
165 }
166 };
167
168 let df = match load_dataframe(&split.train_path) {
169 Ok(df) => df,
170 Err(err) => {
171 return AgentEffect::with_proposal(diagnostic(
172 self.name(),
173 ContextKey::Diagnostic,
174 "feature-engineering-error",
175 err.to_string(),
176 ));
177 }
178 };
179
180 let (target_column, _) = match select_target_column(&df) {
181 Ok(value) => value,
182 Err(err) => {
183 return AgentEffect::with_proposal(diagnostic(
184 self.name(),
185 ContextKey::Diagnostic,
186 "feature-engineering-error",
187 err.to_string(),
188 ));
189 }
190 };
191
192 let (numeric_features, categorical_features) = split_feature_columns(&df, &target_column);
193
194 let mut interactions = Vec::new();
195 if numeric_features.len() >= 2 {
196 interactions.push(FeatureInteraction {
197 name: format!("{}_x_{}", numeric_features[0], numeric_features[1]),
198 left: numeric_features[0].clone(),
199 right: numeric_features[1].clone(),
200 op: "multiply".to_string(),
201 });
202 }
203
204 let spec = FeatureSpec {
205 kind: "feature_spec".to_string(),
206 iteration: split.iteration,
207 target_column,
208 numeric_features,
209 categorical_features,
210 normalization: "standardize".to_string(),
211 interactions,
212 };
213
214 AgentEffect::with_proposal(proposal(
215 self.name(),
216 ContextKey::Constraints,
217 format!("feature-spec-{}", split.iteration),
218 spec,
219 ))
220 }
221}
222
223#[derive(Debug)]
224pub struct HyperparameterSearchAgent {
225 pub max_trials: usize,
226}
227
228impl HyperparameterSearchAgent {
229 pub fn new(max_trials: usize) -> Self {
230 Self { max_trials }
231 }
232}
233
234#[async_trait::async_trait]
235impl Suggestor for HyperparameterSearchAgent {
236 fn name(&self) -> &'static str {
237 "HyperparameterSearchAgent"
238 }
239
240 fn dependencies(&self) -> &[ContextKey] {
241 &[ContextKey::Constraints, ContextKey::Signals]
242 }
243
244 fn accepts(&self, ctx: &dyn Context) -> bool {
245 ctx.has(ContextKey::Signals)
246 && match read_latest_split_from_ctx(ctx) {
247 Ok(split) => !has_hyperparam_result_for_iteration(ctx, split.iteration),
248 Err(_) => false,
249 }
250 }
251
252 fn provenance(&self) -> Provenance {
253 Provenance::from(CRUCIBLE_PROVENANCE.as_str())
254 }
255
256 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
257 let split = match read_latest_split_from_ctx(ctx) {
258 Ok(split) => split,
259 Err(err) => {
260 return AgentEffect::with_proposal(diagnostic(
261 self.name(),
262 ContextKey::Diagnostic,
263 "hyperparam-search-error",
264 err.to_string(),
265 ));
266 }
267 };
268
269 let training_plan = read_latest_plan_from_ctx(ctx).unwrap_or(TrainingPlan {
270 iteration: split.iteration,
271 max_rows: split.max_rows,
272 train_fraction: 0.8,
273 val_fraction: 0.15,
274 infer_fraction: 0.05,
275 quality_threshold: 0.75,
276 });
277
278 let mut params = HashMap::new();
279 params.insert("learning_rate".to_string(), vec![0.001, 0.01, 0.1]);
280 params.insert("hidden_size".to_string(), vec![8.0, 16.0, 32.0]);
281
282 let plan = HyperparameterSearchPlan {
283 kind: "hyperparam_plan".to_string(),
284 iteration: split.iteration,
285 max_trials: self.max_trials,
286 early_stopping: true,
287 params,
288 };
289
290 let mut best_params = HashMap::new();
291 best_params.insert("learning_rate".to_string(), 0.01);
292 best_params.insert("hidden_size".to_string(), 16.0);
293 let score = (1.0 - training_plan.quality_threshold) * plan.max_trials as f64
294 / plan.iteration.max(1) as f64;
295 let result = HyperparameterSearchResult {
296 kind: "hyperparam_result".to_string(),
297 iteration: split.iteration,
298 best_params,
299 score,
300 };
301
302 AgentEffect::builder()
303 .proposal(proposal(
304 self.name(),
305 ContextKey::Constraints,
306 format!("hyperparam-plan-{}", split.iteration),
307 plan,
308 ))
309 .proposal(proposal(
310 self.name(),
311 ContextKey::Evaluations,
312 format!("hyperparam-result-{}", split.iteration),
313 result,
314 ))
315 .build()
316 }
317}
318
319pub fn apply_feature_spec(df: &DataFrame, spec: &FeatureSpec) -> Result<DataFrame> {
321 let mut result = df.clone();
322
323 for interaction in &spec.interactions {
325 let left_col = result
326 .column(&interaction.left)
327 .map_err(|_| anyhow!("missing column {} for interaction", interaction.left))?
328 .cast(&DataType::Float64)?;
329 let right_col = result
330 .column(&interaction.right)
331 .map_err(|_| anyhow!("missing column {} for interaction", interaction.right))?
332 .cast(&DataType::Float64)?;
333
334 let left_vals = left_col.f64().context("left column not f64")?;
335 let right_vals = right_col.f64().context("right column not f64")?;
336
337 let interaction_series = match interaction.op.as_str() {
338 "multiply" => left_vals * right_vals,
339 "add" => left_vals + right_vals,
340 "subtract" => left_vals - right_vals,
341 "divide" => {
342 left_vals
344 .into_iter()
345 .zip(right_vals)
346 .map(|(l, r)| match (l, r) {
347 (Some(lv), Some(rv)) if rv.abs() > 1e-10 => Some(lv / rv),
348 _ => None,
349 })
350 .collect::<Float64Chunked>()
351 }
352 _ => return Err(anyhow!("unsupported interaction op: {}", interaction.op)),
353 };
354
355 let named_series = interaction_series.with_name(interaction.name.clone().into());
356 result = result
357 .hstack(&[named_series.into_series().into()])
358 .context("failed to add interaction column")?;
359 }
360
361 if spec.normalization == "standardize" {
363 for col_name in &spec.numeric_features {
364 if let Ok(col) = result.column(col_name) {
365 let casted = col.cast(&DataType::Float64)?;
366 let values = casted.f64().context("column not f64")?;
367
368 let (mean, std) = compute_mean_std(values)?;
370
371 if std > 0.0 {
372 let standardized = (values - mean) / std;
374 let named = standardized.with_name(col_name.clone().into());
375
376 result = result.drop(col_name)?;
378 result = result
379 .hstack(&[named.into_series().into()])
380 .context("failed to replace standardized column")?;
381 }
382 }
383 }
384 }
385
386 Ok(result)
387}