Skip to main content

dag_ml_core/
error.rs

1use std::collections::BTreeMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{json, Value};
5use thiserror::Error;
6
7/// A stable ADR-11 error payload that can be serialized across bindings.
8#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
9pub struct DagMlErrorDescriptor {
10    /// ADR-11 category, for example `validation`, `runtime` or `controller`.
11    pub category: String,
12    /// Stable machine-readable code inside the category.
13    pub code: String,
14    /// Error severity. Current failing variants use `error`.
15    pub severity: String,
16    /// Human-readable error message.
17    pub message: String,
18    /// One-sentence remediation hint suitable for user-facing diagnostics.
19    pub remediation_hint: String,
20    /// Structured debug fields that remain stable enough for logs and tests.
21    pub context: BTreeMap<String, Value>,
22}
23
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
25pub struct OofLeakageViolation {
26    pub producer_node: String,
27    pub partition: String,
28    pub fold_id: Option<String>,
29}
30
31#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
32pub struct OofLeakageReport {
33    pub node_id: String,
34    pub violators: Vec<OofLeakageViolation>,
35    pub allow_train_predictions_as_features: bool,
36    pub remediation: String,
37}
38
39#[derive(Debug, Error)]
40pub enum DagMlError {
41    #[error("invalid identifier `{value}`: {reason}")]
42    InvalidIdentifier { value: String, reason: &'static str },
43
44    #[error("graph validation failed: {0}")]
45    GraphValidation(String),
46
47    #[error("controller validation failed: {0}")]
48    ControllerValidation(String),
49
50    #[error("campaign validation failed: {0}")]
51    CampaignValidation(String),
52
53    #[error("planning failed: {0}")]
54    Planning(String),
55
56    #[error("runtime validation failed: {0}")]
57    RuntimeValidation(String),
58
59    #[error("OOF validation failed: {0}")]
60    OofValidation(String),
61
62    #[error("OOF leakage at `{}`: {} violator(s); {}", .0.node_id, .0.violators.len(), .0.remediation)]
63    OofLeakage(Box<OofLeakageReport>),
64
65    #[error("serialization error: {0}")]
66    Serialization(#[from] serde_json::Error),
67}
68
69impl DagMlError {
70    /// Return the stable ADR-11 category for this error.
71    pub fn category(&self) -> &'static str {
72        self.taxonomy_parts().0
73    }
74
75    /// Return the stable ADR-11 code for this error.
76    pub fn code(&self) -> &'static str {
77        self.taxonomy_parts().1
78    }
79
80    /// Return the ADR-11 severity for this error.
81    pub fn severity(&self) -> &'static str {
82        self.taxonomy_parts().2
83    }
84
85    /// Return the remediation hint associated with this error.
86    pub fn remediation_hint(&self) -> String {
87        match self {
88            Self::InvalidIdentifier { .. } => {
89                "Use a non-empty stable identifier that matches the dag-ml identifier grammar."
90                    .to_string()
91            }
92            Self::GraphValidation(_) => {
93                "Fix the graph contract violation before planning or running the pipeline."
94                    .to_string()
95            }
96            Self::ControllerValidation(_) => {
97                "Register or configure the controller so it matches the graph contract."
98                    .to_string()
99            }
100            Self::CampaignValidation(_) => {
101                "Fix the campaign template so its folds, metrics and graph references are consistent."
102                    .to_string()
103            }
104            Self::Planning(_) => {
105                "Inspect the graph and campaign constraints, then re-plan with a compatible execution request."
106                    .to_string()
107            }
108            Self::RuntimeValidation(_) => {
109                "Inspect the runtime inputs and produced artifacts, then rerun the failed step with compatible values."
110                    .to_string()
111            }
112            Self::OofValidation(_) => {
113                "Use validated out-of-fold contracts and keep training predictions out of feature inputs unless explicitly allowed."
114                    .to_string()
115            }
116            Self::OofLeakage(report) => report.remediation.clone(),
117            Self::Serialization(_) => {
118                "Check that the JSON or YAML payload matches the supported dag-ml contract version."
119                    .to_string()
120            }
121        }
122    }
123
124    /// Return structured context fields for logs, bindings and tests.
125    pub fn context(&self) -> BTreeMap<String, Value> {
126        let mut context = BTreeMap::new();
127        match self {
128            Self::InvalidIdentifier { value, reason } => {
129                context.insert("value".to_string(), json!(value));
130                context.insert("reason".to_string(), json!(reason));
131            }
132            Self::GraphValidation(detail)
133            | Self::ControllerValidation(detail)
134            | Self::CampaignValidation(detail)
135            | Self::Planning(detail)
136            | Self::RuntimeValidation(detail)
137            | Self::OofValidation(detail) => {
138                context.insert("detail".to_string(), json!(detail));
139            }
140            Self::OofLeakage(report) => {
141                context.insert("node_id".to_string(), json!(report.node_id));
142                context.insert("violator_count".to_string(), json!(report.violators.len()));
143                context.insert(
144                    "allow_train_predictions_as_features".to_string(),
145                    json!(report.allow_train_predictions_as_features),
146                );
147                context.insert("violators".to_string(), json!(report.violators));
148            }
149            Self::Serialization(error) => {
150                context.insert("detail".to_string(), json!(error.to_string()));
151            }
152        }
153        context
154    }
155
156    /// Build the serializable ADR-11 descriptor for this error.
157    pub fn descriptor(&self) -> DagMlErrorDescriptor {
158        DagMlErrorDescriptor {
159            category: self.category().to_string(),
160            code: self.code().to_string(),
161            severity: self.severity().to_string(),
162            message: self.to_string(),
163            remediation_hint: self.remediation_hint(),
164            context: self.context(),
165        }
166    }
167
168    /// Serialize the ADR-11 descriptor as compact JSON.
169    pub fn descriptor_json(&self) -> std::result::Result<String, serde_json::Error> {
170        serde_json::to_string(&self.descriptor())
171    }
172
173    /// Stable ADR-11 numeric error code for FFI consumers: the high 16 bits are
174    /// the taxonomy category id and the low 16 bits are the per-category code id,
175    /// mirroring the `(category << 16) | code` convention from ADR-11.
176    pub fn error_code(&self) -> u32 {
177        let (category_id, code_id) = self.numeric_taxonomy();
178        (u32::from(category_id) << 16) | u32::from(code_id)
179    }
180
181    fn taxonomy_parts(&self) -> (&'static str, &'static str, &'static str) {
182        match self {
183            Self::InvalidIdentifier { .. } => ("validation", "invalid_identifier", "error"),
184            Self::GraphValidation(_) => ("validation", "graph_validation", "error"),
185            Self::ControllerValidation(_) => ("controller", "controller_validation", "error"),
186            Self::CampaignValidation(_) => ("validation", "campaign_validation", "error"),
187            Self::Planning(_) => ("runtime", "planning_failed", "error"),
188            Self::RuntimeValidation(_) => ("runtime", "runtime_validation", "error"),
189            Self::OofValidation(_) => ("validation", "oof_validation", "error"),
190            Self::OofLeakage(_) => ("validation", "oof_leakage", "error"),
191            Self::Serialization(_) => ("compatibility", "serialization_error", "error"),
192        }
193    }
194
195    /// Stable `(category_id, code_id)` pair backing [`error_code`](Self::error_code).
196    ///
197    /// Category ids follow ADR-11: validation=0, runtime=1, data=2, controller=3,
198    /// bundle=4, lineage=5, replay=6, security=7, compatibility=8, internal=9.
199    /// Code ids are **1-based** so a packed `error_code()` is never `0` — `0` is
200    /// reserved as the "no error" sentinel for `dagml_last_error_code()`. Code ids
201    /// are stable within their category; never renumber a shipped pair.
202    fn numeric_taxonomy(&self) -> (u16, u16) {
203        match self {
204            Self::InvalidIdentifier { .. } => (0, 1),
205            Self::GraphValidation(_) => (0, 2),
206            Self::CampaignValidation(_) => (0, 3),
207            Self::OofValidation(_) => (0, 4),
208            Self::OofLeakage(_) => (0, 5),
209            Self::ControllerValidation(_) => (3, 1),
210            Self::Planning(_) => (1, 1),
211            Self::RuntimeValidation(_) => (1, 2),
212            Self::Serialization(_) => (8, 1),
213        }
214    }
215}
216
217pub type Result<T> = std::result::Result<T, DagMlError>;
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn invalid_identifier_descriptor_carries_taxonomy_and_context() {
225        let error = DagMlError::InvalidIdentifier {
226            value: "".to_string(),
227            reason: "empty",
228        };
229
230        let descriptor = error.descriptor();
231
232        assert_eq!(descriptor.category, "validation");
233        assert_eq!(descriptor.code, "invalid_identifier");
234        assert_eq!(descriptor.severity, "error");
235        assert_eq!(descriptor.context["value"], json!(""));
236        assert_eq!(descriptor.context["reason"], json!("empty"));
237        assert!(descriptor.remediation_hint.contains("identifier"));
238    }
239
240    #[test]
241    fn oof_leakage_descriptor_preserves_report_context() {
242        let error = DagMlError::OofLeakage(Box::new(OofLeakageReport {
243            node_id: "node:model".to_string(),
244            violators: vec![OofLeakageViolation {
245                producer_node: "node:prep".to_string(),
246                partition: "train".to_string(),
247                fold_id: Some("fold:0".to_string()),
248            }],
249            allow_train_predictions_as_features: false,
250            remediation: "Use validation-only OOF predictions.".to_string(),
251        }));
252
253        let descriptor = error.descriptor();
254
255        assert_eq!(descriptor.category, "validation");
256        assert_eq!(descriptor.code, "oof_leakage");
257        assert_eq!(
258            descriptor.remediation_hint,
259            "Use validation-only OOF predictions."
260        );
261        assert_eq!(descriptor.context["node_id"], json!("node:model"));
262        assert_eq!(descriptor.context["violator_count"], json!(1));
263        assert_eq!(
264            descriptor.context["allow_train_predictions_as_features"],
265            json!(false)
266        );
267    }
268
269    #[test]
270    fn error_code_packs_category_and_code() {
271        // Code ids are 1-based so no real error packs to the 0 "no error" sentinel.
272        assert_eq!(
273            DagMlError::InvalidIdentifier {
274                value: "x".to_string(),
275                reason: "bad",
276            }
277            .error_code(),
278            0x0000_0001
279        );
280        assert!(
281            DagMlError::InvalidIdentifier {
282                value: "x".to_string(),
283                reason: "bad",
284            }
285            .error_code()
286                != 0
287        );
288        // validation (0) / graph_validation (2) -> 0x0000_0002
289        assert_eq!(
290            DagMlError::GraphValidation("x".to_string()).error_code(),
291            0x0000_0002
292        );
293        // controller (3) / controller_validation (1) -> 0x0003_0001
294        assert_eq!(
295            DagMlError::ControllerValidation("x".to_string()).error_code(),
296            0x0003_0001
297        );
298        // runtime (1) / runtime_validation (2) -> 0x0001_0002
299        assert_eq!(
300            DagMlError::RuntimeValidation("x".to_string()).error_code(),
301            0x0001_0002
302        );
303        // compatibility (8) / serialization_error (1) -> 0x0008_0001
304        let serde_error = serde_json::from_str::<Value>("{").expect_err("invalid JSON");
305        assert_eq!(
306            DagMlError::Serialization(serde_error).error_code(),
307            0x0008_0001
308        );
309    }
310
311    #[test]
312    fn descriptor_json_is_stable_json_payload() {
313        let serde_error = serde_json::from_str::<Value>("{").expect_err("invalid JSON");
314        let error = DagMlError::Serialization(serde_error);
315
316        let payload = error.descriptor_json().expect("descriptor JSON");
317        let parsed = serde_json::from_str::<Value>(&payload).expect("parse descriptor");
318
319        assert_eq!(parsed["category"], json!("compatibility"));
320        assert_eq!(parsed["code"], json!("serialization_error"));
321        assert!(parsed["message"]
322            .as_str()
323            .expect("message")
324            .contains("serialization error"));
325        assert!(parsed["remediation_hint"]
326            .as_str()
327            .expect("hint")
328            .contains("contract version"));
329    }
330}