Skip to main content

crucible/training/
features.rs

1// Copyright 2024-2026 Reflective Labs
2
3use anyhow::{Context as _, Result, anyhow};
4use converge_pack::{AgentEffect, Context, ContextKey, Provenance, ProvenanceSource, Suggestor};
5use polars::prelude::*;
6
7use crate::provenance::CRUCIBLE_PROVENANCE;
8
9use super::io::{
10    compute_mean_std, compute_numeric_stats, is_numeric_dtype, load_dataframe,
11    select_target_column, split_feature_columns,
12};
13use super::types::{
14    DataQualityReport, FeatureInteraction, FeatureSpec, HyperparameterSearchPlan,
15    HyperparameterSearchResult, TrainingPlan, diagnostic, drift_score_from_ctx,
16    has_data_quality_for_iteration, has_feature_spec_for_iteration,
17    has_hyperparam_result_for_iteration, proposal, read_latest_plan_from_ctx,
18    read_latest_split_from_ctx,
19};
20
21use std::collections::HashMap;
22
23#[derive(Debug, Default)]
24pub struct DataValidationAgent;
25
26impl DataValidationAgent {
27    pub fn new() -> Self {
28        Self
29    }
30}
31
32#[async_trait::async_trait]
33impl Suggestor for DataValidationAgent {
34    fn name(&self) -> &'static str {
35        "DataValidationAgent"
36    }
37
38    fn dependencies(&self) -> &[ContextKey] {
39        &[ContextKey::Signals]
40    }
41
42    fn accepts(&self, ctx: &dyn Context) -> bool {
43        ctx.has(ContextKey::Signals)
44            && match read_latest_split_from_ctx(ctx) {
45                Ok(split) => !has_data_quality_for_iteration(ctx, split.iteration),
46                Err(_) => false,
47            }
48    }
49
50    fn provenance(&self) -> Provenance {
51        Provenance::from(CRUCIBLE_PROVENANCE.as_str())
52    }
53
54    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
55        let split = match read_latest_split_from_ctx(ctx) {
56            Ok(split) => split,
57            Err(err) => {
58                return AgentEffect::with_proposal(diagnostic(
59                    self.name(),
60                    ContextKey::Diagnostic,
61                    "data-validation-error",
62                    err.to_string(),
63                ));
64            }
65        };
66
67        let df = match load_dataframe(&split.train_path) {
68            Ok(df) => df,
69            Err(err) => {
70                return AgentEffect::with_proposal(diagnostic(
71                    self.name(),
72                    ContextKey::Diagnostic,
73                    "data-validation-error",
74                    err.to_string(),
75                ));
76            }
77        };
78
79        let rows = df.height();
80        let mut missingness = HashMap::new();
81        let mut numeric_means = HashMap::new();
82        let mut outlier_counts = HashMap::new();
83
84        for series in df.get_columns() {
85            let name = series.name().to_string();
86            let null_ratio = if rows > 0 {
87                series.null_count() as f64 / rows as f64
88            } else {
89                0.0
90            };
91            missingness.insert(name.clone(), null_ratio);
92
93            if is_numeric_dtype(series.dtype())
94                && let Ok((mean, _std, outliers)) =
95                    compute_numeric_stats(series.as_materialized_series())
96            {
97                numeric_means.insert(name.clone(), mean);
98                outlier_counts.insert(name, outliers);
99            }
100        }
101
102        let drift_score = drift_score_from_ctx(ctx, split.iteration, &numeric_means);
103
104        let report = DataQualityReport {
105            kind: "data_quality".to_string(),
106            iteration: split.iteration,
107            source_path: split.train_path.clone(),
108            rows_checked: rows,
109            missingness,
110            numeric_means,
111            outlier_counts,
112            drift_score,
113        };
114
115        AgentEffect::with_proposal(proposal(
116            self.name(),
117            ContextKey::Signals,
118            format!("data-quality-{}", split.iteration),
119            report,
120        ))
121    }
122}
123
124#[derive(Debug, Default)]
125pub struct FeatureEngineeringAgent;
126
127impl FeatureEngineeringAgent {
128    pub fn new() -> Self {
129        Self
130    }
131}
132
133#[async_trait::async_trait]
134impl Suggestor for FeatureEngineeringAgent {
135    fn name(&self) -> &'static str {
136        "FeatureEngineeringAgent"
137    }
138
139    fn dependencies(&self) -> &[ContextKey] {
140        &[ContextKey::Signals]
141    }
142
143    fn accepts(&self, ctx: &dyn Context) -> bool {
144        ctx.has(ContextKey::Signals)
145            && match read_latest_split_from_ctx(ctx) {
146                Ok(split) => !has_feature_spec_for_iteration(ctx, split.iteration),
147                Err(_) => false,
148            }
149    }
150
151    fn provenance(&self) -> Provenance {
152        Provenance::from(CRUCIBLE_PROVENANCE.as_str())
153    }
154
155    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
156        let split = match read_latest_split_from_ctx(ctx) {
157            Ok(split) => split,
158            Err(err) => {
159                return AgentEffect::with_proposal(diagnostic(
160                    self.name(),
161                    ContextKey::Diagnostic,
162                    "feature-engineering-error",
163                    err.to_string(),
164                ));
165            }
166        };
167
168        let df = match load_dataframe(&split.train_path) {
169            Ok(df) => df,
170            Err(err) => {
171                return AgentEffect::with_proposal(diagnostic(
172                    self.name(),
173                    ContextKey::Diagnostic,
174                    "feature-engineering-error",
175                    err.to_string(),
176                ));
177            }
178        };
179
180        let (target_column, _) = match select_target_column(&df) {
181            Ok(value) => value,
182            Err(err) => {
183                return AgentEffect::with_proposal(diagnostic(
184                    self.name(),
185                    ContextKey::Diagnostic,
186                    "feature-engineering-error",
187                    err.to_string(),
188                ));
189            }
190        };
191
192        let (numeric_features, categorical_features) = split_feature_columns(&df, &target_column);
193
194        let mut interactions = Vec::new();
195        if numeric_features.len() >= 2 {
196            interactions.push(FeatureInteraction {
197                name: format!("{}_x_{}", numeric_features[0], numeric_features[1]),
198                left: numeric_features[0].clone(),
199                right: numeric_features[1].clone(),
200                op: "multiply".to_string(),
201            });
202        }
203
204        let spec = FeatureSpec {
205            kind: "feature_spec".to_string(),
206            iteration: split.iteration,
207            target_column,
208            numeric_features,
209            categorical_features,
210            normalization: "standardize".to_string(),
211            interactions,
212        };
213
214        AgentEffect::with_proposal(proposal(
215            self.name(),
216            ContextKey::Constraints,
217            format!("feature-spec-{}", split.iteration),
218            spec,
219        ))
220    }
221}
222
223#[derive(Debug)]
224pub struct HyperparameterSearchAgent {
225    pub max_trials: usize,
226}
227
228impl HyperparameterSearchAgent {
229    pub fn new(max_trials: usize) -> Self {
230        Self { max_trials }
231    }
232}
233
234#[async_trait::async_trait]
235impl Suggestor for HyperparameterSearchAgent {
236    fn name(&self) -> &'static str {
237        "HyperparameterSearchAgent"
238    }
239
240    fn dependencies(&self) -> &[ContextKey] {
241        &[ContextKey::Constraints, ContextKey::Signals]
242    }
243
244    fn accepts(&self, ctx: &dyn Context) -> bool {
245        ctx.has(ContextKey::Signals)
246            && match read_latest_split_from_ctx(ctx) {
247                Ok(split) => !has_hyperparam_result_for_iteration(ctx, split.iteration),
248                Err(_) => false,
249            }
250    }
251
252    fn provenance(&self) -> Provenance {
253        Provenance::from(CRUCIBLE_PROVENANCE.as_str())
254    }
255
256    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
257        let split = match read_latest_split_from_ctx(ctx) {
258            Ok(split) => split,
259            Err(err) => {
260                return AgentEffect::with_proposal(diagnostic(
261                    self.name(),
262                    ContextKey::Diagnostic,
263                    "hyperparam-search-error",
264                    err.to_string(),
265                ));
266            }
267        };
268
269        let training_plan = read_latest_plan_from_ctx(ctx).unwrap_or(TrainingPlan {
270            iteration: split.iteration,
271            max_rows: split.max_rows,
272            train_fraction: 0.8,
273            val_fraction: 0.15,
274            infer_fraction: 0.05,
275            quality_threshold: 0.75,
276        });
277
278        let mut params = HashMap::new();
279        params.insert("learning_rate".to_string(), vec![0.001, 0.01, 0.1]);
280        params.insert("hidden_size".to_string(), vec![8.0, 16.0, 32.0]);
281
282        let plan = HyperparameterSearchPlan {
283            kind: "hyperparam_plan".to_string(),
284            iteration: split.iteration,
285            max_trials: self.max_trials,
286            early_stopping: true,
287            params,
288        };
289
290        let mut best_params = HashMap::new();
291        best_params.insert("learning_rate".to_string(), 0.01);
292        best_params.insert("hidden_size".to_string(), 16.0);
293        let score = (1.0 - training_plan.quality_threshold) * plan.max_trials as f64
294            / plan.iteration.max(1) as f64;
295        let result = HyperparameterSearchResult {
296            kind: "hyperparam_result".to_string(),
297            iteration: split.iteration,
298            best_params,
299            score,
300        };
301
302        AgentEffect::builder()
303            .proposal(proposal(
304                self.name(),
305                ContextKey::Constraints,
306                format!("hyperparam-plan-{}", split.iteration),
307                plan,
308            ))
309            .proposal(proposal(
310                self.name(),
311                ContextKey::Evaluations,
312                format!("hyperparam-result-{}", split.iteration),
313                result,
314            ))
315            .build()
316    }
317}
318
319/// Apply a FeatureSpec to a DataFrame, creating interaction features and normalizing
320pub fn apply_feature_spec(df: &DataFrame, spec: &FeatureSpec) -> Result<DataFrame> {
321    let mut result = df.clone();
322
323    // Apply feature interactions
324    for interaction in &spec.interactions {
325        let left_col = result
326            .column(&interaction.left)
327            .map_err(|_| anyhow!("missing column {} for interaction", interaction.left))?
328            .cast(&DataType::Float64)?;
329        let right_col = result
330            .column(&interaction.right)
331            .map_err(|_| anyhow!("missing column {} for interaction", interaction.right))?
332            .cast(&DataType::Float64)?;
333
334        let left_vals = left_col.f64().context("left column not f64")?;
335        let right_vals = right_col.f64().context("right column not f64")?;
336
337        let interaction_series = match interaction.op.as_str() {
338            "multiply" => left_vals * right_vals,
339            "add" => left_vals + right_vals,
340            "subtract" => left_vals - right_vals,
341            "divide" => {
342                // Safe division: use map to handle division safely
343                left_vals
344                    .into_iter()
345                    .zip(right_vals)
346                    .map(|(l, r)| match (l, r) {
347                        (Some(lv), Some(rv)) if rv.abs() > 1e-10 => Some(lv / rv),
348                        _ => None,
349                    })
350                    .collect::<Float64Chunked>()
351            }
352            _ => return Err(anyhow!("unsupported interaction op: {}", interaction.op)),
353        };
354
355        let named_series = interaction_series.with_name(interaction.name.clone().into());
356        result = result
357            .hstack(&[named_series.into_series().into()])
358            .context("failed to add interaction column")?;
359    }
360
361    // Apply normalization to numeric features
362    if spec.normalization == "standardize" {
363        for col_name in &spec.numeric_features {
364            if let Ok(col) = result.column(col_name) {
365                let casted = col.cast(&DataType::Float64)?;
366                let values = casted.f64().context("column not f64")?;
367
368                // Compute mean and std
369                let (mean, std) = compute_mean_std(values)?;
370
371                if std > 0.0 {
372                    // Standardize: (x - mean) / std
373                    let standardized = (values - mean) / std;
374                    let named = standardized.with_name(col_name.clone().into());
375
376                    // Replace the column
377                    result = result.drop(col_name)?;
378                    result = result
379                        .hstack(&[named.into_series().into()])
380                        .context("failed to replace standardized column")?;
381                }
382            }
383        }
384    }
385
386    Ok(result)
387}