Skip to main content

dag_ml_core/
controller.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::data::ModelInputSpec;
6use crate::error::{DagMlError, Result};
7use crate::graph::{NodeKind, NodeSpec, PortKind, PortSpec};
8use crate::ids::ControllerId;
9use crate::phase::Phase;
10use crate::policy::FitInfluencePolicy;
11
12pub const CONTROLLER_MANIFEST_SCHEMA_VERSION: u32 = 1;
13pub const CONTROLLER_MANIFEST_SCHEMA_ID: &str =
14    "https://github.com/GBeurier/dag-ml/schemas/controller_manifest.v1.schema.json";
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ControllerCapability {
19    Deterministic,
20    ThreadSafe,
21    ProcessSafe,
22    NeedsPythonGil,
23    EmitsPredictions,
24    ConsumesOofPredictions,
25    EmitsArtifacts,
26    Stateful,
27    EmitsRelation,
28    UsesCoreRng,
29    ShapeChanging,
30    GeneratesData,
31    GeneratesModel,
32    ExpandsVariants,
33    AggregatesPredictions,
34    SupportsSampleWeights,
35    SupportsRowResampling,
36    SupportsBackendLossWeights,
37    SupportsMissingMasks,
38}
39
40#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
41#[serde(rename_all = "snake_case")]
42pub enum ControllerFitScope {
43    Stateless,
44    FoldTrain,
45    FullTrain,
46    InferenceOnly,
47}
48
49#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum RngPolicy {
52    UsesCoreSeed,
53    IgnoresSeed,
54    ExternallyDeterministic,
55    Nondeterministic,
56}
57
58#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum ArtifactPolicy {
61    Serializable,
62    HostOnly,
63    ContentAddressed,
64    ReplayRequired,
65}
66
67#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
68#[serde(deny_unknown_fields)]
69pub struct OperatorSelector {
70    #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
71    pub aliases: BTreeSet<String>,
72    #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
73    pub classes: BTreeSet<String>,
74    #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
75    pub class_prefixes: BTreeSet<String>,
76    #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
77    pub functions: BTreeSet<String>,
78    #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
79    pub refs: BTreeSet<String>,
80    #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
81    pub types: BTreeSet<String>,
82}
83
84impl OperatorSelector {
85    fn validate(&self, controller_id: &ControllerId) -> Result<()> {
86        if self.aliases.is_empty()
87            && self.classes.is_empty()
88            && self.class_prefixes.is_empty()
89            && self.functions.is_empty()
90            && self.refs.is_empty()
91            && self.types.is_empty()
92        {
93            return Err(DagMlError::ControllerValidation(format!(
94                "controller `{controller_id}` has an empty operator selector"
95            )));
96        }
97        for (field, values) in [
98            ("aliases", &self.aliases),
99            ("classes", &self.classes),
100            ("class_prefixes", &self.class_prefixes),
101            ("functions", &self.functions),
102            ("refs", &self.refs),
103            ("types", &self.types),
104        ] {
105            if values.iter().any(|value| value.trim().is_empty()) {
106                return Err(DagMlError::ControllerValidation(format!(
107                    "controller `{controller_id}` operator selector `{field}` contains an empty value"
108                )));
109            }
110        }
111        Ok(())
112    }
113}
114
115#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
116#[serde(deny_unknown_fields)]
117pub struct ControllerManifest {
118    pub controller_id: ControllerId,
119    pub controller_version: String,
120    pub operator_kind: NodeKind,
121    #[serde(default)]
122    pub priority: u32,
123    #[serde(default)]
124    pub supported_phases: BTreeSet<Phase>,
125    #[serde(default)]
126    pub input_ports: Vec<PortSpec>,
127    #[serde(default)]
128    pub output_ports: Vec<PortSpec>,
129    #[serde(default)]
130    pub data_requirements: Option<serde_json::Value>,
131    #[serde(default)]
132    pub capabilities: BTreeSet<ControllerCapability>,
133    #[serde(default, skip_serializing_if = "Vec::is_empty")]
134    pub operator_selectors: Vec<OperatorSelector>,
135    pub fit_scope: ControllerFitScope,
136    pub rng_policy: RngPolicy,
137    pub artifact_policy: ArtifactPolicy,
138}
139
140impl ControllerManifest {
141    pub fn validate(&self) -> Result<()> {
142        if self.controller_version.trim().is_empty() {
143            return Err(DagMlError::ControllerValidation(format!(
144                "controller `{}` has an empty version",
145                self.controller_id
146            )));
147        }
148        if self.supported_phases.is_empty() {
149            return Err(DagMlError::ControllerValidation(format!(
150                "controller `{}` supports no phases",
151                self.controller_id
152            )));
153        }
154        if let Some(model_input) = self.model_input_spec()? {
155            model_input.validate().map_err(|error| {
156                DagMlError::ControllerValidation(format!(
157                    "controller `{}` data_requirements are not a valid ModelInputSpec: {error}",
158                    self.controller_id
159                ))
160            })?;
161        }
162        validate_ports(&self.controller_id, "input", &self.input_ports)?;
163        validate_ports(&self.controller_id, "output", &self.output_ports)?;
164        for selector in &self.operator_selectors {
165            selector.validate(&self.controller_id)?;
166        }
167        if self.rng_policy == RngPolicy::Nondeterministic
168            && self
169                .capabilities
170                .contains(&ControllerCapability::Deterministic)
171        {
172            return Err(DagMlError::ControllerValidation(format!(
173                "controller `{}` cannot be deterministic with nondeterministic RNG",
174                self.controller_id
175            )));
176        }
177        if self.fit_scope == ControllerFitScope::InferenceOnly
178            && (self.supported_phases.contains(&Phase::FitCv)
179                || self.supported_phases.contains(&Phase::Refit))
180        {
181            return Err(DagMlError::ControllerValidation(format!(
182                "controller `{}` is inference_only but supports training phases",
183                self.controller_id
184            )));
185        }
186        if self.supported_phases.contains(&Phase::FitCv)
187            && matches!(
188                self.fit_scope,
189                ControllerFitScope::FullTrain | ControllerFitScope::InferenceOnly
190            )
191        {
192            return Err(DagMlError::ControllerValidation(format!(
193                "controller `{}` supports FIT_CV but has fit_scope {:?}",
194                self.controller_id, self.fit_scope
195            )));
196        }
197        if self
198            .output_ports
199            .iter()
200            .any(|port| port.kind == PortKind::Prediction)
201            && !self
202                .capabilities
203                .contains(&ControllerCapability::EmitsPredictions)
204        {
205            return Err(DagMlError::ControllerValidation(format!(
206                "controller `{}` has prediction output ports but lacks emits_predictions",
207                self.controller_id
208            )));
209        }
210        if self
211            .output_ports
212            .iter()
213            .any(|port| port.kind == PortKind::Artifact)
214            && !self
215                .capabilities
216                .contains(&ControllerCapability::EmitsArtifacts)
217        {
218            return Err(DagMlError::ControllerValidation(format!(
219                "controller `{}` has artifact output ports but lacks emits_artifacts",
220                self.controller_id
221            )));
222        }
223        Ok(())
224    }
225
226    pub fn supports_phase(&self, phase: Phase) -> bool {
227        self.supported_phases.contains(&phase)
228    }
229
230    pub fn supports_parallel_invocation(&self) -> bool {
231        self.capabilities
232            .contains(&ControllerCapability::ThreadSafe)
233            || self
234                .capabilities
235                .contains(&ControllerCapability::ProcessSafe)
236    }
237
238    pub fn supports_fit_influence_policy(&self, policy: FitInfluencePolicy) -> bool {
239        capabilities_support_fit_influence(&self.capabilities, policy)
240    }
241
242    pub fn model_input_spec(&self) -> Result<Option<ModelInputSpec>> {
243        self.data_requirements
244            .as_ref()
245            .map(|value| {
246                serde_json::from_value::<ModelInputSpec>(value.clone()).map_err(|error| {
247                    DagMlError::ControllerValidation(format!(
248                        "controller `{}` data_requirements must be ModelInputSpec JSON: {error}",
249                        self.controller_id
250                    ))
251                })
252            })
253            .transpose()
254    }
255}
256
257pub fn capabilities_support_fit_influence(
258    capabilities: &BTreeSet<ControllerCapability>,
259    policy: FitInfluencePolicy,
260) -> bool {
261    match policy {
262        FitInfluencePolicy::Auto
263        | FitInfluencePolicy::UniformRows
264        | FitInfluencePolicy::ScorerOnly => true,
265        FitInfluencePolicy::EqualSampleInfluence => {
266            capabilities.contains(&ControllerCapability::SupportsSampleWeights)
267        }
268        FitInfluencePolicy::ResampleEqualized => {
269            capabilities.contains(&ControllerCapability::SupportsRowResampling)
270        }
271        FitInfluencePolicy::BackendLossWeight => {
272            capabilities.contains(&ControllerCapability::SupportsBackendLossWeights)
273        }
274        FitInfluencePolicy::StrictWeightSupport => {
275            capabilities.contains(&ControllerCapability::SupportsSampleWeights)
276                || capabilities.contains(&ControllerCapability::SupportsRowResampling)
277                || capabilities.contains(&ControllerCapability::SupportsBackendLossWeights)
278        }
279    }
280}
281
282#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
283pub struct ControllerRegistry {
284    manifests: BTreeMap<ControllerId, ControllerManifest>,
285}
286
287impl ControllerRegistry {
288    pub fn new() -> Self {
289        Self::default()
290    }
291
292    pub fn register(&mut self, manifest: ControllerManifest) -> Result<()> {
293        manifest.validate()?;
294        if self.manifests.contains_key(&manifest.controller_id) {
295            return Err(DagMlError::ControllerValidation(format!(
296                "duplicate controller id `{}`",
297                manifest.controller_id
298            )));
299        }
300        self.manifests
301            .insert(manifest.controller_id.clone(), manifest);
302        Ok(())
303    }
304
305    pub fn get(&self, controller_id: &ControllerId) -> Option<&ControllerManifest> {
306        self.manifests.get(controller_id)
307    }
308
309    pub fn manifests(&self) -> impl Iterator<Item = &ControllerManifest> {
310        self.manifests.values()
311    }
312
313    pub fn resolve_for_node(&self, node: &NodeSpec) -> Result<ControllerManifest> {
314        if let Some(requested) = requested_controller(node)? {
315            let manifest = self.get(&requested).ok_or_else(|| {
316                DagMlError::Planning(format!(
317                    "node `{}` requested unknown controller `{requested}`",
318                    node.id
319                ))
320            })?;
321            if manifest.operator_kind != node.kind {
322                return Err(DagMlError::Planning(format!(
323                    "node `{}` kind {:?} is incompatible with controller `{}` kind {:?}",
324                    node.id, node.kind, manifest.controller_id, manifest.operator_kind
325                )));
326            }
327            return Ok(manifest.clone());
328        }
329
330        let mut candidates = self
331            .manifests
332            .values()
333            .filter_map(|manifest| controller_candidate(manifest, node))
334            .collect::<Vec<_>>();
335        candidates.sort_by(|left, right| {
336            left.rank
337                .cmp(&right.rank)
338                .then_with(|| left.manifest.priority.cmp(&right.manifest.priority))
339                .then_with(|| {
340                    left.manifest
341                        .controller_id
342                        .cmp(&right.manifest.controller_id)
343                })
344        });
345        let Some(first) = candidates.first() else {
346            return Err(DagMlError::Planning(format!(
347                "no controller registered for node `{}` kind {:?}",
348                node.id, node.kind
349            )));
350        };
351        if candidates.get(1).is_some_and(|second| {
352            second.rank == first.rank && second.manifest.priority == first.manifest.priority
353        }) {
354            return Err(DagMlError::Planning(format!(
355                "node `{}` has ambiguous controllers for kind {:?}; set metadata.controller_id",
356                node.id, node.kind
357            )));
358        }
359        Ok(first.manifest.clone())
360    }
361
362    pub fn infer_operator_kind(&self, operator: &serde_json::Value) -> Result<Option<NodeKind>> {
363        let matches = self
364            .manifests
365            .values()
366            .filter(|manifest| {
367                !manifest.operator_selectors.is_empty()
368                    && manifest
369                        .operator_selectors
370                        .iter()
371                        .any(|selector| selector_matches_operator(selector, operator))
372            })
373            .collect::<Vec<_>>();
374        let Some(first) = matches.first() else {
375            return Ok(None);
376        };
377        let kind = first.operator_kind.clone();
378        let conflicting = matches
379            .iter()
380            .find(|manifest| manifest.operator_kind != kind);
381        if let Some(conflicting) = conflicting {
382            return Err(DagMlError::Planning(format!(
383                "minimal operator alias `{}` matches controllers with different node kinds ({:?} and {:?}); use explicit DSL syntax",
384                operator_label(operator),
385                kind,
386                conflicting.operator_kind
387            )));
388        }
389        Ok(Some(kind))
390    }
391}
392
393#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
394enum ControllerMatchRank {
395    OperatorSelector,
396    GenericKind,
397}
398
399struct ControllerCandidate<'a> {
400    manifest: &'a ControllerManifest,
401    rank: ControllerMatchRank,
402}
403
404fn controller_candidate<'a>(
405    manifest: &'a ControllerManifest,
406    node: &NodeSpec,
407) -> Option<ControllerCandidate<'a>> {
408    if manifest.operator_kind != node.kind {
409        return None;
410    }
411    if manifest.operator_selectors.is_empty() {
412        return Some(ControllerCandidate {
413            manifest,
414            rank: ControllerMatchRank::GenericKind,
415        });
416    }
417    let operator = node.operator.as_ref()?;
418    manifest
419        .operator_selectors
420        .iter()
421        .any(|selector| selector_matches_operator(selector, operator))
422        .then_some(ControllerCandidate {
423            manifest,
424            rank: ControllerMatchRank::OperatorSelector,
425        })
426}
427
428fn selector_matches_operator(selector: &OperatorSelector, operator: &serde_json::Value) -> bool {
429    let descriptor = OperatorDescriptor::from_value(operator);
430    selector_matches_any(
431        &selector.aliases,
432        descriptor.alias_candidates.iter().copied(),
433    ) || descriptor
434        .class
435        .is_some_and(|class| selector_matches_exact(&selector.classes, class))
436        || descriptor.class.is_some_and(|class| {
437            selector
438                .class_prefixes
439                .iter()
440                .any(|prefix| normalized_starts_with(class, prefix))
441        })
442        || descriptor
443            .function
444            .is_some_and(|function| selector_matches_exact(&selector.functions, function))
445        || descriptor
446            .reference
447            .is_some_and(|reference| selector_matches_exact(&selector.refs, reference))
448        || descriptor
449            .operator_type
450            .is_some_and(|operator_type| selector_matches_exact(&selector.types, operator_type))
451}
452
453fn operator_label(operator: &serde_json::Value) -> String {
454    match operator {
455        serde_json::Value::String(value) => value.clone(),
456        serde_json::Value::Object(object) => ["type", "ref", "class", "function"]
457            .into_iter()
458            .find_map(|key| object.get(key).and_then(serde_json::Value::as_str))
459            .map(str::to_string)
460            .unwrap_or_else(|| operator.to_string()),
461        _ => operator.to_string(),
462    }
463}
464
465fn selector_matches_any<'a>(
466    values: &BTreeSet<String>,
467    mut candidates: impl Iterator<Item = &'a str>,
468) -> bool {
469    candidates.any(|candidate| selector_matches_exact(values, candidate))
470}
471
472fn selector_matches_exact(values: &BTreeSet<String>, candidate: &str) -> bool {
473    values
474        .iter()
475        .any(|value| normalized_eq(value.as_str(), candidate))
476}
477
478fn normalized_eq(left: &str, right: &str) -> bool {
479    left.trim().eq_ignore_ascii_case(right.trim())
480}
481
482fn normalized_starts_with(value: &str, prefix: &str) -> bool {
483    value
484        .trim()
485        .to_ascii_lowercase()
486        .starts_with(&prefix.trim().to_ascii_lowercase())
487}
488
489struct OperatorDescriptor<'a> {
490    class: Option<&'a str>,
491    function: Option<&'a str>,
492    reference: Option<&'a str>,
493    operator_type: Option<&'a str>,
494    alias_candidates: Vec<&'a str>,
495}
496
497impl<'a> OperatorDescriptor<'a> {
498    fn from_value(value: &'a serde_json::Value) -> Self {
499        let mut descriptor = Self {
500            class: None,
501            function: None,
502            reference: None,
503            operator_type: None,
504            alias_candidates: Vec::new(),
505        };
506        match value {
507            serde_json::Value::String(reference) => {
508                descriptor.reference = Some(reference);
509                descriptor.push_alias_candidates(reference);
510            }
511            serde_json::Value::Object(object) => {
512                descriptor.class = object.get("class").and_then(serde_json::Value::as_str);
513                descriptor.function = object.get("function").and_then(serde_json::Value::as_str);
514                descriptor.reference = object.get("ref").and_then(serde_json::Value::as_str);
515                descriptor.operator_type = object.get("type").and_then(serde_json::Value::as_str);
516                for value in [
517                    descriptor.operator_type,
518                    descriptor.reference,
519                    descriptor.class,
520                    descriptor.function,
521                ]
522                .into_iter()
523                .flatten()
524                {
525                    descriptor.push_alias_candidates(value);
526                }
527            }
528            _ => {}
529        }
530        descriptor
531    }
532
533    fn push_alias_candidates(&mut self, value: &'a str) {
534        self.alias_candidates.push(value);
535        if let Some(short) = value
536            .rsplit(['.', ':'])
537            .next()
538            .filter(|short| *short != value)
539        {
540            self.alias_candidates.push(short);
541        }
542    }
543}
544
545fn validate_ports(controller_id: &ControllerId, direction: &str, ports: &[PortSpec]) -> Result<()> {
546    let mut seen = BTreeSet::new();
547    for port in ports {
548        if port.name.trim().is_empty() {
549            return Err(DagMlError::ControllerValidation(format!(
550                "{direction} port on controller `{controller_id}` has an empty name"
551            )));
552        }
553        if !seen.insert(port.name.as_str()) {
554            return Err(DagMlError::ControllerValidation(format!(
555                "duplicate {direction} port `{}` on controller `{controller_id}`",
556                port.name
557            )));
558        }
559    }
560    Ok(())
561}
562
563fn requested_controller(node: &NodeSpec) -> Result<Option<ControllerId>> {
564    node.metadata
565        .get("controller_id")
566        .map(|value| {
567            value.as_str().ok_or_else(|| {
568                DagMlError::Planning(format!(
569                    "node `{}` metadata.controller_id must be a string",
570                    node.id
571                ))
572            })
573        })
574        .transpose()?
575        .map(ControllerId::new)
576        .transpose()
577}
578
579#[cfg(test)]
580mod tests {
581    use std::collections::{BTreeMap, BTreeSet};
582
583    use serde_json::json;
584
585    use super::*;
586    use crate::graph::{NodeSpec, PortCardinality, PortSchema};
587    use crate::ids::NodeId;
588
589    fn manifest(id: &str, kind: NodeKind, priority: u32) -> ControllerManifest {
590        ControllerManifest {
591            controller_id: ControllerId::new(id).unwrap(),
592            controller_version: "0.1.0".to_string(),
593            operator_kind: kind,
594            priority,
595            supported_phases: BTreeSet::from([Phase::FitCv]),
596            input_ports: Vec::new(),
597            output_ports: Vec::new(),
598            data_requirements: None,
599            capabilities: BTreeSet::from([ControllerCapability::Deterministic]),
600            operator_selectors: Vec::new(),
601            fit_scope: ControllerFitScope::FoldTrain,
602            rng_policy: RngPolicy::UsesCoreSeed,
603            artifact_policy: ArtifactPolicy::Serializable,
604        }
605    }
606
607    fn node(kind: NodeKind) -> NodeSpec {
608        NodeSpec {
609            id: NodeId::new("node:model").unwrap(),
610            kind,
611            operator: None,
612            params: BTreeMap::new(),
613            ports: PortSchema::default(),
614            metadata: BTreeMap::new(),
615            seed_label: None,
616        }
617    }
618
619    fn node_with_operator(kind: NodeKind, operator: serde_json::Value) -> NodeSpec {
620        NodeSpec {
621            operator: Some(operator),
622            ..node(kind)
623        }
624    }
625
626    fn alias_selector(alias: &str) -> OperatorSelector {
627        OperatorSelector {
628            aliases: BTreeSet::from([alias.to_string()]),
629            ..OperatorSelector::default()
630        }
631    }
632
633    #[test]
634    fn registry_resolves_lowest_priority_manifest() {
635        let mut registry = ControllerRegistry::new();
636        registry
637            .register(manifest("controller:slow", NodeKind::Model, 10))
638            .unwrap();
639        registry
640            .register(manifest("controller:fast", NodeKind::Model, 1))
641            .unwrap();
642
643        let resolved = registry.resolve_for_node(&node(NodeKind::Model)).unwrap();
644
645        assert_eq!(resolved.controller_id.as_str(), "controller:fast");
646    }
647
648    #[test]
649    fn explicit_controller_id_disambiguates() {
650        let mut registry = ControllerRegistry::new();
651        registry
652            .register(manifest("controller:a", NodeKind::Model, 1))
653            .unwrap();
654        registry
655            .register(manifest("controller:b", NodeKind::Model, 1))
656            .unwrap();
657        let mut node = node(NodeKind::Model);
658        node.metadata
659            .insert("controller_id".to_string(), json!("controller:b"));
660
661        let resolved = registry.resolve_for_node(&node).unwrap();
662
663        assert_eq!(resolved.controller_id.as_str(), "controller:b");
664    }
665
666    #[test]
667    fn equal_priority_requires_explicit_controller() {
668        let mut registry = ControllerRegistry::new();
669        registry
670            .register(manifest("controller:a", NodeKind::Model, 1))
671            .unwrap();
672        registry
673            .register(manifest("controller:b", NodeKind::Model, 1))
674            .unwrap();
675
676        assert!(registry.resolve_for_node(&node(NodeKind::Model)).is_err());
677    }
678
679    #[test]
680    fn operator_selector_prefers_specific_controller_over_generic() {
681        let mut registry = ControllerRegistry::new();
682        registry
683            .register(manifest(
684                "controller:transform.generic",
685                NodeKind::Transform,
686                0,
687            ))
688            .unwrap();
689        let mut specific = manifest("controller:transform.snv", NodeKind::Transform, 0);
690        specific.operator_selectors.push(alias_selector("SNV"));
691        registry.register(specific).unwrap();
692        let node = node_with_operator(NodeKind::Transform, json!("SNV"));
693
694        let resolved = registry.resolve_for_node(&node).unwrap();
695
696        assert_eq!(resolved.controller_id.as_str(), "controller:transform.snv");
697    }
698
699    #[test]
700    fn operator_selector_matches_plain_class_basename_alias() {
701        let mut registry = ControllerRegistry::new();
702        registry
703            .register(manifest(
704                "controller:transform.generic",
705                NodeKind::Transform,
706                0,
707            ))
708            .unwrap();
709        let mut specific = manifest("controller:transform.mixin", NodeKind::Transform, 0);
710        specific
711            .operator_selectors
712            .push(alias_selector("StandardScaler"));
713        registry.register(specific).unwrap();
714        let node = node_with_operator(
715            NodeKind::Transform,
716            json!({"class": "sklearn.preprocessing.StandardScaler"}),
717        );
718
719        let resolved = registry.resolve_for_node(&node).unwrap();
720
721        assert_eq!(
722            resolved.controller_id.as_str(),
723            "controller:transform.mixin"
724        );
725    }
726
727    #[test]
728    fn registry_infers_operator_kind_from_alias_selector() {
729        let mut registry = ControllerRegistry::new();
730        let mut model = manifest("controller:model.custom", NodeKind::Model, 0);
731        model
732            .operator_selectors
733            .push(alias_selector("ElasticSpectra"));
734        registry.register(model).unwrap();
735
736        let kind = registry
737            .infer_operator_kind(&json!("ElasticSpectra"))
738            .unwrap()
739            .unwrap();
740
741        assert_eq!(kind, NodeKind::Model);
742    }
743
744    #[test]
745    fn registry_refuses_cross_kind_alias_inference() {
746        let mut registry = ControllerRegistry::new();
747        let mut transform = manifest("controller:transform.custom", NodeKind::Transform, 0);
748        transform
749            .operator_selectors
750            .push(alias_selector("AmbiguousAlias"));
751        let mut model = manifest("controller:model.custom", NodeKind::Model, 0);
752        model
753            .operator_selectors
754            .push(alias_selector("AmbiguousAlias"));
755        registry.register(transform).unwrap();
756        registry.register(model).unwrap();
757
758        let error = registry
759            .infer_operator_kind(&json!("AmbiguousAlias"))
760            .unwrap_err()
761            .to_string();
762
763        assert!(error.contains("different node kinds"));
764    }
765
766    #[test]
767    fn operator_selector_matches_class_prefix() {
768        let mut registry = ControllerRegistry::new();
769        let mut sklearn = manifest("controller:sklearn.transform", NodeKind::Transform, 0);
770        sklearn.operator_selectors.push(OperatorSelector {
771            class_prefixes: BTreeSet::from(["sklearn.preprocessing.".to_string()]),
772            ..OperatorSelector::default()
773        });
774        registry.register(sklearn).unwrap();
775        let node = node_with_operator(
776            NodeKind::Transform,
777            json!({"class": "sklearn.preprocessing.MinMaxScaler"}),
778        );
779
780        let resolved = registry.resolve_for_node(&node).unwrap();
781
782        assert_eq!(
783            resolved.controller_id.as_str(),
784            "controller:sklearn.transform"
785        );
786    }
787
788    #[test]
789    fn equal_priority_operator_selector_matches_are_ambiguous() {
790        let mut registry = ControllerRegistry::new();
791        let mut first = manifest("controller:snv.a", NodeKind::Transform, 0);
792        first.operator_selectors.push(alias_selector("SNV"));
793        let mut second = manifest("controller:snv.b", NodeKind::Transform, 0);
794        second.operator_selectors.push(alias_selector("SNV"));
795        registry.register(first).unwrap();
796        registry.register(second).unwrap();
797        let node = node_with_operator(NodeKind::Transform, json!({"type": "SNV"}));
798
799        let error = registry.resolve_for_node(&node).unwrap_err().to_string();
800
801        assert!(error.contains("ambiguous controllers"));
802    }
803
804    #[test]
805    fn selector_only_controller_does_not_catch_unmatched_operator() {
806        let mut registry = ControllerRegistry::new();
807        let mut snv = manifest("controller:transform.snv", NodeKind::Transform, 0);
808        snv.operator_selectors.push(alias_selector("SNV"));
809        registry.register(snv).unwrap();
810        let node = node_with_operator(NodeKind::Transform, json!("MSC"));
811
812        let error = registry.resolve_for_node(&node).unwrap_err().to_string();
813
814        assert!(error.contains("no controller registered"));
815    }
816
817    #[test]
818    fn manifest_rejects_prediction_output_without_capability() {
819        let mut manifest = manifest("controller:predictor", NodeKind::Model, 0);
820        manifest.output_ports.push(PortSpec {
821            name: "pred".to_string(),
822            kind: PortKind::Prediction,
823            representation: None,
824            cardinality: PortCardinality::One,
825            unit_level: None,
826            alignment_key: None,
827            target_level: None,
828            description: String::new(),
829        });
830
831        let error = manifest.validate().unwrap_err().to_string();
832
833        assert!(error.contains("lacks emits_predictions"));
834    }
835
836    #[test]
837    fn manifest_rejects_training_phases_for_inference_only_controller() {
838        let mut manifest = manifest("controller:predict-only", NodeKind::Model, 0);
839        manifest.fit_scope = ControllerFitScope::InferenceOnly;
840
841        let error = manifest.validate().unwrap_err().to_string();
842
843        assert!(error.contains("inference_only"));
844    }
845
846    #[test]
847    fn manifest_validates_model_input_spec_data_requirements() {
848        let mut manifest = manifest("controller:data-aware", NodeKind::Model, 0);
849        manifest.data_requirements = Some(json!({
850            "schema_version": 1,
851            "ports": [{
852                "name": "x",
853                "accepted_representations": ["tabular_numeric"],
854                "accepted_types": ["f64"],
855                "rank": 2
856            }]
857        }));
858
859        let input_spec = manifest.model_input_spec().unwrap().unwrap();
860        assert_eq!(input_spec.ports[0].name, "x");
861        manifest.validate().unwrap();
862    }
863
864    #[test]
865    fn manifest_rejects_invalid_model_input_spec_data_requirements() {
866        let mut manifest = manifest("controller:data-aware", NodeKind::Model, 0);
867        manifest.data_requirements = Some(json!({
868            "schema_version": 1,
869            "ports": [{
870                "name": "x",
871                "accepted_representations": [],
872                "accepted_types": ["f64"]
873            }]
874        }));
875
876        let error = manifest.validate().unwrap_err().to_string();
877
878        assert!(error.contains("data_requirements"));
879        assert!(error.contains("accepted_representations"));
880    }
881
882    #[test]
883    fn manifest_rejects_empty_operator_selector() {
884        let mut manifest = manifest("controller:empty-selector", NodeKind::Transform, 0);
885        manifest
886            .operator_selectors
887            .push(OperatorSelector::default());
888
889        let error = manifest.validate().unwrap_err().to_string();
890
891        assert!(error.contains("empty operator selector"));
892    }
893
894    #[test]
895    fn manifest_reports_parallel_invocation_support() {
896        let mut manifest = manifest("controller:parallel", NodeKind::Model, 0);
897        assert!(!manifest.supports_parallel_invocation());
898        manifest
899            .capabilities
900            .insert(ControllerCapability::ProcessSafe);
901        assert!(manifest.supports_parallel_invocation());
902    }
903
904    #[test]
905    fn published_controller_manifest_schema_declares_current_contract() {
906        let schema: serde_json::Value = serde_json::from_str(include_str!(
907            "../../../docs/contracts/controller_manifest.schema.json"
908        ))
909        .unwrap();
910
911        assert_eq!(schema["$id"], CONTROLLER_MANIFEST_SCHEMA_ID);
912        assert!(schema["required"]
913            .as_array()
914            .unwrap()
915            .iter()
916            .any(|field| field.as_str() == Some("controller_id")));
917        assert!(schema["$defs"]["controller_capability"]["enum"]
918            .as_array()
919            .unwrap()
920            .iter()
921            .any(|capability| capability.as_str() == Some("emits_predictions")));
922        assert!(schema["$defs"]["controller_capability"]["enum"]
923            .as_array()
924            .unwrap()
925            .iter()
926            .any(|capability| capability.as_str() == Some("aggregates_predictions")));
927        assert!(schema["properties"]
928            .as_object()
929            .unwrap()
930            .contains_key("operator_selectors"));
931        assert_eq!(
932            schema["$defs"]["model_input_spec"]["properties"]["schema_version"]["const"].as_u64(),
933            Some(crate::data::MODEL_INPUT_SPEC_SCHEMA_VERSION as u64)
934        );
935    }
936}