Skip to main content

dag_ml_core/
aggregation.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{DagMlError, Result};
7use crate::ids::{ControllerId, FoldId, GroupId, NodeId, ObservationId, SampleId, TargetId};
8use crate::oof::{PredictionBlock, PredictionPartition};
9use crate::policy::{
10    AggregationMethod, AggregationPolicy, AggregationWeights, PredictionLevel, ReductionAxis,
11    ReductionMethod, ReductionPlan,
12};
13use crate::relation::{EntityUnitLevel, SampleRelationSet};
14
15pub const AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION: u32 = 1;
16pub const AGGREGATION_CONTROLLER_TASK_SCHEMA_ID: &str =
17    "https://github.com/GBeurier/dag-ml/schemas/aggregation_controller_task.v1.schema.json";
18pub const AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION: u32 = 1;
19pub const AGGREGATION_CONTROLLER_RESULT_SCHEMA_ID: &str =
20    "https://github.com/GBeurier/dag-ml/schemas/aggregation_controller_result.v1.schema.json";
21const DEFAULT_ROBUST_TRIM_FRACTION: f64 = 0.1;
22
23#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
24pub struct ObservationPredictionBlock {
25    #[serde(default)]
26    pub prediction_id: Option<String>,
27    pub producer_node: NodeId,
28    pub partition: PredictionPartition,
29    pub fold_id: Option<FoldId>,
30    pub observation_ids: Vec<ObservationId>,
31    pub values: Vec<Vec<f64>>,
32    #[serde(default, skip_serializing_if = "Vec::is_empty")]
33    pub weights: Vec<f64>,
34    #[serde(default)]
35    pub target_names: Vec<String>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case", tag = "level", content = "id")]
40pub enum PredictionUnitId {
41    Sample(SampleId),
42    Target(TargetId),
43    Group(GroupId),
44}
45
46impl PredictionUnitId {
47    pub fn level(&self) -> PredictionLevel {
48        match self {
49            Self::Sample(_) => PredictionLevel::Sample,
50            Self::Target(_) => PredictionLevel::Target,
51            Self::Group(_) => PredictionLevel::Group,
52        }
53    }
54}
55
56impl fmt::Display for PredictionUnitId {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            Self::Sample(id) => write!(f, "sample:{id}"),
60            Self::Target(id) => write!(f, "target:{id}"),
61            Self::Group(id) => write!(f, "group:{id}"),
62        }
63    }
64}
65
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
67pub struct AggregatedPredictionBlock {
68    #[serde(default)]
69    pub prediction_id: Option<String>,
70    pub producer_node: NodeId,
71    pub partition: PredictionPartition,
72    pub fold_id: Option<FoldId>,
73    pub level: PredictionLevel,
74    pub unit_ids: Vec<PredictionUnitId>,
75    pub values: Vec<Vec<f64>>,
76    #[serde(default)]
77    pub target_names: Vec<String>,
78}
79
80impl AggregatedPredictionBlock {
81    pub fn validate_shape(&self) -> Result<usize> {
82        if self.unit_ids.len() != self.values.len() {
83            return Err(DagMlError::OofValidation(format!(
84                "producer `{}` has {} aggregated unit ids but {} prediction rows",
85                self.producer_node,
86                self.unit_ids.len(),
87                self.values.len()
88            )));
89        }
90        if self
91            .unit_ids
92            .iter()
93            .any(|unit_id| unit_id.level() != self.level)
94        {
95            return Err(DagMlError::OofValidation(format!(
96                "producer `{}` emitted aggregated units outside level {:?}",
97                self.producer_node, self.level
98            )));
99        }
100        let unique = self.unit_ids.iter().collect::<BTreeSet<_>>();
101        if unique.len() != self.unit_ids.len() {
102            return Err(DagMlError::OofValidation(format!(
103                "producer `{}` emitted duplicate aggregated unit ids",
104                self.producer_node
105            )));
106        }
107        let width = self.values.first().map_or(0, Vec::len);
108        if width == 0 {
109            return Err(DagMlError::OofValidation(format!(
110                "producer `{}` emitted empty aggregated prediction rows",
111                self.producer_node
112            )));
113        }
114        if self.values.iter().any(|row| row.len() != width) {
115            return Err(DagMlError::OofValidation(format!(
116                "producer `{}` emitted ragged aggregated prediction rows",
117                self.producer_node
118            )));
119        }
120        if self.values.iter().flatten().any(|value| !value.is_finite()) {
121            return Err(DagMlError::OofValidation(format!(
122                "producer `{}` emitted non-finite aggregated prediction values",
123                self.producer_node
124            )));
125        }
126        if !self.target_names.is_empty() && self.target_names.len() != width {
127            return Err(DagMlError::OofValidation(format!(
128                "producer `{}` has {} aggregated target names for width {}",
129                self.producer_node,
130                self.target_names.len(),
131                width
132            )));
133        }
134        Ok(width)
135    }
136}
137
138impl ObservationPredictionBlock {
139    pub fn validate_shape(&self) -> Result<usize> {
140        if self.observation_ids.len() != self.values.len() {
141            return Err(DagMlError::OofValidation(format!(
142                "producer `{}` has {} observation ids but {} prediction rows",
143                self.producer_node,
144                self.observation_ids.len(),
145                self.values.len()
146            )));
147        }
148        let width = self.values.first().map_or(0, Vec::len);
149        if width == 0 {
150            return Err(DagMlError::OofValidation(format!(
151                "producer `{}` emitted empty observation prediction rows",
152                self.producer_node
153            )));
154        }
155        if self.values.iter().any(|row| row.len() != width) {
156            return Err(DagMlError::OofValidation(format!(
157                "producer `{}` emitted ragged observation prediction rows",
158                self.producer_node
159            )));
160        }
161        if self.values.iter().flatten().any(|value| !value.is_finite()) {
162            return Err(DagMlError::OofValidation(format!(
163                "producer `{}` emitted non-finite observation prediction values",
164                self.producer_node
165            )));
166        }
167        if !self.weights.is_empty() {
168            if self.weights.len() != self.observation_ids.len() {
169                return Err(DagMlError::OofValidation(format!(
170                    "producer `{}` has {} observation weights but {} observation ids",
171                    self.producer_node,
172                    self.weights.len(),
173                    self.observation_ids.len()
174                )));
175            }
176            if self
177                .weights
178                .iter()
179                .any(|weight| !weight.is_finite() || *weight < 0.0)
180            {
181                return Err(DagMlError::OofValidation(format!(
182                    "producer `{}` emitted non-finite or negative observation weights",
183                    self.producer_node
184                )));
185            }
186        }
187        if !self.target_names.is_empty() && self.target_names.len() != width {
188            return Err(DagMlError::OofValidation(format!(
189                "producer `{}` has {} target names for width {}",
190                self.producer_node,
191                self.target_names.len(),
192                width
193            )));
194        }
195        let unique = self.observation_ids.iter().collect::<BTreeSet<_>>();
196        if unique.len() != self.observation_ids.len() {
197            return Err(DagMlError::OofValidation(format!(
198                "producer `{}` emitted duplicate observation predictions",
199                self.producer_node
200            )));
201        }
202        Ok(width)
203    }
204}
205
206#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
207pub struct AggregationControllerTask {
208    #[serde(default = "default_aggregation_controller_task_schema_version")]
209    pub schema_version: u32,
210    pub task_id: String,
211    pub controller_id: ControllerId,
212    pub policy: AggregationPolicy,
213    #[serde(default, skip_serializing_if = "Option::is_none")]
214    pub reduction_plan: Option<ReductionPlan>,
215    pub input: AggregationControllerInput,
216}
217
218#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
219#[serde(tag = "input_kind", rename_all = "snake_case")]
220pub enum AggregationControllerInput {
221    ObservationToSample {
222        block: ObservationPredictionBlock,
223        relations: SampleRelationSet,
224        requested_sample_order: Vec<SampleId>,
225    },
226    SampleToUnit {
227        block: PredictionBlock,
228        relations: SampleRelationSet,
229        requested_unit_order: Vec<PredictionUnitId>,
230    },
231}
232
233#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
234pub struct AggregationControllerResult {
235    #[serde(default = "default_aggregation_controller_result_schema_version")]
236    pub schema_version: u32,
237    pub task_id: String,
238    #[serde(default, skip_serializing_if = "Option::is_none")]
239    pub reduction_plan: Option<ReductionPlan>,
240    pub output: AggregationControllerOutput,
241}
242
243#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
244#[serde(tag = "output_kind", rename_all = "snake_case")]
245pub enum AggregationControllerOutput {
246    Sample { block: PredictionBlock },
247    Unit { block: AggregatedPredictionBlock },
248}
249
250impl AggregationControllerTask {
251    pub fn validate(&self) -> Result<()> {
252        if self.schema_version != AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION {
253            return Err(DagMlError::OofValidation(format!(
254                "aggregation controller task `{}` uses unsupported schema_version {}",
255                self.task_id, self.schema_version
256            )));
257        }
258        if self.task_id.trim().is_empty() {
259            return Err(DagMlError::OofValidation(
260                "aggregation controller task_id is empty".to_string(),
261            ));
262        }
263        self.policy.validate()?;
264        if self.policy.method != AggregationMethod::CustomController {
265            return Err(DagMlError::OofValidation(format!(
266                "aggregation controller task `{}` must use custom_controller method",
267                self.task_id
268            )));
269        }
270        let controller = self
271            .policy
272            .custom_controller
273            .as_ref()
274            .expect("custom_controller policy validation requires controller spec");
275        if controller.controller_id != self.controller_id {
276            return Err(DagMlError::OofValidation(format!(
277                "aggregation controller task `{}` targets controller `{}` but policy targets `{}`",
278                self.task_id, self.controller_id, controller.controller_id
279            )));
280        }
281        if let Some(reduction_plan) = &self.reduction_plan {
282            validate_aggregation_controller_reduction_plan(
283                reduction_plan,
284                &self.policy,
285                &self.input,
286            )?;
287        }
288        match &self.input {
289            AggregationControllerInput::ObservationToSample {
290                block,
291                relations,
292                requested_sample_order,
293            } => validate_aggregation_controller_observation_input(
294                block,
295                relations,
296                &self.policy,
297                requested_sample_order,
298            ),
299            AggregationControllerInput::SampleToUnit {
300                block,
301                relations,
302                requested_unit_order,
303            } => validate_aggregation_controller_sample_input(
304                block,
305                relations,
306                &self.policy,
307                requested_unit_order,
308            ),
309        }
310    }
311}
312
313impl AggregationControllerResult {
314    pub fn validate_for_task(&self, task: &AggregationControllerTask) -> Result<()> {
315        task.validate()?;
316        if self.schema_version != AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION {
317            return Err(DagMlError::OofValidation(format!(
318                "aggregation controller result `{}` uses unsupported schema_version {}",
319                self.task_id, self.schema_version
320            )));
321        }
322        if self.task_id != task.task_id {
323            return Err(DagMlError::OofValidation(format!(
324                "aggregation controller result task_id `{}` does not match task `{}`",
325                self.task_id, task.task_id
326            )));
327        }
328        validate_aggregation_controller_result_reduction_plan(task, self)?;
329        match (&task.input, &self.output) {
330            (
331                AggregationControllerInput::ObservationToSample {
332                    block: input_block,
333                    requested_sample_order,
334                    ..
335                },
336                AggregationControllerOutput::Sample { block },
337            ) => validate_aggregation_controller_sample_output(
338                input_block,
339                requested_sample_order,
340                block,
341            ),
342            (
343                AggregationControllerInput::SampleToUnit {
344                    block: input_block,
345                    requested_unit_order,
346                    ..
347                },
348                AggregationControllerOutput::Unit { block },
349            ) => validate_aggregation_controller_unit_output(
350                input_block,
351                requested_unit_order,
352                task.policy.aggregation_level,
353                block,
354            ),
355            (AggregationControllerInput::ObservationToSample { .. }, _) => {
356                Err(DagMlError::OofValidation(format!(
357                    "aggregation controller result `{}` must return sample output for observation input",
358                    self.task_id
359                )))
360            }
361            (AggregationControllerInput::SampleToUnit { .. }, _) => {
362                Err(DagMlError::OofValidation(format!(
363                    "aggregation controller result `{}` must return unit output for sample input",
364                    self.task_id
365                )))
366            }
367        }
368    }
369}
370
371fn validate_aggregation_controller_reduction_plan(
372    plan: &ReductionPlan,
373    policy: &AggregationPolicy,
374    input: &AggregationControllerInput,
375) -> Result<()> {
376    plan.validate()
377        .map_err(|error| DagMlError::OofValidation(error.to_string()))?;
378    if plan.method != ReductionMethod::from(policy.method) {
379        return Err(DagMlError::OofValidation(format!(
380            "reduction plan method {:?} does not match aggregation policy method {:?}",
381            plan.method, policy.method
382        )));
383    }
384    if plan.weight_source != policy.weights {
385        return Err(DagMlError::OofValidation(format!(
386            "reduction plan weight_source {:?} does not match aggregation policy weights {:?}",
387            plan.weight_source, policy.weights
388        )));
389    }
390    if plan.method == ReductionMethod::Custom {
391        let plan_controller = plan
392            .custom_controller
393            .as_ref()
394            .expect("reduction plan validation requires custom controller");
395        let policy_controller = policy
396            .custom_controller
397            .as_ref()
398            .expect("aggregation policy validation requires custom controller");
399        if plan_controller.controller_id != policy_controller.controller_id {
400            return Err(DagMlError::OofValidation(format!(
401                "reduction plan controller `{}` does not match aggregation policy controller `{}`",
402                plan_controller.controller_id, policy_controller.controller_id
403            )));
404        }
405    }
406    if plan.axis != ReductionAxis::Unit {
407        return Err(DagMlError::OofValidation(format!(
408            "aggregation controller reduction plan axis {:?} is not supported for unit aggregation tasks",
409            plan.axis
410        )));
411    }
412    match input {
413        AggregationControllerInput::ObservationToSample { .. } => {
414            if !matches!(
415                plan.input_unit_level,
416                EntityUnitLevel::Observation | EntityUnitLevel::Combo
417            ) {
418                return Err(DagMlError::OofValidation(format!(
419                    "observation aggregation reduction plan input_unit_level {:?} is invalid",
420                    plan.input_unit_level
421                )));
422            }
423            if plan.output_unit_level != EntityUnitLevel::PhysicalSample {
424                return Err(DagMlError::OofValidation(format!(
425                    "observation aggregation reduction plan output_unit_level {:?} must be physical_sample",
426                    plan.output_unit_level
427                )));
428            }
429            if policy.aggregation_level != PredictionLevel::Sample {
430                return Err(DagMlError::OofValidation(format!(
431                    "observation aggregation reduction plan must output sample predictions, got {:?}",
432                    policy.aggregation_level
433                )));
434            }
435        }
436        AggregationControllerInput::SampleToUnit { .. } => {
437            if plan.input_unit_level != EntityUnitLevel::PhysicalSample {
438                return Err(DagMlError::OofValidation(format!(
439                    "sample aggregation reduction plan input_unit_level {:?} must be physical_sample",
440                    plan.input_unit_level
441                )));
442            }
443            if plan.output_unit_level != EntityUnitLevel::PhysicalSample
444                || policy.aggregation_level != PredictionLevel::Sample
445            {
446                return Err(DagMlError::OofValidation(
447                    "sample aggregation reduction plans currently support only physical_sample output; target/group aggregation remains available without a ReductionPlan".to_string(),
448                ));
449            }
450        }
451    }
452    Ok(())
453}
454
455fn validate_aggregation_controller_result_reduction_plan(
456    task: &AggregationControllerTask,
457    result: &AggregationControllerResult,
458) -> Result<()> {
459    match (&task.reduction_plan, &result.reduction_plan) {
460        (Some(task_plan), Some(result_plan)) if task_plan == result_plan => Ok(()),
461        (Some(_), Some(_)) => Err(DagMlError::OofValidation(format!(
462            "aggregation controller result `{}` reduction_plan does not match task reduction_plan",
463            result.task_id
464        ))),
465        (Some(_), None) => Err(DagMlError::OofValidation(format!(
466            "aggregation controller result `{}` must echo task reduction_plan",
467            result.task_id
468        ))),
469        (None, Some(_)) => Err(DagMlError::OofValidation(format!(
470            "aggregation controller result `{}` declares reduction_plan but task does not",
471            result.task_id
472        ))),
473        (None, None) => Ok(()),
474    }
475}
476
477fn validate_aggregation_controller_observation_input(
478    block: &ObservationPredictionBlock,
479    relations: &SampleRelationSet,
480    policy: &AggregationPolicy,
481    requested_sample_order: &[SampleId],
482) -> Result<()> {
483    block.validate_shape()?;
484    relations.validate()?;
485    if policy.aggregation_level != PredictionLevel::Sample {
486        return Err(DagMlError::OofValidation(format!(
487            "observation aggregation controller task must output sample predictions, got {:?}",
488            policy.aggregation_level
489        )));
490    }
491    validate_unique_order(requested_sample_order, "requested_sample_order")?;
492    if matches!(
493        policy.weights,
494        AggregationWeights::ControllerEmitted | AggregationWeights::Quality
495    ) && block.weights.is_empty()
496    {
497        return Err(DagMlError::OofValidation(format!(
498            "aggregation controller task with {:?} weights requires observation weights",
499            policy.weights
500        )));
501    }
502    let requested = requested_sample_order.iter().collect::<BTreeSet<_>>();
503    let mut covered = BTreeSet::new();
504    for observation_id in &block.observation_ids {
505        let sample_id = relations
506            .sample_for_observation(observation_id)
507            .ok_or_else(|| {
508                DagMlError::OofValidation(format!(
509                    "observation prediction `{observation_id}` has no sample relation"
510                ))
511            })?;
512        if !requested.contains(sample_id) {
513            return Err(DagMlError::OofValidation(format!(
514                "observation prediction `{observation_id}` maps to unexpected sample `{sample_id}`"
515            )));
516        }
517        covered.insert(sample_id);
518    }
519    for sample_id in requested_sample_order {
520        if !covered.contains(sample_id) {
521            return Err(DagMlError::OofValidation(format!(
522                "sample `{sample_id}` has no observation predictions for aggregation controller task"
523            )));
524        }
525    }
526    Ok(())
527}
528
529fn validate_aggregation_controller_sample_input(
530    block: &PredictionBlock,
531    relations: &SampleRelationSet,
532    policy: &AggregationPolicy,
533    requested_unit_order: &[PredictionUnitId],
534) -> Result<()> {
535    validate_sample_prediction_block(block)?;
536    relations.validate()?;
537    if policy.aggregation_level == PredictionLevel::Observation {
538        return Err(DagMlError::OofValidation(
539            "sample aggregation controller task cannot output observation-level predictions"
540                .to_string(),
541        ));
542    }
543    if matches!(
544        policy.weights,
545        AggregationWeights::ControllerEmitted | AggregationWeights::Quality
546    ) {
547        return Err(DagMlError::OofValidation(format!(
548            "sample aggregation controller task cannot use {:?} weights without sample weights",
549            policy.weights
550        )));
551    }
552    validate_unique_order(requested_unit_order, "requested_unit_order")?;
553    if requested_unit_order
554        .iter()
555        .any(|unit_id| unit_id.level() != policy.aggregation_level)
556    {
557        return Err(DagMlError::OofValidation(format!(
558            "aggregation controller requested units do not match level {:?}",
559            policy.aggregation_level
560        )));
561    }
562    let requested = requested_unit_order.iter().collect::<BTreeSet<_>>();
563    let mut covered = BTreeSet::new();
564    for sample_id in &block.sample_ids {
565        let unit_id = unit_for_sample(relations, policy.aggregation_level, sample_id)?;
566        if !requested.contains(&unit_id) {
567            return Err(DagMlError::OofValidation(format!(
568                "sample prediction `{sample_id}` maps to unexpected aggregation unit `{unit_id}`"
569            )));
570        }
571        covered.insert(unit_id);
572    }
573    for unit_id in requested_unit_order {
574        if !covered.contains(unit_id) {
575            return Err(DagMlError::OofValidation(format!(
576                "aggregation unit `{unit_id}` has no sample predictions for aggregation controller task"
577            )));
578        }
579    }
580    Ok(())
581}
582
583fn validate_aggregation_controller_sample_output(
584    input_block: &ObservationPredictionBlock,
585    requested_sample_order: &[SampleId],
586    block: &PredictionBlock,
587) -> Result<()> {
588    validate_sample_prediction_block(block)?;
589    if block.producer_node != input_block.producer_node
590        || block.partition != input_block.partition
591        || block.fold_id != input_block.fold_id
592    {
593        return Err(DagMlError::OofValidation(format!(
594            "aggregation controller sample output for `{}` does not preserve producer, partition and fold",
595            input_block.producer_node
596        )));
597    }
598    if block.target_names != input_block.target_names {
599        return Err(DagMlError::OofValidation(format!(
600            "aggregation controller sample output for `{}` does not preserve target names",
601            input_block.producer_node
602        )));
603    }
604    if block.sample_ids != requested_sample_order {
605        return Err(DagMlError::OofValidation(format!(
606            "aggregation controller sample output for `{}` does not match requested sample order",
607            input_block.producer_node
608        )));
609    }
610    Ok(())
611}
612
613fn validate_aggregation_controller_unit_output(
614    input_block: &PredictionBlock,
615    requested_unit_order: &[PredictionUnitId],
616    expected_level: PredictionLevel,
617    block: &AggregatedPredictionBlock,
618) -> Result<()> {
619    block.validate_shape()?;
620    if block.producer_node != input_block.producer_node
621        || block.partition != input_block.partition
622        || block.fold_id != input_block.fold_id
623    {
624        return Err(DagMlError::OofValidation(format!(
625            "aggregation controller unit output for `{}` does not preserve producer, partition and fold",
626            input_block.producer_node
627        )));
628    }
629    if block.target_names != input_block.target_names {
630        return Err(DagMlError::OofValidation(format!(
631            "aggregation controller unit output for `{}` does not preserve target names",
632            input_block.producer_node
633        )));
634    }
635    if block.level != expected_level {
636        return Err(DagMlError::OofValidation(format!(
637            "aggregation controller unit output for `{}` has level {:?}, expected {:?}",
638            input_block.producer_node, block.level, expected_level
639        )));
640    }
641    if block.unit_ids != requested_unit_order {
642        return Err(DagMlError::OofValidation(format!(
643            "aggregation controller unit output for `{}` does not match requested unit order",
644            input_block.producer_node
645        )));
646    }
647    Ok(())
648}
649
650fn validate_unique_order<T>(values: &[T], label: &str) -> Result<()>
651where
652    T: Ord,
653{
654    if values.is_empty() {
655        return Err(DagMlError::OofValidation(format!(
656            "aggregation controller {label} is empty"
657        )));
658    }
659    let unique = values.iter().collect::<BTreeSet<_>>();
660    if unique.len() != values.len() {
661        return Err(DagMlError::OofValidation(format!(
662            "aggregation controller {label} contains duplicates"
663        )));
664    }
665    Ok(())
666}
667
668pub fn aggregate_observation_predictions(
669    block: &ObservationPredictionBlock,
670    relations: &SampleRelationSet,
671    policy: &AggregationPolicy,
672    requested_sample_order: &[SampleId],
673) -> Result<PredictionBlock> {
674    let width = block.validate_shape()?;
675    relations.validate()?;
676    policy.validate()?;
677    if requested_sample_order.is_empty() {
678        return Err(DagMlError::OofValidation(
679            "aggregation requested_sample_order is empty".to_string(),
680        ));
681    }
682    let requested = requested_sample_order.iter().collect::<BTreeSet<_>>();
683    if requested.len() != requested_sample_order.len() {
684        return Err(DagMlError::OofValidation(
685            "aggregation requested_sample_order contains duplicates".to_string(),
686        ));
687    }
688    if policy.aggregation_level != PredictionLevel::Sample {
689        return Err(DagMlError::OofValidation(format!(
690            "observation aggregation currently supports sample-level output, got {:?}",
691            policy.aggregation_level
692        )));
693    }
694    if policy.method == AggregationMethod::WeightedMean
695        && policy.weights == AggregationWeights::None
696    {
697        return Err(DagMlError::OofValidation(
698            "weighted_mean aggregation requires an explicit weights policy".to_string(),
699        ));
700    }
701    if policy.method != AggregationMethod::WeightedMean
702        && policy.weights != AggregationWeights::None
703    {
704        return Err(DagMlError::OofValidation(format!(
705            "aggregation weights {:?} are only valid with weighted_mean",
706            policy.weights
707        )));
708    }
709    if !block.weights.is_empty() && policy.method != AggregationMethod::WeightedMean {
710        return Err(DagMlError::OofValidation(format!(
711            "producer `{}` supplied observation weights for non-weighted aggregation {:?}",
712            block.producer_node, policy.method
713        )));
714    }
715
716    let store_rows = matches!(
717        policy.method,
718        AggregationMethod::Median | AggregationMethod::Vote | AggregationMethod::RobustMean
719    );
720    let mut accumulators = requested_sample_order
721        .iter()
722        .cloned()
723        .map(|sample_id| (sample_id, SampleAccumulator::new(width, store_rows)))
724        .collect::<BTreeMap<_, _>>();
725
726    for (row_idx, (observation_id, row)) in block
727        .observation_ids
728        .iter()
729        .zip(block.values.iter())
730        .enumerate()
731    {
732        let sample_id = relations
733            .sample_for_observation(observation_id)
734            .ok_or_else(|| {
735                DagMlError::OofValidation(format!(
736                    "observation prediction `{observation_id}` has no sample relation"
737                ))
738            })?;
739        if !requested.contains(sample_id) {
740            return Err(DagMlError::OofValidation(format!(
741                "observation prediction `{observation_id}` maps to unexpected sample `{sample_id}`"
742            )));
743        }
744        let accumulator = accumulators
745            .get_mut(sample_id)
746            .expect("requested sample accumulator exists");
747        let weight = observation_weight(block, policy, row_idx)?;
748        accumulator.push(row, weight);
749    }
750
751    let values = requested_sample_order
752        .iter()
753        .map(|sample_id| {
754            let accumulator = accumulators
755                .get(sample_id)
756                .expect("requested sample accumulator exists");
757            if accumulator.count == 0 {
758                return Err(DagMlError::OofValidation(format!(
759                    "sample `{sample_id}` has no observation predictions to aggregate"
760                )));
761            }
762            match policy.method {
763                AggregationMethod::Mean => Ok(accumulator.mean()),
764                AggregationMethod::WeightedMean => accumulator.weighted_mean(&sample_id.to_string()),
765                AggregationMethod::Median => Ok(accumulator.median()),
766                AggregationMethod::Vote => Ok(accumulator.vote()),
767                AggregationMethod::RobustMean => {
768                    Ok(accumulator.robust_mean(DEFAULT_ROBUST_TRIM_FRACTION))
769                }
770                AggregationMethod::ExcludeOutliers => Err(DagMlError::OofValidation(
771                    "exclude_outliers aggregation requires a custom aggregation controller"
772                        .to_string(),
773                )),
774                AggregationMethod::None => {
775                    if accumulator.count == 1 {
776                        Ok(accumulator
777                            .first_row
778                            .clone()
779                            .expect("single prediction accumulator stores first row"))
780                    } else {
781                        Err(DagMlError::OofValidation(format!(
782                            "sample `{sample_id}` has {} observation predictions but aggregation method is none",
783                            accumulator.count
784                        )))
785                    }
786                }
787                AggregationMethod::CustomController => Err(DagMlError::OofValidation(format!(
788                    "aggregation method {:?} is delegated to an aggregation controller",
789                    policy.method
790                ))),
791            }
792        })
793        .collect::<Result<Vec<Vec<f64>>>>()?;
794
795    Ok(PredictionBlock {
796        prediction_id: block
797            .prediction_id
798            .as_ref()
799            .map(|prediction_id| format!("{prediction_id}:sample_agg")),
800        producer_node: block.producer_node.clone(),
801        partition: block.partition.clone(),
802        fold_id: block.fold_id.clone(),
803        sample_ids: requested_sample_order.to_vec(),
804        values,
805        target_names: block.target_names.clone(),
806    })
807}
808
809pub fn aggregate_sample_predictions_by_unit(
810    block: &PredictionBlock,
811    relations: &SampleRelationSet,
812    policy: &AggregationPolicy,
813    requested_unit_order: &[PredictionUnitId],
814) -> Result<AggregatedPredictionBlock> {
815    let width = validate_sample_prediction_block(block)?;
816    relations.validate()?;
817    policy.validate()?;
818    if requested_unit_order.is_empty() {
819        return Err(DagMlError::OofValidation(
820            "aggregation requested_unit_order is empty".to_string(),
821        ));
822    }
823    let requested_level = policy.aggregation_level;
824    if requested_level == PredictionLevel::Observation {
825        return Err(DagMlError::OofValidation(
826            "sample prediction aggregation cannot output observation-level predictions".to_string(),
827        ));
828    }
829    if requested_unit_order
830        .iter()
831        .any(|unit_id| unit_id.level() != requested_level)
832    {
833        return Err(DagMlError::OofValidation(format!(
834            "aggregation requested units do not match level {:?}",
835            requested_level
836        )));
837    }
838    let requested = requested_unit_order.iter().collect::<BTreeSet<_>>();
839    if requested.len() != requested_unit_order.len() {
840        return Err(DagMlError::OofValidation(
841            "aggregation requested_unit_order contains duplicates".to_string(),
842        ));
843    }
844
845    let by_sample = block
846        .sample_ids
847        .iter()
848        .cloned()
849        .zip(block.values.iter().cloned())
850        .collect::<BTreeMap<_, _>>();
851    if requested_level == PredictionLevel::Sample {
852        let values = requested_unit_order
853            .iter()
854            .map(|unit_id| {
855                let PredictionUnitId::Sample(sample_id) = unit_id else {
856                    unreachable!("requested unit level already validated");
857                };
858                by_sample.get(sample_id).cloned().ok_or_else(|| {
859                    DagMlError::OofValidation(format!(
860                        "sample prediction block for `{}` is missing requested sample `{sample_id}`",
861                        block.producer_node
862                    ))
863                })
864            })
865            .collect::<Result<Vec<_>>>()?;
866        if by_sample.len() != requested_unit_order.len() {
867            return Err(DagMlError::OofValidation(format!(
868                "sample prediction block for `{}` contains samples outside requested sample order",
869                block.producer_node
870            )));
871        }
872        let aggregated = AggregatedPredictionBlock {
873            prediction_id: block.prediction_id.clone(),
874            producer_node: block.producer_node.clone(),
875            partition: block.partition.clone(),
876            fold_id: block.fold_id.clone(),
877            level: PredictionLevel::Sample,
878            unit_ids: requested_unit_order.to_vec(),
879            values,
880            target_names: block.target_names.clone(),
881        };
882        aggregated.validate_shape()?;
883        return Ok(aggregated);
884    }
885
886    if policy.method == AggregationMethod::WeightedMean
887        && matches!(
888            policy.weights,
889            AggregationWeights::ControllerEmitted | AggregationWeights::Quality
890        )
891    {
892        return Err(DagMlError::OofValidation(format!(
893            "sample-to-{:?} weighted_mean cannot use {:?} weights without sample-level weights",
894            requested_level, policy.weights
895        )));
896    }
897
898    let store_rows = matches!(
899        policy.method,
900        AggregationMethod::Median | AggregationMethod::Vote | AggregationMethod::RobustMean
901    );
902    let mut accumulators = requested_unit_order
903        .iter()
904        .cloned()
905        .map(|unit_id| (unit_id, SampleAccumulator::new(width, store_rows)))
906        .collect::<BTreeMap<_, _>>();
907
908    for (sample_id, row) in block.sample_ids.iter().zip(block.values.iter()) {
909        let unit_id = unit_for_sample(relations, requested_level, sample_id)?;
910        if !requested.contains(&unit_id) {
911            return Err(DagMlError::OofValidation(format!(
912                "sample prediction `{sample_id}` maps to unexpected aggregation unit `{unit_id}`"
913            )));
914        }
915        let weight = sample_weight(relations, policy, sample_id)?;
916        accumulators
917            .get_mut(&unit_id)
918            .expect("requested aggregation unit accumulator exists")
919            .push(row, weight);
920    }
921
922    let values = requested_unit_order
923        .iter()
924        .map(|unit_id| {
925            let accumulator = accumulators
926                .get(unit_id)
927                .expect("requested aggregation unit accumulator exists");
928            if accumulator.count == 0 {
929                return Err(DagMlError::OofValidation(format!(
930                    "aggregation unit `{unit_id}` has no sample predictions to aggregate"
931                )));
932            }
933            match policy.method {
934                AggregationMethod::Mean => Ok(accumulator.mean()),
935                AggregationMethod::WeightedMean => accumulator.weighted_mean(&unit_id.to_string()),
936                AggregationMethod::Median => Ok(accumulator.median()),
937                AggregationMethod::Vote => Ok(accumulator.vote()),
938                AggregationMethod::RobustMean => {
939                    Ok(accumulator.robust_mean(DEFAULT_ROBUST_TRIM_FRACTION))
940                }
941                AggregationMethod::ExcludeOutliers => Err(DagMlError::OofValidation(
942                    "exclude_outliers aggregation requires a custom aggregation controller"
943                        .to_string(),
944                )),
945                AggregationMethod::None => {
946                    if accumulator.count == 1 {
947                        Ok(accumulator
948                            .first_row
949                            .clone()
950                            .expect("single prediction accumulator stores first row"))
951                    } else {
952                        Err(DagMlError::OofValidation(format!(
953                            "aggregation unit `{unit_id}` has {} sample predictions but aggregation method is none",
954                            accumulator.count
955                        )))
956                    }
957                }
958                AggregationMethod::CustomController => Err(DagMlError::OofValidation(format!(
959                    "aggregation method {:?} is delegated to an aggregation controller",
960                    policy.method
961                ))),
962            }
963        })
964        .collect::<Result<Vec<_>>>()?;
965
966    let suffix = match requested_level {
967        PredictionLevel::Target => "target_agg",
968        PredictionLevel::Group => "group_agg",
969        PredictionLevel::Sample => "sample_agg",
970        PredictionLevel::Observation => unreachable!("observation output rejected above"),
971    };
972    let aggregated = AggregatedPredictionBlock {
973        prediction_id: block
974            .prediction_id
975            .as_ref()
976            .map(|prediction_id| format!("{prediction_id}:{suffix}")),
977        producer_node: block.producer_node.clone(),
978        partition: block.partition.clone(),
979        fold_id: block.fold_id.clone(),
980        level: requested_level,
981        unit_ids: requested_unit_order.to_vec(),
982        values,
983        target_names: block.target_names.clone(),
984    };
985    aggregated.validate_shape()?;
986    Ok(aggregated)
987}
988
989fn validate_sample_prediction_block(block: &PredictionBlock) -> Result<usize> {
990    let width = block.validate_shape()?;
991    if block
992        .values
993        .iter()
994        .flatten()
995        .any(|value| !value.is_finite())
996    {
997        return Err(DagMlError::OofValidation(format!(
998            "producer `{}` emitted non-finite sample prediction values",
999            block.producer_node
1000        )));
1001    }
1002    let unique = block.sample_ids.iter().collect::<BTreeSet<_>>();
1003    if unique.len() != block.sample_ids.len() {
1004        return Err(DagMlError::OofValidation(format!(
1005            "producer `{}` emitted duplicate sample predictions",
1006            block.producer_node
1007        )));
1008    }
1009    Ok(width)
1010}
1011
1012fn unit_for_sample(
1013    relations: &SampleRelationSet,
1014    level: PredictionLevel,
1015    sample_id: &SampleId,
1016) -> Result<PredictionUnitId> {
1017    match level {
1018        PredictionLevel::Sample => Ok(PredictionUnitId::Sample(sample_id.clone())),
1019        PredictionLevel::Target => relations
1020            .target_for_sample(sample_id)
1021            .cloned()
1022            .map(PredictionUnitId::Target)
1023            .ok_or_else(|| {
1024                DagMlError::OofValidation(format!(
1025                    "sample `{sample_id}` is missing target id for target aggregation"
1026                ))
1027            }),
1028        PredictionLevel::Group => relations
1029            .group_for_sample(sample_id)
1030            .cloned()
1031            .map(PredictionUnitId::Group)
1032            .ok_or_else(|| {
1033                DagMlError::OofValidation(format!(
1034                    "sample `{sample_id}` is missing group id for group aggregation"
1035                ))
1036            }),
1037        PredictionLevel::Observation => Err(DagMlError::OofValidation(
1038            "sample prediction aggregation cannot output observation-level predictions".to_string(),
1039        )),
1040    }
1041}
1042
1043fn sample_weight(
1044    relations: &SampleRelationSet,
1045    policy: &AggregationPolicy,
1046    sample_id: &SampleId,
1047) -> Result<f64> {
1048    if policy.method != AggregationMethod::WeightedMean {
1049        return Ok(1.0);
1050    }
1051    match policy.weights {
1052        AggregationWeights::RepetitionCount => {
1053            let count = relations.observation_count_for_sample(sample_id);
1054            if count == 0 {
1055                return Err(DagMlError::OofValidation(format!(
1056                    "sample `{sample_id}` has no observation relations for repetition_count weights"
1057                )));
1058            }
1059            Ok(count as f64)
1060        }
1061        AggregationWeights::ControllerEmitted | AggregationWeights::Quality => {
1062            Err(DagMlError::OofValidation(format!(
1063                "sample-level {:?} weights are not present in PredictionBlock",
1064                policy.weights
1065            )))
1066        }
1067        AggregationWeights::None => Err(DagMlError::OofValidation(
1068            "weighted_mean aggregation requires an explicit weights policy".to_string(),
1069        )),
1070    }
1071}
1072
1073#[derive(Clone, Debug)]
1074struct SampleAccumulator {
1075    sum: Vec<f64>,
1076    weighted_sum: Vec<f64>,
1077    weight_sum: f64,
1078    rows: Vec<Vec<f64>>,
1079    first_row: Option<Vec<f64>>,
1080    store_rows: bool,
1081    count: usize,
1082}
1083
1084impl SampleAccumulator {
1085    fn new(width: usize, store_rows: bool) -> Self {
1086        Self {
1087            sum: vec![0.0; width],
1088            weighted_sum: vec![0.0; width],
1089            weight_sum: 0.0,
1090            rows: Vec::new(),
1091            first_row: None,
1092            store_rows,
1093            count: 0,
1094        }
1095    }
1096
1097    fn push(&mut self, row: &[f64], weight: f64) {
1098        for (idx, value) in row.iter().enumerate() {
1099            self.sum[idx] += *value;
1100            self.weighted_sum[idx] += *value * weight;
1101        }
1102        self.weight_sum += weight;
1103        if self.first_row.is_none() {
1104            self.first_row = Some(row.to_vec());
1105        }
1106        if self.store_rows {
1107            self.rows.push(row.to_vec());
1108        }
1109        self.count += 1;
1110    }
1111
1112    fn mean(&self) -> Vec<f64> {
1113        self.sum
1114            .iter()
1115            .map(|value| *value / self.count as f64)
1116            .collect()
1117    }
1118
1119    fn weighted_mean(&self, unit_label: &str) -> Result<Vec<f64>> {
1120        if self.weight_sum <= 0.0 {
1121            return Err(DagMlError::OofValidation(format!(
1122                "aggregation unit `{unit_label}` has zero total prediction weight"
1123            )));
1124        }
1125        Ok(self
1126            .weighted_sum
1127            .iter()
1128            .map(|value| *value / self.weight_sum)
1129            .collect())
1130    }
1131
1132    fn median(&self) -> Vec<f64> {
1133        let width = self.sum.len();
1134        (0..width)
1135            .map(|column_idx| {
1136                let mut column = self
1137                    .rows
1138                    .iter()
1139                    .map(|row| row[column_idx])
1140                    .collect::<Vec<_>>();
1141                column.sort_by(f64::total_cmp);
1142                let middle = column.len() / 2;
1143                if column.len() % 2 == 1 {
1144                    column[middle]
1145                } else {
1146                    (column[middle - 1] + column[middle]) / 2.0
1147                }
1148            })
1149            .collect()
1150    }
1151
1152    fn vote(&self) -> Vec<f64> {
1153        let width = self.sum.len();
1154        (0..width)
1155            .map(|column_idx| {
1156                let mut column = self
1157                    .rows
1158                    .iter()
1159                    .map(|row| row[column_idx])
1160                    .collect::<Vec<_>>();
1161                column.sort_by(f64::total_cmp);
1162                mode_sorted(&column)
1163            })
1164            .collect()
1165    }
1166
1167    fn robust_mean(&self, trim_fraction: f64) -> Vec<f64> {
1168        let width = self.sum.len();
1169        (0..width)
1170            .map(|column_idx| {
1171                let mut column = self
1172                    .rows
1173                    .iter()
1174                    .map(|row| row[column_idx])
1175                    .collect::<Vec<_>>();
1176                column.sort_by(f64::total_cmp);
1177                let trim_count = ((column.len() as f64) * trim_fraction).floor() as usize;
1178                let max_trim = column.len().saturating_sub(1) / 2;
1179                let trim_count = trim_count.min(max_trim);
1180                let kept = &column[trim_count..column.len() - trim_count];
1181                kept.iter().sum::<f64>() / kept.len() as f64
1182            })
1183            .collect()
1184    }
1185}
1186
1187fn observation_weight(
1188    block: &ObservationPredictionBlock,
1189    policy: &AggregationPolicy,
1190    row_idx: usize,
1191) -> Result<f64> {
1192    if policy.method != AggregationMethod::WeightedMean {
1193        return Ok(1.0);
1194    }
1195    match policy.weights {
1196        AggregationWeights::ControllerEmitted | AggregationWeights::Quality => block
1197            .weights
1198            .get(row_idx)
1199            .copied()
1200            .ok_or_else(|| {
1201                DagMlError::OofValidation(format!(
1202                    "weighted_mean aggregation with {:?} weights requires one weight per observation",
1203                    policy.weights
1204                ))
1205            }),
1206        AggregationWeights::RepetitionCount => Ok(1.0),
1207        AggregationWeights::None => Err(DagMlError::OofValidation(
1208            "weighted_mean aggregation requires an explicit weights policy".to_string(),
1209        )),
1210    }
1211}
1212
1213fn mode_sorted(values: &[f64]) -> f64 {
1214    let mut best_value = values[0];
1215    let mut best_count = 1usize;
1216    let mut current_value = values[0];
1217    let mut current_count = 1usize;
1218    for value in values.iter().skip(1) {
1219        if *value == current_value {
1220            current_count += 1;
1221            continue;
1222        }
1223        if current_count > best_count {
1224            best_value = current_value;
1225            best_count = current_count;
1226        }
1227        current_value = *value;
1228        current_count = 1;
1229    }
1230    if current_count > best_count {
1231        current_value
1232    } else {
1233        best_value
1234    }
1235}
1236
1237fn default_aggregation_controller_task_schema_version() -> u32 {
1238    AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION
1239}
1240
1241fn default_aggregation_controller_result_schema_version() -> u32 {
1242    AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION
1243}
1244
1245#[cfg(test)]
1246mod tests {
1247    use super::*;
1248    use crate::ids::{ControllerId, GroupId, TargetId};
1249    use crate::relation::SampleRelation;
1250
1251    fn sid(value: &str) -> SampleId {
1252        SampleId::new(value).unwrap()
1253    }
1254
1255    fn oid(value: &str) -> ObservationId {
1256        ObservationId::new(value).unwrap()
1257    }
1258
1259    fn relation(observation: &str, sample: &str) -> SampleRelation {
1260        let mut relation = SampleRelation::new(oid(observation), sid(sample));
1261        relation.target_id = Some(TargetId::new(format!("target:{sample}")).unwrap());
1262        relation
1263    }
1264
1265    fn relation_with_units(
1266        observation: &str,
1267        sample: &str,
1268        target: &str,
1269        group: &str,
1270    ) -> SampleRelation {
1271        let mut relation = SampleRelation::new(oid(observation), sid(sample));
1272        relation.target_id = Some(TargetId::new(target).unwrap());
1273        relation.group_id = Some(GroupId::new(group).unwrap());
1274        relation
1275    }
1276
1277    fn combo_relation(observation: &str, sample: &str, components: &[&str]) -> SampleRelation {
1278        let mut relation = SampleRelation::new(oid(observation), sid(sample));
1279        relation.unit_level = EntityUnitLevel::Combo;
1280        relation.derived_unit_id = Some(format!("combo:{observation}"));
1281        relation.component_observation_ids =
1282            components.iter().map(|component| oid(component)).collect();
1283        relation
1284    }
1285
1286    fn custom_policy(level: PredictionLevel) -> AggregationPolicy {
1287        AggregationPolicy {
1288            aggregation_level: level,
1289            method: AggregationMethod::CustomController,
1290            custom_controller: Some(crate::policy::AggregationControllerSpec {
1291                controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1292                params: serde_json::json!({ "trim_fraction": 0.1 }),
1293            }),
1294            ..AggregationPolicy::default()
1295        }
1296    }
1297
1298    #[test]
1299    fn validates_custom_observation_aggregation_controller_result() {
1300        let reduction_plan = ReductionPlan {
1301            role: crate::policy::ReductionRole::FinalOutput,
1302            axis: ReductionAxis::Unit,
1303            input_unit_level: EntityUnitLevel::Observation,
1304            output_unit_level: EntityUnitLevel::PhysicalSample,
1305            method: ReductionMethod::Custom,
1306            custom_controller: Some(crate::policy::AggregationControllerSpec {
1307                controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1308                params: serde_json::json!({ "trim_fraction": 0.1 }),
1309            }),
1310            ..ReductionPlan::default()
1311        };
1312        let task = AggregationControllerTask {
1313            schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1314            task_id: "agg-task:obs.sample.fold0".to_string(),
1315            controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1316            policy: custom_policy(PredictionLevel::Sample),
1317            reduction_plan: Some(reduction_plan.clone()),
1318            input: AggregationControllerInput::ObservationToSample {
1319                block: ObservationPredictionBlock {
1320                    prediction_id: Some("prediction:model.fold0".to_string()),
1321                    producer_node: NodeId::new("model:pls").unwrap(),
1322                    partition: PredictionPartition::Validation,
1323                    fold_id: Some(FoldId::new("fold:0").unwrap()),
1324                    observation_ids: vec![oid("obs:1"), oid("obs:2"), oid("obs:3")],
1325                    values: vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![9.0, 10.0]],
1326                    weights: Vec::new(),
1327                    target_names: vec!["moisture".to_string(), "protein".to_string()],
1328                },
1329                relations: SampleRelationSet {
1330                    records: vec![
1331                        relation("obs:1", "sample:1"),
1332                        relation("obs:2", "sample:1"),
1333                        relation("obs:3", "sample:2"),
1334                    ],
1335                },
1336                requested_sample_order: vec![sid("sample:1"), sid("sample:2")],
1337            },
1338        };
1339        task.validate().unwrap();
1340
1341        let result = AggregationControllerResult {
1342            schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1343            task_id: task.task_id.clone(),
1344            reduction_plan: Some(reduction_plan),
1345            output: AggregationControllerOutput::Sample {
1346                block: PredictionBlock {
1347                    prediction_id: Some("prediction:model.fold0:custom_sample_agg".to_string()),
1348                    producer_node: NodeId::new("model:pls").unwrap(),
1349                    partition: PredictionPartition::Validation,
1350                    fold_id: Some(FoldId::new("fold:0").unwrap()),
1351                    sample_ids: vec![sid("sample:1"), sid("sample:2")],
1352                    values: vec![vec![2.0, 3.0], vec![9.0, 10.0]],
1353                    target_names: vec!["moisture".to_string(), "protein".to_string()],
1354                },
1355            },
1356        };
1357
1358        result.validate_for_task(&task).unwrap();
1359    }
1360
1361    #[test]
1362    fn custom_aggregation_controller_result_must_echo_reduction_plan() {
1363        let reduction_plan = ReductionPlan {
1364            method: ReductionMethod::Custom,
1365            custom_controller: Some(crate::policy::AggregationControllerSpec {
1366                controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1367                params: serde_json::json!({}),
1368            }),
1369            ..ReductionPlan::default()
1370        };
1371        let task = AggregationControllerTask {
1372            schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1373            task_id: "agg-task:obs.sample.fold0".to_string(),
1374            controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1375            policy: custom_policy(PredictionLevel::Sample),
1376            reduction_plan: Some(reduction_plan),
1377            input: AggregationControllerInput::ObservationToSample {
1378                block: ObservationPredictionBlock {
1379                    prediction_id: None,
1380                    producer_node: NodeId::new("model:pls").unwrap(),
1381                    partition: PredictionPartition::Validation,
1382                    fold_id: None,
1383                    observation_ids: vec![oid("obs:1")],
1384                    values: vec![vec![1.0]],
1385                    weights: Vec::new(),
1386                    target_names: vec!["y".to_string()],
1387                },
1388                relations: SampleRelationSet {
1389                    records: vec![relation("obs:1", "sample:1")],
1390                },
1391                requested_sample_order: vec![sid("sample:1")],
1392            },
1393        };
1394        let result = AggregationControllerResult {
1395            schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1396            task_id: task.task_id.clone(),
1397            reduction_plan: None,
1398            output: AggregationControllerOutput::Sample {
1399                block: PredictionBlock {
1400                    prediction_id: None,
1401                    producer_node: NodeId::new("model:pls").unwrap(),
1402                    partition: PredictionPartition::Validation,
1403                    fold_id: None,
1404                    sample_ids: vec![sid("sample:1")],
1405                    values: vec![vec![1.0]],
1406                    target_names: vec!["y".to_string()],
1407                },
1408            },
1409        };
1410
1411        let error = result.validate_for_task(&task).unwrap_err().to_string();
1412
1413        assert!(error.contains("echo task reduction_plan"));
1414    }
1415
1416    #[test]
1417    fn custom_aggregation_controller_result_refuses_order_mismatch() {
1418        let task = AggregationControllerTask {
1419            schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1420            task_id: "agg-task:obs.sample.fold0".to_string(),
1421            controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1422            policy: custom_policy(PredictionLevel::Sample),
1423            reduction_plan: None,
1424            input: AggregationControllerInput::ObservationToSample {
1425                block: ObservationPredictionBlock {
1426                    prediction_id: None,
1427                    producer_node: NodeId::new("model:pls").unwrap(),
1428                    partition: PredictionPartition::Validation,
1429                    fold_id: None,
1430                    observation_ids: vec![oid("obs:1"), oid("obs:2")],
1431                    values: vec![vec![1.0], vec![2.0]],
1432                    weights: Vec::new(),
1433                    target_names: vec!["y".to_string()],
1434                },
1435                relations: SampleRelationSet {
1436                    records: vec![relation("obs:1", "sample:1"), relation("obs:2", "sample:2")],
1437                },
1438                requested_sample_order: vec![sid("sample:1"), sid("sample:2")],
1439            },
1440        };
1441        let result = AggregationControllerResult {
1442            schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1443            task_id: task.task_id.clone(),
1444            reduction_plan: None,
1445            output: AggregationControllerOutput::Sample {
1446                block: PredictionBlock {
1447                    prediction_id: None,
1448                    producer_node: NodeId::new("model:pls").unwrap(),
1449                    partition: PredictionPartition::Validation,
1450                    fold_id: None,
1451                    sample_ids: vec![sid("sample:2"), sid("sample:1")],
1452                    values: vec![vec![2.0], vec![1.0]],
1453                    target_names: vec!["y".to_string()],
1454                },
1455            },
1456        };
1457
1458        let error = result.validate_for_task(&task).unwrap_err().to_string();
1459        assert!(error.contains("requested sample order"));
1460    }
1461
1462    #[test]
1463    fn validates_custom_sample_to_group_aggregation_controller_result() {
1464        let task = AggregationControllerTask {
1465            schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1466            task_id: "agg-task:sample.group.fold0".to_string(),
1467            controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1468            policy: custom_policy(PredictionLevel::Group),
1469            reduction_plan: None,
1470            input: AggregationControllerInput::SampleToUnit {
1471                block: PredictionBlock {
1472                    prediction_id: Some("prediction:model.fold0".to_string()),
1473                    producer_node: NodeId::new("model:pls").unwrap(),
1474                    partition: PredictionPartition::Validation,
1475                    fold_id: Some(FoldId::new("fold:0").unwrap()),
1476                    sample_ids: vec![sid("sample:1"), sid("sample:2"), sid("sample:3")],
1477                    values: vec![vec![1.0], vec![3.0], vec![10.0]],
1478                    target_names: vec!["y".to_string()],
1479                },
1480                relations: SampleRelationSet {
1481                    records: vec![
1482                        relation_with_units("obs:1", "sample:1", "target:1", "group:left"),
1483                        relation_with_units("obs:2", "sample:2", "target:2", "group:left"),
1484                        relation_with_units("obs:3", "sample:3", "target:3", "group:right"),
1485                    ],
1486                },
1487                requested_unit_order: vec![
1488                    PredictionUnitId::Group(GroupId::new("group:left").unwrap()),
1489                    PredictionUnitId::Group(GroupId::new("group:right").unwrap()),
1490                ],
1491            },
1492        };
1493        task.validate().unwrap();
1494
1495        let result = AggregationControllerResult {
1496            schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1497            task_id: task.task_id.clone(),
1498            reduction_plan: None,
1499            output: AggregationControllerOutput::Unit {
1500                block: AggregatedPredictionBlock {
1501                    prediction_id: Some("prediction:model.fold0:custom_group_agg".to_string()),
1502                    producer_node: NodeId::new("model:pls").unwrap(),
1503                    partition: PredictionPartition::Validation,
1504                    fold_id: Some(FoldId::new("fold:0").unwrap()),
1505                    level: PredictionLevel::Group,
1506                    unit_ids: vec![
1507                        PredictionUnitId::Group(GroupId::new("group:left").unwrap()),
1508                        PredictionUnitId::Group(GroupId::new("group:right").unwrap()),
1509                    ],
1510                    values: vec![vec![2.0], vec![10.0]],
1511                    target_names: vec!["y".to_string()],
1512                },
1513            },
1514        };
1515
1516        result.validate_for_task(&task).unwrap();
1517    }
1518
1519    #[test]
1520    fn averages_repeated_observation_predictions_by_sample() {
1521        let block = ObservationPredictionBlock {
1522            prediction_id: Some("pred:oof".to_string()),
1523            producer_node: NodeId::new("model:pls").unwrap(),
1524            partition: PredictionPartition::Validation,
1525            fold_id: Some(FoldId::new("fold:0").unwrap()),
1526            observation_ids: vec![oid("obs:1a"), oid("obs:1b"), oid("obs:2a")],
1527            values: vec![vec![1.0], vec![3.0], vec![10.0]],
1528            weights: Vec::new(),
1529            target_names: vec!["y".to_string()],
1530        };
1531        let relations = SampleRelationSet {
1532            records: vec![
1533                relation("obs:1a", "sample:1"),
1534                relation("obs:1b", "sample:1"),
1535                relation("obs:2a", "sample:2"),
1536            ],
1537        };
1538
1539        let aggregated = aggregate_observation_predictions(
1540            &block,
1541            &relations,
1542            &AggregationPolicy::default(),
1543            &[sid("sample:1"), sid("sample:2")],
1544        )
1545        .unwrap();
1546
1547        assert_eq!(
1548            aggregated.sample_ids,
1549            vec![sid("sample:1"), sid("sample:2")]
1550        );
1551        assert_eq!(aggregated.values, vec![vec![2.0], vec![10.0]]);
1552    }
1553
1554    #[test]
1555    fn aggregates_relation_backed_combo_predictions_by_sample() {
1556        let relations = SampleRelationSet {
1557            records: vec![
1558                relation("obs:s1.a", "sample:1"),
1559                relation("obs:s1.b", "sample:1"),
1560                relation("obs:s2.a", "sample:2"),
1561                relation("obs:s2.b", "sample:2"),
1562                combo_relation("obs:s1.combo", "sample:1", &["obs:s1.a", "obs:s1.b"]),
1563                combo_relation("obs:s2.combo", "sample:2", &["obs:s2.a", "obs:s2.b"]),
1564            ],
1565        };
1566        let block = ObservationPredictionBlock {
1567            prediction_id: Some("pred:combo".to_string()),
1568            producer_node: NodeId::new("model:combo").unwrap(),
1569            partition: PredictionPartition::Validation,
1570            fold_id: Some(FoldId::new("fold:0").unwrap()),
1571            observation_ids: vec![oid("obs:s1.combo"), oid("obs:s2.combo")],
1572            values: vec![vec![5.0], vec![9.0]],
1573            weights: Vec::new(),
1574            target_names: vec!["y".to_string()],
1575        };
1576
1577        let aggregated = aggregate_observation_predictions(
1578            &block,
1579            &relations,
1580            &AggregationPolicy::default(),
1581            &[sid("sample:1"), sid("sample:2")],
1582        )
1583        .unwrap();
1584
1585        assert_eq!(aggregated.values, vec![vec![5.0], vec![9.0]]);
1586    }
1587
1588    #[test]
1589    fn robust_mean_trims_extreme_repeated_predictions() {
1590        let observations = (0..10)
1591            .map(|idx| format!("obs:s1.{idx}"))
1592            .collect::<Vec<_>>();
1593        let relations = SampleRelationSet {
1594            records: observations
1595                .iter()
1596                .map(|observation| relation(observation, "sample:1"))
1597                .collect(),
1598        };
1599        let block = ObservationPredictionBlock {
1600            prediction_id: Some("pred:robust".to_string()),
1601            producer_node: NodeId::new("model:pls").unwrap(),
1602            partition: PredictionPartition::Validation,
1603            fold_id: Some(FoldId::new("fold:0").unwrap()),
1604            observation_ids: observations
1605                .iter()
1606                .map(|observation| oid(observation))
1607                .collect(),
1608            values: vec![
1609                vec![0.0],
1610                vec![1.0],
1611                vec![2.0],
1612                vec![3.0],
1613                vec![4.0],
1614                vec![5.0],
1615                vec![6.0],
1616                vec![7.0],
1617                vec![8.0],
1618                vec![100.0],
1619            ],
1620            weights: Vec::new(),
1621            target_names: vec!["y".to_string()],
1622        };
1623
1624        let aggregated = aggregate_observation_predictions(
1625            &block,
1626            &relations,
1627            &AggregationPolicy {
1628                method: AggregationMethod::RobustMean,
1629                ..AggregationPolicy::default()
1630            },
1631            &[sid("sample:1")],
1632        )
1633        .unwrap();
1634
1635        assert_eq!(aggregated.values, vec![vec![4.5]]);
1636    }
1637
1638    #[test]
1639    fn exclude_outliers_requires_custom_controller_path() {
1640        let relations = SampleRelationSet {
1641            records: vec![relation("obs:1", "sample:1")],
1642        };
1643        let block = ObservationPredictionBlock {
1644            prediction_id: None,
1645            producer_node: NodeId::new("model:pls").unwrap(),
1646            partition: PredictionPartition::Validation,
1647            fold_id: None,
1648            observation_ids: vec![oid("obs:1")],
1649            values: vec![vec![1.0]],
1650            weights: Vec::new(),
1651            target_names: vec!["y".to_string()],
1652        };
1653
1654        let error = aggregate_observation_predictions(
1655            &block,
1656            &relations,
1657            &AggregationPolicy {
1658                method: AggregationMethod::ExcludeOutliers,
1659                ..AggregationPolicy::default()
1660            },
1661            &[sid("sample:1")],
1662        )
1663        .unwrap_err()
1664        .to_string();
1665
1666        assert!(error.contains("custom aggregation controller"));
1667    }
1668
1669    #[test]
1670    fn aggregates_repeated_predictions_with_median_vote_and_weights() {
1671        let relations = SampleRelationSet {
1672            records: vec![
1673                relation("obs:1a", "sample:1"),
1674                relation("obs:1b", "sample:1"),
1675                relation("obs:1c", "sample:1"),
1676                relation("obs:2a", "sample:2"),
1677                relation("obs:2b", "sample:2"),
1678            ],
1679        };
1680        let base_block = ObservationPredictionBlock {
1681            prediction_id: Some("pred:oof".to_string()),
1682            producer_node: NodeId::new("model:pls").unwrap(),
1683            partition: PredictionPartition::Validation,
1684            fold_id: Some(FoldId::new("fold:0").unwrap()),
1685            observation_ids: vec![
1686                oid("obs:1a"),
1687                oid("obs:1b"),
1688                oid("obs:1c"),
1689                oid("obs:2a"),
1690                oid("obs:2b"),
1691            ],
1692            values: vec![
1693                vec![1.0, 0.0],
1694                vec![5.0, 1.0],
1695                vec![9.0, 1.0],
1696                vec![10.0, 2.0],
1697                vec![30.0, 3.0],
1698            ],
1699            weights: Vec::new(),
1700            target_names: vec!["regression".to_string(), "class".to_string()],
1701        };
1702        let sample_order = [sid("sample:1"), sid("sample:2")];
1703
1704        let median_policy = AggregationPolicy {
1705            method: AggregationMethod::Median,
1706            ..AggregationPolicy::default()
1707        };
1708        let median = aggregate_observation_predictions(
1709            &base_block,
1710            &relations,
1711            &median_policy,
1712            &sample_order,
1713        )
1714        .unwrap();
1715        assert_eq!(median.values, vec![vec![5.0, 1.0], vec![20.0, 2.5]]);
1716
1717        let vote_policy = AggregationPolicy {
1718            method: AggregationMethod::Vote,
1719            ..AggregationPolicy::default()
1720        };
1721        let vote =
1722            aggregate_observation_predictions(&base_block, &relations, &vote_policy, &sample_order)
1723                .unwrap();
1724        assert_eq!(vote.values, vec![vec![1.0, 1.0], vec![10.0, 2.0]]);
1725
1726        let mut weighted_block = base_block;
1727        weighted_block.weights = vec![1.0, 1.0, 2.0, 1.0, 3.0];
1728        let weighted_policy = AggregationPolicy {
1729            method: AggregationMethod::WeightedMean,
1730            weights: AggregationWeights::ControllerEmitted,
1731            ..AggregationPolicy::default()
1732        };
1733        let weighted = aggregate_observation_predictions(
1734            &weighted_block,
1735            &relations,
1736            &weighted_policy,
1737            &sample_order,
1738        )
1739        .unwrap();
1740        assert_eq!(weighted.values, vec![vec![6.0, 0.75], vec![25.0, 2.75]]);
1741    }
1742
1743    #[test]
1744    fn refuses_incompatible_observation_weight_contracts() {
1745        let relations = SampleRelationSet {
1746            records: vec![
1747                relation("obs:1a", "sample:1"),
1748                relation("obs:1b", "sample:1"),
1749            ],
1750        };
1751        let block = ObservationPredictionBlock {
1752            prediction_id: None,
1753            producer_node: NodeId::new("model:pls").unwrap(),
1754            partition: PredictionPartition::Validation,
1755            fold_id: None,
1756            observation_ids: vec![oid("obs:1a"), oid("obs:1b")],
1757            values: vec![vec![1.0], vec![2.0]],
1758            weights: vec![1.0, 2.0],
1759            target_names: vec!["y".to_string()],
1760        };
1761
1762        let mean_error = aggregate_observation_predictions(
1763            &block,
1764            &relations,
1765            &AggregationPolicy::default(),
1766            &[sid("sample:1")],
1767        )
1768        .unwrap_err()
1769        .to_string();
1770        assert!(
1771            mean_error.contains("non-weighted aggregation"),
1772            "unexpected mean error: {mean_error}"
1773        );
1774
1775        let mut missing_weights_block = block;
1776        missing_weights_block.weights.clear();
1777        let weighted_error = aggregate_observation_predictions(
1778            &missing_weights_block,
1779            &relations,
1780            &AggregationPolicy {
1781                method: AggregationMethod::WeightedMean,
1782                weights: AggregationWeights::ControllerEmitted,
1783                ..AggregationPolicy::default()
1784            },
1785            &[sid("sample:1")],
1786        )
1787        .unwrap_err()
1788        .to_string();
1789        assert!(
1790            weighted_error.contains("requires one weight per observation"),
1791            "unexpected weighted error: {weighted_error}"
1792        );
1793    }
1794
1795    #[test]
1796    fn aggregates_sample_predictions_to_target_and_group_units() {
1797        let relations = SampleRelationSet {
1798            records: vec![
1799                relation_with_units("obs:s1:a", "sample:1", "target:a", "group:left"),
1800                relation_with_units("obs:s1:b", "sample:1", "target:a", "group:left"),
1801                relation_with_units("obs:s2:a", "sample:2", "target:a", "group:left"),
1802                relation_with_units("obs:s3:a", "sample:3", "target:b", "group:right"),
1803            ],
1804        };
1805        let block = PredictionBlock {
1806            prediction_id: Some("pred:sample".to_string()),
1807            producer_node: NodeId::new("model:pls").unwrap(),
1808            partition: PredictionPartition::Validation,
1809            fold_id: Some(FoldId::new("fold:0").unwrap()),
1810            sample_ids: vec![sid("sample:1"), sid("sample:2"), sid("sample:3")],
1811            values: vec![vec![10.0], vec![4.0], vec![30.0]],
1812            target_names: vec!["y".to_string()],
1813        };
1814
1815        let target_policy = AggregationPolicy {
1816            aggregation_level: PredictionLevel::Target,
1817            method: AggregationMethod::Mean,
1818            ..AggregationPolicy::default()
1819        };
1820        let by_target = aggregate_sample_predictions_by_unit(
1821            &block,
1822            &relations,
1823            &target_policy,
1824            &[
1825                PredictionUnitId::Target(TargetId::new("target:a").unwrap()),
1826                PredictionUnitId::Target(TargetId::new("target:b").unwrap()),
1827            ],
1828        )
1829        .unwrap();
1830        assert_eq!(by_target.level, PredictionLevel::Target);
1831        assert_eq!(by_target.values, vec![vec![7.0], vec![30.0]]);
1832
1833        let group_policy = AggregationPolicy {
1834            aggregation_level: PredictionLevel::Group,
1835            method: AggregationMethod::WeightedMean,
1836            weights: AggregationWeights::RepetitionCount,
1837            ..AggregationPolicy::default()
1838        };
1839        let by_group = aggregate_sample_predictions_by_unit(
1840            &block,
1841            &relations,
1842            &group_policy,
1843            &[
1844                PredictionUnitId::Group(GroupId::new("group:left").unwrap()),
1845                PredictionUnitId::Group(GroupId::new("group:right").unwrap()),
1846            ],
1847        )
1848        .unwrap();
1849        assert_eq!(by_group.level, PredictionLevel::Group);
1850        assert_eq!(by_group.values, vec![vec![8.0], vec![30.0]]);
1851    }
1852
1853    #[test]
1854    fn refuses_target_group_aggregation_without_relation_units() {
1855        let relations = SampleRelationSet {
1856            records: vec![SampleRelation::new(oid("obs:1"), sid("sample:1"))],
1857        };
1858        let block = PredictionBlock {
1859            prediction_id: None,
1860            producer_node: NodeId::new("model:pls").unwrap(),
1861            partition: PredictionPartition::Validation,
1862            fold_id: None,
1863            sample_ids: vec![sid("sample:1")],
1864            values: vec![vec![1.0]],
1865            target_names: vec!["y".to_string()],
1866        };
1867
1868        let error = aggregate_sample_predictions_by_unit(
1869            &block,
1870            &relations,
1871            &AggregationPolicy {
1872                aggregation_level: PredictionLevel::Target,
1873                method: AggregationMethod::Mean,
1874                ..AggregationPolicy::default()
1875            },
1876            &[PredictionUnitId::Target(
1877                TargetId::new("target:missing").unwrap(),
1878            )],
1879        )
1880        .unwrap_err()
1881        .to_string();
1882        assert!(
1883            error.contains("missing target id"),
1884            "unexpected target aggregation error: {error}"
1885        );
1886    }
1887
1888    #[test]
1889    fn refuses_missing_observation_relation() {
1890        let block = ObservationPredictionBlock {
1891            prediction_id: None,
1892            producer_node: NodeId::new("model:pls").unwrap(),
1893            partition: PredictionPartition::Validation,
1894            fold_id: None,
1895            observation_ids: vec![oid("obs:missing")],
1896            values: vec![vec![1.0]],
1897            weights: Vec::new(),
1898            target_names: vec!["y".to_string()],
1899        };
1900
1901        assert!(aggregate_observation_predictions(
1902            &block,
1903            &SampleRelationSet::default(),
1904            &AggregationPolicy::default(),
1905            &[sid("sample:1")]
1906        )
1907        .is_err());
1908    }
1909}