use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct DagMlErrorDescriptor {
pub category: String,
pub code: String,
pub severity: String,
pub message: String,
pub remediation_hint: String,
pub context: BTreeMap<String, Value>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct OofLeakageViolation {
pub producer_node: String,
pub partition: String,
pub fold_id: Option<String>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct OofLeakageReport {
pub node_id: String,
pub violators: Vec<OofLeakageViolation>,
pub allow_train_predictions_as_features: bool,
pub remediation: String,
}
#[derive(Debug, Error)]
pub enum DagMlError {
#[error("invalid identifier `{value}`: {reason}")]
InvalidIdentifier { value: String, reason: &'static str },
#[error("graph validation failed: {0}")]
GraphValidation(String),
#[error("controller validation failed: {0}")]
ControllerValidation(String),
#[error("campaign validation failed: {0}")]
CampaignValidation(String),
#[error("planning failed: {0}")]
Planning(String),
#[error("runtime validation failed: {0}")]
RuntimeValidation(String),
#[error("OOF validation failed: {0}")]
OofValidation(String),
#[error("OOF leakage at `{}`: {} violator(s); {}", .0.node_id, .0.violators.len(), .0.remediation)]
OofLeakage(Box<OofLeakageReport>),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}
impl DagMlError {
pub fn category(&self) -> &'static str {
self.taxonomy_parts().0
}
pub fn code(&self) -> &'static str {
self.taxonomy_parts().1
}
pub fn severity(&self) -> &'static str {
self.taxonomy_parts().2
}
pub fn remediation_hint(&self) -> String {
match self {
Self::InvalidIdentifier { .. } => {
"Use a non-empty stable identifier that matches the dag-ml identifier grammar."
.to_string()
}
Self::GraphValidation(_) => {
"Fix the graph contract violation before planning or running the pipeline."
.to_string()
}
Self::ControllerValidation(_) => {
"Register or configure the controller so it matches the graph contract."
.to_string()
}
Self::CampaignValidation(_) => {
"Fix the campaign template so its folds, metrics and graph references are consistent."
.to_string()
}
Self::Planning(_) => {
"Inspect the graph and campaign constraints, then re-plan with a compatible execution request."
.to_string()
}
Self::RuntimeValidation(_) => {
"Inspect the runtime inputs and produced artifacts, then rerun the failed step with compatible values."
.to_string()
}
Self::OofValidation(_) => {
"Use validated out-of-fold contracts and keep training predictions out of feature inputs unless explicitly allowed."
.to_string()
}
Self::OofLeakage(report) => report.remediation.clone(),
Self::Serialization(_) => {
"Check that the JSON or YAML payload matches the supported dag-ml contract version."
.to_string()
}
}
}
pub fn context(&self) -> BTreeMap<String, Value> {
let mut context = BTreeMap::new();
match self {
Self::InvalidIdentifier { value, reason } => {
context.insert("value".to_string(), json!(value));
context.insert("reason".to_string(), json!(reason));
}
Self::GraphValidation(detail)
| Self::ControllerValidation(detail)
| Self::CampaignValidation(detail)
| Self::Planning(detail)
| Self::RuntimeValidation(detail)
| Self::OofValidation(detail) => {
context.insert("detail".to_string(), json!(detail));
}
Self::OofLeakage(report) => {
context.insert("node_id".to_string(), json!(report.node_id));
context.insert("violator_count".to_string(), json!(report.violators.len()));
context.insert(
"allow_train_predictions_as_features".to_string(),
json!(report.allow_train_predictions_as_features),
);
context.insert("violators".to_string(), json!(report.violators));
}
Self::Serialization(error) => {
context.insert("detail".to_string(), json!(error.to_string()));
}
}
context
}
pub fn descriptor(&self) -> DagMlErrorDescriptor {
DagMlErrorDescriptor {
category: self.category().to_string(),
code: self.code().to_string(),
severity: self.severity().to_string(),
message: self.to_string(),
remediation_hint: self.remediation_hint(),
context: self.context(),
}
}
pub fn descriptor_json(&self) -> std::result::Result<String, serde_json::Error> {
serde_json::to_string(&self.descriptor())
}
pub fn error_code(&self) -> u32 {
let (category_id, code_id) = self.numeric_taxonomy();
(u32::from(category_id) << 16) | u32::from(code_id)
}
fn taxonomy_parts(&self) -> (&'static str, &'static str, &'static str) {
match self {
Self::InvalidIdentifier { .. } => ("validation", "invalid_identifier", "error"),
Self::GraphValidation(_) => ("validation", "graph_validation", "error"),
Self::ControllerValidation(_) => ("controller", "controller_validation", "error"),
Self::CampaignValidation(_) => ("validation", "campaign_validation", "error"),
Self::Planning(_) => ("runtime", "planning_failed", "error"),
Self::RuntimeValidation(_) => ("runtime", "runtime_validation", "error"),
Self::OofValidation(_) => ("validation", "oof_validation", "error"),
Self::OofLeakage(_) => ("validation", "oof_leakage", "error"),
Self::Serialization(_) => ("compatibility", "serialization_error", "error"),
}
}
fn numeric_taxonomy(&self) -> (u16, u16) {
match self {
Self::InvalidIdentifier { .. } => (0, 1),
Self::GraphValidation(_) => (0, 2),
Self::CampaignValidation(_) => (0, 3),
Self::OofValidation(_) => (0, 4),
Self::OofLeakage(_) => (0, 5),
Self::ControllerValidation(_) => (3, 1),
Self::Planning(_) => (1, 1),
Self::RuntimeValidation(_) => (1, 2),
Self::Serialization(_) => (8, 1),
}
}
}
pub type Result<T> = std::result::Result<T, DagMlError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn invalid_identifier_descriptor_carries_taxonomy_and_context() {
let error = DagMlError::InvalidIdentifier {
value: "".to_string(),
reason: "empty",
};
let descriptor = error.descriptor();
assert_eq!(descriptor.category, "validation");
assert_eq!(descriptor.code, "invalid_identifier");
assert_eq!(descriptor.severity, "error");
assert_eq!(descriptor.context["value"], json!(""));
assert_eq!(descriptor.context["reason"], json!("empty"));
assert!(descriptor.remediation_hint.contains("identifier"));
}
#[test]
fn oof_leakage_descriptor_preserves_report_context() {
let error = DagMlError::OofLeakage(Box::new(OofLeakageReport {
node_id: "node:model".to_string(),
violators: vec![OofLeakageViolation {
producer_node: "node:prep".to_string(),
partition: "train".to_string(),
fold_id: Some("fold:0".to_string()),
}],
allow_train_predictions_as_features: false,
remediation: "Use validation-only OOF predictions.".to_string(),
}));
let descriptor = error.descriptor();
assert_eq!(descriptor.category, "validation");
assert_eq!(descriptor.code, "oof_leakage");
assert_eq!(
descriptor.remediation_hint,
"Use validation-only OOF predictions."
);
assert_eq!(descriptor.context["node_id"], json!("node:model"));
assert_eq!(descriptor.context["violator_count"], json!(1));
assert_eq!(
descriptor.context["allow_train_predictions_as_features"],
json!(false)
);
}
#[test]
fn error_code_packs_category_and_code() {
assert_eq!(
DagMlError::InvalidIdentifier {
value: "x".to_string(),
reason: "bad",
}
.error_code(),
0x0000_0001
);
assert!(
DagMlError::InvalidIdentifier {
value: "x".to_string(),
reason: "bad",
}
.error_code()
!= 0
);
assert_eq!(
DagMlError::GraphValidation("x".to_string()).error_code(),
0x0000_0002
);
assert_eq!(
DagMlError::ControllerValidation("x".to_string()).error_code(),
0x0003_0001
);
assert_eq!(
DagMlError::RuntimeValidation("x".to_string()).error_code(),
0x0001_0002
);
let serde_error = serde_json::from_str::<Value>("{").expect_err("invalid JSON");
assert_eq!(
DagMlError::Serialization(serde_error).error_code(),
0x0008_0001
);
}
#[test]
fn descriptor_json_is_stable_json_payload() {
let serde_error = serde_json::from_str::<Value>("{").expect_err("invalid JSON");
let error = DagMlError::Serialization(serde_error);
let payload = error.descriptor_json().expect("descriptor JSON");
let parsed = serde_json::from_str::<Value>(&payload).expect("parse descriptor");
assert_eq!(parsed["category"], json!("compatibility"));
assert_eq!(parsed["code"], json!("serialization_error"));
assert!(parsed["message"]
.as_str()
.expect("message")
.contains("serialization error"));
assert!(parsed["remediation_hint"]
.as_str()
.expect("hint")
.contains("contract version"));
}
}