1use std::cell::RefCell;
2use std::collections::{BTreeMap, BTreeSet};
3use std::fs;
4use std::io::Read;
5use std::path::{Path, PathBuf};
6
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::aggregation::{
11 aggregate_observation_predictions, aggregate_sample_predictions_by_unit,
12 AggregatedPredictionBlock, AggregationControllerInput, AggregationControllerOutput,
13 AggregationControllerResult, AggregationControllerTask, ObservationPredictionBlock,
14 PredictionUnitId,
15};
16use crate::bundle::{
17 build_aggregated_prediction_cache_payload, build_prediction_cache_payload,
18 bundle_prediction_requirement_key, validate_prediction_cache_payload_matches_record,
19 BundlePredictionCachePayload, BundlePredictionCachePayloadSet, BundlePredictionCacheRecord,
20 BundlePredictionRequirement, ExecutionBundle, RefitArtifactRecord, ReplayPhaseRequest,
21};
22use crate::campaign::stable_json_fingerprint;
23use crate::controller::{capabilities_support_fit_influence, ControllerCapability};
24use crate::data::{
25 DataBinding, DataRequestPartition, ExternalDataPlanEnvelope, RepresentationCompatibilityReport,
26 RepresentationPlan, RepresentationReplayManifest,
27};
28use crate::error::{DagMlError, Result};
29use crate::fold::{FoldAssignment, FoldSet};
30use crate::generation::{GenerationChoice, VariantPlan};
31use crate::graph::{EdgeSpec, PortKind};
32use crate::ids::{
33 ArtifactId, BranchId, BundleId, ControllerId, FoldId, LineageId, NodeId, RunId, SampleId,
34 VariantId,
35};
36use crate::oof::{PredictionBlock, PredictionPartition};
37use crate::phase::Phase;
38use crate::plan::{CampaignSpec, ExecutionPlan, NodePlan};
39use crate::policy::{
40 AggregationPolicy, FitInfluencePolicy, PredictionLevel, ShapeDelta, ShapeDeltaKind,
41};
42use crate::relation::SampleRelationSet;
43use crate::rng::SeedContext;
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum HandleKind {
48 Data,
49 DataView,
50 Model,
51 Artifact,
52 Prediction,
53 Relation,
54}
55
56#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
57pub struct HandleRef {
58 pub handle: u64,
59 pub kind: HandleKind,
60 pub owner_controller: ControllerId,
61}
62
63#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
64#[serde(rename_all = "snake_case")]
65pub enum ArtifactBackend {
66 Joblib,
67 Torch,
68 Tensorflow,
69 Onnx,
70 Safetensors,
71 Json,
72 Raw,
73}
74
75#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
76pub struct ArtifactRef {
77 pub id: ArtifactId,
78 pub kind: String,
79 pub controller_id: ControllerId,
80 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub backend: Option<ArtifactBackend>,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub uri: Option<String>,
84 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub content_fingerprint: Option<String>,
86 pub size_bytes: Option<u64>,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub plugin: Option<String>,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub plugin_version: Option<String>,
91}
92
93impl ArtifactRef {
94 pub fn validate(&self) -> Result<()> {
95 if self.kind.trim().is_empty() {
96 return Err(DagMlError::RuntimeValidation(format!(
97 "artifact `{}` has empty kind",
98 self.id
99 )));
100 }
101 validate_artifact_optional_text("uri", &self.uri, &self.id)?;
102 validate_artifact_optional_text("plugin", &self.plugin, &self.id)?;
103 validate_artifact_optional_text("plugin_version", &self.plugin_version, &self.id)?;
104 if self.plugin_version.is_some() && self.plugin.is_none() {
105 return Err(DagMlError::RuntimeValidation(format!(
106 "artifact `{}` has plugin_version without plugin",
107 self.id
108 )));
109 }
110 if let Some(content_fingerprint) = &self.content_fingerprint {
111 validate_runtime_fingerprint("artifact content", content_fingerprint)?;
112 }
113 if self.uri.is_some() && self.backend.is_none() {
114 return Err(DagMlError::RuntimeValidation(format!(
115 "artifact `{}` has uri without backend",
116 self.id
117 )));
118 }
119 if self.uri.is_some() && self.content_fingerprint.is_none() {
120 return Err(DagMlError::RuntimeValidation(format!(
121 "artifact `{}` has uri without content_fingerprint",
122 self.id
123 )));
124 }
125 Ok(())
126 }
127
128 pub fn validate_portable(&self) -> Result<()> {
133 self.validate()?;
134 let Some(uri) = self.uri.as_deref() else {
135 return Err(DagMlError::RuntimeValidation(format!(
136 "artifact `{}` is not portable: requires backend, uri and content_fingerprint",
137 self.id
138 )));
139 };
140 validate_relative_artifact_uri(&self.id, uri)
143 }
144}
145
146pub fn refit_artifact_input_key(artifact_id: &ArtifactId) -> String {
147 format!("artifact:{artifact_id}")
148}
149
150#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
151pub struct ArtifactMaterializationRequest {
152 pub run_id: RunId,
153 pub bundle_id: BundleId,
154 pub node_id: NodeId,
155 pub phase: Phase,
156 pub variant_id: Option<VariantId>,
157 pub controller_id: ControllerId,
158 pub artifact: ArtifactRef,
159 pub params_fingerprint: String,
160}
161
162#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
163pub struct ArtifactHandleRecord {
164 pub handle: HandleRef,
165 pub node_id: NodeId,
166 pub controller_id: ControllerId,
167 pub artifact: ArtifactRef,
168 pub params_fingerprint: String,
169}
170
171impl ArtifactHandleRecord {
172 pub fn validate(&self) -> Result<()> {
173 self.artifact.validate()?;
174 if !matches!(self.handle.kind, HandleKind::Model | HandleKind::Artifact) {
175 return Err(DagMlError::RuntimeValidation(format!(
176 "artifact `{}` is registered with non-artifact/model handle kind {:?}",
177 self.artifact.id, self.handle.kind
178 )));
179 }
180 if self.handle.owner_controller != self.controller_id {
181 return Err(DagMlError::RuntimeValidation(format!(
182 "artifact `{}` handle owner `{}` does not match controller `{}`",
183 self.artifact.id, self.handle.owner_controller, self.controller_id
184 )));
185 }
186 if self.artifact.controller_id != self.controller_id {
187 return Err(DagMlError::RuntimeValidation(format!(
188 "artifact `{}` controller `{}` does not match record controller `{}`",
189 self.artifact.id, self.artifact.controller_id, self.controller_id
190 )));
191 }
192 if self.params_fingerprint.trim().is_empty() {
193 return Err(DagMlError::RuntimeValidation(format!(
194 "artifact `{}` has empty params fingerprint",
195 self.artifact.id
196 )));
197 }
198 Ok(())
199 }
200}
201
202pub trait RuntimeArtifactStore {
203 fn materialize(&self, request: &ArtifactMaterializationRequest) -> Result<HandleRef>;
204}
205
206#[derive(Clone, Debug, Default)]
207pub struct InMemoryArtifactStore {
208 records: BTreeMap<ArtifactId, ArtifactHandleRecord>,
209 refit_artifacts: BTreeMap<ArtifactId, RefitArtifactRecord>,
210}
211
212impl InMemoryArtifactStore {
213 pub fn new() -> Self {
214 Self::default()
215 }
216
217 pub fn register(&mut self, artifact: &RefitArtifactRecord, handle: HandleRef) -> Result<()> {
218 artifact.validate()?;
219 let record = ArtifactHandleRecord {
220 handle,
221 node_id: artifact.node_id.clone(),
222 controller_id: artifact.controller_id.clone(),
223 artifact: artifact.artifact.clone(),
224 params_fingerprint: artifact.params_fingerprint.clone(),
225 };
226 record.validate()?;
227 if self.records.contains_key(&record.artifact.id)
228 || self.refit_artifacts.contains_key(&record.artifact.id)
229 {
230 return Err(DagMlError::RuntimeValidation(format!(
231 "duplicate artifact handle for `{}`",
232 artifact.artifact.id
233 )));
234 }
235 let previous_record = self.records.insert(record.artifact.id.clone(), record);
236 debug_assert!(previous_record.is_none());
237 let previous_artifact = self
238 .refit_artifacts
239 .insert(artifact.artifact.id.clone(), artifact.clone());
240 debug_assert!(previous_artifact.is_none());
241 Ok(())
242 }
243
244 pub fn capture_refit_artifacts(
245 &mut self,
246 task: &NodeTask,
247 result: &NodeResult,
248 ) -> Result<Vec<RefitArtifactRecord>> {
249 if task.phase != Phase::Refit {
250 return Err(DagMlError::RuntimeValidation(format!(
251 "cannot capture refit artifacts from phase {:?}",
252 task.phase
253 )));
254 }
255 let mut records = Vec::new();
256 for artifact in &result.artifacts {
257 let handle = result.artifact_handles.get(&artifact.id).ok_or_else(|| {
258 DagMlError::RuntimeValidation(format!(
259 "node `{}` emitted artifact `{}` without artifact handle",
260 task.node_plan.node_id, artifact.id
261 ))
262 })?;
263 let record = RefitArtifactRecord {
264 node_id: task.node_plan.node_id.clone(),
265 controller_id: task.node_plan.controller_id.clone(),
266 artifact: artifact.clone(),
267 params_fingerprint: task.node_plan.params_fingerprint.clone(),
268 data_requirement_keys: task
269 .node_plan
270 .data_bindings
271 .iter()
272 .map(|binding| format!("{}.{}", binding.node_id, binding.input_name))
273 .collect(),
274 prediction_requirement_keys: task
275 .prediction_inputs
276 .values()
277 .map(|spec| {
278 bundle_prediction_requirement_key(
279 &spec.producer_node,
280 &spec.source_port,
281 &task.node_plan.node_id,
282 &spec.target_port,
283 )
284 })
285 .collect(),
286 };
287 self.register(&record, handle.clone())?;
288 records.push(record);
289 }
290 Ok(records)
291 }
292
293 pub fn get(&self, artifact_id: &ArtifactId) -> Option<&ArtifactHandleRecord> {
294 self.records.get(artifact_id)
295 }
296
297 pub fn len(&self) -> usize {
298 self.records.len()
299 }
300
301 pub fn is_empty(&self) -> bool {
302 self.records.is_empty()
303 }
304
305 pub fn refit_artifacts(&self) -> Vec<RefitArtifactRecord> {
306 self.refit_artifacts.values().cloned().collect()
307 }
308}
309
310impl RuntimeArtifactStore for InMemoryArtifactStore {
311 fn materialize(&self, request: &ArtifactMaterializationRequest) -> Result<HandleRef> {
312 let record = self.records.get(&request.artifact.id).ok_or_else(|| {
313 DagMlError::RuntimeValidation(format!(
314 "artifact store is missing refit artifact `{}` for bundle `{}`",
315 request.artifact.id, request.bundle_id
316 ))
317 })?;
318 if record.node_id != request.node_id {
319 return Err(DagMlError::RuntimeValidation(format!(
320 "artifact `{}` is registered for node `{}` but requested for `{}`",
321 request.artifact.id, record.node_id, request.node_id
322 )));
323 }
324 if record.controller_id != request.controller_id {
325 return Err(DagMlError::RuntimeValidation(format!(
326 "artifact `{}` is registered for controller `{}` but requested for `{}`",
327 request.artifact.id, record.controller_id, request.controller_id
328 )));
329 }
330 if record.artifact != request.artifact {
331 return Err(DagMlError::RuntimeValidation(format!(
332 "artifact `{}` metadata does not match bundle record",
333 request.artifact.id
334 )));
335 }
336 if record.params_fingerprint != request.params_fingerprint {
337 return Err(DagMlError::RuntimeValidation(format!(
338 "artifact `{}` params fingerprint does not match bundle record",
339 request.artifact.id
340 )));
341 }
342 record.validate()?;
343 Ok(record.handle.clone())
344 }
345}
346
347pub const FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION: u32 = 1;
348pub const FILE_ARTIFACT_MANIFEST_FILE: &str = "artifact_manifest.json";
349
350fn default_file_artifact_manifest_schema_version() -> u32 {
351 FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION
352}
353
354#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
359pub struct FileArtifactManifestEntry {
360 pub node_id: NodeId,
361 pub controller_id: ControllerId,
362 pub artifact: ArtifactRef,
363 pub params_fingerprint: String,
364}
365
366impl FileArtifactManifestEntry {
367 fn from_refit_record(record: &RefitArtifactRecord) -> Result<Self> {
368 let entry = Self {
369 node_id: record.node_id.clone(),
370 controller_id: record.controller_id.clone(),
371 artifact: record.artifact.clone(),
372 params_fingerprint: record.params_fingerprint.clone(),
373 };
374 entry.validate()?;
375 Ok(entry)
376 }
377
378 pub fn validate(&self) -> Result<()> {
379 self.artifact.validate_portable()?;
380 if self.artifact.controller_id != self.controller_id {
381 return Err(DagMlError::RuntimeValidation(format!(
382 "artifact manifest entry `{}` controller `{}` does not match artifact controller `{}`",
383 self.artifact.id, self.controller_id, self.artifact.controller_id
384 )));
385 }
386 validate_runtime_fingerprint("artifact manifest params", &self.params_fingerprint)
387 }
388
389 fn matches_refit_record(&self, record: &RefitArtifactRecord) -> bool {
390 self.node_id == record.node_id
391 && self.controller_id == record.controller_id
392 && self.artifact == record.artifact
393 && self.params_fingerprint == record.params_fingerprint
394 }
395}
396
397#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
402pub struct FileArtifactManifest {
403 pub bundle_id: BundleId,
404 #[serde(default = "default_file_artifact_manifest_schema_version")]
405 pub schema_version: u32,
406 #[serde(default)]
407 pub artifacts: Vec<FileArtifactManifestEntry>,
408}
409
410impl FileArtifactManifest {
411 pub fn validate(&self) -> Result<()> {
412 if self.schema_version != FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION {
413 return Err(DagMlError::RuntimeValidation(format!(
414 "file artifact manifest for bundle `{}` uses unsupported schema_version {}, expected {}",
415 self.bundle_id, self.schema_version, FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION
416 )));
417 }
418 let mut artifact_ids = BTreeSet::new();
419 let mut uris = BTreeSet::new();
420 for entry in &self.artifacts {
421 entry.validate()?;
422 if !artifact_ids.insert(entry.artifact.id.as_str()) {
423 return Err(DagMlError::RuntimeValidation(format!(
424 "file artifact manifest for bundle `{}` has duplicate artifact id `{}`",
425 self.bundle_id, entry.artifact.id
426 )));
427 }
428 if let Some(uri) = entry.artifact.uri.as_deref() {
430 if !uris.insert(uri) {
431 return Err(DagMlError::RuntimeValidation(format!(
432 "file artifact manifest for bundle `{}` has duplicate artifact uri `{}`",
433 self.bundle_id, uri
434 )));
435 }
436 }
437 }
438 Ok(())
439 }
440
441 pub fn validate_against_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
442 self.validate()?;
443 bundle.validate()?;
444 if self.bundle_id != bundle.bundle_id {
445 return Err(DagMlError::RuntimeValidation(format!(
446 "file artifact manifest bundle `{}` does not match bundle `{}`",
447 self.bundle_id, bundle.bundle_id
448 )));
449 }
450 if self.artifacts.len() != bundle.refit_artifacts.len() {
451 return Err(DagMlError::RuntimeValidation(format!(
452 "file artifact manifest for bundle `{}` has {} artifact(s) for {} bundle refit artifact(s)",
453 self.bundle_id,
454 self.artifacts.len(),
455 bundle.refit_artifacts.len()
456 )));
457 }
458 let entries_by_id = self
459 .artifacts
460 .iter()
461 .map(|entry| (entry.artifact.id.as_str(), entry))
462 .collect::<BTreeMap<_, _>>();
463 for record in &bundle.refit_artifacts {
464 let entry = entries_by_id
465 .get(record.artifact.id.as_str())
466 .ok_or_else(|| {
467 DagMlError::RuntimeValidation(format!(
468 "file artifact manifest for bundle `{}` is missing refit artifact `{}`",
469 self.bundle_id, record.artifact.id
470 ))
471 })?;
472 if !entry.matches_refit_record(record) {
473 return Err(DagMlError::RuntimeValidation(format!(
474 "file artifact manifest entry `{}` does not match bundle refit artifact",
475 entry.artifact.id
476 )));
477 }
478 }
479 Ok(())
480 }
481}
482
483#[derive(Clone, Debug)]
490pub struct FileArtifactManifestStore {
491 root: PathBuf,
492 manifest: FileArtifactManifest,
493}
494
495impl FileArtifactManifestStore {
496 pub fn write(root: impl AsRef<Path>, bundle: &ExecutionBundle) -> Result<FileArtifactManifest> {
497 bundle.validate()?;
498 let root = root.as_ref();
499 fs::create_dir_all(root).map_err(|err| {
500 DagMlError::RuntimeValidation(format!(
501 "failed to create artifact manifest store `{}`: {err}",
502 root.display()
503 ))
504 })?;
505 let mut entries = Vec::with_capacity(bundle.refit_artifacts.len());
506 for record in &bundle.refit_artifacts {
507 entries.push(FileArtifactManifestEntry::from_refit_record(record)?);
508 }
509 entries.sort_by(|left, right| left.artifact.id.cmp(&right.artifact.id));
510 let manifest = FileArtifactManifest {
511 bundle_id: bundle.bundle_id.clone(),
512 schema_version: FILE_ARTIFACT_MANIFEST_SCHEMA_VERSION,
513 artifacts: entries,
514 };
515 manifest.validate_against_bundle(bundle)?;
516 write_runtime_json(
517 &root.join(FILE_ARTIFACT_MANIFEST_FILE),
518 &manifest,
519 "artifact manifest",
520 )?;
521 Ok(manifest)
522 }
523
524 pub fn open(root: impl Into<PathBuf>, bundle: &ExecutionBundle) -> Result<Self> {
525 bundle.validate()?;
526 let root = root.into();
527 let manifest: FileArtifactManifest =
528 read_runtime_json(&root.join(FILE_ARTIFACT_MANIFEST_FILE), "artifact manifest")?;
529 manifest.validate_against_bundle(bundle)?;
530 Ok(Self { root, manifest })
531 }
532
533 pub fn root(&self) -> &Path {
534 &self.root
535 }
536
537 pub fn manifest(&self) -> &FileArtifactManifest {
538 &self.manifest
539 }
540}
541
542#[derive(Clone, Debug, Eq, PartialEq)]
543pub struct ArtifactPayloadMaterializationRecord {
544 pub run_id: RunId,
545 pub bundle_id: BundleId,
546 pub node_id: NodeId,
547 pub phase: Phase,
548 pub variant_id: Option<VariantId>,
549 pub artifact_id: ArtifactId,
550 pub payload_uri: String,
551 pub content_fingerprint: String,
552 pub size_bytes: u64,
553 pub handle: HandleRef,
554}
555
556#[derive(Clone, Debug, Eq, PartialEq)]
557struct ArtifactPayloadMetadata {
558 uri: String,
559 content_fingerprint: String,
560 size_bytes: u64,
561}
562
563#[derive(Clone, Debug)]
564pub struct FileArtifactPayloadStore {
565 root: PathBuf,
566 manifest: FileArtifactManifest,
567 records_by_artifact_id: BTreeMap<ArtifactId, RefitArtifactRecord>,
568 materialization_records: RefCell<Vec<ArtifactPayloadMaterializationRecord>>,
569}
570
571impl FileArtifactPayloadStore {
572 pub fn write_from_source(
573 output_root: impl AsRef<Path>,
574 source_root: impl AsRef<Path>,
575 bundle: &ExecutionBundle,
576 ) -> Result<Self> {
577 bundle.validate()?;
578 let output_root = output_root.as_ref();
579 let source_root = source_root.as_ref();
580 fs::create_dir_all(output_root).map_err(|err| {
581 DagMlError::RuntimeValidation(format!(
582 "failed to create artifact payload store `{}`: {err}",
583 output_root.display()
584 ))
585 })?;
586 for record in &bundle.refit_artifacts {
587 record.artifact.validate_portable()?;
588 validate_artifact_payload_file(source_root, &record.artifact)?;
589 let source_path = artifact_payload_path(source_root, &record.artifact)?;
590 let output_path = artifact_payload_path(output_root, &record.artifact)?;
591 if let Some(parent) = output_path.parent() {
592 fs::create_dir_all(parent).map_err(|err| {
593 DagMlError::RuntimeValidation(format!(
594 "failed to create artifact payload directory `{}`: {err}",
595 parent.display()
596 ))
597 })?;
598 }
599 if source_path != output_path {
600 fs::copy(&source_path, &output_path).map_err(|err| {
601 DagMlError::RuntimeValidation(format!(
602 "failed to copy artifact payload `{}` from {} to {}: {err}",
603 record.artifact.id,
604 source_path.display(),
605 output_path.display()
606 ))
607 })?;
608 }
609 }
610 FileArtifactManifestStore::write(output_root, bundle)?;
611 Self::open(output_root.to_path_buf(), bundle)
612 }
613
614 pub fn open(root: impl Into<PathBuf>, bundle: &ExecutionBundle) -> Result<Self> {
615 bundle.validate()?;
616 let root = root.into();
617 let manifest_store = FileArtifactManifestStore::open(root.clone(), bundle)?;
618 let records_by_artifact_id = bundle
619 .refit_artifacts
620 .iter()
621 .cloned()
622 .map(|record| (record.artifact.id.clone(), record))
623 .collect::<BTreeMap<_, _>>();
624 let store = Self {
625 root,
626 manifest: manifest_store.manifest().clone(),
627 records_by_artifact_id,
628 materialization_records: RefCell::new(Vec::new()),
629 };
630 store.validate_payloads()?;
631 Ok(store)
632 }
633
634 pub fn root(&self) -> &Path {
635 &self.root
636 }
637
638 pub fn manifest(&self) -> &FileArtifactManifest {
639 &self.manifest
640 }
641
642 pub fn payload_count(&self) -> usize {
643 self.manifest.artifacts.len()
644 }
645
646 pub fn materialization_records(&self) -> Vec<ArtifactPayloadMaterializationRecord> {
647 self.materialization_records.borrow().clone()
648 }
649
650 pub fn validate_payloads(&self) -> Result<()> {
651 self.manifest.validate()?;
652 for entry in &self.manifest.artifacts {
653 let record = self
654 .records_by_artifact_id
655 .get(&entry.artifact.id)
656 .ok_or_else(|| {
657 DagMlError::RuntimeValidation(format!(
658 "artifact payload store for bundle `{}` has no bundle record for `{}`",
659 self.manifest.bundle_id, entry.artifact.id
660 ))
661 })?;
662 if !entry.matches_refit_record(record) {
663 return Err(DagMlError::RuntimeValidation(format!(
664 "artifact payload store entry `{}` does not match bundle refit artifact",
665 entry.artifact.id
666 )));
667 }
668 validate_artifact_payload_file(&self.root, &entry.artifact)?;
669 }
670 Ok(())
671 }
672}
673
674impl RuntimeArtifactStore for FileArtifactPayloadStore {
675 fn materialize(&self, request: &ArtifactMaterializationRequest) -> Result<HandleRef> {
676 request.artifact.validate_portable()?;
677 let record = self
678 .records_by_artifact_id
679 .get(&request.artifact.id)
680 .ok_or_else(|| {
681 DagMlError::RuntimeValidation(format!(
682 "artifact payload store is missing refit artifact `{}` for bundle `{}`",
683 request.artifact.id, request.bundle_id
684 ))
685 })?;
686 if record.node_id != request.node_id {
687 return Err(DagMlError::RuntimeValidation(format!(
688 "artifact `{}` is registered for node `{}` but requested for `{}`",
689 request.artifact.id, record.node_id, request.node_id
690 )));
691 }
692 if record.controller_id != request.controller_id {
693 return Err(DagMlError::RuntimeValidation(format!(
694 "artifact `{}` is registered for controller `{}` but requested for `{}`",
695 request.artifact.id, record.controller_id, request.controller_id
696 )));
697 }
698 if record.artifact != request.artifact {
699 return Err(DagMlError::RuntimeValidation(format!(
700 "artifact `{}` metadata does not match bundle record",
701 request.artifact.id
702 )));
703 }
704 if record.params_fingerprint != request.params_fingerprint {
705 return Err(DagMlError::RuntimeValidation(format!(
706 "artifact `{}` params fingerprint does not match bundle record",
707 request.artifact.id
708 )));
709 }
710 let metadata = validate_artifact_payload_file(&self.root, &request.artifact)?;
711 let fingerprint = stable_json_fingerprint(&(
712 &request.run_id,
713 &request.bundle_id,
714 &request.node_id,
715 request.phase,
716 &request.variant_id,
717 &request.artifact.id,
718 &metadata.content_fingerprint,
719 &request.params_fingerprint,
720 ))?;
721 let handle = HandleRef {
722 handle: u64::from_str_radix(&fingerprint[..16], 16)
723 .expect("sha256 hex prefix should fit into u64"),
724 kind: HandleKind::Artifact,
725 owner_controller: request.controller_id.clone(),
726 };
727 self.materialization_records
728 .borrow_mut()
729 .push(ArtifactPayloadMaterializationRecord {
730 run_id: request.run_id.clone(),
731 bundle_id: request.bundle_id.clone(),
732 node_id: request.node_id.clone(),
733 phase: request.phase,
734 variant_id: request.variant_id.clone(),
735 artifact_id: request.artifact.id.clone(),
736 payload_uri: metadata.uri,
737 content_fingerprint: metadata.content_fingerprint,
738 size_bytes: metadata.size_bytes,
739 handle: handle.clone(),
740 });
741 Ok(handle)
742 }
743}
744
745#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
746pub struct LineageRecord {
747 pub record_id: LineageId,
748 pub run_id: RunId,
749 pub node_id: NodeId,
750 pub phase: Phase,
751 pub controller_id: ControllerId,
752 pub controller_version: String,
753 pub variant_id: Option<VariantId>,
754 pub fold_id: Option<FoldId>,
755 #[serde(default)]
756 pub branch_path: Vec<BranchId>,
757 #[serde(default)]
758 pub input_lineage: Vec<LineageId>,
759 #[serde(default)]
760 pub artifact_refs: Vec<ArtifactRef>,
761 pub params_fingerprint: String,
762 pub data_model_shape_fingerprint: Option<String>,
763 pub aggregation_policy_fingerprint: Option<String>,
764 pub seed: Option<u64>,
765 #[serde(default)]
766 pub unsafe_flags: BTreeSet<String>,
767 #[serde(default)]
768 pub metrics: BTreeMap<String, f64>,
769}
770
771impl LineageRecord {
772 pub fn validate(&self) -> Result<()> {
773 if self.params_fingerprint.trim().is_empty() {
774 return Err(DagMlError::RuntimeValidation(format!(
775 "lineage `{}` has empty params fingerprint",
776 self.record_id
777 )));
778 }
779 for artifact in &self.artifact_refs {
780 artifact.validate()?;
781 }
782 Ok(())
783 }
784}
785
786#[derive(Clone, Debug, Default)]
787pub struct InMemoryLineageRecorder {
788 records: BTreeMap<LineageId, LineageRecord>,
789}
790
791impl InMemoryLineageRecorder {
792 pub fn new() -> Self {
793 Self::default()
794 }
795
796 pub fn record(&mut self, record: LineageRecord) -> Result<()> {
797 record.validate()?;
798 if self
799 .records
800 .insert(record.record_id.clone(), record)
801 .is_some()
802 {
803 return Err(DagMlError::RuntimeValidation(
804 "duplicate lineage record id".to_string(),
805 ));
806 }
807 Ok(())
808 }
809
810 pub fn get(&self, id: &LineageId) -> Option<&LineageRecord> {
811 self.records.get(id)
812 }
813
814 pub fn len(&self) -> usize {
815 self.records.len()
816 }
817
818 pub fn is_empty(&self) -> bool {
819 self.records.is_empty()
820 }
821
822 pub fn records(&self) -> impl Iterator<Item = &LineageRecord> {
823 self.records.values()
824 }
825}
826
827#[derive(Clone, Debug, Default)]
828pub struct InMemoryPredictionStore {
829 blocks: Vec<PredictionBlock>,
830}
831
832impl InMemoryPredictionStore {
833 pub fn new() -> Self {
834 Self::default()
835 }
836
837 pub fn append(&mut self, block: PredictionBlock) -> Result<()> {
838 block.validate_shape()?;
839 self.blocks.push(block);
840 Ok(())
841 }
842
843 pub fn blocks(&self) -> &[PredictionBlock] {
844 &self.blocks
845 }
846
847 pub fn find(
848 &self,
849 producer_node: Option<&NodeId>,
850 phase_partition: Option<&crate::oof::PredictionPartition>,
851 fold_id: Option<&FoldId>,
852 ) -> Vec<&PredictionBlock> {
853 self.blocks
854 .iter()
855 .filter(|block| {
856 producer_node.is_none_or(|node_id| &block.producer_node == node_id)
857 && phase_partition.is_none_or(|partition| &block.partition == partition)
858 && fold_id.is_none_or(|requested| block.fold_id.as_ref() == Some(requested))
859 })
860 .collect()
861 }
862}
863
864#[derive(Clone, Debug, Default)]
865pub struct InMemoryAggregatedPredictionStore {
866 blocks: Vec<AggregatedPredictionBlock>,
867}
868
869impl InMemoryAggregatedPredictionStore {
870 pub fn new() -> Self {
871 Self::default()
872 }
873
874 pub fn append(&mut self, block: AggregatedPredictionBlock) -> Result<()> {
875 block.validate_shape()?;
876 self.blocks.push(block);
877 Ok(())
878 }
879
880 pub fn blocks(&self) -> &[AggregatedPredictionBlock] {
881 &self.blocks
882 }
883
884 pub fn find(
885 &self,
886 producer_node: Option<&NodeId>,
887 phase_partition: Option<&PredictionPartition>,
888 fold_id: Option<&FoldId>,
889 prediction_level: Option<PredictionLevel>,
890 ) -> Vec<&AggregatedPredictionBlock> {
891 self.blocks
892 .iter()
893 .filter(|block| {
894 producer_node.is_none_or(|node_id| &block.producer_node == node_id)
895 && phase_partition.is_none_or(|partition| &block.partition == partition)
896 && fold_id.is_none_or(|requested| block.fold_id.as_ref() == Some(requested))
897 && prediction_level.is_none_or(|level| block.level == level)
898 })
899 .collect()
900 }
901}
902
903#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
904pub struct PredictionCacheMaterializationRequest {
905 pub run_id: RunId,
906 pub bundle_id: BundleId,
907 pub phase: Phase,
908 pub variant_id: Option<VariantId>,
909 pub requirement: BundlePredictionRequirement,
910 pub cache: BundlePredictionCacheRecord,
911 pub producer_controller_id: ControllerId,
912}
913
914#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
915pub struct PredictionCacheMaterializationRecord {
916 pub run_id: RunId,
917 pub bundle_id: BundleId,
918 pub phase: Phase,
919 pub variant_id: Option<VariantId>,
920 pub requirement_key: String,
921 pub cache_id: String,
922 pub handle: HandleRef,
923}
924
925pub trait RuntimePredictionCacheStore {
926 fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>>;
927 fn load_aggregated_blocks(
928 &self,
929 requirement_key: &str,
930 ) -> Result<Vec<AggregatedPredictionBlock>> {
931 Err(DagMlError::RuntimeValidation(format!(
932 "prediction cache store does not support aggregated requirement `{requirement_key}`"
933 )))
934 }
935 fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef>;
936}
937
938pub const FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION: u32 = 1;
939pub const FILE_PREDICTION_CACHE_MANIFEST_FILE: &str = "prediction_cache_manifest.json";
940
941fn default_file_prediction_cache_store_schema_version() -> u32 {
942 FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION
943}
944
945#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
946pub struct FilePredictionCacheEntry {
947 pub requirement_key: String,
948 pub cache_id: String,
949 pub file_name: String,
950 #[serde(default = "default_runtime_prediction_level")]
951 pub prediction_level: PredictionLevel,
952 #[serde(default, skip_serializing_if = "Vec::is_empty")]
953 pub unit_ids: Vec<PredictionUnitId>,
954 pub block_count: usize,
955 pub row_count: usize,
956 pub content_fingerprint: String,
957}
958
959impl FilePredictionCacheEntry {
960 pub fn validate(&self) -> Result<()> {
961 validate_runtime_non_empty("requirement_key", &self.requirement_key)?;
962 validate_runtime_non_empty("cache_id", &self.cache_id)?;
963 validate_runtime_non_empty("file_name", &self.file_name)?;
964 validate_prediction_cache_file_name(&self.file_name)?;
965 if self.block_count == 0 {
966 return Err(DagMlError::RuntimeValidation(format!(
967 "file prediction cache `{}` has zero block_count",
968 self.cache_id
969 )));
970 }
971 if self.row_count == 0 {
972 return Err(DagMlError::RuntimeValidation(format!(
973 "file prediction cache `{}` has zero row_count",
974 self.cache_id
975 )));
976 }
977 if self.prediction_level != PredictionLevel::Sample && self.unit_ids.is_empty() {
978 return Err(DagMlError::RuntimeValidation(format!(
979 "file prediction cache `{}` has no aggregated unit ids",
980 self.cache_id
981 )));
982 }
983 if self
984 .unit_ids
985 .iter()
986 .any(|unit_id| unit_id.level() != self.prediction_level)
987 {
988 return Err(DagMlError::RuntimeValidation(format!(
989 "file prediction cache `{}` has unit ids outside {:?}",
990 self.cache_id, self.prediction_level
991 )));
992 }
993 validate_runtime_fingerprint("prediction cache content", &self.content_fingerprint)
994 }
995
996 fn from_payload(payload: &crate::bundle::BundlePredictionCachePayload) -> Result<Self> {
997 Ok(Self {
998 requirement_key: payload.requirement_key.clone(),
999 cache_id: payload.cache_id.clone(),
1000 file_name: prediction_cache_payload_file_name(payload)?,
1001 prediction_level: payload.prediction_level,
1002 unit_ids: payload
1003 .aggregated_blocks
1004 .iter()
1005 .flat_map(|block| block.unit_ids.iter().cloned())
1006 .collect(),
1007 block_count: payload.block_count,
1008 row_count: payload.row_count,
1009 content_fingerprint: payload.content_fingerprint.clone(),
1010 })
1011 }
1012
1013 fn matches_record(&self, record: &BundlePredictionCacheRecord) -> bool {
1014 self.requirement_key == record.requirement_key
1015 && self.cache_id == record.cache_id
1016 && self.prediction_level == record.prediction_level
1017 && self.unit_ids == record.unit_ids
1018 && self.block_count == record.block_count
1019 && self.row_count == record.row_count
1020 && self.content_fingerprint == record.content_fingerprint
1021 }
1022}
1023
1024#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1025pub struct FilePredictionCacheManifest {
1026 pub bundle_id: BundleId,
1027 #[serde(default = "default_file_prediction_cache_store_schema_version")]
1028 pub schema_version: u32,
1029 #[serde(default)]
1030 pub caches: Vec<FilePredictionCacheEntry>,
1031}
1032
1033impl FilePredictionCacheManifest {
1034 pub fn validate(&self) -> Result<()> {
1035 if self.schema_version != FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION {
1036 return Err(DagMlError::RuntimeValidation(format!(
1037 "file prediction cache manifest for bundle `{}` uses unsupported schema_version {}, expected {}",
1038 self.bundle_id,
1039 self.schema_version,
1040 FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION
1041 )));
1042 }
1043 let mut requirement_keys = BTreeSet::new();
1044 let mut cache_ids = BTreeSet::new();
1045 let mut file_names = BTreeSet::new();
1046 for entry in &self.caches {
1047 entry.validate()?;
1048 if !requirement_keys.insert(entry.requirement_key.as_str()) {
1049 return Err(DagMlError::RuntimeValidation(format!(
1050 "file prediction cache manifest for bundle `{}` has duplicate requirement `{}`",
1051 self.bundle_id, entry.requirement_key
1052 )));
1053 }
1054 if !cache_ids.insert(entry.cache_id.as_str()) {
1055 return Err(DagMlError::RuntimeValidation(format!(
1056 "file prediction cache manifest for bundle `{}` has duplicate cache id `{}`",
1057 self.bundle_id, entry.cache_id
1058 )));
1059 }
1060 if !file_names.insert(entry.file_name.as_str()) {
1061 return Err(DagMlError::RuntimeValidation(format!(
1062 "file prediction cache manifest for bundle `{}` has duplicate file `{}`",
1063 self.bundle_id, entry.file_name
1064 )));
1065 }
1066 }
1067 Ok(())
1068 }
1069
1070 pub fn validate_against_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
1071 self.validate()?;
1072 bundle.validate()?;
1073 if self.bundle_id != bundle.bundle_id {
1074 return Err(DagMlError::RuntimeValidation(format!(
1075 "file prediction cache manifest bundle `{}` does not match bundle `{}`",
1076 self.bundle_id, bundle.bundle_id
1077 )));
1078 }
1079 if self.caches.len() != bundle.prediction_caches.len() {
1080 return Err(DagMlError::RuntimeValidation(format!(
1081 "file prediction cache manifest for bundle `{}` has {} cache(s) for {} bundle cache record(s)",
1082 self.bundle_id,
1083 self.caches.len(),
1084 bundle.prediction_caches.len()
1085 )));
1086 }
1087 let entries_by_requirement = self
1088 .caches
1089 .iter()
1090 .map(|entry| (entry.requirement_key.as_str(), entry))
1091 .collect::<BTreeMap<_, _>>();
1092 for record in &bundle.prediction_caches {
1093 let entry = entries_by_requirement
1094 .get(record.requirement_key.as_str())
1095 .ok_or_else(|| {
1096 DagMlError::RuntimeValidation(format!(
1097 "file prediction cache manifest for bundle `{}` is missing requirement `{}`",
1098 self.bundle_id, record.requirement_key
1099 ))
1100 })?;
1101 if !entry.matches_record(record) {
1102 return Err(DagMlError::RuntimeValidation(format!(
1103 "file prediction cache manifest entry `{}` does not match bundle cache record",
1104 entry.cache_id
1105 )));
1106 }
1107 }
1108 Ok(())
1109 }
1110}
1111
1112#[derive(Clone, Debug)]
1113pub struct FilePredictionCacheStore {
1114 root: PathBuf,
1115 manifest: FilePredictionCacheManifest,
1116 records_by_requirement: BTreeMap<String, BundlePredictionCacheRecord>,
1117 materialization_records: RefCell<Vec<PredictionCacheMaterializationRecord>>,
1118}
1119
1120impl FilePredictionCacheStore {
1121 pub fn write_payload_set(
1122 root: impl AsRef<Path>,
1123 bundle: &ExecutionBundle,
1124 payloads: &BundlePredictionCachePayloadSet,
1125 ) -> Result<FilePredictionCacheManifest> {
1126 payloads.validate_against_bundle(bundle)?;
1127 let root = root.as_ref();
1128 fs::create_dir_all(root).map_err(|err| {
1129 DagMlError::RuntimeValidation(format!(
1130 "failed to create prediction cache store `{}`: {err}",
1131 root.display()
1132 ))
1133 })?;
1134
1135 let mut entries = Vec::new();
1136 let records_by_requirement = bundle
1137 .prediction_caches
1138 .iter()
1139 .map(|record| (record.requirement_key.as_str(), record))
1140 .collect::<BTreeMap<_, _>>();
1141 for payload in &payloads.caches {
1142 let record = records_by_requirement
1143 .get(payload.requirement_key.as_str())
1144 .ok_or_else(|| {
1145 DagMlError::RuntimeValidation(format!(
1146 "prediction cache payload `{}` references unknown requirement `{}`",
1147 payload.cache_id, payload.requirement_key
1148 ))
1149 })?;
1150 validate_prediction_cache_payload_matches_record(payload, record)?;
1151 let entry = FilePredictionCacheEntry::from_payload(payload)?;
1152 let payload_path = root.join(&entry.file_name);
1153 write_runtime_json(&payload_path, payload, "prediction cache payload")?;
1154 entries.push(entry);
1155 }
1156 entries.sort_by(|left, right| left.requirement_key.cmp(&right.requirement_key));
1157 let manifest = FilePredictionCacheManifest {
1158 bundle_id: bundle.bundle_id.clone(),
1159 schema_version: FILE_PREDICTION_CACHE_STORE_SCHEMA_VERSION,
1160 caches: entries,
1161 };
1162 manifest.validate_against_bundle(bundle)?;
1163 write_runtime_json(
1164 &root.join(FILE_PREDICTION_CACHE_MANIFEST_FILE),
1165 &manifest,
1166 "prediction cache manifest",
1167 )?;
1168 Ok(manifest)
1169 }
1170
1171 pub fn open(root: impl Into<PathBuf>, bundle: &ExecutionBundle) -> Result<Self> {
1172 bundle.validate()?;
1173 let root = root.into();
1174 let manifest: FilePredictionCacheManifest = read_runtime_json(
1175 &root.join(FILE_PREDICTION_CACHE_MANIFEST_FILE),
1176 "prediction cache manifest",
1177 )?;
1178 manifest.validate_against_bundle(bundle)?;
1179 let records_by_requirement = bundle
1180 .prediction_caches
1181 .iter()
1182 .cloned()
1183 .map(|record| (record.requirement_key.clone(), record))
1184 .collect::<BTreeMap<_, _>>();
1185 Ok(Self {
1186 root,
1187 manifest,
1188 records_by_requirement,
1189 materialization_records: RefCell::new(Vec::new()),
1190 })
1191 }
1192
1193 pub fn manifest(&self) -> &FilePredictionCacheManifest {
1194 &self.manifest
1195 }
1196
1197 pub fn materialization_records(&self) -> Vec<PredictionCacheMaterializationRecord> {
1198 self.materialization_records.borrow().clone()
1199 }
1200
1201 fn payload_for_requirement(
1202 &self,
1203 requirement_key: &str,
1204 ) -> Result<crate::bundle::BundlePredictionCachePayload> {
1205 let entry = self
1206 .manifest
1207 .caches
1208 .iter()
1209 .find(|entry| entry.requirement_key == requirement_key)
1210 .ok_or_else(|| {
1211 DagMlError::RuntimeValidation(format!(
1212 "file prediction cache store is missing requirement `{requirement_key}`"
1213 ))
1214 })?;
1215 let record = self
1216 .records_by_requirement
1217 .get(requirement_key)
1218 .ok_or_else(|| {
1219 DagMlError::RuntimeValidation(format!(
1220 "file prediction cache store has no bundle record for requirement `{requirement_key}`"
1221 ))
1222 })?;
1223 let payload: crate::bundle::BundlePredictionCachePayload = read_runtime_json(
1224 &self.root.join(&entry.file_name),
1225 "prediction cache payload",
1226 )?;
1227 validate_prediction_cache_payload_matches_record(&payload, record)?;
1228 Ok(payload)
1229 }
1230}
1231
1232impl RuntimePredictionCacheStore for FilePredictionCacheStore {
1233 fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>> {
1234 let payload = self.payload_for_requirement(requirement_key)?;
1235 if payload.prediction_level != PredictionLevel::Sample {
1236 return Err(DagMlError::RuntimeValidation(format!(
1237 "file prediction cache store requirement `{requirement_key}` contains {:?} predictions, not sample blocks",
1238 payload.prediction_level
1239 )));
1240 }
1241 Ok(payload.blocks)
1242 }
1243
1244 fn load_aggregated_blocks(
1245 &self,
1246 requirement_key: &str,
1247 ) -> Result<Vec<AggregatedPredictionBlock>> {
1248 let payload = self.payload_for_requirement(requirement_key)?;
1249 if payload.prediction_level == PredictionLevel::Sample {
1250 return Err(DagMlError::RuntimeValidation(format!(
1251 "file prediction cache store requirement `{requirement_key}` contains sample predictions, not aggregated blocks"
1252 )));
1253 }
1254 Ok(payload.aggregated_blocks)
1255 }
1256
1257 fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef> {
1258 request.requirement.validate()?;
1259 request.cache.validate()?;
1260 let requirement_key = request.requirement.key();
1261 let record = self
1262 .records_by_requirement
1263 .get(&requirement_key)
1264 .ok_or_else(|| {
1265 DagMlError::RuntimeValidation(format!(
1266 "file prediction cache store is missing requirement `{requirement_key}`"
1267 ))
1268 })?;
1269 if record != &request.cache {
1270 return Err(DagMlError::RuntimeValidation(format!(
1271 "file prediction cache materialization request for `{requirement_key}` does not match bundle cache record"
1272 )));
1273 }
1274 let payload = self.payload_for_requirement(&requirement_key)?;
1275 validate_prediction_cache_payload_matches_record(&payload, record)?;
1276 let fingerprint = stable_json_fingerprint(&(
1277 &request.run_id,
1278 &request.bundle_id,
1279 request.phase,
1280 &request.variant_id,
1281 &request.cache.requirement_key,
1282 &request.cache.cache_id,
1283 request.cache.prediction_level,
1284 &request.cache.content_fingerprint,
1285 ))?;
1286 let handle = HandleRef {
1287 handle: u64::from_str_radix(&fingerprint[..16], 16)
1288 .expect("sha256 hex prefix should fit into u64"),
1289 kind: HandleKind::Prediction,
1290 owner_controller: request.producer_controller_id.clone(),
1291 };
1292 self.materialization_records
1293 .borrow_mut()
1294 .push(PredictionCacheMaterializationRecord {
1295 run_id: request.run_id.clone(),
1296 bundle_id: request.bundle_id.clone(),
1297 phase: request.phase,
1298 variant_id: request.variant_id.clone(),
1299 requirement_key,
1300 cache_id: request.cache.cache_id.clone(),
1301 handle: handle.clone(),
1302 });
1303 Ok(handle)
1304 }
1305}
1306
1307fn prediction_cache_payload_file_name(
1308 payload: &crate::bundle::BundlePredictionCachePayload,
1309) -> Result<String> {
1310 let fingerprint = stable_json_fingerprint(&(
1311 &payload.requirement_key,
1312 &payload.cache_id,
1313 payload.prediction_level,
1314 &payload.content_fingerprint,
1315 payload.block_count,
1316 payload.row_count,
1317 ))?;
1318 Ok(format!("prediction-cache-{}.json", &fingerprint[..16]))
1319}
1320
1321fn validate_prediction_cache_file_name(file_name: &str) -> Result<()> {
1322 if file_name == "." || file_name == ".." || file_name.contains('/') || file_name.contains('\\')
1323 {
1324 return Err(DagMlError::RuntimeValidation(format!(
1325 "prediction cache file name `{file_name}` must be a plain file name"
1326 )));
1327 }
1328 Ok(())
1329}
1330
1331#[derive(Clone, Debug, PartialEq)]
1332pub struct ColumnarPredictionCacheBlock {
1333 pub prediction_id: Option<String>,
1334 pub producer_node: NodeId,
1335 pub partition: PredictionPartition,
1336 pub fold_id: Option<FoldId>,
1337 pub prediction_level: PredictionLevel,
1338 pub unit_ids: Vec<PredictionUnitId>,
1339 pub sample_ids: Vec<SampleId>,
1340 pub target_names: Vec<String>,
1341 pub width: usize,
1342 pub columns: Vec<Vec<f64>>,
1343}
1344
1345impl ColumnarPredictionCacheBlock {
1346 pub fn from_prediction_block(block: &PredictionBlock) -> Result<Self> {
1347 let width = block.validate_shape()?;
1348 let mut columns = vec![Vec::with_capacity(block.values.len()); width];
1349 for row in &block.values {
1350 for (column_idx, value) in row.iter().enumerate() {
1351 columns[column_idx].push(*value);
1352 }
1353 }
1354 Ok(Self {
1355 prediction_id: block.prediction_id.clone(),
1356 producer_node: block.producer_node.clone(),
1357 partition: block.partition.clone(),
1358 fold_id: block.fold_id.clone(),
1359 prediction_level: PredictionLevel::Sample,
1360 unit_ids: Vec::new(),
1361 sample_ids: block.sample_ids.clone(),
1362 target_names: block.target_names.clone(),
1363 width,
1364 columns,
1365 })
1366 }
1367
1368 pub fn from_aggregated_prediction_block(block: &AggregatedPredictionBlock) -> Result<Self> {
1369 let width = block.validate_shape()?;
1370 if block.level == PredictionLevel::Sample {
1371 return Err(DagMlError::RuntimeValidation(format!(
1372 "columnar aggregated prediction block for `{}` must use target/group level, got sample",
1373 block.producer_node
1374 )));
1375 }
1376 let mut columns = vec![Vec::with_capacity(block.values.len()); width];
1377 for row in &block.values {
1378 for (column_idx, value) in row.iter().enumerate() {
1379 columns[column_idx].push(*value);
1380 }
1381 }
1382 Ok(Self {
1383 prediction_id: block.prediction_id.clone(),
1384 producer_node: block.producer_node.clone(),
1385 partition: block.partition.clone(),
1386 fold_id: block.fold_id.clone(),
1387 prediction_level: block.level,
1388 unit_ids: block.unit_ids.clone(),
1389 sample_ids: Vec::new(),
1390 target_names: block.target_names.clone(),
1391 width,
1392 columns,
1393 })
1394 }
1395
1396 pub fn row_count(&self) -> usize {
1397 match self.prediction_level {
1398 PredictionLevel::Sample => self.sample_ids.len(),
1399 PredictionLevel::Target | PredictionLevel::Group => self.unit_ids.len(),
1400 PredictionLevel::Observation => 0,
1401 }
1402 }
1403
1404 pub fn value_count(&self) -> usize {
1405 self.columns.iter().map(Vec::len).sum()
1406 }
1407
1408 pub fn validate(&self) -> Result<()> {
1409 match self.prediction_level {
1410 PredictionLevel::Observation => {
1411 return Err(DagMlError::RuntimeValidation(format!(
1412 "columnar prediction block for `{}` cannot store observation-level predictions",
1413 self.producer_node
1414 )));
1415 }
1416 PredictionLevel::Sample => {
1417 if self.sample_ids.is_empty() {
1418 return Err(DagMlError::RuntimeValidation(format!(
1419 "columnar sample prediction block for `{}` has no sample ids",
1420 self.producer_node
1421 )));
1422 }
1423 if !self.unit_ids.is_empty() {
1424 return Err(DagMlError::RuntimeValidation(format!(
1425 "columnar sample prediction block for `{}` unexpectedly carries unit ids",
1426 self.producer_node
1427 )));
1428 }
1429 }
1430 PredictionLevel::Target | PredictionLevel::Group => {
1431 if !self.sample_ids.is_empty() {
1432 return Err(DagMlError::RuntimeValidation(format!(
1433 "columnar aggregated prediction block for `{}` unexpectedly carries sample ids",
1434 self.producer_node
1435 )));
1436 }
1437 if self.unit_ids.is_empty() {
1438 return Err(DagMlError::RuntimeValidation(format!(
1439 "columnar aggregated prediction block for `{}` has no unit ids",
1440 self.producer_node
1441 )));
1442 }
1443 if self
1444 .unit_ids
1445 .iter()
1446 .any(|unit_id| unit_id.level() != self.prediction_level)
1447 {
1448 return Err(DagMlError::RuntimeValidation(format!(
1449 "columnar aggregated prediction block for `{}` carries unit ids outside {:?}",
1450 self.producer_node, self.prediction_level
1451 )));
1452 }
1453 }
1454 }
1455 if self.width == 0 {
1456 return Err(DagMlError::RuntimeValidation(format!(
1457 "columnar prediction block for `{}` has zero width",
1458 self.producer_node
1459 )));
1460 }
1461 if self.columns.len() != self.width {
1462 return Err(DagMlError::RuntimeValidation(format!(
1463 "columnar prediction block for `{}` has {} column(s), expected {}",
1464 self.producer_node,
1465 self.columns.len(),
1466 self.width
1467 )));
1468 }
1469 for (column_idx, column) in self.columns.iter().enumerate() {
1470 if column.len() != self.row_count() {
1471 return Err(DagMlError::RuntimeValidation(format!(
1472 "columnar prediction block for `{}` column {} has {} value(s), expected {}",
1473 self.producer_node,
1474 column_idx,
1475 column.len(),
1476 self.row_count()
1477 )));
1478 }
1479 }
1480 if !self.target_names.is_empty() && self.target_names.len() != self.width {
1481 return Err(DagMlError::RuntimeValidation(format!(
1482 "columnar prediction block for `{}` has {} target names for width {}",
1483 self.producer_node,
1484 self.target_names.len(),
1485 self.width
1486 )));
1487 }
1488 Ok(())
1489 }
1490
1491 pub fn to_prediction_block(&self) -> Result<PredictionBlock> {
1492 self.validate()?;
1493 if self.prediction_level != PredictionLevel::Sample {
1494 return Err(DagMlError::RuntimeValidation(format!(
1495 "columnar prediction block for `{}` contains {:?} predictions, not sample predictions",
1496 self.producer_node, self.prediction_level
1497 )));
1498 }
1499 let values = (0..self.row_count())
1500 .map(|row_idx| {
1501 self.columns
1502 .iter()
1503 .map(|column| column[row_idx])
1504 .collect::<Vec<_>>()
1505 })
1506 .collect();
1507 let block = PredictionBlock {
1508 prediction_id: self.prediction_id.clone(),
1509 producer_node: self.producer_node.clone(),
1510 partition: self.partition.clone(),
1511 fold_id: self.fold_id.clone(),
1512 sample_ids: self.sample_ids.clone(),
1513 values,
1514 target_names: self.target_names.clone(),
1515 };
1516 block.validate_shape()?;
1517 Ok(block)
1518 }
1519
1520 pub fn to_aggregated_prediction_block(&self) -> Result<AggregatedPredictionBlock> {
1521 self.validate()?;
1522 if self.prediction_level == PredictionLevel::Sample {
1523 return Err(DagMlError::RuntimeValidation(format!(
1524 "columnar prediction block for `{}` contains sample predictions, not aggregated predictions",
1525 self.producer_node
1526 )));
1527 }
1528 let values = (0..self.row_count())
1529 .map(|row_idx| {
1530 self.columns
1531 .iter()
1532 .map(|column| column[row_idx])
1533 .collect::<Vec<_>>()
1534 })
1535 .collect();
1536 let block = AggregatedPredictionBlock {
1537 prediction_id: self.prediction_id.clone(),
1538 producer_node: self.producer_node.clone(),
1539 partition: self.partition.clone(),
1540 fold_id: self.fold_id.clone(),
1541 level: self.prediction_level,
1542 unit_ids: self.unit_ids.clone(),
1543 values,
1544 target_names: self.target_names.clone(),
1545 };
1546 block.validate_shape()?;
1547 Ok(block)
1548 }
1549}
1550
1551#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1552pub struct ColumnarPredictionCacheManifest {
1553 pub requirement_key: String,
1554 pub cache_id: String,
1555 pub prediction_level: PredictionLevel,
1556 pub block_count: usize,
1557 pub row_count: usize,
1558 pub prediction_width: usize,
1559 pub value_count: usize,
1560 pub estimated_value_bytes: usize,
1561 pub content_fingerprint: String,
1562}
1563
1564#[derive(Clone, Debug, PartialEq)]
1565struct ColumnarPredictionCacheEntry {
1566 cache: BundlePredictionCacheRecord,
1567 blocks: Vec<ColumnarPredictionCacheBlock>,
1568}
1569
1570impl ColumnarPredictionCacheEntry {
1571 fn from_payload(
1572 payload: BundlePredictionCachePayload,
1573 cache: BundlePredictionCacheRecord,
1574 ) -> Result<Self> {
1575 validate_prediction_cache_payload_matches_record(&payload, &cache)?;
1576 let blocks = match payload.prediction_level {
1577 PredictionLevel::Sample => payload
1578 .blocks
1579 .iter()
1580 .map(ColumnarPredictionCacheBlock::from_prediction_block)
1581 .collect::<Result<Vec<_>>>()?,
1582 PredictionLevel::Target | PredictionLevel::Group => payload
1583 .aggregated_blocks
1584 .iter()
1585 .map(ColumnarPredictionCacheBlock::from_aggregated_prediction_block)
1586 .collect::<Result<Vec<_>>>()?,
1587 PredictionLevel::Observation => {
1588 return Err(DagMlError::RuntimeValidation(format!(
1589 "columnar prediction cache payload `{}` cannot use observation-level predictions",
1590 payload.cache_id
1591 )));
1592 }
1593 };
1594 let entry = Self { cache, blocks };
1595 entry.validate()?;
1596 Ok(entry)
1597 }
1598
1599 fn validate(&self) -> Result<()> {
1600 self.cache.validate()?;
1601 if self.blocks.len() != self.cache.block_count {
1602 return Err(DagMlError::RuntimeValidation(format!(
1603 "columnar prediction cache `{}` has {} block(s), expected {}",
1604 self.cache.cache_id,
1605 self.blocks.len(),
1606 self.cache.block_count
1607 )));
1608 }
1609 let mut row_count = 0usize;
1610 let mut value_count = 0usize;
1611 for block in &self.blocks {
1612 block.validate()?;
1613 if block.prediction_level != self.cache.prediction_level {
1614 return Err(DagMlError::RuntimeValidation(format!(
1615 "columnar prediction cache `{}` contains a {:?} block, expected {:?}",
1616 self.cache.cache_id, block.prediction_level, self.cache.prediction_level
1617 )));
1618 }
1619 if block.partition != self.cache.partition {
1620 return Err(DagMlError::RuntimeValidation(format!(
1621 "columnar prediction cache `{}` contains a block from partition {:?}",
1622 self.cache.cache_id, block.partition
1623 )));
1624 }
1625 row_count += block.row_count();
1626 value_count += block.value_count();
1627 }
1628 if row_count != self.cache.row_count {
1629 return Err(DagMlError::RuntimeValidation(format!(
1630 "columnar prediction cache `{}` has {} row(s), expected {}",
1631 self.cache.cache_id, row_count, self.cache.row_count
1632 )));
1633 }
1634 let expected_values = self
1635 .cache
1636 .row_count
1637 .checked_mul(self.cache.prediction_width)
1638 .ok_or_else(|| {
1639 DagMlError::RuntimeValidation(format!(
1640 "columnar prediction cache `{}` value count overflow",
1641 self.cache.cache_id
1642 ))
1643 })?;
1644 if value_count != expected_values {
1645 return Err(DagMlError::RuntimeValidation(format!(
1646 "columnar prediction cache `{}` has {} value(s), expected {}",
1647 self.cache.cache_id, value_count, expected_values
1648 )));
1649 }
1650 Ok(())
1651 }
1652
1653 fn to_blocks(&self) -> Result<Vec<PredictionBlock>> {
1654 self.validate()?;
1655 self.blocks
1656 .iter()
1657 .map(ColumnarPredictionCacheBlock::to_prediction_block)
1658 .collect()
1659 }
1660
1661 fn to_aggregated_blocks(&self) -> Result<Vec<AggregatedPredictionBlock>> {
1662 self.validate()?;
1663 self.blocks
1664 .iter()
1665 .map(ColumnarPredictionCacheBlock::to_aggregated_prediction_block)
1666 .collect()
1667 }
1668
1669 fn validate_against_cache_record(&self, cache: &BundlePredictionCacheRecord) -> Result<()> {
1670 if &self.cache != cache {
1671 return Err(DagMlError::RuntimeValidation(format!(
1672 "columnar prediction cache materialization request for `{}` does not match bundle cache record",
1673 cache.requirement_key
1674 )));
1675 }
1676 let (blocks, aggregated_blocks) = match self.cache.prediction_level {
1677 PredictionLevel::Sample => (self.to_blocks()?, Vec::new()),
1678 PredictionLevel::Target | PredictionLevel::Group => {
1679 (Vec::new(), self.to_aggregated_blocks()?)
1680 }
1681 PredictionLevel::Observation => {
1682 return Err(DagMlError::RuntimeValidation(format!(
1683 "columnar prediction cache `{}` cannot materialize observation-level predictions",
1684 self.cache.cache_id
1685 )));
1686 }
1687 };
1688 let payload = BundlePredictionCachePayload {
1689 requirement_key: self.cache.requirement_key.clone(),
1690 cache_id: self.cache.cache_id.clone(),
1691 format: self.cache.format.clone(),
1692 partition: self.cache.partition.clone(),
1693 prediction_level: self.cache.prediction_level,
1694 block_count: self.cache.block_count,
1695 row_count: self.cache.row_count,
1696 content_fingerprint: self.cache.content_fingerprint.clone(),
1697 blocks,
1698 aggregated_blocks,
1699 };
1700 validate_prediction_cache_payload_matches_record(&payload, cache)
1701 }
1702
1703 fn manifest(&self) -> ColumnarPredictionCacheManifest {
1704 let value_count = self
1705 .blocks
1706 .iter()
1707 .map(ColumnarPredictionCacheBlock::value_count)
1708 .sum::<usize>();
1709 ColumnarPredictionCacheManifest {
1710 requirement_key: self.cache.requirement_key.clone(),
1711 cache_id: self.cache.cache_id.clone(),
1712 prediction_level: self.cache.prediction_level,
1713 block_count: self.cache.block_count,
1714 row_count: self.cache.row_count,
1715 prediction_width: self.cache.prediction_width,
1716 value_count,
1717 estimated_value_bytes: value_count * std::mem::size_of::<f64>(),
1718 content_fingerprint: self.cache.content_fingerprint.clone(),
1719 }
1720 }
1721}
1722
1723#[derive(Clone, Debug, Default)]
1724pub struct ColumnarPredictionCacheStore {
1725 entries: BTreeMap<String, ColumnarPredictionCacheEntry>,
1726 materialization_records: RefCell<Vec<PredictionCacheMaterializationRecord>>,
1727}
1728
1729impl ColumnarPredictionCacheStore {
1730 pub fn from_payloads(
1731 bundle: &ExecutionBundle,
1732 payloads: BundlePredictionCachePayloadSet,
1733 ) -> Result<Self> {
1734 payloads.validate_against_bundle(bundle)?;
1735 let records_by_requirement = bundle
1736 .prediction_caches
1737 .iter()
1738 .cloned()
1739 .map(|cache| (cache.requirement_key.clone(), cache))
1740 .collect::<BTreeMap<_, _>>();
1741 let mut entries = BTreeMap::new();
1742 for payload in payloads.caches {
1743 let cache = records_by_requirement
1744 .get(&payload.requirement_key)
1745 .cloned()
1746 .ok_or_else(|| {
1747 DagMlError::RuntimeValidation(format!(
1748 "columnar prediction cache payload `{}` references unknown requirement `{}`",
1749 payload.cache_id, payload.requirement_key
1750 ))
1751 })?;
1752 let requirement_key = payload.requirement_key.clone();
1753 let previous = entries.insert(
1754 requirement_key,
1755 ColumnarPredictionCacheEntry::from_payload(payload, cache)?,
1756 );
1757 debug_assert!(previous.is_none());
1758 }
1759 Ok(Self {
1760 entries,
1761 materialization_records: RefCell::new(Vec::new()),
1762 })
1763 }
1764
1765 pub fn entry_count(&self) -> usize {
1766 self.entries.len()
1767 }
1768
1769 pub fn manifests(&self) -> Vec<ColumnarPredictionCacheManifest> {
1770 self.entries
1771 .values()
1772 .map(ColumnarPredictionCacheEntry::manifest)
1773 .collect()
1774 }
1775
1776 pub fn materialization_records(&self) -> Vec<PredictionCacheMaterializationRecord> {
1777 self.materialization_records.borrow().clone()
1778 }
1779}
1780
1781impl RuntimePredictionCacheStore for ColumnarPredictionCacheStore {
1782 fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>> {
1783 let entry = self.entries.get(requirement_key).ok_or_else(|| {
1784 DagMlError::RuntimeValidation(format!(
1785 "columnar prediction cache store is missing requirement `{requirement_key}`"
1786 ))
1787 })?;
1788 if entry.cache.prediction_level != PredictionLevel::Sample {
1789 return Err(DagMlError::RuntimeValidation(format!(
1790 "columnar prediction cache store requirement `{requirement_key}` contains {:?} predictions, not sample blocks",
1791 entry.cache.prediction_level
1792 )));
1793 }
1794 entry.validate_against_cache_record(&entry.cache)?;
1795 entry.to_blocks()
1796 }
1797
1798 fn load_aggregated_blocks(
1799 &self,
1800 requirement_key: &str,
1801 ) -> Result<Vec<AggregatedPredictionBlock>> {
1802 let entry = self.entries.get(requirement_key).ok_or_else(|| {
1803 DagMlError::RuntimeValidation(format!(
1804 "columnar prediction cache store is missing requirement `{requirement_key}`"
1805 ))
1806 })?;
1807 if entry.cache.prediction_level == PredictionLevel::Sample {
1808 return Err(DagMlError::RuntimeValidation(format!(
1809 "columnar prediction cache store requirement `{requirement_key}` contains sample predictions, not aggregated blocks"
1810 )));
1811 }
1812 entry.validate_against_cache_record(&entry.cache)?;
1813 entry.to_aggregated_blocks()
1814 }
1815
1816 fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef> {
1817 request.requirement.validate()?;
1818 request.cache.validate()?;
1819 let requirement_key = request.requirement.key();
1820 if requirement_key != request.cache.requirement_key {
1821 return Err(DagMlError::RuntimeValidation(format!(
1822 "columnar prediction cache materialization request for `{}` uses cache `{}` with mismatched requirement `{}`",
1823 requirement_key, request.cache.cache_id, request.cache.requirement_key
1824 )));
1825 }
1826 let entry = self.entries.get(&requirement_key).ok_or_else(|| {
1827 DagMlError::RuntimeValidation(format!(
1828 "columnar prediction cache store is missing requirement `{requirement_key}`"
1829 ))
1830 })?;
1831 entry.validate_against_cache_record(&request.cache)?;
1832 let fingerprint = stable_json_fingerprint(&(
1833 &request.run_id,
1834 &request.bundle_id,
1835 request.phase,
1836 &request.variant_id,
1837 &request.cache.requirement_key,
1838 &request.cache.cache_id,
1839 request.cache.prediction_level,
1840 &request.cache.content_fingerprint,
1841 ))?;
1842 let handle = HandleRef {
1843 handle: u64::from_str_radix(&fingerprint[..16], 16)
1844 .expect("sha256 hex prefix should fit into u64"),
1845 kind: HandleKind::Prediction,
1846 owner_controller: request.producer_controller_id.clone(),
1847 };
1848 self.materialization_records
1849 .borrow_mut()
1850 .push(PredictionCacheMaterializationRecord {
1851 run_id: request.run_id.clone(),
1852 bundle_id: request.bundle_id.clone(),
1853 phase: request.phase,
1854 variant_id: request.variant_id.clone(),
1855 requirement_key,
1856 cache_id: request.cache.cache_id.clone(),
1857 handle: handle.clone(),
1858 });
1859 Ok(handle)
1860 }
1861}
1862
1863fn validate_runtime_non_empty(label: &str, value: &str) -> Result<()> {
1864 if value.trim().is_empty() {
1865 return Err(DagMlError::RuntimeValidation(format!("{label} is empty")));
1866 }
1867 Ok(())
1868}
1869
1870fn validate_artifact_optional_text(
1871 label: &str,
1872 value: &Option<String>,
1873 artifact_id: &ArtifactId,
1874) -> Result<()> {
1875 let Some(value) = value else {
1876 return Ok(());
1877 };
1878 if value.trim().is_empty() {
1879 return Err(DagMlError::RuntimeValidation(format!(
1880 "artifact `{artifact_id}` has empty {label}"
1881 )));
1882 }
1883 if value.chars().any(char::is_control) {
1884 return Err(DagMlError::RuntimeValidation(format!(
1885 "artifact `{artifact_id}` has control characters in {label}"
1886 )));
1887 }
1888 Ok(())
1889}
1890
1891fn artifact_payload_path(root: &Path, artifact: &ArtifactRef) -> Result<PathBuf> {
1892 artifact.validate_portable()?;
1893 let uri = artifact
1894 .uri
1895 .as_deref()
1896 .expect("portable artifact validation requires uri");
1897 Ok(root.join(uri))
1898}
1899
1900fn validate_artifact_payload_file(
1901 root: &Path,
1902 artifact: &ArtifactRef,
1903) -> Result<ArtifactPayloadMetadata> {
1904 artifact.validate_portable()?;
1905 let uri = artifact
1906 .uri
1907 .as_deref()
1908 .expect("portable artifact validation requires uri")
1909 .to_string();
1910 let path = artifact_payload_path(root, artifact)?;
1911 validate_payload_path_stays_within_root(root, &path, artifact)?;
1912 let metadata = fs::metadata(&path).map_err(|err| {
1913 DagMlError::RuntimeValidation(format!(
1914 "failed to stat artifact payload `{}` at {}: {err}",
1915 artifact.id,
1916 path.display()
1917 ))
1918 })?;
1919 if !metadata.is_file() {
1920 return Err(DagMlError::RuntimeValidation(format!(
1921 "artifact payload `{}` at {} is not a regular file",
1922 artifact.id,
1923 path.display()
1924 )));
1925 }
1926 let size_bytes = metadata.len();
1927 if let Some(expected_size) = artifact.size_bytes {
1928 if expected_size != size_bytes {
1929 return Err(DagMlError::RuntimeValidation(format!(
1930 "artifact payload `{}` size mismatch: expected {}, got {}",
1931 artifact.id, expected_size, size_bytes
1932 )));
1933 }
1934 }
1935 let content_fingerprint =
1936 sha256_file_hex(&path, &format!("artifact payload `{}`", artifact.id))?;
1937 let expected_fingerprint = artifact
1938 .content_fingerprint
1939 .as_deref()
1940 .expect("portable artifact validation requires content_fingerprint");
1941 if !content_fingerprint.eq_ignore_ascii_case(expected_fingerprint) {
1942 return Err(DagMlError::RuntimeValidation(format!(
1943 "artifact payload `{}` content fingerprint mismatch",
1944 artifact.id
1945 )));
1946 }
1947 Ok(ArtifactPayloadMetadata {
1948 uri,
1949 content_fingerprint,
1950 size_bytes,
1951 })
1952}
1953
1954fn validate_payload_path_stays_within_root(
1955 root: &Path,
1956 path: &Path,
1957 artifact: &ArtifactRef,
1958) -> Result<()> {
1959 let root = fs::canonicalize(root).map_err(|err| {
1960 DagMlError::RuntimeValidation(format!(
1961 "failed to canonicalize artifact payload root `{}`: {err}",
1962 root.display()
1963 ))
1964 })?;
1965 let path = fs::canonicalize(path).map_err(|err| {
1966 DagMlError::RuntimeValidation(format!(
1967 "failed to canonicalize artifact payload `{}` at {}: {err}",
1968 artifact.id,
1969 path.display()
1970 ))
1971 })?;
1972 if !path.starts_with(&root) {
1973 return Err(DagMlError::RuntimeValidation(format!(
1974 "artifact payload `{}` resolves outside store root `{}`",
1975 artifact.id,
1976 root.display()
1977 )));
1978 }
1979 Ok(())
1980}
1981
1982fn sha256_file_hex(path: &Path, label: &str) -> Result<String> {
1983 let mut file = fs::File::open(path).map_err(|err| {
1984 DagMlError::RuntimeValidation(format!(
1985 "failed to open {label} at {}: {err}",
1986 path.display()
1987 ))
1988 })?;
1989 let mut hasher = Sha256::new();
1990 let mut buffer = [0u8; 64 * 1024];
1991 loop {
1992 let read = file.read(&mut buffer).map_err(|err| {
1993 DagMlError::RuntimeValidation(format!(
1994 "failed to read {label} at {}: {err}",
1995 path.display()
1996 ))
1997 })?;
1998 if read == 0 {
1999 break;
2000 }
2001 hasher.update(&buffer[..read]);
2002 }
2003 Ok(bytes_to_hex(&hasher.finalize()))
2004}
2005
2006#[cfg(test)]
2007fn sha256_bytes_hex(bytes: &[u8]) -> String {
2008 bytes_to_hex(&Sha256::digest(bytes))
2009}
2010
2011fn bytes_to_hex(bytes: &[u8]) -> String {
2012 let mut out = String::with_capacity(bytes.len() * 2);
2013 for byte in bytes {
2014 use std::fmt::Write as _;
2015 write!(&mut out, "{byte:02x}").expect("writing to String cannot fail");
2016 }
2017 out
2018}
2019
2020fn validate_relative_artifact_uri(artifact_id: &ArtifactId, uri: &str) -> Result<()> {
2027 if uri.is_empty() {
2028 return Err(DagMlError::RuntimeValidation(format!(
2029 "artifact `{artifact_id}` has empty uri"
2030 )));
2031 }
2032 if uri.chars().any(char::is_control) {
2033 return Err(DagMlError::RuntimeValidation(format!(
2034 "artifact `{artifact_id}` uri has control characters"
2035 )));
2036 }
2037 if uri.starts_with('/') || uri.starts_with('\\') {
2038 return Err(DagMlError::RuntimeValidation(format!(
2039 "artifact `{artifact_id}` uri `{uri}` must be a relative path"
2040 )));
2041 }
2042 let mut prefix = uri.chars();
2043 if let (Some(drive), Some(':')) = (prefix.next(), prefix.next()) {
2044 if drive.is_ascii_alphabetic() {
2045 return Err(DagMlError::RuntimeValidation(format!(
2046 "artifact `{artifact_id}` uri `{uri}` must be a relative path"
2047 )));
2048 }
2049 }
2050 let first_segment = uri.split(['/', '\\']).next().unwrap_or(uri);
2054 if first_segment.contains(':') {
2055 return Err(DagMlError::RuntimeValidation(format!(
2056 "artifact `{artifact_id}` uri `{uri}` must not include a scheme or colon in its first path segment"
2057 )));
2058 }
2059 for segment in uri.split(['/', '\\']) {
2060 if segment == ".." {
2061 return Err(DagMlError::RuntimeValidation(format!(
2062 "artifact `{artifact_id}` uri `{uri}` must not contain `..` components"
2063 )));
2064 }
2065 }
2066 Ok(())
2067}
2068
2069fn validate_runtime_fingerprint(label: &str, value: &str) -> Result<()> {
2070 if value.len() != 64 || !value.bytes().all(|byte| byte.is_ascii_hexdigit()) {
2071 return Err(DagMlError::RuntimeValidation(format!(
2072 "{label} fingerprint must be a 64-character hex digest"
2073 )));
2074 }
2075 Ok(())
2076}
2077
2078fn read_runtime_json<T: serde::de::DeserializeOwned>(path: &Path, label: &str) -> Result<T> {
2079 let data = fs::read(path).map_err(|err| {
2080 DagMlError::RuntimeValidation(format!(
2081 "failed to read {label} at {}: {err}",
2082 path.display()
2083 ))
2084 })?;
2085 serde_json::from_slice(&data).map_err(|err| {
2086 DagMlError::RuntimeValidation(format!(
2087 "failed to parse {label} at {}: {err}",
2088 path.display()
2089 ))
2090 })
2091}
2092
2093fn write_runtime_json<T: Serialize>(path: &Path, value: &T, label: &str) -> Result<()> {
2094 let mut data = serde_json::to_vec_pretty(value).map_err(|err| {
2095 DagMlError::RuntimeValidation(format!("failed to serialize {label}: {err}"))
2096 })?;
2097 data.push(b'\n');
2098 fs::write(path, data).map_err(|err| {
2099 DagMlError::RuntimeValidation(format!(
2100 "failed to write {label} at {}: {err}",
2101 path.display()
2102 ))
2103 })
2104}
2105
2106#[derive(Clone, Debug, Default)]
2107pub struct InMemoryPredictionCacheStore {
2108 payloads: BTreeMap<String, crate::bundle::BundlePredictionCachePayload>,
2109 materialization_records: RefCell<Vec<PredictionCacheMaterializationRecord>>,
2110}
2111
2112impl InMemoryPredictionCacheStore {
2113 pub fn from_payloads(
2114 bundle: &ExecutionBundle,
2115 payloads: BundlePredictionCachePayloadSet,
2116 ) -> Result<Self> {
2117 payloads.validate_against_bundle(bundle)?;
2118 Ok(Self {
2119 payloads: payloads
2120 .caches
2121 .into_iter()
2122 .map(|payload| (payload.requirement_key.clone(), payload))
2123 .collect(),
2124 materialization_records: RefCell::new(Vec::new()),
2125 })
2126 }
2127
2128 pub fn payload_count(&self) -> usize {
2129 self.payloads.len()
2130 }
2131
2132 pub fn materialization_records(&self) -> Vec<PredictionCacheMaterializationRecord> {
2133 self.materialization_records.borrow().clone()
2134 }
2135}
2136
2137impl RuntimePredictionCacheStore for InMemoryPredictionCacheStore {
2138 fn load_blocks(&self, requirement_key: &str) -> Result<Vec<PredictionBlock>> {
2139 let payload = self.payloads.get(requirement_key).ok_or_else(|| {
2140 DagMlError::RuntimeValidation(format!(
2141 "prediction cache store is missing requirement `{requirement_key}`"
2142 ))
2143 })?;
2144 payload.validate()?;
2145 if payload.prediction_level != PredictionLevel::Sample {
2146 return Err(DagMlError::RuntimeValidation(format!(
2147 "prediction cache store requirement `{requirement_key}` contains {:?} predictions, not sample blocks",
2148 payload.prediction_level
2149 )));
2150 }
2151 Ok(payload.blocks.clone())
2152 }
2153
2154 fn load_aggregated_blocks(
2155 &self,
2156 requirement_key: &str,
2157 ) -> Result<Vec<AggregatedPredictionBlock>> {
2158 let payload = self.payloads.get(requirement_key).ok_or_else(|| {
2159 DagMlError::RuntimeValidation(format!(
2160 "prediction cache store is missing requirement `{requirement_key}`"
2161 ))
2162 })?;
2163 payload.validate()?;
2164 if payload.prediction_level == PredictionLevel::Sample {
2165 return Err(DagMlError::RuntimeValidation(format!(
2166 "prediction cache store requirement `{requirement_key}` contains sample predictions, not aggregated blocks"
2167 )));
2168 }
2169 Ok(payload.aggregated_blocks.clone())
2170 }
2171
2172 fn materialize(&self, request: &PredictionCacheMaterializationRequest) -> Result<HandleRef> {
2173 request.requirement.validate()?;
2174 request.cache.validate()?;
2175 if request.requirement.key() != request.cache.requirement_key {
2176 return Err(DagMlError::RuntimeValidation(format!(
2177 "prediction cache materialization request for `{}` uses cache `{}` with mismatched requirement `{}`",
2178 request.requirement.key(),
2179 request.cache.cache_id,
2180 request.cache.requirement_key
2181 )));
2182 }
2183 let payload = self
2184 .payloads
2185 .get(&request.cache.requirement_key)
2186 .ok_or_else(|| {
2187 DagMlError::RuntimeValidation(format!(
2188 "prediction cache store is missing requirement `{}`",
2189 request.cache.requirement_key
2190 ))
2191 })?;
2192 validate_prediction_cache_payload_matches_record(payload, &request.cache)?;
2193 let fingerprint = stable_json_fingerprint(&(
2194 &request.run_id,
2195 &request.bundle_id,
2196 request.phase,
2197 &request.variant_id,
2198 &request.cache.requirement_key,
2199 &request.cache.cache_id,
2200 request.cache.prediction_level,
2201 &request.cache.content_fingerprint,
2202 ))?;
2203 let handle = HandleRef {
2204 handle: u64::from_str_radix(&fingerprint[..16], 16)
2205 .expect("sha256 hex prefix should fit into u64"),
2206 kind: HandleKind::Prediction,
2207 owner_controller: request.producer_controller_id.clone(),
2208 };
2209 self.materialization_records
2210 .borrow_mut()
2211 .push(PredictionCacheMaterializationRecord {
2212 run_id: request.run_id.clone(),
2213 bundle_id: request.bundle_id.clone(),
2214 phase: request.phase,
2215 variant_id: request.variant_id.clone(),
2216 requirement_key: request.cache.requirement_key.clone(),
2217 cache_id: request.cache.cache_id.clone(),
2218 handle: handle.clone(),
2219 });
2220 Ok(handle)
2221 }
2222}
2223
2224#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2225pub struct PredictionInputSpec {
2226 pub producer_node: NodeId,
2227 pub source_port: String,
2228 pub target_port: String,
2229 pub partition: PredictionPartition,
2230 #[serde(default = "default_runtime_prediction_level")]
2231 pub prediction_level: PredictionLevel,
2232 pub fold_id: Option<FoldId>,
2233 #[serde(default)]
2234 pub fold_ids: Vec<FoldId>,
2235 #[serde(default, skip_serializing_if = "Vec::is_empty")]
2236 pub unit_ids: Vec<PredictionUnitId>,
2237 #[serde(default)]
2238 pub sample_ids: Vec<SampleId>,
2239 pub prediction_width: usize,
2240 #[serde(default)]
2241 pub target_names: Vec<String>,
2242}
2243
2244#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2245pub struct ArtifactInputSpec {
2246 pub node_id: NodeId,
2247 pub controller_id: ControllerId,
2248 pub artifact: ArtifactRef,
2249 pub params_fingerprint: String,
2250 #[serde(default)]
2251 pub data_requirement_keys: Vec<String>,
2252 #[serde(default)]
2253 pub prediction_requirement_keys: Vec<String>,
2254}
2255
2256impl ArtifactInputSpec {
2257 fn from_refit_record(record: &RefitArtifactRecord) -> Result<Self> {
2258 record.validate()?;
2259 Ok(Self {
2260 node_id: record.node_id.clone(),
2261 controller_id: record.controller_id.clone(),
2262 artifact: record.artifact.clone(),
2263 params_fingerprint: record.params_fingerprint.clone(),
2264 data_requirement_keys: record.data_requirement_keys.clone(),
2265 prediction_requirement_keys: record.prediction_requirement_keys.clone(),
2266 })
2267 }
2268}
2269
2270fn default_runtime_prediction_level() -> PredictionLevel {
2271 PredictionLevel::Sample
2272}
2273
2274#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2275pub struct NodeTask {
2276 pub run_id: RunId,
2277 pub node_plan: NodePlan,
2278 pub phase: Phase,
2279 pub variant_id: Option<VariantId>,
2280 #[serde(default)]
2281 pub variant: Option<VariantExecutionSpec>,
2282 pub fold_id: Option<FoldId>,
2283 #[serde(default)]
2284 pub branch_path: Vec<BranchId>,
2285 #[serde(default)]
2286 pub input_handles: BTreeMap<String, HandleRef>,
2287 #[serde(default)]
2288 pub data_views: BTreeMap<String, DataProviderViewSpec>,
2289 #[serde(default)]
2290 pub prediction_inputs: BTreeMap<String, PredictionInputSpec>,
2291 #[serde(default)]
2292 pub artifact_inputs: BTreeMap<String, ArtifactInputSpec>,
2293 #[serde(default, skip_serializing_if = "Option::is_none")]
2298 pub inner_fold_set: Option<FoldSet>,
2299 #[serde(default, skip_serializing_if = "FitInfluenceTask::is_default")]
2300 pub fit_influence: FitInfluenceTask,
2301 pub seed: Option<u64>,
2302}
2303
2304#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
2305#[serde(rename_all = "snake_case")]
2306pub enum FitInfluenceMechanism {
2307 UniformRows,
2308 SampleWeights,
2309 RowResampling,
2310 BackendLossWeights,
2311 ScorerOnly,
2312}
2313
2314#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2315pub struct FitInfluenceTask {
2316 pub requested_policy: FitInfluencePolicy,
2317 pub effective_policy: FitInfluencePolicy,
2318 pub mechanism: FitInfluenceMechanism,
2319 #[serde(default, skip_serializing_if = "Vec::is_empty")]
2320 pub row_weights: Vec<f64>,
2321 #[serde(default, skip_serializing_if = "Vec::is_empty")]
2322 pub warnings: Vec<String>,
2323}
2324
2325impl Default for FitInfluenceTask {
2326 fn default() -> Self {
2327 Self {
2328 requested_policy: FitInfluencePolicy::UniformRows,
2329 effective_policy: FitInfluencePolicy::UniformRows,
2330 mechanism: FitInfluenceMechanism::UniformRows,
2331 row_weights: Vec::new(),
2332 warnings: Vec::new(),
2333 }
2334 }
2335}
2336
2337impl FitInfluenceTask {
2338 fn is_default(&self) -> bool {
2339 self == &Self::default()
2340 }
2341
2342 pub fn diagnostic(&self) -> FitInfluenceDiagnostic {
2343 FitInfluenceDiagnostic {
2344 requested_policy: self.requested_policy,
2345 effective_policy: self.effective_policy,
2346 mechanism: self.mechanism,
2347 fallback_used: !self.warnings.is_empty(),
2348 row_weight_count: self.row_weights.len(),
2349 warnings: self.warnings.clone(),
2350 }
2351 }
2352
2353 pub fn validate(&self) -> Result<()> {
2354 if !self
2355 .row_weights
2356 .iter()
2357 .all(|weight| weight.is_finite() && *weight > 0.0)
2358 {
2359 return Err(DagMlError::RuntimeValidation(
2360 "fit influence row_weights must be finite and > 0".to_string(),
2361 ));
2362 }
2363 if self
2364 .warnings
2365 .iter()
2366 .any(|warning| warning.trim().is_empty())
2367 {
2368 return Err(DagMlError::RuntimeValidation(
2369 "fit influence warnings must not be empty".to_string(),
2370 ));
2371 }
2372 match self.effective_policy {
2373 FitInfluencePolicy::EqualSampleInfluence | FitInfluencePolicy::BackendLossWeight
2374 if self.row_weights.is_empty() =>
2375 {
2376 return Err(DagMlError::RuntimeValidation(format!(
2377 "fit influence {:?} requires row_weights",
2378 self.effective_policy
2379 )));
2380 }
2381 _ => {}
2382 }
2383 if self.requested_policy == FitInfluencePolicy::StrictWeightSupport
2384 && self.effective_policy == FitInfluencePolicy::UniformRows
2385 {
2386 return Err(DagMlError::RuntimeValidation(
2387 "strict fit influence cannot fall back to uniform_rows".to_string(),
2388 ));
2389 }
2390 Ok(())
2391 }
2392}
2393
2394#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2395pub struct FitInfluenceDiagnostic {
2396 pub requested_policy: FitInfluencePolicy,
2397 pub effective_policy: FitInfluencePolicy,
2398 pub mechanism: FitInfluenceMechanism,
2399 #[serde(default)]
2400 pub fallback_used: bool,
2401 #[serde(default)]
2402 pub row_weight_count: usize,
2403 #[serde(default, skip_serializing_if = "Vec::is_empty")]
2404 pub warnings: Vec<String>,
2405}
2406
2407impl FitInfluenceDiagnostic {
2408 pub fn validate(&self, task: &NodeTask) -> Result<()> {
2409 if self.requested_policy != task.fit_influence.requested_policy {
2410 return Err(DagMlError::RuntimeValidation(format!(
2411 "fit influence diagnostic requested_policy {:?} does not match task {:?}",
2412 self.requested_policy, task.fit_influence.requested_policy
2413 )));
2414 }
2415 if self.effective_policy != task.fit_influence.effective_policy {
2416 return Err(DagMlError::RuntimeValidation(format!(
2417 "fit influence diagnostic effective_policy {:?} does not match task {:?}",
2418 self.effective_policy, task.fit_influence.effective_policy
2419 )));
2420 }
2421 if self.mechanism != task.fit_influence.mechanism {
2422 return Err(DagMlError::RuntimeValidation(format!(
2423 "fit influence diagnostic mechanism {:?} does not match task {:?}",
2424 self.mechanism, task.fit_influence.mechanism
2425 )));
2426 }
2427 if self.row_weight_count != task.fit_influence.row_weights.len() {
2428 return Err(DagMlError::RuntimeValidation(format!(
2429 "fit influence diagnostic row_weight_count {} does not match task {}",
2430 self.row_weight_count,
2431 task.fit_influence.row_weights.len()
2432 )));
2433 }
2434 if self
2435 .warnings
2436 .iter()
2437 .any(|warning| warning.trim().is_empty())
2438 {
2439 return Err(DagMlError::RuntimeValidation(
2440 "fit influence diagnostic warnings must not be empty".to_string(),
2441 ));
2442 }
2443 Ok(())
2444 }
2445}
2446
2447#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2448pub struct VariantExecutionSpec {
2449 pub variant_id: VariantId,
2450 #[serde(default)]
2451 pub choices: BTreeMap<String, GenerationChoice>,
2452 pub fingerprint: String,
2453 pub seed: Option<u64>,
2454}
2455
2456impl VariantExecutionSpec {
2457 pub fn from_plan(variant: &VariantPlan) -> Self {
2458 Self {
2459 variant_id: variant.variant_id.clone(),
2460 choices: variant.choices.clone(),
2461 fingerprint: variant.fingerprint.clone(),
2462 seed: variant.seed,
2463 }
2464 }
2465
2466 pub fn validate(&self) -> Result<()> {
2467 if self.fingerprint.trim().is_empty() {
2468 return Err(DagMlError::RuntimeValidation(format!(
2469 "variant `{}` has an empty fingerprint in task context",
2470 self.variant_id
2471 )));
2472 }
2473 for (dimension_name, choice) in &self.choices {
2474 if dimension_name.trim().is_empty() {
2475 return Err(DagMlError::RuntimeValidation(format!(
2476 "variant `{}` has an empty generation dimension name",
2477 self.variant_id
2478 )));
2479 }
2480 if choice.label.trim().is_empty() {
2481 return Err(DagMlError::RuntimeValidation(format!(
2482 "variant `{}` has an empty choice label for dimension `{dimension_name}`",
2483 self.variant_id
2484 )));
2485 }
2486 for override_spec in &choice.param_overrides {
2487 if override_spec.params.is_empty() {
2488 return Err(DagMlError::RuntimeValidation(format!(
2489 "variant `{}` has an empty param override for node `{}`",
2490 self.variant_id, override_spec.node_id
2491 )));
2492 }
2493 for param_key in override_spec.params.keys() {
2494 if param_key.trim().is_empty() {
2495 return Err(DagMlError::RuntimeValidation(format!(
2496 "variant `{}` has an empty param override key for node `{}`",
2497 self.variant_id, override_spec.node_id
2498 )));
2499 }
2500 }
2501 }
2502 }
2503 self.param_overrides_by_node()?;
2504 Ok(())
2505 }
2506
2507 pub fn effective_params_for_node(
2508 &self,
2509 node_id: &NodeId,
2510 base_params: &BTreeMap<String, serde_json::Value>,
2511 ) -> Result<BTreeMap<String, serde_json::Value>> {
2512 let overrides_by_node = self.param_overrides_by_node()?;
2513 let Some(overrides) = overrides_by_node.get(node_id) else {
2514 return Ok(base_params.clone());
2515 };
2516 let mut params = base_params.clone();
2517 params.extend(overrides.clone());
2518 Ok(params)
2519 }
2520
2521 fn param_overrides_by_node(
2522 &self,
2523 ) -> Result<BTreeMap<NodeId, BTreeMap<String, serde_json::Value>>> {
2524 let mut overrides = BTreeMap::<NodeId, BTreeMap<String, serde_json::Value>>::new();
2525 let mut owners = BTreeMap::<(NodeId, String), String>::new();
2526 for (dimension_name, choice) in &self.choices {
2527 for override_spec in &choice.param_overrides {
2528 for (param_key, value) in &override_spec.params {
2529 let owner_key = (override_spec.node_id.clone(), param_key.clone());
2530 if let Some(previous) =
2531 owners.insert(owner_key, format!("{dimension_name}:{}", choice.label))
2532 {
2533 return Err(DagMlError::RuntimeValidation(format!(
2534 "variant `{}` has conflicting generation overrides for `{}.{}` from `{previous}` and `{}:{}`",
2535 self.variant_id,
2536 override_spec.node_id,
2537 param_key,
2538 dimension_name,
2539 choice.label
2540 )));
2541 }
2542 overrides
2543 .entry(override_spec.node_id.clone())
2544 .or_default()
2545 .insert(param_key.clone(), value.clone());
2546 }
2547 }
2548 }
2549 Ok(overrides)
2550 }
2551}
2552
2553#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2559pub struct ExplanationBlock {
2560 pub producer_node: NodeId,
2562 pub method: String,
2564 #[serde(default, skip_serializing_if = "Option::is_none")]
2566 pub target_name: Option<String>,
2567 pub payload: serde_json::Value,
2569}
2570
2571impl ExplanationBlock {
2572 pub fn validate(&self) -> Result<()> {
2576 if self.method.trim().is_empty() {
2577 return Err(DagMlError::RuntimeValidation(
2578 "explanation method must be a non-empty identifier".to_string(),
2579 ));
2580 }
2581 if let Some(name) = &self.target_name {
2582 if name.trim().is_empty() {
2583 return Err(DagMlError::RuntimeValidation(
2584 "explanation target_name must be non-empty when present".to_string(),
2585 ));
2586 }
2587 }
2588 Ok(())
2589 }
2590}
2591
2592#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
2593pub struct NodeResult {
2594 pub node_id: NodeId,
2595 #[serde(default)]
2596 pub outputs: BTreeMap<String, HandleRef>,
2597 #[serde(default)]
2598 pub predictions: Vec<PredictionBlock>,
2599 #[serde(default)]
2600 pub observation_predictions: Vec<ObservationPredictionBlock>,
2601 #[serde(default)]
2602 pub aggregated_predictions: Vec<AggregatedPredictionBlock>,
2603 #[serde(default)]
2604 pub explanations: Vec<ExplanationBlock>,
2605 #[serde(default)]
2606 pub shape_deltas: Vec<ShapeDelta>,
2607 #[serde(default)]
2608 pub artifacts: Vec<ArtifactRef>,
2609 #[serde(default)]
2610 pub artifact_handles: BTreeMap<ArtifactId, HandleRef>,
2611 #[serde(default, skip_serializing_if = "Vec::is_empty")]
2612 pub fit_influence_diagnostics: Vec<FitInfluenceDiagnostic>,
2613 pub lineage: LineageRecord,
2614}
2615
2616impl NodeResult {
2617 pub fn validate_for_task(&self, task: &NodeTask) -> Result<()> {
2618 if self.node_id != task.node_plan.node_id {
2619 return Err(DagMlError::RuntimeValidation(format!(
2620 "task for `{}` returned result for `{}`",
2621 task.node_plan.node_id, self.node_id
2622 )));
2623 }
2624 if self.lineage.node_id != task.node_plan.node_id {
2625 return Err(DagMlError::RuntimeValidation(format!(
2626 "lineage for task `{}` references node `{}`",
2627 task.node_plan.node_id, self.lineage.node_id
2628 )));
2629 }
2630 if self.lineage.phase != task.phase {
2631 return Err(DagMlError::RuntimeValidation(format!(
2632 "lineage for node `{}` has phase {:?}, expected {:?}",
2633 task.node_plan.node_id, self.lineage.phase, task.phase
2634 )));
2635 }
2636 if self.lineage.run_id != task.run_id {
2637 return Err(DagMlError::RuntimeValidation(format!(
2638 "lineage for node `{}` has run `{}`, expected `{}`",
2639 task.node_plan.node_id, self.lineage.run_id, task.run_id
2640 )));
2641 }
2642 if self.lineage.controller_id != task.node_plan.controller_id {
2643 return Err(DagMlError::RuntimeValidation(format!(
2644 "lineage for node `{}` has controller `{}`, expected `{}`",
2645 task.node_plan.node_id, self.lineage.controller_id, task.node_plan.controller_id
2646 )));
2647 }
2648 if self.lineage.controller_version != task.node_plan.controller_version {
2649 return Err(DagMlError::RuntimeValidation(format!(
2650 "lineage for node `{}` has controller version `{}`, expected `{}`",
2651 task.node_plan.node_id,
2652 self.lineage.controller_version,
2653 task.node_plan.controller_version
2654 )));
2655 }
2656 if self.lineage.variant_id != task.variant_id {
2657 return Err(DagMlError::RuntimeValidation(format!(
2658 "lineage for node `{}` has variant {:?}, expected {:?}",
2659 task.node_plan.node_id, self.lineage.variant_id, task.variant_id
2660 )));
2661 }
2662 if let Some(variant) = &task.variant {
2663 variant.validate()?;
2664 if Some(&variant.variant_id) != task.variant_id.as_ref() {
2665 return Err(DagMlError::RuntimeValidation(format!(
2666 "task for node `{}` has variant context `{}` but variant_id {:?}",
2667 task.node_plan.node_id, variant.variant_id, task.variant_id
2668 )));
2669 }
2670 }
2671 if self.lineage.fold_id != task.fold_id {
2672 return Err(DagMlError::RuntimeValidation(format!(
2673 "lineage for node `{}` has fold {:?}, expected {:?}",
2674 task.node_plan.node_id, self.lineage.fold_id, task.fold_id
2675 )));
2676 }
2677 if self.lineage.branch_path != task.branch_path {
2678 return Err(DagMlError::RuntimeValidation(format!(
2679 "lineage for node `{}` has branch path {:?}, expected {:?}",
2680 task.node_plan.node_id, self.lineage.branch_path, task.branch_path
2681 )));
2682 }
2683 if self.lineage.seed != task.seed {
2684 return Err(DagMlError::RuntimeValidation(format!(
2685 "lineage for node `{}` has seed {:?}, expected {:?}",
2686 task.node_plan.node_id, self.lineage.seed, task.seed
2687 )));
2688 }
2689 if self.lineage.params_fingerprint != task.node_plan.params_fingerprint {
2690 return Err(DagMlError::RuntimeValidation(format!(
2691 "lineage for node `{}` has params fingerprint `{}`, expected `{}`",
2692 task.node_plan.node_id,
2693 self.lineage.params_fingerprint,
2694 task.node_plan.params_fingerprint
2695 )));
2696 }
2697 task.fit_influence.validate()?;
2698 for diagnostic in &self.fit_influence_diagnostics {
2699 diagnostic.validate(task)?;
2700 }
2701 validate_lineage_shape_fingerprints(&self.lineage, task)?;
2702 if !self.explanations.is_empty() && task.phase != Phase::Explain {
2703 return Err(DagMlError::RuntimeValidation(format!(
2704 "node `{}` returned explanations outside the EXPLAIN phase",
2705 task.node_plan.node_id
2706 )));
2707 }
2708 for explanation in &self.explanations {
2709 explanation.validate()?;
2710 if explanation.producer_node != self.node_id {
2711 return Err(DagMlError::RuntimeValidation(format!(
2712 "node `{}` returned an explanation produced by `{}`",
2713 self.node_id, explanation.producer_node
2714 )));
2715 }
2716 }
2717 for (port, handle) in &self.outputs {
2718 if handle.owner_controller != task.node_plan.controller_id {
2719 return Err(DagMlError::RuntimeValidation(format!(
2720 "node `{}` output `{port}` is owned by `{}`, expected `{}`",
2721 task.node_plan.node_id, handle.owner_controller, task.node_plan.controller_id
2722 )));
2723 }
2724 }
2725 let mut artifact_ids = BTreeSet::new();
2726 for artifact in &self.artifacts {
2727 artifact.validate()?;
2728 if !artifact_ids.insert(artifact.id.clone()) {
2729 return Err(DagMlError::RuntimeValidation(format!(
2730 "node `{}` emitted duplicate artifact `{}`",
2731 task.node_plan.node_id, artifact.id
2732 )));
2733 }
2734 if artifact.controller_id != task.node_plan.controller_id {
2735 return Err(DagMlError::RuntimeValidation(format!(
2736 "node `{}` emitted artifact `{}` for controller `{}`, expected `{}`",
2737 task.node_plan.node_id,
2738 artifact.id,
2739 artifact.controller_id,
2740 task.node_plan.controller_id
2741 )));
2742 }
2743 let handle = self.artifact_handles.get(&artifact.id).ok_or_else(|| {
2744 DagMlError::RuntimeValidation(format!(
2745 "node `{}` emitted artifact `{}` without artifact handle",
2746 task.node_plan.node_id, artifact.id
2747 ))
2748 })?;
2749 if !matches!(handle.kind, HandleKind::Model | HandleKind::Artifact) {
2750 return Err(DagMlError::RuntimeValidation(format!(
2751 "node `{}` emitted artifact `{}` with non-artifact/model handle kind {:?}",
2752 task.node_plan.node_id, artifact.id, handle.kind
2753 )));
2754 }
2755 if handle.owner_controller != task.node_plan.controller_id {
2756 return Err(DagMlError::RuntimeValidation(format!(
2757 "node `{}` emitted artifact `{}` owned by `{}`, expected `{}`",
2758 task.node_plan.node_id,
2759 artifact.id,
2760 handle.owner_controller,
2761 task.node_plan.controller_id
2762 )));
2763 }
2764 }
2765 for artifact_id in self.artifact_handles.keys() {
2766 if !self
2767 .artifacts
2768 .iter()
2769 .any(|artifact| &artifact.id == artifact_id)
2770 {
2771 return Err(DagMlError::RuntimeValidation(format!(
2772 "node `{}` emitted artifact handle for undeclared artifact `{artifact_id}`",
2773 task.node_plan.node_id
2774 )));
2775 }
2776 }
2777 for artifact in &self.artifacts {
2778 if !self
2779 .lineage
2780 .artifact_refs
2781 .iter()
2782 .any(|lineage_artifact| lineage_artifact == artifact)
2783 {
2784 return Err(DagMlError::RuntimeValidation(format!(
2785 "node `{}` emitted artifact `{}` without matching lineage artifact ref",
2786 task.node_plan.node_id, artifact.id
2787 )));
2788 }
2789 }
2790 for artifact in &self.lineage.artifact_refs {
2791 if !self
2792 .artifacts
2793 .iter()
2794 .any(|emitted_artifact| emitted_artifact == artifact)
2795 {
2796 return Err(DagMlError::RuntimeValidation(format!(
2797 "node `{}` lineage references undeclared artifact `{}`",
2798 task.node_plan.node_id, artifact.id
2799 )));
2800 }
2801 }
2802 for prediction in &self.predictions {
2803 prediction.validate_shape()?;
2804 if prediction.producer_node != task.node_plan.node_id {
2805 return Err(DagMlError::RuntimeValidation(format!(
2806 "node `{}` emitted prediction for producer `{}`",
2807 task.node_plan.node_id, prediction.producer_node
2808 )));
2809 }
2810 validate_prediction_scope(prediction, task)?;
2811 }
2812 for prediction in &self.observation_predictions {
2813 prediction.validate_shape()?;
2814 if prediction.producer_node != task.node_plan.node_id {
2815 return Err(DagMlError::RuntimeValidation(format!(
2816 "node `{}` emitted observation prediction for producer `{}`",
2817 task.node_plan.node_id, prediction.producer_node
2818 )));
2819 }
2820 validate_observation_prediction_scope(prediction, task)?;
2821 }
2822 for prediction in &self.aggregated_predictions {
2823 prediction.validate_shape()?;
2824 if prediction.producer_node != task.node_plan.node_id {
2825 return Err(DagMlError::RuntimeValidation(format!(
2826 "node `{}` emitted aggregated prediction for producer `{}`",
2827 task.node_plan.node_id, prediction.producer_node
2828 )));
2829 }
2830 validate_aggregated_prediction_scope(prediction, task)?;
2831 }
2832 for delta in &self.shape_deltas {
2833 delta.validate()?;
2834 if delta.node_id != task.node_plan.node_id {
2835 return Err(DagMlError::RuntimeValidation(format!(
2836 "node `{}` emitted shape delta for `{}`",
2837 task.node_plan.node_id, delta.node_id
2838 )));
2839 }
2840 validate_shape_delta_for_task(delta, task)?;
2841 }
2842 self.lineage.validate()
2843 }
2844}
2845
2846fn validate_lineage_shape_fingerprints(lineage: &LineageRecord, task: &NodeTask) -> Result<()> {
2847 let Some(shape_plan) = &task.node_plan.shape_plan else {
2848 if lineage.data_model_shape_fingerprint.is_some()
2849 || lineage.aggregation_policy_fingerprint.is_some()
2850 {
2851 return Err(DagMlError::RuntimeValidation(format!(
2852 "lineage for node `{}` carries shape fingerprints but the node has no shape plan",
2853 task.node_plan.node_id
2854 )));
2855 }
2856 return Ok(());
2857 };
2858
2859 if let Some(actual) = &lineage.data_model_shape_fingerprint {
2860 let expected = stable_json_fingerprint(shape_plan)?;
2861 if actual != &expected {
2862 return Err(DagMlError::RuntimeValidation(format!(
2863 "lineage for node `{}` has data/model shape fingerprint `{actual}`, expected `{expected}`",
2864 task.node_plan.node_id
2865 )));
2866 }
2867 }
2868 if let Some(actual) = &lineage.aggregation_policy_fingerprint {
2869 let expected = stable_json_fingerprint(&shape_plan.aggregation_policy)?;
2870 if actual != &expected {
2871 return Err(DagMlError::RuntimeValidation(format!(
2872 "lineage for node `{}` has aggregation policy fingerprint `{actual}`, expected `{expected}`",
2873 task.node_plan.node_id
2874 )));
2875 }
2876 }
2877 Ok(())
2878}
2879
2880fn validate_shape_delta_for_task(delta: &ShapeDelta, task: &NodeTask) -> Result<()> {
2881 let Some(shape_plan) = &task.node_plan.shape_plan else {
2882 return Ok(());
2883 };
2884 if delta.kind == ShapeDeltaKind::Feature {
2885 if let Some(expected) = &shape_plan.feature_schema_fingerprint {
2886 if &delta.before_fingerprint != expected {
2887 return Err(DagMlError::RuntimeValidation(format!(
2888 "node `{}` emitted feature shape delta from `{}`, expected current schema `{expected}`",
2889 task.node_plan.node_id, delta.before_fingerprint
2890 )));
2891 }
2892 }
2893 }
2894 Ok(())
2895}
2896
2897fn validate_prediction_scope(prediction: &PredictionBlock, task: &NodeTask) -> Result<()> {
2898 if prediction.partition != PredictionPartition::Validation {
2899 return Ok(());
2900 }
2901 if prediction.fold_id != task.fold_id {
2902 return Err(DagMlError::RuntimeValidation(format!(
2903 "node `{}` emitted validation predictions for fold {:?}, expected {:?}",
2904 task.node_plan.node_id, prediction.fold_id, task.fold_id
2905 )));
2906 }
2907 if task.phase == Phase::FitCv
2908 && task.fold_id.is_some()
2909 && (!task.node_plan.data_bindings.is_empty() || !task.data_views.is_empty())
2910 {
2911 let validation_sample_ids = validation_view_sample_ids(task).ok_or_else(|| {
2912 DagMlError::RuntimeValidation(format!(
2913 "node `{}` emitted validation predictions without a fold-validation data view",
2914 task.node_plan.node_id
2915 ))
2916 })?;
2917 for sample_id in &prediction.sample_ids {
2918 if !validation_sample_ids.contains(sample_id) {
2919 return Err(DagMlError::RuntimeValidation(format!(
2920 "node `{}` emitted validation prediction for sample `{}` outside its validation view",
2921 task.node_plan.node_id, sample_id
2922 )));
2923 }
2924 }
2925 }
2926 Ok(())
2927}
2928
2929fn validate_observation_prediction_scope(
2930 prediction: &ObservationPredictionBlock,
2931 task: &NodeTask,
2932) -> Result<()> {
2933 if prediction.partition != PredictionPartition::Validation {
2934 return Ok(());
2935 }
2936 if prediction.fold_id != task.fold_id {
2937 return Err(DagMlError::RuntimeValidation(format!(
2938 "node `{}` emitted observation validation predictions for fold {:?}, expected {:?}",
2939 task.node_plan.node_id, prediction.fold_id, task.fold_id
2940 )));
2941 }
2942 Ok(())
2943}
2944
2945fn validate_aggregated_prediction_scope(
2946 prediction: &AggregatedPredictionBlock,
2947 task: &NodeTask,
2948) -> Result<()> {
2949 if prediction.partition != PredictionPartition::Validation {
2950 return Ok(());
2951 }
2952 if prediction.fold_id != task.fold_id {
2953 return Err(DagMlError::RuntimeValidation(format!(
2954 "node `{}` emitted aggregated validation predictions for fold {:?}, expected {:?}",
2955 task.node_plan.node_id, prediction.fold_id, task.fold_id
2956 )));
2957 }
2958 Ok(())
2959}
2960
2961fn validation_view_sample_ids(task: &NodeTask) -> Option<BTreeSet<SampleId>> {
2962 let mut sample_ids = BTreeSet::new();
2963 for view in task
2964 .data_views
2965 .values()
2966 .filter(|view| view.partition == DataRequestPartition::FoldValidation)
2967 {
2968 if let Some(view_sample_ids) = &view.sample_ids {
2969 sample_ids.extend(view_sample_ids.iter().cloned());
2970 }
2971 }
2972 (!sample_ids.is_empty()).then_some(sample_ids)
2973}
2974
2975fn fit_influence_task_for_node(
2976 plan: &ExecutionPlan,
2977 node_plan: &NodePlan,
2978 data_views: &BTreeMap<String, DataProviderViewSpec>,
2979) -> Result<FitInfluenceTask> {
2980 let manifest = plan
2981 .controller_manifests
2982 .get(&node_plan.controller_id)
2983 .ok_or_else(|| {
2984 DagMlError::RuntimeValidation(format!(
2985 "node `{}` references missing controller manifest `{}`",
2986 node_plan.node_id, node_plan.controller_id
2987 ))
2988 })?;
2989 let Some(model_input_spec) = manifest.model_input_spec()? else {
2990 return Ok(FitInfluenceTask::default());
2991 };
2992 let Some(requested_policy) = model_input_spec.fit_influence_policy else {
2993 return Ok(FitInfluenceTask::default());
2994 };
2995 resolve_fit_influence_task(
2996 requested_policy,
2997 &node_plan.controller_capabilities,
2998 data_views,
2999 )
3000}
3001
3002fn resolve_fit_influence_task(
3003 requested_policy: FitInfluencePolicy,
3004 capabilities: &BTreeSet<ControllerCapability>,
3005 data_views: &BTreeMap<String, DataProviderViewSpec>,
3006) -> Result<FitInfluenceTask> {
3007 let row_weights = equal_sample_influence_weights(data_views);
3008 match requested_policy {
3009 FitInfluencePolicy::UniformRows => Ok(FitInfluenceTask {
3010 requested_policy,
3011 effective_policy: FitInfluencePolicy::UniformRows,
3012 mechanism: FitInfluenceMechanism::UniformRows,
3013 row_weights: Vec::new(),
3014 warnings: Vec::new(),
3015 }),
3016 FitInfluencePolicy::ScorerOnly => Ok(FitInfluenceTask {
3017 requested_policy,
3018 effective_policy: FitInfluencePolicy::ScorerOnly,
3019 mechanism: FitInfluenceMechanism::ScorerOnly,
3020 row_weights: Vec::new(),
3021 warnings: Vec::new(),
3022 }),
3023 FitInfluencePolicy::EqualSampleInfluence => {
3024 require_fit_influence_support(capabilities, requested_policy)?;
3025 let weights = row_weights.ok_or_else(|| {
3026 DagMlError::RuntimeValidation(
3027 "equal_sample_influence requires task row sample ids".to_string(),
3028 )
3029 })?;
3030 Ok(FitInfluenceTask {
3031 requested_policy,
3032 effective_policy: FitInfluencePolicy::EqualSampleInfluence,
3033 mechanism: FitInfluenceMechanism::SampleWeights,
3034 row_weights: weights,
3035 warnings: Vec::new(),
3036 })
3037 }
3038 FitInfluencePolicy::ResampleEqualized => {
3039 require_fit_influence_support(capabilities, requested_policy)?;
3040 Ok(FitInfluenceTask {
3041 requested_policy,
3042 effective_policy: FitInfluencePolicy::ResampleEqualized,
3043 mechanism: FitInfluenceMechanism::RowResampling,
3044 row_weights: Vec::new(),
3045 warnings: Vec::new(),
3046 })
3047 }
3048 FitInfluencePolicy::BackendLossWeight => {
3049 require_fit_influence_support(capabilities, requested_policy)?;
3050 let weights = row_weights.ok_or_else(|| {
3051 DagMlError::RuntimeValidation(
3052 "backend_loss_weight requires task row sample ids".to_string(),
3053 )
3054 })?;
3055 Ok(FitInfluenceTask {
3056 requested_policy,
3057 effective_policy: FitInfluencePolicy::BackendLossWeight,
3058 mechanism: FitInfluenceMechanism::BackendLossWeights,
3059 row_weights: weights,
3060 warnings: Vec::new(),
3061 })
3062 }
3063 FitInfluencePolicy::StrictWeightSupport => {
3064 require_fit_influence_support(capabilities, requested_policy)?;
3065 strict_fit_influence_task(capabilities, row_weights, requested_policy)
3066 }
3067 FitInfluencePolicy::Auto => Ok(auto_fit_influence_task(capabilities, row_weights)),
3068 }
3069}
3070
3071fn require_fit_influence_support(
3072 capabilities: &BTreeSet<ControllerCapability>,
3073 policy: FitInfluencePolicy,
3074) -> Result<()> {
3075 if capabilities_support_fit_influence(capabilities, policy) {
3076 return Ok(());
3077 }
3078 Err(DagMlError::RuntimeValidation(format!(
3079 "controller capabilities do not support requested fit influence policy {:?}",
3080 policy
3081 )))
3082}
3083
3084fn strict_fit_influence_task(
3085 capabilities: &BTreeSet<ControllerCapability>,
3086 row_weights: Option<Vec<f64>>,
3087 requested_policy: FitInfluencePolicy,
3088) -> Result<FitInfluenceTask> {
3089 if capabilities.contains(&ControllerCapability::SupportsBackendLossWeights) {
3090 let weights = row_weights.ok_or_else(|| {
3091 DagMlError::RuntimeValidation(
3092 "strict_weight_support with backend loss weights requires task row sample ids"
3093 .to_string(),
3094 )
3095 })?;
3096 return Ok(FitInfluenceTask {
3097 requested_policy,
3098 effective_policy: FitInfluencePolicy::BackendLossWeight,
3099 mechanism: FitInfluenceMechanism::BackendLossWeights,
3100 row_weights: weights,
3101 warnings: Vec::new(),
3102 });
3103 }
3104 if capabilities.contains(&ControllerCapability::SupportsSampleWeights) {
3105 let weights = row_weights.ok_or_else(|| {
3106 DagMlError::RuntimeValidation(
3107 "strict_weight_support with sample weights requires task row sample ids"
3108 .to_string(),
3109 )
3110 })?;
3111 return Ok(FitInfluenceTask {
3112 requested_policy,
3113 effective_policy: FitInfluencePolicy::EqualSampleInfluence,
3114 mechanism: FitInfluenceMechanism::SampleWeights,
3115 row_weights: weights,
3116 warnings: Vec::new(),
3117 });
3118 }
3119 Ok(FitInfluenceTask {
3120 requested_policy,
3121 effective_policy: FitInfluencePolicy::ResampleEqualized,
3122 mechanism: FitInfluenceMechanism::RowResampling,
3123 row_weights: Vec::new(),
3124 warnings: Vec::new(),
3125 })
3126}
3127
3128fn auto_fit_influence_task(
3129 capabilities: &BTreeSet<ControllerCapability>,
3130 row_weights: Option<Vec<f64>>,
3131) -> FitInfluenceTask {
3132 if capabilities.contains(&ControllerCapability::SupportsSampleWeights) {
3133 if let Some(weights) = row_weights.clone() {
3134 return FitInfluenceTask {
3135 requested_policy: FitInfluencePolicy::Auto,
3136 effective_policy: FitInfluencePolicy::EqualSampleInfluence,
3137 mechanism: FitInfluenceMechanism::SampleWeights,
3138 row_weights: weights,
3139 warnings: Vec::new(),
3140 };
3141 }
3142 }
3143 if capabilities.contains(&ControllerCapability::SupportsRowResampling) {
3144 return FitInfluenceTask {
3145 requested_policy: FitInfluencePolicy::Auto,
3146 effective_policy: FitInfluencePolicy::ResampleEqualized,
3147 mechanism: FitInfluenceMechanism::RowResampling,
3148 row_weights: Vec::new(),
3149 warnings: Vec::new(),
3150 };
3151 }
3152 if capabilities.contains(&ControllerCapability::SupportsBackendLossWeights) {
3153 if let Some(weights) = row_weights {
3154 return FitInfluenceTask {
3155 requested_policy: FitInfluencePolicy::Auto,
3156 effective_policy: FitInfluencePolicy::BackendLossWeight,
3157 mechanism: FitInfluenceMechanism::BackendLossWeights,
3158 row_weights: weights,
3159 warnings: Vec::new(),
3160 };
3161 }
3162 }
3163 FitInfluenceTask {
3164 requested_policy: FitInfluencePolicy::Auto,
3165 effective_policy: FitInfluencePolicy::UniformRows,
3166 mechanism: FitInfluenceMechanism::UniformRows,
3167 row_weights: Vec::new(),
3168 warnings: vec![
3169 "auto fit influence fell back to uniform_rows because no supported weighting capability was usable".to_string(),
3170 ],
3171 }
3172}
3173
3174fn equal_sample_influence_weights(
3175 data_views: &BTreeMap<String, DataProviderViewSpec>,
3176) -> Option<Vec<f64>> {
3177 let row_sample_ids = data_views
3178 .values()
3179 .filter(|view| {
3180 matches!(
3181 view.partition,
3182 DataRequestPartition::FoldTrain | DataRequestPartition::FullTrain
3183 )
3184 })
3185 .filter_map(|view| view.sample_ids.as_ref())
3186 .find(|sample_ids| !sample_ids.is_empty())
3187 .or_else(|| {
3188 data_views
3189 .values()
3190 .filter_map(|view| view.sample_ids.as_ref())
3191 .find(|sample_ids| !sample_ids.is_empty())
3192 })?;
3193 let mut counts = BTreeMap::<&SampleId, usize>::new();
3194 for sample_id in row_sample_ids {
3195 *counts.entry(sample_id).or_default() += 1;
3196 }
3197 Some(
3198 row_sample_ids
3199 .iter()
3200 .map(|sample_id| 1.0 / *counts.get(sample_id).expect("counted sample id") as f64)
3201 .collect(),
3202 )
3203}
3204
3205fn record_fit_influence_diagnostic(task: &NodeTask, result: &mut NodeResult) {
3206 if task.fit_influence.is_default() || !result.fit_influence_diagnostics.is_empty() {
3207 return;
3208 }
3209 result
3210 .fit_influence_diagnostics
3211 .push(task.fit_influence.diagnostic());
3212}
3213
3214#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3215pub struct DataMaterializationRequest {
3216 pub run_id: RunId,
3217 pub node_id: NodeId,
3218 pub input_name: String,
3219 pub phase: Phase,
3220 pub variant_id: Option<VariantId>,
3221 pub fold_id: Option<FoldId>,
3222 pub binding: crate::data::DataBinding,
3223}
3224
3225#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3226pub struct DataProviderViewSpec {
3227 #[serde(default)]
3228 pub sample_ids: Option<Vec<SampleId>>,
3229 pub partition: DataRequestPartition,
3230 #[serde(default)]
3231 pub fold_id: Option<FoldId>,
3232 #[serde(default)]
3233 pub source_ids: Option<Vec<String>>,
3234 #[serde(default)]
3235 pub columns: Option<Vec<String>>,
3236 pub include_augmented: bool,
3237 pub include_excluded: bool,
3238 #[serde(default, skip_serializing_if = "Option::is_none")]
3239 pub branch_view: Option<crate::data::BranchViewPlan>,
3240 #[serde(default)]
3241 pub extra: BTreeMap<String, serde_json::Value>,
3242}
3243
3244pub const DATA_OUTPUT_PROVENANCE_KEY: &str = "dag_ml_output";
3245pub const DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION: u32 = 1;
3246pub const DATA_OUTPUT_PROVENANCE_SCHEMA_ID: &str =
3247 "https://github.com/GBeurier/dag-ml/schemas/data_output_provenance.v1.schema.json";
3248pub const NODE_TASK_SCHEMA_VERSION: u32 = 1;
3249pub const NODE_TASK_SCHEMA_ID: &str =
3250 "https://github.com/GBeurier/dag-ml/schemas/node_task.v1.schema.json";
3251pub const NODE_RESULT_SCHEMA_VERSION: u32 = 1;
3252pub const NODE_RESULT_SCHEMA_ID: &str =
3253 "https://github.com/GBeurier/dag-ml/schemas/node_result.v1.schema.json";
3254
3255fn default_data_output_provenance_schema_version() -> u32 {
3256 DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION
3257}
3258
3259impl DataProviderViewSpec {
3260 pub fn validate(&self) -> Result<()> {
3261 validate_optional_ids("sample id", &self.sample_ids)?;
3262 validate_optional_strings("source id", &self.source_ids)?;
3263 validate_optional_strings("column", &self.columns)?;
3264 match self.partition {
3265 DataRequestPartition::FoldTrain | DataRequestPartition::FoldValidation => {
3266 if self.sample_ids.is_some() && self.fold_id.is_none() {
3267 return Err(DagMlError::RuntimeValidation(format!(
3268 "data provider view {:?} with explicit sample ids requires a fold id",
3269 self.partition
3270 )));
3271 }
3272 }
3273 DataRequestPartition::FullTrain | DataRequestPartition::Predict => {
3274 if self.fold_id.is_some() {
3275 return Err(DagMlError::RuntimeValidation(format!(
3276 "data provider view {:?} must not carry a fold id",
3277 self.partition
3278 )));
3279 }
3280 }
3281 }
3282 for key in self.extra.keys() {
3283 if key.trim().is_empty() {
3284 return Err(DagMlError::RuntimeValidation(
3285 "data provider view extra contains an empty key".to_string(),
3286 ));
3287 }
3288 }
3289 if let Some(branch_view) = &self.branch_view {
3290 branch_view.validate()?;
3291 }
3292 self.output_provenance()?;
3293 Ok(())
3294 }
3295
3296 pub fn output_provenance(&self) -> Result<Option<DataOutputProvenance>> {
3297 let Some(value) = self.extra.get(DATA_OUTPUT_PROVENANCE_KEY) else {
3298 return Ok(None);
3299 };
3300 let provenance: DataOutputProvenance = serde_json::from_value(value.clone())?;
3301 provenance.validate()?;
3302 Ok(Some(provenance))
3303 }
3304}
3305
3306#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3307pub struct DataOutputProvenance {
3308 #[serde(default = "default_data_output_provenance_schema_version")]
3309 pub schema_version: u32,
3310 pub producer_node: NodeId,
3311 pub producer_port: String,
3312 pub producer_phase: Phase,
3313 #[serde(default)]
3314 pub variant_id: Option<VariantId>,
3315 #[serde(default)]
3316 pub fold_id: Option<FoldId>,
3317 #[serde(default)]
3318 pub shape_plan_fingerprint: Option<String>,
3319 #[serde(default)]
3320 pub aggregation_policy_fingerprint: Option<String>,
3321 #[serde(default)]
3322 pub feature_namespace: Option<String>,
3323 #[serde(default)]
3324 pub feature_schema_fingerprint: Option<String>,
3325 #[serde(default, skip_serializing_if = "Option::is_none")]
3326 pub representation_plan: Option<RepresentationPlan>,
3327 #[serde(default, skip_serializing_if = "Option::is_none")]
3328 pub representation_replay_manifest: Option<RepresentationReplayManifest>,
3329 #[serde(default, skip_serializing_if = "Option::is_none")]
3330 pub representation_compatibility: Option<RepresentationCompatibilityReport>,
3331 #[serde(default, skip_serializing_if = "Option::is_none")]
3332 pub relation_delta_fingerprint: Option<String>,
3333 #[serde(default)]
3334 pub shape_deltas: Vec<ShapeDelta>,
3335}
3336
3337impl DataOutputProvenance {
3338 pub fn validate(&self) -> Result<()> {
3339 if self.schema_version != DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION {
3340 return Err(DagMlError::RuntimeValidation(format!(
3341 "data output provenance for `{}` uses unsupported schema_version {}, expected {}",
3342 self.producer_node, self.schema_version, DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION
3343 )));
3344 }
3345 if self.producer_port.trim().is_empty() {
3346 return Err(DagMlError::RuntimeValidation(format!(
3347 "data output provenance for `{}` has empty producer_port",
3348 self.producer_node
3349 )));
3350 }
3351 validate_optional_fingerprint(
3352 "shape_plan_fingerprint",
3353 &self.shape_plan_fingerprint,
3354 &self.producer_node,
3355 )?;
3356 validate_optional_fingerprint(
3357 "aggregation_policy_fingerprint",
3358 &self.aggregation_policy_fingerprint,
3359 &self.producer_node,
3360 )?;
3361 validate_optional_fingerprint(
3362 "feature_schema_fingerprint",
3363 &self.feature_schema_fingerprint,
3364 &self.producer_node,
3365 )?;
3366 validate_optional_fingerprint(
3367 "relation_delta_fingerprint",
3368 &self.relation_delta_fingerprint,
3369 &self.producer_node,
3370 )?;
3371 if let Some(representation_plan) = &self.representation_plan {
3372 representation_plan.validate().map_err(|error| {
3373 DagMlError::RuntimeValidation(format!(
3374 "data output provenance for `{}` has invalid representation_plan: {error}",
3375 self.producer_node
3376 ))
3377 })?;
3378 }
3379 if let Some(replay_manifest) = &self.representation_replay_manifest {
3380 replay_manifest.validate().map_err(|error| {
3381 DagMlError::RuntimeValidation(format!(
3382 "data output provenance for `{}` has invalid representation_replay_manifest: {error}",
3383 self.producer_node
3384 ))
3385 })?;
3386 }
3387 if let Some(report) = &self.representation_compatibility {
3388 report.validate().map_err(|error| {
3389 DagMlError::RuntimeValidation(format!(
3390 "data output provenance for `{}` has invalid representation_compatibility: {error}",
3391 self.producer_node
3392 ))
3393 })?;
3394 }
3395 if self
3396 .feature_namespace
3397 .as_ref()
3398 .is_some_and(|namespace| namespace.trim().is_empty())
3399 {
3400 return Err(DagMlError::RuntimeValidation(format!(
3401 "data output provenance for `{}` has empty feature_namespace",
3402 self.producer_node
3403 )));
3404 }
3405 for delta in &self.shape_deltas {
3406 delta.validate()?;
3407 if delta.node_id != self.producer_node {
3408 return Err(DagMlError::RuntimeValidation(format!(
3409 "data output provenance for `{}` contains shape delta for `{}`",
3410 self.producer_node, delta.node_id
3411 )));
3412 }
3413 }
3414 if let Some(feature_schema_fingerprint) = &self.feature_schema_fingerprint {
3415 if let Some(last_feature_delta) = self
3416 .shape_deltas
3417 .iter()
3418 .rev()
3419 .find(|delta| delta.kind == ShapeDeltaKind::Feature)
3420 {
3421 if &last_feature_delta.after_fingerprint != feature_schema_fingerprint {
3422 return Err(DagMlError::RuntimeValidation(format!(
3423 "data output provenance for `{}` has feature_schema_fingerprint `{feature_schema_fingerprint}` but last feature delta ends at `{}`",
3424 self.producer_node, last_feature_delta.after_fingerprint
3425 )));
3426 }
3427 }
3428 }
3429 Ok(())
3430 }
3431}
3432
3433fn validate_optional_fingerprint(
3434 label: &str,
3435 fingerprint: &Option<String>,
3436 producer_node: &NodeId,
3437) -> Result<()> {
3438 let Some(fingerprint) = fingerprint else {
3439 return Ok(());
3440 };
3441 if fingerprint.len() != 64 || !fingerprint.bytes().all(|byte| byte.is_ascii_hexdigit()) {
3442 return Err(DagMlError::RuntimeValidation(format!(
3443 "data output provenance for `{producer_node}` has invalid {label}"
3444 )));
3445 }
3446 Ok(())
3447}
3448
3449fn validate_optional_ids<T>(label: &str, values: &Option<Vec<T>>) -> Result<()>
3450where
3451 T: Ord + ToString,
3452{
3453 let Some(values) = values else {
3454 return Ok(());
3455 };
3456 if values.is_empty() {
3457 return Err(DagMlError::RuntimeValidation(format!(
3458 "data provider view {label} list is empty"
3459 )));
3460 }
3461 let mut seen = BTreeSet::new();
3462 for value in values {
3463 if !seen.insert(value) {
3464 return Err(DagMlError::RuntimeValidation(format!(
3465 "data provider view has duplicate {label} `{}`",
3466 value.to_string()
3467 )));
3468 }
3469 }
3470 Ok(())
3471}
3472
3473fn validate_optional_strings(label: &str, values: &Option<Vec<String>>) -> Result<()> {
3474 let Some(values) = values else {
3475 return Ok(());
3476 };
3477 if values.is_empty() {
3478 return Err(DagMlError::RuntimeValidation(format!(
3479 "data provider view {label} list is empty"
3480 )));
3481 }
3482 let mut seen = BTreeSet::new();
3483 for value in values {
3484 if value.trim().is_empty() {
3485 return Err(DagMlError::RuntimeValidation(format!(
3486 "data provider view contains an empty {label}"
3487 )));
3488 }
3489 if !seen.insert(value.as_str()) {
3490 return Err(DagMlError::RuntimeValidation(format!(
3491 "data provider view has duplicate {label} `{value}`"
3492 )));
3493 }
3494 }
3495 Ok(())
3496}
3497
3498#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
3499pub struct DataViewRequest {
3500 pub run_id: RunId,
3501 pub node_id: NodeId,
3502 pub input_name: String,
3503 pub phase: Phase,
3504 pub variant_id: Option<VariantId>,
3505 pub fold_id: Option<FoldId>,
3506 pub binding: crate::data::DataBinding,
3507 pub data_handle: HandleRef,
3508 pub view: DataProviderViewSpec,
3509}
3510
3511pub trait RuntimeDataProvider {
3512 fn materialize(&self, request: &DataMaterializationRequest) -> Result<HandleRef>;
3513 fn make_view(&self, request: &DataViewRequest) -> Result<HandleRef>;
3514 fn coordinator_relations(&self, _binding: &DataBinding) -> Result<Option<SampleRelationSet>> {
3515 Ok(None)
3516 }
3517}
3518
3519pub trait RuntimeController: Send + Sync {
3520 fn controller_id(&self) -> &ControllerId;
3521 fn invoke(&self, task: &NodeTask) -> Result<NodeResult>;
3522
3523 fn invoke_aggregation(
3524 &self,
3525 task: &AggregationControllerTask,
3526 ) -> Result<AggregationControllerResult> {
3527 Err(DagMlError::RuntimeValidation(format!(
3528 "runtime controller `{}` does not implement aggregation task `{}`",
3529 self.controller_id(),
3530 task.task_id
3531 )))
3532 }
3533}
3534
3535pub struct BundleReplayExecution<'a> {
3536 pub plan: &'a ExecutionPlan,
3537 pub bundle: &'a ExecutionBundle,
3538 pub replay_request: &'a ReplayPhaseRequest,
3539 pub prediction_cache_store: Option<&'a dyn RuntimePredictionCacheStore>,
3540 pub controllers: &'a RuntimeControllerRegistry,
3541 pub data_provider: &'a dyn RuntimeDataProvider,
3542 pub artifact_store: &'a dyn RuntimeArtifactStore,
3543 pub data_envelopes: &'a BTreeMap<String, ExternalDataPlanEnvelope>,
3544}
3545
3546#[derive(Default)]
3547pub struct RuntimeControllerRegistry {
3548 controllers: BTreeMap<ControllerId, Box<dyn RuntimeController>>,
3549}
3550
3551impl RuntimeControllerRegistry {
3552 pub fn new() -> Self {
3553 Self::default()
3554 }
3555
3556 pub fn register(&mut self, controller: Box<dyn RuntimeController>) -> Result<()> {
3557 let id = controller.controller_id().clone();
3558 if self.controllers.insert(id.clone(), controller).is_some() {
3559 return Err(DagMlError::RuntimeValidation(format!(
3560 "duplicate runtime controller `{id}`"
3561 )));
3562 }
3563 Ok(())
3564 }
3565
3566 pub fn get(&self, controller_id: &ControllerId) -> Option<&dyn RuntimeController> {
3567 self.controllers.get(controller_id).map(Box::as_ref)
3568 }
3569}
3570
3571pub fn dispatch_custom_observation_aggregation(
3572 plan: &ExecutionPlan,
3573 controllers: &RuntimeControllerRegistry,
3574 task_id: impl Into<String>,
3575 block: ObservationPredictionBlock,
3576 relations: SampleRelationSet,
3577 policy: AggregationPolicy,
3578 requested_sample_order: Vec<SampleId>,
3579) -> Result<PredictionBlock> {
3580 let controller_id = custom_aggregation_controller_id(&policy)?;
3581 ensure_aggregation_controller_capability(plan, controller_id)?;
3582 let task = AggregationControllerTask {
3583 schema_version: crate::aggregation::AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
3584 task_id: task_id.into(),
3585 controller_id: controller_id.clone(),
3586 policy,
3587 reduction_plan: None,
3588 input: AggregationControllerInput::ObservationToSample {
3589 block,
3590 relations,
3591 requested_sample_order,
3592 },
3593 };
3594 let result = dispatch_custom_aggregation_task(controllers, &task)?;
3595 match result.output {
3596 AggregationControllerOutput::Sample { block } => Ok(block),
3597 AggregationControllerOutput::Unit { .. } => Err(DagMlError::RuntimeValidation(format!(
3598 "aggregation controller task `{}` returned unit output for observation input",
3599 task.task_id
3600 ))),
3601 }
3602}
3603
3604pub fn dispatch_custom_sample_aggregation(
3605 plan: &ExecutionPlan,
3606 controllers: &RuntimeControllerRegistry,
3607 task_id: impl Into<String>,
3608 block: PredictionBlock,
3609 relations: SampleRelationSet,
3610 policy: AggregationPolicy,
3611 requested_unit_order: Vec<PredictionUnitId>,
3612) -> Result<AggregatedPredictionBlock> {
3613 let controller_id = custom_aggregation_controller_id(&policy)?;
3614 ensure_aggregation_controller_capability(plan, controller_id)?;
3615 let task = AggregationControllerTask {
3616 schema_version: crate::aggregation::AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
3617 task_id: task_id.into(),
3618 controller_id: controller_id.clone(),
3619 policy,
3620 reduction_plan: None,
3621 input: AggregationControllerInput::SampleToUnit {
3622 block,
3623 relations,
3624 requested_unit_order,
3625 },
3626 };
3627 let result = dispatch_custom_aggregation_task(controllers, &task)?;
3628 match result.output {
3629 AggregationControllerOutput::Unit { block } => Ok(block),
3630 AggregationControllerOutput::Sample { .. } => Err(DagMlError::RuntimeValidation(format!(
3631 "aggregation controller task `{}` returned sample output for sample input",
3632 task.task_id
3633 ))),
3634 }
3635}
3636
3637pub fn dispatch_custom_aggregation_task(
3638 controllers: &RuntimeControllerRegistry,
3639 task: &AggregationControllerTask,
3640) -> Result<AggregationControllerResult> {
3641 task.validate()?;
3642 let controller = controllers.get(&task.controller_id).ok_or_else(|| {
3643 DagMlError::RuntimeValidation(format!(
3644 "aggregation runtime controller `{}` is not registered",
3645 task.controller_id
3646 ))
3647 })?;
3648 let result = controller.invoke_aggregation(task)?;
3649 result.validate_for_task(task)?;
3650 Ok(result)
3651}
3652
3653fn custom_aggregation_controller_id(policy: &AggregationPolicy) -> Result<&ControllerId> {
3654 policy.validate()?;
3655 policy
3656 .custom_controller
3657 .as_ref()
3658 .map(|controller| &controller.controller_id)
3659 .ok_or_else(|| {
3660 DagMlError::RuntimeValidation(
3661 "custom aggregation dispatch requires a custom_controller policy".to_string(),
3662 )
3663 })
3664}
3665
3666fn ensure_aggregation_controller_capability(
3667 plan: &ExecutionPlan,
3668 controller_id: &ControllerId,
3669) -> Result<()> {
3670 let manifest = plan
3671 .controller_manifests
3672 .get(controller_id)
3673 .ok_or_else(|| {
3674 DagMlError::Planning(format!(
3675 "missing aggregation controller manifest `{controller_id}`"
3676 ))
3677 })?;
3678 if !manifest
3679 .capabilities
3680 .contains(&ControllerCapability::AggregatesPredictions)
3681 {
3682 return Err(DagMlError::Planning(format!(
3683 "aggregation controller `{controller_id}` must declare aggregates_predictions"
3684 )));
3685 }
3686 Ok(())
3687}
3688
3689#[derive(Clone, Debug)]
3690pub struct RunContext {
3691 pub run_id: RunId,
3692 pub root_seed: Option<u64>,
3693 pub variant_id: Option<VariantId>,
3694 pub prediction_store: InMemoryPredictionStore,
3695 pub aggregated_prediction_store: InMemoryAggregatedPredictionStore,
3696 pub lineage: InMemoryLineageRecorder,
3697}
3698
3699impl RunContext {
3700 pub fn new(run_id: RunId, root_seed: Option<u64>) -> Self {
3701 Self {
3702 run_id,
3703 root_seed,
3704 variant_id: None,
3705 prediction_store: InMemoryPredictionStore::new(),
3706 aggregated_prediction_store: InMemoryAggregatedPredictionStore::new(),
3707 lineage: InMemoryLineageRecorder::new(),
3708 }
3709 }
3710}
3711
3712#[derive(Clone, Debug, Default)]
3713pub struct SequentialScheduler;
3714
3715#[derive(Clone, Debug)]
3716pub struct ParallelScheduler {
3717 max_workers: usize,
3718}
3719
3720impl ParallelScheduler {
3721 pub fn new(max_workers: usize) -> Result<Self> {
3722 if max_workers == 0 {
3723 return Err(DagMlError::RuntimeValidation(
3724 "parallel scheduler max_workers must be at least 1".to_string(),
3725 ));
3726 }
3727 Ok(Self { max_workers })
3728 }
3729
3730 pub fn max_workers(&self) -> usize {
3731 self.max_workers
3732 }
3733}
3734
3735#[derive(Clone, Debug)]
3736struct PhaseScope {
3737 phase: Phase,
3738 variant_id: Option<VariantId>,
3739 variant: Option<VariantExecutionSpec>,
3740 fold_id: Option<FoldId>,
3741 seed_root: Option<u64>,
3742}
3743
3744#[derive(Clone, Debug)]
3745struct ReplayPredictionCacheContract {
3746 requirement: BundlePredictionRequirement,
3747 cache: BundlePredictionCacheRecord,
3748}
3749
3750struct MaterializedReplayArtifacts {
3751 handles: BTreeMap<NodeId, BTreeMap<String, HandleRef>>,
3752 inputs: BTreeMap<NodeId, BTreeMap<String, ArtifactInputSpec>>,
3753}
3754
3755#[derive(Default)]
3756struct PhaseScopeResources<'a> {
3757 data_provider: Option<&'a dyn RuntimeDataProvider>,
3758 replay_artifact_handles: Option<&'a BTreeMap<NodeId, BTreeMap<String, HandleRef>>>,
3759 replay_artifact_inputs: Option<&'a BTreeMap<NodeId, BTreeMap<String, ArtifactInputSpec>>>,
3760 replay_bundle_id: Option<&'a BundleId>,
3761 data_envelopes: Option<&'a BTreeMap<String, ExternalDataPlanEnvelope>>,
3762 prediction_cache_store: Option<&'a dyn RuntimePredictionCacheStore>,
3763 prediction_cache_contracts: Option<&'a BTreeMap<String, ReplayPredictionCacheContract>>,
3764 artifact_store: Option<&'a mut InMemoryArtifactStore>,
3765}
3766
3767impl SequentialScheduler {
3768 pub fn execute_phase(
3769 &self,
3770 plan: &ExecutionPlan,
3771 controllers: &RuntimeControllerRegistry,
3772 ctx: &mut RunContext,
3773 phase: Phase,
3774 ) -> Result<Vec<NodeResult>> {
3775 plan.validate()?;
3776 let variant_id = ctx.variant_id.clone();
3777 let seed_root = ctx.root_seed;
3778 self.execute_phase_scope(
3779 plan,
3780 controllers,
3781 ctx,
3782 PhaseScope {
3783 phase,
3784 variant_id,
3785 variant: None,
3786 fold_id: None,
3787 seed_root,
3788 },
3789 PhaseScopeResources::default(),
3790 )
3791 }
3792
3793 pub fn execute_phase_with_data_provider(
3794 &self,
3795 plan: &ExecutionPlan,
3796 controllers: &RuntimeControllerRegistry,
3797 data_provider: &dyn RuntimeDataProvider,
3798 ctx: &mut RunContext,
3799 phase: Phase,
3800 ) -> Result<Vec<NodeResult>> {
3801 plan.validate()?;
3802 let variant_id = ctx.variant_id.clone();
3803 let seed_root = ctx.root_seed;
3804 self.execute_phase_scope(
3805 plan,
3806 controllers,
3807 ctx,
3808 PhaseScope {
3809 phase,
3810 variant_id,
3811 variant: None,
3812 fold_id: None,
3813 seed_root,
3814 },
3815 PhaseScopeResources {
3816 data_provider: Some(data_provider),
3817 ..Default::default()
3818 },
3819 )
3820 }
3821
3822 pub fn execute_campaign_phase(
3823 &self,
3824 plan: &ExecutionPlan,
3825 controllers: &RuntimeControllerRegistry,
3826 ctx: &mut RunContext,
3827 phase: Phase,
3828 ) -> Result<Vec<NodeResult>> {
3829 plan.validate()?;
3830 let mut results = Vec::new();
3831 let fold_ids = if phase == Phase::FitCv {
3832 plan.fold_set
3833 .as_ref()
3834 .map(|fold_set| {
3835 fold_set
3836 .folds
3837 .iter()
3838 .map(|fold| Some(fold.fold_id.clone()))
3839 .collect::<Vec<_>>()
3840 })
3841 .unwrap_or_else(|| vec![None])
3842 } else {
3843 vec![None]
3844 };
3845 for variant in &plan.variants {
3846 if ctx
3847 .variant_id
3848 .as_ref()
3849 .is_some_and(|requested| requested != &variant.variant_id)
3850 {
3851 continue;
3852 }
3853 for fold_id in &fold_ids {
3854 let seed_root = variant.seed.or(ctx.root_seed);
3855 results.extend(self.execute_phase_scope(
3856 plan,
3857 controllers,
3858 ctx,
3859 PhaseScope {
3860 phase,
3861 variant_id: Some(variant.variant_id.clone()),
3862 variant: Some(VariantExecutionSpec::from_plan(variant)),
3863 fold_id: fold_id.clone(),
3864 seed_root,
3865 },
3866 PhaseScopeResources::default(),
3867 )?);
3868 }
3869 }
3870 Ok(results)
3871 }
3872
3873 pub fn execute_campaign_phase_with_data_provider(
3874 &self,
3875 plan: &ExecutionPlan,
3876 controllers: &RuntimeControllerRegistry,
3877 data_provider: &dyn RuntimeDataProvider,
3878 ctx: &mut RunContext,
3879 phase: Phase,
3880 ) -> Result<Vec<NodeResult>> {
3881 plan.validate()?;
3882 let mut results = Vec::new();
3883 let fold_ids = if phase == Phase::FitCv {
3884 plan.fold_set
3885 .as_ref()
3886 .map(|fold_set| {
3887 fold_set
3888 .folds
3889 .iter()
3890 .map(|fold| Some(fold.fold_id.clone()))
3891 .collect::<Vec<_>>()
3892 })
3893 .unwrap_or_else(|| vec![None])
3894 } else {
3895 vec![None]
3896 };
3897 for variant in &plan.variants {
3898 if ctx
3899 .variant_id
3900 .as_ref()
3901 .is_some_and(|requested| requested != &variant.variant_id)
3902 {
3903 continue;
3904 }
3905 for fold_id in &fold_ids {
3906 let seed_root = variant.seed.or(ctx.root_seed);
3907 results.extend(self.execute_phase_scope(
3908 plan,
3909 controllers,
3910 ctx,
3911 PhaseScope {
3912 phase,
3913 variant_id: Some(variant.variant_id.clone()),
3914 variant: Some(VariantExecutionSpec::from_plan(variant)),
3915 fold_id: fold_id.clone(),
3916 seed_root,
3917 },
3918 PhaseScopeResources {
3919 data_provider: Some(data_provider),
3920 ..Default::default()
3921 },
3922 )?);
3923 }
3924 }
3925 Ok(results)
3926 }
3927
3928 pub fn execute_campaign_phase_with_data_provider_and_artifact_store(
3929 &self,
3930 plan: &ExecutionPlan,
3931 controllers: &RuntimeControllerRegistry,
3932 data_provider: &dyn RuntimeDataProvider,
3933 artifact_store: &mut InMemoryArtifactStore,
3934 ctx: &mut RunContext,
3935 phase: Phase,
3936 ) -> Result<Vec<NodeResult>> {
3937 plan.validate()?;
3938 let mut results = Vec::new();
3939 let fold_ids = if phase == Phase::FitCv {
3940 plan.fold_set
3941 .as_ref()
3942 .map(|fold_set| {
3943 fold_set
3944 .folds
3945 .iter()
3946 .map(|fold| Some(fold.fold_id.clone()))
3947 .collect::<Vec<_>>()
3948 })
3949 .unwrap_or_else(|| vec![None])
3950 } else {
3951 vec![None]
3952 };
3953 for variant in &plan.variants {
3954 if ctx
3955 .variant_id
3956 .as_ref()
3957 .is_some_and(|requested| requested != &variant.variant_id)
3958 {
3959 continue;
3960 }
3961 for fold_id in &fold_ids {
3962 let seed_root = variant.seed.or(ctx.root_seed);
3963 results.extend(self.execute_phase_scope(
3964 plan,
3965 controllers,
3966 ctx,
3967 PhaseScope {
3968 phase,
3969 variant_id: Some(variant.variant_id.clone()),
3970 variant: Some(VariantExecutionSpec::from_plan(variant)),
3971 fold_id: fold_id.clone(),
3972 seed_root,
3973 },
3974 PhaseScopeResources {
3975 data_provider: Some(data_provider),
3976 artifact_store: Some(&mut *artifact_store),
3977 ..Default::default()
3978 },
3979 )?);
3980 }
3981 }
3982 Ok(results)
3983 }
3984
3985 pub fn execute_bundle_replay(
3986 &self,
3987 replay: BundleReplayExecution<'_>,
3988 ctx: &mut RunContext,
3989 ) -> Result<Vec<NodeResult>> {
3990 replay.bundle.validate_against_plan(replay.plan)?;
3991 replay
3992 .replay_request
3993 .validate_for_bundle_with_prediction_cache_store(
3994 replay.bundle,
3995 replay.prediction_cache_store.is_some(),
3996 )?;
3997 replay
3998 .bundle
3999 .validate_replay_envelopes(replay.data_envelopes)?;
4000 let prediction_cache_contracts = if replay.replay_request.phase == Phase::Refit {
4001 Some(replay_prediction_cache_contracts(replay.bundle)?)
4002 } else {
4003 None
4004 };
4005 if replay.replay_request.phase == Phase::Refit {
4006 preload_replay_prediction_cache_store(
4007 replay.bundle,
4008 replay.prediction_cache_store,
4009 ctx,
4010 )?;
4011 }
4012 let replay_artifacts = materialize_replay_artifact_handles(
4013 replay.plan,
4014 replay.bundle,
4015 replay.replay_request,
4016 replay.artifact_store,
4017 ctx,
4018 )?;
4019 let selected_variant = replay
4020 .bundle
4021 .selected_variant_id
4022 .as_ref()
4023 .map(|selected| {
4024 replay
4025 .plan
4026 .variants
4027 .iter()
4028 .find(|variant| &variant.variant_id == selected)
4029 .map(VariantExecutionSpec::from_plan)
4030 .ok_or_else(|| {
4031 DagMlError::RuntimeValidation(format!(
4032 "bundle `{}` selected unknown variant `{selected}`",
4033 replay.bundle.bundle_id
4034 ))
4035 })
4036 })
4037 .transpose()?;
4038 let seed_root = selected_variant
4039 .as_ref()
4040 .and_then(|variant| variant.seed)
4041 .or(ctx.root_seed);
4042
4043 self.execute_phase_scope(
4044 replay.plan,
4045 replay.controllers,
4046 ctx,
4047 PhaseScope {
4048 phase: replay.replay_request.phase,
4049 variant_id: replay.bundle.selected_variant_id.clone(),
4050 variant: selected_variant,
4051 fold_id: None,
4052 seed_root,
4053 },
4054 PhaseScopeResources {
4055 data_provider: Some(replay.data_provider),
4056 replay_artifact_handles: Some(&replay_artifacts.handles),
4057 replay_artifact_inputs: Some(&replay_artifacts.inputs),
4058 replay_bundle_id: Some(&replay.bundle.bundle_id),
4059 data_envelopes: Some(replay.data_envelopes),
4060 prediction_cache_store: replay.prediction_cache_store,
4061 prediction_cache_contracts: prediction_cache_contracts.as_ref(),
4062 ..Default::default()
4063 },
4064 )
4065 }
4066
4067 fn execute_phase_scope(
4068 &self,
4069 plan: &ExecutionPlan,
4070 controllers: &RuntimeControllerRegistry,
4071 ctx: &mut RunContext,
4072 scope: PhaseScope,
4073 mut resources: PhaseScopeResources<'_>,
4074 ) -> Result<Vec<NodeResult>> {
4075 let _phase_span = crate::observability::phase_span(
4076 ctx.run_id.as_str(),
4077 plan.id.as_str(),
4078 scope.phase.as_str(),
4079 scope.variant_id.as_ref().map(VariantId::as_str),
4080 scope.fold_id.as_ref().map(FoldId::as_str),
4081 )
4082 .entered();
4083 let mut results = Vec::new();
4084 let mut output_handles = BTreeMap::<NodeId, BTreeMap<String, HandleRef>>::new();
4085 let mut output_data_views =
4086 BTreeMap::<NodeId, BTreeMap<String, DataProviderViewSpec>>::new();
4087 let mut input_lineage = BTreeMap::<NodeId, LineageId>::new();
4088
4089 for level in plan.node_parallel_levels_for_phase(scope.phase)? {
4090 for node_id in &level {
4091 let node_plan = plan
4092 .node_plans
4093 .get(node_id)
4094 .expect("execution plan was validated");
4095 let controller = controllers.get(&node_plan.controller_id).ok_or_else(|| {
4096 DagMlError::RuntimeValidation(format!(
4097 "runtime controller `{}` is not registered",
4098 node_plan.controller_id
4099 ))
4100 })?;
4101 let collected_inputs = collect_input_handles(
4102 plan,
4103 node_plan,
4104 &output_handles,
4105 &output_data_views,
4106 &resources,
4107 ctx,
4108 &scope,
4109 )?;
4110 let mut input_handles = collected_inputs.handles;
4111 let mut artifact_inputs = BTreeMap::new();
4112 if let Some(node_artifact_handles) = resources
4113 .replay_artifact_handles
4114 .and_then(|handles| handles.get(node_id))
4115 {
4116 for (key, handle) in node_artifact_handles {
4117 if input_handles.insert(key.clone(), handle.clone()).is_some() {
4118 return Err(DagMlError::RuntimeValidation(format!(
4119 "node `{node_id}` received duplicate replay artifact input `{key}`"
4120 )));
4121 }
4122 }
4123 }
4124 if let Some(node_artifact_inputs) = resources
4125 .replay_artifact_inputs
4126 .and_then(|inputs| inputs.get(node_id))
4127 {
4128 for (key, spec) in node_artifact_inputs {
4129 if artifact_inputs.insert(key.clone(), spec.clone()).is_some() {
4130 return Err(DagMlError::RuntimeValidation(format!(
4131 "node `{node_id}` received duplicate replay artifact metadata `{key}`"
4132 )));
4133 }
4134 }
4135 }
4136 let task_node_plan = effective_node_plan_for_scope(node_plan, &scope)?;
4137 let inner_fold_set = inner_fold_set_for_scope(
4138 &plan.campaign,
4139 plan.fold_set.as_ref(),
4140 node_plan,
4141 &scope,
4142 )?;
4143 let fit_influence = fit_influence_task_for_node(
4144 plan,
4145 &task_node_plan,
4146 &collected_inputs.data_views,
4147 )?;
4148 let task = NodeTask {
4149 inner_fold_set,
4150 run_id: ctx.run_id.clone(),
4151 node_plan: task_node_plan.clone(),
4152 phase: scope.phase,
4153 variant_id: scope.variant_id.clone(),
4154 variant: scope.variant.clone(),
4155 fold_id: scope.fold_id.clone(),
4156 branch_path: Vec::new(),
4157 input_handles,
4158 data_views: collected_inputs.data_views,
4159 prediction_inputs: collected_inputs.prediction_inputs,
4160 artifact_inputs,
4161 fit_influence,
4162 seed: derive_task_seed(
4163 scope.seed_root,
4164 scope.variant_id.as_ref(),
4165 scope.fold_id.as_ref(),
4166 &task_node_plan,
4167 scope.phase,
4168 ),
4169 };
4170 let _node_span = crate::observability::node_span(
4171 task.run_id.as_str(),
4172 plan.id.as_str(),
4173 task.phase.as_str(),
4174 task.node_plan.node_id.as_str(),
4175 task.node_plan.controller_id.as_str(),
4176 )
4177 .entered();
4178 let mut result = controller.invoke(&task)?;
4179 record_fit_influence_diagnostic(&task, &mut result);
4180 result.validate_for_task(&task)?;
4181 apply_result_prediction_aggregation(
4182 plan,
4183 controllers,
4184 &task,
4185 &mut result,
4186 &resources,
4187 )?;
4188 attach_coordinator_input_lineage(
4189 &mut result,
4190 plan,
4191 &task.node_plan.node_id,
4192 &input_lineage,
4193 )?;
4194 if let Some(store) = resources.artifact_store.as_deref_mut() {
4195 if scope.phase == Phase::Refit {
4196 store.capture_refit_artifacts(&task, &result)?;
4197 }
4198 }
4199 for prediction in &result.predictions {
4200 ctx.prediction_store.append(prediction.clone())?;
4201 }
4202 for prediction in &result.aggregated_predictions {
4203 ctx.aggregated_prediction_store.append(prediction.clone())?;
4204 }
4205 ctx.lineage.record(result.lineage.clone())?;
4206 let data_views = derive_output_data_views(plan, &task, &result)?;
4207 output_handles.insert(node_id.clone(), result.outputs.clone());
4208 output_data_views.insert(node_id.clone(), data_views);
4209 input_lineage.insert(node_id.clone(), result.lineage.record_id.clone());
4210 results.push(result);
4211 }
4212 }
4213
4214 Ok(results)
4215 }
4216}
4217
4218impl ParallelScheduler {
4219 pub fn execute_phase(
4220 &self,
4221 plan: &ExecutionPlan,
4222 controllers: &RuntimeControllerRegistry,
4223 ctx: &mut RunContext,
4224 phase: Phase,
4225 ) -> Result<Vec<NodeResult>> {
4226 plan.validate()?;
4227 let variant_id = ctx.variant_id.clone();
4228 let seed_root = ctx.root_seed;
4229 self.execute_phase_scope(
4230 plan,
4231 controllers,
4232 ctx,
4233 PhaseScope {
4234 phase,
4235 variant_id,
4236 variant: None,
4237 fold_id: None,
4238 seed_root,
4239 },
4240 PhaseScopeResources::default(),
4241 )
4242 }
4243
4244 pub fn execute_phase_with_data_provider(
4245 &self,
4246 plan: &ExecutionPlan,
4247 controllers: &RuntimeControllerRegistry,
4248 data_provider: &dyn RuntimeDataProvider,
4249 ctx: &mut RunContext,
4250 phase: Phase,
4251 ) -> Result<Vec<NodeResult>> {
4252 plan.validate()?;
4253 let variant_id = ctx.variant_id.clone();
4254 let seed_root = ctx.root_seed;
4255 self.execute_phase_scope(
4256 plan,
4257 controllers,
4258 ctx,
4259 PhaseScope {
4260 phase,
4261 variant_id,
4262 variant: None,
4263 fold_id: None,
4264 seed_root,
4265 },
4266 PhaseScopeResources {
4267 data_provider: Some(data_provider),
4268 ..Default::default()
4269 },
4270 )
4271 }
4272
4273 pub fn execute_campaign_phase(
4274 &self,
4275 plan: &ExecutionPlan,
4276 controllers: &RuntimeControllerRegistry,
4277 ctx: &mut RunContext,
4278 phase: Phase,
4279 ) -> Result<Vec<NodeResult>> {
4280 plan.validate()?;
4281 let mut results = Vec::new();
4282 let fold_ids = if phase == Phase::FitCv {
4283 plan.fold_set
4284 .as_ref()
4285 .map(|fold_set| {
4286 fold_set
4287 .folds
4288 .iter()
4289 .map(|fold| Some(fold.fold_id.clone()))
4290 .collect::<Vec<_>>()
4291 })
4292 .unwrap_or_else(|| vec![None])
4293 } else {
4294 vec![None]
4295 };
4296 for variant in &plan.variants {
4297 if ctx
4298 .variant_id
4299 .as_ref()
4300 .is_some_and(|requested| requested != &variant.variant_id)
4301 {
4302 continue;
4303 }
4304 for fold_id in &fold_ids {
4305 let seed_root = variant.seed.or(ctx.root_seed);
4306 results.extend(self.execute_phase_scope(
4307 plan,
4308 controllers,
4309 ctx,
4310 PhaseScope {
4311 phase,
4312 variant_id: Some(variant.variant_id.clone()),
4313 variant: Some(VariantExecutionSpec::from_plan(variant)),
4314 fold_id: fold_id.clone(),
4315 seed_root,
4316 },
4317 PhaseScopeResources::default(),
4318 )?);
4319 }
4320 }
4321 Ok(results)
4322 }
4323
4324 pub fn execute_campaign_phase_with_data_provider(
4325 &self,
4326 plan: &ExecutionPlan,
4327 controllers: &RuntimeControllerRegistry,
4328 data_provider: &dyn RuntimeDataProvider,
4329 ctx: &mut RunContext,
4330 phase: Phase,
4331 ) -> Result<Vec<NodeResult>> {
4332 plan.validate()?;
4333 let mut results = Vec::new();
4334 let fold_ids = if phase == Phase::FitCv {
4335 plan.fold_set
4336 .as_ref()
4337 .map(|fold_set| {
4338 fold_set
4339 .folds
4340 .iter()
4341 .map(|fold| Some(fold.fold_id.clone()))
4342 .collect::<Vec<_>>()
4343 })
4344 .unwrap_or_else(|| vec![None])
4345 } else {
4346 vec![None]
4347 };
4348 for variant in &plan.variants {
4349 if ctx
4350 .variant_id
4351 .as_ref()
4352 .is_some_and(|requested| requested != &variant.variant_id)
4353 {
4354 continue;
4355 }
4356 for fold_id in &fold_ids {
4357 let seed_root = variant.seed.or(ctx.root_seed);
4358 results.extend(self.execute_phase_scope(
4359 plan,
4360 controllers,
4361 ctx,
4362 PhaseScope {
4363 phase,
4364 variant_id: Some(variant.variant_id.clone()),
4365 variant: Some(VariantExecutionSpec::from_plan(variant)),
4366 fold_id: fold_id.clone(),
4367 seed_root,
4368 },
4369 PhaseScopeResources {
4370 data_provider: Some(data_provider),
4371 ..Default::default()
4372 },
4373 )?);
4374 }
4375 }
4376 Ok(results)
4377 }
4378
4379 pub fn execute_campaign_phase_with_data_provider_and_artifact_store(
4380 &self,
4381 plan: &ExecutionPlan,
4382 controllers: &RuntimeControllerRegistry,
4383 data_provider: &dyn RuntimeDataProvider,
4384 artifact_store: &mut InMemoryArtifactStore,
4385 ctx: &mut RunContext,
4386 phase: Phase,
4387 ) -> Result<Vec<NodeResult>> {
4388 plan.validate()?;
4389 let mut results = Vec::new();
4390 let fold_ids = if phase == Phase::FitCv {
4391 plan.fold_set
4392 .as_ref()
4393 .map(|fold_set| {
4394 fold_set
4395 .folds
4396 .iter()
4397 .map(|fold| Some(fold.fold_id.clone()))
4398 .collect::<Vec<_>>()
4399 })
4400 .unwrap_or_else(|| vec![None])
4401 } else {
4402 vec![None]
4403 };
4404 for variant in &plan.variants {
4405 if ctx
4406 .variant_id
4407 .as_ref()
4408 .is_some_and(|requested| requested != &variant.variant_id)
4409 {
4410 continue;
4411 }
4412 for fold_id in &fold_ids {
4413 let seed_root = variant.seed.or(ctx.root_seed);
4414 results.extend(self.execute_phase_scope(
4415 plan,
4416 controllers,
4417 ctx,
4418 PhaseScope {
4419 phase,
4420 variant_id: Some(variant.variant_id.clone()),
4421 variant: Some(VariantExecutionSpec::from_plan(variant)),
4422 fold_id: fold_id.clone(),
4423 seed_root,
4424 },
4425 PhaseScopeResources {
4426 data_provider: Some(data_provider),
4427 artifact_store: Some(&mut *artifact_store),
4428 ..Default::default()
4429 },
4430 )?);
4431 }
4432 }
4433 Ok(results)
4434 }
4435
4436 pub fn execute_bundle_replay(
4437 &self,
4438 replay: BundleReplayExecution<'_>,
4439 ctx: &mut RunContext,
4440 ) -> Result<Vec<NodeResult>> {
4441 replay.bundle.validate_against_plan(replay.plan)?;
4442 replay
4443 .replay_request
4444 .validate_for_bundle_with_prediction_cache_store(
4445 replay.bundle,
4446 replay.prediction_cache_store.is_some(),
4447 )?;
4448 replay
4449 .bundle
4450 .validate_replay_envelopes(replay.data_envelopes)?;
4451 let prediction_cache_contracts = if replay.replay_request.phase == Phase::Refit {
4452 Some(replay_prediction_cache_contracts(replay.bundle)?)
4453 } else {
4454 None
4455 };
4456 if replay.replay_request.phase == Phase::Refit {
4457 preload_replay_prediction_cache_store(
4458 replay.bundle,
4459 replay.prediction_cache_store,
4460 ctx,
4461 )?;
4462 }
4463 let replay_artifacts = materialize_replay_artifact_handles(
4464 replay.plan,
4465 replay.bundle,
4466 replay.replay_request,
4467 replay.artifact_store,
4468 ctx,
4469 )?;
4470 let selected_variant = replay
4471 .bundle
4472 .selected_variant_id
4473 .as_ref()
4474 .map(|selected| {
4475 replay
4476 .plan
4477 .variants
4478 .iter()
4479 .find(|variant| &variant.variant_id == selected)
4480 .map(VariantExecutionSpec::from_plan)
4481 .ok_or_else(|| {
4482 DagMlError::RuntimeValidation(format!(
4483 "bundle `{}` selected unknown variant `{selected}`",
4484 replay.bundle.bundle_id
4485 ))
4486 })
4487 })
4488 .transpose()?;
4489 let seed_root = selected_variant
4490 .as_ref()
4491 .and_then(|variant| variant.seed)
4492 .or(ctx.root_seed);
4493
4494 self.execute_phase_scope(
4495 replay.plan,
4496 replay.controllers,
4497 ctx,
4498 PhaseScope {
4499 phase: replay.replay_request.phase,
4500 variant_id: replay.bundle.selected_variant_id.clone(),
4501 variant: selected_variant,
4502 fold_id: None,
4503 seed_root,
4504 },
4505 PhaseScopeResources {
4506 data_provider: Some(replay.data_provider),
4507 replay_artifact_handles: Some(&replay_artifacts.handles),
4508 replay_artifact_inputs: Some(&replay_artifacts.inputs),
4509 replay_bundle_id: Some(&replay.bundle.bundle_id),
4510 data_envelopes: Some(replay.data_envelopes),
4511 prediction_cache_store: replay.prediction_cache_store,
4512 prediction_cache_contracts: prediction_cache_contracts.as_ref(),
4513 ..Default::default()
4514 },
4515 )
4516 }
4517
4518 fn execute_phase_scope(
4519 &self,
4520 plan: &ExecutionPlan,
4521 controllers: &RuntimeControllerRegistry,
4522 ctx: &mut RunContext,
4523 scope: PhaseScope,
4524 mut resources: PhaseScopeResources<'_>,
4525 ) -> Result<Vec<NodeResult>> {
4526 let phase_span = crate::observability::phase_span(
4530 ctx.run_id.as_str(),
4531 plan.id.as_str(),
4532 scope.phase.as_str(),
4533 scope.variant_id.as_ref().map(VariantId::as_str),
4534 scope.fold_id.as_ref().map(FoldId::as_str),
4535 );
4536 let _phase_entered = phase_span.clone().entered();
4537 let plan_id = plan.id.as_str();
4539 plan.validate_parallel_controller_capabilities(self.max_workers, scope.phase)?;
4540 let mut results = Vec::new();
4541 let mut output_handles = BTreeMap::<NodeId, BTreeMap<String, HandleRef>>::new();
4542 let mut output_data_views =
4543 BTreeMap::<NodeId, BTreeMap<String, DataProviderViewSpec>>::new();
4544 let mut input_lineage = BTreeMap::<NodeId, LineageId>::new();
4545
4546 for level in plan.node_parallel_levels_for_phase(scope.phase)? {
4547 let mut prepared = Vec::<PreparedNodeTask>::new();
4548 for node_id in &level {
4549 let node_plan = plan
4550 .node_plans
4551 .get(node_id)
4552 .expect("execution plan was validated");
4553 let collected_inputs = collect_input_handles(
4554 plan,
4555 node_plan,
4556 &output_handles,
4557 &output_data_views,
4558 &resources,
4559 ctx,
4560 &scope,
4561 )?;
4562 let mut input_handles = collected_inputs.handles;
4563 let mut artifact_inputs = BTreeMap::new();
4564 if let Some(node_artifact_handles) = resources
4565 .replay_artifact_handles
4566 .and_then(|handles| handles.get(node_id))
4567 {
4568 for (key, handle) in node_artifact_handles {
4569 if input_handles.insert(key.clone(), handle.clone()).is_some() {
4570 return Err(DagMlError::RuntimeValidation(format!(
4571 "node `{node_id}` received duplicate replay artifact input `{key}`"
4572 )));
4573 }
4574 }
4575 }
4576 if let Some(node_artifact_inputs) = resources
4577 .replay_artifact_inputs
4578 .and_then(|inputs| inputs.get(node_id))
4579 {
4580 for (key, spec) in node_artifact_inputs {
4581 if artifact_inputs.insert(key.clone(), spec.clone()).is_some() {
4582 return Err(DagMlError::RuntimeValidation(format!(
4583 "node `{node_id}` received duplicate replay artifact metadata `{key}`"
4584 )));
4585 }
4586 }
4587 }
4588 let task_node_plan = effective_node_plan_for_scope(node_plan, &scope)?;
4589 let inner_fold_set = inner_fold_set_for_scope(
4590 &plan.campaign,
4591 plan.fold_set.as_ref(),
4592 node_plan,
4593 &scope,
4594 )?;
4595 let fit_influence = fit_influence_task_for_node(
4596 plan,
4597 &task_node_plan,
4598 &collected_inputs.data_views,
4599 )?;
4600 prepared.push(PreparedNodeTask {
4601 node_id: node_id.clone(),
4602 task: NodeTask {
4603 inner_fold_set,
4604 run_id: ctx.run_id.clone(),
4605 node_plan: task_node_plan.clone(),
4606 phase: scope.phase,
4607 variant_id: scope.variant_id.clone(),
4608 variant: scope.variant.clone(),
4609 fold_id: scope.fold_id.clone(),
4610 branch_path: Vec::new(),
4611 input_handles,
4612 data_views: collected_inputs.data_views,
4613 prediction_inputs: collected_inputs.prediction_inputs,
4614 artifact_inputs,
4615 fit_influence,
4616 seed: derive_task_seed(
4617 scope.seed_root,
4618 scope.variant_id.as_ref(),
4619 scope.fold_id.as_ref(),
4620 &task_node_plan,
4621 scope.phase,
4622 ),
4623 },
4624 });
4625 }
4626
4627 for chunk in prepared.chunks(self.max_workers) {
4628 let chunk_results =
4629 std::thread::scope(|thread_scope| -> Result<Vec<NodeResult>> {
4630 let mut handles = Vec::with_capacity(chunk.len());
4631 for prepared_task in chunk {
4632 let controller = controllers
4633 .get(&prepared_task.task.node_plan.controller_id)
4634 .ok_or_else(|| {
4635 DagMlError::RuntimeValidation(format!(
4636 "runtime controller `{}` is not registered",
4637 prepared_task.task.node_plan.controller_id
4638 ))
4639 })?;
4640 let worker_span = phase_span.clone();
4641 handles.push(thread_scope.spawn(move || {
4642 let _worker_span = worker_span.entered();
4643 let _node_span = crate::observability::node_span(
4644 prepared_task.task.run_id.as_str(),
4645 plan_id,
4646 prepared_task.task.phase.as_str(),
4647 prepared_task.task.node_plan.node_id.as_str(),
4648 prepared_task.task.node_plan.controller_id.as_str(),
4649 )
4650 .entered();
4651 let mut result = controller.invoke(&prepared_task.task)?;
4652 record_fit_influence_diagnostic(&prepared_task.task, &mut result);
4653 result.validate_for_task(&prepared_task.task)?;
4654 Ok(result)
4655 }));
4656 }
4657 handles
4658 .into_iter()
4659 .map(|handle| {
4660 handle.join().map_err(|_| {
4661 DagMlError::RuntimeValidation(
4662 "parallel scheduler worker panicked".to_string(),
4663 )
4664 })?
4665 })
4666 .collect()
4667 })?;
4668
4669 for (prepared_task, mut result) in chunk.iter().zip(chunk_results) {
4670 apply_result_prediction_aggregation(
4671 plan,
4672 controllers,
4673 &prepared_task.task,
4674 &mut result,
4675 &resources,
4676 )?;
4677 attach_coordinator_input_lineage(
4678 &mut result,
4679 plan,
4680 &prepared_task.task.node_plan.node_id,
4681 &input_lineage,
4682 )?;
4683 if let Some(store) = resources.artifact_store.as_deref_mut() {
4684 if scope.phase == Phase::Refit {
4685 store.capture_refit_artifacts(&prepared_task.task, &result)?;
4686 }
4687 }
4688 for prediction in &result.predictions {
4689 ctx.prediction_store.append(prediction.clone())?;
4690 }
4691 for prediction in &result.aggregated_predictions {
4692 ctx.aggregated_prediction_store.append(prediction.clone())?;
4693 }
4694 ctx.lineage.record(result.lineage.clone())?;
4695 let data_views = derive_output_data_views(plan, &prepared_task.task, &result)?;
4696 output_handles.insert(prepared_task.node_id.clone(), result.outputs.clone());
4697 output_data_views.insert(prepared_task.node_id.clone(), data_views);
4698 input_lineage.insert(
4699 prepared_task.node_id.clone(),
4700 result.lineage.record_id.clone(),
4701 );
4702 results.push(result);
4703 }
4704 }
4705 }
4706
4707 Ok(results)
4708 }
4709}
4710
4711struct PreparedNodeTask {
4712 node_id: NodeId,
4713 task: NodeTask,
4714}
4715
4716fn attach_coordinator_input_lineage(
4717 result: &mut NodeResult,
4718 plan: &ExecutionPlan,
4719 node_id: &NodeId,
4720 upstream_lineage: &BTreeMap<NodeId, LineageId>,
4721) -> Result<()> {
4722 let inferred = inferred_input_lineage_for_node(plan, node_id, upstream_lineage);
4723 if result.lineage.input_lineage.is_empty() {
4724 result.lineage.input_lineage = inferred;
4725 return Ok(());
4726 }
4727
4728 let declared = result
4729 .lineage
4730 .input_lineage
4731 .iter()
4732 .cloned()
4733 .collect::<BTreeSet<_>>()
4734 .into_iter()
4735 .collect::<Vec<_>>();
4736 if declared != inferred {
4737 return Err(DagMlError::RuntimeValidation(format!(
4738 "lineage for node `{}` declared input lineage {:?}, expected {:?}",
4739 result.node_id, declared, inferred
4740 )));
4741 }
4742 result.lineage.input_lineage = declared;
4743 Ok(())
4744}
4745
4746fn inferred_input_lineage_for_node(
4747 plan: &ExecutionPlan,
4748 node_id: &NodeId,
4749 upstream_lineage: &BTreeMap<NodeId, LineageId>,
4750) -> Vec<LineageId> {
4751 plan.graph_plan
4752 .graph
4753 .edges
4754 .iter()
4755 .filter(|edge| &edge.target.node_id == node_id && edge.contract.propagates_lineage)
4756 .filter_map(|edge| upstream_lineage.get(&edge.source.node_id).cloned())
4757 .collect::<BTreeSet<_>>()
4758 .into_iter()
4759 .collect()
4760}
4761
4762fn apply_result_prediction_aggregation(
4763 plan: &ExecutionPlan,
4764 controllers: &RuntimeControllerRegistry,
4765 task: &NodeTask,
4766 result: &mut NodeResult,
4767 resources: &PhaseScopeResources<'_>,
4768) -> Result<()> {
4769 let has_observation_predictions = !result.observation_predictions.is_empty();
4770 let has_sample_predictions = !result.predictions.is_empty();
4771 if !has_observation_predictions && !has_sample_predictions {
4772 return Ok(());
4773 }
4774 let Some(shape_plan) = &task.node_plan.shape_plan else {
4775 if !has_observation_predictions {
4776 return Ok(());
4777 }
4778 return Err(DagMlError::RuntimeValidation(format!(
4779 "node `{}` emitted observation predictions but has no data/model shape plan for aggregation",
4780 task.node_plan.node_id
4781 )));
4782 };
4783 let policy = &shape_plan.aggregation_policy;
4784 if !policy.store_aggregated_predictions {
4785 return Ok(());
4786 }
4787 if policy.aggregation_level == PredictionLevel::Observation {
4788 return Ok(());
4789 }
4790 if !has_observation_predictions && policy.aggregation_level == PredictionLevel::Sample {
4791 return Ok(());
4792 }
4793
4794 let mut derived_sample_blocks = Vec::new();
4795 if !result.observation_predictions.is_empty() {
4796 let relations = coordinator_relations_for_task(task, resources)?;
4797 let sample_policy = observation_to_sample_policy(policy);
4798 for block in result.observation_predictions.clone() {
4799 let requested_sample_order =
4800 requested_sample_order_for_observation_block(plan, task, &block, &relations)?;
4801 let sample_block =
4802 if sample_policy.method == crate::policy::AggregationMethod::CustomController {
4803 dispatch_custom_observation_aggregation(
4804 plan,
4805 controllers,
4806 aggregation_task_id(
4807 task,
4808 &block.producer_node,
4809 block.fold_id.as_ref(),
4810 "obs_to_sample",
4811 ),
4812 block,
4813 relations.clone(),
4814 sample_policy.clone(),
4815 requested_sample_order,
4816 )?
4817 } else {
4818 aggregate_observation_predictions(
4819 &block,
4820 &relations,
4821 &sample_policy,
4822 &requested_sample_order,
4823 )?
4824 };
4825 derived_sample_blocks.push(sample_block);
4826 }
4827 }
4828
4829 if policy.aggregation_level == PredictionLevel::Sample {
4830 result.predictions.extend(derived_sample_blocks);
4831 result.validate_for_task(task)?;
4832 return Ok(());
4833 }
4834
4835 if !result.aggregated_predictions.is_empty() {
4836 result.validate_for_task(task)?;
4837 return Ok(());
4838 }
4839
4840 let relations = coordinator_relations_for_task(task, resources)?;
4841 let sample_blocks = result
4842 .predictions
4843 .iter()
4844 .cloned()
4845 .chain(derived_sample_blocks)
4846 .collect::<Vec<_>>();
4847 for block in sample_blocks {
4848 let requested_unit_order =
4849 requested_unit_order_for_sample_block(policy.aggregation_level, &relations, &block)?;
4850 let aggregated = if policy.method == crate::policy::AggregationMethod::CustomController {
4851 dispatch_custom_sample_aggregation(
4852 plan,
4853 controllers,
4854 aggregation_task_id(
4855 task,
4856 &block.producer_node,
4857 block.fold_id.as_ref(),
4858 "sample_to_unit",
4859 ),
4860 block,
4861 relations.clone(),
4862 policy.clone(),
4863 requested_unit_order,
4864 )?
4865 } else {
4866 aggregate_sample_predictions_by_unit(&block, &relations, policy, &requested_unit_order)?
4867 };
4868 result.aggregated_predictions.push(aggregated);
4869 }
4870 result.validate_for_task(task)
4871}
4872
4873fn observation_to_sample_policy(policy: &AggregationPolicy) -> AggregationPolicy {
4874 let mut sample_policy = policy.clone();
4875 sample_policy.aggregation_level = PredictionLevel::Sample;
4876 sample_policy
4877}
4878
4879fn coordinator_relations_for_task(
4880 task: &NodeTask,
4881 resources: &PhaseScopeResources<'_>,
4882) -> Result<SampleRelationSet> {
4883 coordinator_relations_for_node(&task.node_plan, resources)?.ok_or_else(|| {
4884 DagMlError::RuntimeValidation(format!(
4885 "node `{}` needs coordinator relations for prediction aggregation but no matching data provider/envelope carries relations",
4886 task.node_plan.node_id
4887 ))
4888 })
4889}
4890
4891fn coordinator_relations_for_edge(
4892 plan: &ExecutionPlan,
4893 edge: &EdgeSpec,
4894 resources: &PhaseScopeResources<'_>,
4895) -> Result<SampleRelationSet> {
4896 let target_plan = plan.node_plans.get(&edge.target.node_id).ok_or_else(|| {
4897 DagMlError::Planning(format!(
4898 "OOF edge target node `{}` has no node plan",
4899 edge.target.node_id
4900 ))
4901 })?;
4902 if let Some(relations) = coordinator_relations_for_node(target_plan, resources)? {
4903 return Ok(relations);
4904 }
4905
4906 let source_plan = plan.node_plans.get(&edge.source.node_id).ok_or_else(|| {
4907 DagMlError::Planning(format!(
4908 "OOF edge source node `{}` has no node plan",
4909 edge.source.node_id
4910 ))
4911 })?;
4912 if let Some(relations) = coordinator_relations_for_node(source_plan, resources)? {
4913 return Ok(relations);
4914 }
4915
4916 Err(DagMlError::RuntimeValidation(format!(
4917 "edge `{}.{}` -> `{}.{}` needs coordinator relations for aggregated OOF validation but neither endpoint has a relation-carrying data binding",
4918 edge.source.node_id,
4919 edge.source.port_name,
4920 edge.target.node_id,
4921 edge.target.port_name
4922 )))
4923}
4924
4925fn coordinator_relations_for_node(
4926 node_plan: &NodePlan,
4927 resources: &PhaseScopeResources<'_>,
4928) -> Result<Option<SampleRelationSet>> {
4929 let mut selected: Option<SampleRelationSet> = None;
4930 for binding in &node_plan.data_bindings {
4931 if !binding.require_relations && binding.relation_fingerprint.is_none() {
4932 continue;
4933 }
4934 let relations = if let Some(envelopes) = resources.data_envelopes {
4935 let key = format!("{}.{}", binding.node_id, binding.input_name);
4936 let Some(envelope) = envelopes.get(&key) else {
4937 continue;
4938 };
4939 binding.validate_envelope(envelope)?;
4940 envelope.coordinator_relations.clone()
4941 } else if let Some(data_provider) = resources.data_provider {
4942 data_provider.coordinator_relations(binding)?
4943 } else {
4944 None
4945 };
4946 let Some(relations) = relations else {
4947 continue;
4948 };
4949 if let Some(previous) = &selected {
4950 if previous != &relations {
4951 return Err(DagMlError::RuntimeValidation(format!(
4952 "node `{}` has multiple non-identical coordinator relation sets",
4953 node_plan.node_id
4954 )));
4955 }
4956 } else {
4957 selected = Some(relations);
4958 }
4959 }
4960 Ok(selected)
4961}
4962
4963fn requested_sample_order_for_observation_block(
4964 plan: &ExecutionPlan,
4965 task: &NodeTask,
4966 block: &ObservationPredictionBlock,
4967 relations: &SampleRelationSet,
4968) -> Result<Vec<SampleId>> {
4969 if block.partition == PredictionPartition::Validation {
4970 if let Some(sample_ids) = validation_view_sample_ids(task) {
4971 return Ok(sample_ids.into_iter().collect());
4972 }
4973 if let (Some(fold_set), Some(fold_id)) = (plan.fold_set.as_ref(), block.fold_id.as_ref()) {
4974 if let Some(fold) = fold_set.folds.iter().find(|fold| &fold.fold_id == fold_id) {
4975 return Ok(fold.validation_sample_ids.clone());
4976 }
4977 }
4978 }
4979 first_seen_samples_for_observations(block, relations)
4980}
4981
4982fn first_seen_samples_for_observations(
4983 block: &ObservationPredictionBlock,
4984 relations: &SampleRelationSet,
4985) -> Result<Vec<SampleId>> {
4986 let mut seen = BTreeSet::new();
4987 let mut sample_order = Vec::new();
4988 for observation_id in &block.observation_ids {
4989 let sample_id = relations
4990 .sample_for_observation(observation_id)
4991 .ok_or_else(|| {
4992 DagMlError::OofValidation(format!(
4993 "observation prediction `{observation_id}` has no sample relation"
4994 ))
4995 })?;
4996 if seen.insert(sample_id.clone()) {
4997 sample_order.push(sample_id.clone());
4998 }
4999 }
5000 Ok(sample_order)
5001}
5002
5003fn requested_unit_order_for_sample_block(
5004 level: PredictionLevel,
5005 relations: &SampleRelationSet,
5006 block: &PredictionBlock,
5007) -> Result<Vec<PredictionUnitId>> {
5008 let mut seen = BTreeSet::new();
5009 let mut unit_order = Vec::new();
5010 for sample_id in &block.sample_ids {
5011 let unit_id = match level {
5012 PredictionLevel::Sample => PredictionUnitId::Sample(sample_id.clone()),
5013 PredictionLevel::Target => relations
5014 .target_for_sample(sample_id)
5015 .cloned()
5016 .map(PredictionUnitId::Target)
5017 .ok_or_else(|| {
5018 DagMlError::OofValidation(format!(
5019 "sample `{sample_id}` is missing target id for target aggregation"
5020 ))
5021 })?,
5022 PredictionLevel::Group => relations
5023 .group_for_sample(sample_id)
5024 .cloned()
5025 .map(PredictionUnitId::Group)
5026 .ok_or_else(|| {
5027 DagMlError::OofValidation(format!(
5028 "sample `{sample_id}` is missing group id for group aggregation"
5029 ))
5030 })?,
5031 PredictionLevel::Observation => {
5032 return Err(DagMlError::OofValidation(
5033 "sample prediction aggregation cannot output observation-level predictions"
5034 .to_string(),
5035 ));
5036 }
5037 };
5038 if seen.insert(unit_id.clone()) {
5039 unit_order.push(unit_id);
5040 }
5041 }
5042 Ok(unit_order)
5043}
5044
5045fn aggregation_task_id(
5046 task: &NodeTask,
5047 producer_node: &NodeId,
5048 fold_id: Option<&FoldId>,
5049 stage: &str,
5050) -> String {
5051 let fold = fold_id
5052 .map(ToString::to_string)
5053 .unwrap_or_else(|| "nofold".to_string());
5054 format!(
5055 "aggregation:{}:{}:{}:{}:{}",
5056 task.run_id, task.node_plan.node_id, producer_node, fold, stage
5057 )
5058}
5059
5060fn collect_input_handles(
5061 plan: &ExecutionPlan,
5062 node_plan: &NodePlan,
5063 output_handles: &BTreeMap<NodeId, BTreeMap<String, HandleRef>>,
5064 output_data_views: &BTreeMap<NodeId, BTreeMap<String, DataProviderViewSpec>>,
5065 resources: &PhaseScopeResources<'_>,
5066 ctx: &RunContext,
5067 scope: &PhaseScope,
5068) -> Result<CollectedInputs> {
5069 let mut inputs = BTreeMap::new();
5070 let mut data_views = BTreeMap::new();
5071 let mut prediction_inputs = BTreeMap::new();
5072 let training_oof_edges = incoming_training_oof_edges(plan, node_plan, scope)?;
5073 let training_oof_sources = training_oof_edges
5074 .iter()
5075 .map(|edge| edge.source.node_id.clone())
5076 .collect::<BTreeSet<_>>();
5077 let bound_data_inputs = node_plan
5078 .data_bindings
5079 .iter()
5080 .map(|binding| binding.input_name.clone())
5081 .collect::<BTreeSet<_>>();
5082 for upstream in &node_plan.input_nodes {
5083 if training_oof_sources.contains(upstream) {
5084 continue;
5085 }
5086 if let Some(handles) = output_handles.get(upstream) {
5087 for (port, handle) in handles {
5088 inputs.insert(format!("{upstream}.{port}"), handle.clone());
5089 }
5090 }
5091 }
5092 for edge in plan
5093 .graph_plan
5094 .graph
5095 .edges
5096 .iter()
5097 .filter(|edge| edge.target.node_id == node_plan.node_id)
5098 .filter(|edge| edge.contract.kind == PortKind::Data && !edge.contract.requires_oof)
5099 {
5100 if bound_data_inputs.contains(&edge.target.port_name) {
5101 continue;
5102 }
5103 let Some(handles) = output_handles.get(&edge.source.node_id) else {
5104 continue;
5105 };
5106 let Some(handle) = handles.get(&edge.source.port_name) else {
5107 continue;
5108 };
5109 let key = data_view_key(&edge.target.port_name);
5110 if inputs.insert(key.clone(), handle.clone()).is_some() {
5111 return Err(DagMlError::RuntimeValidation(format!(
5112 "node `{}` received duplicate data edge input `{key}`",
5113 node_plan.node_id
5114 )));
5115 }
5116 if let Some(source_views) = output_data_views.get(&edge.source.node_id) {
5117 if let Some(view) = source_views.get(&edge.source.port_name) {
5118 if data_views.insert(key.clone(), view.clone()).is_some() {
5119 return Err(DagMlError::RuntimeValidation(format!(
5120 "node `{}` received duplicate data edge view `{key}`",
5121 node_plan.node_id
5122 )));
5123 }
5124 }
5125 let source_validation_key = validation_data_view_key(&edge.source.port_name);
5126 if let Some(view) = source_views.get(&source_validation_key) {
5127 let validation_key = format!("{key}:validation");
5128 if data_views
5129 .insert(validation_key.clone(), view.clone())
5130 .is_some()
5131 {
5132 return Err(DagMlError::RuntimeValidation(format!(
5133 "node `{}` received duplicate data edge validation view `{validation_key}`",
5134 node_plan.node_id
5135 )));
5136 }
5137 }
5138 }
5139 }
5140 for edge in training_oof_edges {
5141 let key = format!("{}.{}", edge.source.node_id, edge.source.port_name);
5142 let input = collect_oof_prediction_input(plan, edge, ctx, scope, resources)?;
5143 if inputs.insert(key.clone(), input.handle).is_some() {
5144 return Err(DagMlError::RuntimeValidation(format!(
5145 "node `{}` received duplicate OOF prediction input `{key}`",
5146 node_plan.node_id
5147 )));
5148 }
5149 if prediction_inputs.insert(key.clone(), input.spec).is_some() {
5150 return Err(DagMlError::RuntimeValidation(format!(
5151 "node `{}` received duplicate OOF prediction spec `{key}`",
5152 node_plan.node_id
5153 )));
5154 }
5155 }
5156 if !node_plan.data_bindings.is_empty() && resources.data_provider.is_none() {
5157 return Err(DagMlError::RuntimeValidation(format!(
5158 "node `{}` requires {} data binding(s) but no runtime data provider is registered",
5159 node_plan.node_id,
5160 node_plan.data_bindings.len()
5161 )));
5162 }
5163 if let Some(data_provider) = resources.data_provider {
5164 for binding in &node_plan.data_bindings {
5165 let materialized = data_provider.materialize(&DataMaterializationRequest {
5166 run_id: ctx.run_id.clone(),
5167 node_id: node_plan.node_id.clone(),
5168 input_name: binding.input_name.clone(),
5169 phase: scope.phase,
5170 variant_id: scope.variant_id.clone(),
5171 fold_id: scope.fold_id.clone(),
5172 binding: binding.clone(),
5173 })?;
5174 let branch_view_for_node = branch_view_from_node_metadata(plan, &node_plan.node_id)?;
5175 let view = data_view_for_scope(
5176 binding,
5177 plan.fold_set.as_ref(),
5178 scope,
5179 branch_view_for_node.as_ref(),
5180 )?;
5181 let key = data_view_key(&binding.input_name);
5182 let view_handle = make_data_view_handle(
5183 data_provider,
5184 ctx,
5185 node_plan,
5186 scope,
5187 binding,
5188 &materialized,
5189 &view,
5190 )?;
5191 if data_views.insert(key.clone(), view).is_some() {
5192 return Err(DagMlError::RuntimeValidation(format!(
5193 "node `{}` received duplicate data view `{key}`",
5194 node_plan.node_id
5195 )));
5196 }
5197 if inputs.insert(key.clone(), view_handle).is_some() {
5198 return Err(DagMlError::RuntimeValidation(format!(
5199 "node `{}` received duplicate data input `{key}`",
5200 node_plan.node_id
5201 )));
5202 }
5203
5204 if let Some(validation_view) = validation_data_view_for_scope(
5205 binding,
5206 plan.fold_set.as_ref(),
5207 scope,
5208 branch_view_for_node.as_ref(),
5209 )? {
5210 let validation_key = format!("{key}:validation");
5211 let validation_handle = make_data_view_handle(
5212 data_provider,
5213 ctx,
5214 node_plan,
5215 scope,
5216 binding,
5217 &materialized,
5218 &validation_view,
5219 )?;
5220 if data_views
5221 .insert(validation_key.clone(), validation_view)
5222 .is_some()
5223 {
5224 return Err(DagMlError::RuntimeValidation(format!(
5225 "node `{}` received duplicate validation data view `{validation_key}`",
5226 node_plan.node_id
5227 )));
5228 }
5229 if inputs
5230 .insert(validation_key.clone(), validation_handle)
5231 .is_some()
5232 {
5233 return Err(DagMlError::RuntimeValidation(format!(
5234 "node `{}` received duplicate validation data input `{validation_key}`",
5235 node_plan.node_id
5236 )));
5237 }
5238 }
5239 }
5240 }
5241 Ok(CollectedInputs {
5242 handles: inputs,
5243 data_views,
5244 prediction_inputs,
5245 })
5246}
5247
5248fn effective_node_plan_for_scope(node_plan: &NodePlan, scope: &PhaseScope) -> Result<NodePlan> {
5249 let Some(variant) = &scope.variant else {
5250 return Ok(node_plan.clone());
5251 };
5252 let params = variant.effective_params_for_node(&node_plan.node_id, &node_plan.params)?;
5253 if params == node_plan.params {
5254 return Ok(node_plan.clone());
5255 }
5256 let mut node_plan = node_plan.clone();
5257 node_plan.params = params;
5258 node_plan.params_fingerprint = stable_json_fingerprint(&node_plan.params)?;
5259 Ok(node_plan)
5260}
5261
5262fn incoming_training_oof_edges<'a>(
5263 plan: &'a ExecutionPlan,
5264 node_plan: &NodePlan,
5265 scope: &PhaseScope,
5266) -> Result<Vec<&'a EdgeSpec>> {
5267 if !scope.phase.is_training() {
5268 return Ok(Vec::new());
5269 }
5270 plan.graph_plan
5271 .graph
5272 .edges
5273 .iter()
5274 .filter(|edge| edge.target.node_id == node_plan.node_id && edge.contract.requires_oof)
5275 .map(|edge| {
5276 if edge.contract.kind != PortKind::Prediction {
5277 return Err(DagMlError::RuntimeValidation(format!(
5278 "edge `{}.{}` -> `{}.{}` requires OOF but is not a prediction edge",
5279 edge.source.node_id,
5280 edge.source.port_name,
5281 edge.target.node_id,
5282 edge.target.port_name
5283 )));
5284 }
5285 Ok(edge)
5286 })
5287 .collect()
5288}
5289
5290struct CollectedPredictionInput {
5291 handle: HandleRef,
5292 spec: PredictionInputSpec,
5293}
5294
5295fn collect_oof_prediction_input(
5296 plan: &ExecutionPlan,
5297 edge: &EdgeSpec,
5298 ctx: &RunContext,
5299 scope: &PhaseScope,
5300 resources: &PhaseScopeResources<'_>,
5301) -> Result<CollectedPredictionInput> {
5302 if scope.phase == Phase::Refit {
5303 if let Some(contract) = replay_prediction_cache_contract_for_edge(resources, edge) {
5304 if contract.requirement.prediction_level != PredictionLevel::Sample {
5305 let source_plan = plan
5306 .node_plans
5307 .get(&edge.source.node_id)
5308 .expect("edge source has a node plan");
5309 let handle = materialize_oof_prediction_handle(
5310 plan,
5311 edge,
5312 ctx,
5313 scope,
5314 resources,
5315 &source_plan.controller_id,
5316 )?;
5317 return Ok(CollectedPredictionInput {
5318 handle,
5319 spec: prediction_input_spec_from_requirement(&contract.requirement, scope)?,
5320 });
5321 }
5322 }
5323 }
5324 let source_plan = plan
5325 .node_plans
5326 .get(&edge.source.node_id)
5327 .expect("edge source has a node plan");
5328 let prediction_level = oof_prediction_level_for_source(source_plan);
5329 if prediction_level != PredictionLevel::Sample {
5330 let blocks = match scope.phase {
5331 Phase::FitCv => validate_fit_cv_aggregated_oof_edge(
5332 plan,
5333 edge,
5334 ctx,
5335 scope,
5336 resources,
5337 prediction_level,
5338 )?,
5339 Phase::Refit => {
5340 validate_refit_aggregated_oof_edge(plan, edge, ctx, resources, prediction_level)?
5341 }
5342 _ => Vec::new(),
5343 };
5344 let handle = materialize_oof_prediction_handle(
5345 plan,
5346 edge,
5347 ctx,
5348 scope,
5349 resources,
5350 &source_plan.controller_id,
5351 )?;
5352 return Ok(CollectedPredictionInput {
5353 handle,
5354 spec: aggregated_prediction_input_spec(edge, scope, prediction_level, &blocks)?,
5355 });
5356 }
5357 let blocks = match scope.phase {
5358 Phase::FitCv => validate_fit_cv_oof_edge(plan, edge, ctx, scope)?,
5359 Phase::Refit => validate_refit_oof_edge(plan, edge, ctx)?,
5360 _ => Vec::new(),
5361 };
5362 let handle = materialize_oof_prediction_handle(
5363 plan,
5364 edge,
5365 ctx,
5366 scope,
5367 resources,
5368 &source_plan.controller_id,
5369 )?;
5370 Ok(CollectedPredictionInput {
5371 handle,
5372 spec: prediction_input_spec(edge, scope, &blocks)?,
5373 })
5374}
5375
5376fn oof_prediction_level_for_source(source_plan: &NodePlan) -> PredictionLevel {
5377 source_plan
5378 .shape_plan
5379 .as_ref()
5380 .map(|shape_plan| shape_plan.aggregation_policy.aggregation_level)
5381 .unwrap_or(PredictionLevel::Sample)
5382}
5383
5384fn replay_prediction_cache_contract_for_edge<'a>(
5385 resources: &'a PhaseScopeResources<'_>,
5386 edge: &EdgeSpec,
5387) -> Option<&'a ReplayPredictionCacheContract> {
5388 let contracts = resources.prediction_cache_contracts?;
5389 let key = bundle_prediction_requirement_key(
5390 &edge.source.node_id,
5391 &edge.source.port_name,
5392 &edge.target.node_id,
5393 &edge.target.port_name,
5394 );
5395 contracts.get(&key)
5396}
5397
5398fn materialize_oof_prediction_handle(
5399 plan: &ExecutionPlan,
5400 edge: &EdgeSpec,
5401 ctx: &RunContext,
5402 scope: &PhaseScope,
5403 resources: &PhaseScopeResources<'_>,
5404 producer_controller_id: &ControllerId,
5405) -> Result<HandleRef> {
5406 if scope.phase == Phase::Refit {
5407 if let (Some(store), Some(bundle_id), Some(contracts)) = (
5408 resources.prediction_cache_store,
5409 resources.replay_bundle_id,
5410 resources.prediction_cache_contracts,
5411 ) {
5412 let key = bundle_prediction_requirement_key(
5413 &edge.source.node_id,
5414 &edge.source.port_name,
5415 &edge.target.node_id,
5416 &edge.target.port_name,
5417 );
5418 let contract = contracts.get(&key).ok_or_else(|| {
5419 DagMlError::RuntimeValidation(format!(
5420 "replay prediction cache store cannot materialize missing requirement `{key}`"
5421 ))
5422 })?;
5423 let handle = store.materialize(&PredictionCacheMaterializationRequest {
5424 run_id: ctx.run_id.clone(),
5425 bundle_id: bundle_id.clone(),
5426 phase: scope.phase,
5427 variant_id: scope.variant_id.clone(),
5428 requirement: contract.requirement.clone(),
5429 cache: contract.cache.clone(),
5430 producer_controller_id: producer_controller_id.clone(),
5431 })?;
5432 if handle.kind != HandleKind::Prediction {
5433 return Err(DagMlError::RuntimeValidation(format!(
5434 "prediction cache store materialized requirement `{key}` as {:?}",
5435 handle.kind
5436 )));
5437 }
5438 if &handle.owner_controller != producer_controller_id {
5439 return Err(DagMlError::RuntimeValidation(format!(
5440 "prediction cache store materialized requirement `{key}` for controller `{}`, expected `{}`",
5441 handle.owner_controller, producer_controller_id
5442 )));
5443 }
5444 return Ok(handle);
5445 }
5446 }
5447 Ok(HandleRef {
5448 handle: deterministic_oof_handle(plan, edge, ctx, scope)?,
5449 kind: HandleKind::Prediction,
5450 owner_controller: producer_controller_id.clone(),
5451 })
5452}
5453
5454fn validate_fit_cv_oof_edge<'a>(
5455 plan: &ExecutionPlan,
5456 edge: &EdgeSpec,
5457 ctx: &'a RunContext,
5458 scope: &PhaseScope,
5459) -> Result<Vec<&'a PredictionBlock>> {
5460 let fold_id = scope.fold_id.as_ref().ok_or_else(|| {
5461 DagMlError::RuntimeValidation(format!(
5462 "edge `{}.{}` -> `{}.{}` requires OOF predictions but FIT_CV has no fold scope",
5463 edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
5464 ))
5465 })?;
5466 let blocks = ctx.prediction_store.find(
5467 Some(&edge.source.node_id),
5468 Some(&PredictionPartition::Validation),
5469 Some(fold_id),
5470 );
5471 if blocks.is_empty() {
5472 return Err(missing_oof_edge_error(edge, Some(fold_id)));
5473 }
5474 if edge.contract.requires_fold_alignment {
5475 let fold_set = required_fold_set_for_oof(plan, edge)?;
5476 validate_oof_blocks_match_fold(edge, fold_set, fold_id, &blocks)?;
5477 }
5478 Ok(blocks)
5479}
5480
5481fn validate_refit_oof_edge<'a>(
5482 plan: &ExecutionPlan,
5483 edge: &EdgeSpec,
5484 ctx: &'a RunContext,
5485) -> Result<Vec<&'a PredictionBlock>> {
5486 let blocks = ctx.prediction_store.find(
5487 Some(&edge.source.node_id),
5488 Some(&PredictionPartition::Validation),
5489 None,
5490 );
5491 if blocks.is_empty() {
5492 return Err(missing_oof_edge_error(edge, None));
5493 }
5494 if edge.contract.requires_fold_alignment {
5495 let fold_set = required_fold_set_for_oof(plan, edge)?;
5496 validate_oof_blocks_cover_fold_set(edge, fold_set, &blocks)?;
5497 }
5498 Ok(blocks)
5499}
5500
5501fn validate_fit_cv_aggregated_oof_edge<'a>(
5502 plan: &ExecutionPlan,
5503 edge: &EdgeSpec,
5504 ctx: &'a RunContext,
5505 scope: &PhaseScope,
5506 resources: &PhaseScopeResources<'_>,
5507 prediction_level: PredictionLevel,
5508) -> Result<Vec<&'a AggregatedPredictionBlock>> {
5509 let fold_id = scope.fold_id.as_ref().ok_or_else(|| {
5510 DagMlError::RuntimeValidation(format!(
5511 "edge `{}.{}` -> `{}.{}` requires aggregated OOF predictions but FIT_CV has no fold scope",
5512 edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
5513 ))
5514 })?;
5515 let blocks = ctx.aggregated_prediction_store.find(
5516 Some(&edge.source.node_id),
5517 Some(&PredictionPartition::Validation),
5518 Some(fold_id),
5519 Some(prediction_level),
5520 );
5521 if blocks.is_empty() {
5522 return Err(missing_oof_edge_error(edge, Some(fold_id)));
5523 }
5524 validate_aggregated_blocks_basic(edge, prediction_level, &blocks)?;
5525 if edge.contract.requires_fold_alignment {
5526 let fold_set = required_fold_set_for_oof(plan, edge)?;
5527 let relations = coordinator_relations_for_edge(plan, edge, resources)?;
5528 validate_aggregated_oof_blocks_match_fold(
5529 edge,
5530 fold_set,
5531 &relations,
5532 prediction_level,
5533 fold_id,
5534 &blocks,
5535 )?;
5536 }
5537 Ok(blocks)
5538}
5539
5540fn validate_refit_aggregated_oof_edge<'a>(
5541 plan: &ExecutionPlan,
5542 edge: &EdgeSpec,
5543 ctx: &'a RunContext,
5544 resources: &PhaseScopeResources<'_>,
5545 prediction_level: PredictionLevel,
5546) -> Result<Vec<&'a AggregatedPredictionBlock>> {
5547 let blocks = ctx.aggregated_prediction_store.find(
5548 Some(&edge.source.node_id),
5549 Some(&PredictionPartition::Validation),
5550 None,
5551 Some(prediction_level),
5552 );
5553 if blocks.is_empty() {
5554 return Err(missing_oof_edge_error(edge, None));
5555 }
5556 validate_aggregated_blocks_basic(edge, prediction_level, &blocks)?;
5557 if edge.contract.requires_fold_alignment {
5558 let fold_set = required_fold_set_for_oof(plan, edge)?;
5559 let relations = coordinator_relations_for_edge(plan, edge, resources)?;
5560 validate_aggregated_oof_blocks_cover_fold_set(
5561 edge,
5562 fold_set,
5563 &relations,
5564 prediction_level,
5565 &blocks,
5566 )?;
5567 }
5568 Ok(blocks)
5569}
5570
5571fn validate_aggregated_blocks_basic(
5572 edge: &EdgeSpec,
5573 prediction_level: PredictionLevel,
5574 blocks: &[&AggregatedPredictionBlock],
5575) -> Result<()> {
5576 for block in blocks {
5577 block.validate_shape()?;
5578 if block.partition != PredictionPartition::Validation {
5579 return Err(DagMlError::RuntimeValidation(format!(
5580 "edge `{}.{}` -> `{}.{}` selected non-validation aggregated predictions",
5581 edge.source.node_id,
5582 edge.source.port_name,
5583 edge.target.node_id,
5584 edge.target.port_name
5585 )));
5586 }
5587 if block.level != prediction_level {
5588 return Err(DagMlError::RuntimeValidation(format!(
5589 "edge `{}.{}` -> `{}.{}` selected {:?} aggregated predictions, expected {:?}",
5590 edge.source.node_id,
5591 edge.source.port_name,
5592 edge.target.node_id,
5593 edge.target.port_name,
5594 block.level,
5595 prediction_level
5596 )));
5597 }
5598 }
5599 Ok(())
5600}
5601
5602fn prediction_input_spec(
5603 edge: &EdgeSpec,
5604 scope: &PhaseScope,
5605 blocks: &[&PredictionBlock],
5606) -> Result<PredictionInputSpec> {
5607 let sample_ids = collect_unique_oof_samples(edge, blocks)?
5608 .into_iter()
5609 .collect::<Vec<_>>();
5610 let fold_ids = blocks
5611 .iter()
5612 .filter_map(|block| block.fold_id.clone())
5613 .collect::<BTreeSet<_>>()
5614 .into_iter()
5615 .collect::<Vec<_>>();
5616 let mut prediction_width = None;
5617 let mut target_names = None;
5618 for block in blocks {
5619 let width = block.validate_shape()?;
5620 let block_target_names = if block.target_names.is_empty() {
5621 (0..width)
5622 .map(|index| format!("p{index}"))
5623 .collect::<Vec<_>>()
5624 } else {
5625 block.target_names.clone()
5626 };
5627 if prediction_width.is_some_and(|expected| expected != width) {
5628 return Err(DagMlError::RuntimeValidation(format!(
5629 "edge `{}.{}` -> `{}.{}` OOF prediction width is not stable across folds",
5630 edge.source.node_id,
5631 edge.source.port_name,
5632 edge.target.node_id,
5633 edge.target.port_name
5634 )));
5635 }
5636 if target_names
5637 .as_ref()
5638 .is_some_and(|expected| expected != &block_target_names)
5639 {
5640 return Err(DagMlError::RuntimeValidation(format!(
5641 "edge `{}.{}` -> `{}.{}` OOF target names are not stable across folds",
5642 edge.source.node_id,
5643 edge.source.port_name,
5644 edge.target.node_id,
5645 edge.target.port_name
5646 )));
5647 }
5648 prediction_width = Some(width);
5649 target_names = Some(block_target_names);
5650 }
5651 Ok(PredictionInputSpec {
5652 producer_node: edge.source.node_id.clone(),
5653 source_port: edge.source.port_name.clone(),
5654 target_port: edge.target.port_name.clone(),
5655 partition: PredictionPartition::Validation,
5656 prediction_level: PredictionLevel::Sample,
5657 fold_id: scope.fold_id.clone(),
5658 fold_ids,
5659 unit_ids: sample_ids
5660 .iter()
5661 .cloned()
5662 .map(PredictionUnitId::Sample)
5663 .collect(),
5664 sample_ids,
5665 prediction_width: prediction_width.unwrap_or_default(),
5666 target_names: target_names.unwrap_or_default(),
5667 })
5668}
5669
5670fn aggregated_prediction_input_spec(
5671 edge: &EdgeSpec,
5672 scope: &PhaseScope,
5673 prediction_level: PredictionLevel,
5674 blocks: &[&AggregatedPredictionBlock],
5675) -> Result<PredictionInputSpec> {
5676 let unit_ids = collect_unique_aggregated_oof_units(edge, prediction_level, blocks)?
5677 .into_iter()
5678 .collect::<Vec<_>>();
5679 let fold_ids = blocks
5680 .iter()
5681 .filter_map(|block| block.fold_id.clone())
5682 .collect::<BTreeSet<_>>()
5683 .into_iter()
5684 .collect::<Vec<_>>();
5685 let mut prediction_width = None;
5686 let mut target_names = None;
5687 for block in blocks {
5688 let width = block.validate_shape()?;
5689 let block_target_names = if block.target_names.is_empty() {
5690 (0..width)
5691 .map(|index| format!("p{index}"))
5692 .collect::<Vec<_>>()
5693 } else {
5694 block.target_names.clone()
5695 };
5696 if prediction_width.is_some_and(|expected| expected != width) {
5697 return Err(DagMlError::RuntimeValidation(format!(
5698 "edge `{}.{}` -> `{}.{}` aggregated OOF prediction width is not stable across folds",
5699 edge.source.node_id,
5700 edge.source.port_name,
5701 edge.target.node_id,
5702 edge.target.port_name
5703 )));
5704 }
5705 if target_names
5706 .as_ref()
5707 .is_some_and(|expected| expected != &block_target_names)
5708 {
5709 return Err(DagMlError::RuntimeValidation(format!(
5710 "edge `{}.{}` -> `{}.{}` aggregated OOF target names are not stable across folds",
5711 edge.source.node_id,
5712 edge.source.port_name,
5713 edge.target.node_id,
5714 edge.target.port_name
5715 )));
5716 }
5717 prediction_width = Some(width);
5718 target_names = Some(block_target_names);
5719 }
5720 Ok(PredictionInputSpec {
5721 producer_node: edge.source.node_id.clone(),
5722 source_port: edge.source.port_name.clone(),
5723 target_port: edge.target.port_name.clone(),
5724 partition: PredictionPartition::Validation,
5725 prediction_level,
5726 fold_id: scope.fold_id.clone(),
5727 fold_ids,
5728 unit_ids,
5729 sample_ids: Vec::new(),
5730 prediction_width: prediction_width.unwrap_or_default(),
5731 target_names: target_names.unwrap_or_default(),
5732 })
5733}
5734
5735fn prediction_input_spec_from_requirement(
5736 requirement: &BundlePredictionRequirement,
5737 scope: &PhaseScope,
5738) -> Result<PredictionInputSpec> {
5739 requirement.validate()?;
5740 Ok(PredictionInputSpec {
5741 producer_node: requirement.producer_node.clone(),
5742 source_port: requirement.source_port.clone(),
5743 target_port: requirement.target_port.clone(),
5744 partition: requirement.partition.clone(),
5745 prediction_level: requirement.prediction_level,
5746 fold_id: scope.fold_id.clone(),
5747 fold_ids: requirement.fold_ids.clone(),
5748 unit_ids: requirement.unit_ids.clone(),
5749 sample_ids: requirement.sample_ids.clone(),
5750 prediction_width: requirement.prediction_width,
5751 target_names: requirement.target_names.clone(),
5752 })
5753}
5754
5755fn missing_oof_edge_error(edge: &EdgeSpec, fold_id: Option<&FoldId>) -> DagMlError {
5756 DagMlError::RuntimeValidation(format!(
5757 "edge `{}.{}` -> `{}.{}` requires OOF validation predictions from `{}`{}",
5758 edge.source.node_id,
5759 edge.source.port_name,
5760 edge.target.node_id,
5761 edge.target.port_name,
5762 edge.source.node_id,
5763 fold_id
5764 .map(|fold_id| format!(" for fold `{fold_id}`"))
5765 .unwrap_or_default()
5766 ))
5767}
5768
5769fn required_fold_set_for_oof<'a>(plan: &'a ExecutionPlan, edge: &EdgeSpec) -> Result<&'a FoldSet> {
5770 plan.fold_set.as_ref().ok_or_else(|| {
5771 DagMlError::RuntimeValidation(format!(
5772 "edge `{}.{}` -> `{}.{}` requires fold-aligned OOF predictions but the plan has no fold set",
5773 edge.source.node_id,
5774 edge.source.port_name,
5775 edge.target.node_id,
5776 edge.target.port_name
5777 ))
5778 })
5779}
5780
5781fn validate_oof_blocks_match_fold(
5782 edge: &EdgeSpec,
5783 fold_set: &FoldSet,
5784 fold_id: &FoldId,
5785 blocks: &[&PredictionBlock],
5786) -> Result<()> {
5787 let fold = fold_set
5788 .folds
5789 .iter()
5790 .find(|fold| &fold.fold_id == fold_id)
5791 .ok_or_else(|| {
5792 DagMlError::RuntimeValidation(format!(
5793 "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5794 edge.source.node_id,
5795 edge.source.port_name,
5796 edge.target.node_id,
5797 edge.target.port_name
5798 ))
5799 })?;
5800 let actual = collect_unique_oof_samples(edge, blocks)?;
5801 let expected = fold
5802 .validation_sample_ids
5803 .iter()
5804 .cloned()
5805 .collect::<BTreeSet<_>>();
5806 if actual != expected {
5807 return Err(DagMlError::RuntimeValidation(format!(
5808 "edge `{}.{}` -> `{}.{}` OOF predictions do not match validation samples for fold `{fold_id}`",
5809 edge.source.node_id,
5810 edge.source.port_name,
5811 edge.target.node_id,
5812 edge.target.port_name
5813 )));
5814 }
5815 Ok(())
5816}
5817
5818fn validate_oof_blocks_cover_fold_set(
5819 edge: &EdgeSpec,
5820 fold_set: &FoldSet,
5821 blocks: &[&PredictionBlock],
5822) -> Result<()> {
5823 let folds = fold_set
5824 .folds
5825 .iter()
5826 .map(|fold| (&fold.fold_id, fold))
5827 .collect::<BTreeMap<_, _>>();
5828 let mut all_samples = BTreeSet::new();
5829 for block in blocks {
5830 let fold_id = block.fold_id.as_ref().ok_or_else(|| {
5831 DagMlError::RuntimeValidation(format!(
5832 "edge `{}.{}` -> `{}.{}` has OOF predictions without a fold id",
5833 edge.source.node_id,
5834 edge.source.port_name,
5835 edge.target.node_id,
5836 edge.target.port_name
5837 ))
5838 })?;
5839 let fold = folds.get(fold_id).ok_or_else(|| {
5840 DagMlError::RuntimeValidation(format!(
5841 "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5842 edge.source.node_id,
5843 edge.source.port_name,
5844 edge.target.node_id,
5845 edge.target.port_name
5846 ))
5847 })?;
5848 let block_samples = collect_unique_oof_samples(edge, &[*block])?;
5849 let expected = fold
5850 .validation_sample_ids
5851 .iter()
5852 .cloned()
5853 .collect::<BTreeSet<_>>();
5854 if block_samples != expected {
5855 return Err(DagMlError::RuntimeValidation(format!(
5856 "edge `{}.{}` -> `{}.{}` OOF predictions do not match validation samples for fold `{fold_id}`",
5857 edge.source.node_id,
5858 edge.source.port_name,
5859 edge.target.node_id,
5860 edge.target.port_name
5861 )));
5862 }
5863 for sample_id in block_samples {
5864 if !all_samples.insert(sample_id.clone()) {
5865 return Err(DagMlError::RuntimeValidation(format!(
5866 "edge `{}.{}` -> `{}.{}` has duplicate OOF prediction for sample `{sample_id}`",
5867 edge.source.node_id,
5868 edge.source.port_name,
5869 edge.target.node_id,
5870 edge.target.port_name
5871 )));
5872 }
5873 }
5874 }
5875 let expected_all = fold_set.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
5876 if all_samples != expected_all {
5877 return Err(DagMlError::RuntimeValidation(format!(
5878 "edge `{}.{}` -> `{}.{}` OOF predictions do not cover the refit sample universe",
5879 edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
5880 )));
5881 }
5882 Ok(())
5883}
5884
5885fn validate_aggregated_oof_blocks_match_fold(
5886 edge: &EdgeSpec,
5887 fold_set: &FoldSet,
5888 relations: &SampleRelationSet,
5889 prediction_level: PredictionLevel,
5890 fold_id: &FoldId,
5891 blocks: &[&AggregatedPredictionBlock],
5892) -> Result<()> {
5893 let fold = fold_set
5894 .folds
5895 .iter()
5896 .find(|fold| &fold.fold_id == fold_id)
5897 .ok_or_else(|| {
5898 DagMlError::RuntimeValidation(format!(
5899 "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5900 edge.source.node_id,
5901 edge.source.port_name,
5902 edge.target.node_id,
5903 edge.target.port_name
5904 ))
5905 })?;
5906 validate_aggregated_fold_unit_safety(edge, relations, prediction_level, fold)?;
5907 for block in blocks {
5908 if block.fold_id.as_ref() != Some(fold_id) {
5909 return Err(DagMlError::RuntimeValidation(format!(
5910 "edge `{}.{}` -> `{}.{}` selected aggregated OOF predictions outside fold `{fold_id}`",
5911 edge.source.node_id,
5912 edge.source.port_name,
5913 edge.target.node_id,
5914 edge.target.port_name
5915 )));
5916 }
5917 }
5918 let actual = collect_unique_aggregated_oof_units(edge, prediction_level, blocks)?;
5919 let expected = expected_prediction_units_for_samples(
5920 edge,
5921 relations,
5922 prediction_level,
5923 &fold.validation_sample_ids,
5924 )?;
5925 if actual != expected {
5926 return Err(DagMlError::RuntimeValidation(format!(
5927 "edge `{}.{}` -> `{}.{}` aggregated OOF predictions do not match {:?} validation units for fold `{fold_id}`",
5928 edge.source.node_id,
5929 edge.source.port_name,
5930 edge.target.node_id,
5931 edge.target.port_name,
5932 prediction_level
5933 )));
5934 }
5935 Ok(())
5936}
5937
5938fn validate_aggregated_oof_blocks_cover_fold_set(
5939 edge: &EdgeSpec,
5940 fold_set: &FoldSet,
5941 relations: &SampleRelationSet,
5942 prediction_level: PredictionLevel,
5943 blocks: &[&AggregatedPredictionBlock],
5944) -> Result<()> {
5945 let folds = fold_set
5946 .folds
5947 .iter()
5948 .map(|fold| (fold.fold_id.clone(), fold))
5949 .collect::<BTreeMap<_, _>>();
5950 let mut blocks_by_fold = BTreeMap::<FoldId, Vec<&AggregatedPredictionBlock>>::new();
5951 for block in blocks {
5952 let fold_id = block.fold_id.as_ref().ok_or_else(|| {
5953 DagMlError::RuntimeValidation(format!(
5954 "edge `{}.{}` -> `{}.{}` has aggregated OOF predictions without a fold id",
5955 edge.source.node_id,
5956 edge.source.port_name,
5957 edge.target.node_id,
5958 edge.target.port_name
5959 ))
5960 })?;
5961 if !folds.contains_key(fold_id) {
5962 return Err(DagMlError::RuntimeValidation(format!(
5963 "edge `{}.{}` -> `{}.{}` references unknown fold `{fold_id}`",
5964 edge.source.node_id,
5965 edge.source.port_name,
5966 edge.target.node_id,
5967 edge.target.port_name
5968 )));
5969 }
5970 blocks_by_fold
5971 .entry(fold_id.clone())
5972 .or_default()
5973 .push(*block);
5974 }
5975 for fold_id in folds.keys() {
5976 if !blocks_by_fold.contains_key(fold_id) {
5977 return Err(DagMlError::RuntimeValidation(format!(
5978 "edge `{}.{}` -> `{}.{}` is missing aggregated OOF predictions for fold `{fold_id}`",
5979 edge.source.node_id,
5980 edge.source.port_name,
5981 edge.target.node_id,
5982 edge.target.port_name
5983 )));
5984 }
5985 }
5986
5987 let mut all_units = BTreeSet::new();
5988 for (fold_id, fold_blocks) in blocks_by_fold {
5989 let fold = folds.get(&fold_id).expect("fold id was validated above");
5990 validate_aggregated_fold_unit_safety(edge, relations, prediction_level, fold)?;
5991 let fold_units = collect_unique_aggregated_oof_units(edge, prediction_level, &fold_blocks)?;
5992 let expected = expected_prediction_units_for_samples(
5993 edge,
5994 relations,
5995 prediction_level,
5996 &fold.validation_sample_ids,
5997 )?;
5998 if fold_units != expected {
5999 return Err(DagMlError::RuntimeValidation(format!(
6000 "edge `{}.{}` -> `{}.{}` aggregated OOF predictions do not match {:?} validation units for fold `{fold_id}`",
6001 edge.source.node_id,
6002 edge.source.port_name,
6003 edge.target.node_id,
6004 edge.target.port_name,
6005 prediction_level
6006 )));
6007 }
6008 for unit_id in fold_units {
6009 if !all_units.insert(unit_id.clone()) {
6010 return Err(DagMlError::RuntimeValidation(format!(
6011 "edge `{}.{}` -> `{}.{}` has duplicate aggregated OOF prediction for unit `{unit_id}`",
6012 edge.source.node_id,
6013 edge.source.port_name,
6014 edge.target.node_id,
6015 edge.target.port_name
6016 )));
6017 }
6018 }
6019 }
6020
6021 let expected_all = expected_prediction_units_for_samples(
6022 edge,
6023 relations,
6024 prediction_level,
6025 &fold_set.sample_ids,
6026 )?;
6027 if all_units != expected_all {
6028 return Err(DagMlError::RuntimeValidation(format!(
6029 "edge `{}.{}` -> `{}.{}` aggregated OOF predictions do not cover the refit {:?} unit universe",
6030 edge.source.node_id,
6031 edge.source.port_name,
6032 edge.target.node_id,
6033 edge.target.port_name,
6034 prediction_level
6035 )));
6036 }
6037 Ok(())
6038}
6039
6040fn validate_aggregated_fold_unit_safety(
6041 edge: &EdgeSpec,
6042 relations: &SampleRelationSet,
6043 prediction_level: PredictionLevel,
6044 fold: &FoldAssignment,
6045) -> Result<()> {
6046 let train_units = expected_prediction_units_for_samples(
6047 edge,
6048 relations,
6049 prediction_level,
6050 &fold.train_sample_ids,
6051 )?;
6052 let validation_units = expected_prediction_units_for_samples(
6053 edge,
6054 relations,
6055 prediction_level,
6056 &fold.validation_sample_ids,
6057 )?;
6058 if let Some(unit_id) = train_units.intersection(&validation_units).next() {
6059 return Err(DagMlError::RuntimeValidation(format!(
6060 "edge `{}.{}` -> `{}.{}` fold `{}` has {:?} unit `{unit_id}` in both train and validation partitions",
6061 edge.source.node_id,
6062 edge.source.port_name,
6063 edge.target.node_id,
6064 edge.target.port_name,
6065 fold.fold_id,
6066 prediction_level
6067 )));
6068 }
6069 Ok(())
6070}
6071
6072fn collect_unique_oof_samples(
6073 edge: &EdgeSpec,
6074 blocks: &[&PredictionBlock],
6075) -> Result<BTreeSet<SampleId>> {
6076 let mut samples = BTreeSet::new();
6077 for block in blocks {
6078 if block.partition != PredictionPartition::Validation {
6079 return Err(DagMlError::RuntimeValidation(format!(
6080 "edge `{}.{}` -> `{}.{}` selected non-validation predictions",
6081 edge.source.node_id,
6082 edge.source.port_name,
6083 edge.target.node_id,
6084 edge.target.port_name
6085 )));
6086 }
6087 for sample_id in &block.sample_ids {
6088 if !samples.insert(sample_id.clone()) {
6089 return Err(DagMlError::RuntimeValidation(format!(
6090 "edge `{}.{}` -> `{}.{}` has duplicate OOF prediction for sample `{sample_id}`",
6091 edge.source.node_id,
6092 edge.source.port_name,
6093 edge.target.node_id,
6094 edge.target.port_name
6095 )));
6096 }
6097 }
6098 }
6099 Ok(samples)
6100}
6101
6102fn collect_unique_aggregated_oof_units(
6103 edge: &EdgeSpec,
6104 prediction_level: PredictionLevel,
6105 blocks: &[&AggregatedPredictionBlock],
6106) -> Result<BTreeSet<PredictionUnitId>> {
6107 let mut unit_ids = BTreeSet::new();
6108 for block in blocks {
6109 block.validate_shape()?;
6110 if block.partition != PredictionPartition::Validation {
6111 return Err(DagMlError::RuntimeValidation(format!(
6112 "edge `{}.{}` -> `{}.{}` selected non-validation aggregated predictions",
6113 edge.source.node_id,
6114 edge.source.port_name,
6115 edge.target.node_id,
6116 edge.target.port_name
6117 )));
6118 }
6119 if block.level != prediction_level {
6120 return Err(DagMlError::RuntimeValidation(format!(
6121 "edge `{}.{}` -> `{}.{}` selected {:?} aggregated predictions, expected {:?}",
6122 edge.source.node_id,
6123 edge.source.port_name,
6124 edge.target.node_id,
6125 edge.target.port_name,
6126 block.level,
6127 prediction_level
6128 )));
6129 }
6130 for unit_id in &block.unit_ids {
6131 if !unit_ids.insert(unit_id.clone()) {
6132 return Err(DagMlError::RuntimeValidation(format!(
6133 "edge `{}.{}` -> `{}.{}` has duplicate aggregated OOF prediction for unit `{unit_id}`",
6134 edge.source.node_id,
6135 edge.source.port_name,
6136 edge.target.node_id,
6137 edge.target.port_name
6138 )));
6139 }
6140 }
6141 }
6142 Ok(unit_ids)
6143}
6144
6145fn expected_prediction_units_for_samples(
6146 edge: &EdgeSpec,
6147 relations: &SampleRelationSet,
6148 prediction_level: PredictionLevel,
6149 sample_ids: &[SampleId],
6150) -> Result<BTreeSet<PredictionUnitId>> {
6151 sample_ids
6152 .iter()
6153 .map(|sample_id| prediction_unit_for_sample(edge, relations, prediction_level, sample_id))
6154 .collect()
6155}
6156
6157fn prediction_unit_for_sample(
6158 edge: &EdgeSpec,
6159 relations: &SampleRelationSet,
6160 prediction_level: PredictionLevel,
6161 sample_id: &SampleId,
6162) -> Result<PredictionUnitId> {
6163 match prediction_level {
6164 PredictionLevel::Sample => Ok(PredictionUnitId::Sample(sample_id.clone())),
6165 PredictionLevel::Target => relations
6166 .target_for_sample(sample_id)
6167 .cloned()
6168 .map(PredictionUnitId::Target)
6169 .ok_or_else(|| {
6170 DagMlError::RuntimeValidation(format!(
6171 "edge `{}.{}` -> `{}.{}` needs target-level OOF predictions but sample `{sample_id}` has no target relation",
6172 edge.source.node_id,
6173 edge.source.port_name,
6174 edge.target.node_id,
6175 edge.target.port_name
6176 ))
6177 }),
6178 PredictionLevel::Group => relations
6179 .group_for_sample(sample_id)
6180 .cloned()
6181 .map(PredictionUnitId::Group)
6182 .ok_or_else(|| {
6183 DagMlError::RuntimeValidation(format!(
6184 "edge `{}.{}` -> `{}.{}` needs group-level OOF predictions but sample `{sample_id}` has no group relation",
6185 edge.source.node_id,
6186 edge.source.port_name,
6187 edge.target.node_id,
6188 edge.target.port_name
6189 ))
6190 }),
6191 PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
6192 "edge `{}.{}` -> `{}.{}` cannot consume observation-level OOF predictions from sample folds",
6193 edge.source.node_id, edge.source.port_name, edge.target.node_id, edge.target.port_name
6194 ))),
6195 }
6196}
6197
6198fn deterministic_oof_handle(
6199 plan: &ExecutionPlan,
6200 edge: &EdgeSpec,
6201 ctx: &RunContext,
6202 scope: &PhaseScope,
6203) -> Result<u64> {
6204 let fingerprint = stable_json_fingerprint(&(
6205 &plan.id,
6206 &ctx.run_id,
6207 &edge.source.node_id,
6208 &edge.source.port_name,
6209 &edge.target.node_id,
6210 &edge.target.port_name,
6211 scope.phase,
6212 &scope.variant_id,
6213 &scope.fold_id,
6214 ))?;
6215 Ok(u64::from_str_radix(&fingerprint[..16], 16).expect("sha256 hex prefix should fit into u64"))
6216}
6217
6218struct CollectedInputs {
6219 handles: BTreeMap<String, HandleRef>,
6220 data_views: BTreeMap<String, DataProviderViewSpec>,
6221 prediction_inputs: BTreeMap<String, PredictionInputSpec>,
6222}
6223
6224fn data_view_key(input_name: &str) -> String {
6225 format!("data:{input_name}")
6226}
6227
6228fn validation_data_view_key(input_name: &str) -> String {
6229 format!("{input_name}:validation")
6230}
6231
6232fn derive_output_data_views(
6233 plan: &ExecutionPlan,
6234 task: &NodeTask,
6235 result: &NodeResult,
6236) -> Result<BTreeMap<String, DataProviderViewSpec>> {
6237 let node = plan
6238 .graph_plan
6239 .graph
6240 .nodes
6241 .iter()
6242 .find(|node| node.id == task.node_plan.node_id)
6243 .expect("execution plan was validated");
6244 let mut views = BTreeMap::new();
6245 for port in node
6246 .ports
6247 .outputs
6248 .iter()
6249 .filter(|port| port.kind == PortKind::Data)
6250 {
6251 let Some(handle) = result.outputs.get(&port.name) else {
6252 continue;
6253 };
6254 if !matches!(handle.kind, HandleKind::Data | HandleKind::DataView) {
6255 return Err(DagMlError::RuntimeValidation(format!(
6256 "node `{}` emitted data output `{}` with non-data/data-view handle kind {:?}",
6257 task.node_plan.node_id, port.name, handle.kind
6258 )));
6259 }
6260 if let Some(view) = primary_output_data_view(task) {
6261 views.insert(
6262 port.name.clone(),
6263 output_data_view_for_port(task, result, &port.name, view)?,
6264 );
6265 }
6266 if let Some(validation_view) = validation_output_data_view(task) {
6267 views.insert(
6268 validation_data_view_key(&port.name),
6269 output_data_view_for_port(task, result, &port.name, validation_view)?,
6270 );
6271 }
6272 }
6273 Ok(views)
6274}
6275
6276fn output_data_view_for_port(
6277 task: &NodeTask,
6278 result: &NodeResult,
6279 port_name: &str,
6280 base_view: &DataProviderViewSpec,
6281) -> Result<DataProviderViewSpec> {
6282 let mut view = base_view.clone();
6283 if let Some(upstream_provenance) = view.extra.remove(DATA_OUTPUT_PROVENANCE_KEY) {
6284 let provenance: DataOutputProvenance =
6285 serde_json::from_value(upstream_provenance).map_err(|error| {
6286 DagMlError::RuntimeValidation(format!(
6287 "node `{}` cannot propagate data output `{port_name}` because upstream data output provenance is invalid JSON: {error}",
6288 task.node_plan.node_id
6289 ))
6290 })?;
6291 provenance.validate().map_err(|error| {
6292 DagMlError::RuntimeValidation(format!(
6293 "node `{}` cannot propagate data output `{port_name}` because upstream data output provenance is invalid: {error}",
6294 task.node_plan.node_id
6295 ))
6296 })?;
6297 }
6298 let shape_deltas = result
6299 .shape_deltas
6300 .iter()
6301 .filter(|delta| delta.node_id == task.node_plan.node_id)
6302 .cloned()
6303 .collect::<Vec<_>>();
6304 let mut provenance = DataOutputProvenance {
6305 schema_version: DATA_OUTPUT_PROVENANCE_SCHEMA_VERSION,
6306 producer_node: task.node_plan.node_id.clone(),
6307 producer_port: port_name.to_string(),
6308 producer_phase: task.phase,
6309 variant_id: task.variant_id.clone(),
6310 fold_id: task.fold_id.clone(),
6311 shape_plan_fingerprint: None,
6312 aggregation_policy_fingerprint: None,
6313 feature_namespace: None,
6314 feature_schema_fingerprint: None,
6315 representation_plan: None,
6316 representation_replay_manifest: None,
6317 representation_compatibility: None,
6318 relation_delta_fingerprint: None,
6319 shape_deltas,
6320 };
6321 if let Some(shape_plan) = &task.node_plan.shape_plan {
6322 provenance.shape_plan_fingerprint = Some(stable_json_fingerprint(shape_plan)?);
6323 provenance.aggregation_policy_fingerprint =
6324 Some(stable_json_fingerprint(&shape_plan.aggregation_policy)?);
6325 provenance.feature_namespace = shape_plan.feature_namespace.clone();
6326 provenance.feature_schema_fingerprint =
6327 output_feature_schema_fingerprint(shape_plan, result);
6328 }
6329 provenance.validate()?;
6330
6331 view.extra.insert(
6332 DATA_OUTPUT_PROVENANCE_KEY.to_string(),
6333 serde_json::to_value(provenance)?,
6334 );
6335 view.validate()?;
6336 Ok(view)
6337}
6338
6339fn output_feature_schema_fingerprint(
6340 shape_plan: &crate::policy::DataModelShapePlan,
6341 result: &NodeResult,
6342) -> Option<String> {
6343 result
6344 .shape_deltas
6345 .iter()
6346 .rev()
6347 .find(|delta| delta.kind == ShapeDeltaKind::Feature)
6348 .map(|delta| delta.after_fingerprint.clone())
6349 .or_else(|| shape_plan.feature_schema_fingerprint.clone())
6350}
6351
6352fn primary_output_data_view(task: &NodeTask) -> Option<&DataProviderViewSpec> {
6353 task.data_views
6354 .values()
6355 .find(|view| view.partition != DataRequestPartition::FoldValidation)
6356 .or_else(|| task.data_views.values().next())
6357}
6358
6359fn validation_output_data_view(task: &NodeTask) -> Option<&DataProviderViewSpec> {
6360 task.data_views
6361 .values()
6362 .find(|view| view.partition == DataRequestPartition::FoldValidation)
6363}
6364
6365fn make_data_view_handle(
6366 data_provider: &dyn RuntimeDataProvider,
6367 ctx: &RunContext,
6368 node_plan: &NodePlan,
6369 scope: &PhaseScope,
6370 binding: &DataBinding,
6371 data_handle: &HandleRef,
6372 view: &DataProviderViewSpec,
6373) -> Result<HandleRef> {
6374 view.validate()?;
6375 data_provider.make_view(&DataViewRequest {
6376 run_id: ctx.run_id.clone(),
6377 node_id: node_plan.node_id.clone(),
6378 input_name: binding.input_name.clone(),
6379 phase: scope.phase,
6380 variant_id: scope.variant_id.clone(),
6381 fold_id: scope.fold_id.clone(),
6382 binding: binding.clone(),
6383 data_handle: data_handle.clone(),
6384 view: view.clone(),
6385 })
6386}
6387
6388fn data_view_for_scope(
6389 binding: &DataBinding,
6390 fold_set: Option<&FoldSet>,
6391 scope: &PhaseScope,
6392 branch_view: Option<&crate::data::BranchViewPlan>,
6393) -> Result<DataProviderViewSpec> {
6394 let partition = data_partition_for_scope(binding, scope);
6395 data_view_for_partition(binding, fold_set, scope, partition, branch_view)
6396}
6397
6398fn validation_data_view_for_scope(
6399 binding: &DataBinding,
6400 fold_set: Option<&FoldSet>,
6401 scope: &PhaseScope,
6402 branch_view: Option<&crate::data::BranchViewPlan>,
6403) -> Result<Option<DataProviderViewSpec>> {
6404 if scope.phase != Phase::FitCv || scope.fold_id.is_none() {
6405 return Ok(None);
6406 }
6407 let partition = binding.view_policy.predict_partition;
6408 if partition == data_partition_for_scope(binding, scope) {
6409 return Ok(None);
6410 }
6411 data_view_for_partition(binding, fold_set, scope, partition, branch_view).map(Some)
6412}
6413
6414fn branch_view_from_node_metadata(
6422 plan: &ExecutionPlan,
6423 node_id: &NodeId,
6424) -> Result<Option<crate::data::BranchViewPlan>> {
6425 let node = match plan
6426 .graph_plan
6427 .graph
6428 .nodes
6429 .iter()
6430 .find(|node| &node.id == node_id)
6431 {
6432 Some(node) => node,
6433 None => return Ok(None),
6434 };
6435 let Some(value) = node.metadata.get("dsl_branch_view_plan") else {
6436 return Ok(None);
6437 };
6438 let plan: crate::data::BranchViewPlan =
6439 serde_json::from_value(value.clone()).map_err(|error| {
6440 DagMlError::RuntimeValidation(format!(
6441 "node `{node_id}` carries malformed `dsl_branch_view_plan` metadata: {error}"
6442 ))
6443 })?;
6444 plan.validate()?;
6445 Ok(Some(plan))
6446}
6447
6448fn data_view_for_partition(
6449 binding: &DataBinding,
6450 fold_set: Option<&FoldSet>,
6451 scope: &PhaseScope,
6452 partition: DataRequestPartition,
6453 branch_view: Option<&crate::data::BranchViewPlan>,
6454) -> Result<DataProviderViewSpec> {
6455 let fold = fold_for_scope(fold_set, scope.fold_id.as_ref())?;
6456 let sample_ids = sample_ids_for_partition(partition, fold_set, fold);
6457 if binding.view_policy.require_sample_ids
6458 && matches!(
6459 partition,
6460 DataRequestPartition::FoldTrain | DataRequestPartition::FoldValidation
6461 )
6462 && scope.fold_id.is_some()
6463 && sample_ids.as_ref().is_none_or(Vec::is_empty)
6464 {
6465 return Err(DagMlError::RuntimeValidation(format!(
6466 "data binding `{}` on `{}` requires sample ids for {:?}",
6467 binding.input_name, binding.node_id, partition
6468 )));
6469 }
6470 let include_augmented = match partition {
6471 DataRequestPartition::FoldTrain | DataRequestPartition::FullTrain => {
6472 binding.view_policy.include_augmented_train
6473 }
6474 DataRequestPartition::FoldValidation | DataRequestPartition::Predict => {
6475 binding.view_policy.include_augmented_validation
6476 }
6477 };
6478 let mut extra = BTreeMap::new();
6479 extra.insert(
6480 "feature_set_id".to_string(),
6481 serde_json::Value::String(binding.feature_set_id().to_string()),
6482 );
6483 if !binding.view_policy.unsafe_flags.is_empty() {
6484 extra.insert(
6485 "unsafe_flags".to_string(),
6486 serde_json::Value::Array(
6487 binding
6488 .view_policy
6489 .unsafe_flags
6490 .iter()
6491 .cloned()
6492 .map(serde_json::Value::String)
6493 .collect(),
6494 ),
6495 );
6496 }
6497 let view = DataProviderViewSpec {
6498 sample_ids,
6499 partition,
6500 fold_id: match partition {
6501 DataRequestPartition::FoldTrain | DataRequestPartition::FoldValidation => {
6502 scope.fold_id.clone()
6503 }
6504 DataRequestPartition::FullTrain | DataRequestPartition::Predict => None,
6505 },
6506 source_ids: (!binding.source_ids.is_empty()).then(|| binding.source_ids.clone()),
6507 columns: None,
6508 include_augmented,
6509 include_excluded: binding.view_policy.include_excluded,
6510 branch_view: branch_view.cloned(),
6511 extra,
6512 };
6513 view.validate()?;
6514 Ok(view)
6515}
6516
6517fn data_partition_for_scope(binding: &DataBinding, scope: &PhaseScope) -> DataRequestPartition {
6518 match scope.phase {
6519 Phase::FitCv => binding.view_policy.fit_partition,
6520 Phase::Refit => DataRequestPartition::FullTrain,
6521 Phase::Predict | Phase::Explain if scope.fold_id.is_none() => DataRequestPartition::Predict,
6522 Phase::Predict | Phase::Explain => binding.view_policy.predict_partition,
6523 Phase::Compile | Phase::Plan | Phase::Select => DataRequestPartition::FullTrain,
6524 }
6525}
6526
6527fn fold_for_scope<'a>(
6528 fold_set: Option<&'a FoldSet>,
6529 fold_id: Option<&FoldId>,
6530) -> Result<Option<&'a FoldAssignment>> {
6531 let Some(fold_id) = fold_id else {
6532 return Ok(None);
6533 };
6534 let fold_set = fold_set.ok_or_else(|| {
6535 DagMlError::RuntimeValidation(format!(
6536 "fold `{fold_id}` requested but execution plan has no fold set"
6537 ))
6538 })?;
6539 fold_set
6540 .folds
6541 .iter()
6542 .find(|fold| &fold.fold_id == fold_id)
6543 .map(Some)
6544 .ok_or_else(|| {
6545 DagMlError::RuntimeValidation(format!(
6546 "fold `{fold_id}` requested but is not present in fold set `{}`",
6547 fold_set.id
6548 ))
6549 })
6550}
6551
6552fn inner_fold_set_for_scope(
6558 campaign: &CampaignSpec,
6559 outer_fold_set: Option<&FoldSet>,
6560 node_plan: &NodePlan,
6561 scope: &PhaseScope,
6562) -> Result<Option<FoldSet>> {
6563 if scope.phase != Phase::FitCv {
6564 return Ok(None);
6565 }
6566 let Some(spec) =
6567 crate::fold::resolve_inner_cv(node_plan.inner_cv.as_ref(), campaign.inner_cv.as_ref())
6568 else {
6569 return Ok(None);
6570 };
6571 let Some(outer) = fold_for_scope(outer_fold_set, scope.fold_id.as_ref())? else {
6575 return Ok(None);
6576 };
6577 let outer_groups = &outer_fold_set
6578 .expect("fold_for_scope returned a fold, so the outer fold set is present")
6579 .sample_groups;
6580 Ok(Some(spec.build_inner_fold_set(outer, outer_groups)?))
6581}
6582
6583fn sample_ids_for_partition(
6584 partition: DataRequestPartition,
6585 fold_set: Option<&FoldSet>,
6586 fold: Option<&FoldAssignment>,
6587) -> Option<Vec<SampleId>> {
6588 match partition {
6589 DataRequestPartition::FoldTrain => fold.map(|fold| fold.train_sample_ids.clone()),
6590 DataRequestPartition::FoldValidation => fold.map(|fold| fold.validation_sample_ids.clone()),
6591 DataRequestPartition::FullTrain => fold_set.map(|fold_set| fold_set.sample_ids.clone()),
6592 DataRequestPartition::Predict => None,
6593 }
6594}
6595
6596fn preload_replay_prediction_cache_store(
6597 bundle: &ExecutionBundle,
6598 prediction_cache_store: Option<&dyn RuntimePredictionCacheStore>,
6599 ctx: &mut RunContext,
6600) -> Result<()> {
6601 if bundle.prediction_requirements.is_empty() {
6602 return Ok(());
6603 }
6604 let store = prediction_cache_store.ok_or_else(|| {
6605 DagMlError::RuntimeValidation(format!(
6606 "bundle `{}` cannot preload OOF prediction caches without a prediction cache store",
6607 bundle.bundle_id
6608 ))
6609 })?;
6610 if !ctx.prediction_store.blocks().is_empty() {
6611 return Err(DagMlError::RuntimeValidation(format!(
6612 "bundle `{}` cannot preload OOF prediction caches into a non-empty prediction store",
6613 bundle.bundle_id
6614 )));
6615 }
6616 let contracts = replay_prediction_cache_contracts(bundle)?;
6617 for contract in contracts.values() {
6618 if contract.requirement.prediction_level == PredictionLevel::Sample {
6619 let blocks = store.load_blocks(&contract.cache.requirement_key)?;
6620 if blocks.iter().any(|block| {
6621 block.producer_node != contract.requirement.producer_node
6622 || block.partition != contract.requirement.partition
6623 }) {
6624 return Err(DagMlError::RuntimeValidation(format!(
6625 "prediction cache store returned blocks outside requirement `{}`",
6626 contract.cache.requirement_key
6627 )));
6628 }
6629 let payload = build_prediction_cache_payload(&contract.requirement, &blocks)?;
6630 validate_prediction_cache_payload_matches_record(&payload, &contract.cache)?;
6631 for block in &payload.blocks {
6632 ctx.prediction_store.append(block.clone())?;
6633 }
6634 } else {
6635 let blocks = store.load_aggregated_blocks(&contract.cache.requirement_key)?;
6636 if blocks.iter().any(|block| {
6637 block.producer_node != contract.requirement.producer_node
6638 || block.partition != contract.requirement.partition
6639 || block.level != contract.requirement.prediction_level
6640 }) {
6641 return Err(DagMlError::RuntimeValidation(format!(
6642 "prediction cache store returned aggregated blocks outside requirement `{}`",
6643 contract.cache.requirement_key
6644 )));
6645 }
6646 let payload =
6647 build_aggregated_prediction_cache_payload(&contract.requirement, &blocks)?;
6648 validate_prediction_cache_payload_matches_record(&payload, &contract.cache)?;
6649 }
6650 }
6651 Ok(())
6652}
6653
6654fn replay_prediction_cache_contracts(
6655 bundle: &ExecutionBundle,
6656) -> Result<BTreeMap<String, ReplayPredictionCacheContract>> {
6657 bundle.validate()?;
6658 let requirements = bundle
6659 .prediction_requirements
6660 .iter()
6661 .map(|requirement| (requirement.key(), requirement))
6662 .collect::<BTreeMap<_, _>>();
6663 let mut contracts = BTreeMap::new();
6664 for cache in &bundle.prediction_caches {
6665 let requirement = requirements.get(&cache.requirement_key).ok_or_else(|| {
6666 DagMlError::RuntimeValidation(format!(
6667 "prediction cache `{}` references unknown prediction requirement `{}`",
6668 cache.cache_id, cache.requirement_key
6669 ))
6670 })?;
6671 contracts.insert(
6672 cache.requirement_key.clone(),
6673 ReplayPredictionCacheContract {
6674 requirement: (*requirement).clone(),
6675 cache: cache.clone(),
6676 },
6677 );
6678 }
6679 Ok(contracts)
6680}
6681
6682fn materialize_replay_artifact_handles(
6683 plan: &ExecutionPlan,
6684 bundle: &ExecutionBundle,
6685 replay_request: &ReplayPhaseRequest,
6686 artifact_store: &dyn RuntimeArtifactStore,
6687 ctx: &RunContext,
6688) -> Result<MaterializedReplayArtifacts> {
6689 let mut handles = BTreeMap::<NodeId, BTreeMap<String, HandleRef>>::new();
6690 let mut inputs = BTreeMap::<NodeId, BTreeMap<String, ArtifactInputSpec>>::new();
6691 for artifact in &bundle.refit_artifacts {
6692 artifact.validate()?;
6693 let node_plan = plan.node_plans.get(&artifact.node_id).ok_or_else(|| {
6694 DagMlError::RuntimeValidation(format!(
6695 "bundle `{}` artifact references unknown node `{}`",
6696 bundle.bundle_id, artifact.node_id
6697 ))
6698 })?;
6699 if !node_plan.supported_phases.contains(&replay_request.phase) {
6700 return Err(DagMlError::RuntimeValidation(format!(
6701 "bundle `{}` artifact node `{}` does not support replay phase {:?}",
6702 bundle.bundle_id, artifact.node_id, replay_request.phase
6703 )));
6704 }
6705 let handle = artifact_store.materialize(&ArtifactMaterializationRequest {
6706 run_id: ctx.run_id.clone(),
6707 bundle_id: bundle.bundle_id.clone(),
6708 node_id: artifact.node_id.clone(),
6709 phase: replay_request.phase,
6710 variant_id: bundle.selected_variant_id.clone(),
6711 controller_id: artifact.controller_id.clone(),
6712 artifact: artifact.artifact.clone(),
6713 params_fingerprint: artifact.params_fingerprint.clone(),
6714 })?;
6715 if !matches!(handle.kind, HandleKind::Model | HandleKind::Artifact) {
6716 return Err(DagMlError::RuntimeValidation(format!(
6717 "artifact `{}` materialized as unsupported handle kind {:?}",
6718 artifact.artifact.id, handle.kind
6719 )));
6720 }
6721 if handle.owner_controller != artifact.controller_id {
6722 return Err(DagMlError::RuntimeValidation(format!(
6723 "artifact `{}` handle owner `{}` does not match controller `{}`",
6724 artifact.artifact.id, handle.owner_controller, artifact.controller_id
6725 )));
6726 }
6727 let key = refit_artifact_input_key(&artifact.artifact.id);
6728 if handles
6729 .entry(artifact.node_id.clone())
6730 .or_default()
6731 .insert(key.clone(), handle)
6732 .is_some()
6733 {
6734 return Err(DagMlError::RuntimeValidation(format!(
6735 "duplicate replay artifact input `{key}` for node `{}`",
6736 artifact.node_id
6737 )));
6738 }
6739 if inputs
6740 .entry(artifact.node_id.clone())
6741 .or_default()
6742 .insert(key.clone(), ArtifactInputSpec::from_refit_record(artifact)?)
6743 .is_some()
6744 {
6745 return Err(DagMlError::RuntimeValidation(format!(
6746 "duplicate replay artifact metadata `{key}` for node `{}`",
6747 artifact.node_id
6748 )));
6749 }
6750 }
6751 Ok(MaterializedReplayArtifacts { handles, inputs })
6752}
6753
6754fn derive_task_seed(
6755 root_seed: Option<u64>,
6756 variant_id: Option<&VariantId>,
6757 fold_id: Option<&FoldId>,
6758 node_plan: &NodePlan,
6759 phase: Phase,
6760) -> Option<u64> {
6761 root_seed.map(|root| {
6762 let mut context = SeedContext::root(root);
6763 if let Some(variant_id) = variant_id {
6764 context = context.child(format!("variant:{variant_id}"));
6765 }
6766 if let Some(fold_id) = fold_id {
6767 context = context.child(format!("fold:{fold_id}"));
6768 }
6769 context
6770 .child(format!("node:{}", node_plan.node_id))
6771 .child(format!("phase:{phase:?}"))
6772 .derive_u64("task")
6773 })
6774}
6775
6776#[cfg(test)]
6777mod explain_contract_tests {
6778 use super::*;
6779
6780 fn block(method: &str) -> ExplanationBlock {
6781 ExplanationBlock {
6782 producer_node: NodeId::new("model:base").unwrap(),
6783 method: method.to_string(),
6784 target_name: Some("y".to_string()),
6785 payload: serde_json::json!({"feature_importance": [0.5, 0.3, 0.2]}),
6786 }
6787 }
6788
6789 #[test]
6790 fn validates_well_formed_explanation() {
6791 assert!(block("shap").validate().is_ok());
6792 }
6793
6794 #[test]
6795 fn rejects_empty_method() {
6796 assert!(block(" ").validate().is_err());
6797 }
6798
6799 #[test]
6800 fn rejects_empty_target_name() {
6801 let mut b = block("shap");
6802 b.target_name = Some(String::new());
6803 assert!(b.validate().is_err());
6804 }
6805
6806 #[test]
6807 fn round_trips_through_json() {
6808 let b = block("permutation_importance");
6809 let json = serde_json::to_string(&b).expect("serialize");
6810 let parsed: ExplanationBlock = serde_json::from_str(&json).expect("deserialize");
6811 assert_eq!(parsed, b);
6812 let mut without = block("shap");
6814 without.target_name = None;
6815 let json = serde_json::to_string(&without).expect("serialize");
6816 assert!(!json.contains("target_name"));
6817 }
6818}
6819
6820#[cfg(test)]
6821mod tests;