Skip to main content

dag_ml_core/
runtime.rs

1use std::cell::RefCell;
2use std::collections::{BTreeMap, BTreeSet};
3use std::fs;
4use std::io::Read;
5use std::path::{Path, PathBuf};
6
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::aggregation::{
11    aggregate_observation_predictions, aggregate_sample_predictions_by_unit,
12    AggregatedPredictionBlock, AggregationControllerInput, AggregationControllerOutput,
13    AggregationControllerResult, AggregationControllerTask, ObservationPredictionBlock,
14    PredictionUnitId,
15};
16use crate::bundle::{
17    build_aggregated_prediction_cache_payload, build_prediction_cache_payload,
18    bundle_prediction_requirement_key, validate_prediction_cache_payload_matches_record,
19    BundlePredictionCachePayload, BundlePredictionCachePayloadSet, BundlePredictionCacheRecord,
20    BundlePredictionRequirement, ExecutionBundle, RefitArtifactRecord, ReplayPhaseRequest,
21};
22use crate::campaign::stable_json_fingerprint;
23use crate::controller::{capabilities_support_fit_influence, ControllerCapability};
24use crate::data::{
25    DataBinding, DataRequestPartition, ExternalDataPlanEnvelope, RepresentationCompatibilityReport,
26    RepresentationPlan, RepresentationReplayManifest,
27};
28use crate::error::{DagMlError, Result};
29use crate::fold::{FoldAssignment, FoldSet};
30use crate::generation::{GenerationChoice, VariantPlan};
31use crate::graph::{EdgeSpec, PortKind};
32use crate::ids::{
33    ArtifactId, BranchId, BundleId, ControllerId, FoldId, LineageId, NodeId, RunId, SampleId,
34    VariantId,
35};
36use crate::oof::{PredictionBlock, PredictionPartition};
37use crate::phase::Phase;
38use crate::plan::{CampaignSpec, ExecutionPlan, NodePlan};
39use crate::policy::{
40    AggregationPolicy, FitInfluencePolicy, PredictionLevel, ShapeDelta, ShapeDeltaKind,
41};
42use crate::relation::SampleRelationSet;
43use crate::rng::SeedContext;
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum HandleKind {
48    Data,
49    DataView,
50    Model,
51    Artifact,
52    Prediction,
53    Relation,
54}
55
56#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
57pub struct HandleRef {
58    pub handle: u64,
59    pub kind: HandleKind,
60    pub owner_controller: ControllerId,
61}
62
63#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
64#[serde(rename_all = "snake_case")]
65pub enum ArtifactBackend {
66    Joblib,
67    Torch,
68    Tensorflow,
69    Onnx,
70    Safetensors,
71    Json,
72    Raw,
73}
74
75#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
76pub struct ArtifactRef {
77    pub id: ArtifactId,
78    pub kind: String,
79    pub controller_id: ControllerId,
80    #[serde(default, skip_serializing_if = "Option::is_none")]
81    pub backend: Option<ArtifactBackend>,
82    #[serde(default, skip_serializing_if = "Option::is_none")]
83    pub uri: Option<String>,
84    #[serde(default, skip_serializing_if = "Option::is_none")]
85    pub content_fingerprint: Option<String>,
86    pub size_bytes: Option<u64>,
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub plugin: Option<String>,
89    #[serde(default, skip_serializing_if = "Option::is_none")]
90    pub plugin_version: Option<String>,
91}
92
93impl ArtifactRef {
94    pub fn validate(&self) -> Result<()> {
95        if self.kind.trim().is_empty() {
96            return Err(DagMlError::RuntimeValidation(format!(
97                "artifact `{}` has empty kind",
98                self.id
99            )));
100        }
101        validate_artifact_optional_text("uri", &self.uri, &self.id)?;
102        validate_artifact_optional_text("plugin", &self.plugin, &self.id)?;
103        validate_artifact_optional_text("plugin_version", &self.plugin_version, &self.id)?;
104        if self.plugin_version.is_some() && self.plugin.is_none() {
105            return Err(DagMlError::RuntimeValidation(format!(
106                "artifact `{}` has plugin_version without plugin",
107                self.id
108            )));
109        }
110        if let Some(content_fingerprint) = &self.content_fingerprint {
111            validate_runtime_fingerprint("artifact content", content_fingerprint)?;
112        }
113        if self.uri.is_some() && self.backend.is_none() {
114            return Err(DagMlError::RuntimeValidation(format!(
115                "artifact `{}` has uri without backend",
116                self.id
117            )));
118        }
119        if self.uri.is_some() && self.content_fingerprint.is_none() {
120            return Err(DagMlError::RuntimeValidation(format!(
121                "artifact `{}` has uri without content_fingerprint",
122                self.id
123            )));
124        }
125        Ok(())
126    }
127
128    /// Validate that the artifact carries portable metadata: a backend, a safe
129    /// relative URI and a content fingerprint. Legacy artifacts that only carry
130    /// inline metadata stay readable through [`ArtifactRef::validate`] but are
131    /// refused here so persisted manifests can be moved with their payloads.
132    pub fn validate_portable(&self) -> Result<()> {
133        self.validate()?;
134        let Some(uri) = self.uri.as_deref() else {
135            return Err(DagMlError::RuntimeValidation(format!(
136                "artifact `{}` is not portable: requires backend, uri and content_fingerprint",
137                self.id
138            )));
139        };
140        // `validate` already guarantees that a present URI implies a backend and
141        // a 64-hex content fingerprint, so confirming the URI is enough here.
142        validate_relative_artifact_uri(&self.id, uri)
143    }
144}
145
146pub fn refit_artifact_input_key(artifact_id: &ArtifactId) -> String {
147    format!("artifact:{artifact_id}")
148}
149
150#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
151pub struct ArtifactMaterializationRequest {
152    pub run_id: RunId,
153    pub bundle_id: BundleId,
154    pub node_id: NodeId,
155    pub phase: Phase,
156    pub variant_id: Option<VariantId>,
157    pub controller_id: ControllerId,
158    pub artifact: ArtifactRef,
159    pub params_fingerprint: String,
160}
161
162#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
163pub struct ArtifactHandleRecord {
164    pub handle: HandleRef,
165    pub node_id: NodeId,
166    pub controller_id: ControllerId,
167    pub artifact: ArtifactRef,
168    pub params_fingerprint: String,
169}
170
171impl ArtifactHandleRecord {
172    pub fn validate(&self) -> Result<()> {
173        self.artifact.validate()?;
174        if !matches!(self.handle.kind, HandleKind::Model | HandleKind::Artifact) {
175            return Err(DagMlError::RuntimeValidation(format!(
176                "artifact `{}` is registered with non-artifact/model handle kind {:?}",
177                self.artifact.id, self.handle.kind
178            )));
179        }
180        if self.handle.owner_controller != self.controller_id {
181            return Err(DagMlError::RuntimeValidation(format!(
182                "artifact `{}` handle owner `{}` does not match controller `{}`",
183                self.artifact.id, self.handle.owner_controller, self.controller_id
184            )));
185        }
186        if self.artifact.controller_id != self.controller_id {
187            return Err(DagMlError::RuntimeValidation(format!(
188                "artifact `{}` controller `{}` does not match record controller `{}`",
189                self.artifact.id, self.artifact.controller_id, self.controller_id
190            )));
191        }
192        if self.params_fingerprint.trim().is_empty() {
193            return Err(DagMlError::RuntimeValidation(format!(
194                "artifact `{}` has empty params fingerprint",
195                self.artifact.id
196            )));
197        }
198        Ok(())
199    }
200}
201
202pub trait RuntimeArtifactStore {
203    fn materialize(&self, request: &ArtifactMaterializationRequest) -> Result<HandleRef>;
204}
205
206#[derive(Clone, Debug, Default)]
207pub struct InMemoryArtifactStore {
208    records: BTreeMap<ArtifactId, ArtifactHandleRecord>,
209    refit_artifacts: BTreeMap<ArtifactId, RefitArtifactRecord>,
210}
211
212impl InMemoryArtifactStore {
213    pub fn new() -> Self {
214        Self::default()
215    }
216
217    pub fn register(&mut self, artifact: &RefitArtifactRecord, handle: HandleRef) -> Result<()> {
218        artifact.validate()?;
219        let record = ArtifactHandleRecord {
220            handle,
221            node_id: artifact.node_id.clone(),
222            controller_id: artifact.controller_id.clone(),
223            artifact: artifact.artifact.clone(),
224            params_fingerprint: artifact.params_fingerprint.clone(),
225        };
226        record.validate()?;
227        if self.records.contains_key(&record.artifact.id)
228            || self.refit_artifacts.contains_key(&record.artifact.id)
229        {
230            return Err(DagMlError::RuntimeValidation(format!(
231                "duplicate artifact handle for `{}`",
232                artifact.artifact.id
233            )));
234        }
235        let previous_record = self.records.insert(record.artifact.id.clone(), record);
236        debug_assert!(previous_record.is_none());
237        let previous_artifact = self
238            .refit_artifacts
239            .insert(artifact.artifact.id.clone(), artifact.clone());
240        debug_assert!(previous_artifact.is_none());
241        Ok(())
242    }
243
244    pub fn capture_refit_artifacts(
245        &mut self,
246        task: &NodeTask,
247        result: &NodeResult,
248    ) -> Result<Vec<RefitArtifactRecord>> {
249        if task.phase != Phase::Refit {
250            return Err(DagMlError::RuntimeValidation(format!(
251                "cannot capture refit artifacts from phase {:?}",
252                task.phase
253            )));
254        }
255        let mut records = Vec::new();
256        for artifact in &result.artifacts {
257            let handle = result.artifact_handles.get(&artifact.id).ok_or_else(|| {
258                DagMlError::RuntimeValidation(format!(
259                    "node `{}` emitted artifact `{}` without artifact handle",
260                    task.node_plan.node_id, artifact.id
261                ))
262            })?;
263            let record = RefitArtifactRecord {
264                node_id: task.node_plan.node_id.clone(),
265                controller_id: task.node_plan.controller_id.clone(),
266                artifact: artifact.clone(),
267                params_fingerprint: task.node_plan.params_fingerprint.clone(),
268                data_requirement_keys: task
269                    .node_plan
270                    .data_bindings
271                    .iter()
272                    .map(|binding| format!("{}.{}", binding.node_id, binding.input_name))
273                    .collect(),
274                prediction_requirement_keys: task
275                    .prediction_inputs
276                    .values()
277                    .map(|spec| {
278                        bundle_prediction_requirement_key(
279                            &spec.producer_node,
280                            &spec.source_port,
281                            &task.node_plan.node_id,
282                            &spec.target_port,
283                        )
284                    })
285                    .collect(),
286            };
287            self.register(&record, handle.clone())?;
288            records.push(record);
289        }
290        Ok(records)
291    }
292
293    pub fn get(&self, artifact_id: &ArtifactId) -> Option<&ArtifactHandleRecord> {
294        self.records.get(artifact_id)
295    }
296
297    pub fn len(&self) -> usize {
298        self.records.len()
299    }
300
301    pub fn is_empty(&self) -> bool {
302        self.records.is_empty()
303    }
304
305    pub fn refit_artifacts(&self) -> Vec<RefitArtifactRecord> {
306        self.refit_artifacts.values().cloned().collect()
307    }
308}
309
310impl RuntimeArtifactStore for InMemoryArtifactStore {
311    fn materialize(&self, request: &ArtifactMaterializationRequest) -> Result<HandleRef> {
312        let record = self.records.get(&request.artifact.id).ok_or_else(|| {
313            DagMlError::RuntimeValidation(format!(
314                "artifact store is missing refit artifact `{}` for bundle `{}`",
315                request.artifact.id, request.bundle_id
316            ))
317        })?;
318        if record.node_id != request.node_id {
319            return Err(DagMlError::RuntimeValidation(format!(
320                "artifact `{}` is registered for node `{}` but requested for `{}`",
321                request.artifact.id, record.node_id, request.node_id
322            )));
323        }
324        if record.controller_id != request.controller_id {
325            return Err(DagMlError::RuntimeValidation(format!(
326                "artifact `{}` is registered for controller `{}` but requested for `{}`",
327                request.artifact.id, record.controller_id, request.controller_id
328            )));
329        }
330        if record.artifact != request.artifact {
331            return Err(DagMlError::RuntimeValidation(format!(
332                "artifact `{}` metadata does not match bundle record",
333                request.artifact.id
334            )));
335        }
336        if record.params_fingerprint != request.params_fingerprint {
337            return Err(DagMlError::RuntimeValidation(format!(
338                "artifact `{}` params fingerprint does not match bundle record",
339                request.artifact.id
340            )));
341        }
342        record.validate()?;
343        Ok(record.handle.clone())
344    }
345}
346
347pub const FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION: u32 = 1;
348pub const FILE_ARTIFACT_MANIFEST_FILE: &str = "artifact_manifest.json";
349
350fn default_file_artifact_manifest_schema_version() -> u32 {
351    FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION
352}
353
354/// One persisted artifact entry. Mirrors the bundle [`RefitArtifactRecord`]
355/// identity (node, controller, artifact and params fingerprint) while requiring
356/// the [`ArtifactRef`] to be portable so the manifest stays movable with its
357/// payloads.
358#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
359pub struct FileArtifactManifestEntry {
360    pub node_id: NodeId,
361    pub controller_id: ControllerId,
362    pub artifact: ArtifactRef,
363    pub params_fingerprint: String,
364}
365
366impl FileArtifactManifestEntry {
367    fn from_refit_record(record: &RefitArtifactRecord) -> Result<Self> {
368        let entry = Self {
369            node_id: record.node_id.clone(),
370            controller_id: record.controller_id.clone(),
371            artifact: record.artifact.clone(),
372            params_fingerprint: record.params_fingerprint.clone(),
373        };
374        entry.validate()?;
375        Ok(entry)
376    }
377
378    pub fn validate(&self) -> Result<()> {
379        self.artifact.validate_portable()?;
380        if self.artifact.controller_id != self.controller_id {
381            return Err(DagMlError::RuntimeValidation(format!(
382                "artifact manifest entry `{}` controller `{}` does not match artifact controller `{}`",
383                self.artifact.id, self.controller_id, self.artifact.controller_id
384            )));
385        }
386        validate_runtime_fingerprint("artifact manifest params", &self.params_fingerprint)
387    }
388
389    fn matches_refit_record(&self, record: &RefitArtifactRecord) -> bool {
390        self.node_id == record.node_id
391            && self.controller_id == record.controller_id
392            && self.artifact == record.artifact
393            && self.params_fingerprint == record.params_fingerprint
394    }
395}
396
397/// Versioned, file-backed artifact manifest. This is a manifest/portability
398/// layer only: it records portable [`ArtifactRef`] metadata for a bundle's
399/// refit artifacts. It does not deserialize ML objects or materialize artifact
400/// payloads; payload stores remain future work.
401#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
402pub struct FileArtifactManifest {
403    pub bundle_id: BundleId,
404    #[serde(default = "default_file_artifact_manifest_schema_version")]
405    pub schema_version: u32,
406    #[serde(default)]
407    pub artifacts: Vec<FileArtifactManifestEntry>,
408}
409
410impl FileArtifactManifest {
411    pub fn validate(&self) -> Result<()> {
412        if self.schema_version != FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION {
413            return Err(DagMlError::RuntimeValidation(format!(
414                "file artifact manifest for bundle `{}` uses unsupported schema_version {}, expected {}",
415                self.bundle_id, self.schema_version, FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION
416            )));
417        }
418        let mut artifact_ids = BTreeSet::new();
419        let mut uris = BTreeSet::new();
420        for entry in &self.artifacts {
421            entry.validate()?;
422            if !artifact_ids.insert(entry.artifact.id.as_str()) {
423                return Err(DagMlError::RuntimeValidation(format!(
424                    "file artifact manifest for bundle `{}` has duplicate artifact id `{}`",
425                    self.bundle_id, entry.artifact.id
426                )));
427            }
428            // `entry.validate` guarantees a portable URI is present.
429            if let Some(uri) = entry.artifact.uri.as_deref() {
430                if !uris.insert(uri) {
431                    return Err(DagMlError::RuntimeValidation(format!(
432                        "file artifact manifest for bundle `{}` has duplicate artifact uri `{}`",
433                        self.bundle_id, uri
434                    )));
435                }
436            }
437        }
438        Ok(())
439    }
440
441    pub fn validate_against_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
442        self.validate()?;
443        bundle.validate()?;
444        if self.bundle_id != bundle.bundle_id {
445            return Err(DagMlError::RuntimeValidation(format!(
446                "file artifact manifest bundle `{}` does not match bundle `{}`",
447                self.bundle_id, bundle.bundle_id
448            )));
449        }
450        if self.artifacts.len() != bundle.refit_artifacts.len() {
451            return Err(DagMlError::RuntimeValidation(format!(
452                "file artifact manifest for bundle `{}` has {} artifact(s) for {} bundle refit artifact(s)",
453                self.bundle_id,
454                self.artifacts.len(),
455                bundle.refit_artifacts.len()
456            )));
457        }
458        let entries_by_id = self
459            .artifacts
460            .iter()
461            .map(|entry| (entry.artifact.id.as_str(), entry))
462            .collect::<BTreeMap<_, _>>();
463        for record in &bundle.refit_artifacts {
464            let entry = entries_by_id
465                .get(record.artifact.id.as_str())
466                .ok_or_else(|| {
467                    DagMlError::RuntimeValidation(format!(
468                        "file artifact manifest for bundle `{}` is missing refit artifact `{}`",
469                        self.bundle_id, record.artifact.id
470                    ))
471                })?;
472            if !entry.matches_refit_record(record) {
473                return Err(DagMlError::RuntimeValidation(format!(
474                    "file artifact manifest entry `{}` does not match bundle refit artifact",
475                    entry.artifact.id
476                )));
477            }
478        }
479        Ok(())
480    }
481}
482
483/// File-backed artifact manifest store rooted at a directory.
484///
485/// This is a portability/manifest layer: [`FileArtifactManifestStore::write`]
486/// serializes portable artifact references from a validated bundle and
487/// [`FileArtifactManifestStore::open`] reloads and revalidates them against the
488/// bundle. It never reads, writes or deserializes artifact payloads.
489#[derive(Clone, Debug)]
490pub struct FileArtifactManifestStore {
491    root: PathBuf,
492    manifest: FileArtifactManifest,
493}
494
495impl FileArtifactManifestStore {
496    pub fn write(root: impl AsRef<Path>, bundle: &ExecutionBundle) -> Result<FileArtifactManifest> {
497        bundle.validate()?;
498        let root = root.as_ref();
499        fs::create_dir_all(root).map_err(|err| {
500            DagMlError::RuntimeValidation(format!(
501                "failed to create artifact manifest store `{}`: {err}",
502                root.display()
503            ))
504        })?;
505        let mut entries = Vec::with_capacity(bundle.refit_artifacts.len());
506        for record in &bundle.refit_artifacts {
507            entries.push(FileArtifactManifestEntry::from_refit_record(record)?);
508        }
509        entries.sort_by(|left, right| left.artifact.id.cmp(&right.artifact.id));
510        let manifest = FileArtifactManifest {
511            bundle_id: bundle.bundle_id.clone(),
512            schema_version: FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION,
513            artifacts: entries,
514        };
515        manifest.validate_against_bundle(bundle)?;
516        write_runtime_json(
517            &root.join(FILE_ARTIFACT_MANIFEST_FILE),
518            &manifest,
519            "artifact manifest",
520        )?;
521        Ok(manifest)
522    }
523
524    pub fn open(root: impl Into<PathBuf>, bundle: &ExecutionBundle) -> Result<Self> {
525        bundle.validate()?;
526        let root = root.into();
527        let manifest: FileArtifactManifest =
528            read_runtime_json(&root.join(FILE_ARTIFACT_MANIFEST_FILE), "artifact manifest")?;
529        manifest.validate_against_bundle(bundle)?;
530        Ok(Self { root, manifest })
531    }
532
533    pub fn root(&self) -> &Path {
534        &self.root
535    }
536
537    pub fn manifest(&self) -> &FileArtifactManifest {
538        &self.manifest
539    }
540}
541
542#[derive(Clone, Debug, Eq, PartialEq)]
543pub struct ArtifactPayloadMaterializationRecord {
544    pub run_id: RunId,
545    pub bundle_id: BundleId,
546    pub node_id: NodeId,
547    pub phase: Phase,
548    pub variant_id: Option<VariantId>,
549    pub artifact_id: ArtifactId,
550    pub payload_uri: String,
551    pub content_fingerprint: String,
552    pub size_bytes: u64,
553    pub handle: HandleRef,
554}
555
556#[derive(Clone, Debug, Eq, PartialEq)]
557struct ArtifactPayloadMetadata {
558    uri: String,
559    content_fingerprint: String,
560    size_bytes: u64,
561}
562
563#[derive(Clone, Debug)]
564pub struct FileArtifactPayloadStore {
565    root: PathBuf,
566    manifest: FileArtifactManifest,
567    records_by_artifact_id: BTreeMap<ArtifactId, RefitArtifactRecord>,
568    materialization_records: RefCell<Vec<ArtifactPayloadMaterializationRecord>>,
569}
570
571impl FileArtifactPayloadStore {
572    pub fn write_from_source(
573        output_root: impl AsRef<Path>,
574        source_root: impl AsRef<Path>,
575        bundle: &ExecutionBundle,
576    ) -> Result<Self> {
577        bundle.validate()?;
578        let output_root = output_root.as_ref();
579        let source_root = source_root.as_ref();
580        fs::create_dir_all(output_root).map_err(|err| {
581            DagMlError::RuntimeValidation(format!(
582                "failed to create artifact payload store `{}`: {err}",
583                output_root.display()
584            ))
585        })?;
586        for record in &bundle.refit_artifacts {
587            record.artifact.validate_portable()?;
588            validate_artifact_payload_file(source_root, &record.artifact)?;
589            let source_path = artifact_payload_path(source_root, &record.artifact)?;
590            let output_path = artifact_payload_path(output_root, &record.artifact)?;
591            if let Some(parent) = output_path.parent() {
592                fs::create_dir_all(parent).map_err(|err| {
593                    DagMlError::RuntimeValidation(format!(
594                        "failed to create artifact payload directory `{}`: {err}",
595                        parent.display()
596                    ))
597                })?;
598            }
599            if source_path != output_path {
600                fs::copy(&source_path, &output_path).map_err(|err| {
601                    DagMlError::RuntimeValidation(format!(
602                        "failed to copy artifact payload `{}` from {} to {}: {err}",
603                        record.artifact.id,
604                        source_path.display(),
605                        output_path.display()
606                    ))
607                })?;
608            }
609        }
610        FileArtifactManifestStore::write(output_root, bundle)?;
611        Self::open(output_root.to_path_buf(), bundle)
612    }
613
614    pub fn open(root: impl Into<PathBuf>, bundle: &ExecutionBundle) -> Result<Self> {
615        bundle.validate()?;
616        let root = root.into();
617        let manifest_store = FileArtifactManifestStore::open(root.clone(), bundle)?;
618        let records_by_artifact_id = bundle
619            .refit_artifacts
620            .iter()
621            .cloned()
622            .map(|record| (record.artifact.id.clone(), record))
623            .collect::<BTreeMap<_, _>>();
624        let store = Self {
625            root,
626            manifest: manifest_store.manifest().clone(),
627            records_by_artifact_id,
628            materialization_records: RefCell::new(Vec::new()),
629        };
630        store.validate_payloads()?;
631        Ok(store)
632    }
633
634    pub fn root(&self) -> &Path {
635        &self.root
636    }
637
638    pub fn manifest(&self) -> &FileArtifactManifest {
639        &self.manifest
640    }
641
642    pub fn payload_count(&self) -> usize {
643        self.manifest.artifacts.len()
644    }
645
646    pub fn materialization_records(&self) -> Vec<ArtifactPayloadMaterializationRecord> {
647        self.materialization_records.borrow().clone()
648    }
649
650    pub fn validate_payloads(&self) -> Result<()> {
651        self.manifest.validate()?;
652        for entry in &self.manifest.artifacts {
653            let record = self
654                .records_by_artifact_id
655                .get(&entry.artifact.id)
656                .ok_or_else(|| {
657                    DagMlError::RuntimeValidation(format!(
658                        "artifact payload store for bundle `{}` has no bundle record for `{}`",
659                        self.manifest.bundle_id, entry.artifact.id
660                    ))
661                })?;
662            if !entry.matches_refit_record(record) {
663                return Err(DagMlError::RuntimeValidation(format!(
664                    "artifact payload store entry `{}` does not match bundle refit artifact",
665                    entry.artifact.id
666                )));
667            }
668            validate_artifact_payload_file(&self.root, &entry.artifact)?;
669        }
670        Ok(())
671    }
672}
673
674impl RuntimeArtifactStore for FileArtifactPayloadStore {
675    fn materialize(&self, request: &ArtifactMaterializationRequest) -> Result<HandleRef> {
676        request.artifact.validate_portable()?;
677        let record = self
678            .records_by_artifact_id
679            .get(&request.artifact.id)
680            .ok_or_else(|| {
681                DagMlError::RuntimeValidation(format!(
682                    "artifact payload store is missing refit artifact `{}` for bundle `{}`",
683                    request.artifact.id, request.bundle_id
684                ))
685            })?;
686        if record.node_id != request.node_id {
687            return Err(DagMlError::RuntimeValidation(format!(
688                "artifact `{}` is registered for node `{}` but requested for `{}`",
689                request.artifact.id, record.node_id, request.node_id
690            )));
691        }
692        if record.controller_id != request.controller_id {
693            return Err(DagMlError::RuntimeValidation(format!(
694                "artifact `{}` is registered for controller `{}` but requested for `{}`",
695                request.artifact.id, record.controller_id, request.controller_id
696            )));
697        }
698        if record.artifact != request.artifact {
699            return Err(DagMlError::RuntimeValidation(format!(
700                "artifact `{}` metadata does not match bundle record",
701                request.artifact.id
702            )));
703        }
704        if record.params_fingerprint != request.params_fingerprint {
705            return Err(DagMlError::RuntimeValidation(format!(
706                "artifact `{}` params fingerprint does not match bundle record",
707                request.artifact.id
708            )));
709        }
710        let metadata = validate_artifact_payload_file(&self.root, &request.artifact)?;
711        let fingerprint = stable_json_fingerprint(&(
712            &request.run_id,
713            &request.bundle_id,
714            &request.node_id,
715            request.phase,
716            &request.variant_id,
717            &request.artifact.id,
718            &metadata.content_fingerprint,
719            &request.params_fingerprint,
720        ))?;
721        let handle = HandleRef {
722            handle: u64::from_str_radix(&fingerprint[..16], 16)
723                .expect("sha256 hex prefix should fit into u64"),
724            kind: HandleKind::Artifact,
725            owner_controller: request.controller_id.clone(),
726        };
727        self.materialization_records
728            .borrow_mut()
729            .push(ArtifactPayloadMaterializationRecord {
730                run_id: request.run_id.clone(),
731                bundle_id: request.bundle_id.clone(),
732                node_id: request.node_id.clone(),
733                phase: request.phase,
734                variant_id: request.variant_id.clone(),
735                artifact_id: request.artifact.id.clone(),
736                payload_uri: metadata.uri,
737                content_fingerprint: metadata.content_fingerprint,
738                size_bytes: metadata.size_bytes,
739                handle: handle.clone(),
740            });
741        Ok(handle)
742    }
743}
744
745#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
746pub struct LineageRecord {
747    pub record_id: LineageId,
748    pub run_id: RunId,
749    pub node_id: NodeId,
750    pub phase: Phase,
751    pub controller_id: ControllerId,
752    pub controller_version: String,
753    pub variant_id: Option<VariantId>,
754    pub fold_id: Option<FoldId>,
755    #[serde(default)]
756    pub branch_path: Vec<BranchId>,
757    #[serde(default)]
758    pub input_lineage: Vec<LineageId>,
759    #[serde(default)]
760    pub artifact_refs: Vec<ArtifactRef>,
761    pub params_fingerprint: String,
762    pub data_model_shape_fingerprint: Option<String>,
763    pub aggregation_policy_fingerprint: Option<String>,
764    pub seed: Option<u64>,
765    #[serde(default)]
766    pub unsafe_flags: BTreeSet<String>,
767    #[serde(default)]
768    pub metrics: BTreeMap<String, f64>,
769}
770
771impl LineageRecord {
772    pub fn validate(&self) -> Result<()> {
773        if self.params_fingerprint.trim().is_empty() {
774            return Err(DagMlError::RuntimeValidation(format!(
775                "lineage `{}` has empty params fingerprint",
776                self.record_id
777            )));
778        }
779        for artifact in &self.artifact_refs {
780            artifact.validate()?;
781        }
782        Ok(())
783    }
784}
785
786#[derive(Clone, Debug, Default)]
787pub struct InMemoryLineageRecorder {
788    records: BTreeMap<LineageId, LineageRecord>,
789}
790
791impl InMemoryLineageRecorder {
792    pub fn new() -> Self {
793        Self::default()
794    }
795
796    pub fn record(&mut self, record: LineageRecord) -> Result<()> {
797        record.validate()?;
798        if self
799            .records
800            .insert(record.record_id.clone(), record)
801            .is_some()
802        {
803            return Err(DagMlError::RuntimeValidation(
804                "duplicate lineage record id".to_string(),
805            ));
806        }
807        Ok(())
808    }
809
810    pub fn get(&self, id: &LineageId) -> Option<&LineageRecord> {
811        self.records.get(id)
812    }
813
814    pub fn len(&self) -> usize {
815        self.records.len()
816    }
817
818    pub fn is_empty(&self) -> bool {
819        self.records.is_empty()
820    }
821
822    pub fn records(&self) -> impl Iterator<Item = &LineageRecord> {
823        self.records.values()
824    }
825}
826
827#[derive(Clone, Debug, Default)]
828pub struct InMemoryPredictionStore {
829    blocks: Vec<PredictionBlock>,
830}
831
832impl InMemoryPredictionStore {
833    pub fn new() -> Self {
834        Self::default()
835    }
836
837    pub fn append(&mut self, block: PredictionBlock) -> Result<()> {
838        block.validate_shape()?;
839        self.blocks.push(block);
840        Ok(())
841    }
842
843    pub fn blocks(&self) -> &[PredictionBlock] {
844        &self.blocks
845    }
846
847    pub fn find(
848        &self,
849        producer_node: Option<&NodeId>,
850        phase_partition: Option<&crate::oof::PredictionPartition>,
851        fold_id: Option<&FoldId>,
852    ) -> Vec<&PredictionBlock> {
853        self.blocks
854            .iter()
855            .filter(|block| {
856                producer_node.is_none_or(|node_id| &block.producer_node == node_id)
857                    && phase_partition.is_none_or(|partition| &block.partition == partition)
858                    && fold_id.is_none_or(|requested| block.fold_id.as_ref() == Some(requested))
859            })
860            .collect()
861    }
862}
863
864#[derive(Clone, Debug, Default)]
865pub struct InMemoryAggregatedPredictionStore {
866    blocks: Vec<AggregatedPredictionBlock>,
867}
868
869impl InMemoryAggregatedPredictionStore {
870    pub fn new() -> Self {
871        Self::default()
872    }
873
874    pub fn append(&mut self, block: AggregatedPredictionBlock) -> Result<()> {
875        block.validate_shape()?;
876        self.blocks.push(block);
877        Ok(())
878    }
879
880    pub fn blocks(&self) -> &[AggregatedPredictionBlock] {
881        &self.blocks
882    }
883
884    pub fn find(
885        &self,
886        producer_node: Option<&NodeId>,
887        phase_partition: Option<&PredictionPartition>,
888        fold_id: Option<&FoldId>,
889        prediction_level: Option<PredictionLevel>,
890    ) -> Vec<&AggregatedPredictionBlock> {
891        self.blocks
892            .iter()
893            .filter(|block| {
894                producer_node.is_none_or(|node_id| &block.producer_node == node_id)
895                    && phase_partition.is_none_or(|partition| &block.partition == partition)
896                    && fold_id.is_none_or(|requested| block.fold_id.as_ref() == Some(requested))
897                    && prediction_level.is_none_or(|level| block.level == level)
898            })
899            .collect()
900    }
901}
902
903#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
904pub struct PredictionCacheMaterializationRequest {
905    pub run_id: RunId,
906    pub bundle_id: BundleId,
907    pub phase: Phase,
908    pub variant_id: Option<VariantId>,
909    pub requirement: BundlePredictionRequirement,
910    pub cache: BundlePredictionCacheRecord,
911    pub producer_controller_id: ControllerId,
912}
913
914#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
915pub struct PredictionCacheMaterializationRecord {
916    pub run_id: RunId,
917    pub bundle_id: BundleId,
918    pub phase: Phase,
919    pub variant_id: Option<VariantId>,
920    pub requirement_key: String,
921    pub cache_id: String,
922    pub handle: HandleRef,
923}
924
925pub trait RuntimePredictionCacheStore {
926    fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>>;
927    fn load_aggregated_blocks(
928        &self,
929        requirement_key: &str,
930    ) -> Result<Vec<AggregatedPredictionBlock>> {
931        Err(DagMlError::RuntimeValidation(format!(
932            "prediction cache store does not support aggregated requirement `{requirement_key}`"
933        )))
934    }
935    fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef>;
936}
937
938pub const FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION: u32 = 1;
939pub const FILE_PREDICTION_CACHE_MANIFEST_FILE: &str = "prediction_cache_manifest.json";
940
941fn default_file_prediction_cache_store_schema_version() -> u32 {
942    FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION
943}
944
945#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
946pub struct FilePredictionCacheEntry {
947    pub requirement_key: String,
948    pub cache_id: String,
949    pub file_name: String,
950    #[serde(default = "default_runtime_prediction_level")]
951    pub prediction_level: PredictionLevel,
952    #[serde(default, skip_serializing_if = "Vec::is_empty")]
953    pub unit_ids: Vec<PredictionUnitId>,
954    pub block_count: usize,
955    pub row_count: usize,
956    pub content_fingerprint: String,
957}
958
959impl FilePredictionCacheEntry {
960    pub fn validate(&self) -> Result<()> {
961        validate_runtime_non_empty("requirement_key", &self.requirement_key)?;
962        validate_runtime_non_empty("cache_id", &self.cache_id)?;
963        validate_runtime_non_empty("file_name", &self.file_name)?;
964        validate_prediction_cache_file_name(&self.file_name)?;
965        if self.block_count == 0 {
966            return Err(DagMlError::RuntimeValidation(format!(
967                "file prediction cache `{}` has zero block_count",
968                self.cache_id
969            )));
970        }
971        if self.row_count == 0 {
972            return Err(DagMlError::RuntimeValidation(format!(
973                "file prediction cache `{}` has zero row_count",
974                self.cache_id
975            )));
976        }
977        if self.prediction_level != PredictionLevel::Sample && self.unit_ids.is_empty() {
978            return Err(DagMlError::RuntimeValidation(format!(
979                "file prediction cache `{}` has no aggregated unit ids",
980                self.cache_id
981            )));
982        }
983        if self
984            .unit_ids
985            .iter()
986            .any(|unit_id| unit_id.level() != self.prediction_level)
987        {
988            return Err(DagMlError::RuntimeValidation(format!(
989                "file prediction cache `{}` has unit ids outside {:?}",
990                self.cache_id, self.prediction_level
991            )));
992        }
993        validate_runtime_fingerprint("prediction cache content", &self.content_fingerprint)
994    }
995
996    fn from_payload(payload: &crate::bundle::BundlePredictionCachePayload) -> Result<Self> {
997        Ok(Self {
998            requirement_key: payload.requirement_key.clone(),
999            cache_id: payload.cache_id.clone(),
1000            file_name: prediction_cache_payload_file_name(payload)?,
1001            prediction_level: payload.prediction_level,
1002            unit_ids: payload
1003                .aggregated_blocks
1004                .iter()
1005                .flat_map(|block| block.unit_ids.iter().cloned())
1006                .collect(),
1007            block_count: payload.block_count,
1008            row_count: payload.row_count,
1009            content_fingerprint: payload.content_fingerprint.clone(),
1010        })
1011    }
1012
1013    fn matches_record(&self, record: &BundlePredictionCacheRecord) -> bool {
1014        self.requirement_key == record.requirement_key
1015            && self.cache_id == record.cache_id
1016            && self.prediction_level == record.prediction_level
1017            && self.unit_ids == record.unit_ids
1018            && self.block_count == record.block_count
1019            && self.row_count == record.row_count
1020            && self.content_fingerprint == record.content_fingerprint
1021    }
1022}
1023
1024#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1025pub struct FilePredictionCacheManifest {
1026    pub bundle_id: BundleId,
1027    #[serde(default = "default_file_prediction_cache_store_schema_version")]
1028    pub schema_version: u32,
1029    #[serde(default)]
1030    pub caches: Vec<FilePredictionCacheEntry>,
1031}
1032
1033impl FilePredictionCacheManifest {
1034    pub fn validate(&self) -> Result<()> {
1035        if self.schema_version != FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION {
1036            return Err(DagMlError::RuntimeValidation(format!(
1037                "file prediction cache manifest for bundle `{}` uses unsupported schema_version {}, expected {}",
1038                self.bundle_id,
1039                self.schema_version,
1040                FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION
1041            )));
1042        }
1043        let mut requirement_keys = BTreeSet::new();
1044        let mut cache_ids = BTreeSet::new();
1045        let mut file_names = BTreeSet::new();
1046        for entry in &self.caches {
1047            entry.validate()?;
1048            if !requirement_keys.insert(entry.requirement_key.as_str()) {
1049                return Err(DagMlError::RuntimeValidation(format!(
1050                    "file prediction cache manifest for bundle `{}` has duplicate requirement `{}`",
1051                    self.bundle_id, entry.requirement_key
1052                )));
1053            }
1054            if !cache_ids.insert(entry.cache_id.as_str()) {
1055                return Err(DagMlError::RuntimeValidation(format!(
1056                    "file prediction cache manifest for bundle `{}` has duplicate cache id `{}`",
1057                    self.bundle_id, entry.cache_id
1058                )));
1059            }
1060            if !file_names.insert(entry.file_name.as_str()) {
1061                return Err(DagMlError::RuntimeValidation(format!(
1062                    "file prediction cache manifest for bundle `{}` has duplicate file `{}`",
1063                    self.bundle_id, entry.file_name
1064                )));
1065            }
1066        }
1067        Ok(())
1068    }
1069
1070    pub fn validate_against_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
1071        self.validate()?;
1072        bundle.validate()?;
1073        if self.bundle_id != bundle.bundle_id {
1074            return Err(DagMlError::RuntimeValidation(format!(
1075                "file prediction cache manifest bundle `{}` does not match bundle `{}`",
1076                self.bundle_id, bundle.bundle_id
1077            )));
1078        }
1079        if self.caches.len() != bundle.prediction_caches.len() {
1080            return Err(DagMlError::RuntimeValidation(format!(
1081                "file prediction cache manifest for bundle `{}` has {} cache(s) for {} bundle cache record(s)",
1082                self.bundle_id,
1083                self.caches.len(),
1084                bundle.prediction_caches.len()
1085            )));
1086        }
1087        let entries_by_requirement = self
1088            .caches
1089            .iter()
1090            .map(|entry| (entry.requirement_key.as_str(), entry))
1091            .collect::<BTreeMap<_, _>>();
1092        for record in &bundle.prediction_caches {
1093            let entry = entries_by_requirement
1094                .get(record.requirement_key.as_str())
1095                .ok_or_else(|| {
1096                    DagMlError::RuntimeValidation(format!(
1097                        "file prediction cache manifest for bundle `{}` is missing requirement `{}`",
1098                        self.bundle_id, record.requirement_key
1099                    ))
1100                })?;
1101            if !entry.matches_record(record) {
1102                return Err(DagMlError::RuntimeValidation(format!(
1103                    "file prediction cache manifest entry `{}` does not match bundle cache record",
1104                    entry.cache_id
1105                )));
1106            }
1107        }
1108        Ok(())
1109    }
1110}
1111
1112#[derive(Clone, Debug)]
1113pub struct FilePredictionCacheStore {
1114    root: PathBuf,
1115    manifest: FilePredictionCacheManifest,
1116    records_by_requirement: BTreeMap<String, BundlePredictionCacheRecord>,
1117    materialization_records: RefCell<Vec<PredictionCacheMaterializationRecord>>,
1118}
1119
1120impl FilePredictionCacheStore {
1121    pub fn write_payload_set(
1122        root: impl AsRef<Path>,
1123        bundle: &ExecutionBundle,
1124        payloads: &BundlePredictionCachePayloadSet,
1125    ) -> Result<FilePredictionCacheManifest> {
1126        payloads.validate_against_bundle(bundle)?;
1127        let root = root.as_ref();
1128        fs::create_dir_all(root).map_err(|err| {
1129            DagMlError::RuntimeValidation(format!(
1130                "failed to create prediction cache store `{}`: {err}",
1131                root.display()
1132            ))
1133        })?;
1134
1135        let mut entries = Vec::new();
1136        let records_by_requirement = bundle
1137            .prediction_caches
1138            .iter()
1139            .map(|record| (record.requirement_key.as_str(), record))
1140            .collect::<BTreeMap<_, _>>();
1141        for payload in &payloads.caches {
1142            let record = records_by_requirement
1143                .get(payload.requirement_key.as_str())
1144                .ok_or_else(|| {
1145                    DagMlError::RuntimeValidation(format!(
1146                        "prediction cache payload `{}` references unknown requirement `{}`",
1147                        payload.cache_id, payload.requirement_key
1148                    ))
1149                })?;
1150            validate_prediction_cache_payload_matches_record(payload, record)?;
1151            let entry = FilePredictionCacheEntry::from_payload(payload)?;
1152            let payload_path = root.join(&entry.file_name);
1153            write_runtime_json(&payload_path, payload, "prediction cache payload")?;
1154            entries.push(entry);
1155        }
1156        entries.sort_by(|left, right| left.requirement_key.cmp(&right.requirement_key));
1157        let manifest = FilePredictionCacheManifest {
1158            bundle_id: bundle.bundle_id.clone(),
1159            schema_version: FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION,
1160            caches: entries,
1161        };
1162        manifest.validate_against_bundle(bundle)?;
1163        write_runtime_json(
1164            &root.join(FILE_PREDICTION_CACHE_MANIFEST_FILE),
1165            &manifest,
1166            "prediction cache manifest",
1167        )?;
1168        Ok(manifest)
1169    }
1170
1171    pub fn open(root: impl Into<PathBuf>, bundle: &ExecutionBundle) -> Result<Self> {
1172        bundle.validate()?;
1173        let root = root.into();
1174        let manifest: FilePredictionCacheManifest = read_runtime_json(
1175            &root.join(FILE_PREDICTION_CACHE_MANIFEST_FILE),
1176            "prediction cache manifest",
1177        )?;
1178        manifest.validate_against_bundle(bundle)?;
1179        let records_by_requirement = bundle
1180            .prediction_caches
1181            .iter()
1182            .cloned()
1183            .map(|record| (record.requirement_key.clone(), record))
1184            .collect::<BTreeMap<_, _>>();
1185        Ok(Self {
1186            root,
1187            manifest,
1188            records_by_requirement,
1189            materialization_records: RefCell::new(Vec::new()),
1190        })
1191    }
1192
1193    pub fn manifest(&self) -> &FilePredictionCacheManifest {
1194        &self.manifest
1195    }
1196
1197    pub fn materialization_records(&self) -> Vec<PredictionCacheMaterializationRecord> {
1198        self.materialization_records.borrow().clone()
1199    }
1200
1201    fn payload_for_requirement(
1202        &self,
1203        requirement_key: &str,
1204    ) -> Result<crate::bundle::BundlePredictionCachePayload> {
1205        let entry = self
1206            .manifest
1207            .caches
1208            .iter()
1209            .find(|entry| entry.requirement_key == requirement_key)
1210            .ok_or_else(|| {
1211                DagMlError::RuntimeValidation(format!(
1212                    "file prediction cache store is missing requirement `{requirement_key}`"
1213                ))
1214            })?;
1215        let record = self
1216            .records_by_requirement
1217            .get(requirement_key)
1218            .ok_or_else(|| {
1219                DagMlError::RuntimeValidation(format!(
1220                    "file prediction cache store has no bundle record for requirement `{requirement_key}`"
1221                ))
1222            })?;
1223        let payload: crate::bundle::BundlePredictionCachePayload = read_runtime_json(
1224            &self.root.join(&entry.file_name),
1225            "prediction cache payload",
1226        )?;
1227        validate_prediction_cache_payload_matches_record(&payload, record)?;
1228        Ok(payload)
1229    }
1230}
1231
1232impl RuntimePredictionCacheStore for FilePredictionCacheStore {
1233    fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>> {
1234        let payload = self.payload_for_requirement(requirement_key)?;
1235        if payload.prediction_level != PredictionLevel::Sample {
1236            return Err(DagMlError::RuntimeValidation(format!(
1237                "file prediction cache store requirement `{requirement_key}` contains {:?} predictions, not sample blocks",
1238                payload.prediction_level
1239            )));
1240        }
1241        Ok(payload.blocks)
1242    }
1243
1244    fn load_aggregated_blocks(
1245        &self,
1246        requirement_key: &str,
1247    ) -> Result<Vec<AggregatedPredictionBlock>> {
1248        let payload = self.payload_for_requirement(requirement_key)?;
1249        if payload.prediction_level == PredictionLevel::Sample {
1250            return Err(DagMlError::RuntimeValidation(format!(
1251                "file prediction cache store requirement `{requirement_key}` contains sample predictions, not aggregated blocks"
1252            )));
1253        }
1254        Ok(payload.aggregated_blocks)
1255    }
1256
1257    fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef> {
1258        request.requirement.validate()?;
1259        request.cache.validate()?;
1260        let requirement_key = request.requirement.key();
1261        let record = self
1262            .records_by_requirement
1263            .get(&requirement_key)
1264            .ok_or_else(|| {
1265                DagMlError::RuntimeValidation(format!(
1266                    "file prediction cache store is missing requirement `{requirement_key}`"
1267                ))
1268            })?;
1269        if record != &request.cache {
1270            return Err(DagMlError::RuntimeValidation(format!(
1271                "file prediction cache materialization request for `{requirement_key}` does not match bundle cache record"
1272            )));
1273        }
1274        let payload = self.payload_for_requirement(&requirement_key)?;
1275        validate_prediction_cache_payload_matches_record(&payload, record)?;
1276        let fingerprint = stable_json_fingerprint(&(
1277            &request.run_id,
1278            &request.bundle_id,
1279            request.phase,
1280            &request.variant_id,
1281            &request.cache.requirement_key,
1282            &request.cache.cache_id,
1283            request.cache.prediction_level,
1284            &request.cache.content_fingerprint,
1285        ))?;
1286        let handle = HandleRef {
1287            handle: u64::from_str_radix(&fingerprint[..16], 16)
1288                .expect("sha256 hex prefix should fit into u64"),
1289            kind: HandleKind::Prediction,
1290            owner_controller: request.producer_controller_id.clone(),
1291        };
1292        self.materialization_records
1293            .borrow_mut()
1294            .push(PredictionCacheMaterializationRecord {
1295                run_id: request.run_id.clone(),
1296                bundle_id: request.bundle_id.clone(),
1297                phase: request.phase,
1298                variant_id: request.variant_id.clone(),
1299                requirement_key,
1300                cache_id: request.cache.cache_id.clone(),
1301                handle: handle.clone(),
1302            });
1303        Ok(handle)
1304    }
1305}
1306
1307fn prediction_cache_payload_file_name(
1308    payload: &crate::bundle::BundlePredictionCachePayload,
1309) -> Result<String> {
1310    let fingerprint = stable_json_fingerprint(&(
1311        &payload.requirement_key,
1312        &payload.cache_id,
1313        payload.prediction_level,
1314        &payload.content_fingerprint,
1315        payload.block_count,
1316        payload.row_count,
1317    ))?;
1318    Ok(format!("prediction-cache-{}.json", &fingerprint[..16]))
1319}
1320
1321fn validate_prediction_cache_file_name(file_name: &str) -> Result<()> {
1322    if file_name == "." || file_name == ".." || file_name.contains('/') || file_name.contains('\\')
1323    {
1324        return Err(DagMlError::RuntimeValidation(format!(
1325            "prediction cache file name `{file_name}` must be a plain file name"
1326        )));
1327    }
1328    Ok(())
1329}
1330
1331#[derive(Clone, Debug, PartialEq)]
1332pub struct ColumnarPredictionCacheBlock {
1333    pub prediction_id: Option<String>,
1334    pub producer_node: NodeId,
1335    pub partition: PredictionPartition,
1336    pub fold_id: Option<FoldId>,
1337    pub prediction_level: PredictionLevel,
1338    pub unit_ids: Vec<PredictionUnitId>,
1339    pub sample_ids: Vec<SampleId>,
1340    pub target_names: Vec<String>,
1341    pub width: usize,
1342    pub columns: Vec<Vec<f64>>,
1343}
1344
1345impl ColumnarPredictionCacheBlock {
1346    pub fn from_prediction_block(block: &PredictionBlock) -> Result<Self> {
1347        let width = block.validate_shape()?;
1348        let mut columns = vec![Vec::with_capacity(block.values.len()); width];
1349        for row in &block.values {
1350            for (column_idx, value) in row.iter().enumerate() {
1351                columns[column_idx].push(*value);
1352            }
1353        }
1354        Ok(Self {
1355            prediction_id: block.prediction_id.clone(),
1356            producer_node: block.producer_node.clone(),
1357            partition: block.partition.clone(),
1358            fold_id: block.fold_id.clone(),
1359            prediction_level: PredictionLevel::Sample,
1360            unit_ids: Vec::new(),
1361            sample_ids: block.sample_ids.clone(),
1362            target_names: block.target_names.clone(),
1363            width,
1364            columns,
1365        })
1366    }
1367
1368    pub fn from_aggregated_prediction_block(block: &AggregatedPredictionBlock) -> Result<Self> {
1369        let width = block.validate_shape()?;
1370        if block.level == PredictionLevel::Sample {
1371            return Err(DagMlError::RuntimeValidation(format!(
1372                "columnar aggregated prediction block for `{}` must use target/group level, got sample",
1373                block.producer_node
1374            )));
1375        }
1376        let mut columns = vec![Vec::with_capacity(block.values.len()); width];
1377        for row in &block.values {
1378            for (column_idx, value) in row.iter().enumerate() {
1379                columns[column_idx].push(*value);
1380            }
1381        }
1382        Ok(Self {
1383            prediction_id: block.prediction_id.clone(),
1384            producer_node: block.producer_node.clone(),
1385            partition: block.partition.clone(),
1386            fold_id: block.fold_id.clone(),
1387            prediction_level: block.level,
1388            unit_ids: block.unit_ids.clone(),
1389            sample_ids: Vec::new(),
1390            target_names: block.target_names.clone(),
1391            width,
1392            columns,
1393        })
1394    }
1395
1396    pub fn row_count(&self) -> usize {
1397        match self.prediction_level {
1398            PredictionLevel::Sample => self.sample_ids.len(),
1399            PredictionLevel::Target | PredictionLevel::Group => self.unit_ids.len(),
1400            PredictionLevel::Observation => 0,
1401        }
1402    }
1403
1404    pub fn value_count(&self) -> usize {
1405        self.columns.iter().map(Vec::len).sum()
1406    }
1407
1408    pub fn validate(&self) -> Result<()> {
1409        match self.prediction_level {
1410            PredictionLevel::Observation => {
1411                return Err(DagMlError::RuntimeValidation(format!(
1412                    "columnar prediction block for `{}` cannot store observation-level predictions",
1413                    self.producer_node
1414                )));
1415            }
1416            PredictionLevel::Sample => {
1417                if self.sample_ids.is_empty() {
1418                    return Err(DagMlError::RuntimeValidation(format!(
1419                        "columnar sample prediction block for `{}` has no sample ids",
1420                        self.producer_node
1421                    )));
1422                }
1423                if !self.unit_ids.is_empty() {
1424                    return Err(DagMlError::RuntimeValidation(format!(
1425                        "columnar sample prediction block for `{}` unexpectedly carries unit ids",
1426                        self.producer_node
1427                    )));
1428                }
1429            }
1430            PredictionLevel::Target | PredictionLevel::Group => {
1431                if !self.sample_ids.is_empty() {
1432                    return Err(DagMlError::RuntimeValidation(format!(
1433                        "columnar aggregated prediction block for `{}` unexpectedly carries sample ids",
1434                        self.producer_node
1435                    )));
1436                }
1437                if self.unit_ids.is_empty() {
1438                    return Err(DagMlError::RuntimeValidation(format!(
1439                        "columnar aggregated prediction block for `{}` has no unit ids",
1440                        self.producer_node
1441                    )));
1442                }
1443                if self
1444                    .unit_ids
1445                    .iter()
1446                    .any(|unit_id| unit_id.level() != self.prediction_level)
1447                {
1448                    return Err(DagMlError::RuntimeValidation(format!(
1449                        "columnar aggregated prediction block for `{}` carries unit ids outside {:?}",
1450                        self.producer_node, self.prediction_level
1451                    )));
1452                }
1453            }
1454        }
1455        if self.width == 0 {
1456            return Err(DagMlError::RuntimeValidation(format!(
1457                "columnar prediction block for `{}` has zero width",
1458                self.producer_node
1459            )));
1460        }
1461        if self.columns.len() != self.width {
1462            return Err(DagMlError::RuntimeValidation(format!(
1463                "columnar prediction block for `{}` has {} column(s), expected {}",
1464                self.producer_node,
1465                self.columns.len(),
1466                self.width
1467            )));
1468        }
1469        for (column_idx, column) in self.columns.iter().enumerate() {
1470            if column.len() != self.row_count() {
1471                return Err(DagMlError::RuntimeValidation(format!(
1472                    "columnar prediction block for `{}` column {} has {} value(s), expected {}",
1473                    self.producer_node,
1474                    column_idx,
1475                    column.len(),
1476                    self.row_count()
1477                )));
1478            }
1479        }
1480        if !self.target_names.is_empty() && self.target_names.len() != self.width {
1481            return Err(DagMlError::RuntimeValidation(format!(
1482                "columnar prediction block for `{}` has {} target names for width {}",
1483                self.producer_node,
1484                self.target_names.len(),
1485                self.width
1486            )));
1487        }
1488        Ok(())
1489    }
1490
1491    pub fn to_prediction_block(&self) -> Result<PredictionBlock> {
1492        self.validate()?;
1493        if self.prediction_level != PredictionLevel::Sample {
1494            return Err(DagMlError::RuntimeValidation(format!(
1495                "columnar prediction block for `{}` contains {:?} predictions, not sample predictions",
1496                self.producer_node, self.prediction_level
1497            )));
1498        }
1499        let values = (0..self.row_count())
1500            .map(|row_idx| {
1501                self.columns
1502                    .iter()
1503                    .map(|column| column[row_idx])
1504                    .collect::<Vec<_>>()
1505            })
1506            .collect();
1507        let block = PredictionBlock {
1508            prediction_id: self.prediction_id.clone(),
1509            producer_node: self.producer_node.clone(),
1510            partition: self.partition.clone(),
1511            fold_id: self.fold_id.clone(),
1512            sample_ids: self.sample_ids.clone(),
1513            values,
1514            target_names: self.target_names.clone(),
1515        };
1516        block.validate_shape()?;
1517        Ok(block)
1518    }
1519
1520    pub fn to_aggregated_prediction_block(&self) -> Result<AggregatedPredictionBlock> {
1521        self.validate()?;
1522        if self.prediction_level == PredictionLevel::Sample {
1523            return Err(DagMlError::RuntimeValidation(format!(
1524                "columnar prediction block for `{}` contains sample predictions, not aggregated predictions",
1525                self.producer_node
1526            )));
1527        }
1528        let values = (0..self.row_count())
1529            .map(|row_idx| {
1530                self.columns
1531                    .iter()
1532                    .map(|column| column[row_idx])
1533                    .collect::<Vec<_>>()
1534            })
1535            .collect();
1536        let block = AggregatedPredictionBlock {
1537            prediction_id: self.prediction_id.clone(),
1538            producer_node: self.producer_node.clone(),
1539            partition: self.partition.clone(),
1540            fold_id: self.fold_id.clone(),
1541            level: self.prediction_level,
1542            unit_ids: self.unit_ids.clone(),
1543            values,
1544            target_names: self.target_names.clone(),
1545        };
1546        block.validate_shape()?;
1547        Ok(block)
1548    }
1549}
1550
1551#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1552pub struct ColumnarPredictionCacheManifest {
1553    pub requirement_key: String,
1554    pub cache_id: String,
1555    pub prediction_level: PredictionLevel,
1556    pub block_count: usize,
1557    pub row_count: usize,
1558    pub prediction_width: usize,
1559    pub value_count: usize,
1560    pub estimated_value_bytes: usize,
1561    pub content_fingerprint: String,
1562}
1563
1564#[derive(Clone, Debug, PartialEq)]
1565struct ColumnarPredictionCacheEntry {
1566    cache: BundlePredictionCacheRecord,
1567    blocks: Vec<ColumnarPredictionCacheBlock>,
1568}
1569
1570impl ColumnarPredictionCacheEntry {
1571    fn from_payload(
1572        payload: BundlePredictionCachePayload,
1573        cache: BundlePredictionCacheRecord,
1574    ) -> Result<Self> {
1575        validate_prediction_cache_payload_matches_record(&payload, &cache)?;
1576        let blocks = match payload.prediction_level {
1577            PredictionLevel::Sample => payload
1578                .blocks
1579                .iter()
1580                .map(ColumnarPredictionCacheBlock::from_prediction_block)
1581                .collect::<Result<Vec<_>>>()?,
1582            PredictionLevel::Target | PredictionLevel::Group => payload
1583                .aggregated_blocks
1584                .iter()
1585                .map(ColumnarPredictionCacheBlock::from_aggregated_prediction_block)
1586                .collect::<Result<Vec<_>>>()?,
1587            PredictionLevel::Observation => {
1588                return Err(DagMlError::RuntimeValidation(format!(
1589                    "columnar prediction cache payload `{}` cannot use observation-level predictions",
1590                    payload.cache_id
1591                )));
1592            }
1593        };
1594        let entry = Self { cache, blocks };
1595        entry.validate()?;
1596        Ok(entry)
1597    }
1598
1599    fn validate(&self) -> Result<()> {
1600        self.cache.validate()?;
1601        if self.blocks.len() != self.cache.block_count {
1602            return Err(DagMlError::RuntimeValidation(format!(
1603                "columnar prediction cache `{}` has {} block(s), expected {}",
1604                self.cache.cache_id,
1605                self.blocks.len(),
1606                self.cache.block_count
1607            )));
1608        }
1609        let mut row_count = 0usize;
1610        let mut value_count = 0usize;
1611        for block in &self.blocks {
1612            block.validate()?;
1613            if block.prediction_level != self.cache.prediction_level {
1614                return Err(DagMlError::RuntimeValidation(format!(
1615                    "columnar prediction cache `{}` contains a {:?} block, expected {:?}",
1616                    self.cache.cache_id, block.prediction_level, self.cache.prediction_level
1617                )));
1618            }
1619            if block.partition != self.cache.partition {
1620                return Err(DagMlError::RuntimeValidation(format!(
1621                    "columnar prediction cache `{}` contains a block from partition {:?}",
1622                    self.cache.cache_id, block.partition
1623                )));
1624            }
1625            row_count += block.row_count();
1626            value_count += block.value_count();
1627        }
1628        if row_count != self.cache.row_count {
1629            return Err(DagMlError::RuntimeValidation(format!(
1630                "columnar prediction cache `{}` has {} row(s), expected {}",
1631                self.cache.cache_id, row_count, self.cache.row_count
1632            )));
1633        }
1634        let expected_values = self
1635            .cache
1636            .row_count
1637            .checked_mul(self.cache.prediction_width)
1638            .ok_or_else(|| {
1639                DagMlError::RuntimeValidation(format!(
1640                    "columnar prediction cache `{}` value count overflow",
1641                    self.cache.cache_id
1642                ))
1643            })?;
1644        if value_count != expected_values {
1645            return Err(DagMlError::RuntimeValidation(format!(
1646                "columnar prediction cache `{}` has {} value(s), expected {}",
1647                self.cache.cache_id, value_count, expected_values
1648            )));
1649        }
1650        Ok(())
1651    }
1652
1653    fn to_blocks(&self) -> Result<Vec<PredictionBlock>> {
1654        self.validate()?;
1655        self.blocks
1656            .iter()
1657            .map(ColumnarPredictionCacheBlock::to_prediction_block)
1658            .collect()
1659    }
1660
1661    fn to_aggregated_blocks(&self) -> Result<Vec<AggregatedPredictionBlock>> {
1662        self.validate()?;
1663        self.blocks
1664            .iter()
1665            .map(ColumnarPredictionCacheBlock::to_aggregated_prediction_block)
1666            .collect()
1667    }
1668
1669    fn validate_against_cache_record(&self, cache: &BundlePredictionCacheRecord) -> Result<()> {
1670        if &self.cache != cache {
1671            return Err(DagMlError::RuntimeValidation(format!(
1672                "columnar prediction cache materialization request for `{}` does not match bundle cache record",
1673                cache.requirement_key
1674            )));
1675        }
1676        let (blocks, aggregated_blocks) = match self.cache.prediction_level {
1677            PredictionLevel::Sample => (self.to_blocks()?, Vec::new()),
1678            PredictionLevel::Target | PredictionLevel::Group => {
1679                (Vec::new(), self.to_aggregated_blocks()?)
1680            }
1681            PredictionLevel::Observation => {
1682                return Err(DagMlError::RuntimeValidation(format!(
1683                    "columnar prediction cache `{}` cannot materialize observation-level predictions",
1684                    self.cache.cache_id
1685                )));
1686            }
1687        };
1688        let payload = BundlePredictionCachePayload {
1689            requirement_key: self.cache.requirement_key.clone(),
1690            cache_id: self.cache.cache_id.clone(),
1691            format: self.cache.format.clone(),
1692            partition: self.cache.partition.clone(),
1693            prediction_level: self.cache.prediction_level,
1694            block_count: self.cache.block_count,
1695            row_count: self.cache.row_count,
1696            content_fingerprint: self.cache.content_fingerprint.clone(),
1697            blocks,
1698            aggregated_blocks,
1699        };
1700        validate_prediction_cache_payload_matches_record(&payload, cache)
1701    }
1702
1703    fn manifest(&self) -> ColumnarPredictionCacheManifest {
1704        let value_count = self
1705            .blocks
1706            .iter()
1707            .map(ColumnarPredictionCacheBlock::value_count)
1708            .sum::<usize>();
1709        ColumnarPredictionCacheManifest {
1710            requirement_key: self.cache.requirement_key.clone(),
1711            cache_id: self.cache.cache_id.clone(),
1712            prediction_level: self.cache.prediction_level,
1713            block_count: self.cache.block_count,
1714            row_count: self.cache.row_count,
1715            prediction_width: self.cache.prediction_width,
1716            value_count,
1717            estimated_value_bytes: value_count * std::mem::size_of::<f64>(),
1718            content_fingerprint: self.cache.content_fingerprint.clone(),
1719        }
1720    }
1721}
1722
1723#[derive(Clone, Debug, Default)]
1724pub struct ColumnarPredictionCacheStore {
1725    entries: BTreeMap<String, ColumnarPredictionCacheEntry>,
1726    materialization_records: RefCell<Vec<PredictionCacheMaterializationRecord>>,
1727}
1728
1729impl ColumnarPredictionCacheStore {
1730    pub fn from_payloads(
1731        bundle: &ExecutionBundle,
1732        payloads: BundlePredictionCachePayloadSet,
1733    ) -> Result<Self> {
1734        payloads.validate_against_bundle(bundle)?;
1735        let records_by_requirement = bundle
1736            .prediction_caches
1737            .iter()
1738            .cloned()
1739            .map(|cache| (cache.requirement_key.clone(), cache))
1740            .collect::<BTreeMap<_, _>>();
1741        let mut entries = BTreeMap::new();
1742        for payload in payloads.caches {
1743            let cache = records_by_requirement
1744                .get(&payload.requirement_key)
1745                .cloned()
1746                .ok_or_else(|| {
1747                    DagMlError::RuntimeValidation(format!(
1748                        "columnar prediction cache payload `{}` references unknown requirement `{}`",
1749                        payload.cache_id, payload.requirement_key
1750                    ))
1751                })?;
1752            let requirement_key = payload.requirement_key.clone();
1753            let previous = entries.insert(
1754                requirement_key,
1755                ColumnarPredictionCacheEntry::from_payload(payload, cache)?,
1756            );
1757            debug_assert!(previous.is_none());
1758        }
1759        Ok(Self {
1760            entries,
1761            materialization_records: RefCell::new(Vec::new()),
1762        })
1763    }
1764
1765    pub fn entry_count(&self) -> usize {
1766        self.entries.len()
1767    }
1768
1769    pub fn manifests(&self) -> Vec<ColumnarPredictionCacheManifest> {
1770        self.entries
1771            .values()
1772            .map(ColumnarPredictionCacheEntry::manifest)
1773            .collect()
1774    }
1775
1776    pub fn materialization_records(&self) -> Vec<PredictionCacheMaterializationRecord> {
1777        self.materialization_records.borrow().clone()
1778    }
1779}
1780
1781impl RuntimePredictionCacheStore for ColumnarPredictionCacheStore {
1782    fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>> {
1783        let entry = self.entries.get(requirement_key).ok_or_else(|| {
1784            DagMlError::RuntimeValidation(format!(
1785                "columnar prediction cache store is missing requirement `{requirement_key}`"
1786            ))
1787        })?;
1788        if entry.cache.prediction_level != PredictionLevel::Sample {
1789            return Err(DagMlError::RuntimeValidation(format!(
1790                "columnar prediction cache store requirement `{requirement_key}` contains {:?} predictions, not sample blocks",
1791                entry.cache.prediction_level
1792            )));
1793        }
1794        entry.validate_against_cache_record(&entry.cache)?;
1795        entry.to_blocks()
1796    }
1797
1798    fn load_aggregated_blocks(
1799        &self,
1800        requirement_key: &str,
1801    ) -> Result<Vec<AggregatedPredictionBlock>> {
1802        let entry = self.entries.get(requirement_key).ok_or_else(|| {
1803            DagMlError::RuntimeValidation(format!(
1804                "columnar prediction cache store is missing requirement `{requirement_key}`"
1805            ))
1806        })?;
1807        if entry.cache.prediction_level == PredictionLevel::Sample {
1808            return Err(DagMlError::RuntimeValidation(format!(
1809                "columnar prediction cache store requirement `{requirement_key}` contains sample predictions, not aggregated blocks"
1810            )));
1811        }
1812        entry.validate_against_cache_record(&entry.cache)?;
1813        entry.to_aggregated_blocks()
1814    }
1815
1816    fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef> {
1817        request.requirement.validate()?;
1818        request.cache.validate()?;
1819        let requirement_key = request.requirement.key();
1820        if requirement_key != request.cache.requirement_key {
1821            return Err(DagMlError::RuntimeValidation(format!(
1822                "columnar prediction cache materialization request for `{}` uses cache `{}` with mismatched requirement `{}`",
1823                requirement_key, request.cache.cache_id, request.cache.requirement_key
1824            )));
1825        }
1826        let entry = self.entries.get(&requirement_key).ok_or_else(|| {
1827            DagMlError::RuntimeValidation(format!(
1828                "columnar prediction cache store is missing requirement `{requirement_key}`"
1829            ))
1830        })?;
1831        entry.validate_against_cache_record(&request.cache)?;
1832        let fingerprint = stable_json_fingerprint(&(
1833            &request.run_id,
1834            &request.bundle_id,
1835            request.phase,
1836            &request.variant_id,
1837            &request.cache.requirement_key,
1838            &request.cache.cache_id,
1839            request.cache.prediction_level,
1840            &request.cache.content_fingerprint,
1841        ))?;
1842        let handle = HandleRef {
1843            handle: u64::from_str_radix(&fingerprint[..16], 16)
1844                .expect("sha256 hex prefix should fit into u64"),
1845            kind: HandleKind::Prediction,
1846            owner_controller: request.producer_controller_id.clone(),
1847        };
1848        self.materialization_records
1849            .borrow_mut()
1850            .push(PredictionCacheMaterializationRecord {
1851                run_id: request.run_id.clone(),
1852                bundle_id: request.bundle_id.clone(),
1853                phase: request.phase,
1854                variant_id: request.variant_id.clone(),
1855                requirement_key,
1856                cache_id: request.cache.cache_id.clone(),
1857                handle: handle.clone(),
1858            });
1859        Ok(handle)
1860    }
1861}
1862
1863fn validate_runtime_non_empty(label: &str, value: &str) -> Result<()> {
1864    if value.trim().is_empty() {
1865        return Err(DagMlError::RuntimeValidation(format!("{label} is empty")));
1866    }
1867    Ok(())
1868}
1869
1870fn validate_artifact_optional_text(
1871    label: &str,
1872    value: &Option<String>,
1873    artifact_id: &ArtifactId,
1874) -> Result<()> {
1875    let Some(value) = value else {
1876        return Ok(());
1877    };
1878    if value.trim().is_empty() {
1879        return Err(DagMlError::RuntimeValidation(format!(
1880            "artifact `{artifact_id}` has empty {label}"
1881        )));
1882    }
1883    if value.chars().any(char::is_control) {
1884        return Err(DagMlError::RuntimeValidation(format!(
1885            "artifact `{artifact_id}` has control characters in {label}"
1886        )));
1887    }
1888    Ok(())
1889}
1890
1891fn artifact_payload_path(root: &Path, artifact: &ArtifactRef) -> Result<PathBuf> {
1892    artifact.validate_portable()?;
1893    let uri = artifact
1894        .uri
1895        .as_deref()
1896        .expect("portable artifact validation requires uri");
1897    Ok(root.join(uri))
1898}
1899
1900fn validate_artifact_payload_file(
1901    root: &Path,
1902    artifact: &ArtifactRef,
1903) -> Result<ArtifactPayloadMetadata> {
1904    artifact.validate_portable()?;
1905    let uri = artifact
1906        .uri
1907        .as_deref()
1908        .expect("portable artifact validation requires uri")
1909        .to_string();
1910    let path = artifact_payload_path(root, artifact)?;
1911    validate_payload_path_stays_within_root(root, &path, artifact)?;
1912    let metadata = fs::metadata(&path).map_err(|err| {
1913        DagMlError::RuntimeValidation(format!(
1914            "failed to stat artifact payload `{}` at {}: {err}",
1915            artifact.id,
1916            path.display()
1917        ))
1918    })?;
1919    if !metadata.is_file() {
1920        return Err(DagMlError::RuntimeValidation(format!(
1921            "artifact payload `{}` at {} is not a regular file",
1922            artifact.id,
1923            path.display()
1924        )));
1925    }
1926    let size_bytes = metadata.len();
1927    if let Some(expected_size) = artifact.size_bytes {
1928        if expected_size != size_bytes {
1929            return Err(DagMlError::RuntimeValidation(format!(
1930                "artifact payload `{}` size mismatch: expected {}, got {}",
1931                artifact.id, expected_size, size_bytes
1932            )));
1933        }
1934    }
1935    let content_fingerprint =
1936        sha256_file_hex(&path, &format!("artifact payload `{}`", artifact.id))?;
1937    let expected_fingerprint = artifact
1938        .content_fingerprint
1939        .as_deref()
1940        .expect("portable artifact validation requires content_fingerprint");
1941    if !content_fingerprint.eq_ignore_ascii_case(expected_fingerprint) {
1942        return Err(DagMlError::RuntimeValidation(format!(
1943            "artifact payload `{}` content fingerprint mismatch",
1944            artifact.id
1945        )));
1946    }
1947    Ok(ArtifactPayloadMetadata {
1948        uri,
1949        content_fingerprint,
1950        size_bytes,
1951    })
1952}
1953
1954fn validate_payload_path_stays_within_root(
1955    root: &Path,
1956    path: &Path,
1957    artifact: &ArtifactRef,
1958) -> Result<()> {
1959    let root = fs::canonicalize(root).map_err(|err| {
1960        DagMlError::RuntimeValidation(format!(
1961            "failed to canonicalize artifact payload root `{}`: {err}",
1962            root.display()
1963        ))
1964    })?;
1965    let path = fs::canonicalize(path).map_err(|err| {
1966        DagMlError::RuntimeValidation(format!(
1967            "failed to canonicalize artifact payload `{}` at {}: {err}",
1968            artifact.id,
1969            path.display()
1970        ))
1971    })?;
1972    if !path.starts_with(&root) {
1973        return Err(DagMlError::RuntimeValidation(format!(
1974            "artifact payload `{}` resolves outside store root `{}`",
1975            artifact.id,
1976            root.display()
1977        )));
1978    }
1979    Ok(())
1980}
1981
1982fn sha256_file_hex(path: &Path, label: &str) -> Result<String> {
1983    let mut file = fs::File::open(path).map_err(|err| {
1984        DagMlError::RuntimeValidation(format!(
1985            "failed to open {label} at {}: {err}",
1986            path.display()
1987        ))
1988    })?;
1989    let mut hasher = Sha256::new();
1990    let mut buffer = [0u8; 64 * 1024];
1991    loop {
1992        let read = file.read(&mut buffer).map_err(|err| {
1993            DagMlError::RuntimeValidation(format!(
1994                "failed to read {label} at {}: {err}",
1995                path.display()
1996            ))
1997        })?;
1998        if read == 0 {
1999            break;
2000        }
2001        hasher.update(&buffer[..read]);
2002    }
2003    Ok(bytes_to_hex(&hasher.finalize()))
2004}
2005
2006#[cfg(test)]
2007fn sha256_bytes_hex(bytes: &[u8]) -> String {
2008    bytes_to_hex(&Sha256::digest(bytes))
2009}
2010
2011fn bytes_to_hex(bytes: &[u8]) -> String {
2012    let mut out = String::with_capacity(bytes.len() * 2);
2013    for byte in bytes {
2014        use std::fmt::Write as _;
2015        write!(&mut out, "{byte:02x}").expect("writing to String cannot fail");
2016    }
2017    out
2018}
2019
2020/// Deterministic path safety for relative artifact URIs. Rejects empty values,
2021/// control characters, absolute paths (POSIX root, Windows root or drive
2022/// prefix), URI schemes such as `http://`, `s3://` or `file://` (any colon in
2023/// the leading path segment) and any `..` traversal component. Parsing is
2024/// platform-independent so portable manifests validate identically everywhere;
2025/// it adds no dependency.
2026fn validate_relative_artifact_uri(artifact_id: &ArtifactId, uri: &str) -> Result<()> {
2027    if uri.is_empty() {
2028        return Err(DagMlError::RuntimeValidation(format!(
2029            "artifact `{artifact_id}` has empty uri"
2030        )));
2031    }
2032    if uri.chars().any(char::is_control) {
2033        return Err(DagMlError::RuntimeValidation(format!(
2034            "artifact `{artifact_id}` uri has control characters"
2035        )));
2036    }
2037    if uri.starts_with('/') || uri.starts_with('\\') {
2038        return Err(DagMlError::RuntimeValidation(format!(
2039            "artifact `{artifact_id}` uri `{uri}` must be a relative path"
2040        )));
2041    }
2042    let mut prefix = uri.chars();
2043    if let (Some(drive), Some(':')) = (prefix.next(), prefix.next()) {
2044        if drive.is_ascii_alphabetic() {
2045            return Err(DagMlError::RuntimeValidation(format!(
2046                "artifact `{artifact_id}` uri `{uri}` must be a relative path"
2047            )));
2048        }
2049    }
2050    // Reject URI schemes (`http://`, `s3://`, `file://`, ...) and any other
2051    // colon in the leading path segment. A scheme always places a colon in the
2052    // first segment, so a strictly relative artifact path never carries one.
2053    let first_segment = uri.split(['/', '\\']).next().unwrap_or(uri);
2054    if first_segment.contains(':') {
2055        return Err(DagMlError::RuntimeValidation(format!(
2056            "artifact `{artifact_id}` uri `{uri}` must not include a scheme or colon in its first path segment"
2057        )));
2058    }
2059    for segment in uri.split(['/', '\\']) {
2060        if segment == ".." {
2061            return Err(DagMlError::RuntimeValidation(format!(
2062                "artifact `{artifact_id}` uri `{uri}` must not contain `..` components"
2063            )));
2064        }
2065    }
2066    Ok(())
2067}
2068
2069fn validate_runtime_fingerprint(label: &str, value: &str) -> Result<()> {
2070    if value.len() != 64 || !value.bytes().all(|byte| byte.is_ascii_hexdigit()) {
2071        return Err(DagMlError::RuntimeValidation(format!(
2072            "{label} fingerprint must be a 64-character hex digest"
2073        )));
2074    }
2075    Ok(())
2076}
2077
2078fn read_runtime_json<T: serde::de::DeserializeOwned>(path: &Path, label: &str) -> Result<T> {
2079    let data = fs::read(path).map_err(|err| {
2080        DagMlError::RuntimeValidation(format!(
2081            "failed to read {label} at {}: {err}",
2082            path.display()
2083        ))
2084    })?;
2085    serde_json::from_slice(&data).map_err(|err| {
2086        DagMlError::RuntimeValidation(format!(
2087            "failed to parse {label} at {}: {err}",
2088            path.display()
2089        ))
2090    })
2091}
2092
2093fn write_runtime_json<T: Serialize>(path: &Path, value: &T, label: &str) -> Result<()> {
2094    let mut data = serde_json::to_vec_pretty(value).map_err(|err| {
2095        DagMlError::RuntimeValidation(format!("failed to serialize {label}: {err}"))
2096    })?;
2097    data.push(b'\n');
2098    fs::write(path, data).map_err(|err| {
2099        DagMlError::RuntimeValidation(format!(
2100            "failed to write {label} at {}: {err}",
2101            path.display()
2102        ))
2103    })
2104}
2105
2106#[derive(Clone, Debug, Default)]
2107pub struct InMemoryPredictionCacheStore {
2108    payloads: BTreeMap<String, crate::bundle::BundlePredictionCachePayload>,
2109    materialization_records: RefCell<Vec<PredictionCacheMaterializationRecord>>,
2110}
2111
2112impl InMemoryPredictionCacheStore {
2113    pub fn from_payloads(
2114        bundle: &ExecutionBundle,
2115        payloads: BundlePredictionCachePayloadSet,
2116    ) -> Result<Self> {
2117        payloads.validate_against_bundle(bundle)?;
2118        Ok(Self {
2119            payloads: payloads
2120                .caches
2121                .into_iter()
2122                .map(|payload| (payload.requirement_key.clone(), payload))
2123                .collect(),
2124            materialization_records: RefCell::new(Vec::new()),
2125        })
2126    }
2127
2128    pub fn payload_count(&self) -> usize {
2129        self.payloads.len()
2130    }
2131
2132    pub fn materialization_records(&self) -> Vec<PredictionCacheMaterializationRecord> {
2133        self.materialization_records.borrow().clone()
2134    }
2135}
2136
2137impl RuntimePredictionCacheStore for InMemoryPredictionCacheStore {
2138    fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>> {
2139        let payload = self.payloads.get(requirement_key).ok_or_else(|| {
2140            DagMlError::RuntimeValidation(format!(
2141                "prediction cache store is missing requirement `{requirement_key}`"
2142            ))
2143        })?;
2144        payload.validate()?;
2145        if payload.prediction_level != PredictionLevel::Sample {
2146            return Err(DagMlError::RuntimeValidation(format!(
2147                "prediction cache store requirement `{requirement_key}` contains {:?} predictions, not sample blocks",
2148                payload.prediction_level
2149            )));
2150        }
2151        Ok(payload.blocks.clone())
2152    }
2153
2154    fn load_aggregated_blocks(
2155        &self,
2156        requirement_key: &str,
2157    ) -> Result<Vec<AggregatedPredictionBlock>> {
2158        let payload = self.payloads.get(requirement_key).ok_or_else(|| {
2159            DagMlError::RuntimeValidation(format!(
2160                "prediction cache store is missing requirement `{requirement_key}`"
2161            ))
2162        })?;
2163        payload.validate()?;
2164        if payload.prediction_level == PredictionLevel::Sample {
2165            return Err(DagMlError::RuntimeValidation(format!(
2166                "prediction cache store requirement `{requirement_key}` contains sample predictions, not aggregated blocks"
2167            )));
2168        }
2169        Ok(payload.aggregated_blocks.clone())
2170    }
2171
2172    fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef> {
2173        request.requirement.validate()?;
2174        request.cache.validate()?;
2175        if request.requirement.key() != request.cache.requirement_key {
2176            return Err(DagMlError::RuntimeValidation(format!(
2177                "prediction cache materialization request for `{}` uses cache `{}` with mismatched requirement `{}`",
2178                request.requirement.key(),
2179                request.cache.cache_id,
2180                request.cache.requirement_key
2181            )));
2182        }
2183        let payload = self
2184            .payloads
2185            .get(&request.cache.requirement_key)
2186            .ok_or_else(|| {
2187                DagMlError::RuntimeValidation(format!(
2188                    "prediction cache store is missing requirement `{}`",
2189                    request.cache.requirement_key
2190                ))
2191            })?;
2192        validate_prediction_cache_payload_matches_record(payload, &request.cache)?;
2193        let fingerprint = stable_json_fingerprint(&(
2194            &request.run_id,
2195            &request.bundle_id,
2196            request.phase,
2197            &request.variant_id,
2198            &request.cache.requirement_key,
2199            &request.cache.cache_id,
2200            request.cache.prediction_level,
2201            &request.cache.content_fingerprint,
2202        ))?;
2203        let handle = HandleRef {
2204            handle: u64::from_str_radix(&fingerprint[..16], 16)
2205                .expect("sha256 hex prefix should fit into u64"),
2206            kind: HandleKind::Prediction,
2207            owner_controller: request.producer_controller_id.clone(),
2208        };
2209        self.materialization_records
2210            .borrow_mut()
2211            .push(PredictionCacheMaterializationRecord {
2212                run_id: request.run_id.clone(),
2213                bundle_id: request.bundle_id.clone(),
2214                phase: request.phase,
2215                variant_id: request.variant_id.clone(),
2216                requirement_key: request.cache.requirement_key.clone(),
2217                cache_id: request.cache.cache_id.clone(),
2218                handle: handle.clone(),
2219            });
2220        Ok(handle)
2221    }
2222}
2223
2224#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2225pub struct PredictionInputSpec {
2226    pub producer_node: NodeId,
2227    pub source_port: String,
2228    pub target_port: String,
2229    pub partition: PredictionPartition,
2230    #[serde(default = "default_runtime_prediction_level")]
2231    pub prediction_level: PredictionLevel,
2232    pub fold_id: Option<FoldId>,
2233    #[serde(default)]
2234    pub fold_ids: Vec<FoldId>,
2235    #[serde(default, skip_serializing_if = "Vec::is_empty")]
2236    pub unit_ids: Vec<PredictionUnitId>,
2237    #[serde(default)]
2238    pub sample_ids: Vec<SampleId>,
2239    pub prediction_width: usize,
2240    #[serde(default)]
2241    pub target_names: Vec<String>,
2242}
2243
2244#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2245pub struct ArtifactInputSpec {
2246    pub node_id: NodeId,
2247    pub controller_id: ControllerId,
2248    pub artifact: ArtifactRef,
2249    pub params_fingerprint: String,
2250    #[serde(default)]
2251    pub data_requirement_keys: Vec<String>,
2252    #[serde(default)]
2253    pub prediction_requirement_keys: Vec<String>,
2254}
2255
2256impl ArtifactInputSpec {
2257    fn from_refit_record(record: &RefitArtifactRecord) -> Result<Self> {
2258        record.validate()?;
2259        Ok(Self {
2260            node_id: record.node_id.clone(),
2261            controller_id: record.controller_id.clone(),
2262            artifact: record.artifact.clone(),
2263            params_fingerprint: record.params_fingerprint.clone(),
2264            data_requirement_keys: record.data_requirement_keys.clone(),
2265            prediction_requirement_keys: record.prediction_requirement_keys.clone(),
2266        })
2267    }
2268}
2269
2270fn default_runtime_prediction_level() -> PredictionLevel {
2271    PredictionLevel::Sample
2272}
2273
2274#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2275pub struct NodeTask {
2276    pub run_id: RunId,
2277    pub node_plan: NodePlan,
2278    pub phase: Phase,
2279    pub variant_id: Option<VariantId>,
2280    #[serde(default)]
2281    pub variant: Option<VariantExecutionSpec>,
2282    pub fold_id: Option<FoldId>,
2283    #[serde(default)]
2284    pub branch_path: Vec<BranchId>,
2285    #[serde(default)]
2286    pub input_handles: BTreeMap<String, HandleRef>,
2287    #[serde(default)]
2288    pub data_views: BTreeMap<String, DataProviderViewSpec>,
2289    #[serde(default)]
2290    pub prediction_inputs: BTreeMap<String, PredictionInputSpec>,
2291    #[serde(default)]
2292    pub artifact_inputs: BTreeMap<String, ArtifactInputSpec>,
2293    /// Nested (inner) CV fold set for this node in the current outer fold, built
2294    /// by the runtime from the outer fold's training samples when an effective
2295    /// `inner_cv` policy applies (FIT_CV only). `None` otherwise. Leakage-safe by
2296    /// construction (inner ⊆ outer-train); see [`crate::fold::NestedCvSpec`].
2297    #[serde(default, skip_serializing_if = "Option::is_none")]
2298    pub inner_fold_set: Option<FoldSet>,
2299    #[serde(default, skip_serializing_if = "FitInfluenceTask::is_default")]
2300    pub fit_influence: FitInfluenceTask,
2301    pub seed: Option<u64>,
2302}
2303
2304#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
2305#[serde(rename_all = "snake_case")]
2306pub enum FitInfluenceMechanism {
2307    UniformRows,
2308    SampleWeights,
2309    RowResampling,
2310    BackendLossWeights,
2311    ScorerOnly,
2312}
2313
2314#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2315pub struct FitInfluenceTask {
2316    pub requested_policy: FitInfluencePolicy,
2317    pub effective_policy: FitInfluencePolicy,
2318    pub mechanism: FitInfluenceMechanism,
2319    #[serde(default, skip_serializing_if = "Vec::is_empty")]
2320    pub row_weights: Vec<f64>,
2321    #[serde(default, skip_serializing_if = "Vec::is_empty")]
2322    pub warnings: Vec<String>,
2323}
2324
2325impl Default for FitInfluenceTask {
2326    fn default() -> Self {
2327        Self {
2328            requested_policy: FitInfluencePolicy::UniformRows,
2329            effective_policy: FitInfluencePolicy::UniformRows,
2330            mechanism: FitInfluenceMechanism::UniformRows,
2331            row_weights: Vec::new(),
2332            warnings: Vec::new(),
2333        }
2334    }
2335}
2336
2337impl FitInfluenceTask {
2338    fn is_default(&self) -> bool {
2339        self == &Self::default()
2340    }
2341
2342    pub fn diagnostic(&self) -> FitInfluenceDiagnostic {
2343        FitInfluenceDiagnostic {
2344            requested_policy: self.requested_policy,
2345            effective_policy: self.effective_policy,
2346            mechanism: self.mechanism,
2347            fallback_used: !self.warnings.is_empty(),
2348            row_weight_count: self.row_weights.len(),
2349            warnings: self.warnings.clone(),
2350        }
2351    }
2352
2353    pub fn validate(&self) -> Result<()> {
2354        if !self
2355            .row_weights
2356            .iter()
2357            .all(|weight| weight.is_finite() && *weight > 0.0)
2358        {
2359            return Err(DagMlError::RuntimeValidation(
2360                "fit influence row_weights must be finite and > 0".to_string(),
2361            ));
2362        }
2363        if self
2364            .warnings
2365            .iter()
2366            .any(|warning| warning.trim().is_empty())
2367        {
2368            return Err(DagMlError::RuntimeValidation(
2369                "fit influence warnings must not be empty".to_string(),
2370            ));
2371        }
2372        match self.effective_policy {
2373            FitInfluencePolicy::EqualSampleInfluence | FitInfluencePolicy::BackendLossWeight
2374                if self.row_weights.is_empty() =>
2375            {
2376                return Err(DagMlError::RuntimeValidation(format!(
2377                    "fit influence {:?} requires row_weights",
2378                    self.effective_policy
2379                )));
2380            }
2381            _ => {}
2382        }
2383        if self.requested_policy == FitInfluencePolicy::StrictWeightSupport
2384            && self.effective_policy == FitInfluencePolicy::UniformRows
2385        {
2386            return Err(DagMlError::RuntimeValidation(
2387                "strict fit influence cannot fall back to uniform_rows".to_string(),
2388            ));
2389        }
2390        Ok(())
2391    }
2392}
2393
2394#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2395pub struct FitInfluenceDiagnostic {
2396    pub requested_policy: FitInfluencePolicy,
2397    pub effective_policy: FitInfluencePolicy,
2398    pub mechanism: FitInfluenceMechanism,
2399    #[serde(default)]
2400    pub fallback_used: bool,
2401    #[serde(default)]
2402    pub row_weight_count: usize,
2403    #[serde(default, skip_serializing_if = "Vec::is_empty")]
2404    pub warnings: Vec<String>,
2405}
2406
2407impl FitInfluenceDiagnostic {
2408    pub fn validate(&self, task: &NodeTask) -> Result<()> {
2409        if self.requested_policy != task.fit_influence.requested_policy {
2410            return Err(DagMlError::RuntimeValidation(format!(
2411                "fit influence diagnostic requested_policy {:?} does not match task {:?}",
2412                self.requested_policy, task.fit_influence.requested_policy
2413            )));
2414        }
2415        if self.effective_policy != task.fit_influence.effective_policy {
2416            return Err(DagMlError::RuntimeValidation(format!(
2417                "fit influence diagnostic effective_policy {:?} does not match task {:?}",
2418                self.effective_policy, task.fit_influence.effective_policy
2419            )));
2420        }
2421        if self.mechanism != task.fit_influence.mechanism {
2422            return Err(DagMlError::RuntimeValidation(format!(
2423                "fit influence diagnostic mechanism {:?} does not match task {:?}",
2424                self.mechanism, task.fit_influence.mechanism
2425            )));
2426        }
2427        if self.row_weight_count != task.fit_influence.row_weights.len() {
2428            return Err(DagMlError::RuntimeValidation(format!(
2429                "fit influence diagnostic row_weight_count {} does not match task {}",
2430                self.row_weight_count,
2431                task.fit_influence.row_weights.len()
2432            )));
2433        }
2434        if self
2435            .warnings
2436            .iter()
2437            .any(|warning| warning.trim().is_empty())
2438        {
2439            return Err(DagMlError::RuntimeValidation(
2440                "fit influence diagnostic warnings must not be empty".to_string(),
2441            ));
2442        }
2443        Ok(())
2444    }
2445}
2446
2447#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2448pub struct VariantExecutionSpec {
2449    pub variant_id: VariantId,
2450    #[serde(default)]
2451    pub choices: BTreeMap<String, GenerationChoice>,
2452    pub fingerprint: String,
2453    pub seed: Option<u64>,
2454}
2455
2456impl VariantExecutionSpec {
2457    pub fn from_plan(variant: &VariantPlan) -> Self {
2458        Self {
2459            variant_id: variant.variant_id.clone(),
2460            choices: variant.choices.clone(),
2461            fingerprint: variant.fingerprint.clone(),
2462            seed: variant.seed,
2463        }
2464    }
2465
2466    pub fn validate(&self) -> Result<()> {
2467        if self.fingerprint.trim().is_empty() {
2468            return Err(DagMlError::RuntimeValidation(format!(
2469                "variant `{}` has an empty fingerprint in task context",
2470                self.variant_id
2471            )));
2472        }
2473        for (dimension_name, choice) in &self.choices {
2474            if dimension_name.trim().is_empty() {
2475                return Err(DagMlError::RuntimeValidation(format!(
2476                    "variant `{}` has an empty generation dimension name",
2477                    self.variant_id
2478                )));
2479            }
2480            if choice.label.trim().is_empty() {
2481                return Err(DagMlError::RuntimeValidation(format!(
2482                    "variant `{}` has an empty choice label for dimension `{dimension_name}`",
2483                    self.variant_id
2484                )));
2485            }
2486            for override_spec in &choice.param_overrides {
2487                if override_spec.params.is_empty() {
2488                    return Err(DagMlError::RuntimeValidation(format!(
2489                        "variant `{}` has an empty param override for node `{}`",
2490                        self.variant_id, override_spec.node_id
2491                    )));
2492                }
2493                for param_key in override_spec.params.keys() {
2494                    if param_key.trim().is_empty() {
2495                        return Err(DagMlError::RuntimeValidation(format!(
2496                            "variant `{}` has an empty param override key for node `{}`",
2497                            self.variant_id, override_spec.node_id
2498                        )));
2499                    }
2500                }
2501            }
2502        }
2503        self.param_overrides_by_node()?;
2504        Ok(())
2505    }
2506
2507    pub fn effective_params_for_node(
2508        &self,
2509        node_id: &NodeId,
2510        base_params: &BTreeMap<String, serde_json::Value>,
2511    ) -> Result<BTreeMap<String, serde_json::Value>> {
2512        let overrides_by_node = self.param_overrides_by_node()?;
2513        let Some(overrides) = overrides_by_node.get(node_id) else {
2514            return Ok(base_params.clone());
2515        };
2516        let mut params = base_params.clone();
2517        params.extend(overrides.clone());
2518        Ok(params)
2519    }
2520
2521    fn param_overrides_by_node(
2522        &self,
2523    ) -> Result<BTreeMap<NodeId, BTreeMap<String, serde_json::Value>>> {
2524        let mut overrides = BTreeMap::<NodeId, BTreeMap<String, serde_json::Value>>::new();
2525        let mut owners = BTreeMap::<(NodeId, String), String>::new();
2526        for (dimension_name, choice) in &self.choices {
2527            for override_spec in &choice.param_overrides {
2528                for (param_key, value) in &override_spec.params {
2529                    let owner_key = (override_spec.node_id.clone(), param_key.clone());
2530                    if let Some(previous) =
2531                        owners.insert(owner_key, format!("{dimension_name}:{}", choice.label))
2532                    {
2533                        return Err(DagMlError::RuntimeValidation(format!(
2534                            "variant `{}` has conflicting generation overrides for `{}.{}` from `{previous}` and `{}:{}`",
2535                            self.variant_id,
2536                            override_spec.node_id,
2537                            param_key,
2538                            dimension_name,
2539                            choice.label
2540                        )));
2541                    }
2542                    overrides
2543                        .entry(override_spec.node_id.clone())
2544                        .or_default()
2545                        .insert(param_key.clone(), value.clone());
2546                }
2547            }
2548        }
2549        Ok(overrides)
2550    }
2551}
2552
2553/// An EXPLAIN-phase output block (ADR-12 explain contract). Explanations are a
2554/// node *output* returned in the [`NodeResult`] — like predictions, they cross as
2555/// data, not as an opaque host handle. The `payload` shape is controller-defined
2556/// (e.g. per-feature importances); the core does not interpret it. Explanations
2557/// are only valid in the `EXPLAIN` phase.
2558#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2559pub struct ExplanationBlock {
2560    /// Node whose model the explanation describes (must equal the producing node).
2561    pub producer_node: NodeId,
2562    /// Stable explanation method identifier, e.g. `shap`, `permutation_importance`.
2563    pub method: String,
2564    /// Optional target/output name the explanation pertains to.
2565    #[serde(default, skip_serializing_if = "Option::is_none")]
2566    pub target_name: Option<String>,
2567    /// Controller-defined explanation payload as canonical JSON.
2568    pub payload: serde_json::Value,
2569}
2570
2571impl ExplanationBlock {
2572    /// Validate the intrinsic shape of the explanation block (method/target_name
2573    /// non-empty). Producer identity is checked against the node in
2574    /// [`NodeResult::validate_for_task`].
2575    pub fn validate(&self) -> Result<()> {
2576        if self.method.trim().is_empty() {
2577            return Err(DagMlError::RuntimeValidation(
2578                "explanation method must be a non-empty identifier".to_string(),
2579            ));
2580        }
2581        if let Some(name) = &self.target_name {
2582            if name.trim().is_empty() {
2583                return Err(DagMlError::RuntimeValidation(
2584                    "explanation target_name must be non-empty when present".to_string(),
2585                ));
2586            }
2587        }
2588        Ok(())
2589    }
2590}
2591
2592#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2593pub struct NodeResult {
2594    pub node_id: NodeId,
2595    #[serde(default)]
2596    pub outputs: BTreeMap<String, HandleRef>,
2597    #[serde(default)]
2598    pub predictions: Vec<PredictionBlock>,
2599    #[serde(default)]
2600    pub observation_predictions: Vec<ObservationPredictionBlock>,
2601    #[serde(default)]
2602    pub aggregated_predictions: Vec<AggregatedPredictionBlock>,
2603    #[serde(default)]
2604    pub explanations: Vec<ExplanationBlock>,
2605    #[serde(default)]
2606    pub shape_deltas: Vec<ShapeDelta>,
2607    #[serde(default)]
2608    pub artifacts: Vec<ArtifactRef>,
2609    #[serde(default)]
2610    pub artifact_handles: BTreeMap<ArtifactId, HandleRef>,
2611    #[serde(default, skip_serializing_if = "Vec::is_empty")]
2612    pub fit_influence_diagnostics: Vec<FitInfluenceDiagnostic>,
2613    pub lineage: LineageRecord,
2614}
2615
2616impl NodeResult {
2617    pub fn validate_for_task(&self, task: &NodeTask) -> Result<()> {
2618        if self.node_id != task.node_plan.node_id {
2619            return Err(DagMlError::RuntimeValidation(format!(
2620                "task for `{}` returned result for `{}`",
2621                task.node_plan.node_id, self.node_id
2622            )));
2623        }
2624        if self.lineage.node_id != task.node_plan.node_id {
2625            return Err(DagMlError::RuntimeValidation(format!(
2626                "lineage for task `{}` references node `{}`",
2627                task.node_plan.node_id, self.lineage.node_id
2628            )));
2629        }
2630        if self.lineage.phase != task.phase {
2631            return Err(DagMlError::RuntimeValidation(format!(
2632                "lineage for node `{}` has phase {:?}, expected {:?}",
2633                task.node_plan.node_id, self.lineage.phase, task.phase
2634            )));
2635        }
2636        if self.lineage.run_id != task.run_id {
2637            return Err(DagMlError::RuntimeValidation(format!(
2638                "lineage for node `{}` has run `{}`, expected `{}`",
2639                task.node_plan.node_id, self.lineage.run_id, task.run_id
2640            )));
2641        }
2642        if self.lineage.controller_id != task.node_plan.controller_id {
2643            return Err(DagMlError::RuntimeValidation(format!(
2644                "lineage for node `{}` has controller `{}`, expected `{}`",
2645                task.node_plan.node_id, self.lineage.controller_id, task.node_plan.controller_id
2646            )));
2647        }
2648        if self.lineage.controller_version != task.node_plan.controller_version {
2649            return Err(DagMlError::RuntimeValidation(format!(
2650                "lineage for node `{}` has controller version `{}`, expected `{}`",
2651                task.node_plan.node_id,
2652                self.lineage.controller_version,
2653                task.node_plan.controller_version
2654            )));
2655        }
2656        if self.lineage.variant_id != task.variant_id {
2657            return Err(DagMlError::RuntimeValidation(format!(
2658                "lineage for node `{}` has variant {:?}, expected {:?}",
2659                task.node_plan.node_id, self.lineage.variant_id, task.variant_id
2660            )));
2661        }
2662        if let Some(variant) = &task.variant {
2663            variant.validate()?;
2664            if Some(&variant.variant_id) != task.variant_id.as_ref() {
2665                return Err(DagMlError::RuntimeValidation(format!(
2666                    "task for node `{}` has variant context `{}` but variant_id {:?}",
2667                    task.node_plan.node_id, variant.variant_id, task.variant_id
2668                )));
2669            }
2670        }
2671        if self.lineage.fold_id != task.fold_id {
2672            return Err(DagMlError::RuntimeValidation(format!(
2673                "lineage for node `{}` has fold {:?}, expected {:?}",
2674                task.node_plan.node_id, self.lineage.fold_id, task.fold_id
2675            )));
2676        }
2677        if self.lineage.branch_path != task.branch_path {
2678            return Err(DagMlError::RuntimeValidation(format!(
2679                "lineage for node `{}` has branch path {:?}, expected {:?}",
2680                task.node_plan.node_id, self.lineage.branch_path, task.branch_path
2681            )));
2682        }
2683        if self.lineage.seed != task.seed {
2684            return Err(DagMlError::RuntimeValidation(format!(
2685                "lineage for node `{}` has seed {:?}, expected {:?}",
2686                task.node_plan.node_id, self.lineage.seed, task.seed
2687            )));
2688        }
2689        if self.lineage.params_fingerprint != task.node_plan.params_fingerprint {
2690            return Err(DagMlError::RuntimeValidation(format!(
2691                "lineage for node `{}` has params fingerprint `{}`, expected `{}`",
2692                task.node_plan.node_id,
2693                self.lineage.params_fingerprint,
2694                task.node_plan.params_fingerprint
2695            )));
2696        }
2697        task.fit_influence.validate()?;
2698        for diagnostic in &self.fit_influence_diagnostics {
2699            diagnostic.validate(task)?;
2700        }
2701        validate_lineage_shape_fingerprints(&self.lineage, task)?;
2702        if !self.explanations.is_empty() && task.phase != Phase::Explain {
2703            return Err(DagMlError::RuntimeValidation(format!(
2704                "node `{}` returned explanations outside the EXPLAIN phase",
2705                task.node_plan.node_id
2706            )));
2707        }
2708        for explanation in &self.explanations {
2709            explanation.validate()?;
2710            if explanation.producer_node != self.node_id {
2711                return Err(DagMlError::RuntimeValidation(format!(
2712                    "node `{}` returned an explanation produced by `{}`",
2713                    self.node_id, explanation.producer_node
2714                )));
2715            }
2716        }
2717        for (port, handle) in &self.outputs {
2718            if handle.owner_controller != task.node_plan.controller_id {
2719                return Err(DagMlError::RuntimeValidation(format!(
2720                    "node `{}` output `{port}` is owned by `{}`, expected `{}`",
2721                    task.node_plan.node_id, handle.owner_controller, task.node_plan.controller_id
2722                )));
2723            }
2724        }
2725        let mut artifact_ids = BTreeSet::new();
2726        for artifact in &self.artifacts {
2727            artifact.validate()?;
2728            if !artifact_ids.insert(artifact.id.clone()) {
2729                return Err(DagMlError::RuntimeValidation(format!(
2730                    "node `{}` emitted duplicate artifact `{}`",
2731                    task.node_plan.node_id, artifact.id
2732                )));
2733            }
2734            if artifact.controller_id != task.node_plan.controller_id {
2735                return Err(DagMlError::RuntimeValidation(format!(
2736                    "node `{}` emitted artifact `{}` for controller `{}`, expected `{}`",
2737                    task.node_plan.node_id,
2738                    artifact.id,
2739                    artifact.controller_id,
2740                    task.node_plan.controller_id
2741                )));
2742            }
2743            let handle = self.artifact_handles.get(&artifact.id).ok_or_else(|| {
2744                DagMlError::RuntimeValidation(format!(
2745                    "node `{}` emitted artifact `{}` without artifact handle",
2746                    task.node_plan.node_id, artifact.id
2747                ))
2748            })?;
2749            if !matches!(handle.kind, HandleKind::Model | HandleKind::Artifact) {
2750                return Err(DagMlError::RuntimeValidation(format!(
2751                    "node `{}` emitted artifact `{}` with non-artifact/model handle kind {:?}",
2752                    task.node_plan.node_id, artifact.id, handle.kind
2753                )));
2754            }
2755            if handle.owner_controller != task.node_plan.controller_id {
2756                return Err(DagMlError::RuntimeValidation(format!(
2757                    "node `{}` emitted artifact `{}` owned by `{}`, expected `{}`",
2758                    task.node_plan.node_id,
2759                    artifact.id,
2760                    handle.owner_controller,
2761                    task.node_plan.controller_id
2762                )));
2763            }
2764        }
2765        for artifact_id in self.artifact_handles.keys() {
2766            if !self
2767                .artifacts
2768                .iter()
2769                .any(|artifact| &artifact.id == artifact_id)
2770            {
2771                return Err(DagMlError::RuntimeValidation(format!(
2772                    "node `{}` emitted artifact handle for undeclared artifact `{artifact_id}`",
2773                    task.node_plan.node_id
2774                )));
2775            }
2776        }
2777        for artifact in &self.artifacts {
2778            if !self
2779                .lineage
2780                .artifact_refs
2781                .iter()
2782                .any(|lineage_artifact| lineage_artifact == artifact)
2783            {
2784                return Err(DagMlError::RuntimeValidation(format!(
2785                    "node `{}` emitted artifact `{}` without matching lineage artifact ref",
2786                    task.node_plan.node_id, artifact.id
2787                )));
2788            }
2789        }
2790        for artifact in &self.lineage.artifact_refs {
2791            if !self
2792                .artifacts
2793                .iter()
2794                .any(|emitted_artifact| emitted_artifact == artifact)
2795            {
2796                return Err(DagMlError::RuntimeValidation(format!(
2797                    "node `{}` lineage references undeclared artifact `{}`",
2798                    task.node_plan.node_id, artifact.id
2799                )));
2800            }
2801        }
2802        for prediction in &self.predictions {
2803            prediction.validate_shape()?;
2804            if prediction.producer_node != task.node_plan.node_id {
2805                return Err(DagMlError::RuntimeValidation(format!(
2806                    "node `{}` emitted prediction for producer `{}`",
2807                    task.node_plan.node_id, prediction.producer_node
2808                )));
2809            }
2810            validate_prediction_scope(prediction, task)?;
2811        }
2812        for prediction in &self.observation_predictions {
2813            prediction.validate_shape()?;
2814            if prediction.producer_node != task.node_plan.node_id {
2815                return Err(DagMlError::RuntimeValidation(format!(
2816                    "node `{}` emitted observation prediction for producer `{}`",
2817                    task.node_plan.node_id, prediction.producer_node
2818                )));
2819            }
2820            validate_observation_prediction_scope(prediction, task)?;
2821        }
2822        for prediction in &self.aggregated_predictions {
2823            prediction.validate_shape()?;
2824            if prediction.producer_node != task.node_plan.node_id {
2825                return Err(DagMlError::RuntimeValidation(format!(
2826                    "node `{}` emitted aggregated prediction for producer `{}`",
2827                    task.node_plan.node_id, prediction.producer_node
2828                )));
2829            }
2830            validate_aggregated_prediction_scope(prediction, task)?;
2831        }
2832        for delta in &self.shape_deltas {
2833            delta.validate()?;
2834            if delta.node_id != task.node_plan.node_id {
2835                return Err(DagMlError::RuntimeValidation(format!(
2836                    "node `{}` emitted shape delta for `{}`",
2837                    task.node_plan.node_id, delta.node_id
2838                )));
2839            }
2840            validate_shape_delta_for_task(delta, task)?;
2841        }
2842        self.lineage.validate()
2843    }
2844}
2845
2846fn validate_lineage_shape_fingerprints(lineage: &LineageRecord, task: &NodeTask) -> Result<()> {
2847    let Some(shape_plan) = &task.node_plan.shape_plan else {
2848        if lineage.data_model_shape_fingerprint.is_some()
2849            || lineage.aggregation_policy_fingerprint.is_some()
2850        {
2851            return Err(DagMlError::RuntimeValidation(format!(
2852                "lineage for node `{}` carries shape fingerprints but the node has no shape plan",
2853                task.node_plan.node_id
2854            )));
2855        }
2856        return Ok(());
2857    };
2858
2859    if let Some(actual) = &lineage.data_model_shape_fingerprint {
2860        let expected = stable_json_fingerprint(shape_plan)?;
2861        if actual != &expected {
2862            return Err(DagMlError::RuntimeValidation(format!(
2863                "lineage for node `{}` has data/model shape fingerprint `{actual}`, expected `{expected}`",
2864                task.node_plan.node_id
2865            )));
2866        }
2867    }
2868    if let Some(actual) = &lineage.aggregation_policy_fingerprint {
2869        let expected = stable_json_fingerprint(&shape_plan.aggregation_policy)?;
2870        if actual != &expected {
2871            return Err(DagMlError::RuntimeValidation(format!(
2872                "lineage for node `{}` has aggregation policy fingerprint `{actual}`, expected `{expected}`",
2873                task.node_plan.node_id
2874            )));
2875        }
2876    }
2877    Ok(())
2878}
2879
2880fn validate_shape_delta_for_task(delta: &ShapeDelta, task: &NodeTask) -> Result<()> {
2881    let Some(shape_plan) = &task.node_plan.shape_plan else {
2882        return Ok(());
2883    };
2884    if delta.kind == ShapeDeltaKind::Feature {
2885        if let Some(expected) = &shape_plan.feature_schema_fingerprint {
2886            if &delta.before_fingerprint != expected {
2887                return Err(DagMlError::RuntimeValidation(format!(
2888                    "node `{}` emitted feature shape delta from `{}`, expected current schema `{expected}`",
2889                    task.node_plan.node_id, delta.before_fingerprint
2890                )));
2891            }
2892        }
2893    }
2894    Ok(())
2895}
2896
2897fn validate_prediction_scope(prediction: &PredictionBlock, task: &NodeTask) -> Result<()> {
2898    if prediction.partition != PredictionPartition::Validation {
2899        return Ok(());
2900    }
2901    if prediction.fold_id != task.fold_id {
2902        return Err(DagMlError::RuntimeValidation(format!(
2903            "node `{}` emitted validation predictions for fold {:?}, expected {:?}",
2904            task.node_plan.node_id, prediction.fold_id, task.fold_id
2905        )));
2906    }
2907    if task.phase == Phase::FitCv
2908        && task.fold_id.is_some()
2909        && (!task.node_plan.data_bindings.is_empty() || !task.data_views.is_empty())
2910    {
2911        let validation_sample_ids = validation_view_sample_ids(task).ok_or_else(|| {
2912            DagMlError::RuntimeValidation(format!(
2913                "node `{}` emitted validation predictions without a fold-validation data view",
2914                task.node_plan.node_id
2915            ))
2916        })?;
2917        for sample_id in &prediction.sample_ids {
2918            if !validation_sample_ids.contains(sample_id) {
2919                return Err(DagMlError::RuntimeValidation(format!(
2920                    "node `{}` emitted validation prediction for sample `{}` outside its validation view",
2921                    task.node_plan.node_id, sample_id
2922                )));
2923            }
2924        }
2925    }
2926    Ok(())
2927}
2928
2929fn validate_observation_prediction_scope(
2930    prediction: &ObservationPredictionBlock,
2931    task: &NodeTask,
2932) -> Result<()> {
2933    if prediction.partition != PredictionPartition::Validation {
2934        return Ok(());
2935    }
2936    if prediction.fold_id != task.fold_id {
2937        return Err(DagMlError::RuntimeValidation(format!(
2938            "node `{}` emitted observation validation predictions for fold {:?}, expected {:?}",
2939            task.node_plan.node_id, prediction.fold_id, task.fold_id
2940        )));
2941    }
2942    Ok(())
2943}
2944
2945fn validate_aggregated_prediction_scope(
2946    prediction: &AggregatedPredictionBlock,
2947    task: &NodeTask,
2948) -> Result<()> {
2949    if prediction.partition != PredictionPartition::Validation {
2950        return Ok(());
2951    }
2952    if prediction.fold_id != task.fold_id {
2953        return Err(DagMlError::RuntimeValidation(format!(
2954            "node `{}` emitted aggregated validation predictions for fold {:?}, expected {:?}",
2955            task.node_plan.node_id, prediction.fold_id, task.fold_id
2956        )));
2957    }
2958    Ok(())
2959}
2960
2961fn validation_view_sample_ids(task: &NodeTask) -> Option<BTreeSet<SampleId>> {
2962    let mut sample_ids = BTreeSet::new();
2963    for view in task
2964        .data_views
2965        .values()
2966        .filter(|view| view.partition == DataRequestPartition::FoldValidation)
2967    {
2968        if let Some(view_sample_ids) = &view.sample_ids {
2969            sample_ids.extend(view_sample_ids.iter().cloned());
2970        }
2971    }
2972    (!sample_ids.is_empty()).then_some(sample_ids)
2973}
2974
2975fn fit_influence_task_for_node(
2976    plan: &ExecutionPlan,
2977    node_plan: &NodePlan,
2978    data_views: &BTreeMap<String, DataProviderViewSpec>,
2979) -> Result<FitInfluenceTask> {
2980    let manifest = plan
2981        .controller_manifests
2982        .get(&node_plan.controller_id)
2983        .ok_or_else(|| {
2984            DagMlError::RuntimeValidation(format!(
2985                "node `{}` references missing controller manifest `{}`",
2986                node_plan.node_id, node_plan.controller_id
2987            ))
2988        })?;
2989    let Some(model_input_spec) = manifest.model_input_spec()? else {
2990        return Ok(FitInfluenceTask::default());
2991    };
2992    let Some(requested_policy) = model_input_spec.fit_influence_policy else {
2993        return Ok(FitInfluenceTask::default());
2994    };
2995    resolve_fit_influence_task(
2996        requested_policy,
2997        &node_plan.controller_capabilities,
2998        data_views,
2999    )
3000}
3001
3002fn resolve_fit_influence_task(
3003    requested_policy: FitInfluencePolicy,
3004    capabilities: &BTreeSet<ControllerCapability>,
3005    data_views: &BTreeMap<String, DataProviderViewSpec>,
3006) -> Result<FitInfluenceTask> {
3007    let row_weights = equal_sample_influence_weights(data_views);
3008    match requested_policy {
3009        FitInfluencePolicy::UniformRows => Ok(FitInfluenceTask {
3010            requested_policy,
3011            effective_policy: FitInfluencePolicy::UniformRows,
3012            mechanism: FitInfluenceMechanism::UniformRows,
3013            row_weights: Vec::new(),
3014            warnings: Vec::new(),
3015        }),
3016        FitInfluencePolicy::ScorerOnly => Ok(FitInfluenceTask {
3017            requested_policy,
3018            effective_policy: FitInfluencePolicy::ScorerOnly,
3019            mechanism: FitInfluenceMechanism::ScorerOnly,
3020            row_weights: Vec::new(),
3021            warnings: Vec::new(),
3022        }),
3023        FitInfluencePolicy::EqualSampleInfluence => {
3024            require_fit_influence_support(capabilities, requested_policy)?;
3025            let weights = row_weights.ok_or_else(|| {
3026                DagMlError::RuntimeValidation(
3027                    "equal_sample_influence requires task row sample ids".to_string(),
3028                )
3029            })?;
3030            Ok(FitInfluenceTask {
3031                requested_policy,
3032                effective_policy: FitInfluencePolicy::EqualSampleInfluence,
3033                mechanism: FitInfluenceMechanism::SampleWeights,
3034                row_weights: weights,
3035                warnings: Vec::new(),
3036            })
3037        }
3038        FitInfluencePolicy::ResampleEqualized => {
3039            require_fit_influence_support(capabilities, requested_policy)?;
3040            Ok(FitInfluenceTask {
3041                requested_policy,
3042                effective_policy: FitInfluencePolicy::ResampleEqualized,
3043                mechanism: FitInfluenceMechanism::RowResampling,
3044                row_weights: Vec::new(),
3045                warnings: Vec::new(),
3046            })
3047        }
3048        FitInfluencePolicy::BackendLossWeight => {
3049            require_fit_influence_support(capabilities, requested_policy)?;
3050            let weights = row_weights.ok_or_else(|| {
3051                DagMlError::RuntimeValidation(
3052                    "backend_loss_weight requires task row sample ids".to_string(),
3053                )
3054            })?;
3055            Ok(FitInfluenceTask {
3056                requested_policy,
3057                effective_policy: FitInfluencePolicy::BackendLossWeight,
3058                mechanism: FitInfluenceMechanism::BackendLossWeights,
3059                row_weights: weights,
3060                warnings: Vec::new(),
3061            })
3062        }
3063        FitInfluencePolicy::StrictWeightSupport => {
3064            require_fit_influence_support(capabilities, requested_policy)?;
3065            strict_fit_influence_task(capabilities, row_weights, requested_policy)
3066        }
3067        FitInfluencePolicy::Auto => Ok(auto_fit_influence_task(capabilities, row_weights)),
3068    }
3069}
3070
3071fn require_fit_influence_support(
3072    capabilities: &BTreeSet<ControllerCapability>,
3073    policy: FitInfluencePolicy,
3074) -> Result<()> {
3075    if capabilities_support_fit_influence(capabilities, policy) {
3076        return Ok(());
3077    }
3078    Err(DagMlError::RuntimeValidation(format!(
3079        "controller capabilities do not support requested fit influence policy {:?}",
3080        policy
3081    )))
3082}
3083
3084fn strict_fit_influence_task(
3085    capabilities: &BTreeSet<ControllerCapability>,
3086    row_weights: Option<Vec<f64>>,
3087    requested_policy: FitInfluencePolicy,
3088) -> Result<FitInfluenceTask> {
3089    if capabilities.contains(&ControllerCapability::SupportsBackendLossWeights) {
3090        let weights = row_weights.ok_or_else(|| {
3091            DagMlError::RuntimeValidation(
3092                "strict_weight_support with backend loss weights requires task row sample ids"
3093                    .to_string(),
3094            )
3095        })?;
3096        return Ok(FitInfluenceTask {
3097            requested_policy,
3098            effective_policy: FitInfluencePolicy::BackendLossWeight,
3099            mechanism: FitInfluenceMechanism::BackendLossWeights,
3100            row_weights: weights,
3101            warnings: Vec::new(),
3102        });
3103    }
3104    if capabilities.contains(&ControllerCapability::SupportsSampleWeights) {
3105        let weights = row_weights.ok_or_else(|| {
3106            DagMlError::RuntimeValidation(
3107                "strict_weight_support with sample weights requires task row sample ids"
3108                    .to_string(),
3109            )
3110        })?;
3111        return Ok(FitInfluenceTask {
3112            requested_policy,
3113            effective_policy: FitInfluencePolicy::EqualSampleInfluence,
3114            mechanism: FitInfluenceMechanism::SampleWeights,
3115            row_weights: weights,
3116            warnings: Vec::new(),
3117        });
3118    }
3119    Ok(FitInfluenceTask {
3120        requested_policy,
3121        effective_policy: FitInfluencePolicy::ResampleEqualized,
3122        mechanism: FitInfluenceMechanism::RowResampling,
3123        row_weights: Vec::new(),
3124        warnings: Vec::new(),
3125    })
3126}
3127
3128fn auto_fit_influence_task(
3129    capabilities: &BTreeSet<ControllerCapability>,
3130    row_weights: Option<Vec<f64>>,
3131) -> FitInfluenceTask {
3132    if capabilities.contains(&ControllerCapability::SupportsSampleWeights) {
3133        if let Some(weights) = row_weights.clone() {
3134            return FitInfluenceTask {
3135                requested_policy: FitInfluencePolicy::Auto,
3136                effective_policy: FitInfluencePolicy::EqualSampleInfluence,
3137                mechanism: FitInfluenceMechanism::SampleWeights,
3138                row_weights: weights,
3139                warnings: Vec::new(),
3140            };
3141        }
3142    }
3143    if capabilities.contains(&ControllerCapability::SupportsRowResampling) {
3144        return FitInfluenceTask {
3145            requested_policy: FitInfluencePolicy::Auto,
3146            effective_policy: FitInfluencePolicy::ResampleEqualized,
3147            mechanism: FitInfluenceMechanism::RowResampling,
3148            row_weights: Vec::new(),
3149            warnings: Vec::new(),
3150        };
3151    }
3152    if capabilities.contains(&ControllerCapability::SupportsBackendLossWeights) {
3153        if let Some(weights) = row_weights {
3154            return FitInfluenceTask {
3155                requested_policy: FitInfluencePolicy::Auto,
3156                effective_policy: FitInfluencePolicy::BackendLossWeight,
3157                mechanism: FitInfluenceMechanism::BackendLossWeights,
3158                row_weights: weights,
3159                warnings: Vec::new(),
3160            };
3161        }
3162    }
3163    FitInfluenceTask {
3164        requested_policy: FitInfluencePolicy::Auto,
3165        effective_policy: FitInfluencePolicy::UniformRows,
3166        mechanism: FitInfluenceMechanism::UniformRows,
3167        row_weights: Vec::new(),
3168        warnings: vec![
3169            "auto fit influence fell back to uniform_rows because no supported weighting capability was usable".to_string(),
3170        ],
3171    }
3172}
3173
3174fn equal_sample_influence_weights(
3175    data_views: &BTreeMap<String, DataProviderViewSpec>,
3176) -> Option<Vec<f64>> {
3177    let row_sample_ids = data_views
3178        .values()
3179        .filter(|view| {
3180            matches!(
3181                view.partition,
3182                DataRequestPartition::FoldTrain | DataRequestPartition::FullTrain
3183            )
3184        })
3185        .filter_map(|view| view.sample_ids.as_ref())
3186        .find(|sample_ids| !sample_ids.is_empty())
3187        .or_else(|| {
3188            data_views
3189                .values()
3190                .filter_map(|view| view.sample_ids.as_ref())
3191                .find(|sample_ids| !sample_ids.is_empty())
3192        })?;
3193    let mut counts = BTreeMap::<&SampleId, usize>::new();
3194    for sample_id in row_sample_ids {
3195        *counts.entry(sample_id).or_default() += 1;
3196    }
3197    Some(
3198        row_sample_ids
3199            .iter()
3200            .map(|sample_id| 1.0 / *counts.get(sample_id).expect("counted sample id") as f64)
3201            .collect(),
3202    )
3203}
3204
3205fn record_fit_influence_diagnostic(task: &NodeTask, result: &mut NodeResult) {
3206    if task.fit_influence.is_default() || !result.fit_influence_diagnostics.is_empty() {
3207        return;
3208    }
3209    result
3210        .fit_influence_diagnostics
3211        .push(task.fit_influence.diagnostic());
3212}
3213
3214#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3215pub struct DataMaterializationRequest {
3216    pub run_id: RunId,
3217    pub node_id: NodeId,
3218    pub input_name: String,
3219    pub phase: Phase,
3220    pub variant_id: Option<VariantId>,
3221    pub fold_id: Option<FoldId>,
3222    pub binding: crate::data::DataBinding,
3223}
3224
3225#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3226pub struct DataProviderViewSpec {
3227    #[serde(default)]
3228    pub sample_ids: Option<Vec<SampleId>>,
3229    pub partition: DataRequestPartition,
3230    #[serde(default)]
3231    pub fold_id: Option<FoldId>,
3232    #[serde(default)]
3233    pub source_ids: Option<Vec<String>>,
3234    #[serde(default)]
3235    pub columns: Option<Vec<String>>,
3236    pub include_augmented: bool,
3237    pub include_excluded: bool,
3238    #[serde(default, skip_serializing_if = "Option::is_none")]
3239    pub branch_view: Option<crate::data::BranchViewPlan>,
3240    #[serde(default)]
3241    pub extra: BTreeMap<String, serde_json::Value>,
3242}
3243
3244pub const DATA_OUTPUT_PROVENANCE_KEY: &str = "dag_ml_output";
3245pub const DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION: u32 = 1;
3246pub const DATA_OUTPUT_PROVENANCE_SCHEMA_ID: &str =
3247    "https://github.com/GBeurier/dag-ml/schemas/data_output_provenance.v1.schema.json";
3248pub const NODE_TASK_SCHEMA_VERSION: u32 = 1;
3249pub const NODE_TASK_SCHEMA_ID: &str =
3250    "https://github.com/GBeurier/dag-ml/schemas/node_task.v1.schema.json";
3251pub const NODE_RESULT_SCHEMA_VERSION: u32 = 1;
3252pub const NODE_RESULT_SCHEMA_ID: &str =
3253    "https://github.com/GBeurier/dag-ml/schemas/node_result.v1.schema.json";
3254
3255fn default_data_output_provenance_schema_version() -> u32 {
3256    DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION
3257}
3258
3259impl DataProviderViewSpec {
3260    pub fn validate(&self) -> Result<()> {
3261        validate_optional_ids("sample id", &self.sample_ids)?;
3262        validate_optional_strings("source id", &self.source_ids)?;
3263        validate_optional_strings("column", &self.columns)?;
3264        match self.partition {
3265            DataRequestPartition::FoldTrain | DataRequestPartition::FoldValidation => {
3266                if self.sample_ids.is_some() && self.fold_id.is_none() {
3267                    return Err(DagMlError::RuntimeValidation(format!(
3268                        "data provider view {:?} with explicit sample ids requires a fold id",
3269                        self.partition
3270                    )));
3271                }
3272            }
3273            DataRequestPartition::FullTrain | DataRequestPartition::Predict => {
3274                if self.fold_id.is_some() {
3275                    return Err(DagMlError::RuntimeValidation(format!(
3276                        "data provider view {:?} must not carry a fold id",
3277                        self.partition
3278                    )));
3279                }
3280            }
3281        }
3282        for key in self.extra.keys() {
3283            if key.trim().is_empty() {
3284                return Err(DagMlError::RuntimeValidation(
3285                    "data provider view extra contains an empty key".to_string(),
3286                ));
3287            }
3288        }
3289        if let Some(branch_view) = &self.branch_view {
3290            branch_view.validate()?;
3291        }
3292        self.output_provenance()?;
3293        Ok(())
3294    }
3295
3296    pub fn output_provenance(&self) -> Result<Option<DataOutputProvenance>> {
3297        let Some(value) = self.extra.get(DATA_OUTPUT_PROVENANCE_KEY) else {
3298            return Ok(None);
3299        };
3300        let provenance: DataOutputProvenance = serde_json::from_value(value.clone())?;
3301        provenance.validate()?;
3302        Ok(Some(provenance))
3303    }
3304}
3305
3306#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3307pub struct DataOutputProvenance {
3308    #[serde(default = "default_data_output_provenance_schema_version")]
3309    pub schema_version: u32,
3310    pub producer_node: NodeId,
3311    pub producer_port: String,
3312    pub producer_phase: Phase,
3313    #[serde(default)]
3314    pub variant_id: Option<VariantId>,
3315    #[serde(default)]
3316    pub fold_id: Option<FoldId>,
3317    #[serde(default)]
3318    pub shape_plan_fingerprint: Option<String>,
3319    #[serde(default)]
3320    pub aggregation_policy_fingerprint: Option<String>,
3321    #[serde(default)]
3322    pub feature_namespace: Option<String>,
3323    #[serde(default)]
3324    pub feature_schema_fingerprint: Option<String>,
3325    #[serde(default, skip_serializing_if = "Option::is_none")]
3326    pub representation_plan: Option<RepresentationPlan>,
3327    #[serde(default, skip_serializing_if = "Option::is_none")]
3328    pub representation_replay_manifest: Option<RepresentationReplayManifest>,
3329    #[serde(default, skip_serializing_if = "Option::is_none")]
3330    pub representation_compatibility: Option<RepresentationCompatibilityReport>,
3331    #[serde(default, skip_serializing_if = "Option::is_none")]
3332    pub relation_delta_fingerprint: Option<String>,
3333    #[serde(default)]
3334    pub shape_deltas: Vec<ShapeDelta>,
3335}
3336
3337impl DataOutputProvenance {
3338    pub fn validate(&self) -> Result<()> {
3339        if self.schema_version != DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION {
3340            return Err(DagMlError::RuntimeValidation(format!(
3341                "data output provenance for `{}` uses unsupported schema_version {}, expected {}",
3342                self.producer_node, self.schema_version, DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION
3343            )));
3344        }
3345        if self.producer_port.trim().is_empty() {
3346            return Err(DagMlError::RuntimeValidation(format!(
3347                "data output provenance for `{}` has empty producer_port",
3348                self.producer_node
3349            )));
3350        }
3351        validate_optional_fingerprint(
3352            "shape_plan_fingerprint",
3353            &self.shape_plan_fingerprint,
3354            &self.producer_node,
3355        )?;
3356        validate_optional_fingerprint(
3357            "aggregation_policy_fingerprint",
3358            &self.aggregation_policy_fingerprint,
3359            &self.producer_node,
3360        )?;
3361        validate_optional_fingerprint(
3362            "feature_schema_fingerprint",
3363            &self.feature_schema_fingerprint,
3364            &self.producer_node,
3365        )?;
3366        validate_optional_fingerprint(
3367            "relation_delta_fingerprint",
3368            &self.relation_delta_fingerprint,
3369            &self.producer_node,
3370        )?;
3371        if let Some(representation_plan) = &self.representation_plan {
3372            representation_plan.validate().map_err(|error| {
3373                DagMlError::RuntimeValidation(format!(
3374                    "data output provenance for `{}` has invalid representation_plan: {error}",
3375                    self.producer_node
3376                ))
3377            })?;
3378        }
3379        if let Some(replay_manifest) = &self.representation_replay_manifest {
3380            replay_manifest.validate().map_err(|error| {
3381                DagMlError::RuntimeValidation(format!(
3382                    "data output provenance for `{}` has invalid representation_replay_manifest: {error}",
3383                    self.producer_node
3384                ))
3385            })?;
3386        }
3387        if let Some(report) = &self.representation_compatibility {
3388            report.validate().map_err(|error| {
3389                DagMlError::RuntimeValidation(format!(
3390                    "data output provenance for `{}` has invalid representation_compatibility: {error}",
3391                    self.producer_node
3392                ))
3393            })?;
3394        }
3395        if self
3396            .feature_namespace
3397            .as_ref()
3398            .is_some_and(|namespace| namespace.trim().is_empty())
3399        {
3400            return Err(DagMlError::RuntimeValidation(format!(
3401                "data output provenance for `{}` has empty feature_namespace",
3402                self.producer_node
3403            )));
3404        }
3405        for delta in &self.shape_deltas {
3406            delta.validate()?;
3407            if delta.node_id != self.producer_node {
3408                return Err(DagMlError::RuntimeValidation(format!(
3409                    "data output provenance for `{}` contains shape delta for `{}`",
3410                    self.producer_node, delta.node_id
3411                )));
3412            }
3413        }
3414        if let Some(feature_schema_fingerprint) = &self.feature_schema_fingerprint {
3415            if let Some(last_feature_delta) = self
3416                .shape_deltas
3417                .iter()
3418                .rev()
3419                .find(|delta| delta.kind == ShapeDeltaKind::Feature)
3420            {
3421                if &last_feature_delta.after_fingerprint != feature_schema_fingerprint {
3422                    return Err(DagMlError::RuntimeValidation(format!(
3423                        "data output provenance for `{}` has feature_schema_fingerprint `{feature_schema_fingerprint}` but last feature delta ends at `{}`",
3424                        self.producer_node, last_feature_delta.after_fingerprint
3425                    )));
3426                }
3427            }
3428        }
3429        Ok(())
3430    }
3431}
3432
3433fn validate_optional_fingerprint(
3434    label: &str,
3435    fingerprint: &Option<String>,
3436    producer_node: &NodeId,
3437) -> Result<()> {
3438    let Some(fingerprint) = fingerprint else {
3439        return Ok(());
3440    };
3441    if fingerprint.len() != 64 || !fingerprint.bytes().all(|byte| byte.is_ascii_hexdigit()) {
3442        return Err(DagMlError::RuntimeValidation(format!(
3443            "data output provenance for `{producer_node}` has invalid {label}"
3444        )));
3445    }
3446    Ok(())
3447}
3448
3449fn validate_optional_ids<T>(label: &str, values: &Option<Vec<T>>) -> Result<()>
3450where
3451    T: Ord + ToString,
3452{
3453    let Some(values) = values else {
3454        return Ok(());
3455    };
3456    if values.is_empty() {
3457        return Err(DagMlError::RuntimeValidation(format!(
3458            "data provider view {label} list is empty"
3459        )));
3460    }
3461    let mut seen = BTreeSet::new();
3462    for value in values {
3463        if !seen.insert(value) {
3464            return Err(DagMlError::RuntimeValidation(format!(
3465                "data provider view has duplicate {label} `{}`",
3466                value.to_string()
3467            )));
3468        }
3469    }
3470    Ok(())
3471}
3472
3473fn validate_optional_strings(label: &str, values: &Option<Vec<String>>) -> Result<()> {
3474    let Some(values) = values else {
3475        return Ok(());
3476    };
3477    if values.is_empty() {
3478        return Err(DagMlError::RuntimeValidation(format!(
3479            "data provider view {label} list is empty"
3480        )));
3481    }
3482    let mut seen = BTreeSet::new();
3483    for value in values {
3484        if value.trim().is_empty() {
3485            return Err(DagMlError::RuntimeValidation(format!(
3486                "data provider view contains an empty {label}"
3487            )));
3488        }
3489        if !seen.insert(value.as_str()) {
3490            return Err(DagMlError::RuntimeValidation(format!(
3491                "data provider view has duplicate {label} `{value}`"
3492            )));
3493        }
3494    }
3495    Ok(())
3496}
3497
3498#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3499pub struct DataViewRequest {
3500    pub run_id: RunId,
3501    pub node_id: NodeId,
3502    pub input_name: String,
3503    pub phase: Phase,
3504    pub variant_id: Option<VariantId>,
3505    pub fold_id: Option<FoldId>,
3506    pub binding: crate::data::DataBinding,
3507    pub data_handle: HandleRef,
3508    pub view: DataProviderViewSpec,
3509}
3510
3511pub trait RuntimeDataProvider {
3512    fn materialize(&self, request: &DataMaterializationRequest) -> Result<HandleRef>;
3513    fn make_view(&self, request: &DataViewRequest) -> Result<HandleRef>;
3514    fn coordinator_relations(&self, _binding: &DataBinding) -> Result<Option<SampleRelationSet>> {
3515        Ok(None)
3516    }
3517}
3518
3519pub trait RuntimeController: Send + Sync {
3520    fn controller_id(&self) -> &ControllerId;
3521    fn invoke(&self, task: &NodeTask) -> Result<NodeResult>;
3522
3523    fn invoke_aggregation(
3524        &self,
3525        task: &AggregationControllerTask,
3526    ) -> Result<AggregationControllerResult> {
3527        Err(DagMlError::RuntimeValidation(format!(
3528            "runtime controller `{}` does not implement aggregation task `{}`",
3529            self.controller_id(),
3530            task.task_id
3531        )))
3532    }
3533}
3534
3535pub struct BundleReplayExecution<'a> {
3536    pub plan: &'a ExecutionPlan,
3537    pub bundle: &'a ExecutionBundle,
3538    pub replay_request: &'a ReplayPhaseRequest,
3539    pub prediction_cache_store: Option<&'a dyn RuntimePredictionCacheStore>,
3540    pub controllers: &'a RuntimeControllerRegistry,
3541    pub data_provider: &'a dyn RuntimeDataProvider,
3542    pub artifact_store: &'a dyn RuntimeArtifactStore,
3543    pub data_envelopes: &'a BTreeMap<String, ExternalDataPlanEnvelope>,
3544}
3545
3546#[derive(Default)]
3547pub struct RuntimeControllerRegistry {
3548    controllers: BTreeMap<ControllerId, Box<dyn RuntimeController>>,
3549}
3550
3551impl RuntimeControllerRegistry {
3552    pub fn new() -> Self {
3553        Self::default()
3554    }
3555
3556    pub fn register(&mut self, controller: Box<dyn RuntimeController>) -> Result<()> {
3557        let id = controller.controller_id().clone();
3558        if self.controllers.insert(id.clone(), controller).is_some() {
3559            return Err(DagMlError::RuntimeValidation(format!(
3560                "duplicate runtime controller `{id}`"
3561            )));
3562        }
3563        Ok(())
3564    }
3565
3566    pub fn get(&self, controller_id: &ControllerId) -> Option<&dyn RuntimeController> {
3567        self.controllers.get(controller_id).map(Box::as_ref)
3568    }
3569}
3570
3571pub fn dispatch_custom_observation_aggregation(
3572    plan: &ExecutionPlan,
3573    controllers: &RuntimeControllerRegistry,
3574    task_id: impl Into<String>,
3575    block: ObservationPredictionBlock,
3576    relations: SampleRelationSet,
3577    policy: AggregationPolicy,
3578    requested_sample_order: Vec<SampleId>,
3579) -> Result<PredictionBlock> {
3580    let controller_id = custom_aggregation_controller_id(&policy)?;
3581    ensure_aggregation_controller_capability(plan, controller_id)?;
3582    let task = AggregationControllerTask {
3583        schema_version: crate::aggregation::AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
3584        task_id: task_id.into(),
3585        controller_id: controller_id.clone(),
3586        policy,
3587        reduction_plan: None,
3588        input: AggregationControllerInput::ObservationToSample {
3589            block,
3590            relations,
3591            requested_sample_order,
3592        },
3593    };
3594    let result = dispatch_custom_aggregation_task(controllers, &task)?;
3595    match result.output {
3596        AggregationControllerOutput::Sample { block } => Ok(block),
3597        AggregationControllerOutput::Unit { .. } => Err(DagMlError::RuntimeValidation(format!(
3598            "aggregation controller task `{}` returned unit output for observation input",
3599            task.task_id
3600        ))),
3601    }
3602}
3603
3604pub fn dispatch_custom_sample_aggregation(
3605    plan: &ExecutionPlan,
3606    controllers: &RuntimeControllerRegistry,
3607    task_id: impl Into<String>,
3608    block: PredictionBlock,
3609    relations: SampleRelationSet,
3610    policy: AggregationPolicy,
3611    requested_unit_order: Vec<PredictionUnitId>,
3612) -> Result<AggregatedPredictionBlock> {
3613    let controller_id = custom_aggregation_controller_id(&policy)?;
3614    ensure_aggregation_controller_capability(plan, controller_id)?;
3615    let task = AggregationControllerTask {
3616        schema_version: crate::aggregation::AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
3617        task_id: task_id.into(),
3618        controller_id: controller_id.clone(),
3619        policy,
3620        reduction_plan: None,
3621        input: AggregationControllerInput::SampleToUnit {
3622            block,
3623            relations,
3624            requested_unit_order,
3625        },
3626    };
3627    let result = dispatch_custom_aggregation_task(controllers, &task)?;
3628    match result.output {
3629        AggregationControllerOutput::Unit { block } => Ok(block),
3630        AggregationControllerOutput::Sample { .. } => Err(DagMlError::RuntimeValidation(format!(
3631            "aggregation controller task `{}` returned sample output for sample input",
3632            task.task_id
3633        ))),
3634    }
3635}
3636
3637pub fn dispatch_custom_aggregation_task(
3638    controllers: &RuntimeControllerRegistry,
3639    task: &AggregationControllerTask,
3640) -> Result<AggregationControllerResult> {
3641    task.validate()?;
3642    let controller = controllers.get(&task.controller_id).ok_or_else(|| {
3643        DagMlError::RuntimeValidation(format!(
3644            "aggregation runtime controller `{}` is not registered",
3645            task.controller_id
3646        ))
3647    })?;
3648    let result = controller.invoke_aggregation(task)?;
3649    result.validate_for_task(task)?;
3650    Ok(result)
3651}
3652
3653fn custom_aggregation_controller_id(policy: &AggregationPolicy) -> Result<&ControllerId> {
3654    policy.validate()?;
3655    policy
3656        .custom_controller
3657        .as_ref()
3658        .map(|controller| &controller.controller_id)
3659        .ok_or_else(|| {
3660            DagMlError::RuntimeValidation(
3661                "custom aggregation dispatch requires a custom_controller policy".to_string(),
3662            )
3663        })
3664}
3665
3666fn ensure_aggregation_controller_capability(
3667    plan: &ExecutionPlan,
3668    controller_id: &ControllerId,
3669) -> Result<()> {
3670    let manifest = plan
3671        .controller_manifests
3672        .get(controller_id)
3673        .ok_or_else(|| {
3674            DagMlError::Planning(format!(
3675                "missing aggregation controller manifest `{controller_id}`"
3676            ))
3677        })?;
3678    if !manifest
3679        .capabilities
3680        .contains(&ControllerCapability::AggregatesPredictions)
3681    {
3682        return Err(DagMlError::Planning(format!(
3683            "aggregation controller `{controller_id}` must declare aggregates_predictions"
3684        )));
3685    }
3686    Ok(())
3687}
3688
3689#[derive(Clone, Debug)]
3690pub struct RunContext {
3691    pub run_id: RunId,
3692    pub root_seed: Option<u64>,
3693    pub variant_id: Option<VariantId>,
3694    pub prediction_store: InMemoryPredictionStore,
3695    pub aggregated_prediction_store: InMemoryAggregatedPredictionStore,
3696    pub lineage: InMemoryLineageRecorder,
3697}
3698
3699impl RunContext {
3700    pub fn new(run_id: RunId, root_seed: Option<u64>) -> Self {
3701        Self {
3702            run_id,
3703            root_seed,
3704            variant_id: None,
3705            prediction_store: InMemoryPredictionStore::new(),
3706            aggregated_prediction_store: InMemoryAggregatedPredictionStore::new(),
3707            lineage: InMemoryLineageRecorder::new(),
3708        }
3709    }
3710}
3711
3712#[derive(Clone, Debug, Default)]
3713pub struct SequentialScheduler;
3714
3715#[derive(Clone, Debug)]
3716pub struct ParallelScheduler {
3717    max_workers: usize,
3718}
3719
3720impl ParallelScheduler {
3721    pub fn new(max_workers: usize) -> Result<Self> {
3722        if max_workers == 0 {
3723            return Err(DagMlError::RuntimeValidation(
3724                "parallel scheduler max_workers must be at least 1".to_string(),
3725            ));
3726        }
3727        Ok(Self { max_workers })
3728    }
3729
3730    pub fn max_workers(&self) -> usize {
3731        self.max_workers
3732    }
3733}
3734
3735#[derive(Clone, Debug)]
3736struct PhaseScope {
3737    phase: Phase,
3738    variant_id: Option<VariantId>,
3739    variant: Option<VariantExecutionSpec>,
3740    fold_id: Option<FoldId>,
3741    seed_root: Option<u64>,
3742}
3743
3744#[derive(Clone, Debug)]
3745struct ReplayPredictionCacheContract {
3746    requirement: BundlePredictionRequirement,
3747    cache: BundlePredictionCacheRecord,
3748}
3749
3750struct MaterializedReplayArtifacts {
3751    handles: BTreeMap<NodeId, BTreeMap<String, HandleRef>>,
3752    inputs: BTreeMap<NodeId, BTreeMap<String, ArtifactInputSpec>>,
3753}
3754
3755#[derive(Default)]
3756struct PhaseScopeResources<'a> {
3757    data_provider: Option<&'a dyn RuntimeDataProvider>,
3758    replay_artifact_handles: Option<&'a BTreeMap<NodeId, BTreeMap<String, HandleRef>>>,
3759    replay_artifact_inputs: Option<&'a BTreeMap<NodeId, BTreeMap<String, ArtifactInputSpec>>>,
3760    replay_bundle_id: Option<&'a BundleId>,
3761    data_envelopes: Option<&'a BTreeMap<String, ExternalDataPlanEnvelope>>,
3762    prediction_cache_store: Option<&'a dyn RuntimePredictionCacheStore>,
3763    prediction_cache_contracts: Option<&'a BTreeMap<String, ReplayPredictionCacheContract>>,
3764    artifact_store: Option<&'a mut InMemoryArtifactStore>,
3765}
3766
3767impl SequentialScheduler {
3768    pub fn execute_phase(
3769        &self,
3770        plan: &ExecutionPlan,
3771        controllers: &RuntimeControllerRegistry,
3772        ctx: &mut RunContext,
3773        phase: Phase,
3774    ) -> Result<Vec<NodeResult>> {
3775        plan.validate()?;
3776        let variant_id = ctx.variant_id.clone();
3777        let seed_root = ctx.root_seed;
3778        self.execute_phase_scope(
3779            plan,
3780            controllers,
3781            ctx,
3782            PhaseScope {
3783                phase,
3784                variant_id,
3785                variant: None,
3786                fold_id: None,
3787                seed_root,
3788            },
3789            PhaseScopeResources::default(),
3790        )
3791    }
3792
3793    pub fn execute_phase_with_data_provider(
3794        &self,
3795        plan: &ExecutionPlan,
3796        controllers: &RuntimeControllerRegistry,
3797        data_provider: &dyn RuntimeDataProvider,
3798        ctx: &mut RunContext,
3799        phase: Phase,
3800    ) -> Result<Vec<NodeResult>> {
3801        plan.validate()?;
3802        let variant_id = ctx.variant_id.clone();
3803        let seed_root = ctx.root_seed;
3804        self.execute_phase_scope(
3805            plan,
3806            controllers,
3807            ctx,
3808            PhaseScope {
3809                phase,
3810                variant_id,
3811                variant: None,
3812                fold_id: None,
3813                seed_root,
3814            },
3815            PhaseScopeResources {
3816                data_provider: Some(data_provider),
3817                ..Default::default()
3818            },
3819        )
3820    }
3821
3822    pub fn execute_campaign_phase(
3823        &self,
3824        plan: &ExecutionPlan,
3825        controllers: &RuntimeControllerRegistry,
3826        ctx: &mut RunContext,
3827        phase: Phase,
3828    ) -> Result<Vec<NodeResult>> {
3829        plan.validate()?;
3830        let mut results = Vec::new();
3831        let fold_ids = if phase == Phase::FitCv {
3832            plan.fold_set
3833                .as_ref()
3834                .map(|fold_set| {
3835                    fold_set
3836                        .folds
3837                        .iter()
3838                        .map(|fold| Some(fold.fold_id.clone()))
3839                        .collect::<Vec<_>>()
3840                })
3841                .unwrap_or_else(|| vec![None])
3842        } else {
3843            vec![None]
3844        };
3845        for variant in &plan.variants {
3846            if ctx
3847                .variant_id
3848                .as_ref()
3849                .is_some_and(|requested| requested != &variant.variant_id)
3850            {
3851                continue;
3852            }
3853            for fold_id in &fold_ids {
3854                let seed_root = variant.seed.or(ctx.root_seed);
3855                results.extend(self.execute_phase_scope(
3856                    plan,
3857                    controllers,
3858                    ctx,
3859                    PhaseScope {
3860                        phase,
3861                        variant_id: Some(variant.variant_id.clone()),
3862                        variant: Some(VariantExecutionSpec::from_plan(variant)),
3863                        fold_id: fold_id.clone(),
3864                        seed_root,
3865                    },
3866                    PhaseScopeResources::default(),
3867                )?);
3868            }
3869        }
3870        Ok(results)
3871    }
3872
3873    pub fn execute_campaign_phase_with_data_provider(
3874        &self,
3875        plan: &ExecutionPlan,
3876        controllers: &RuntimeControllerRegistry,
3877        data_provider: &dyn RuntimeDataProvider,
3878        ctx: &mut RunContext,
3879        phase: Phase,
3880    ) -> Result<Vec<NodeResult>> {
3881        plan.validate()?;
3882        let mut results = Vec::new();
3883        let fold_ids = if phase == Phase::FitCv {
3884            plan.fold_set
3885                .as_ref()
3886                .map(|fold_set| {
3887                    fold_set
3888                        .folds
3889                        .iter()
3890                        .map(|fold| Some(fold.fold_id.clone()))
3891                        .collect::<Vec<_>>()
3892                })
3893                .unwrap_or_else(|| vec![None])
3894        } else {
3895            vec![None]
3896        };
3897        for variant in &plan.variants {
3898            if ctx
3899                .variant_id
3900                .as_ref()
3901                .is_some_and(|requested| requested != &variant.variant_id)
3902            {
3903                continue;
3904            }
3905            for fold_id in &fold_ids {
3906                let seed_root = variant.seed.or(ctx.root_seed);
3907                results.extend(self.execute_phase_scope(
3908                    plan,
3909                    controllers,
3910                    ctx,
3911                    PhaseScope {
3912                        phase,
3913                        variant_id: Some(variant.variant_id.clone()),
3914                        variant: Some(VariantExecutionSpec::from_plan(variant)),
3915                        fold_id: fold_id.clone(),
3916                        seed_root,
3917                    },
3918                    PhaseScopeResources {
3919                        data_provider: Some(data_provider),
3920                        ..Default::default()
3921                    },
3922                )?);
3923            }
3924        }
3925        Ok(results)
3926    }
3927
3928    pub fn execute_campaign_phase_with_data_provider_and_artifact_store(
3929        &self,
3930        plan: &ExecutionPlan,
3931        controllers: &RuntimeControllerRegistry,
3932        data_provider: &dyn RuntimeDataProvider,
3933        artifact_store: &mut InMemoryArtifactStore,
3934        ctx: &mut RunContext,
3935        phase: Phase,
3936    ) -> Result<Vec<NodeResult>> {
3937        plan.validate()?;
3938        let mut results = Vec::new();
3939        let fold_ids = if phase == Phase::FitCv {
3940            plan.fold_set
3941                .as_ref()
3942                .map(|fold_set| {
3943                    fold_set
3944                        .folds
3945                        .iter()
3946                        .map(|fold| Some(fold.fold_id.clone()))
3947                        .collect::<Vec<_>>()
3948                })
3949                .unwrap_or_else(|| vec![None])
3950        } else {
3951            vec![None]
3952        };
3953        for variant in &plan.variants {
3954            if ctx
3955                .variant_id
3956                .as_ref()
3957                .is_some_and(|requested| requested != &variant.variant_id)
3958            {
3959                continue;
3960            }
3961            for fold_id in &fold_ids {
3962                let seed_root = variant.seed.or(ctx.root_seed);
3963                results.extend(self.execute_phase_scope(
3964                    plan,
3965                    controllers,
3966                    ctx,
3967                    PhaseScope {
3968                        phase,
3969                        variant_id: Some(variant.variant_id.clone()),
3970                        variant: Some(VariantExecutionSpec::from_plan(variant)),
3971                        fold_id: fold_id.clone(),
3972                        seed_root,
3973                    },
3974                    PhaseScopeResources {
3975                        data_provider: Some(data_provider),
3976                        artifact_store: Some(&mut *artifact_store),
3977                        ..Default::default()
3978                    },
3979                )?);
3980            }
3981        }
3982        Ok(results)
3983    }
3984
3985    pub fn execute_bundle_replay(
3986        &self,
3987        replay: BundleReplayExecution<'_>,
3988        ctx: &mut RunContext,
3989    ) -> Result<Vec<NodeResult>> {
3990        replay.bundle.validate_against_plan(replay.plan)?;
3991        replay
3992            .replay_request
3993            .validate_for_bundle_with_prediction_cache_store(
3994                replay.bundle,
3995                replay.prediction_cache_store.is_some(),
3996            )?;
3997        replay
3998            .bundle
3999            .validate_replay_envelopes(replay.data_envelopes)?;
4000        let prediction_cache_contracts = if replay.replay_request.phase == Phase::Refit {
4001            Some(replay_prediction_cache_contracts(replay.bundle)?)
4002        } else {
4003            None
4004        };
4005        if replay.replay_request.phase == Phase::Refit {
4006            preload_replay_prediction_cache_store(
4007                replay.bundle,
4008                replay.prediction_cache_store,
4009                ctx,
4010            )?;
4011        }
4012        let replay_artifacts = materialize_replay_artifact_handles(
4013            replay.plan,
4014            replay.bundle,
4015            replay.replay_request,
4016            replay.artifact_store,
4017            ctx,
4018        )?;
4019        let selected_variant = replay
4020            .bundle
4021            .selected_variant_id
4022            .as_ref()
4023            .map(|selected| {
4024                replay
4025                    .plan
4026                    .variants
4027                    .iter()
4028                    .find(|variant| &variant.variant_id == selected)
4029                    .map(VariantExecutionSpec::from_plan)
4030                    .ok_or_else(|| {
4031                        DagMlError::RuntimeValidation(format!(
4032                            "bundle `{}` selected unknown variant `{selected}`",
4033                            replay.bundle.bundle_id
4034                        ))
4035                    })
4036            })
4037            .transpose()?;
4038        let seed_root = selected_variant
4039            .as_ref()
4040            .and_then(|variant| variant.seed)
4041            .or(ctx.root_seed);
4042
4043        self.execute_phase_scope(
4044            replay.plan,
4045            replay.controllers,
4046            ctx,
4047            PhaseScope {
4048                phase: replay.replay_request.phase,
4049                variant_id: replay.bundle.selected_variant_id.clone(),
4050                variant: selected_variant,
4051                fold_id: None,
4052                seed_root,
4053            },
4054            PhaseScopeResources {
4055                data_provider: Some(replay.data_provider),
4056                replay_artifact_handles: Some(&replay_artifacts.handles),
4057                replay_artifact_inputs: Some(&replay_artifacts.inputs),
4058                replay_bundle_id: Some(&replay.bundle.bundle_id),
4059                data_envelopes: Some(replay.data_envelopes),
4060                prediction_cache_store: replay.prediction_cache_store,
4061                prediction_cache_contracts: prediction_cache_contracts.as_ref(),
4062                ..Default::default()
4063            },
4064        )
4065    }
4066
4067    fn execute_phase_scope(
4068        &self,
4069        plan: &ExecutionPlan,
4070        controllers: &RuntimeControllerRegistry,
4071        ctx: &mut RunContext,
4072        scope: PhaseScope,
4073        mut resources: PhaseScopeResources<'_>,
4074    ) -> Result<Vec<NodeResult>> {
4075        let _phase_span = crate::observability::phase_span(
4076            ctx.run_id.as_str(),
4077            plan.id.as_str(),
4078            scope.phase.as_str(),
4079            scope.variant_id.as_ref().map(VariantId::as_str),
4080            scope.fold_id.as_ref().map(FoldId::as_str),
4081        )
4082        .entered();
4083        let mut results = Vec::new();
4084        let mut output_handles = BTreeMap::<NodeId, BTreeMap<String, HandleRef>>::new();
4085        let mut output_data_views =
4086            BTreeMap::<NodeId, BTreeMap<String, DataProviderViewSpec>>::new();
4087        let mut input_lineage = BTreeMap::<NodeId, LineageId>::new();
4088
4089        for level in plan.node_parallel_levels_for_phase(scope.phase)? {
4090            for node_id in &level {
4091                let node_plan = plan
4092                    .node_plans
4093                    .get(node_id)
4094                    .expect("execution plan was validated");
4095                let controller = controllers.get(&node_plan.controller_id).ok_or_else(|| {
4096                    DagMlError::RuntimeValidation(format!(
4097                        "runtime controller `{}` is not registered",
4098                        node_plan.controller_id
4099                    ))
4100                })?;
4101                let collected_inputs = collect_input_handles(
4102                    plan,
4103                    node_plan,
4104                    &output_handles,
4105                    &output_data_views,
4106                    &resources,
4107                    ctx,
4108                    &scope,
4109                )?;
4110                let mut input_handles = collected_inputs.handles;
4111                let mut artifact_inputs = BTreeMap::new();
4112                if let Some(node_artifact_handles) = resources
4113                    .replay_artifact_handles
4114                    .and_then(|handles| handles.get(node_id))
4115                {
4116                    for (key, handle) in node_artifact_handles {
4117                        if input_handles.insert(key.clone(), handle.clone()).is_some() {
4118                            return Err(DagMlError::RuntimeValidation(format!(
4119                                "node `{node_id}` received duplicate replay artifact input `{key}`"
4120                            )));
4121                        }
4122                    }
4123                }
4124                if let Some(node_artifact_inputs) = resources
4125                    .replay_artifact_inputs
4126                    .and_then(|inputs| inputs.get(node_id))
4127                {
4128                    for (key, spec) in node_artifact_inputs {
4129                        if artifact_inputs.insert(key.clone(), spec.clone()).is_some() {
4130                            return Err(DagMlError::RuntimeValidation(format!(
4131                                "node `{node_id}` received duplicate replay artifact metadata `{key}`"
4132                            )));
4133                        }
4134                    }
4135                }
4136                let task_node_plan = effective_node_plan_for_scope(node_plan, &scope)?;
4137                let inner_fold_set = inner_fold_set_for_scope(
4138                    &plan.campaign,
4139                    plan.fold_set.as_ref(),
4140                    node_plan,
4141                    &scope,
4142                )?;
4143                let fit_influence = fit_influence_task_for_node(
4144                    plan,
4145                    &task_node_plan,
4146                    &collected_inputs.data_views,
4147                )?;
4148                let task = NodeTask {
4149                    inner_fold_set,
4150                    run_id: ctx.run_id.clone(),
4151                    node_plan: task_node_plan.clone(),
4152                    phase: scope.phase,
4153                    variant_id: scope.variant_id.clone(),
4154                    variant: scope.variant.clone(),
4155                    fold_id: scope.fold_id.clone(),
4156                    branch_path: Vec::new(),
4157                    input_handles,
4158                    data_views: collected_inputs.data_views,
4159                    prediction_inputs: collected_inputs.prediction_inputs,
4160                    artifact_inputs,
4161                    fit_influence,
4162                    seed: derive_task_seed(
4163                        scope.seed_root,
4164                        scope.variant_id.as_ref(),
4165                        scope.fold_id.as_ref(),
4166                        &task_node_plan,
4167                        scope.phase,
4168                    ),
4169                };
4170                let _node_span = crate::observability::node_span(
4171                    task.run_id.as_str(),
4172                    plan.id.as_str(),
4173                    task.phase.as_str(),
4174                    task.node_plan.node_id.as_str(),
4175                    task.node_plan.controller_id.as_str(),
4176                )
4177                .entered();
4178                let mut result = controller.invoke(&task)?;
4179                record_fit_influence_diagnostic(&task, &mut result);
4180                result.validate_for_task(&task)?;
4181                apply_result_prediction_aggregation(
4182                    plan,
4183                    controllers,
4184                    &task,
4185                    &mut result,
4186                    &resources,
4187                )?;
4188                attach_coordinator_input_lineage(
4189                    &mut result,
4190                    plan,
4191                    &task.node_plan.node_id,
4192                    &input_lineage,
4193                )?;
4194                if let Some(store) = resources.artifact_store.as_deref_mut() {
4195                    if scope.phase == Phase::Refit {
4196                        store.capture_refit_artifacts(&task, &result)?;
4197                    }
4198                }
4199                for prediction in &result.predictions {
4200                    ctx.prediction_store.append(prediction.clone())?;
4201                }
4202                for prediction in &result.aggregated_predictions {
4203                    ctx.aggregated_prediction_store.append(prediction.clone())?;
4204                }
4205                ctx.lineage.record(result.lineage.clone())?;
4206                let data_views = derive_output_data_views(plan, &task, &result)?;
4207                output_handles.insert(node_id.clone(), result.outputs.clone());
4208                output_data_views.insert(node_id.clone(), data_views);
4209                input_lineage.insert(node_id.clone(), result.lineage.record_id.clone());
4210                results.push(result);
4211            }
4212        }
4213
4214        Ok(results)
4215    }
4216}
4217
4218impl ParallelScheduler {
4219    pub fn execute_phase(
4220        &self,
4221        plan: &ExecutionPlan,
4222        controllers: &RuntimeControllerRegistry,
4223        ctx: &mut RunContext,
4224        phase: Phase,
4225    ) -> Result<Vec<NodeResult>> {
4226        plan.validate()?;
4227        let variant_id = ctx.variant_id.clone();
4228        let seed_root = ctx.root_seed;
4229        self.execute_phase_scope(
4230            plan,
4231            controllers,
4232            ctx,
4233            PhaseScope {
4234                phase,
4235                variant_id,
4236                variant: None,
4237                fold_id: None,
4238                seed_root,
4239            },
4240            PhaseScopeResources::default(),
4241        )
4242    }
4243
4244    pub fn execute_phase_with_data_provider(
4245        &self,
4246        plan: &ExecutionPlan,
4247        controllers: &RuntimeControllerRegistry,
4248        data_provider: &dyn RuntimeDataProvider,
4249        ctx: &mut RunContext,
4250        phase: Phase,
4251    ) -> Result<Vec<NodeResult>> {
4252        plan.validate()?;
4253        let variant_id = ctx.variant_id.clone();
4254        let seed_root = ctx.root_seed;
4255        self.execute_phase_scope(
4256            plan,
4257            controllers,
4258            ctx,
4259            PhaseScope {
4260                phase,
4261                variant_id,
4262                variant: None,
4263                fold_id: None,
4264                seed_root,
4265            },
4266            PhaseScopeResources {
4267                data_provider: Some(data_provider),
4268                ..Default::default()
4269            },
4270        )
4271    }
4272
4273    pub fn execute_campaign_phase(
4274        &self,
4275        plan: &ExecutionPlan,
4276        controllers: &RuntimeControllerRegistry,
4277        ctx: &mut RunContext,
4278        phase: Phase,
4279    ) -> Result<Vec<NodeResult>> {
4280        plan.validate()?;
4281        let mut results = Vec::new();
4282        let fold_ids = if phase == Phase::FitCv {
4283            plan.fold_set
4284                .as_ref()
4285                .map(|fold_set| {
4286                    fold_set
4287                        .folds
4288                        .iter()
4289                        .map(|fold| Some(fold.fold_id.clone()))
4290                        .collect::<Vec<_>>()
4291                })
4292                .unwrap_or_else(|| vec![None])
4293        } else {
4294            vec![None]
4295        };
4296        for variant in &plan.variants {
4297            if ctx
4298                .variant_id
4299                .as_ref()
4300                .is_some_and(|requested| requested != &variant.variant_id)
4301            {
4302                continue;
4303            }
4304            for fold_id in &fold_ids {
4305                let seed_root = variant.seed.or(ctx.root_seed);
4306                results.extend(self.execute_phase_scope(
4307                    plan,
4308                    controllers,
4309                    ctx,
4310                    PhaseScope {
4311                        phase,
4312                        variant_id: Some(variant.variant_id.clone()),
4313                        variant: Some(VariantExecutionSpec::from_plan(variant)),
4314                        fold_id: fold_id.clone(),
4315                        seed_root,
4316                    },
4317                    PhaseScopeResources::default(),
4318                )?);
4319            }
4320        }
4321        Ok(results)
4322    }
4323
4324    pub fn execute_campaign_phase_with_data_provider(
4325        &self,
4326        plan: &ExecutionPlan,
4327        controllers: &RuntimeControllerRegistry,
4328        data_provider: &dyn RuntimeDataProvider,
4329        ctx: &mut RunContext,
4330        phase: Phase,
4331    ) -> Result<Vec<NodeResult>> {
4332        plan.validate()?;
4333        let mut results = Vec::new();
4334        let fold_ids = if phase == Phase::FitCv {
4335            plan.fold_set
4336                .as_ref()
4337                .map(|fold_set| {
4338                    fold_set
4339                        .folds
4340                        .iter()
4341                        .map(|fold| Some(fold.fold_id.clone()))
4342                        .collect::<Vec<_>>()
4343                })
4344                .unwrap_or_else(|| vec![None])
4345        } else {
4346            vec![None]
4347        };
4348        for variant in &plan.variants {
4349            if ctx
4350                .variant_id
4351                .as_ref()
4352                .is_some_and(|requested| requested != &variant.variant_id)
4353            {
4354                continue;
4355            }
4356            for fold_id in &fold_ids {
4357                let seed_root = variant.seed.or(ctx.root_seed);
4358                results.extend(self.execute_phase_scope(
4359                    plan,
4360                    controllers,
4361                    ctx,
4362                    PhaseScope {
4363                        phase,
4364                        variant_id: Some(variant.variant_id.clone()),
4365                        variant: Some(VariantExecutionSpec::from_plan(variant)),
4366                        fold_id: fold_id.clone(),
4367                        seed_root,
4368                    },
4369                    PhaseScopeResources {
4370                        data_provider: Some(data_provider),
4371                        ..Default::default()
4372                    },
4373                )?);
4374            }
4375        }
4376        Ok(results)
4377    }
4378
4379    pub fn execute_campaign_phase_with_data_provider_and_artifact_store(
4380        &self,
4381        plan: &ExecutionPlan,
4382        controllers: &RuntimeControllerRegistry,
4383        data_provider: &dyn RuntimeDataProvider,
4384        artifact_store: &mut InMemoryArtifactStore,
4385        ctx: &mut RunContext,
4386        phase: Phase,
4387    ) -> Result<Vec<NodeResult>> {
4388        plan.validate()?;
4389        let mut results = Vec::new();
4390        let fold_ids = if phase == Phase::FitCv {
4391            plan.fold_set
4392                .as_ref()
4393                .map(|fold_set| {
4394                    fold_set
4395                        .folds
4396                        .iter()
4397                        .map(|fold| Some(fold.fold_id.clone()))
4398                        .collect::<Vec<_>>()
4399                })
4400                .unwrap_or_else(|| vec![None])
4401        } else {
4402            vec![None]
4403        };
4404        for variant in &plan.variants {
4405            if ctx
4406                .variant_id
4407                .as_ref()
4408                .is_some_and(|requested| requested != &variant.variant_id)
4409            {
4410                continue;
4411            }
4412            for fold_id in &fold_ids {
4413                let seed_root = variant.seed.or(ctx.root_seed);
4414                results.extend(self.execute_phase_scope(
4415                    plan,
4416                    controllers,
4417                    ctx,
4418                    PhaseScope {
4419                        phase,
4420                        variant_id: Some(variant.variant_id.clone()),
4421                        variant: Some(VariantExecutionSpec::from_plan(variant)),
4422                        fold_id: fold_id.clone(),
4423                        seed_root,
4424                    },
4425                    PhaseScopeResources {
4426                        data_provider: Some(data_provider),
4427                        artifact_store: Some(&mut *artifact_store),
4428                        ..Default::default()
4429                    },
4430                )?);
4431            }
4432        }
4433        Ok(results)
4434    }
4435
4436    pub fn execute_bundle_replay(
4437        &self,
4438        replay: BundleReplayExecution<'_>,
4439        ctx: &mut RunContext,
4440    ) -> Result<Vec<NodeResult>> {
4441        replay.bundle.validate_against_plan(replay.plan)?;
4442        replay
4443            .replay_request
4444            .validate_for_bundle_with_prediction_cache_store(
4445                replay.bundle,
4446                replay.prediction_cache_store.is_some(),
4447            )?;
4448        replay
4449            .bundle
4450            .validate_replay_envelopes(replay.data_envelopes)?;
4451        let prediction_cache_contracts = if replay.replay_request.phase == Phase::Refit {
4452            Some(replay_prediction_cache_contracts(replay.bundle)?)
4453        } else {
4454            None
4455        };
4456        if replay.replay_request.phase == Phase::Refit {
4457            preload_replay_prediction_cache_store(
4458                replay.bundle,
4459                replay.prediction_cache_store,
4460                ctx,
4461            )?;
4462        }
4463        let replay_artifacts = materialize_replay_artifact_handles(
4464            replay.plan,
4465            replay.bundle,
4466            replay.replay_request,
4467            replay.artifact_store,
4468            ctx,
4469        )?;
4470        let selected_variant = replay
4471            .bundle
4472            .selected_variant_id
4473            .as_ref()
4474            .map(|selected| {
4475                replay
4476                    .plan
4477                    .variants
4478                    .iter()
4479                    .find(|variant| &variant.variant_id == selected)
4480                    .map(VariantExecutionSpec::from_plan)
4481                    .ok_or_else(|| {
4482                        DagMlError::RuntimeValidation(format!(
4483                            "bundle `{}` selected unknown variant `{selected}`",
4484                            replay.bundle.bundle_id
4485                        ))
4486                    })
4487            })
4488            .transpose()?;
4489        let seed_root = selected_variant
4490            .as_ref()
4491            .and_then(|variant| variant.seed)
4492            .or(ctx.root_seed);
4493
4494        self.execute_phase_scope(
4495            replay.plan,
4496            replay.controllers,
4497            ctx,
4498            PhaseScope {
4499                phase: replay.replay_request.phase,
4500                variant_id: replay.bundle.selected_variant_id.clone(),
4501                variant: selected_variant,
4502                fold_id: None,
4503                seed_root,
4504            },
4505            PhaseScopeResources {
4506                data_provider: Some(replay.data_provider),
4507                replay_artifact_handles: Some(&replay_artifacts.handles),
4508                replay_artifact_inputs: Some(&replay_artifacts.inputs),
4509                replay_bundle_id: Some(&replay.bundle.bundle_id),
4510                data_envelopes: Some(replay.data_envelopes),
4511                prediction_cache_store: replay.prediction_cache_store,
4512                prediction_cache_contracts: prediction_cache_contracts.as_ref(),
4513                ..Default::default()
4514            },
4515        )
4516    }
4517
4518    fn execute_phase_scope(
4519        &self,
4520        plan: &ExecutionPlan,
4521        controllers: &RuntimeControllerRegistry,
4522        ctx: &mut RunContext,
4523        scope: PhaseScope,
4524        mut resources: PhaseScopeResources<'_>,
4525    ) -> Result<Vec<NodeResult>> {
4526        // Hold the phase span on the scheduler thread, and clone it into each
4527        // worker so worker-thread telemetry nests under the phase (tracing spans
4528        // are thread-local and do not auto-propagate across `thread::scope`).
4529        let phase_span = crate::observability::phase_span(
4530            ctx.run_id.as_str(),
4531            plan.id.as_str(),
4532            scope.phase.as_str(),
4533            scope.variant_id.as_ref().map(VariantId::as_str),
4534            scope.fold_id.as_ref().map(FoldId::as_str),
4535        );
4536        let _phase_entered = phase_span.clone().entered();
4537        // Borrowed for the `thread::scope` below; workers join before it ends.
4538        let plan_id = plan.id.as_str();
4539        plan.validate_parallel_controller_capabilities(self.max_workers, scope.phase)?;
4540        let mut results = Vec::new();
4541        let mut output_handles = BTreeMap::<NodeId, BTreeMap<String, HandleRef>>::new();
4542        let mut output_data_views =
4543            BTreeMap::<NodeId, BTreeMap<String, DataProviderViewSpec>>::new();
4544        let mut input_lineage = BTreeMap::<NodeId, LineageId>::new();
4545
4546        for level in plan.node_parallel_levels_for_phase(scope.phase)? {
4547            let mut prepared = Vec::<PreparedNodeTask>::new();
4548            for node_id in &level {
4549                let node_plan = plan
4550                    .node_plans
4551                    .get(node_id)
4552                    .expect("execution plan was validated");
4553                let collected_inputs = collect_input_handles(
4554                    plan,
4555                    node_plan,
4556                    &output_handles,
4557                    &output_data_views,
4558                    &resources,
4559                    ctx,
4560                    &scope,
4561                )?;
4562                let mut input_handles = collected_inputs.handles;
4563                let mut artifact_inputs = BTreeMap::new();
4564                if let Some(node_artifact_handles) = resources
4565                    .replay_artifact_handles
4566                    .and_then(|handles| handles.get(node_id))
4567                {
4568                    for (key, handle) in node_artifact_handles {
4569                        if input_handles.insert(key.clone(), handle.clone()).is_some() {
4570                            return Err(DagMlError::RuntimeValidation(format!(
4571                                "node `{node_id}` received duplicate replay artifact input `{key}`"
4572                            )));
4573                        }
4574                    }
4575                }
4576                if let Some(node_artifact_inputs) = resources
4577                    .replay_artifact_inputs
4578                    .and_then(|inputs| inputs.get(node_id))
4579                {
4580                    for (key, spec) in node_artifact_inputs {
4581                        if artifact_inputs.insert(key.clone(), spec.clone()).is_some() {
4582                            return Err(DagMlError::RuntimeValidation(format!(
4583                                "node `{node_id}` received duplicate replay artifact metadata `{key}`"
4584                            )));
4585                        }
4586                    }
4587                }
4588                let task_node_plan = effective_node_plan_for_scope(node_plan, &scope)?;
4589                let inner_fold_set = inner_fold_set_for_scope(
4590                    &plan.campaign,
4591                    plan.fold_set.as_ref(),
4592                    node_plan,
4593                    &scope,
4594                )?;
4595                let fit_influence = fit_influence_task_for_node(
4596                    plan,
4597                    &task_node_plan,
4598                    &collected_inputs.data_views,
4599                )?;
4600                prepared.push(PreparedNodeTask {
4601                    node_id: node_id.clone(),
4602                    task: NodeTask {
4603                        inner_fold_set,
4604                        run_id: ctx.run_id.clone(),
4605                        node_plan: task_node_plan.clone(),
4606                        phase: scope.phase,
4607                        variant_id: scope.variant_id.clone(),
4608                        variant: scope.variant.clone(),
4609                        fold_id: scope.fold_id.clone(),
4610                        branch_path: Vec::new(),
4611                        input_handles,
4612                        data_views: collected_inputs.data_views,
4613                        prediction_inputs: collected_inputs.prediction_inputs,
4614                        artifact_inputs,
4615                        fit_influence,
4616                        seed: derive_task_seed(
4617                            scope.seed_root,
4618                            scope.variant_id.as_ref(),
4619                            scope.fold_id.as_ref(),
4620                            &task_node_plan,
4621                            scope.phase,
4622                        ),
4623                    },
4624                });
4625            }
4626
4627            for chunk in prepared.chunks(self.max_workers) {
4628                let chunk_results =
4629                    std::thread::scope(|thread_scope| -> Result<Vec<NodeResult>> {
4630                        let mut handles = Vec::with_capacity(chunk.len());
4631                        for prepared_task in chunk {
4632                            let controller = controllers
4633                                .get(&prepared_task.task.node_plan.controller_id)
4634                                .ok_or_else(|| {
4635                                    DagMlError::RuntimeValidation(format!(
4636                                        "runtime controller `{}` is not registered",
4637                                        prepared_task.task.node_plan.controller_id
4638                                    ))
4639                                })?;
4640                            let worker_span = phase_span.clone();
4641                            handles.push(thread_scope.spawn(move || {
4642                                let _worker_span = worker_span.entered();
4643                                let _node_span = crate::observability::node_span(
4644                                    prepared_task.task.run_id.as_str(),
4645                                    plan_id,
4646                                    prepared_task.task.phase.as_str(),
4647                                    prepared_task.task.node_plan.node_id.as_str(),
4648                                    prepared_task.task.node_plan.controller_id.as_str(),
4649                                )
4650                                .entered();
4651                                let mut result = controller.invoke(&prepared_task.task)?;
4652                                record_fit_influence_diagnostic(&prepared_task.task, &mut result);
4653                                result.validate_for_task(&prepared_task.task)?;
4654                                Ok(result)
4655                            }));
4656                        }
4657                        handles
4658                            .into_iter()
4659                            .map(|handle| {
4660                                handle.join().map_err(|_| {
4661                                    DagMlError::RuntimeValidation(
4662                                        "parallel scheduler worker panicked".to_string(),
4663                                    )
4664                                })?
4665                            })
4666                            .collect()
4667                    })?;
4668
4669                for (prepared_task, mut result) in chunk.iter().zip(chunk_results) {
4670                    apply_result_prediction_aggregation(
4671                        plan,
4672                        controllers,
4673                        &prepared_task.task,
4674                        &mut result,
4675                        &resources,
4676                    )?;
4677                    attach_coordinator_input_lineage(
4678                        &mut result,
4679                        plan,
4680                        &prepared_task.task.node_plan.node_id,
4681                        &input_lineage,
4682                    )?;
4683                    if let Some(store) = resources.artifact_store.as_deref_mut() {
4684                        if scope.phase == Phase::Refit {
4685                            store.capture_refit_artifacts(&prepared_task.task, &result)?;
4686                        }
4687                    }
4688                    for prediction in &result.predictions {
4689                        ctx.prediction_store.append(prediction.clone())?;
4690                    }
4691                    for prediction in &result.aggregated_predictions {
4692                        ctx.aggregated_prediction_store.append(prediction.clone())?;
4693                    }
4694                    ctx.lineage.record(result.lineage.clone())?;
4695                    let data_views = derive_output_data_views(plan, &prepared_task.task, &result)?;
4696                    output_handles.insert(prepared_task.node_id.clone(), result.outputs.clone());
4697                    output_data_views.insert(prepared_task.node_id.clone(), data_views);
4698                    input_lineage.insert(
4699                        prepared_task.node_id.clone(),
4700                        result.lineage.record_id.clone(),
4701                    );
4702                    results.push(result);
4703                }
4704            }
4705        }
4706
4707        Ok(results)
4708    }
4709}
4710
4711struct PreparedNodeTask {
4712    node_id: NodeId,
4713    task: NodeTask,
4714}
4715
4716fn attach_coordinator_input_lineage(
4717    result: &mut NodeResult,
4718    plan: &ExecutionPlan,
4719    node_id: &NodeId,
4720    upstream_lineage: &BTreeMap<NodeId, LineageId>,
4721) -> Result<()> {
4722    let inferred = inferred_input_lineage_for_node(plan, node_id, upstream_lineage);
4723    if result.lineage.input_lineage.is_empty() {
4724        result.lineage.input_lineage = inferred;
4725        return Ok(());
4726    }
4727
4728    let declared = result
4729        .lineage
4730        .input_lineage
4731        .iter()
4732        .cloned()
4733        .collect::<BTreeSet<_>>()
4734        .into_iter()
4735        .collect::<Vec<_>>();
4736    if declared != inferred {
4737        return Err(DagMlError::RuntimeValidation(format!(
4738            "lineage for node `{}` declared input lineage {:?}, expected {:?}",
4739            result.node_id, declared, inferred
4740        )));
4741    }
4742    result.lineage.input_lineage = declared;
4743    Ok(())
4744}
4745
4746fn inferred_input_lineage_for_node(
4747    plan: &ExecutionPlan,
4748    node_id: &NodeId,
4749    upstream_lineage: &BTreeMap<NodeId, LineageId>,
4750) -> Vec<LineageId> {
4751    plan.graph_plan
4752        .graph
4753        .edges
4754        .iter()
4755        .filter(|edge| &edge.target.node_id == node_id && edge.contract.propagates_lineage)
4756        .filter_map(|edge| upstream_lineage.get(&edge.source.node_id).cloned())
4757        .collect::<BTreeSet<_>>()
4758        .into_iter()
4759        .collect()
4760}
4761
4762fn apply_result_prediction_aggregation(
4763    plan: &ExecutionPlan,
4764    controllers: &RuntimeControllerRegistry,
4765    task: &NodeTask,
4766    result: &mut NodeResult,
4767    resources: &PhaseScopeResources<'_>,
4768) -> Result<()> {
4769    let has_observation_predictions = !result.observation_predictions.is_empty();
4770    let has_sample_predictions = !result.predictions.is_empty();
4771    if !has_observation_predictions && !has_sample_predictions {
4772        return Ok(());
4773    }
4774    let Some(shape_plan) = &task.node_plan.shape_plan else {
4775        if !has_observation_predictions {
4776            return Ok(());
4777        }
4778        return Err(DagMlError::RuntimeValidation(format!(
4779            "node `{}` emitted observation predictions but has no data/model shape plan for aggregation",
4780            task.node_plan.node_id
4781        )));
4782    };
4783    let policy = &shape_plan.aggregation_policy;
4784    if !policy.store_aggregated_predictions {
4785        return Ok(());
4786    }
4787    if policy.aggregation_level == PredictionLevel::Observation {
4788        return Ok(());
4789    }
4790    if !has_observation_predictions && policy.aggregation_level == PredictionLevel::Sample {
4791        return Ok(());
4792    }
4793
4794    let mut derived_sample_blocks = Vec::new();
4795    if !result.observation_predictions.is_empty() {
4796        let relations = coordinator_relations_for_task(task, resources)?;
4797        let sample_policy = observation_to_sample_policy(policy);
4798        for block in result.observation_predictions.clone() {
4799            let requested_sample_order =
4800                requested_sample_order_for_observation_block(plan, task, &block, &relations)?;
4801            let sample_block =
4802                if sample_policy.method == crate::policy::AggregationMethod::CustomController {
4803                    dispatch_custom_observation_aggregation(
4804                        plan,
4805                        controllers,
4806                        aggregation_task_id(
4807                            task,
4808                            &block.producer_node,
4809                            block.fold_id.as_ref(),
4810                            "obs_to_sample",
4811                        ),
4812                        block,
4813                        relations.clone(),
4814                        sample_policy.clone(),
4815                        requested_sample_order,
4816                    )?
4817                } else {
4818                    aggregate_observation_predictions(
4819                        &block,
4820                        &relations,
4821                        &sample_policy,
4822                        &requested_sample_order,
4823                    )?
4824                };
4825            derived_sample_blocks.push(sample_block);
4826        }
4827    }
4828
4829    if policy.aggregation_level == PredictionLevel::Sample {
4830        result.predictions.extend(derived_sample_blocks);
4831        result.validate_for_task(task)?;
4832        return Ok(());
4833    }
4834
4835    if !result.aggregated_predictions.is_empty() {
4836        result.validate_for_task(task)?;
4837        return Ok(());
4838    }
4839
4840    let relations = coordinator_relations_for_task(task, resources)?;
4841    let sample_blocks = result
4842        .predictions
4843        .iter()
4844        .cloned()
4845        .chain(derived_sample_blocks)
4846        .collect::<Vec<_>>();
4847    for block in sample_blocks {
4848        let requested_unit_order =
4849            requested_unit_order_for_sample_block(policy.aggregation_level, &relations, &block)?;
4850        let aggregated = if policy.method == crate::policy::AggregationMethod::CustomController {
4851            dispatch_custom_sample_aggregation(
4852                plan,
4853                controllers,
4854                aggregation_task_id(
4855                    task,
4856                    &block.producer_node,
4857                    block.fold_id.as_ref(),
4858                    "sample_to_unit",
4859                ),
4860                block,
4861                relations.clone(),
4862                policy.clone(),
4863                requested_unit_order,
4864            )?
4865        } else {
4866            aggregate_sample_predictions_by_unit(&block, &relations, policy, &requested_unit_order)?
4867        };
4868        result.aggregated_predictions.push(aggregated);
4869    }
4870    result.validate_for_task(task)
4871}
4872
4873fn observation_to_sample_policy(policy: &AggregationPolicy) -> AggregationPolicy {
4874    let mut sample_policy = policy.clone();
4875    sample_policy.aggregation_level = PredictionLevel::Sample;
4876    sample_policy
4877}
4878
4879fn coordinator_relations_for_task(
4880    task: &NodeTask,
4881    resources: &PhaseScopeResources<'_>,
4882) -> Result<SampleRelationSet> {
4883    coordinator_relations_for_node(&task.node_plan, resources)?.ok_or_else(|| {
4884        DagMlError::RuntimeValidation(format!(
4885            "node `{}` needs coordinator relations for prediction aggregation but no matching data provider/envelope carries relations",
4886            task.node_plan.node_id
4887        ))
4888    })
4889}
4890
4891fn coordinator_relations_for_edge(
4892    plan: &ExecutionPlan,
4893    edge: &EdgeSpec,
4894    resources: &PhaseScopeResources<'_>,
4895) -> Result<SampleRelationSet> {
4896    let target_plan = plan.node_plans.get(&edge.target.node_id).ok_or_else(|| {
4897        DagMlError::Planning(format!(
4898            "OOF edge target node `{}` has no node plan",
4899            edge.target.node_id
4900        ))
4901    })?;
4902    if let Some(relations) = coordinator_relations_for_node(target_plan, resources)? {
4903        return Ok(relations);
4904    }
4905
4906    let source_plan = plan.node_plans.get(&edge.source.node_id).ok_or_else(|| {
4907        DagMlError::Planning(format!(
4908            "OOF edge source node `{}` has no node plan",
4909            edge.source.node_id
4910        ))
4911    })?;
4912    if let Some(relations) = coordinator_relations_for_node(source_plan, resources)? {
4913        return Ok(relations);
4914    }
4915
4916    Err(DagMlError::RuntimeValidation(format!(
4917        "edge `{}.{}` -> `{}.{}` needs coordinator relations for aggregated OOF validation but neither endpoint has a relation-carrying data binding",
4918        edge.source.node_id,
4919        edge.source.port_name,
4920        edge.target.node_id,
4921        edge.target.port_name
4922    )))
4923}
4924
4925fn coordinator_relations_for_node(
4926    node_plan: &NodePlan,
4927    resources: &PhaseScopeResources<'_>,
4928) -> Result<Option<SampleRelationSet>> {
4929    let mut selected: Option<SampleRelationSet> = None;
4930    for binding in &node_plan.data_bindings {
4931        if !binding.require_relations && binding.relation_fingerprint.is_none() {
4932            continue;
4933        }
4934        let relations = if let Some(envelopes) = resources.data_envelopes {
4935            let key = format!("{}.{}", binding.node_id, binding.input_name);
4936            let Some(envelope) = envelopes.get(&key) else {
4937                continue;
4938            };
4939            binding.validate_envelope(envelope)?;
4940            envelope.coordinator_relations.clone()
4941        } else if let Some(data_provider) = resources.data_provider {
4942            data_provider.coordinator_relations(binding)?
4943        } else {
4944            None
4945        };
4946        let Some(relations) = relations else {
4947            continue;
4948        };
4949        if let Some(previous) = &selected {
4950            if previous != &relations {
4951                return Err(DagMlError::RuntimeValidation(format!(
4952                    "node `{}` has multiple non-identical coordinator relation sets",
4953                    node_plan.node_id
4954                )));
4955            }
4956        } else {
4957            selected = Some(relations);
4958        }
4959    }
4960    Ok(selected)
4961}
4962
4963fn requested_sample_order_for_observation_block(
4964    plan: &ExecutionPlan,
4965    task: &NodeTask,
4966    block: &ObservationPredictionBlock,
4967    relations: &SampleRelationSet,
4968) -> Result<Vec<SampleId>> {
4969    if block.partition == PredictionPartition::Validation {
4970        if let Some(sample_ids) = validation_view_sample_ids(task) {
4971            return Ok(sample_ids.into_iter().collect());
4972        }
4973        if let (Some(fold_set), Some(fold_id)) = (plan.fold_set.as_ref(), block.fold_id.as_ref()) {
4974            if let Some(fold) = fold_set.folds.iter().find(|fold| &fold.fold_id == fold_id) {
4975                return Ok(fold.validation_sample_ids.clone());
4976            }
4977        }
4978    }
4979    first_seen_samples_for_observations(block, relations)
4980}
4981
4982fn first_seen_samples_for_observations(
4983    block: &ObservationPredictionBlock,
4984    relations: &SampleRelationSet,
4985) -> Result<Vec<SampleId>> {
4986    let mut seen = BTreeSet::new();
4987    let mut sample_order = Vec::new();
4988    for observation_id in &block.observation_ids {
4989        let sample_id = relations
4990            .sample_for_observation(observation_id)
4991            .ok_or_else(|| {
4992                DagMlError::OofValidation(format!(
4993                    "observation prediction `{observation_id}` has no sample relation"
4994                ))
4995            })?;
4996        if seen.insert(sample_id.clone()) {
4997            sample_order.push(sample_id.clone());
4998        }
4999    }
5000    Ok(sample_order)
5001}
5002
5003fn requested_unit_order_for_sample_block(
5004    level: PredictionLevel,
5005    relations: &SampleRelationSet,
5006    block: &PredictionBlock,
5007) -> Result<Vec<PredictionUnitId>> {
5008    let mut seen = BTreeSet::new();
5009    let mut unit_order = Vec::new();
5010    for sample_id in &block.sample_ids {
5011        let unit_id = match level {
5012            PredictionLevel::Sample => PredictionUnitId::Sample(sample_id.clone()),
5013            PredictionLevel::Target => relations
5014                .target_for_sample(sample_id)
5015                .cloned()
5016                .map(PredictionUnitId::Target)
5017                .ok_or_else(|| {
5018                    DagMlError::OofValidation(format!(
5019                        "sample `{sample_id}` is missing target id for target aggregation"
5020                    ))
5021                })?,
5022            PredictionLevel::Group => relations
5023                .group_for_sample(sample_id)
5024                .cloned()
5025                .map(PredictionUnitId::Group)
5026                .ok_or_else(|| {
5027                    DagMlError::OofValidation(format!(
5028                        "sample `{sample_id}` is missing group id for group aggregation"
5029                    ))
5030                })?,
5031            PredictionLevel::Observation => {
5032                return Err(DagMlError::OofValidation(
5033                    "sample prediction aggregation cannot output observation-level predictions"
5034                        .to_string(),
5035                ));
5036            }
5037        };
5038        if seen.insert(unit_id.clone()) {
5039            unit_order.push(unit_id);
5040        }
5041    }
5042    Ok(unit_order)
5043}
5044
5045fn aggregation_task_id(
5046    task: &NodeTask,
5047    producer_node: &NodeId,
5048    fold_id: Option<&FoldId>,
5049    stage: &str,
5050) -> String {
5051    let fold = fold_id
5052        .map(ToString::to_string)
5053        .unwrap_or_else(|| "nofold".to_string());
5054    format!(
5055        "aggregation:{}:{}:{}:{}:{}",
5056        task.run_id, task.node_plan.node_id, producer_node, fold, stage
5057    )
5058}
5059
5060fn collect_input_handles(
5061    plan: &ExecutionPlan,
5062    node_plan: &NodePlan,
5063    output_handles: &BTreeMap<NodeId, BTreeMap<String, HandleRef>>,
5064    output_data_views: &BTreeMap<NodeId, BTreeMap<String, DataProviderViewSpec>>,
5065    resources: &PhaseScopeResources<'_>,
5066    ctx: &RunContext,
5067    scope: &PhaseScope,
5068) -> Result<CollectedInputs> {
5069    let mut inputs = BTreeMap::new();
5070    let mut data_views = BTreeMap::new();
5071    let mut prediction_inputs = BTreeMap::new();
5072    let training_oof_edges = incoming_training_oof_edges(plan, node_plan, scope)?;
5073    let training_oof_sources = training_oof_edges
5074        .iter()
5075        .map(|edge| edge.source.node_id.clone())
5076        .collect::<BTreeSet<_>>();
5077    let bound_data_inputs = node_plan
5078        .data_bindings
5079        .iter()
5080        .map(|binding| binding.input_name.clone())
5081        .collect::<BTreeSet<_>>();
5082    for upstream in &node_plan.input_nodes {
5083        if training_oof_sources.contains(upstream) {
5084            continue;
5085        }
5086        if let Some(handles) = output_handles.get(upstream) {
5087            for (port, handle) in handles {
5088                inputs.insert(format!("{upstream}.{port}"), handle.clone());
5089            }
5090        }
5091    }
5092    for edge in plan
5093        .graph_plan
5094        .graph
5095        .edges
5096        .iter()
5097        .filter(|edge| edge.target.node_id == node_plan.node_id)
5098        .filter(|edge| edge.contract.kind == PortKind::Data && !edge.contract.requires_oof)
5099    {
5100        if bound_data_inputs.contains(&edge.target.port_name) {
5101            continue;
5102        }
5103        let Some(handles) = output_handles.get(&edge.source.node_id) else {
5104            continue;
5105        };
5106        let Some(handle) = handles.get(&edge.source.port_name) else {
5107            continue;
5108        };
5109        let key = data_view_key(&edge.target.port_name);
5110        if inputs.insert(key.clone(), handle.clone()).is_some() {
5111            return Err(DagMlError::RuntimeValidation(format!(
5112                "node `{}` received duplicate data edge input `{key}`",
5113                node_plan.node_id
5114            )));
5115        }
5116        if let Some(source_views) = output_data_views.get(&edge.source.node_id) {
5117            if let Some(view) = source_views.get(&edge.source.port_name) {
5118                if data_views.insert(key.clone(), view.clone()).is_some() {
5119                    return Err(DagMlError::RuntimeValidation(format!(
5120                        "node `{}` received duplicate data edge view `{key}`",
5121                        node_plan.node_id
5122                    )));
5123                }
5124            }
5125            let source_validation_key = validation_data_view_key(&edge.source.port_name);
5126            if let Some(view) = source_views.get(&source_validation_key) {
5127                let validation_key = format!("{key}:validation");
5128                if data_views
5129                    .insert(validation_key.clone(), view.clone())
5130                    .is_some()
5131                {
5132                    return Err(DagMlError::RuntimeValidation(format!(
5133                        "node `{}` received duplicate data edge validation view `{validation_key}`",
5134                        node_plan.node_id
5135                    )));
5136                }
5137            }
5138        }
5139    }
5140    for edge in training_oof_edges {
5141        let key = format!("{}.{}", edge.source.node_id, edge.source.port_name);
5142        let input = collect_oof_prediction_input(plan, edge, ctx, scope, resources)?;
5143        if inputs.insert(key.clone(), input.handle).is_some() {
5144            return Err(DagMlError::RuntimeValidation(format!(
5145                "node `{}` received duplicate OOF prediction input `{key}`",
5146                node_plan.node_id
5147            )));
5148        }
5149        if prediction_inputs.insert(key.clone(), input.spec).is_some() {
5150            return Err(DagMlError::RuntimeValidation(format!(
5151                "node `{}` received duplicate OOF prediction spec `{key}`",
5152                node_plan.node_id
5153            )));
5154        }
5155    }
5156    if !node_plan.data_bindings.is_empty() && resources.data_provider.is_none() {
5157        return Err(DagMlError::RuntimeValidation(format!(
5158            "node `{}` requires {} data binding(s) but no runtime data provider is registered",
5159            node_plan.node_id,
5160            node_plan.data_bindings.len()
5161        )));
5162    }
5163    if let Some(data_provider) = resources.data_provider {
5164        for binding in &node_plan.data_bindings {
5165            let materialized = data_provider.materialize(&DataMaterializationRequest {
5166                run_id: ctx.run_id.clone(),
5167                node_id: node_plan.node_id.clone(),
5168                input_name: binding.input_name.clone(),
5169                phase: scope.phase,
5170                variant_id: scope.variant_id.clone(),
5171                fold_id: scope.fold_id.clone(),
5172                binding: binding.clone(),
5173            })?;
5174            let branch_view_for_node = branch_view_from_node_metadata(plan, &node_plan.node_id)?;
5175            let view = data_view_for_scope(
5176                binding,
5177                plan.fold_set.as_ref(),
5178                scope,
5179                branch_view_for_node.as_ref(),
5180            )?;
5181            let key = data_view_key(&binding.input_name);
5182            let view_handle = make_data_view_handle(
5183                data_provider,
5184                ctx,
5185                node_plan,
5186                scope,
5187                binding,
5188                &materialized,
5189                &view,
5190            )?;
5191            if data_views.insert(key.clone(), view).is_some() {
5192                return Err(DagMlError::RuntimeValidation(format!(
5193                    "node `{}` received duplicate data view `{key}`",
5194                    node_plan.node_id
5195                )));
5196            }
5197            if inputs.insert(key.clone(), view_handle).is_some() {
5198                return Err(DagMlError::RuntimeValidation(format!(
5199                    "node `{}` received duplicate data input `{key}`",
5200                    node_plan.node_id
5201                )));
5202            }
5203
5204            if let Some(validation_view) = validation_data_view_for_scope(
5205                binding,
5206                plan.fold_set.as_ref(),
5207                scope,
5208                branch_view_for_node.as_ref(),
5209            )? {
5210                let validation_key = format!("{key}:validation");
5211                let validation_handle = make_data_view_handle(
5212                    data_provider,
5213                    ctx,
5214                    node_plan,
5215                    scope,
5216                    binding,
5217                    &materialized,
5218                    &validation_view,
5219                )?;
5220                if data_views
5221                    .insert(validation_key.clone(), validation_view)
5222                    .is_some()
5223                {
5224                    return Err(DagMlError::RuntimeValidation(format!(
5225                        "node `{}` received duplicate validation data view `{validation_key}`",
5226                        node_plan.node_id
5227                    )));
5228                }
5229                if inputs
5230                    .insert(validation_key.clone(), validation_handle)
5231                    .is_some()
5232                {
5233                    return Err(DagMlError::RuntimeValidation(format!(
5234                        "node `{}` received duplicate validation data input `{validation_key}`",
5235                        node_plan.node_id
5236                    )));
5237                }
5238            }
5239        }
5240    }
5241    Ok(CollectedInputs {
5242        handles: inputs,
5243        data_views,
5244        prediction_inputs,
5245    })
5246}
5247
5248fn effective_node_plan_for_scope(node_plan: &NodePlan, scope: &PhaseScope) -> Result<NodePlan> {
5249    let Some(variant) = &scope.variant else {
5250        return Ok(node_plan.clone());
5251    };
5252    let params = variant.effective_params_for_node(&node_plan.node_id, &node_plan.params)?;
5253    if params == node_plan.params {
5254        return Ok(node_plan.clone());
5255    }
5256    let mut node_plan = node_plan.clone();
5257    node_plan.params = params;
5258    node_plan.params_fingerprint = stable_json_fingerprint(&node_plan.params)?;
5259    Ok(node_plan)
5260}
5261
5262fn incoming_training_oof_edges<'a>(
5263    plan: &'a ExecutionPlan,
5264    node_plan: &NodePlan,
5265    scope: &PhaseScope,
5266) -> Result<Vec<&'a EdgeSpec>> {
5267    if !scope.phase.is_training() {
5268        return Ok(Vec::new());
5269    }
5270    plan.graph_plan
5271        .graph
5272        .edges
5273        .iter()
5274        .filter(|edge| edge.target.node_id == node_plan.node_id && edge.contract.requires_oof)
5275        .map(|edge| {
5276            if edge.contract.kind != PortKind::Prediction {
5277                return Err(DagMlError::RuntimeValidation(format!(
5278                    "edge `{}.{}` -> `{}.{}` requires OOF but is not a prediction edge",
5279                    edge.source.node_id,
5280                    edge.source.port_name,
5281                    edge.target.node_id,
5282                    edge.target.port_name
5283                )));
5284            }
5285            Ok(edge)
5286        })
5287        .collect()
5288}
5289
5290struct CollectedPredictionInput {
5291    handle: HandleRef,
5292    spec: PredictionInputSpec,
5293}
5294
5295fn collect_oof_prediction_input(
5296    plan: &ExecutionPlan,
5297    edge: &EdgeSpec,
5298    ctx: &RunContext,
5299    scope: &PhaseScope,
5300    resources: &PhaseScopeResources<'_>,
5301) -> Result<CollectedPredictionInput> {
5302    if scope.phase == Phase::Refit {
5303        if let Some(contract) = replay_prediction_cache_contract_for_edge(resources, edge) {
5304            if contract.requirement.prediction_level != PredictionLevel::Sample {
5305                let source_plan = plan
5306                    .node_plans
5307                    .get(&edge.source.node_id)
5308                    .expect("edge source has a node plan");
5309                let handle = materialize_oof_prediction_handle(
5310                    plan,
5311                    edge,
5312                    ctx,
5313                    scope,
5314                    resources,
5315                    &source_plan.controller_id,
5316                )?;
5317                return Ok(CollectedPredictionInput {
5318                    handle,
5319                    spec: prediction_input_spec_from_requirement(&contract.requirement, scope)?,
5320                });
5321            }
5322        }
5323    }
5324    let source_plan = plan
5325        .node_plans
5326        .get(&edge.source.node_id)
5327        .expect("edge source has a node plan");
5328    let prediction_level = oof_prediction_level_for_source(source_plan);
5329    if prediction_level != PredictionLevel::Sample {
5330        let blocks = match scope.phase {
5331            Phase::FitCv => validate_fit_cv_aggregated_oof_edge(
5332                plan,
5333                edge,
5334                ctx,
5335                scope,
5336                resources,
5337                prediction_level,
5338            )?,
5339            Phase::Refit => {
5340                validate_refit_aggregated_oof_edge(plan, edge, ctx, resources, prediction_level)?
5341            }
5342            _ => Vec::new(),
5343        };
5344        let handle = materialize_oof_prediction_handle(
5345            plan,
5346            edge,
5347            ctx,
5348            scope,
5349            resources,
5350            &source_plan.controller_id,
5351        )?;
5352        return Ok(CollectedPredictionInput {
5353            handle,
5354            spec: aggregated_prediction_input_spec(edge, scope, prediction_level, &blocks)?,
5355        });
5356    }
5357    let blocks = match scope.phase {
5358        Phase::FitCv => validate_fit_cv_oof_edge(plan, edge, ctx, scope)?,
5359        Phase::Refit => validate_refit_oof_edge(plan, edge, ctx)?,
5360        _ => Vec::new(),
5361    };
5362    let handle = materialize_oof_prediction_handle(
5363        plan,
5364        edge,
5365        ctx,
5366        scope,
5367        resources,
5368        &source_plan.controller_id,
5369    )?;
5370    Ok(CollectedPredictionInput {
5371        handle,
5372        spec: prediction_input_spec(edge, scope, &blocks)?,
5373    })
5374}
5375
5376fn oof_prediction_level_for_source(source_plan: &NodePlan) -> PredictionLevel {
5377    source_plan
5378        .shape_plan
5379        .as_ref()
5380        .map(|shape_plan| shape_plan.aggregation_policy.aggregation_level)
5381        .unwrap_or(PredictionLevel::Sample)
5382}
5383
5384fn replay_prediction_cache_contract_for_edge<'a>(
5385    resources: &'a PhaseScopeResources<'_>,
5386    edge: &EdgeSpec,
5387) -> Option<&'a ReplayPredictionCacheContract> {
5388    let contracts = resources.prediction_cache_contracts?;
5389    let key = bundle_prediction_requirement_key(
5390        &edge.source.node_id,
5391        &edge.source.port_name,
5392        &edge.target.node_id,
5393        &edge.target.port_name,
5394    );
5395    contracts.get(&key)
5396}
5397
5398fn materialize_oof_prediction_handle(
5399    plan: &ExecutionPlan,
5400    edge: &EdgeSpec,
5401    ctx: &RunContext,
5402    scope: &PhaseScope,
5403    resources: &PhaseScopeResources<'_>,
5404    producer_controller_id: &ControllerId,
5405) -> Result<HandleRef> {
5406    if scope.phase == Phase::Refit {
5407        if let (Some(store), Some(bundle_id), Some(contracts)) = (
5408            resources.prediction_cache_store,
5409            resources.replay_bundle_id,
5410            resources.prediction_cache_contracts,
5411        ) {
5412            let key = bundle_prediction_requirement_key(
5413                &edge.source.node_id,
5414                &edge.source.port_name,
5415                &edge.target.node_id,
5416                &edge.target.port_name,
5417            );
5418            let contract = contracts.get(&key).ok_or_else(|| {
5419                DagMlError::RuntimeValidation(format!(
5420                    "replay prediction cache store cannot materialize missing requirement `{key}`"
5421                ))
5422            })?;
5423            let handle = store.materialize(&PredictionCacheMaterializationRequest {
5424                run_id: ctx.run_id.clone(),
5425                bundle_id: bundle_id.clone(),
5426                phase: scope.phase,
5427                variant_id: scope.variant_id.clone(),
5428                requirement: contract.requirement.clone(),
5429                cache: contract.cache.clone(),
5430                producer_controller_id: producer_controller_id.clone(),
5431            })?;
5432            if handle.kind != HandleKind::Prediction {
5433                return Err(DagMlError::RuntimeValidation(format!(
5434                    "prediction cache store materialized requirement `{key}` as {:?}",
5435                    handle.kind
5436                )));
5437            }
5438            if &handle.owner_controller != producer_controller_id {
5439                return Err(DagMlError::RuntimeValidation(format!(
5440                    "prediction cache store materialized requirement `{key}` for controller `{}`, expected `{}`",
5441                    handle.owner_controller, producer_controller_id
5442                )));
5443            }
5444            return Ok(handle);
5445        }
5446    }
5447    Ok(HandleRef {
5448        handle: deterministic_oof_handle(plan, edge, ctx, scope)?,
5449        kind: HandleKind::Prediction,
5450        owner_controller: producer_controller_id.clone(),
5451    })
5452}
5453
5454fn validate_fit_cv_oof_edge<'a>(
5455    plan: &ExecutionPlan,
5456    edge: &EdgeSpec,
5457    ctx: &'a RunContext,
5458    scope: &PhaseScope,
5459) -> Result<Vec<&'a PredictionBlock>> {
5460    let fold_id = scope.fold_id.as_ref().ok_or_else(|| {
5461        DagMlError::RuntimeValidation(format!(
5462            "edge `{}.{}` -> `{}.{}` requires OOF predictions but FIT_CV has no fold scope",
5463            edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
5464        ))
5465    })?;
5466    let blocks = ctx.prediction_store.find(
5467        Some(&edge.source.node_id),
5468        Some(&PredictionPartition::Validation),
5469        Some(fold_id),
5470    );
5471    if blocks.is_empty() {
5472        return Err(missing_oof_edge_error(edge, Some(fold_id)));
5473    }
5474    if edge.contract.requires_fold_alignment {
5475        let fold_set = required_fold_set_for_oof(plan, edge)?;
5476        validate_oof_blocks_match_fold(edge, fold_set, fold_id, &blocks)?;
5477    }
5478    Ok(blocks)
5479}
5480
5481fn validate_refit_oof_edge<'a>(
5482    plan: &ExecutionPlan,
5483    edge: &EdgeSpec,
5484    ctx: &'a RunContext,
5485) -> Result<Vec<&'a PredictionBlock>> {
5486    let blocks = ctx.prediction_store.find(
5487        Some(&edge.source.node_id),
5488        Some(&PredictionPartition::Validation),
5489        None,
5490    );
5491    if blocks.is_empty() {
5492        return Err(missing_oof_edge_error(edge, None));
5493    }
5494    if edge.contract.requires_fold_alignment {
5495        let fold_set = required_fold_set_for_oof(plan, edge)?;
5496        validate_oof_blocks_cover_fold_set(edge, fold_set, &blocks)?;
5497    }
5498    Ok(blocks)
5499}
5500
5501fn validate_fit_cv_aggregated_oof_edge<'a>(
5502    plan: &ExecutionPlan,
5503    edge: &EdgeSpec,
5504    ctx: &'a RunContext,
5505    scope: &PhaseScope,
5506    resources: &PhaseScopeResources<'_>,
5507    prediction_level: PredictionLevel,
5508) -> Result<Vec<&'a AggregatedPredictionBlock>> {
5509    let fold_id = scope.fold_id.as_ref().ok_or_else(|| {
5510        DagMlError::RuntimeValidation(format!(
5511            "edge `{}.{}` -> `{}.{}` requires aggregated OOF predictions but FIT_CV has no fold scope",
5512            edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
5513        ))
5514    })?;
5515    let blocks = ctx.aggregated_prediction_store.find(
5516        Some(&edge.source.node_id),
5517        Some(&PredictionPartition::Validation),
5518        Some(fold_id),
5519        Some(prediction_level),
5520    );
5521    if blocks.is_empty() {
5522        return Err(missing_oof_edge_error(edge, Some(fold_id)));
5523    }
5524    validate_aggregated_blocks_basic(edge, prediction_level, &blocks)?;
5525    if edge.contract.requires_fold_alignment {
5526        let fold_set = required_fold_set_for_oof(plan, edge)?;
5527        let relations = coordinator_relations_for_edge(plan, edge, resources)?;
5528        validate_aggregated_oof_blocks_match_fold(
5529            edge,
5530            fold_set,
5531            &relations,
5532            prediction_level,
5533            fold_id,
5534            &blocks,
5535        )?;
5536    }
5537    Ok(blocks)
5538}
5539
5540fn validate_refit_aggregated_oof_edge<'a>(
5541    plan: &ExecutionPlan,
5542    edge: &EdgeSpec,
5543    ctx: &'a RunContext,
5544    resources: &PhaseScopeResources<'_>,
5545    prediction_level: PredictionLevel,
5546) -> Result<Vec<&'a AggregatedPredictionBlock>> {
5547    let blocks = ctx.aggregated_prediction_store.find(
5548        Some(&edge.source.node_id),
5549        Some(&PredictionPartition::Validation),
5550        None,
5551        Some(prediction_level),
5552    );
5553    if blocks.is_empty() {
5554        return Err(missing_oof_edge_error(edge, None));
5555    }
5556    validate_aggregated_blocks_basic(edge, prediction_level, &blocks)?;
5557    if edge.contract.requires_fold_alignment {
5558        let fold_set = required_fold_set_for_oof(plan, edge)?;
5559        let relations = coordinator_relations_for_edge(plan, edge, resources)?;
5560        validate_aggregated_oof_blocks_cover_fold_set(
5561            edge,
5562            fold_set,
5563            &relations,
5564            prediction_level,
5565            &blocks,
5566        )?;
5567    }
5568    Ok(blocks)
5569}
5570
5571fn validate_aggregated_blocks_basic(
5572    edge: &EdgeSpec,
5573    prediction_level: PredictionLevel,
5574    blocks: &[&AggregatedPredictionBlock],
5575) -> Result<()> {
5576    for block in blocks {
5577        block.validate_shape()?;
5578        if block.partition != PredictionPartition::Validation {
5579            return Err(DagMlError::RuntimeValidation(format!(
5580                "edge `{}.{}` -> `{}.{}` selected non-validation aggregated predictions",
5581                edge.source.node_id,
5582                edge.source.port_name,
5583                edge.target.node_id,
5584                edge.target.port_name
5585            )));
5586        }
5587        if block.level != prediction_level {
5588            return Err(DagMlError::RuntimeValidation(format!(
5589                "edge `{}.{}` -> `{}.{}` selected {:?} aggregated predictions, expected {:?}",
5590                edge.source.node_id,
5591                edge.source.port_name,
5592                edge.target.node_id,
5593                edge.target.port_name,
5594                block.level,
5595                prediction_level
5596            )));
5597        }
5598    }
5599    Ok(())
5600}
5601
5602fn prediction_input_spec(
5603    edge: &EdgeSpec,
5604    scope: &PhaseScope,
5605    blocks: &[&PredictionBlock],
5606) -> Result<PredictionInputSpec> {
5607    let sample_ids = collect_unique_oof_samples(edge, blocks)?
5608        .into_iter()
5609        .collect::<Vec<_>>();
5610    let fold_ids = blocks
5611        .iter()
5612        .filter_map(|block| block.fold_id.clone())
5613        .collect::<BTreeSet<_>>()
5614        .into_iter()
5615        .collect::<Vec<_>>();
5616    let mut prediction_width = None;
5617    let mut target_names = None;
5618    for block in blocks {
5619        let width = block.validate_shape()?;
5620        let block_target_names = if block.target_names.is_empty() {
5621            (0..width)
5622                .map(|index| format!("p{index}"))
5623                .collect::<Vec<_>>()
5624        } else {
5625            block.target_names.clone()
5626        };
5627        if prediction_width.is_some_and(|expected| expected != width) {
5628            return Err(DagMlError::RuntimeValidation(format!(
5629                "edge `{}.{}` -> `{}.{}` OOF prediction width is not stable across folds",
5630                edge.source.node_id,
5631                edge.source.port_name,
5632                edge.target.node_id,
5633                edge.target.port_name
5634            )));
5635        }
5636        if target_names
5637            .as_ref()
5638            .is_some_and(|expected| expected != &block_target_names)
5639        {
5640            return Err(DagMlError::RuntimeValidation(format!(
5641                "edge `{}.{}` -> `{}.{}` OOF target names are not stable across folds",
5642                edge.source.node_id,
5643                edge.source.port_name,
5644                edge.target.node_id,
5645                edge.target.port_name
5646            )));
5647        }
5648        prediction_width = Some(width);
5649        target_names = Some(block_target_names);
5650    }
5651    Ok(PredictionInputSpec {
5652        producer_node: edge.source.node_id.clone(),
5653        source_port: edge.source.port_name.clone(),
5654        target_port: edge.target.port_name.clone(),
5655        partition: PredictionPartition::Validation,
5656        prediction_level: PredictionLevel::Sample,
5657        fold_id: scope.fold_id.clone(),
5658        fold_ids,
5659        unit_ids: sample_ids
5660            .iter()
5661            .cloned()
5662            .map(PredictionUnitId::Sample)
5663            .collect(),
5664        sample_ids,
5665        prediction_width: prediction_width.unwrap_or_default(),
5666        target_names: target_names.unwrap_or_default(),
5667    })
5668}
5669
5670fn aggregated_prediction_input_spec(
5671    edge: &EdgeSpec,
5672    scope: &PhaseScope,
5673    prediction_level: PredictionLevel,
5674    blocks: &[&AggregatedPredictionBlock],
5675) -> Result<PredictionInputSpec> {
5676    let unit_ids = collect_unique_aggregated_oof_units(edge, prediction_level, blocks)?
5677        .into_iter()
5678        .collect::<Vec<_>>();
5679    let fold_ids = blocks
5680        .iter()
5681        .filter_map(|block| block.fold_id.clone())
5682        .collect::<BTreeSet<_>>()
5683        .into_iter()
5684        .collect::<Vec<_>>();
5685    let mut prediction_width = None;
5686    let mut target_names = None;
5687    for block in blocks {
5688        let width = block.validate_shape()?;
5689        let block_target_names = if block.target_names.is_empty() {
5690            (0..width)
5691                .map(|index| format!("p{index}"))
5692                .collect::<Vec<_>>()
5693        } else {
5694            block.target_names.clone()
5695        };
5696        if prediction_width.is_some_and(|expected| expected != width) {
5697            return Err(DagMlError::RuntimeValidation(format!(
5698                "edge `{}.{}` -> `{}.{}` aggregated OOF prediction width is not stable across folds",
5699                edge.source.node_id,
5700                edge.source.port_name,
5701                edge.target.node_id,
5702                edge.target.port_name
5703            )));
5704        }
5705        if target_names
5706            .as_ref()
5707            .is_some_and(|expected| expected != &block_target_names)
5708        {
5709            return Err(DagMlError::RuntimeValidation(format!(
5710                "edge `{}.{}` -> `{}.{}` aggregated OOF target names are not stable across folds",
5711                edge.source.node_id,
5712                edge.source.port_name,
5713                edge.target.node_id,
5714                edge.target.port_name
5715            )));
5716        }
5717        prediction_width = Some(width);
5718        target_names = Some(block_target_names);
5719    }
5720    Ok(PredictionInputSpec {
5721        producer_node: edge.source.node_id.clone(),
5722        source_port: edge.source.port_name.clone(),
5723        target_port: edge.target.port_name.clone(),
5724        partition: PredictionPartition::Validation,
5725        prediction_level,
5726        fold_id: scope.fold_id.clone(),
5727        fold_ids,
5728        unit_ids,
5729        sample_ids: Vec::new(),
5730        prediction_width: prediction_width.unwrap_or_default(),
5731        target_names: target_names.unwrap_or_default(),
5732    })
5733}
5734
5735fn prediction_input_spec_from_requirement(
5736    requirement: &BundlePredictionRequirement,
5737    scope: &PhaseScope,
5738) -> Result<PredictionInputSpec> {
5739    requirement.validate()?;
5740    Ok(PredictionInputSpec {
5741        producer_node: requirement.producer_node.clone(),
5742        source_port: requirement.source_port.clone(),
5743        target_port: requirement.target_port.clone(),
5744        partition: requirement.partition.clone(),
5745        prediction_level: requirement.prediction_level,
5746        fold_id: scope.fold_id.clone(),
5747        fold_ids: requirement.fold_ids.clone(),
5748        unit_ids: requirement.unit_ids.clone(),
5749        sample_ids: requirement.sample_ids.clone(),
5750        prediction_width: requirement.prediction_width,
5751        target_names: requirement.target_names.clone(),
5752    })
5753}
5754
5755fn missing_oof_edge_error(edge: &EdgeSpec, fold_id: Option<&FoldId>) -> DagMlError {
5756    DagMlError::RuntimeValidation(format!(
5757        "edge `{}.{}` -> `{}.{}` requires OOF validation predictions from `{}`{}",
5758        edge.source.node_id,
5759        edge.source.port_name,
5760        edge.target.node_id,
5761        edge.target.port_name,
5762        edge.source.node_id,
5763        fold_id
5764            .map(|fold_id| format!(" for fold `{fold_id}`"))
5765            .unwrap_or_default()
5766    ))
5767}
5768
5769fn required_fold_set_for_oof<'a>(plan: &'a ExecutionPlan, edge: &EdgeSpec) -> Result<&'a FoldSet> {
5770    plan.fold_set.as_ref().ok_or_else(|| {
5771        DagMlError::RuntimeValidation(format!(
5772            "edge `{}.{}` -> `{}.{}` requires fold-aligned OOF predictions but the plan has no fold set",
5773            edge.source.node_id,
5774            edge.source.port_name,
5775            edge.target.node_id,
5776            edge.target.port_name
5777        ))
5778    })
5779}
5780
5781fn validate_oof_blocks_match_fold(
5782    edge: &EdgeSpec,
5783    fold_set: &FoldSet,
5784    fold_id: &FoldId,
5785    blocks: &[&PredictionBlock],
5786) -> Result<()> {
5787    let fold = fold_set
5788        .folds
5789        .iter()
5790        .find(|fold| &fold.fold_id == fold_id)
5791        .ok_or_else(|| {
5792            DagMlError::RuntimeValidation(format!(
5793                "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5794                edge.source.node_id,
5795                edge.source.port_name,
5796                edge.target.node_id,
5797                edge.target.port_name
5798            ))
5799        })?;
5800    let actual = collect_unique_oof_samples(edge, blocks)?;
5801    let expected = fold
5802        .validation_sample_ids
5803        .iter()
5804        .cloned()
5805        .collect::<BTreeSet<_>>();
5806    if actual != expected {
5807        return Err(DagMlError::RuntimeValidation(format!(
5808            "edge `{}.{}` -> `{}.{}` OOF predictions do not match validation samples for fold `{fold_id}`",
5809            edge.source.node_id,
5810            edge.source.port_name,
5811            edge.target.node_id,
5812            edge.target.port_name
5813        )));
5814    }
5815    Ok(())
5816}
5817
5818fn validate_oof_blocks_cover_fold_set(
5819    edge: &EdgeSpec,
5820    fold_set: &FoldSet,
5821    blocks: &[&PredictionBlock],
5822) -> Result<()> {
5823    let folds = fold_set
5824        .folds
5825        .iter()
5826        .map(|fold| (&fold.fold_id, fold))
5827        .collect::<BTreeMap<_, _>>();
5828    let mut all_samples = BTreeSet::new();
5829    for block in blocks {
5830        let fold_id = block.fold_id.as_ref().ok_or_else(|| {
5831            DagMlError::RuntimeValidation(format!(
5832                "edge `{}.{}` -> `{}.{}` has OOF predictions without a fold id",
5833                edge.source.node_id,
5834                edge.source.port_name,
5835                edge.target.node_id,
5836                edge.target.port_name
5837            ))
5838        })?;
5839        let fold = folds.get(fold_id).ok_or_else(|| {
5840            DagMlError::RuntimeValidation(format!(
5841                "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5842                edge.source.node_id,
5843                edge.source.port_name,
5844                edge.target.node_id,
5845                edge.target.port_name
5846            ))
5847        })?;
5848        let block_samples = collect_unique_oof_samples(edge, &[*block])?;
5849        let expected = fold
5850            .validation_sample_ids
5851            .iter()
5852            .cloned()
5853            .collect::<BTreeSet<_>>();
5854        if block_samples != expected {
5855            return Err(DagMlError::RuntimeValidation(format!(
5856                "edge `{}.{}` -> `{}.{}` OOF predictions do not match validation samples for fold `{fold_id}`",
5857                edge.source.node_id,
5858                edge.source.port_name,
5859                edge.target.node_id,
5860                edge.target.port_name
5861            )));
5862        }
5863        for sample_id in block_samples {
5864            if !all_samples.insert(sample_id.clone()) {
5865                return Err(DagMlError::RuntimeValidation(format!(
5866                    "edge `{}.{}` -> `{}.{}` has duplicate OOF prediction for sample `{sample_id}`",
5867                    edge.source.node_id,
5868                    edge.source.port_name,
5869                    edge.target.node_id,
5870                    edge.target.port_name
5871                )));
5872            }
5873        }
5874    }
5875    let expected_all = fold_set.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
5876    if all_samples != expected_all {
5877        return Err(DagMlError::RuntimeValidation(format!(
5878            "edge `{}.{}` -> `{}.{}` OOF predictions do not cover the refit sample universe",
5879            edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
5880        )));
5881    }
5882    Ok(())
5883}
5884
5885fn validate_aggregated_oof_blocks_match_fold(
5886    edge: &EdgeSpec,
5887    fold_set: &FoldSet,
5888    relations: &SampleRelationSet,
5889    prediction_level: PredictionLevel,
5890    fold_id: &FoldId,
5891    blocks: &[&AggregatedPredictionBlock],
5892) -> Result<()> {
5893    let fold = fold_set
5894        .folds
5895        .iter()
5896        .find(|fold| &fold.fold_id == fold_id)
5897        .ok_or_else(|| {
5898            DagMlError::RuntimeValidation(format!(
5899                "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5900                edge.source.node_id,
5901                edge.source.port_name,
5902                edge.target.node_id,
5903                edge.target.port_name
5904            ))
5905        })?;
5906    validate_aggregated_fold_unit_safety(edge, relations, prediction_level, fold)?;
5907    for block in blocks {
5908        if block.fold_id.as_ref() != Some(fold_id) {
5909            return Err(DagMlError::RuntimeValidation(format!(
5910                "edge `{}.{}` -> `{}.{}` selected aggregated OOF predictions outside fold `{fold_id}`",
5911                edge.source.node_id,
5912                edge.source.port_name,
5913                edge.target.node_id,
5914                edge.target.port_name
5915            )));
5916        }
5917    }
5918    let actual = collect_unique_aggregated_oof_units(edge, prediction_level, blocks)?;
5919    let expected = expected_prediction_units_for_samples(
5920        edge,
5921        relations,
5922        prediction_level,
5923        &fold.validation_sample_ids,
5924    )?;
5925    if actual != expected {
5926        return Err(DagMlError::RuntimeValidation(format!(
5927            "edge `{}.{}` -> `{}.{}` aggregated OOF predictions do not match {:?} validation units for fold `{fold_id}`",
5928            edge.source.node_id,
5929            edge.source.port_name,
5930            edge.target.node_id,
5931            edge.target.port_name,
5932            prediction_level
5933        )));
5934    }
5935    Ok(())
5936}
5937
5938fn validate_aggregated_oof_blocks_cover_fold_set(
5939    edge: &EdgeSpec,
5940    fold_set: &FoldSet,
5941    relations: &SampleRelationSet,
5942    prediction_level: PredictionLevel,
5943    blocks: &[&AggregatedPredictionBlock],
5944) -> Result<()> {
5945    let folds = fold_set
5946        .folds
5947        .iter()
5948        .map(|fold| (fold.fold_id.clone(), fold))
5949        .collect::<BTreeMap<_, _>>();
5950    let mut blocks_by_fold = BTreeMap::<FoldId, Vec<&AggregatedPredictionBlock>>::new();
5951    for block in blocks {
5952        let fold_id = block.fold_id.as_ref().ok_or_else(|| {
5953            DagMlError::RuntimeValidation(format!(
5954                "edge `{}.{}` -> `{}.{}` has aggregated OOF predictions without a fold id",
5955                edge.source.node_id,
5956                edge.source.port_name,
5957                edge.target.node_id,
5958                edge.target.port_name
5959            ))
5960        })?;
5961        if !folds.contains_key(fold_id) {
5962            return Err(DagMlError::RuntimeValidation(format!(
5963                "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5964                edge.source.node_id,
5965                edge.source.port_name,
5966                edge.target.node_id,
5967                edge.target.port_name
5968            )));
5969        }
5970        blocks_by_fold
5971            .entry(fold_id.clone())
5972            .or_default()
5973            .push(*block);
5974    }
5975    for fold_id in folds.keys() {
5976        if !blocks_by_fold.contains_key(fold_id) {
5977            return Err(DagMlError::RuntimeValidation(format!(
5978                "edge `{}.{}` -> `{}.{}` is missing aggregated OOF predictions for fold `{fold_id}`",
5979                edge.source.node_id,
5980                edge.source.port_name,
5981                edge.target.node_id,
5982                edge.target.port_name
5983            )));
5984        }
5985    }
5986
5987    let mut all_units = BTreeSet::new();
5988    for (fold_id, fold_blocks) in blocks_by_fold {
5989        let fold = folds.get(&fold_id).expect("fold id was validated above");
5990        validate_aggregated_fold_unit_safety(edge, relations, prediction_level, fold)?;
5991        let fold_units = collect_unique_aggregated_oof_units(edge, prediction_level, &fold_blocks)?;
5992        let expected = expected_prediction_units_for_samples(
5993            edge,
5994            relations,
5995            prediction_level,
5996            &fold.validation_sample_ids,
5997        )?;
5998        if fold_units != expected {
5999            return Err(DagMlError::RuntimeValidation(format!(
6000                "edge `{}.{}` -> `{}.{}` aggregated OOF predictions do not match {:?} validation units for fold `{fold_id}`",
6001                edge.source.node_id,
6002                edge.source.port_name,
6003                edge.target.node_id,
6004                edge.target.port_name,
6005                prediction_level
6006            )));
6007        }
6008        for unit_id in fold_units {
6009            if !all_units.insert(unit_id.clone()) {
6010                return Err(DagMlError::RuntimeValidation(format!(
6011                    "edge `{}.{}` -> `{}.{}` has duplicate aggregated OOF prediction for unit `{unit_id}`",
6012                    edge.source.node_id,
6013                    edge.source.port_name,
6014                    edge.target.node_id,
6015                    edge.target.port_name
6016                )));
6017            }
6018        }
6019    }
6020
6021    let expected_all = expected_prediction_units_for_samples(
6022        edge,
6023        relations,
6024        prediction_level,
6025        &fold_set.sample_ids,
6026    )?;
6027    if all_units != expected_all {
6028        return Err(DagMlError::RuntimeValidation(format!(
6029            "edge `{}.{}` -> `{}.{}` aggregated OOF predictions do not cover the refit {:?} unit universe",
6030            edge.source.node_id,
6031            edge.source.port_name,
6032            edge.target.node_id,
6033            edge.target.port_name,
6034            prediction_level
6035        )));
6036    }
6037    Ok(())
6038}
6039
6040fn validate_aggregated_fold_unit_safety(
6041    edge: &EdgeSpec,
6042    relations: &SampleRelationSet,
6043    prediction_level: PredictionLevel,
6044    fold: &FoldAssignment,
6045) -> Result<()> {
6046    let train_units = expected_prediction_units_for_samples(
6047        edge,
6048        relations,
6049        prediction_level,
6050        &fold.train_sample_ids,
6051    )?;
6052    let validation_units = expected_prediction_units_for_samples(
6053        edge,
6054        relations,
6055        prediction_level,
6056        &fold.validation_sample_ids,
6057    )?;
6058    if let Some(unit_id) = train_units.intersection(&validation_units).next() {
6059        return Err(DagMlError::RuntimeValidation(format!(
6060            "edge `{}.{}` -> `{}.{}` fold `{}` has {:?} unit `{unit_id}` in both train and validation partitions",
6061            edge.source.node_id,
6062            edge.source.port_name,
6063            edge.target.node_id,
6064            edge.target.port_name,
6065            fold.fold_id,
6066            prediction_level
6067        )));
6068    }
6069    Ok(())
6070}
6071
6072fn collect_unique_oof_samples(
6073    edge: &EdgeSpec,
6074    blocks: &[&PredictionBlock],
6075) -> Result<BTreeSet<SampleId>> {
6076    let mut samples = BTreeSet::new();
6077    for block in blocks {
6078        if block.partition != PredictionPartition::Validation {
6079            return Err(DagMlError::RuntimeValidation(format!(
6080                "edge `{}.{}` -> `{}.{}` selected non-validation predictions",
6081                edge.source.node_id,
6082                edge.source.port_name,
6083                edge.target.node_id,
6084                edge.target.port_name
6085            )));
6086        }
6087        for sample_id in &block.sample_ids {
6088            if !samples.insert(sample_id.clone()) {
6089                return Err(DagMlError::RuntimeValidation(format!(
6090                    "edge `{}.{}` -> `{}.{}` has duplicate OOF prediction for sample `{sample_id}`",
6091                    edge.source.node_id,
6092                    edge.source.port_name,
6093                    edge.target.node_id,
6094                    edge.target.port_name
6095                )));
6096            }
6097        }
6098    }
6099    Ok(samples)
6100}
6101
6102fn collect_unique_aggregated_oof_units(
6103    edge: &EdgeSpec,
6104    prediction_level: PredictionLevel,
6105    blocks: &[&AggregatedPredictionBlock],
6106) -> Result<BTreeSet<PredictionUnitId>> {
6107    let mut unit_ids = BTreeSet::new();
6108    for block in blocks {
6109        block.validate_shape()?;
6110        if block.partition != PredictionPartition::Validation {
6111            return Err(DagMlError::RuntimeValidation(format!(
6112                "edge `{}.{}` -> `{}.{}` selected non-validation aggregated predictions",
6113                edge.source.node_id,
6114                edge.source.port_name,
6115                edge.target.node_id,
6116                edge.target.port_name
6117            )));
6118        }
6119        if block.level != prediction_level {
6120            return Err(DagMlError::RuntimeValidation(format!(
6121                "edge `{}.{}` -> `{}.{}` selected {:?} aggregated predictions, expected {:?}",
6122                edge.source.node_id,
6123                edge.source.port_name,
6124                edge.target.node_id,
6125                edge.target.port_name,
6126                block.level,
6127                prediction_level
6128            )));
6129        }
6130        for unit_id in &block.unit_ids {
6131            if !unit_ids.insert(unit_id.clone()) {
6132                return Err(DagMlError::RuntimeValidation(format!(
6133                    "edge `{}.{}` -> `{}.{}` has duplicate aggregated OOF prediction for unit `{unit_id}`",
6134                    edge.source.node_id,
6135                    edge.source.port_name,
6136                    edge.target.node_id,
6137                    edge.target.port_name
6138                )));
6139            }
6140        }
6141    }
6142    Ok(unit_ids)
6143}
6144
6145fn expected_prediction_units_for_samples(
6146    edge: &EdgeSpec,
6147    relations: &SampleRelationSet,
6148    prediction_level: PredictionLevel,
6149    sample_ids: &[SampleId],
6150) -> Result<BTreeSet<PredictionUnitId>> {
6151    sample_ids
6152        .iter()
6153        .map(|sample_id| prediction_unit_for_sample(edge, relations, prediction_level, sample_id))
6154        .collect()
6155}
6156
6157fn prediction_unit_for_sample(
6158    edge: &EdgeSpec,
6159    relations: &SampleRelationSet,
6160    prediction_level: PredictionLevel,
6161    sample_id: &SampleId,
6162) -> Result<PredictionUnitId> {
6163    match prediction_level {
6164        PredictionLevel::Sample => Ok(PredictionUnitId::Sample(sample_id.clone())),
6165        PredictionLevel::Target => relations
6166            .target_for_sample(sample_id)
6167            .cloned()
6168            .map(PredictionUnitId::Target)
6169            .ok_or_else(|| {
6170                DagMlError::RuntimeValidation(format!(
6171                    "edge `{}.{}` -> `{}.{}` needs target-level OOF predictions but sample `{sample_id}` has no target relation",
6172                    edge.source.node_id,
6173                    edge.source.port_name,
6174                    edge.target.node_id,
6175                    edge.target.port_name
6176                ))
6177            }),
6178        PredictionLevel::Group => relations
6179            .group_for_sample(sample_id)
6180            .cloned()
6181            .map(PredictionUnitId::Group)
6182            .ok_or_else(|| {
6183                DagMlError::RuntimeValidation(format!(
6184                    "edge `{}.{}` -> `{}.{}` needs group-level OOF predictions but sample `{sample_id}` has no group relation",
6185                    edge.source.node_id,
6186                    edge.source.port_name,
6187                    edge.target.node_id,
6188                    edge.target.port_name
6189                ))
6190            }),
6191        PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
6192            "edge `{}.{}` -> `{}.{}` cannot consume observation-level OOF predictions from sample folds",
6193            edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
6194        ))),
6195    }
6196}
6197
6198fn deterministic_oof_handle(
6199    plan: &ExecutionPlan,
6200    edge: &EdgeSpec,
6201    ctx: &RunContext,
6202    scope: &PhaseScope,
6203) -> Result<u64> {
6204    let fingerprint = stable_json_fingerprint(&(
6205        &plan.id,
6206        &ctx.run_id,
6207        &edge.source.node_id,
6208        &edge.source.port_name,
6209        &edge.target.node_id,
6210        &edge.target.port_name,
6211        scope.phase,
6212        &scope.variant_id,
6213        &scope.fold_id,
6214    ))?;
6215    Ok(u64::from_str_radix(&fingerprint[..16], 16).expect("sha256 hex prefix should fit into u64"))
6216}
6217
6218struct CollectedInputs {
6219    handles: BTreeMap<String, HandleRef>,
6220    data_views: BTreeMap<String, DataProviderViewSpec>,
6221    prediction_inputs: BTreeMap<String, PredictionInputSpec>,
6222}
6223
6224fn data_view_key(input_name: &str) -> String {
6225    format!("data:{input_name}")
6226}
6227
6228fn validation_data_view_key(input_name: &str) -> String {
6229    format!("{input_name}:validation")
6230}
6231
6232fn derive_output_data_views(
6233    plan: &ExecutionPlan,
6234    task: &NodeTask,
6235    result: &NodeResult,
6236) -> Result<BTreeMap<String, DataProviderViewSpec>> {
6237    let node = plan
6238        .graph_plan
6239        .graph
6240        .nodes
6241        .iter()
6242        .find(|node| node.id == task.node_plan.node_id)
6243        .expect("execution plan was validated");
6244    let mut views = BTreeMap::new();
6245    for port in node
6246        .ports
6247        .outputs
6248        .iter()
6249        .filter(|port| port.kind == PortKind::Data)
6250    {
6251        let Some(handle) = result.outputs.get(&port.name) else {
6252            continue;
6253        };
6254        if !matches!(handle.kind, HandleKind::Data | HandleKind::DataView) {
6255            return Err(DagMlError::RuntimeValidation(format!(
6256                "node `{}` emitted data output `{}` with non-data/data-view handle kind {:?}",
6257                task.node_plan.node_id, port.name, handle.kind
6258            )));
6259        }
6260        if let Some(view) = primary_output_data_view(task) {
6261            views.insert(
6262                port.name.clone(),
6263                output_data_view_for_port(task, result, &port.name, view)?,
6264            );
6265        }
6266        if let Some(validation_view) = validation_output_data_view(task) {
6267            views.insert(
6268                validation_data_view_key(&port.name),
6269                output_data_view_for_port(task, result, &port.name, validation_view)?,
6270            );
6271        }
6272    }
6273    Ok(views)
6274}
6275
6276fn output_data_view_for_port(
6277    task: &NodeTask,
6278    result: &NodeResult,
6279    port_name: &str,
6280    base_view: &DataProviderViewSpec,
6281) -> Result<DataProviderViewSpec> {
6282    let mut view = base_view.clone();
6283    if let Some(upstream_provenance) = view.extra.remove(DATA_OUTPUT_PROVENANCE_KEY) {
6284        let provenance: DataOutputProvenance =
6285            serde_json::from_value(upstream_provenance).map_err(|error| {
6286                DagMlError::RuntimeValidation(format!(
6287                    "node `{}` cannot propagate data output `{port_name}` because upstream data output provenance is invalid JSON: {error}",
6288                    task.node_plan.node_id
6289                ))
6290            })?;
6291        provenance.validate().map_err(|error| {
6292            DagMlError::RuntimeValidation(format!(
6293                "node `{}` cannot propagate data output `{port_name}` because upstream data output provenance is invalid: {error}",
6294                task.node_plan.node_id
6295            ))
6296        })?;
6297    }
6298    let shape_deltas = result
6299        .shape_deltas
6300        .iter()
6301        .filter(|delta| delta.node_id == task.node_plan.node_id)
6302        .cloned()
6303        .collect::<Vec<_>>();
6304    let mut provenance = DataOutputProvenance {
6305        schema_version: DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION,
6306        producer_node: task.node_plan.node_id.clone(),
6307        producer_port: port_name.to_string(),
6308        producer_phase: task.phase,
6309        variant_id: task.variant_id.clone(),
6310        fold_id: task.fold_id.clone(),
6311        shape_plan_fingerprint: None,
6312        aggregation_policy_fingerprint: None,
6313        feature_namespace: None,
6314        feature_schema_fingerprint: None,
6315        representation_plan: None,
6316        representation_replay_manifest: None,
6317        representation_compatibility: None,
6318        relation_delta_fingerprint: None,
6319        shape_deltas,
6320    };
6321    if let Some(shape_plan) = &task.node_plan.shape_plan {
6322        provenance.shape_plan_fingerprint = Some(stable_json_fingerprint(shape_plan)?);
6323        provenance.aggregation_policy_fingerprint =
6324            Some(stable_json_fingerprint(&shape_plan.aggregation_policy)?);
6325        provenance.feature_namespace = shape_plan.feature_namespace.clone();
6326        provenance.feature_schema_fingerprint =
6327            output_feature_schema_fingerprint(shape_plan, result);
6328    }
6329    provenance.validate()?;
6330
6331    view.extra.insert(
6332        DATA_OUTPUT_PROVENANCE_KEY.to_string(),
6333        serde_json::to_value(provenance)?,
6334    );
6335    view.validate()?;
6336    Ok(view)
6337}
6338
6339fn output_feature_schema_fingerprint(
6340    shape_plan: &crate::policy::DataModelShapePlan,
6341    result: &NodeResult,
6342) -> Option<String> {
6343    result
6344        .shape_deltas
6345        .iter()
6346        .rev()
6347        .find(|delta| delta.kind == ShapeDeltaKind::Feature)
6348        .map(|delta| delta.after_fingerprint.clone())
6349        .or_else(|| shape_plan.feature_schema_fingerprint.clone())
6350}
6351
6352fn primary_output_data_view(task: &NodeTask) -> Option<&DataProviderViewSpec> {
6353    task.data_views
6354        .values()
6355        .find(|view| view.partition != DataRequestPartition::FoldValidation)
6356        .or_else(|| task.data_views.values().next())
6357}
6358
6359fn validation_output_data_view(task: &NodeTask) -> Option<&DataProviderViewSpec> {
6360    task.data_views
6361        .values()
6362        .find(|view| view.partition == DataRequestPartition::FoldValidation)
6363}
6364
6365fn make_data_view_handle(
6366    data_provider: &dyn RuntimeDataProvider,
6367    ctx: &RunContext,
6368    node_plan: &NodePlan,
6369    scope: &PhaseScope,
6370    binding: &DataBinding,
6371    data_handle: &HandleRef,
6372    view: &DataProviderViewSpec,
6373) -> Result<HandleRef> {
6374    view.validate()?;
6375    data_provider.make_view(&DataViewRequest {
6376        run_id: ctx.run_id.clone(),
6377        node_id: node_plan.node_id.clone(),
6378        input_name: binding.input_name.clone(),
6379        phase: scope.phase,
6380        variant_id: scope.variant_id.clone(),
6381        fold_id: scope.fold_id.clone(),
6382        binding: binding.clone(),
6383        data_handle: data_handle.clone(),
6384        view: view.clone(),
6385    })
6386}
6387
6388fn data_view_for_scope(
6389    binding: &DataBinding,
6390    fold_set: Option<&FoldSet>,
6391    scope: &PhaseScope,
6392    branch_view: Option<&crate::data::BranchViewPlan>,
6393) -> Result<DataProviderViewSpec> {
6394    let partition = data_partition_for_scope(binding, scope);
6395    data_view_for_partition(binding, fold_set, scope, partition, branch_view)
6396}
6397
6398fn validation_data_view_for_scope(
6399    binding: &DataBinding,
6400    fold_set: Option<&FoldSet>,
6401    scope: &PhaseScope,
6402    branch_view: Option<&crate::data::BranchViewPlan>,
6403) -> Result<Option<DataProviderViewSpec>> {
6404    if scope.phase != Phase::FitCv || scope.fold_id.is_none() {
6405        return Ok(None);
6406    }
6407    let partition = binding.view_policy.predict_partition;
6408    if partition == data_partition_for_scope(binding, scope) {
6409        return Ok(None);
6410    }
6411    data_view_for_partition(binding, fold_set, scope, partition, branch_view).map(Some)
6412}
6413
6414/// Extract the `BranchViewPlan` that the DSL compiler stashed in the graph
6415/// node's metadata under `dsl_branch_view_plan`, if any. Returns `None` when
6416/// the node was not produced by a separation branch; returns `Err` when the
6417/// stored value cannot be deserialized as a `BranchViewPlan`. This is the
6418/// scheduler-side bridge that activates the BranchView wiring at runtime;
6419/// without it, every `DataProviderViewSpec.branch_view` would stay `None`
6420/// even when the DSL compiled `branch_view_plans` into the campaign.
6421fn branch_view_from_node_metadata(
6422    plan: &ExecutionPlan,
6423    node_id: &NodeId,
6424) -> Result<Option<crate::data::BranchViewPlan>> {
6425    let node = match plan
6426        .graph_plan
6427        .graph
6428        .nodes
6429        .iter()
6430        .find(|node| &node.id == node_id)
6431    {
6432        Some(node) => node,
6433        None => return Ok(None),
6434    };
6435    let Some(value) = node.metadata.get("dsl_branch_view_plan") else {
6436        return Ok(None);
6437    };
6438    let plan: crate::data::BranchViewPlan =
6439        serde_json::from_value(value.clone()).map_err(|error| {
6440            DagMlError::RuntimeValidation(format!(
6441                "node `{node_id}` carries malformed `dsl_branch_view_plan` metadata: {error}"
6442            ))
6443        })?;
6444    plan.validate()?;
6445    Ok(Some(plan))
6446}
6447
6448fn data_view_for_partition(
6449    binding: &DataBinding,
6450    fold_set: Option<&FoldSet>,
6451    scope: &PhaseScope,
6452    partition: DataRequestPartition,
6453    branch_view: Option<&crate::data::BranchViewPlan>,
6454) -> Result<DataProviderViewSpec> {
6455    let fold = fold_for_scope(fold_set, scope.fold_id.as_ref())?;
6456    let sample_ids = sample_ids_for_partition(partition, fold_set, fold);
6457    if binding.view_policy.require_sample_ids
6458        && matches!(
6459            partition,
6460            DataRequestPartition::FoldTrain | DataRequestPartition::FoldValidation
6461        )
6462        && scope.fold_id.is_some()
6463        && sample_ids.as_ref().is_none_or(Vec::is_empty)
6464    {
6465        return Err(DagMlError::RuntimeValidation(format!(
6466            "data binding `{}` on `{}` requires sample ids for {:?}",
6467            binding.input_name, binding.node_id, partition
6468        )));
6469    }
6470    let include_augmented = match partition {
6471        DataRequestPartition::FoldTrain | DataRequestPartition::FullTrain => {
6472            binding.view_policy.include_augmented_train
6473        }
6474        DataRequestPartition::FoldValidation | DataRequestPartition::Predict => {
6475            binding.view_policy.include_augmented_validation
6476        }
6477    };
6478    let mut extra = BTreeMap::new();
6479    extra.insert(
6480        "feature_set_id".to_string(),
6481        serde_json::Value::String(binding.feature_set_id().to_string()),
6482    );
6483    if !binding.view_policy.unsafe_flags.is_empty() {
6484        extra.insert(
6485            "unsafe_flags".to_string(),
6486            serde_json::Value::Array(
6487                binding
6488                    .view_policy
6489                    .unsafe_flags
6490                    .iter()
6491                    .cloned()
6492                    .map(serde_json::Value::String)
6493                    .collect(),
6494            ),
6495        );
6496    }
6497    let view = DataProviderViewSpec {
6498        sample_ids,
6499        partition,
6500        fold_id: match partition {
6501            DataRequestPartition::FoldTrain | DataRequestPartition::FoldValidation => {
6502                scope.fold_id.clone()
6503            }
6504            DataRequestPartition::FullTrain | DataRequestPartition::Predict => None,
6505        },
6506        source_ids: (!binding.source_ids.is_empty()).then(|| binding.source_ids.clone()),
6507        columns: None,
6508        include_augmented,
6509        include_excluded: binding.view_policy.include_excluded,
6510        branch_view: branch_view.cloned(),
6511        extra,
6512    };
6513    view.validate()?;
6514    Ok(view)
6515}
6516
6517fn data_partition_for_scope(binding: &DataBinding, scope: &PhaseScope) -> DataRequestPartition {
6518    match scope.phase {
6519        Phase::FitCv => binding.view_policy.fit_partition,
6520        Phase::Refit => DataRequestPartition::FullTrain,
6521        Phase::Predict | Phase::Explain if scope.fold_id.is_none() => DataRequestPartition::Predict,
6522        Phase::Predict | Phase::Explain => binding.view_policy.predict_partition,
6523        Phase::Compile | Phase::Plan | Phase::Select => DataRequestPartition::FullTrain,
6524    }
6525}
6526
6527fn fold_for_scope<'a>(
6528    fold_set: Option<&'a FoldSet>,
6529    fold_id: Option<&FoldId>,
6530) -> Result<Option<&'a FoldAssignment>> {
6531    let Some(fold_id) = fold_id else {
6532        return Ok(None);
6533    };
6534    let fold_set = fold_set.ok_or_else(|| {
6535        DagMlError::RuntimeValidation(format!(
6536            "fold `{fold_id}` requested but execution plan has no fold set"
6537        ))
6538    })?;
6539    fold_set
6540        .folds
6541        .iter()
6542        .find(|fold| &fold.fold_id == fold_id)
6543        .map(Some)
6544        .ok_or_else(|| {
6545            DagMlError::RuntimeValidation(format!(
6546                "fold `{fold_id}` requested but is not present in fold set `{}`",
6547                fold_set.id
6548            ))
6549        })
6550}
6551
6552/// Build the inner (nested) `FoldSet` for `node_plan` in `scope`, when an
6553/// effective inner-CV policy applies. Gated to FIT_CV with an outer fold in
6554/// scope; returns `Ok(None)` otherwise (no inner CV, or no outer fold to nest
6555/// within). The inner folds are built from the outer fold's TRAINING samples
6556/// only, so they are a subset of outer-train by construction (no leakage).
6557fn inner_fold_set_for_scope(
6558    campaign: &CampaignSpec,
6559    outer_fold_set: Option<&FoldSet>,
6560    node_plan: &NodePlan,
6561    scope: &PhaseScope,
6562) -> Result<Option<FoldSet>> {
6563    if scope.phase != Phase::FitCv {
6564        return Ok(None);
6565    }
6566    let Some(spec) =
6567        crate::fold::resolve_inner_cv(node_plan.inner_cv.as_ref(), campaign.inner_cv.as_ref())
6568    else {
6569        return Ok(None);
6570    };
6571    // Nested CV needs an outer fold to nest within. `fold_for_scope` yields
6572    // `None` only when there is no outer fold in scope (skip), and errors if a
6573    // fold was requested but is missing from the fold set.
6574    let Some(outer) = fold_for_scope(outer_fold_set, scope.fold_id.as_ref())? else {
6575        return Ok(None);
6576    };
6577    let outer_groups = &outer_fold_set
6578        .expect("fold_for_scope returned a fold, so the outer fold set is present")
6579        .sample_groups;
6580    Ok(Some(spec.build_inner_fold_set(outer, outer_groups)?))
6581}
6582
6583fn sample_ids_for_partition(
6584    partition: DataRequestPartition,
6585    fold_set: Option<&FoldSet>,
6586    fold: Option<&FoldAssignment>,
6587) -> Option<Vec<SampleId>> {
6588    match partition {
6589        DataRequestPartition::FoldTrain => fold.map(|fold| fold.train_sample_ids.clone()),
6590        DataRequestPartition::FoldValidation => fold.map(|fold| fold.validation_sample_ids.clone()),
6591        DataRequestPartition::FullTrain => fold_set.map(|fold_set| fold_set.sample_ids.clone()),
6592        DataRequestPartition::Predict => None,
6593    }
6594}
6595
6596fn preload_replay_prediction_cache_store(
6597    bundle: &ExecutionBundle,
6598    prediction_cache_store: Option<&dyn RuntimePredictionCacheStore>,
6599    ctx: &mut RunContext,
6600) -> Result<()> {
6601    if bundle.prediction_requirements.is_empty() {
6602        return Ok(());
6603    }
6604    let store = prediction_cache_store.ok_or_else(|| {
6605        DagMlError::RuntimeValidation(format!(
6606            "bundle `{}` cannot preload OOF prediction caches without a prediction cache store",
6607            bundle.bundle_id
6608        ))
6609    })?;
6610    if !ctx.prediction_store.blocks().is_empty() {
6611        return Err(DagMlError::RuntimeValidation(format!(
6612            "bundle `{}` cannot preload OOF prediction caches into a non-empty prediction store",
6613            bundle.bundle_id
6614        )));
6615    }
6616    let contracts = replay_prediction_cache_contracts(bundle)?;
6617    for contract in contracts.values() {
6618        if contract.requirement.prediction_level == PredictionLevel::Sample {
6619            let blocks = store.load_blocks(&contract.cache.requirement_key)?;
6620            if blocks.iter().any(|block| {
6621                block.producer_node != contract.requirement.producer_node
6622                    || block.partition != contract.requirement.partition
6623            }) {
6624                return Err(DagMlError::RuntimeValidation(format!(
6625                    "prediction cache store returned blocks outside requirement `{}`",
6626                    contract.cache.requirement_key
6627                )));
6628            }
6629            let payload = build_prediction_cache_payload(&contract.requirement, &blocks)?;
6630            validate_prediction_cache_payload_matches_record(&payload, &contract.cache)?;
6631            for block in &payload.blocks {
6632                ctx.prediction_store.append(block.clone())?;
6633            }
6634        } else {
6635            let blocks = store.load_aggregated_blocks(&contract.cache.requirement_key)?;
6636            if blocks.iter().any(|block| {
6637                block.producer_node != contract.requirement.producer_node
6638                    || block.partition != contract.requirement.partition
6639                    || block.level != contract.requirement.prediction_level
6640            }) {
6641                return Err(DagMlError::RuntimeValidation(format!(
6642                    "prediction cache store returned aggregated blocks outside requirement `{}`",
6643                    contract.cache.requirement_key
6644                )));
6645            }
6646            let payload =
6647                build_aggregated_prediction_cache_payload(&contract.requirement, &blocks)?;
6648            validate_prediction_cache_payload_matches_record(&payload, &contract.cache)?;
6649        }
6650    }
6651    Ok(())
6652}
6653
6654fn replay_prediction_cache_contracts(
6655    bundle: &ExecutionBundle,
6656) -> Result<BTreeMap<String, ReplayPredictionCacheContract>> {
6657    bundle.validate()?;
6658    let requirements = bundle
6659        .prediction_requirements
6660        .iter()
6661        .map(|requirement| (requirement.key(), requirement))
6662        .collect::<BTreeMap<_, _>>();
6663    let mut contracts = BTreeMap::new();
6664    for cache in &bundle.prediction_caches {
6665        let requirement = requirements.get(&cache.requirement_key).ok_or_else(|| {
6666            DagMlError::RuntimeValidation(format!(
6667                "prediction cache `{}` references unknown prediction requirement `{}`",
6668                cache.cache_id, cache.requirement_key
6669            ))
6670        })?;
6671        contracts.insert(
6672            cache.requirement_key.clone(),
6673            ReplayPredictionCacheContract {
6674                requirement: (*requirement).clone(),
6675                cache: cache.clone(),
6676            },
6677        );
6678    }
6679    Ok(contracts)
6680}
6681
6682fn materialize_replay_artifact_handles(
6683    plan: &ExecutionPlan,
6684    bundle: &ExecutionBundle,
6685    replay_request: &ReplayPhaseRequest,
6686    artifact_store: &dyn RuntimeArtifactStore,
6687    ctx: &RunContext,
6688) -> Result<MaterializedReplayArtifacts> {
6689    let mut handles = BTreeMap::<NodeId, BTreeMap<String, HandleRef>>::new();
6690    let mut inputs = BTreeMap::<NodeId, BTreeMap<String, ArtifactInputSpec>>::new();
6691    for artifact in &bundle.refit_artifacts {
6692        artifact.validate()?;
6693        let node_plan = plan.node_plans.get(&artifact.node_id).ok_or_else(|| {
6694            DagMlError::RuntimeValidation(format!(
6695                "bundle `{}` artifact references unknown node `{}`",
6696                bundle.bundle_id, artifact.node_id
6697            ))
6698        })?;
6699        if !node_plan.supported_phases.contains(&replay_request.phase) {
6700            return Err(DagMlError::RuntimeValidation(format!(
6701                "bundle `{}` artifact node `{}` does not support replay phase {:?}",
6702                bundle.bundle_id, artifact.node_id, replay_request.phase
6703            )));
6704        }
6705        let handle = artifact_store.materialize(&ArtifactMaterializationRequest {
6706            run_id: ctx.run_id.clone(),
6707            bundle_id: bundle.bundle_id.clone(),
6708            node_id: artifact.node_id.clone(),
6709            phase: replay_request.phase,
6710            variant_id: bundle.selected_variant_id.clone(),
6711            controller_id: artifact.controller_id.clone(),
6712            artifact: artifact.artifact.clone(),
6713            params_fingerprint: artifact.params_fingerprint.clone(),
6714        })?;
6715        if !matches!(handle.kind, HandleKind::Model | HandleKind::Artifact) {
6716            return Err(DagMlError::RuntimeValidation(format!(
6717                "artifact `{}` materialized as unsupported handle kind {:?}",
6718                artifact.artifact.id, handle.kind
6719            )));
6720        }
6721        if handle.owner_controller != artifact.controller_id {
6722            return Err(DagMlError::RuntimeValidation(format!(
6723                "artifact `{}` handle owner `{}` does not match controller `{}`",
6724                artifact.artifact.id, handle.owner_controller, artifact.controller_id
6725            )));
6726        }
6727        let key = refit_artifact_input_key(&artifact.artifact.id);
6728        if handles
6729            .entry(artifact.node_id.clone())
6730            .or_default()
6731            .insert(key.clone(), handle)
6732            .is_some()
6733        {
6734            return Err(DagMlError::RuntimeValidation(format!(
6735                "duplicate replay artifact input `{key}` for node `{}`",
6736                artifact.node_id
6737            )));
6738        }
6739        if inputs
6740            .entry(artifact.node_id.clone())
6741            .or_default()
6742            .insert(key.clone(), ArtifactInputSpec::from_refit_record(artifact)?)
6743            .is_some()
6744        {
6745            return Err(DagMlError::RuntimeValidation(format!(
6746                "duplicate replay artifact metadata `{key}` for node `{}`",
6747                artifact.node_id
6748            )));
6749        }
6750    }
6751    Ok(MaterializedReplayArtifacts { handles, inputs })
6752}
6753
6754fn derive_task_seed(
6755    root_seed: Option<u64>,
6756    variant_id: Option<&VariantId>,
6757    fold_id: Option<&FoldId>,
6758    node_plan: &NodePlan,
6759    phase: Phase,
6760) -> Option<u64> {
6761    root_seed.map(|root| {
6762        let mut context = SeedContext::root(root);
6763        if let Some(variant_id) = variant_id {
6764            context = context.child(format!("variant:{variant_id}"));
6765        }
6766        if let Some(fold_id) = fold_id {
6767            context = context.child(format!("fold:{fold_id}"));
6768        }
6769        context
6770            .child(format!("node:{}", node_plan.node_id))
6771            .child(format!("phase:{phase:?}"))
6772            .derive_u64("task")
6773    })
6774}
6775
6776#[cfg(test)]
6777mod explain_contract_tests {
6778    use super::*;
6779
6780    fn block(method: &str) -> ExplanationBlock {
6781        ExplanationBlock {
6782            producer_node: NodeId::new("model:base").unwrap(),
6783            method: method.to_string(),
6784            target_name: Some("y".to_string()),
6785            payload: serde_json::json!({"feature_importance": [0.5, 0.3, 0.2]}),
6786        }
6787    }
6788
6789    #[test]
6790    fn validates_well_formed_explanation() {
6791        assert!(block("shap").validate().is_ok());
6792    }
6793
6794    #[test]
6795    fn rejects_empty_method() {
6796        assert!(block("  ").validate().is_err());
6797    }
6798
6799    #[test]
6800    fn rejects_empty_target_name() {
6801        let mut b = block("shap");
6802        b.target_name = Some(String::new());
6803        assert!(b.validate().is_err());
6804    }
6805
6806    #[test]
6807    fn round_trips_through_json() {
6808        let b = block("permutation_importance");
6809        let json = serde_json::to_string(&b).expect("serialize");
6810        let parsed: ExplanationBlock = serde_json::from_str(&json).expect("deserialize");
6811        assert_eq!(parsed, b);
6812        // `target_name` is omitted when absent.
6813        let mut without = block("shap");
6814        without.target_name = None;
6815        let json = serde_json::to_string(&without).expect("serialize");
6816        assert!(!json.contains("target_name"));
6817    }
6818}
6819
6820#[cfg(test)]
6821mod tests;