1use std::collections::BTreeMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{json, Value};
5use thiserror::Error;
6
7#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
9pub struct DagMlErrorDescriptor {
10 pub category: String,
12 pub code: String,
14 pub severity: String,
16 pub message: String,
18 pub remediation_hint: String,
20 pub context: BTreeMap<String, Value>,
22}
23
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
25pub struct OofLeakageViolation {
26 pub producer_node: String,
27 pub partition: String,
28 pub fold_id: Option<String>,
29}
30
31#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
32pub struct OofLeakageReport {
33 pub node_id: String,
34 pub violators: Vec<OofLeakageViolation>,
35 pub allow_train_predictions_as_features: bool,
36 pub remediation: String,
37}
38
39#[derive(Debug, Error)]
40pub enum DagMlError {
41 #[error("invalid identifier `{value}`: {reason}")]
42 InvalidIdentifier { value: String, reason: &'static str },
43
44 #[error("graph validation failed: {0}")]
45 GraphValidation(String),
46
47 #[error("controller validation failed: {0}")]
48 ControllerValidation(String),
49
50 #[error("campaign validation failed: {0}")]
51 CampaignValidation(String),
52
53 #[error("planning failed: {0}")]
54 Planning(String),
55
56 #[error("runtime validation failed: {0}")]
57 RuntimeValidation(String),
58
59 #[error("OOF validation failed: {0}")]
60 OofValidation(String),
61
62 #[error("OOF leakage at `{}`: {} violator(s); {}", .0.node_id, .0.violators.len(), .0.remediation)]
63 OofLeakage(Box<OofLeakageReport>),
64
65 #[error("serialization error: {0}")]
66 Serialization(#[from] serde_json::Error),
67}
68
69impl DagMlError {
70 pub fn category(&self) -> &'static str {
72 self.taxonomy_parts().0
73 }
74
75 pub fn code(&self) -> &'static str {
77 self.taxonomy_parts().1
78 }
79
80 pub fn severity(&self) -> &'static str {
82 self.taxonomy_parts().2
83 }
84
85 pub fn remediation_hint(&self) -> String {
87 match self {
88 Self::InvalidIdentifier { .. } => {
89 "Use a non-empty stable identifier that matches the dag-ml identifier grammar."
90 .to_string()
91 }
92 Self::GraphValidation(_) => {
93 "Fix the graph contract violation before planning or running the pipeline."
94 .to_string()
95 }
96 Self::ControllerValidation(_) => {
97 "Register or configure the controller so it matches the graph contract."
98 .to_string()
99 }
100 Self::CampaignValidation(_) => {
101 "Fix the campaign template so its folds, metrics and graph references are consistent."
102 .to_string()
103 }
104 Self::Planning(_) => {
105 "Inspect the graph and campaign constraints, then re-plan with a compatible execution request."
106 .to_string()
107 }
108 Self::RuntimeValidation(_) => {
109 "Inspect the runtime inputs and produced artifacts, then rerun the failed step with compatible values."
110 .to_string()
111 }
112 Self::OofValidation(_) => {
113 "Use validated out-of-fold contracts and keep training predictions out of feature inputs unless explicitly allowed."
114 .to_string()
115 }
116 Self::OofLeakage(report) => report.remediation.clone(),
117 Self::Serialization(_) => {
118 "Check that the JSON or YAML payload matches the supported dag-ml contract version."
119 .to_string()
120 }
121 }
122 }
123
124 pub fn context(&self) -> BTreeMap<String, Value> {
126 let mut context = BTreeMap::new();
127 match self {
128 Self::InvalidIdentifier { value, reason } => {
129 context.insert("value".to_string(), json!(value));
130 context.insert("reason".to_string(), json!(reason));
131 }
132 Self::GraphValidation(detail)
133 | Self::ControllerValidation(detail)
134 | Self::CampaignValidation(detail)
135 | Self::Planning(detail)
136 | Self::RuntimeValidation(detail)
137 | Self::OofValidation(detail) => {
138 context.insert("detail".to_string(), json!(detail));
139 }
140 Self::OofLeakage(report) => {
141 context.insert("node_id".to_string(), json!(report.node_id));
142 context.insert("violator_count".to_string(), json!(report.violators.len()));
143 context.insert(
144 "allow_train_predictions_as_features".to_string(),
145 json!(report.allow_train_predictions_as_features),
146 );
147 context.insert("violators".to_string(), json!(report.violators));
148 }
149 Self::Serialization(error) => {
150 context.insert("detail".to_string(), json!(error.to_string()));
151 }
152 }
153 context
154 }
155
156 pub fn descriptor(&self) -> DagMlErrorDescriptor {
158 DagMlErrorDescriptor {
159 category: self.category().to_string(),
160 code: self.code().to_string(),
161 severity: self.severity().to_string(),
162 message: self.to_string(),
163 remediation_hint: self.remediation_hint(),
164 context: self.context(),
165 }
166 }
167
168 pub fn descriptor_json(&self) -> std::result::Result<String, serde_json::Error> {
170 serde_json::to_string(&self.descriptor())
171 }
172
173 pub fn error_code(&self) -> u32 {
177 let (category_id, code_id) = self.numeric_taxonomy();
178 (u32::from(category_id) << 16) | u32::from(code_id)
179 }
180
181 fn taxonomy_parts(&self) -> (&'static str, &'static str, &'static str) {
182 match self {
183 Self::InvalidIdentifier { .. } => ("validation", "invalid_identifier", "error"),
184 Self::GraphValidation(_) => ("validation", "graph_validation", "error"),
185 Self::ControllerValidation(_) => ("controller", "controller_validation", "error"),
186 Self::CampaignValidation(_) => ("validation", "campaign_validation", "error"),
187 Self::Planning(_) => ("runtime", "planning_failed", "error"),
188 Self::RuntimeValidation(_) => ("runtime", "runtime_validation", "error"),
189 Self::OofValidation(_) => ("validation", "oof_validation", "error"),
190 Self::OofLeakage(_) => ("validation", "oof_leakage", "error"),
191 Self::Serialization(_) => ("compatibility", "serialization_error", "error"),
192 }
193 }
194
195 fn numeric_taxonomy(&self) -> (u16, u16) {
203 match self {
204 Self::InvalidIdentifier { .. } => (0, 1),
205 Self::GraphValidation(_) => (0, 2),
206 Self::CampaignValidation(_) => (0, 3),
207 Self::OofValidation(_) => (0, 4),
208 Self::OofLeakage(_) => (0, 5),
209 Self::ControllerValidation(_) => (3, 1),
210 Self::Planning(_) => (1, 1),
211 Self::RuntimeValidation(_) => (1, 2),
212 Self::Serialization(_) => (8, 1),
213 }
214 }
215}
216
217pub type Result<T> = std::result::Result<T, DagMlError>;
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn invalid_identifier_descriptor_carries_taxonomy_and_context() {
225 let error = DagMlError::InvalidIdentifier {
226 value: "".to_string(),
227 reason: "empty",
228 };
229
230 let descriptor = error.descriptor();
231
232 assert_eq!(descriptor.category, "validation");
233 assert_eq!(descriptor.code, "invalid_identifier");
234 assert_eq!(descriptor.severity, "error");
235 assert_eq!(descriptor.context["value"], json!(""));
236 assert_eq!(descriptor.context["reason"], json!("empty"));
237 assert!(descriptor.remediation_hint.contains("identifier"));
238 }
239
240 #[test]
241 fn oof_leakage_descriptor_preserves_report_context() {
242 let error = DagMlError::OofLeakage(Box::new(OofLeakageReport {
243 node_id: "node:model".to_string(),
244 violators: vec![OofLeakageViolation {
245 producer_node: "node:prep".to_string(),
246 partition: "train".to_string(),
247 fold_id: Some("fold:0".to_string()),
248 }],
249 allow_train_predictions_as_features: false,
250 remediation: "Use validation-only OOF predictions.".to_string(),
251 }));
252
253 let descriptor = error.descriptor();
254
255 assert_eq!(descriptor.category, "validation");
256 assert_eq!(descriptor.code, "oof_leakage");
257 assert_eq!(
258 descriptor.remediation_hint,
259 "Use validation-only OOF predictions."
260 );
261 assert_eq!(descriptor.context["node_id"], json!("node:model"));
262 assert_eq!(descriptor.context["violator_count"], json!(1));
263 assert_eq!(
264 descriptor.context["allow_train_predictions_as_features"],
265 json!(false)
266 );
267 }
268
269 #[test]
270 fn error_code_packs_category_and_code() {
271 assert_eq!(
273 DagMlError::InvalidIdentifier {
274 value: "x".to_string(),
275 reason: "bad",
276 }
277 .error_code(),
278 0x0000_0001
279 );
280 assert!(
281 DagMlError::InvalidIdentifier {
282 value: "x".to_string(),
283 reason: "bad",
284 }
285 .error_code()
286 != 0
287 );
288 assert_eq!(
290 DagMlError::GraphValidation("x".to_string()).error_code(),
291 0x0000_0002
292 );
293 assert_eq!(
295 DagMlError::ControllerValidation("x".to_string()).error_code(),
296 0x0003_0001
297 );
298 assert_eq!(
300 DagMlError::RuntimeValidation("x".to_string()).error_code(),
301 0x0001_0002
302 );
303 let serde_error = serde_json::from_str::<Value>("{").expect_err("invalid JSON");
305 assert_eq!(
306 DagMlError::Serialization(serde_error).error_code(),
307 0x0008_0001
308 );
309 }
310
311 #[test]
312 fn descriptor_json_is_stable_json_payload() {
313 let serde_error = serde_json::from_str::<Value>("{").expect_err("invalid JSON");
314 let error = DagMlError::Serialization(serde_error);
315
316 let payload = error.descriptor_json().expect("descriptor JSON");
317 let parsed = serde_json::from_str::<Value>(&payload).expect("parse descriptor");
318
319 assert_eq!(parsed["category"], json!("compatibility"));
320 assert_eq!(parsed["code"], json!("serialization_error"));
321 assert!(parsed["message"]
322 .as_str()
323 .expect("message")
324 .contains("serialization error"));
325 assert!(parsed["remediation_hint"]
326 .as_str()
327 .expect("hint")
328 .contains("contract version"));
329 }
330}