Skip to main content

crucible/training/
pipeline.rs

1// Copyright 2024-2026 Reflective Labs
2
3use converge_pack::{AgentEffect, Context, ContextKey, Provenance, ProvenanceSource, Suggestor};
4use std::fs::create_dir_all;
5use std::path::PathBuf;
6
7use crate::provenance::CRUCIBLE_PROVENANCE;
8
9use super::features::apply_feature_spec;
10use super::io::{load_dataframe, mean_of_series, select_target_column, write_json};
11use super::types::{
12    BaselineModel, ModelMetadata, diagnostic, has_model_for_iteration, proposal,
13    read_feature_spec_from_ctx, read_latest_split_from_ctx,
14};
15
16#[derive(Debug)]
17pub struct ModelTrainingAgent {
18    model_dir: PathBuf,
19}
20
21impl ModelTrainingAgent {
22    pub fn new(model_dir: PathBuf) -> Self {
23        Self { model_dir }
24    }
25
26    fn model_path(&self) -> PathBuf {
27        self.model_dir.join("baseline_mean.json")
28    }
29}
30
31#[async_trait::async_trait]
32impl Suggestor for ModelTrainingAgent {
33    fn name(&self) -> &'static str {
34        "ModelTrainingAgent (Baseline)"
35    }
36
37    fn dependencies(&self) -> &[ContextKey] {
38        &[ContextKey::Signals]
39    }
40
41    fn accepts(&self, ctx: &dyn Context) -> bool {
42        if !ctx.has(ContextKey::Signals) {
43            return false;
44        }
45        let split = match read_latest_split_from_ctx(ctx) {
46            Ok(split) => split,
47            Err(_) => return false,
48        };
49        !has_model_for_iteration(ctx, split.iteration)
50    }
51
52    fn provenance(&self) -> Provenance {
53        Provenance::from(CRUCIBLE_PROVENANCE.as_str())
54    }
55
56    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
57        let split = match read_latest_split_from_ctx(ctx) {
58            Ok(split) => split,
59            Err(err) => {
60                return AgentEffect::with_proposal(diagnostic(
61                    self.name(),
62                    ContextKey::Diagnostic,
63                    "model-training-error",
64                    err.to_string(),
65                ));
66            }
67        };
68
69        if let Err(err) = create_dir_all(&self.model_dir) {
70            return AgentEffect::with_proposal(diagnostic(
71                self.name(),
72                ContextKey::Diagnostic,
73                "model-training-error",
74                err.to_string(),
75            ));
76        }
77
78        let raw_train_df = match load_dataframe(&split.train_path) {
79            Ok(df) => df,
80            Err(err) => {
81                return AgentEffect::with_proposal(diagnostic(
82                    self.name(),
83                    ContextKey::Diagnostic,
84                    "model-training-error",
85                    err.to_string(),
86                ));
87            }
88        };
89
90        // Apply FeatureSpec transformation if available
91        let train_df = match read_feature_spec_from_ctx(ctx, split.iteration) {
92            Some(spec) => match apply_feature_spec(&raw_train_df, &spec) {
93                Ok(df) => df,
94                Err(err) => {
95                    return AgentEffect::with_proposal(diagnostic(
96                        self.name(),
97                        ContextKey::Diagnostic,
98                        "model-training-error",
99                        format!("feature spec application failed: {}", err),
100                    ));
101                }
102            },
103            None => raw_train_df,
104        };
105
106        let (target_name, target) = match select_target_column(&train_df) {
107            Ok(value) => value,
108            Err(err) => {
109                return AgentEffect::with_proposal(diagnostic(
110                    self.name(),
111                    ContextKey::Diagnostic,
112                    "model-training-error",
113                    err.to_string(),
114                ));
115            }
116        };
117
118        let mean = match mean_of_series(&target) {
119            Ok(value) => value,
120            Err(err) => {
121                return AgentEffect::with_proposal(diagnostic(
122                    self.name(),
123                    ContextKey::Diagnostic,
124                    "model-training-error",
125                    err.to_string(),
126                ));
127            }
128        };
129
130        let model = BaselineModel {
131            target_column: target_name.clone(),
132            mean,
133        };
134
135        let model_path = self.model_path();
136        if let Err(err) = write_json(&model_path, &model) {
137            return AgentEffect::with_proposal(diagnostic(
138                self.name(),
139                ContextKey::Diagnostic,
140                "model-training-error",
141                err.to_string(),
142            ));
143        }
144
145        let meta = ModelMetadata {
146            model_path,
147            target_column: target_name,
148            train_rows: split.train_rows,
149            baseline_mean: mean,
150            iteration: split.iteration,
151        };
152
153        AgentEffect::with_proposal(proposal(
154            self.name(),
155            ContextKey::Strategies,
156            format!("trained-model-{}", split.iteration),
157            meta,
158        ))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn model_training_agent_model_path() {
168        let agent = ModelTrainingAgent::new(PathBuf::from("/tmp/models"));
169        assert_eq!(
170            agent.model_path(),
171            PathBuf::from("/tmp/models/baseline_mean.json")
172        );
173    }
174}