crucible/training/
pipeline.rs1use converge_pack::{AgentEffect, Context, ContextKey, Provenance, ProvenanceSource, Suggestor};
4use std::fs::create_dir_all;
5use std::path::PathBuf;
6
7use crate::provenance::CRUCIBLE_PROVENANCE;
8
9use super::features::apply_feature_spec;
10use super::io::{load_dataframe, mean_of_series, select_target_column, write_json};
11use super::types::{
12 BaselineModel, ModelMetadata, diagnostic, has_model_for_iteration, proposal,
13 read_feature_spec_from_ctx, read_latest_split_from_ctx,
14};
15
16#[derive(Debug)]
17pub struct ModelTrainingAgent {
18 model_dir: PathBuf,
19}
20
21impl ModelTrainingAgent {
22 pub fn new(model_dir: PathBuf) -> Self {
23 Self { model_dir }
24 }
25
26 fn model_path(&self) -> PathBuf {
27 self.model_dir.join("baseline_mean.json")
28 }
29}
30
31#[async_trait::async_trait]
32impl Suggestor for ModelTrainingAgent {
33 fn name(&self) -> &'static str {
34 "ModelTrainingAgent (Baseline)"
35 }
36
37 fn dependencies(&self) -> &[ContextKey] {
38 &[ContextKey::Signals]
39 }
40
41 fn accepts(&self, ctx: &dyn Context) -> bool {
42 if !ctx.has(ContextKey::Signals) {
43 return false;
44 }
45 let split = match read_latest_split_from_ctx(ctx) {
46 Ok(split) => split,
47 Err(_) => return false,
48 };
49 !has_model_for_iteration(ctx, split.iteration)
50 }
51
52 fn provenance(&self) -> Provenance {
53 Provenance::from(CRUCIBLE_PROVENANCE.as_str())
54 }
55
56 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
57 let split = match read_latest_split_from_ctx(ctx) {
58 Ok(split) => split,
59 Err(err) => {
60 return AgentEffect::with_proposal(diagnostic(
61 self.name(),
62 ContextKey::Diagnostic,
63 "model-training-error",
64 err.to_string(),
65 ));
66 }
67 };
68
69 if let Err(err) = create_dir_all(&self.model_dir) {
70 return AgentEffect::with_proposal(diagnostic(
71 self.name(),
72 ContextKey::Diagnostic,
73 "model-training-error",
74 err.to_string(),
75 ));
76 }
77
78 let raw_train_df = match load_dataframe(&split.train_path) {
79 Ok(df) => df,
80 Err(err) => {
81 return AgentEffect::with_proposal(diagnostic(
82 self.name(),
83 ContextKey::Diagnostic,
84 "model-training-error",
85 err.to_string(),
86 ));
87 }
88 };
89
90 let train_df = match read_feature_spec_from_ctx(ctx, split.iteration) {
92 Some(spec) => match apply_feature_spec(&raw_train_df, &spec) {
93 Ok(df) => df,
94 Err(err) => {
95 return AgentEffect::with_proposal(diagnostic(
96 self.name(),
97 ContextKey::Diagnostic,
98 "model-training-error",
99 format!("feature spec application failed: {}", err),
100 ));
101 }
102 },
103 None => raw_train_df,
104 };
105
106 let (target_name, target) = match select_target_column(&train_df) {
107 Ok(value) => value,
108 Err(err) => {
109 return AgentEffect::with_proposal(diagnostic(
110 self.name(),
111 ContextKey::Diagnostic,
112 "model-training-error",
113 err.to_string(),
114 ));
115 }
116 };
117
118 let mean = match mean_of_series(&target) {
119 Ok(value) => value,
120 Err(err) => {
121 return AgentEffect::with_proposal(diagnostic(
122 self.name(),
123 ContextKey::Diagnostic,
124 "model-training-error",
125 err.to_string(),
126 ));
127 }
128 };
129
130 let model = BaselineModel {
131 target_column: target_name.clone(),
132 mean,
133 };
134
135 let model_path = self.model_path();
136 if let Err(err) = write_json(&model_path, &model) {
137 return AgentEffect::with_proposal(diagnostic(
138 self.name(),
139 ContextKey::Diagnostic,
140 "model-training-error",
141 err.to_string(),
142 ));
143 }
144
145 let meta = ModelMetadata {
146 model_path,
147 target_column: target_name,
148 train_rows: split.train_rows,
149 baseline_mean: mean,
150 iteration: split.iteration,
151 };
152
153 AgentEffect::with_proposal(proposal(
154 self.name(),
155 ContextKey::Strategies,
156 format!("trained-model-{}", split.iteration),
157 meta,
158 ))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn model_training_agent_model_path() {
168 let agent = ModelTrainingAgent::new(PathBuf::from("/tmp/models"));
169 assert_eq!(
170 agent.model_path(),
171 PathBuf::from("/tmp/models/baseline_mean.json")
172 );
173 }
174}