1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{DagMlError, Result};
6use crate::ids::NodeId;
7use crate::relation::EntityUnitLevel;
8
9pub const GRAPH_SPEC_SCHEMA_VERSION: u32 = 1;
10pub const GRAPH_SPEC_SCHEMA_ID: &str =
11 "https://github.com/GBeurier/dag-ml/schemas/graph_spec.v1.schema.json";
12
13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum NodeKind {
16 Transform,
17 YTransform,
18 Split,
19 Model,
20 Fork,
21 Map,
22 FeatureJoin,
23 PredictionJoin,
24 MixedJoin,
25 SourceJoin,
26 Tag,
27 Exclude,
28 Augmentation,
29 Adapter,
30 Aggregator,
31 Generator,
32 Restructure,
33 Tuner,
34 Subgraph,
35 Chart,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum PortKind {
41 Data,
42 Target,
43 Prediction,
44 Artifact,
45 Metric,
46 Control,
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum PortCardinality {
52 One,
53 Many,
54 Optional,
55}
56
57#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
58pub struct PortSpec {
59 pub name: String,
60 pub kind: PortKind,
61 pub representation: Option<String>,
62 pub cardinality: PortCardinality,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub unit_level: Option<EntityUnitLevel>,
65 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub alignment_key: Option<String>,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub target_level: Option<EntityUnitLevel>,
69 #[serde(default)]
70 pub description: String,
71}
72
73#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
74pub struct PortSchema {
75 #[serde(default)]
76 pub inputs: Vec<PortSpec>,
77 #[serde(default)]
78 pub outputs: Vec<PortSpec>,
79}
80
81#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
82pub struct PortRef {
83 pub node_id: NodeId,
84 pub port_name: String,
85}
86
87#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
88pub struct EdgeContract {
89 pub kind: PortKind,
90 pub representation: Option<String>,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub unit_level: Option<EntityUnitLevel>,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
94 pub alignment_key: Option<String>,
95 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub target_level: Option<EntityUnitLevel>,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub relation_contract: Option<RelationContract>,
99 #[serde(default, skip_serializing_if = "is_false")]
100 pub allows_broadcast: bool,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub missingness_policy: Option<MissingnessPolicy>,
103 #[serde(default)]
104 pub requires_oof: bool,
105 #[serde(default)]
106 pub requires_fold_alignment: bool,
107 #[serde(default = "default_true")]
108 pub propagates_lineage: bool,
109}
110
111#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
112pub struct RelationContract {
113 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub relation_fingerprint: Option<String>,
115 #[serde(default, skip_serializing_if = "is_false")]
116 pub required: bool,
117}
118
119#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum MissingnessPolicy {
122 Strict,
123 Warn,
124 ImputeDeclared,
125 Mask,
126 PartialModel,
127 PadRepresentation,
128}
129
130fn default_true() -> bool {
131 true
132}
133
134fn is_false(value: &bool) -> bool {
135 !*value
136}
137
138impl EdgeContract {
139 pub fn new(kind: PortKind, representation: Option<String>) -> Self {
140 Self {
141 kind,
142 representation,
143 unit_level: None,
144 alignment_key: None,
145 target_level: None,
146 relation_contract: None,
147 allows_broadcast: false,
148 missingness_policy: None,
149 requires_oof: false,
150 requires_fold_alignment: false,
151 propagates_lineage: true,
152 }
153 }
154}
155
156#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
157pub struct EdgeSpec {
158 pub source: PortRef,
159 pub target: PortRef,
160 pub contract: EdgeContract,
161}
162
163#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
164pub struct GraphInterface {
165 #[serde(default)]
166 pub inputs: Vec<PortSpec>,
167 #[serde(default)]
168 pub outputs: Vec<PortSpec>,
169}
170
171#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
172pub struct NodeSpec {
173 pub id: NodeId,
174 pub kind: NodeKind,
175 pub operator: Option<serde_json::Value>,
176 #[serde(default)]
177 pub params: BTreeMap<String, serde_json::Value>,
178 #[serde(default)]
179 pub ports: PortSchema,
180 #[serde(default)]
181 pub metadata: BTreeMap<String, serde_json::Value>,
182 #[serde(default)]
183 pub seed_label: Option<String>,
184}
185
186#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
187pub struct GraphSpec {
188 pub id: String,
189 #[serde(default)]
190 pub interface: GraphInterface,
191 #[serde(default)]
192 pub nodes: Vec<NodeSpec>,
193 #[serde(default)]
194 pub edges: Vec<EdgeSpec>,
195 #[serde(default)]
196 pub search_space_fingerprint: Option<String>,
197 #[serde(default)]
198 pub metadata: BTreeMap<String, serde_json::Value>,
199}
200
201impl GraphSpec {
202 pub fn validate(&self) -> Result<()> {
203 if self.id.trim().is_empty() {
204 return Err(DagMlError::GraphValidation(
205 "graph id must not be empty".to_string(),
206 ));
207 }
208 if self.nodes.is_empty() {
209 return Err(DagMlError::GraphValidation(
210 "graph must contain at least one node".to_string(),
211 ));
212 }
213 if let Some(fingerprint) = &self.search_space_fingerprint {
214 if fingerprint.trim().is_empty() {
215 return Err(DagMlError::GraphValidation(format!(
216 "graph `{}` has empty search_space_fingerprint",
217 self.id
218 )));
219 }
220 }
221
222 let mut nodes = BTreeMap::new();
223 validate_unique_ports(
224 &NodeId::new("graph:interface").expect("static identifier is valid"),
225 "interface input",
226 &self.interface.inputs,
227 )?;
228 validate_unique_ports(
229 &NodeId::new("graph:interface").expect("static identifier is valid"),
230 "interface output",
231 &self.interface.outputs,
232 )?;
233 for node in &self.nodes {
234 if nodes.insert(node.id.clone(), node).is_some() {
235 return Err(DagMlError::GraphValidation(format!(
236 "duplicate node id `{}`",
237 node.id
238 )));
239 }
240 validate_unique_ports(&node.id, "input", &node.ports.inputs)?;
241 validate_unique_ports(&node.id, "output", &node.ports.outputs)?;
242 }
243
244 let mut adjacency: BTreeMap<NodeId, Vec<NodeId>> = nodes
245 .keys()
246 .cloned()
247 .map(|id| (id, Vec::new()))
248 .collect::<BTreeMap<_, _>>();
249 let mut indegree: BTreeMap<NodeId, usize> =
250 nodes.keys().cloned().map(|id| (id, 0)).collect();
251
252 for edge in &self.edges {
253 let source = nodes.get(&edge.source.node_id).ok_or_else(|| {
254 DagMlError::GraphValidation(format!(
255 "edge source node `{}` does not exist",
256 edge.source.node_id
257 ))
258 })?;
259 let target = nodes.get(&edge.target.node_id).ok_or_else(|| {
260 DagMlError::GraphValidation(format!(
261 "edge target node `{}` does not exist",
262 edge.target.node_id
263 ))
264 })?;
265
266 let source_port =
267 find_port(&source.ports.outputs, &edge.source.port_name).ok_or_else(|| {
268 DagMlError::GraphValidation(format!(
269 "source port `{}.{}` does not exist",
270 edge.source.node_id, edge.source.port_name
271 ))
272 })?;
273 let target_port =
274 find_port(&target.ports.inputs, &edge.target.port_name).ok_or_else(|| {
275 DagMlError::GraphValidation(format!(
276 "target port `{}.{}` does not exist",
277 edge.target.node_id, edge.target.port_name
278 ))
279 })?;
280
281 if source_port.kind != edge.contract.kind || target_port.kind != edge.contract.kind {
282 return Err(DagMlError::GraphValidation(format!(
283 "edge `{}.{}` -> `{}.{}` has kind {:?}, but ports are {:?} and {:?}",
284 edge.source.node_id,
285 edge.source.port_name,
286 edge.target.node_id,
287 edge.target.port_name,
288 edge.contract.kind,
289 source_port.kind,
290 target_port.kind
291 )));
292 }
293 validate_edge_contract(edge, source_port, target_port)?;
294 if edge.contract.requires_oof && edge.contract.kind != PortKind::Prediction {
295 return Err(DagMlError::GraphValidation(format!(
296 "edge `{}.{}` -> `{}.{}` requires OOF but is not a prediction edge",
297 edge.source.node_id,
298 edge.source.port_name,
299 edge.target.node_id,
300 edge.target.port_name
301 )));
302 }
303
304 adjacency
305 .get_mut(&edge.source.node_id)
306 .expect("source exists")
307 .push(edge.target.node_id.clone());
308 *indegree
309 .get_mut(&edge.target.node_id)
310 .expect("target exists") += 1;
311 }
312
313 ensure_acyclic(adjacency, indegree)
314 }
315
316 pub fn topological_order(&self) -> Result<Vec<NodeId>> {
317 self.validate()?;
318 let nodes = self
319 .nodes
320 .iter()
321 .map(|node| node.id.clone())
322 .collect::<BTreeSet<_>>();
323 let mut adjacency = nodes
324 .iter()
325 .cloned()
326 .map(|id| (id, Vec::new()))
327 .collect::<BTreeMap<_, _>>();
328 let mut indegree: BTreeMap<NodeId, usize> =
329 nodes.iter().cloned().map(|id| (id, 0usize)).collect();
330 for edge in &self.edges {
331 adjacency
332 .get_mut(&edge.source.node_id)
333 .expect("source exists after validate")
334 .push(edge.target.node_id.clone());
335 *indegree
336 .get_mut(&edge.target.node_id)
337 .expect("target exists after validate") += 1;
338 }
339 topological_order(adjacency, indegree)
340 }
341
342 pub fn parallel_levels(&self) -> Result<Vec<Vec<NodeId>>> {
343 self.validate()?;
344 let nodes = self
345 .nodes
346 .iter()
347 .map(|node| node.id.clone())
348 .collect::<BTreeSet<_>>();
349 let mut adjacency = nodes
350 .iter()
351 .cloned()
352 .map(|id| (id, Vec::new()))
353 .collect::<BTreeMap<_, _>>();
354 let mut indegree: BTreeMap<NodeId, usize> =
355 nodes.iter().cloned().map(|id| (id, 0usize)).collect();
356 for edge in &self.edges {
357 adjacency
358 .get_mut(&edge.source.node_id)
359 .expect("source exists after validate")
360 .push(edge.target.node_id.clone());
361 *indegree
362 .get_mut(&edge.target.node_id)
363 .expect("target exists after validate") += 1;
364 }
365 topological_levels(adjacency, indegree)
366 }
367
368 pub fn upstream_nodes(&self, node_id: &NodeId) -> Vec<NodeId> {
369 let mut upstream = self
370 .edges
371 .iter()
372 .filter_map(|edge| {
373 (edge.target.node_id == *node_id).then_some(edge.source.node_id.clone())
374 })
375 .collect::<Vec<_>>();
376 upstream.sort();
377 upstream.dedup();
378 upstream
379 }
380
381 pub fn downstream_nodes(&self, node_id: &NodeId) -> Vec<NodeId> {
382 let mut downstream = self
383 .edges
384 .iter()
385 .filter_map(|edge| {
386 (edge.source.node_id == *node_id).then_some(edge.target.node_id.clone())
387 })
388 .collect::<Vec<_>>();
389 downstream.sort();
390 downstream.dedup();
391 downstream
392 }
393}
394
395fn validate_unique_ports(node_id: &NodeId, direction: &str, ports: &[PortSpec]) -> Result<()> {
396 let mut seen = BTreeSet::new();
397 for port in ports {
398 if port.name.trim().is_empty() {
399 return Err(DagMlError::GraphValidation(format!(
400 "{} port on node `{}` has an empty name",
401 direction, node_id
402 )));
403 }
404 if !seen.insert(port.name.as_str()) {
405 return Err(DagMlError::GraphValidation(format!(
406 "duplicate {} port `{}` on node `{}`",
407 direction, port.name, node_id
408 )));
409 }
410 validate_port_contract(node_id, direction, port)?;
411 }
412 Ok(())
413}
414
415fn find_port<'a>(ports: &'a [PortSpec], name: &str) -> Option<&'a PortSpec> {
416 ports.iter().find(|port| port.name == name)
417}
418
419fn validate_port_contract(node_id: &NodeId, direction: &str, port: &PortSpec) -> Result<()> {
420 validate_optional_non_empty(
421 &format!("{direction} port `{}` representation", port.name),
422 port.representation.as_deref(),
423 )?;
424 validate_optional_non_empty(
425 &format!("{direction} port `{}` alignment_key", port.name),
426 port.alignment_key.as_deref(),
427 )?;
428 if port
429 .alignment_key
430 .as_deref()
431 .is_some_and(|key| !is_identifier(key))
432 {
433 return Err(DagMlError::GraphValidation(format!(
434 "{direction} port `{}` on node `{node_id}` has invalid alignment_key",
435 port.name
436 )));
437 }
438 Ok(())
439}
440
441fn validate_edge_contract(
442 edge: &EdgeSpec,
443 source_port: &PortSpec,
444 target_port: &PortSpec,
445) -> Result<()> {
446 let label = format!(
447 "edge `{}.{}` -> `{}.{}`",
448 edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
449 );
450 validate_optional_non_empty(
451 &format!("{label} representation"),
452 edge.contract.representation.as_deref(),
453 )?;
454 validate_optional_non_empty(
455 &format!("{label} alignment_key"),
456 edge.contract.alignment_key.as_deref(),
457 )?;
458 if edge
459 .contract
460 .alignment_key
461 .as_deref()
462 .is_some_and(|key| !is_identifier(key))
463 {
464 return Err(DagMlError::GraphValidation(format!(
465 "{label} has invalid alignment_key"
466 )));
467 }
468 if let Some(relation_contract) = &edge.contract.relation_contract {
469 validate_relation_contract(&label, relation_contract)?;
470 }
471
472 validate_edge_unit_alignment(&label, edge, source_port, target_port)?;
473
474 if relation_aware_edge(edge, source_port, target_port) {
475 let relation_fingerprint = edge
476 .contract
477 .relation_contract
478 .as_ref()
479 .and_then(|contract| contract.relation_fingerprint.as_deref());
480 if relation_fingerprint.is_none() {
481 return Err(DagMlError::GraphValidation(format!(
482 "{label} is relation-aware but has no relation_fingerprint"
483 )));
484 }
485 if !has_effective_unit_level(edge, source_port, target_port) {
486 return Err(DagMlError::GraphValidation(format!(
487 "{label} is relation-aware but has no unit_level metadata"
488 )));
489 }
490 if !has_effective_alignment_key(edge, source_port, target_port) {
491 return Err(DagMlError::GraphValidation(format!(
492 "{label} is relation-aware but has no alignment_key"
493 )));
494 }
495 }
496 Ok(())
497}
498
499fn validate_relation_contract(label: &str, contract: &RelationContract) -> Result<()> {
500 if let Some(fingerprint) = &contract.relation_fingerprint {
501 validate_sha256(label, "relation_fingerprint", fingerprint)?;
502 } else if contract.required {
503 return Err(DagMlError::GraphValidation(format!(
504 "{label} relation_contract is required but has no relation_fingerprint"
505 )));
506 }
507 Ok(())
508}
509
510fn validate_edge_unit_alignment(
511 label: &str,
512 edge: &EdgeSpec,
513 source_port: &PortSpec,
514 target_port: &PortSpec,
515) -> Result<()> {
516 if let Some(contract_unit) = edge.contract.unit_level {
517 for (endpoint, unit) in [
518 ("source", source_port.unit_level),
519 ("target", target_port.unit_level),
520 ] {
521 if let Some(unit) = unit {
522 if unit != contract_unit && !edge.contract.allows_broadcast {
523 return Err(DagMlError::GraphValidation(format!(
524 "{label} {endpoint} unit {:?} does not match edge unit {:?}",
525 unit, contract_unit
526 )));
527 }
528 }
529 }
530 }
531
532 if let (Some(source_unit), Some(target_unit)) = (source_port.unit_level, target_port.unit_level)
533 {
534 if source_unit != target_unit && !edge.contract.allows_broadcast {
535 return Err(DagMlError::GraphValidation(format!(
536 "{label} joins incompatible unit levels {:?} and {:?}",
537 source_unit, target_unit
538 )));
539 }
540 }
541
542 if let (Some(source_target), Some(target_target)) =
543 (source_port.target_level, target_port.target_level)
544 {
545 if source_target != target_target {
546 return Err(DagMlError::GraphValidation(format!(
547 "{label} joins incompatible target levels {:?} and {:?}",
548 source_target, target_target
549 )));
550 }
551 }
552 if let Some(contract_target) = edge.contract.target_level {
553 for (endpoint, target_level) in [
554 ("source", source_port.target_level),
555 ("target", target_port.target_level),
556 ] {
557 if let Some(target_level) = target_level {
558 if target_level != contract_target {
559 return Err(DagMlError::GraphValidation(format!(
560 "{label} {endpoint} target level {:?} does not match edge target_level {:?}",
561 target_level, contract_target
562 )));
563 }
564 }
565 }
566 }
567
568 if let (Some(source_alignment), Some(target_alignment)) = (
569 source_port.alignment_key.as_deref(),
570 target_port.alignment_key.as_deref(),
571 ) {
572 if source_alignment != target_alignment && !edge.contract.allows_broadcast {
573 return Err(DagMlError::GraphValidation(format!(
574 "{label} joins incompatible alignment keys `{source_alignment}` and `{target_alignment}`"
575 )));
576 }
577 }
578
579 if let Some(edge_alignment) = edge.contract.alignment_key.as_deref() {
580 for (endpoint, alignment) in [
581 ("source", source_port.alignment_key.as_deref()),
582 ("target", target_port.alignment_key.as_deref()),
583 ] {
584 if let Some(alignment) = alignment {
585 if alignment != edge_alignment && !edge.contract.allows_broadcast {
586 return Err(DagMlError::GraphValidation(format!(
587 "{label} {endpoint} alignment `{alignment}` does not match edge alignment `{edge_alignment}`"
588 )));
589 }
590 }
591 }
592 }
593
594 if edge.contract.allows_broadcast
595 && edge.contract.alignment_key.is_none()
596 && source_port.alignment_key.is_none()
597 && target_port.alignment_key.is_none()
598 {
599 return Err(DagMlError::GraphValidation(format!(
600 "{label} allows broadcast but declares no alignment_key"
601 )));
602 }
603 Ok(())
604}
605
606fn relation_aware_edge(edge: &EdgeSpec, source_port: &PortSpec, target_port: &PortSpec) -> bool {
607 edge.contract.relation_contract.is_some()
608 || edge.contract.allows_broadcast
609 || edge.contract.alignment_key.is_some()
610 || non_physical(edge.contract.unit_level)
611 || non_physical(edge.contract.target_level)
612 || non_physical(source_port.unit_level)
613 || non_physical(source_port.target_level)
614 || non_physical(target_port.unit_level)
615 || non_physical(target_port.target_level)
616 || source_port.alignment_key.is_some()
617 || target_port.alignment_key.is_some()
618}
619
620fn has_effective_unit_level(
621 edge: &EdgeSpec,
622 source_port: &PortSpec,
623 target_port: &PortSpec,
624) -> bool {
625 edge.contract.unit_level.is_some()
626 || source_port.unit_level.is_some()
627 || target_port.unit_level.is_some()
628}
629
630fn has_effective_alignment_key(
631 edge: &EdgeSpec,
632 source_port: &PortSpec,
633 target_port: &PortSpec,
634) -> bool {
635 edge.contract.alignment_key.is_some()
636 || source_port.alignment_key.is_some()
637 || target_port.alignment_key.is_some()
638}
639
640fn non_physical(unit_level: Option<EntityUnitLevel>) -> bool {
641 unit_level.is_some_and(|level| level != EntityUnitLevel::PhysicalSample)
642}
643
644fn validate_optional_non_empty(label: &str, value: Option<&str>) -> Result<()> {
645 if value.is_some_and(|value| value.trim().is_empty()) {
646 return Err(DagMlError::GraphValidation(format!(
647 "{label} must not be empty"
648 )));
649 }
650 Ok(())
651}
652
653fn validate_sha256(owner: &str, field: &str, value: &str) -> Result<()> {
654 if value.len() == 64 && value.bytes().all(|byte| byte.is_ascii_hexdigit()) {
655 Ok(())
656 } else {
657 Err(DagMlError::GraphValidation(format!(
658 "{owner} has invalid {field}"
659 )))
660 }
661}
662
663fn is_identifier(value: &str) -> bool {
664 !value.is_empty()
665 && value.len() <= 128
666 && value
667 .bytes()
668 .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b':'))
669}
670
671fn ensure_acyclic(
672 adjacency: BTreeMap<NodeId, Vec<NodeId>>,
673 indegree: BTreeMap<NodeId, usize>,
674) -> Result<()> {
675 topological_order(adjacency, indegree).map(|_| ())
676}
677
678fn topological_order(
679 adjacency: BTreeMap<NodeId, Vec<NodeId>>,
680 mut indegree: BTreeMap<NodeId, usize>,
681) -> Result<Vec<NodeId>> {
682 let mut queue = indegree
683 .iter()
684 .filter_map(|(id, degree)| (*degree == 0).then_some(id.clone()))
685 .collect::<BTreeSet<_>>();
686 let mut order = Vec::with_capacity(indegree.len());
687
688 while let Some(node) = queue.pop_first() {
689 order.push(node.clone());
690 if let Some(next_nodes) = adjacency.get(&node) {
691 for next in next_nodes {
692 let degree = indegree.get_mut(next).expect("node exists");
693 *degree -= 1;
694 if *degree == 0 {
695 queue.insert(next.clone());
696 }
697 }
698 }
699 }
700
701 if order.len() == indegree.len() {
702 Ok(order)
703 } else {
704 Err(DagMlError::GraphValidation(
705 "graph contains at least one cycle".to_string(),
706 ))
707 }
708}
709
710fn topological_levels(
711 adjacency: BTreeMap<NodeId, Vec<NodeId>>,
712 mut indegree: BTreeMap<NodeId, usize>,
713) -> Result<Vec<Vec<NodeId>>> {
714 let mut queue = indegree
715 .iter()
716 .filter_map(|(id, degree)| (*degree == 0).then_some(id.clone()))
717 .collect::<BTreeSet<_>>();
718 let mut levels = Vec::new();
719 let mut visited = 0usize;
720
721 while !queue.is_empty() {
722 let level = queue.iter().cloned().collect::<Vec<_>>();
723 queue.clear();
724 for node in &level {
725 visited += 1;
726 if let Some(next_nodes) = adjacency.get(node) {
727 for next in next_nodes {
728 let degree = indegree.get_mut(next).expect("node exists");
729 *degree -= 1;
730 if *degree == 0 {
731 queue.insert(next.clone());
732 }
733 }
734 }
735 }
736 levels.push(level);
737 }
738
739 if visited == indegree.len() {
740 Ok(levels)
741 } else {
742 Err(DagMlError::GraphValidation(
743 "graph contains at least one cycle".to_string(),
744 ))
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751
752 fn port(name: &str, kind: PortKind) -> PortSpec {
753 PortSpec {
754 name: name.to_string(),
755 kind,
756 representation: None,
757 cardinality: PortCardinality::One,
758 unit_level: None,
759 alignment_key: None,
760 target_level: None,
761 description: String::new(),
762 }
763 }
764
765 fn node(id: &str, inputs: Vec<PortSpec>, outputs: Vec<PortSpec>) -> NodeSpec {
766 NodeSpec {
767 id: NodeId::new(id).unwrap(),
768 kind: NodeKind::Model,
769 operator: None,
770 params: BTreeMap::new(),
771 ports: PortSchema { inputs, outputs },
772 metadata: BTreeMap::new(),
773 seed_label: None,
774 }
775 }
776
777 fn edge(source: &str, source_port: &str, target: &str, target_port: &str) -> EdgeSpec {
778 EdgeSpec {
779 source: PortRef {
780 node_id: NodeId::new(source).unwrap(),
781 port_name: source_port.to_string(),
782 },
783 target: PortRef {
784 node_id: NodeId::new(target).unwrap(),
785 port_name: target_port.to_string(),
786 },
787 contract: EdgeContract {
788 requires_oof: true,
789 requires_fold_alignment: true,
790 ..EdgeContract::new(PortKind::Prediction, None)
791 },
792 }
793 }
794
795 #[test]
796 fn validates_simple_graph() {
797 let graph = GraphSpec {
798 id: "g".to_string(),
799 interface: GraphInterface::default(),
800 nodes: vec![
801 node("model:a", vec![], vec![port("pred", PortKind::Prediction)]),
802 node("model:b", vec![port("pred", PortKind::Prediction)], vec![]),
803 ],
804 edges: vec![edge("model:a", "pred", "model:b", "pred")],
805 search_space_fingerprint: None,
806 metadata: BTreeMap::new(),
807 };
808
809 assert!(graph.validate().is_ok());
810 }
811
812 #[test]
813 fn computes_deterministic_parallel_levels() {
814 let graph = GraphSpec {
815 id: "g".to_string(),
816 interface: GraphInterface::default(),
817 nodes: vec![
818 node("model:a", vec![], vec![port("pred", PortKind::Prediction)]),
819 node(
820 "model:b",
821 vec![port("pred", PortKind::Prediction)],
822 vec![port("pred", PortKind::Prediction)],
823 ),
824 node(
825 "model:c",
826 vec![port("pred", PortKind::Prediction)],
827 vec![port("pred", PortKind::Prediction)],
828 ),
829 node("model:d", vec![port("pred", PortKind::Prediction)], vec![]),
830 ],
831 edges: vec![
832 edge("model:a", "pred", "model:b", "pred"),
833 edge("model:a", "pred", "model:c", "pred"),
834 edge("model:b", "pred", "model:d", "pred"),
835 edge("model:c", "pred", "model:d", "pred"),
836 ],
837 search_space_fingerprint: None,
838 metadata: BTreeMap::new(),
839 };
840
841 let levels = graph.parallel_levels().unwrap();
842
843 assert_eq!(
844 levels,
845 vec![
846 vec![NodeId::new("model:a").unwrap()],
847 vec![
848 NodeId::new("model:b").unwrap(),
849 NodeId::new("model:c").unwrap()
850 ],
851 vec![NodeId::new("model:d").unwrap()]
852 ]
853 );
854 }
855
856 #[test]
857 fn rejects_missing_edge_endpoint() {
858 let graph = GraphSpec {
859 id: "g".to_string(),
860 interface: GraphInterface::default(),
861 nodes: vec![node(
862 "model:a",
863 vec![],
864 vec![port("pred", PortKind::Prediction)],
865 )],
866 edges: vec![edge("model:a", "pred", "model:b", "pred")],
867 search_space_fingerprint: None,
868 metadata: BTreeMap::new(),
869 };
870
871 assert!(graph.validate().is_err());
872 }
873
874 #[test]
875 fn rejects_oof_contract_on_non_prediction_edge() {
876 let graph = GraphSpec {
877 id: "g".to_string(),
878 interface: GraphInterface::default(),
879 nodes: vec![
880 node("model:a", vec![], vec![port("x", PortKind::Data)]),
881 node("model:b", vec![port("x", PortKind::Data)], vec![]),
882 ],
883 edges: vec![EdgeSpec {
884 source: PortRef {
885 node_id: NodeId::new("model:a").unwrap(),
886 port_name: "x".to_string(),
887 },
888 target: PortRef {
889 node_id: NodeId::new("model:b").unwrap(),
890 port_name: "x".to_string(),
891 },
892 contract: EdgeContract {
893 requires_oof: true,
894 requires_fold_alignment: true,
895 ..EdgeContract::new(PortKind::Data, None)
896 },
897 }],
898 search_space_fingerprint: None,
899 metadata: BTreeMap::new(),
900 };
901
902 let error = graph.validate().unwrap_err().to_string();
903
904 assert!(error.contains("requires OOF"));
905 }
906
907 fn unit_port(name: &str, kind: PortKind, unit_level: EntityUnitLevel) -> PortSpec {
908 let mut port = port(name, kind);
909 port.unit_level = Some(unit_level);
910 port.alignment_key = Some("sample_id".to_string());
911 port
912 }
913
914 fn data_edge_contract() -> EdgeContract {
915 EdgeContract::new(PortKind::Data, Some("tabular".to_string()))
916 }
917
918 fn relation_contract() -> RelationContract {
919 RelationContract {
920 relation_fingerprint: Some("a".repeat(64)),
921 required: true,
922 }
923 }
924
925 #[test]
926 fn rejects_unit_mismatch_without_explicit_broadcast() {
927 let graph = GraphSpec {
928 id: "g".to_string(),
929 interface: GraphInterface::default(),
930 nodes: vec![
931 node(
932 "transform:obs",
933 vec![],
934 vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
935 ),
936 node(
937 "join:sample",
938 vec![unit_port(
939 "x",
940 PortKind::Data,
941 EntityUnitLevel::PhysicalSample,
942 )],
943 vec![],
944 ),
945 ],
946 edges: vec![EdgeSpec {
947 source: PortRef {
948 node_id: NodeId::new("transform:obs").unwrap(),
949 port_name: "x".to_string(),
950 },
951 target: PortRef {
952 node_id: NodeId::new("join:sample").unwrap(),
953 port_name: "x".to_string(),
954 },
955 contract: EdgeContract {
956 relation_contract: Some(relation_contract()),
957 ..data_edge_contract()
958 },
959 }],
960 search_space_fingerprint: None,
961 metadata: BTreeMap::new(),
962 };
963
964 let error = graph.validate().unwrap_err().to_string();
965
966 assert!(error.contains("incompatible unit levels"));
967 }
968
969 #[test]
970 fn relation_aware_edge_requires_relation_fingerprint() {
971 let graph = GraphSpec {
972 id: "g".to_string(),
973 interface: GraphInterface::default(),
974 nodes: vec![
975 node(
976 "source:a",
977 vec![],
978 vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
979 ),
980 node(
981 "model:a",
982 vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
983 vec![],
984 ),
985 ],
986 edges: vec![EdgeSpec {
987 source: PortRef {
988 node_id: NodeId::new("source:a").unwrap(),
989 port_name: "x".to_string(),
990 },
991 target: PortRef {
992 node_id: NodeId::new("model:a").unwrap(),
993 port_name: "x".to_string(),
994 },
995 contract: data_edge_contract(),
996 }],
997 search_space_fingerprint: None,
998 metadata: BTreeMap::new(),
999 };
1000
1001 let error = graph.validate().unwrap_err().to_string();
1002
1003 assert!(error.contains("relation-aware"));
1004 }
1005
1006 #[test]
1007 fn relation_aware_edge_requires_alignment_key() {
1008 let mut source_port = port("x", PortKind::Data);
1009 source_port.unit_level = Some(EntityUnitLevel::Observation);
1010 let mut target_port = port("x", PortKind::Data);
1011 target_port.unit_level = Some(EntityUnitLevel::Observation);
1012
1013 let graph = GraphSpec {
1014 id: "g".to_string(),
1015 interface: GraphInterface::default(),
1016 nodes: vec![
1017 node("source:a", vec![], vec![source_port]),
1018 node("model:a", vec![target_port], vec![]),
1019 ],
1020 edges: vec![EdgeSpec {
1021 source: PortRef {
1022 node_id: NodeId::new("source:a").unwrap(),
1023 port_name: "x".to_string(),
1024 },
1025 target: PortRef {
1026 node_id: NodeId::new("model:a").unwrap(),
1027 port_name: "x".to_string(),
1028 },
1029 contract: EdgeContract {
1030 relation_contract: Some(relation_contract()),
1031 ..data_edge_contract()
1032 },
1033 }],
1034 search_space_fingerprint: None,
1035 metadata: BTreeMap::new(),
1036 };
1037
1038 let error = graph.validate().unwrap_err().to_string();
1039
1040 assert!(error.contains("alignment_key"));
1041 }
1042
1043 #[test]
1044 fn explicit_broadcast_allows_sample_to_observation_edge() {
1045 let mut contract = data_edge_contract();
1046 contract.allows_broadcast = true;
1047 contract.alignment_key = Some("sample_id".to_string());
1048 contract.relation_contract = Some(relation_contract());
1049
1050 let graph = GraphSpec {
1051 id: "g".to_string(),
1052 interface: GraphInterface::default(),
1053 nodes: vec![
1054 node(
1055 "source:sample",
1056 vec![],
1057 vec![unit_port(
1058 "x",
1059 PortKind::Data,
1060 EntityUnitLevel::PhysicalSample,
1061 )],
1062 ),
1063 node(
1064 "adapter:broadcast",
1065 vec![unit_port("x", PortKind::Data, EntityUnitLevel::Observation)],
1066 vec![],
1067 ),
1068 ],
1069 edges: vec![EdgeSpec {
1070 source: PortRef {
1071 node_id: NodeId::new("source:sample").unwrap(),
1072 port_name: "x".to_string(),
1073 },
1074 target: PortRef {
1075 node_id: NodeId::new("adapter:broadcast").unwrap(),
1076 port_name: "x".to_string(),
1077 },
1078 contract,
1079 }],
1080 search_space_fingerprint: None,
1081 metadata: BTreeMap::new(),
1082 };
1083
1084 graph.validate().unwrap();
1085 }
1086
1087 #[test]
1088 fn rejects_cycles() {
1089 let graph = GraphSpec {
1090 id: "g".to_string(),
1091 interface: GraphInterface::default(),
1092 nodes: vec![
1093 node(
1094 "model:a",
1095 vec![port("pred", PortKind::Prediction)],
1096 vec![port("pred", PortKind::Prediction)],
1097 ),
1098 node(
1099 "model:b",
1100 vec![port("pred", PortKind::Prediction)],
1101 vec![port("pred", PortKind::Prediction)],
1102 ),
1103 ],
1104 edges: vec![
1105 edge("model:a", "pred", "model:b", "pred"),
1106 edge("model:b", "pred", "model:a", "pred"),
1107 ],
1108 search_space_fingerprint: None,
1109 metadata: BTreeMap::new(),
1110 };
1111
1112 assert!(graph.validate().is_err());
1113 }
1114
1115 #[test]
1116 fn published_graph_spec_schema_declares_current_contract() {
1117 let schema: serde_json::Value = serde_json::from_str(include_str!(
1118 "../../../docs/contracts/graph_spec.schema.json"
1119 ))
1120 .unwrap();
1121
1122 assert_eq!(schema["$id"], GRAPH_SPEC_SCHEMA_ID);
1123 assert!(schema["required"]
1124 .as_array()
1125 .unwrap()
1126 .iter()
1127 .any(|field| field.as_str() == Some("nodes")));
1128 assert_eq!(
1129 schema["$defs"]["node_kind"]["enum"]
1130 .as_array()
1131 .unwrap()
1132 .len(),
1133 20
1134 );
1135 assert!(schema["$defs"]["port_kind"]["enum"]
1136 .as_array()
1137 .unwrap()
1138 .iter()
1139 .any(|kind| kind.as_str() == Some("prediction")));
1140 assert!(schema["$defs"]["entity_unit_level"]["enum"]
1141 .as_array()
1142 .unwrap()
1143 .iter()
1144 .any(|level| level.as_str() == Some("combo")));
1145 assert!(schema["$defs"]["edge_contract"]["properties"]
1146 .as_object()
1147 .unwrap()
1148 .contains_key("relation_contract"));
1149 }
1150}