Skip to main content

dag_ml_core/
graph.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{DagMlError, Result};
6use crate::ids::NodeId;
7use crate::relation::EntityUnitLevel;
8
9pub const GRAPH_SPEC_SCHEMA_VERSION: u32 = 1;
10pub const GRAPH_SPEC_SCHEMA_ID: &str =
11    "https://github.com/GBeurier/dag-ml/schemas/graph_spec.v1.schema.json";
12
13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum NodeKind {
16    Transform,
17    YTransform,
18    Split,
19    Model,
20    Fork,
21    Map,
22    FeatureJoin,
23    PredictionJoin,
24    MixedJoin,
25    SourceJoin,
26    Tag,
27    Exclude,
28    Augmentation,
29    Adapter,
30    Aggregator,
31    Generator,
32    Restructure,
33    Tuner,
34    Subgraph,
35    Chart,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum PortKind {
41    Data,
42    Target,
43    Prediction,
44    Artifact,
45    Metric,
46    Control,
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum PortCardinality {
52    One,
53    Many,
54    Optional,
55}
56
57#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
58pub struct PortSpec {
59    pub name: String,
60    pub kind: PortKind,
61    pub representation: Option<String>,
62    pub cardinality: PortCardinality,
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub unit_level: Option<EntityUnitLevel>,
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub alignment_key: Option<String>,
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub target_level: Option<EntityUnitLevel>,
69    #[serde(default)]
70    pub description: String,
71}
72
73#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
74pub struct PortSchema {
75    #[serde(default)]
76    pub inputs: Vec<PortSpec>,
77    #[serde(default)]
78    pub outputs: Vec<PortSpec>,
79}
80
81#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
82pub struct PortRef {
83    pub node_id: NodeId,
84    pub port_name: String,
85}
86
87#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
88pub struct EdgeContract {
89    pub kind: PortKind,
90    pub representation: Option<String>,
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub unit_level: Option<EntityUnitLevel>,
93    #[serde(default, skip_serializing_if = "Option::is_none")]
94    pub alignment_key: Option<String>,
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub target_level: Option<EntityUnitLevel>,
97    #[serde(default, skip_serializing_if = "Option::is_none")]
98    pub relation_contract: Option<RelationContract>,
99    #[serde(default, skip_serializing_if = "is_false")]
100    pub allows_broadcast: bool,
101    #[serde(default, skip_serializing_if = "Option::is_none")]
102    pub missingness_policy: Option<MissingnessPolicy>,
103    #[serde(default)]
104    pub requires_oof: bool,
105    #[serde(default)]
106    pub requires_fold_alignment: bool,
107    #[serde(default = "default_true")]
108    pub propagates_lineage: bool,
109}
110
111#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
112pub struct RelationContract {
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub relation_fingerprint: Option<String>,
115    #[serde(default, skip_serializing_if = "is_false")]
116    pub required: bool,
117}
118
119#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum MissingnessPolicy {
122    Strict,
123    Warn,
124    ImputeDeclared,
125    Mask,
126    PartialModel,
127    PadRepresentation,
128}
129
130fn default_true() -> bool {
131    true
132}
133
134fn is_false(value: &bool) -> bool {
135    !*value
136}
137
138impl EdgeContract {
139    pub fn new(kind: PortKind, representation: Option<String>) -> Self {
140        Self {
141            kind,
142            representation,
143            unit_level: None,
144            alignment_key: None,
145            target_level: None,
146            relation_contract: None,
147            allows_broadcast: false,
148            missingness_policy: None,
149            requires_oof: false,
150            requires_fold_alignment: false,
151            propagates_lineage: true,
152        }
153    }
154}
155
156#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
157pub struct EdgeSpec {
158    pub source: PortRef,
159    pub target: PortRef,
160    pub contract: EdgeContract,
161}
162
163#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
164pub struct GraphInterface {
165    #[serde(default)]
166    pub inputs: Vec<PortSpec>,
167    #[serde(default)]
168    pub outputs: Vec<PortSpec>,
169}
170
171#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
172pub struct NodeSpec {
173    pub id: NodeId,
174    pub kind: NodeKind,
175    pub operator: Option<serde_json::Value>,
176    #[serde(default)]
177    pub params: BTreeMap<String, serde_json::Value>,
178    #[serde(default)]
179    pub ports: PortSchema,
180    #[serde(default)]
181    pub metadata: BTreeMap<String, serde_json::Value>,
182    #[serde(default)]
183    pub seed_label: Option<String>,
184}
185
186#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
187pub struct GraphSpec {
188    pub id: String,
189    #[serde(default)]
190    pub interface: GraphInterface,
191    #[serde(default)]
192    pub nodes: Vec<NodeSpec>,
193    #[serde(default)]
194    pub edges: Vec<EdgeSpec>,
195    #[serde(default)]
196    pub search_space_fingerprint: Option<String>,
197    #[serde(default)]
198    pub metadata: BTreeMap<String, serde_json::Value>,
199}
200
201impl GraphSpec {
202    pub fn validate(&self) -> Result<()> {
203        if self.id.trim().is_empty() {
204            return Err(DagMlError::GraphValidation(
205                "graph id must not be empty".to_string(),
206            ));
207        }
208        if self.nodes.is_empty() {
209            return Err(DagMlError::GraphValidation(
210                "graph must contain at least one node".to_string(),
211            ));
212        }
213        if let Some(fingerprint) = &self.search_space_fingerprint {
214            if fingerprint.trim().is_empty() {
215                return Err(DagMlError::GraphValidation(format!(
216                    "graph `{}` has empty search_space_fingerprint",
217                    self.id
218                )));
219            }
220        }
221
222        let mut nodes = BTreeMap::new();
223        validate_unique_ports(
224            &NodeId::new("graph:interface").expect("static identifier is valid"),
225            "interface input",
226            &self.interface.inputs,
227        )?;
228        validate_unique_ports(
229            &NodeId::new("graph:interface").expect("static identifier is valid"),
230            "interface output",
231            &self.interface.outputs,
232        )?;
233        for node in &self.nodes {
234            if nodes.insert(node.id.clone(), node).is_some() {
235                return Err(DagMlError::GraphValidation(format!(
236                    "duplicate node id `{}`",
237                    node.id
238                )));
239            }
240            validate_unique_ports(&node.id, "input", &node.ports.inputs)?;
241            validate_unique_ports(&node.id, "output", &node.ports.outputs)?;
242        }
243
244        let mut adjacency: BTreeMap<NodeId, Vec<NodeId>> = nodes
245            .keys()
246            .cloned()
247            .map(|id| (id, Vec::new()))
248            .collect::<BTreeMap<_, _>>();
249        let mut indegree: BTreeMap<NodeId, usize> =
250            nodes.keys().cloned().map(|id| (id, 0)).collect();
251
252        for edge in &self.edges {
253            let source = nodes.get(&edge.source.node_id).ok_or_else(|| {
254                DagMlError::GraphValidation(format!(
255                    "edge source node `{}` does not exist",
256                    edge.source.node_id
257                ))
258            })?;
259            let target = nodes.get(&edge.target.node_id).ok_or_else(|| {
260                DagMlError::GraphValidation(format!(
261                    "edge target node `{}` does not exist",
262                    edge.target.node_id
263                ))
264            })?;
265
266            let source_port =
267                find_port(&source.ports.outputs, &edge.source.port_name).ok_or_else(|| {
268                    DagMlError::GraphValidation(format!(
269                        "source port `{}.{}` does not exist",
270                        edge.source.node_id, edge.source.port_name
271                    ))
272                })?;
273            let target_port =
274                find_port(&target.ports.inputs, &edge.target.port_name).ok_or_else(|| {
275                    DagMlError::GraphValidation(format!(
276                        "target port `{}.{}` does not exist",
277                        edge.target.node_id, edge.target.port_name
278                    ))
279                })?;
280
281            if source_port.kind != edge.contract.kind || target_port.kind != edge.contract.kind {
282                return Err(DagMlError::GraphValidation(format!(
283                    "edge `{}.{}` -> `{}.{}` has kind {:?}, but ports are {:?} and {:?}",
284                    edge.source.node_id,
285                    edge.source.port_name,
286                    edge.target.node_id,
287                    edge.target.port_name,
288                    edge.contract.kind,
289                    source_port.kind,
290                    target_port.kind
291                )));
292            }
293            validate_edge_contract(edge, source_port, target_port)?;
294            if edge.contract.requires_oof && edge.contract.kind != PortKind::Prediction {
295                return Err(DagMlError::GraphValidation(format!(
296                    "edge `{}.{}` -> `{}.{}` requires OOF but is not a prediction edge",
297                    edge.source.node_id,
298                    edge.source.port_name,
299                    edge.target.node_id,
300                    edge.target.port_name
301                )));
302            }
303
304            adjacency
305                .get_mut(&edge.source.node_id)
306                .expect("source exists")
307                .push(edge.target.node_id.clone());
308            *indegree
309                .get_mut(&edge.target.node_id)
310                .expect("target exists") += 1;
311        }
312
313        ensure_acyclic(adjacency, indegree)
314    }
315
316    pub fn topological_order(&self) -> Result<Vec<NodeId>> {
317        self.validate()?;
318        let nodes = self
319            .nodes
320            .iter()
321            .map(|node| node.id.clone())
322            .collect::<BTreeSet<_>>();
323        let mut adjacency = nodes
324            .iter()
325            .cloned()
326            .map(|id| (id, Vec::new()))
327            .collect::<BTreeMap<_, _>>();
328        let mut indegree: BTreeMap<NodeId, usize> =
329            nodes.iter().cloned().map(|id| (id, 0usize)).collect();
330        for edge in &self.edges {
331            adjacency
332                .get_mut(&edge.source.node_id)
333                .expect("source exists after validate")
334                .push(edge.target.node_id.clone());
335            *indegree
336                .get_mut(&edge.target.node_id)
337                .expect("target exists after validate") += 1;
338        }
339        topological_order(adjacency, indegree)
340    }
341
342    pub fn parallel_levels(&self) -> Result<Vec<Vec<NodeId>>> {
343        self.validate()?;
344        let nodes = self
345            .nodes
346            .iter()
347            .map(|node| node.id.clone())
348            .collect::<BTreeSet<_>>();
349        let mut adjacency = nodes
350            .iter()
351            .cloned()
352            .map(|id| (id, Vec::new()))
353            .collect::<BTreeMap<_, _>>();
354        let mut indegree: BTreeMap<NodeId, usize> =
355            nodes.iter().cloned().map(|id| (id, 0usize)).collect();
356        for edge in &self.edges {
357            adjacency
358                .get_mut(&edge.source.node_id)
359                .expect("source exists after validate")
360                .push(edge.target.node_id.clone());
361            *indegree
362                .get_mut(&edge.target.node_id)
363                .expect("target exists after validate") += 1;
364        }
365        topological_levels(adjacency, indegree)
366    }
367
368    pub fn upstream_nodes(&self, node_id: &NodeId) -> Vec<NodeId> {
369        let mut upstream = self
370            .edges
371            .iter()
372            .filter_map(|edge| {
373                (edge.target.node_id == *node_id).then_some(edge.source.node_id.clone())
374            })
375            .collect::<Vec<_>>();
376        upstream.sort();
377        upstream.dedup();
378        upstream
379    }
380
381    pub fn downstream_nodes(&self, node_id: &NodeId) -> Vec<NodeId> {
382        let mut downstream = self
383            .edges
384            .iter()
385            .filter_map(|edge| {
386                (edge.source.node_id == *node_id).then_some(edge.target.node_id.clone())
387            })
388            .collect::<Vec<_>>();
389        downstream.sort();
390        downstream.dedup();
391        downstream
392    }
393}
394
395fn validate_unique_ports(node_id: &NodeId, direction: &str, ports: &[PortSpec]) -> Result<()> {
396    let mut seen = BTreeSet::new();
397    for port in ports {
398        if port.name.trim().is_empty() {
399            return Err(DagMlError::GraphValidation(format!(
400                "{} port on node `{}` has an empty name",
401                direction, node_id
402            )));
403        }
404        if !seen.insert(port.name.as_str()) {
405            return Err(DagMlError::GraphValidation(format!(
406                "duplicate {} port `{}` on node `{}`",
407                direction, port.name, node_id
408            )));
409        }
410        validate_port_contract(node_id, direction, port)?;
411    }
412    Ok(())
413}
414
415fn find_port<'a>(ports: &'a [PortSpec], name: &str) -> Option<&'a PortSpec> {
416    ports.iter().find(|port| port.name == name)
417}
418
419fn validate_port_contract(node_id: &NodeId, direction: &str, port: &PortSpec) -> Result<()> {
420    validate_optional_non_empty(
421        &format!("{direction} port `{}` representation", port.name),
422        port.representation.as_deref(),
423    )?;
424    validate_optional_non_empty(
425        &format!("{direction} port `{}` alignment_key", port.name),
426        port.alignment_key.as_deref(),
427    )?;
428    if port
429        .alignment_key
430        .as_deref()
431        .is_some_and(|key| !is_identifier(key))
432    {
433        return Err(DagMlError::GraphValidation(format!(
434            "{direction} port `{}` on node `{node_id}` has invalid alignment_key",
435            port.name
436        )));
437    }
438    Ok(())
439}
440
441fn validate_edge_contract(
442    edge: &EdgeSpec,
443    source_port: &PortSpec,
444    target_port: &PortSpec,
445) -> Result<()> {
446    let label = format!(
447        "edge `{}.{}` -> `{}.{}`",
448        edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
449    );
450    validate_optional_non_empty(
451        &format!("{label} representation"),
452        edge.contract.representation.as_deref(),
453    )?;
454    validate_optional_non_empty(
455        &format!("{label} alignment_key"),
456        edge.contract.alignment_key.as_deref(),
457    )?;
458    if edge
459        .contract
460        .alignment_key
461        .as_deref()
462        .is_some_and(|key| !is_identifier(key))
463    {
464        return Err(DagMlError::GraphValidation(format!(
465            "{label} has invalid alignment_key"
466        )));
467    }
468    if let Some(relation_contract) = &edge.contract.relation_contract {
469        validate_relation_contract(&label, relation_contract)?;
470    }
471
472    validate_edge_unit_alignment(&label, edge, source_port, target_port)?;
473
474    if relation_aware_edge(edge, source_port, target_port) {
475        let relation_fingerprint = edge
476            .contract
477            .relation_contract
478            .as_ref()
479            .and_then(|contract| contract.relation_fingerprint.as_deref());
480        if relation_fingerprint.is_none() {
481            return Err(DagMlError::GraphValidation(format!(
482                "{label} is relation-aware but has no relation_fingerprint"
483            )));
484        }
485        if !has_effective_unit_level(edge, source_port, target_port) {
486            return Err(DagMlError::GraphValidation(format!(
487                "{label} is relation-aware but has no unit_level metadata"
488            )));
489        }
490        if !has_effective_alignment_key(edge, source_port, target_port) {
491            return Err(DagMlError::GraphValidation(format!(
492                "{label} is relation-aware but has no alignment_key"
493            )));
494        }
495    }
496    Ok(())
497}
498
499fn validate_relation_contract(label: &str, contract: &RelationContract) -> Result<()> {
500    if let Some(fingerprint) = &contract.relation_fingerprint {
501        validate_sha256(label, "relation_fingerprint", fingerprint)?;
502    } else if contract.required {
503        return Err(DagMlError::GraphValidation(format!(
504            "{label} relation_contract is required but has no relation_fingerprint"
505        )));
506    }
507    Ok(())
508}
509
510fn validate_edge_unit_alignment(
511    label: &str,
512    edge: &EdgeSpec,
513    source_port: &PortSpec,
514    target_port: &PortSpec,
515) -> Result<()> {
516    if let Some(contract_unit) = edge.contract.unit_level {
517        for (endpoint, unit) in [
518            ("source", source_port.unit_level),
519            ("target", target_port.unit_level),
520        ] {
521            if let Some(unit) = unit {
522                if unit != contract_unit && !edge.contract.allows_broadcast {
523                    return Err(DagMlError::GraphValidation(format!(
524                        "{label} {endpoint} unit {:?} does not match edge unit {:?}",
525                        unit, contract_unit
526                    )));
527                }
528            }
529        }
530    }
531
532    if let (Some(source_unit), Some(target_unit)) = (source_port.unit_level, target_port.unit_level)
533    {
534        if source_unit != target_unit && !edge.contract.allows_broadcast {
535            return Err(DagMlError::GraphValidation(format!(
536                "{label} joins incompatible unit levels {:?} and {:?}",
537                source_unit, target_unit
538            )));
539        }
540    }
541
542    if let (Some(source_target), Some(target_target)) =
543        (source_port.target_level, target_port.target_level)
544    {
545        if source_target != target_target {
546            return Err(DagMlError::GraphValidation(format!(
547                "{label} joins incompatible target levels {:?} and {:?}",
548                source_target, target_target
549            )));
550        }
551    }
552    if let Some(contract_target) = edge.contract.target_level {
553        for (endpoint, target_level) in [
554            ("source", source_port.target_level),
555            ("target", target_port.target_level),
556        ] {
557            if let Some(target_level) = target_level {
558                if target_level != contract_target {
559                    return Err(DagMlError::GraphValidation(format!(
560                        "{label} {endpoint} target level {:?} does not match edge target_level {:?}",
561                        target_level, contract_target
562                    )));
563                }
564            }
565        }
566    }
567
568    if let (Some(source_alignment), Some(target_alignment)) = (
569        source_port.alignment_key.as_deref(),
570        target_port.alignment_key.as_deref(),
571    ) {
572        if source_alignment != target_alignment && !edge.contract.allows_broadcast {
573            return Err(DagMlError::GraphValidation(format!(
574                "{label} joins incompatible alignment keys `{source_alignment}` and `{target_alignment}`"
575            )));
576        }
577    }
578
579    if let Some(edge_alignment) = edge.contract.alignment_key.as_deref() {
580        for (endpoint, alignment) in [
581            ("source", source_port.alignment_key.as_deref()),
582            ("target", target_port.alignment_key.as_deref()),
583        ] {
584            if let Some(alignment) = alignment {
585                if alignment != edge_alignment && !edge.contract.allows_broadcast {
586                    return Err(DagMlError::GraphValidation(format!(
587                        "{label} {endpoint} alignment `{alignment}` does not match edge alignment `{edge_alignment}`"
588                    )));
589                }
590            }
591        }
592    }
593
594    if edge.contract.allows_broadcast
595        && edge.contract.alignment_key.is_none()
596        && source_port.alignment_key.is_none()
597        && target_port.alignment_key.is_none()
598    {
599        return Err(DagMlError::GraphValidation(format!(
600            "{label} allows broadcast but declares no alignment_key"
601        )));
602    }
603    Ok(())
604}
605
606fn relation_aware_edge(edge: &EdgeSpec, source_port: &PortSpec, target_port: &PortSpec) -> bool {
607    edge.contract.relation_contract.is_some()
608        || edge.contract.allows_broadcast
609        || edge.contract.alignment_key.is_some()
610        || non_physical(edge.contract.unit_level)
611        || non_physical(edge.contract.target_level)
612        || non_physical(source_port.unit_level)
613        || non_physical(source_port.target_level)
614        || non_physical(target_port.unit_level)
615        || non_physical(target_port.target_level)
616        || source_port.alignment_key.is_some()
617        || target_port.alignment_key.is_some()
618}
619
620fn has_effective_unit_level(
621    edge: &EdgeSpec,
622    source_port: &PortSpec,
623    target_port: &PortSpec,
624) -> bool {
625    edge.contract.unit_level.is_some()
626        || source_port.unit_level.is_some()
627        || target_port.unit_level.is_some()
628}
629
630fn has_effective_alignment_key(
631    edge: &EdgeSpec,
632    source_port: &PortSpec,
633    target_port: &PortSpec,
634) -> bool {
635    edge.contract.alignment_key.is_some()
636        || source_port.alignment_key.is_some()
637        || target_port.alignment_key.is_some()
638}
639
640fn non_physical(unit_level: Option<EntityUnitLevel>) -> bool {
641    unit_level.is_some_and(|level| level != EntityUnitLevel::PhysicalSample)
642}
643
644fn validate_optional_non_empty(label: &str, value: Option<&str>) -> Result<()> {
645    if value.is_some_and(|value| value.trim().is_empty()) {
646        return Err(DagMlError::GraphValidation(format!(
647            "{label} must not be empty"
648        )));
649    }
650    Ok(())
651}
652
653fn validate_sha256(owner: &str, field: &str, value: &str) -> Result<()> {
654    if value.len() == 64 && value.bytes().all(|byte| byte.is_ascii_hexdigit()) {
655        Ok(())
656    } else {
657        Err(DagMlError::GraphValidation(format!(
658            "{owner} has invalid {field}"
659        )))
660    }
661}
662
663fn is_identifier(value: &str) -> bool {
664    !value.is_empty()
665        && value.len() <= 128
666        && value
667            .bytes()
668            .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b':'))
669}
670
671fn ensure_acyclic(
672    adjacency: BTreeMap<NodeId, Vec<NodeId>>,
673    indegree: BTreeMap<NodeId, usize>,
674) -> Result<()> {
675    topological_order(adjacency, indegree).map(|_| ())
676}
677
678fn topological_order(
679    adjacency: BTreeMap<NodeId, Vec<NodeId>>,
680    mut indegree: BTreeMap<NodeId, usize>,
681) -> Result<Vec<NodeId>> {
682    let mut queue = indegree
683        .iter()
684        .filter_map(|(id, degree)| (*degree == 0).then_some(id.clone()))
685        .collect::<BTreeSet<_>>();
686    let mut order = Vec::with_capacity(indegree.len());
687
688    while let Some(node) = queue.pop_first() {
689        order.push(node.clone());
690        if let Some(next_nodes) = adjacency.get(&node) {
691            for next in next_nodes {
692                let degree = indegree.get_mut(next).expect("node exists");
693                *degree -= 1;
694                if *degree == 0 {
695                    queue.insert(next.clone());
696                }
697            }
698        }
699    }
700
701    if order.len() == indegree.len() {
702        Ok(order)
703    } else {
704        Err(DagMlError::GraphValidation(
705            "graph contains at least one cycle".to_string(),
706        ))
707    }
708}
709
710fn topological_levels(
711    adjacency: BTreeMap<NodeId, Vec<NodeId>>,
712    mut indegree: BTreeMap<NodeId, usize>,
713) -> Result<Vec<Vec<NodeId>>> {
714    let mut queue = indegree
715        .iter()
716        .filter_map(|(id, degree)| (*degree == 0).then_some(id.clone()))
717        .collect::<BTreeSet<_>>();
718    let mut levels = Vec::new();
719    let mut visited = 0usize;
720
721    while !queue.is_empty() {
722        let level = queue.iter().cloned().collect::<Vec<_>>();
723        queue.clear();
724        for node in &level {
725            visited += 1;
726            if let Some(next_nodes) = adjacency.get(node) {
727                for next in next_nodes {
728                    let degree = indegree.get_mut(next).expect("node exists");
729                    *degree -= 1;
730                    if *degree == 0 {
731                        queue.insert(next.clone());
732                    }
733                }
734            }
735        }
736        levels.push(level);
737    }
738
739    if visited == indegree.len() {
740        Ok(levels)
741    } else {
742        Err(DagMlError::GraphValidation(
743            "graph contains at least one cycle".to_string(),
744        ))
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751
752    fn port(name: &str, kind: PortKind) -> PortSpec {
753        PortSpec {
754            name: name.to_string(),
755            kind,
756            representation: None,
757            cardinality: PortCardinality::One,
758            unit_level: None,
759            alignment_key: None,
760            target_level: None,
761            description: String::new(),
762        }
763    }
764
765    fn node(id: &str, inputs: Vec<PortSpec>, outputs: Vec<PortSpec>) -> NodeSpec {
766        NodeSpec {
767            id: NodeId::new(id).unwrap(),
768            kind: NodeKind::Model,
769            operator: None,
770            params: BTreeMap::new(),
771            ports: PortSchema { inputs, outputs },
772            metadata: BTreeMap::new(),
773            seed_label: None,
774        }
775    }
776
777    fn edge(source: &str, source_port: &str, target: &str, target_port: &str) -> EdgeSpec {
778        EdgeSpec {
779            source: PortRef {
780                node_id: NodeId::new(source).unwrap(),
781                port_name: source_port.to_string(),
782            },
783            target: PortRef {
784                node_id: NodeId::new(target).unwrap(),
785                port_name: target_port.to_string(),
786            },
787            contract: EdgeContract {
788                requires_oof: true,
789                requires_fold_alignment: true,
790                ..EdgeContract::new(PortKind::Prediction, None)
791            },
792        }
793    }
794
795    #[test]
796    fn validates_simple_graph() {
797        let graph = GraphSpec {
798            id: "g".to_string(),
799            interface: GraphInterface::default(),
800            nodes: vec![
801                node("model:a", vec![], vec![port("pred", PortKind::Prediction)]),
802                node("model:b", vec![port("pred", PortKind::Prediction)], vec![]),
803            ],
804            edges: vec![edge("model:a", "pred", "model:b", "pred")],
805            search_space_fingerprint: None,
806            metadata: BTreeMap::new(),
807        };
808
809        assert!(graph.validate().is_ok());
810    }
811
812    #[test]
813    fn computes_deterministic_parallel_levels() {
814        let graph = GraphSpec {
815            id: "g".to_string(),
816            interface: GraphInterface::default(),
817            nodes: vec![
818                node("model:a", vec![], vec![port("pred", PortKind::Prediction)]),
819                node(
820                    "model:b",
821                    vec![port("pred", PortKind::Prediction)],
822                    vec![port("pred", PortKind::Prediction)],
823                ),
824                node(
825                    "model:c",
826                    vec![port("pred", PortKind::Prediction)],
827                    vec![port("pred", PortKind::Prediction)],
828                ),
829                node("model:d", vec![port("pred", PortKind::Prediction)], vec![]),
830            ],
831            edges: vec![
832                edge("model:a", "pred", "model:b", "pred"),
833                edge("model:a", "pred", "model:c", "pred"),
834                edge("model:b", "pred", "model:d", "pred"),
835                edge("model:c", "pred", "model:d", "pred"),
836            ],
837            search_space_fingerprint: None,
838            metadata: BTreeMap::new(),
839        };
840
841        let levels = graph.parallel_levels().unwrap();
842
843        assert_eq!(
844            levels,
845            vec![
846                vec![NodeId::new("model:a").unwrap()],
847                vec![
848                    NodeId::new("model:b").unwrap(),
849                    NodeId::new("model:c").unwrap()
850                ],
851                vec![NodeId::new("model:d").unwrap()]
852            ]
853        );
854    }
855
856    #[test]
857    fn rejects_missing_edge_endpoint() {
858        let graph = GraphSpec {
859            id: "g".to_string(),
860            interface: GraphInterface::default(),
861            nodes: vec![node(
862                "model:a",
863                vec![],
864                vec![port("pred", PortKind::Prediction)],
865            )],
866            edges: vec![edge("model:a", "pred", "model:b", "pred")],
867            search_space_fingerprint: None,
868            metadata: BTreeMap::new(),
869        };
870
871        assert!(graph.validate().is_err());
872    }
873
874    #[test]
875    fn rejects_oof_contract_on_non_prediction_edge() {
876        let graph = GraphSpec {
877            id: "g".to_string(),
878            interface: GraphInterface::default(),
879            nodes: vec![
880                node("model:a", vec![], vec![port("x", PortKind::Data)]),
881                node("model:b", vec![port("x", PortKind::Data)], vec![]),
882            ],
883            edges: vec![EdgeSpec {
884                source: PortRef {
885                    node_id: NodeId::new("model:a").unwrap(),
886                    port_name: "x".to_string(),
887                },
888                target: PortRef {
889                    node_id: NodeId::new("model:b").unwrap(),
890                    port_name: "x".to_string(),
891                },
892                contract: EdgeContract {
893                    requires_oof: true,
894                    requires_fold_alignment: true,
895                    ..EdgeContract::new(PortKind::Data, None)
896                },
897            }],
898            search_space_fingerprint: None,
899            metadata: BTreeMap::new(),
900        };
901
902        let error = graph.validate().unwrap_err().to_string();
903
904        assert!(error.contains("requires OOF"));
905    }
906
907    fn unit_port(name: &str, kind: PortKind, unit_level: EntityUnitLevel) -> PortSpec {
908        let mut port = port(name, kind);
909        port.unit_level = Some(unit_level);
910        port.alignment_key = Some("sample_id".to_string());
911        port
912    }
913
914    fn data_edge_contract() -> EdgeContract {
915        EdgeContract::new(PortKind::Data, Some("tabular".to_string()))
916    }
917
918    fn relation_contract() -> RelationContract {
919        RelationContract {
920            relation_fingerprint: Some("a".repeat(64)),
921            required: true,
922        }
923    }
924
925    #[test]
926    fn rejects_unit_mismatch_without_explicit_broadcast() {
927        let graph = GraphSpec {
928            id: "g".to_string(),
929            interface: GraphInterface::default(),
930            nodes: vec![
931                node(
932                    "transform:obs",
933                    vec![],
934                    vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
935                ),
936                node(
937                    "join:sample",
938                    vec![unit_port(
939                        "x",
940                        PortKind::Data,
941                        EntityUnitLevel::PhysicalSample,
942                    )],
943                    vec![],
944                ),
945            ],
946            edges: vec![EdgeSpec {
947                source: PortRef {
948                    node_id: NodeId::new("transform:obs").unwrap(),
949                    port_name: "x".to_string(),
950                },
951                target: PortRef {
952                    node_id: NodeId::new("join:sample").unwrap(),
953                    port_name: "x".to_string(),
954                },
955                contract: EdgeContract {
956                    relation_contract: Some(relation_contract()),
957                    ..data_edge_contract()
958                },
959            }],
960            search_space_fingerprint: None,
961            metadata: BTreeMap::new(),
962        };
963
964        let error = graph.validate().unwrap_err().to_string();
965
966        assert!(error.contains("incompatible unit levels"));
967    }
968
969    #[test]
970    fn relation_aware_edge_requires_relation_fingerprint() {
971        let graph = GraphSpec {
972            id: "g".to_string(),
973            interface: GraphInterface::default(),
974            nodes: vec![
975                node(
976                    "source:a",
977                    vec![],
978                    vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
979                ),
980                node(
981                    "model:a",
982                    vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
983                    vec![],
984                ),
985            ],
986            edges: vec![EdgeSpec {
987                source: PortRef {
988                    node_id: NodeId::new("source:a").unwrap(),
989                    port_name: "x".to_string(),
990                },
991                target: PortRef {
992                    node_id: NodeId::new("model:a").unwrap(),
993                    port_name: "x".to_string(),
994                },
995                contract: data_edge_contract(),
996            }],
997            search_space_fingerprint: None,
998            metadata: BTreeMap::new(),
999        };
1000
1001        let error = graph.validate().unwrap_err().to_string();
1002
1003        assert!(error.contains("relation-aware"));
1004    }
1005
1006    #[test]
1007    fn relation_aware_edge_requires_alignment_key() {
1008        let mut source_port = port("x", PortKind::Data);
1009        source_port.unit_level = Some(EntityUnitLevel::Observation);
1010        let mut target_port = port("x", PortKind::Data);
1011        target_port.unit_level = Some(EntityUnitLevel::Observation);
1012
1013        let graph = GraphSpec {
1014            id: "g".to_string(),
1015            interface: GraphInterface::default(),
1016            nodes: vec![
1017                node("source:a", vec![], vec![source_port]),
1018                node("model:a", vec![target_port], vec![]),
1019            ],
1020            edges: vec![EdgeSpec {
1021                source: PortRef {
1022                    node_id: NodeId::new("source:a").unwrap(),
1023                    port_name: "x".to_string(),
1024                },
1025                target: PortRef {
1026                    node_id: NodeId::new("model:a").unwrap(),
1027                    port_name: "x".to_string(),
1028                },
1029                contract: EdgeContract {
1030                    relation_contract: Some(relation_contract()),
1031                    ..data_edge_contract()
1032                },
1033            }],
1034            search_space_fingerprint: None,
1035            metadata: BTreeMap::new(),
1036        };
1037
1038        let error = graph.validate().unwrap_err().to_string();
1039
1040        assert!(error.contains("alignment_key"));
1041    }
1042
1043    #[test]
1044    fn explicit_broadcast_allows_sample_to_observation_edge() {
1045        let mut contract = data_edge_contract();
1046        contract.allows_broadcast = true;
1047        contract.alignment_key = Some("sample_id".to_string());
1048        contract.relation_contract = Some(relation_contract());
1049
1050        let graph = GraphSpec {
1051            id: "g".to_string(),
1052            interface: GraphInterface::default(),
1053            nodes: vec![
1054                node(
1055                    "source:sample",
1056                    vec![],
1057                    vec![unit_port(
1058                        "x",
1059                        PortKind::Data,
1060                        EntityUnitLevel::PhysicalSample,
1061                    )],
1062                ),
1063                node(
1064                    "adapter:broadcast",
1065                    vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
1066                    vec![],
1067                ),
1068            ],
1069            edges: vec![EdgeSpec {
1070                source: PortRef {
1071                    node_id: NodeId::new("source:sample").unwrap(),
1072                    port_name: "x".to_string(),
1073                },
1074                target: PortRef {
1075                    node_id: NodeId::new("adapter:broadcast").unwrap(),
1076                    port_name: "x".to_string(),
1077                },
1078                contract,
1079            }],
1080            search_space_fingerprint: None,
1081            metadata: BTreeMap::new(),
1082        };
1083
1084        graph.validate().unwrap();
1085    }
1086
1087    #[test]
1088    fn rejects_cycles() {
1089        let graph = GraphSpec {
1090            id: "g".to_string(),
1091            interface: GraphInterface::default(),
1092            nodes: vec![
1093                node(
1094                    "model:a",
1095                    vec![port("pred", PortKind::Prediction)],
1096                    vec![port("pred", PortKind::Prediction)],
1097                ),
1098                node(
1099                    "model:b",
1100                    vec![port("pred", PortKind::Prediction)],
1101                    vec![port("pred", PortKind::Prediction)],
1102                ),
1103            ],
1104            edges: vec![
1105                edge("model:a", "pred", "model:b", "pred"),
1106                edge("model:b", "pred", "model:a", "pred"),
1107            ],
1108            search_space_fingerprint: None,
1109            metadata: BTreeMap::new(),
1110        };
1111
1112        assert!(graph.validate().is_err());
1113    }
1114
1115    #[test]
1116    fn published_graph_spec_schema_declares_current_contract() {
1117        let schema: serde_json::Value = serde_json::from_str(include_str!(
1118            "../../../docs/contracts/graph_spec.schema.json"
1119        ))
1120        .unwrap();
1121
1122        assert_eq!(schema["$id"], GRAPH_SPEC_SCHEMA_ID);
1123        assert!(schema["required"]
1124            .as_array()
1125            .unwrap()
1126            .iter()
1127            .any(|field| field.as_str() == Some("nodes")));
1128        assert_eq!(
1129            schema["$defs"]["node_kind"]["enum"]
1130                .as_array()
1131                .unwrap()
1132                .len(),
1133            20
1134        );
1135        assert!(schema["$defs"]["port_kind"]["enum"]
1136            .as_array()
1137            .unwrap()
1138            .iter()
1139            .any(|kind| kind.as_str() == Some("prediction")));
1140        assert!(schema["$defs"]["entity_unit_level"]["enum"]
1141            .as_array()
1142            .unwrap()
1143            .iter()
1144            .any(|level| level.as_str() == Some("combo")));
1145        assert!(schema["$defs"]["edge_contract"]["properties"]
1146            .as_object()
1147            .unwrap()
1148            .contains_key("relation_contract"));
1149    }
1150}