1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::data::ModelInputSpec;
6use crate::error::{DagMlError, Result};
7use crate::graph::{NodeKind, NodeSpec, PortKind, PortSpec};
8use crate::ids::ControllerId;
9use crate::phase::Phase;
10use crate::policy::FitInfluencePolicy;
11
12pub const CONTROLLER_MANIFEST_SCHEMA_VERSION: u32 = 1;
13pub const CONTROLLER_MANIFEST_SCHEMA_ID: &str =
14 "https://github.com/GBeurier/dag-ml/schemas/controller_manifest.v1.schema.json";
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ControllerCapability {
19 Deterministic,
20 ThreadSafe,
21 ProcessSafe,
22 NeedsPythonGil,
23 EmitsPredictions,
24 ConsumesOofPredictions,
25 EmitsArtifacts,
26 Stateful,
27 EmitsRelation,
28 UsesCoreRng,
29 ShapeChanging,
30 GeneratesData,
31 GeneratesModel,
32 ExpandsVariants,
33 AggregatesPredictions,
34 SupportsSampleWeights,
35 SupportsRowResampling,
36 SupportsBackendLossWeights,
37 SupportsMissingMasks,
38}
39
40#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
41#[serde(rename_all = "snake_case")]
42pub enum ControllerFitScope {
43 Stateless,
44 FoldTrain,
45 FullTrain,
46 InferenceOnly,
47}
48
49#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum RngPolicy {
52 UsesCoreSeed,
53 IgnoresSeed,
54 ExternallyDeterministic,
55 Nondeterministic,
56}
57
58#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum ArtifactPolicy {
61 Serializable,
62 HostOnly,
63 ContentAddressed,
64 ReplayRequired,
65}
66
67#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
68#[serde(deny_unknown_fields)]
69pub struct OperatorSelector {
70 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
71 pub aliases: BTreeSet<String>,
72 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
73 pub classes: BTreeSet<String>,
74 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
75 pub class_prefixes: BTreeSet<String>,
76 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
77 pub functions: BTreeSet<String>,
78 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
79 pub refs: BTreeSet<String>,
80 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
81 pub types: BTreeSet<String>,
82}
83
84impl OperatorSelector {
85 fn validate(&self, controller_id: &ControllerId) -> Result<()> {
86 if self.aliases.is_empty()
87 && self.classes.is_empty()
88 && self.class_prefixes.is_empty()
89 && self.functions.is_empty()
90 && self.refs.is_empty()
91 && self.types.is_empty()
92 {
93 return Err(DagMlError::ControllerValidation(format!(
94 "controller `{controller_id}` has an empty operator selector"
95 )));
96 }
97 for (field, values) in [
98 ("aliases", &self.aliases),
99 ("classes", &self.classes),
100 ("class_prefixes", &self.class_prefixes),
101 ("functions", &self.functions),
102 ("refs", &self.refs),
103 ("types", &self.types),
104 ] {
105 if values.iter().any(|value| value.trim().is_empty()) {
106 return Err(DagMlError::ControllerValidation(format!(
107 "controller `{controller_id}` operator selector `{field}` contains an empty value"
108 )));
109 }
110 }
111 Ok(())
112 }
113}
114
115#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
116#[serde(deny_unknown_fields)]
117pub struct ControllerManifest {
118 pub controller_id: ControllerId,
119 pub controller_version: String,
120 pub operator_kind: NodeKind,
121 #[serde(default)]
122 pub priority: u32,
123 #[serde(default)]
124 pub supported_phases: BTreeSet<Phase>,
125 #[serde(default)]
126 pub input_ports: Vec<PortSpec>,
127 #[serde(default)]
128 pub output_ports: Vec<PortSpec>,
129 #[serde(default)]
130 pub data_requirements: Option<serde_json::Value>,
131 #[serde(default)]
132 pub capabilities: BTreeSet<ControllerCapability>,
133 #[serde(default, skip_serializing_if = "Vec::is_empty")]
134 pub operator_selectors: Vec<OperatorSelector>,
135 pub fit_scope: ControllerFitScope,
136 pub rng_policy: RngPolicy,
137 pub artifact_policy: ArtifactPolicy,
138}
139
140impl ControllerManifest {
141 pub fn validate(&self) -> Result<()> {
142 if self.controller_version.trim().is_empty() {
143 return Err(DagMlError::ControllerValidation(format!(
144 "controller `{}` has an empty version",
145 self.controller_id
146 )));
147 }
148 if self.supported_phases.is_empty() {
149 return Err(DagMlError::ControllerValidation(format!(
150 "controller `{}` supports no phases",
151 self.controller_id
152 )));
153 }
154 if let Some(model_input) = self.model_input_spec()? {
155 model_input.validate().map_err(|error| {
156 DagMlError::ControllerValidation(format!(
157 "controller `{}` data_requirements are not a valid ModelInputSpec: {error}",
158 self.controller_id
159 ))
160 })?;
161 }
162 validate_ports(&self.controller_id, "input", &self.input_ports)?;
163 validate_ports(&self.controller_id, "output", &self.output_ports)?;
164 for selector in &self.operator_selectors {
165 selector.validate(&self.controller_id)?;
166 }
167 if self.rng_policy == RngPolicy::Nondeterministic
168 && self
169 .capabilities
170 .contains(&ControllerCapability::Deterministic)
171 {
172 return Err(DagMlError::ControllerValidation(format!(
173 "controller `{}` cannot be deterministic with nondeterministic RNG",
174 self.controller_id
175 )));
176 }
177 if self.fit_scope == ControllerFitScope::InferenceOnly
178 && (self.supported_phases.contains(&Phase::FitCv)
179 || self.supported_phases.contains(&Phase::Refit))
180 {
181 return Err(DagMlError::ControllerValidation(format!(
182 "controller `{}` is inference_only but supports training phases",
183 self.controller_id
184 )));
185 }
186 if self.supported_phases.contains(&Phase::FitCv)
187 && matches!(
188 self.fit_scope,
189 ControllerFitScope::FullTrain | ControllerFitScope::InferenceOnly
190 )
191 {
192 return Err(DagMlError::ControllerValidation(format!(
193 "controller `{}` supports FIT_CV but has fit_scope {:?}",
194 self.controller_id, self.fit_scope
195 )));
196 }
197 if self
198 .output_ports
199 .iter()
200 .any(|port| port.kind == PortKind::Prediction)
201 && !self
202 .capabilities
203 .contains(&ControllerCapability::EmitsPredictions)
204 {
205 return Err(DagMlError::ControllerValidation(format!(
206 "controller `{}` has prediction output ports but lacks emits_predictions",
207 self.controller_id
208 )));
209 }
210 if self
211 .output_ports
212 .iter()
213 .any(|port| port.kind == PortKind::Artifact)
214 && !self
215 .capabilities
216 .contains(&ControllerCapability::EmitsArtifacts)
217 {
218 return Err(DagMlError::ControllerValidation(format!(
219 "controller `{}` has artifact output ports but lacks emits_artifacts",
220 self.controller_id
221 )));
222 }
223 Ok(())
224 }
225
226 pub fn supports_phase(&self, phase: Phase) -> bool {
227 self.supported_phases.contains(&phase)
228 }
229
230 pub fn supports_parallel_invocation(&self) -> bool {
231 self.capabilities
232 .contains(&ControllerCapability::ThreadSafe)
233 || self
234 .capabilities
235 .contains(&ControllerCapability::ProcessSafe)
236 }
237
238 pub fn supports_fit_influence_policy(&self, policy: FitInfluencePolicy) -> bool {
239 capabilities_support_fit_influence(&self.capabilities, policy)
240 }
241
242 pub fn model_input_spec(&self) -> Result<Option<ModelInputSpec>> {
243 self.data_requirements
244 .as_ref()
245 .map(|value| {
246 serde_json::from_value::<ModelInputSpec>(value.clone()).map_err(|error| {
247 DagMlError::ControllerValidation(format!(
248 "controller `{}` data_requirements must be ModelInputSpec JSON: {error}",
249 self.controller_id
250 ))
251 })
252 })
253 .transpose()
254 }
255}
256
257pub fn capabilities_support_fit_influence(
258 capabilities: &BTreeSet<ControllerCapability>,
259 policy: FitInfluencePolicy,
260) -> bool {
261 match policy {
262 FitInfluencePolicy::Auto
263 | FitInfluencePolicy::UniformRows
264 | FitInfluencePolicy::ScorerOnly => true,
265 FitInfluencePolicy::EqualSampleInfluence => {
266 capabilities.contains(&ControllerCapability::SupportsSampleWeights)
267 }
268 FitInfluencePolicy::ResampleEqualized => {
269 capabilities.contains(&ControllerCapability::SupportsRowResampling)
270 }
271 FitInfluencePolicy::BackendLossWeight => {
272 capabilities.contains(&ControllerCapability::SupportsBackendLossWeights)
273 }
274 FitInfluencePolicy::StrictWeightSupport => {
275 capabilities.contains(&ControllerCapability::SupportsSampleWeights)
276 || capabilities.contains(&ControllerCapability::SupportsRowResampling)
277 || capabilities.contains(&ControllerCapability::SupportsBackendLossWeights)
278 }
279 }
280}
281
282#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
283pub struct ControllerRegistry {
284 manifests: BTreeMap<ControllerId, ControllerManifest>,
285}
286
287impl ControllerRegistry {
288 pub fn new() -> Self {
289 Self::default()
290 }
291
292 pub fn register(&mut self, manifest: ControllerManifest) -> Result<()> {
293 manifest.validate()?;
294 if self.manifests.contains_key(&manifest.controller_id) {
295 return Err(DagMlError::ControllerValidation(format!(
296 "duplicate controller id `{}`",
297 manifest.controller_id
298 )));
299 }
300 self.manifests
301 .insert(manifest.controller_id.clone(), manifest);
302 Ok(())
303 }
304
305 pub fn get(&self, controller_id: &ControllerId) -> Option<&ControllerManifest> {
306 self.manifests.get(controller_id)
307 }
308
309 pub fn manifests(&self) -> impl Iterator<Item = &ControllerManifest> {
310 self.manifests.values()
311 }
312
313 pub fn resolve_for_node(&self, node: &NodeSpec) -> Result<ControllerManifest> {
314 if let Some(requested) = requested_controller(node)? {
315 let manifest = self.get(&requested).ok_or_else(|| {
316 DagMlError::Planning(format!(
317 "node `{}` requested unknown controller `{requested}`",
318 node.id
319 ))
320 })?;
321 if manifest.operator_kind != node.kind {
322 return Err(DagMlError::Planning(format!(
323 "node `{}` kind {:?} is incompatible with controller `{}` kind {:?}",
324 node.id, node.kind, manifest.controller_id, manifest.operator_kind
325 )));
326 }
327 return Ok(manifest.clone());
328 }
329
330 let mut candidates = self
331 .manifests
332 .values()
333 .filter_map(|manifest| controller_candidate(manifest, node))
334 .collect::<Vec<_>>();
335 candidates.sort_by(|left, right| {
336 left.rank
337 .cmp(&right.rank)
338 .then_with(|| left.manifest.priority.cmp(&right.manifest.priority))
339 .then_with(|| {
340 left.manifest
341 .controller_id
342 .cmp(&right.manifest.controller_id)
343 })
344 });
345 let Some(first) = candidates.first() else {
346 return Err(DagMlError::Planning(format!(
347 "no controller registered for node `{}` kind {:?}",
348 node.id, node.kind
349 )));
350 };
351 if candidates.get(1).is_some_and(|second| {
352 second.rank == first.rank && second.manifest.priority == first.manifest.priority
353 }) {
354 return Err(DagMlError::Planning(format!(
355 "node `{}` has ambiguous controllers for kind {:?}; set metadata.controller_id",
356 node.id, node.kind
357 )));
358 }
359 Ok(first.manifest.clone())
360 }
361
362 pub fn infer_operator_kind(&self, operator: &serde_json::Value) -> Result<Option<NodeKind>> {
363 let matches = self
364 .manifests
365 .values()
366 .filter(|manifest| {
367 !manifest.operator_selectors.is_empty()
368 && manifest
369 .operator_selectors
370 .iter()
371 .any(|selector| selector_matches_operator(selector, operator))
372 })
373 .collect::<Vec<_>>();
374 let Some(first) = matches.first() else {
375 return Ok(None);
376 };
377 let kind = first.operator_kind.clone();
378 let conflicting = matches
379 .iter()
380 .find(|manifest| manifest.operator_kind != kind);
381 if let Some(conflicting) = conflicting {
382 return Err(DagMlError::Planning(format!(
383 "minimal operator alias `{}` matches controllers with different node kinds ({:?} and {:?}); use explicit DSL syntax",
384 operator_label(operator),
385 kind,
386 conflicting.operator_kind
387 )));
388 }
389 Ok(Some(kind))
390 }
391}
392
393#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
394enum ControllerMatchRank {
395 OperatorSelector,
396 GenericKind,
397}
398
399struct ControllerCandidate<'a> {
400 manifest: &'a ControllerManifest,
401 rank: ControllerMatchRank,
402}
403
404fn controller_candidate<'a>(
405 manifest: &'a ControllerManifest,
406 node: &NodeSpec,
407) -> Option<ControllerCandidate<'a>> {
408 if manifest.operator_kind != node.kind {
409 return None;
410 }
411 if manifest.operator_selectors.is_empty() {
412 return Some(ControllerCandidate {
413 manifest,
414 rank: ControllerMatchRank::GenericKind,
415 });
416 }
417 let operator = node.operator.as_ref()?;
418 manifest
419 .operator_selectors
420 .iter()
421 .any(|selector| selector_matches_operator(selector, operator))
422 .then_some(ControllerCandidate {
423 manifest,
424 rank: ControllerMatchRank::OperatorSelector,
425 })
426}
427
428fn selector_matches_operator(selector: &OperatorSelector, operator: &serde_json::Value) -> bool {
429 let descriptor = OperatorDescriptor::from_value(operator);
430 selector_matches_any(
431 &selector.aliases,
432 descriptor.alias_candidates.iter().copied(),
433 ) || descriptor
434 .class
435 .is_some_and(|class| selector_matches_exact(&selector.classes, class))
436 || descriptor.class.is_some_and(|class| {
437 selector
438 .class_prefixes
439 .iter()
440 .any(|prefix| normalized_starts_with(class, prefix))
441 })
442 || descriptor
443 .function
444 .is_some_and(|function| selector_matches_exact(&selector.functions, function))
445 || descriptor
446 .reference
447 .is_some_and(|reference| selector_matches_exact(&selector.refs, reference))
448 || descriptor
449 .operator_type
450 .is_some_and(|operator_type| selector_matches_exact(&selector.types, operator_type))
451}
452
453fn operator_label(operator: &serde_json::Value) -> String {
454 match operator {
455 serde_json::Value::String(value) => value.clone(),
456 serde_json::Value::Object(object) => ["type", "ref", "class", "function"]
457 .into_iter()
458 .find_map(|key| object.get(key).and_then(serde_json::Value::as_str))
459 .map(str::to_string)
460 .unwrap_or_else(|| operator.to_string()),
461 _ => operator.to_string(),
462 }
463}
464
465fn selector_matches_any<'a>(
466 values: &BTreeSet<String>,
467 mut candidates: impl Iterator<Item = &'a str>,
468) -> bool {
469 candidates.any(|candidate| selector_matches_exact(values, candidate))
470}
471
472fn selector_matches_exact(values: &BTreeSet<String>, candidate: &str) -> bool {
473 values
474 .iter()
475 .any(|value| normalized_eq(value.as_str(), candidate))
476}
477
478fn normalized_eq(left: &str, right: &str) -> bool {
479 left.trim().eq_ignore_ascii_case(right.trim())
480}
481
482fn normalized_starts_with(value: &str, prefix: &str) -> bool {
483 value
484 .trim()
485 .to_ascii_lowercase()
486 .starts_with(&prefix.trim().to_ascii_lowercase())
487}
488
489struct OperatorDescriptor<'a> {
490 class: Option<&'a str>,
491 function: Option<&'a str>,
492 reference: Option<&'a str>,
493 operator_type: Option<&'a str>,
494 alias_candidates: Vec<&'a str>,
495}
496
497impl<'a> OperatorDescriptor<'a> {
498 fn from_value(value: &'a serde_json::Value) -> Self {
499 let mut descriptor = Self {
500 class: None,
501 function: None,
502 reference: None,
503 operator_type: None,
504 alias_candidates: Vec::new(),
505 };
506 match value {
507 serde_json::Value::String(reference) => {
508 descriptor.reference = Some(reference);
509 descriptor.push_alias_candidates(reference);
510 }
511 serde_json::Value::Object(object) => {
512 descriptor.class = object.get("class").and_then(serde_json::Value::as_str);
513 descriptor.function = object.get("function").and_then(serde_json::Value::as_str);
514 descriptor.reference = object.get("ref").and_then(serde_json::Value::as_str);
515 descriptor.operator_type = object.get("type").and_then(serde_json::Value::as_str);
516 for value in [
517 descriptor.operator_type,
518 descriptor.reference,
519 descriptor.class,
520 descriptor.function,
521 ]
522 .into_iter()
523 .flatten()
524 {
525 descriptor.push_alias_candidates(value);
526 }
527 }
528 _ => {}
529 }
530 descriptor
531 }
532
533 fn push_alias_candidates(&mut self, value: &'a str) {
534 self.alias_candidates.push(value);
535 if let Some(short) = value
536 .rsplit(['.', ':'])
537 .next()
538 .filter(|short| *short != value)
539 {
540 self.alias_candidates.push(short);
541 }
542 }
543}
544
545fn validate_ports(controller_id: &ControllerId, direction: &str, ports: &[PortSpec]) -> Result<()> {
546 let mut seen = BTreeSet::new();
547 for port in ports {
548 if port.name.trim().is_empty() {
549 return Err(DagMlError::ControllerValidation(format!(
550 "{direction} port on controller `{controller_id}` has an empty name"
551 )));
552 }
553 if !seen.insert(port.name.as_str()) {
554 return Err(DagMlError::ControllerValidation(format!(
555 "duplicate {direction} port `{}` on controller `{controller_id}`",
556 port.name
557 )));
558 }
559 }
560 Ok(())
561}
562
563fn requested_controller(node: &NodeSpec) -> Result<Option<ControllerId>> {
564 node.metadata
565 .get("controller_id")
566 .map(|value| {
567 value.as_str().ok_or_else(|| {
568 DagMlError::Planning(format!(
569 "node `{}` metadata.controller_id must be a string",
570 node.id
571 ))
572 })
573 })
574 .transpose()?
575 .map(ControllerId::new)
576 .transpose()
577}
578
579#[cfg(test)]
580mod tests {
581 use std::collections::{BTreeMap, BTreeSet};
582
583 use serde_json::json;
584
585 use super::*;
586 use crate::graph::{NodeSpec, PortCardinality, PortSchema};
587 use crate::ids::NodeId;
588
589 fn manifest(id: &str, kind: NodeKind, priority: u32) -> ControllerManifest {
590 ControllerManifest {
591 controller_id: ControllerId::new(id).unwrap(),
592 controller_version: "0.1.0".to_string(),
593 operator_kind: kind,
594 priority,
595 supported_phases: BTreeSet::from([Phase::FitCv]),
596 input_ports: Vec::new(),
597 output_ports: Vec::new(),
598 data_requirements: None,
599 capabilities: BTreeSet::from([ControllerCapability::Deterministic]),
600 operator_selectors: Vec::new(),
601 fit_scope: ControllerFitScope::FoldTrain,
602 rng_policy: RngPolicy::UsesCoreSeed,
603 artifact_policy: ArtifactPolicy::Serializable,
604 }
605 }
606
607 fn node(kind: NodeKind) -> NodeSpec {
608 NodeSpec {
609 id: NodeId::new("node:model").unwrap(),
610 kind,
611 operator: None,
612 params: BTreeMap::new(),
613 ports: PortSchema::default(),
614 metadata: BTreeMap::new(),
615 seed_label: None,
616 }
617 }
618
619 fn node_with_operator(kind: NodeKind, operator: serde_json::Value) -> NodeSpec {
620 NodeSpec {
621 operator: Some(operator),
622 ..node(kind)
623 }
624 }
625
626 fn alias_selector(alias: &str) -> OperatorSelector {
627 OperatorSelector {
628 aliases: BTreeSet::from([alias.to_string()]),
629 ..OperatorSelector::default()
630 }
631 }
632
633 #[test]
634 fn registry_resolves_lowest_priority_manifest() {
635 let mut registry = ControllerRegistry::new();
636 registry
637 .register(manifest("controller:slow", NodeKind::Model, 10))
638 .unwrap();
639 registry
640 .register(manifest("controller:fast", NodeKind::Model, 1))
641 .unwrap();
642
643 let resolved = registry.resolve_for_node(&node(NodeKind::Model)).unwrap();
644
645 assert_eq!(resolved.controller_id.as_str(), "controller:fast");
646 }
647
648 #[test]
649 fn explicit_controller_id_disambiguates() {
650 let mut registry = ControllerRegistry::new();
651 registry
652 .register(manifest("controller:a", NodeKind::Model, 1))
653 .unwrap();
654 registry
655 .register(manifest("controller:b", NodeKind::Model, 1))
656 .unwrap();
657 let mut node = node(NodeKind::Model);
658 node.metadata
659 .insert("controller_id".to_string(), json!("controller:b"));
660
661 let resolved = registry.resolve_for_node(&node).unwrap();
662
663 assert_eq!(resolved.controller_id.as_str(), "controller:b");
664 }
665
666 #[test]
667 fn equal_priority_requires_explicit_controller() {
668 let mut registry = ControllerRegistry::new();
669 registry
670 .register(manifest("controller:a", NodeKind::Model, 1))
671 .unwrap();
672 registry
673 .register(manifest("controller:b", NodeKind::Model, 1))
674 .unwrap();
675
676 assert!(registry.resolve_for_node(&node(NodeKind::Model)).is_err());
677 }
678
679 #[test]
680 fn operator_selector_prefers_specific_controller_over_generic() {
681 let mut registry = ControllerRegistry::new();
682 registry
683 .register(manifest(
684 "controller:transform.generic",
685 NodeKind::Transform,
686 0,
687 ))
688 .unwrap();
689 let mut specific = manifest("controller:transform.snv", NodeKind::Transform, 0);
690 specific.operator_selectors.push(alias_selector("SNV"));
691 registry.register(specific).unwrap();
692 let node = node_with_operator(NodeKind::Transform, json!("SNV"));
693
694 let resolved = registry.resolve_for_node(&node).unwrap();
695
696 assert_eq!(resolved.controller_id.as_str(), "controller:transform.snv");
697 }
698
699 #[test]
700 fn operator_selector_matches_plain_class_basename_alias() {
701 let mut registry = ControllerRegistry::new();
702 registry
703 .register(manifest(
704 "controller:transform.generic",
705 NodeKind::Transform,
706 0,
707 ))
708 .unwrap();
709 let mut specific = manifest("controller:transform.mixin", NodeKind::Transform, 0);
710 specific
711 .operator_selectors
712 .push(alias_selector("StandardScaler"));
713 registry.register(specific).unwrap();
714 let node = node_with_operator(
715 NodeKind::Transform,
716 json!({"class": "sklearn.preprocessing.StandardScaler"}),
717 );
718
719 let resolved = registry.resolve_for_node(&node).unwrap();
720
721 assert_eq!(
722 resolved.controller_id.as_str(),
723 "controller:transform.mixin"
724 );
725 }
726
727 #[test]
728 fn registry_infers_operator_kind_from_alias_selector() {
729 let mut registry = ControllerRegistry::new();
730 let mut model = manifest("controller:model.custom", NodeKind::Model, 0);
731 model
732 .operator_selectors
733 .push(alias_selector("ElasticSpectra"));
734 registry.register(model).unwrap();
735
736 let kind = registry
737 .infer_operator_kind(&json!("ElasticSpectra"))
738 .unwrap()
739 .unwrap();
740
741 assert_eq!(kind, NodeKind::Model);
742 }
743
744 #[test]
745 fn registry_refuses_cross_kind_alias_inference() {
746 let mut registry = ControllerRegistry::new();
747 let mut transform = manifest("controller:transform.custom", NodeKind::Transform, 0);
748 transform
749 .operator_selectors
750 .push(alias_selector("AmbiguousAlias"));
751 let mut model = manifest("controller:model.custom", NodeKind::Model, 0);
752 model
753 .operator_selectors
754 .push(alias_selector("AmbiguousAlias"));
755 registry.register(transform).unwrap();
756 registry.register(model).unwrap();
757
758 let error = registry
759 .infer_operator_kind(&json!("AmbiguousAlias"))
760 .unwrap_err()
761 .to_string();
762
763 assert!(error.contains("different node kinds"));
764 }
765
766 #[test]
767 fn operator_selector_matches_class_prefix() {
768 let mut registry = ControllerRegistry::new();
769 let mut sklearn = manifest("controller:sklearn.transform", NodeKind::Transform, 0);
770 sklearn.operator_selectors.push(OperatorSelector {
771 class_prefixes: BTreeSet::from(["sklearn.preprocessing.".to_string()]),
772 ..OperatorSelector::default()
773 });
774 registry.register(sklearn).unwrap();
775 let node = node_with_operator(
776 NodeKind::Transform,
777 json!({"class": "sklearn.preprocessing.MinMaxScaler"}),
778 );
779
780 let resolved = registry.resolve_for_node(&node).unwrap();
781
782 assert_eq!(
783 resolved.controller_id.as_str(),
784 "controller:sklearn.transform"
785 );
786 }
787
788 #[test]
789 fn equal_priority_operator_selector_matches_are_ambiguous() {
790 let mut registry = ControllerRegistry::new();
791 let mut first = manifest("controller:snv.a", NodeKind::Transform, 0);
792 first.operator_selectors.push(alias_selector("SNV"));
793 let mut second = manifest("controller:snv.b", NodeKind::Transform, 0);
794 second.operator_selectors.push(alias_selector("SNV"));
795 registry.register(first).unwrap();
796 registry.register(second).unwrap();
797 let node = node_with_operator(NodeKind::Transform, json!({"type": "SNV"}));
798
799 let error = registry.resolve_for_node(&node).unwrap_err().to_string();
800
801 assert!(error.contains("ambiguous controllers"));
802 }
803
804 #[test]
805 fn selector_only_controller_does_not_catch_unmatched_operator() {
806 let mut registry = ControllerRegistry::new();
807 let mut snv = manifest("controller:transform.snv", NodeKind::Transform, 0);
808 snv.operator_selectors.push(alias_selector("SNV"));
809 registry.register(snv).unwrap();
810 let node = node_with_operator(NodeKind::Transform, json!("MSC"));
811
812 let error = registry.resolve_for_node(&node).unwrap_err().to_string();
813
814 assert!(error.contains("no controller registered"));
815 }
816
817 #[test]
818 fn manifest_rejects_prediction_output_without_capability() {
819 let mut manifest = manifest("controller:predictor", NodeKind::Model, 0);
820 manifest.output_ports.push(PortSpec {
821 name: "pred".to_string(),
822 kind: PortKind::Prediction,
823 representation: None,
824 cardinality: PortCardinality::One,
825 unit_level: None,
826 alignment_key: None,
827 target_level: None,
828 description: String::new(),
829 });
830
831 let error = manifest.validate().unwrap_err().to_string();
832
833 assert!(error.contains("lacks emits_predictions"));
834 }
835
836 #[test]
837 fn manifest_rejects_training_phases_for_inference_only_controller() {
838 let mut manifest = manifest("controller:predict-only", NodeKind::Model, 0);
839 manifest.fit_scope = ControllerFitScope::InferenceOnly;
840
841 let error = manifest.validate().unwrap_err().to_string();
842
843 assert!(error.contains("inference_only"));
844 }
845
846 #[test]
847 fn manifest_validates_model_input_spec_data_requirements() {
848 let mut manifest = manifest("controller:data-aware", NodeKind::Model, 0);
849 manifest.data_requirements = Some(json!({
850 "schema_version": 1,
851 "ports": [{
852 "name": "x",
853 "accepted_representations": ["tabular_numeric"],
854 "accepted_types": ["f64"],
855 "rank": 2
856 }]
857 }));
858
859 let input_spec = manifest.model_input_spec().unwrap().unwrap();
860 assert_eq!(input_spec.ports[0].name, "x");
861 manifest.validate().unwrap();
862 }
863
864 #[test]
865 fn manifest_rejects_invalid_model_input_spec_data_requirements() {
866 let mut manifest = manifest("controller:data-aware", NodeKind::Model, 0);
867 manifest.data_requirements = Some(json!({
868 "schema_version": 1,
869 "ports": [{
870 "name": "x",
871 "accepted_representations": [],
872 "accepted_types": ["f64"]
873 }]
874 }));
875
876 let error = manifest.validate().unwrap_err().to_string();
877
878 assert!(error.contains("data_requirements"));
879 assert!(error.contains("accepted_representations"));
880 }
881
882 #[test]
883 fn manifest_rejects_empty_operator_selector() {
884 let mut manifest = manifest("controller:empty-selector", NodeKind::Transform, 0);
885 manifest
886 .operator_selectors
887 .push(OperatorSelector::default());
888
889 let error = manifest.validate().unwrap_err().to_string();
890
891 assert!(error.contains("empty operator selector"));
892 }
893
894 #[test]
895 fn manifest_reports_parallel_invocation_support() {
896 let mut manifest = manifest("controller:parallel", NodeKind::Model, 0);
897 assert!(!manifest.supports_parallel_invocation());
898 manifest
899 .capabilities
900 .insert(ControllerCapability::ProcessSafe);
901 assert!(manifest.supports_parallel_invocation());
902 }
903
904 #[test]
905 fn published_controller_manifest_schema_declares_current_contract() {
906 let schema: serde_json::Value = serde_json::from_str(include_str!(
907 "../../../docs/contracts/controller_manifest.schema.json"
908 ))
909 .unwrap();
910
911 assert_eq!(schema["$id"], CONTROLLER_MANIFEST_SCHEMA_ID);
912 assert!(schema["required"]
913 .as_array()
914 .unwrap()
915 .iter()
916 .any(|field| field.as_str() == Some("controller_id")));
917 assert!(schema["$defs"]["controller_capability"]["enum"]
918 .as_array()
919 .unwrap()
920 .iter()
921 .any(|capability| capability.as_str() == Some("emits_predictions")));
922 assert!(schema["$defs"]["controller_capability"]["enum"]
923 .as_array()
924 .unwrap()
925 .iter()
926 .any(|capability| capability.as_str() == Some("aggregates_predictions")));
927 assert!(schema["properties"]
928 .as_object()
929 .unwrap()
930 .contains_key("operator_selectors"));
931 assert_eq!(
932 schema["$defs"]["model_input_spec"]["properties"]["schema_version"]["const"].as_u64(),
933 Some(crate::data::MODEL_INPUT_SPEC_SCHEMA_VERSION as u64)
934 );
935 }
936}