1use std::collections::{BTreeMap, BTreeSet};
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{DagMlError, Result};
7use crate::ids::{ControllerId, FoldId, GroupId, NodeId, ObservationId, SampleId, TargetId};
8use crate::oof::{PredictionBlock, PredictionPartition};
9use crate::policy::{
10 AggregationMethod, AggregationPolicy, AggregationWeights, PredictionLevel, ReductionAxis,
11 ReductionMethod, ReductionPlan,
12};
13use crate::relation::{EntityUnitLevel, SampleRelationSet};
14
15pub const AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION: u32 = 1;
16pub const AGGREGATION_CONTROLLER_TASK_SCHEMA_ID: &str =
17 "https://github.com/GBeurier/dag-ml/schemas/aggregation_controller_task.v1.schema.json";
18pub const AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION: u32 = 1;
19pub const AGGREGATION_CONTROLLER_RESULT_SCHEMA_ID: &str =
20 "https://github.com/GBeurier/dag-ml/schemas/aggregation_controller_result.v1.schema.json";
21const DEFAULT_ROBUST_TRIM_FRACTION: f64 = 0.1;
22
23#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
24pub struct ObservationPredictionBlock {
25 #[serde(default)]
26 pub prediction_id: Option<String>,
27 pub producer_node: NodeId,
28 pub partition: PredictionPartition,
29 pub fold_id: Option<FoldId>,
30 pub observation_ids: Vec<ObservationId>,
31 pub values: Vec<Vec<f64>>,
32 #[serde(default, skip_serializing_if = "Vec::is_empty")]
33 pub weights: Vec<f64>,
34 #[serde(default)]
35 pub target_names: Vec<String>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case", tag = "level", content = "id")]
40pub enum PredictionUnitId {
41 Sample(SampleId),
42 Target(TargetId),
43 Group(GroupId),
44}
45
46impl PredictionUnitId {
47 pub fn level(&self) -> PredictionLevel {
48 match self {
49 Self::Sample(_) => PredictionLevel::Sample,
50 Self::Target(_) => PredictionLevel::Target,
51 Self::Group(_) => PredictionLevel::Group,
52 }
53 }
54}
55
56impl fmt::Display for PredictionUnitId {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 Self::Sample(id) => write!(f, "sample:{id}"),
60 Self::Target(id) => write!(f, "target:{id}"),
61 Self::Group(id) => write!(f, "group:{id}"),
62 }
63 }
64}
65
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
67pub struct AggregatedPredictionBlock {
68 #[serde(default)]
69 pub prediction_id: Option<String>,
70 pub producer_node: NodeId,
71 pub partition: PredictionPartition,
72 pub fold_id: Option<FoldId>,
73 pub level: PredictionLevel,
74 pub unit_ids: Vec<PredictionUnitId>,
75 pub values: Vec<Vec<f64>>,
76 #[serde(default)]
77 pub target_names: Vec<String>,
78}
79
80impl AggregatedPredictionBlock {
81 pub fn validate_shape(&self) -> Result<usize> {
82 if self.unit_ids.len() != self.values.len() {
83 return Err(DagMlError::OofValidation(format!(
84 "producer `{}` has {} aggregated unit ids but {} prediction rows",
85 self.producer_node,
86 self.unit_ids.len(),
87 self.values.len()
88 )));
89 }
90 if self
91 .unit_ids
92 .iter()
93 .any(|unit_id| unit_id.level() != self.level)
94 {
95 return Err(DagMlError::OofValidation(format!(
96 "producer `{}` emitted aggregated units outside level {:?}",
97 self.producer_node, self.level
98 )));
99 }
100 let unique = self.unit_ids.iter().collect::<BTreeSet<_>>();
101 if unique.len() != self.unit_ids.len() {
102 return Err(DagMlError::OofValidation(format!(
103 "producer `{}` emitted duplicate aggregated unit ids",
104 self.producer_node
105 )));
106 }
107 let width = self.values.first().map_or(0, Vec::len);
108 if width == 0 {
109 return Err(DagMlError::OofValidation(format!(
110 "producer `{}` emitted empty aggregated prediction rows",
111 self.producer_node
112 )));
113 }
114 if self.values.iter().any(|row| row.len() != width) {
115 return Err(DagMlError::OofValidation(format!(
116 "producer `{}` emitted ragged aggregated prediction rows",
117 self.producer_node
118 )));
119 }
120 if self.values.iter().flatten().any(|value| !value.is_finite()) {
121 return Err(DagMlError::OofValidation(format!(
122 "producer `{}` emitted non-finite aggregated prediction values",
123 self.producer_node
124 )));
125 }
126 if !self.target_names.is_empty() && self.target_names.len() != width {
127 return Err(DagMlError::OofValidation(format!(
128 "producer `{}` has {} aggregated target names for width {}",
129 self.producer_node,
130 self.target_names.len(),
131 width
132 )));
133 }
134 Ok(width)
135 }
136}
137
138impl ObservationPredictionBlock {
139 pub fn validate_shape(&self) -> Result<usize> {
140 if self.observation_ids.len() != self.values.len() {
141 return Err(DagMlError::OofValidation(format!(
142 "producer `{}` has {} observation ids but {} prediction rows",
143 self.producer_node,
144 self.observation_ids.len(),
145 self.values.len()
146 )));
147 }
148 let width = self.values.first().map_or(0, Vec::len);
149 if width == 0 {
150 return Err(DagMlError::OofValidation(format!(
151 "producer `{}` emitted empty observation prediction rows",
152 self.producer_node
153 )));
154 }
155 if self.values.iter().any(|row| row.len() != width) {
156 return Err(DagMlError::OofValidation(format!(
157 "producer `{}` emitted ragged observation prediction rows",
158 self.producer_node
159 )));
160 }
161 if self.values.iter().flatten().any(|value| !value.is_finite()) {
162 return Err(DagMlError::OofValidation(format!(
163 "producer `{}` emitted non-finite observation prediction values",
164 self.producer_node
165 )));
166 }
167 if !self.weights.is_empty() {
168 if self.weights.len() != self.observation_ids.len() {
169 return Err(DagMlError::OofValidation(format!(
170 "producer `{}` has {} observation weights but {} observation ids",
171 self.producer_node,
172 self.weights.len(),
173 self.observation_ids.len()
174 )));
175 }
176 if self
177 .weights
178 .iter()
179 .any(|weight| !weight.is_finite() || *weight < 0.0)
180 {
181 return Err(DagMlError::OofValidation(format!(
182 "producer `{}` emitted non-finite or negative observation weights",
183 self.producer_node
184 )));
185 }
186 }
187 if !self.target_names.is_empty() && self.target_names.len() != width {
188 return Err(DagMlError::OofValidation(format!(
189 "producer `{}` has {} target names for width {}",
190 self.producer_node,
191 self.target_names.len(),
192 width
193 )));
194 }
195 let unique = self.observation_ids.iter().collect::<BTreeSet<_>>();
196 if unique.len() != self.observation_ids.len() {
197 return Err(DagMlError::OofValidation(format!(
198 "producer `{}` emitted duplicate observation predictions",
199 self.producer_node
200 )));
201 }
202 Ok(width)
203 }
204}
205
206#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
207pub struct AggregationControllerTask {
208 #[serde(default = "default_aggregation_controller_task_schema_version")]
209 pub schema_version: u32,
210 pub task_id: String,
211 pub controller_id: ControllerId,
212 pub policy: AggregationPolicy,
213 #[serde(default, skip_serializing_if = "Option::is_none")]
214 pub reduction_plan: Option<ReductionPlan>,
215 pub input: AggregationControllerInput,
216}
217
218#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
219#[serde(tag = "input_kind", rename_all = "snake_case")]
220pub enum AggregationControllerInput {
221 ObservationToSample {
222 block: ObservationPredictionBlock,
223 relations: SampleRelationSet,
224 requested_sample_order: Vec<SampleId>,
225 },
226 SampleToUnit {
227 block: PredictionBlock,
228 relations: SampleRelationSet,
229 requested_unit_order: Vec<PredictionUnitId>,
230 },
231}
232
233#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
234pub struct AggregationControllerResult {
235 #[serde(default = "default_aggregation_controller_result_schema_version")]
236 pub schema_version: u32,
237 pub task_id: String,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub reduction_plan: Option<ReductionPlan>,
240 pub output: AggregationControllerOutput,
241}
242
243#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
244#[serde(tag = "output_kind", rename_all = "snake_case")]
245pub enum AggregationControllerOutput {
246 Sample { block: PredictionBlock },
247 Unit { block: AggregatedPredictionBlock },
248}
249
250impl AggregationControllerTask {
251 pub fn validate(&self) -> Result<()> {
252 if self.schema_version != AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION {
253 return Err(DagMlError::OofValidation(format!(
254 "aggregation controller task `{}` uses unsupported schema_version {}",
255 self.task_id, self.schema_version
256 )));
257 }
258 if self.task_id.trim().is_empty() {
259 return Err(DagMlError::OofValidation(
260 "aggregation controller task_id is empty".to_string(),
261 ));
262 }
263 self.policy.validate()?;
264 if self.policy.method != AggregationMethod::CustomController {
265 return Err(DagMlError::OofValidation(format!(
266 "aggregation controller task `{}` must use custom_controller method",
267 self.task_id
268 )));
269 }
270 let controller = self
271 .policy
272 .custom_controller
273 .as_ref()
274 .expect("custom_controller policy validation requires controller spec");
275 if controller.controller_id != self.controller_id {
276 return Err(DagMlError::OofValidation(format!(
277 "aggregation controller task `{}` targets controller `{}` but policy targets `{}`",
278 self.task_id, self.controller_id, controller.controller_id
279 )));
280 }
281 if let Some(reduction_plan) = &self.reduction_plan {
282 validate_aggregation_controller_reduction_plan(
283 reduction_plan,
284 &self.policy,
285 &self.input,
286 )?;
287 }
288 match &self.input {
289 AggregationControllerInput::ObservationToSample {
290 block,
291 relations,
292 requested_sample_order,
293 } => validate_aggregation_controller_observation_input(
294 block,
295 relations,
296 &self.policy,
297 requested_sample_order,
298 ),
299 AggregationControllerInput::SampleToUnit {
300 block,
301 relations,
302 requested_unit_order,
303 } => validate_aggregation_controller_sample_input(
304 block,
305 relations,
306 &self.policy,
307 requested_unit_order,
308 ),
309 }
310 }
311}
312
313impl AggregationControllerResult {
314 pub fn validate_for_task(&self, task: &AggregationControllerTask) -> Result<()> {
315 task.validate()?;
316 if self.schema_version != AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION {
317 return Err(DagMlError::OofValidation(format!(
318 "aggregation controller result `{}` uses unsupported schema_version {}",
319 self.task_id, self.schema_version
320 )));
321 }
322 if self.task_id != task.task_id {
323 return Err(DagMlError::OofValidation(format!(
324 "aggregation controller result task_id `{}` does not match task `{}`",
325 self.task_id, task.task_id
326 )));
327 }
328 validate_aggregation_controller_result_reduction_plan(task, self)?;
329 match (&task.input, &self.output) {
330 (
331 AggregationControllerInput::ObservationToSample {
332 block: input_block,
333 requested_sample_order,
334 ..
335 },
336 AggregationControllerOutput::Sample { block },
337 ) => validate_aggregation_controller_sample_output(
338 input_block,
339 requested_sample_order,
340 block,
341 ),
342 (
343 AggregationControllerInput::SampleToUnit {
344 block: input_block,
345 requested_unit_order,
346 ..
347 },
348 AggregationControllerOutput::Unit { block },
349 ) => validate_aggregation_controller_unit_output(
350 input_block,
351 requested_unit_order,
352 task.policy.aggregation_level,
353 block,
354 ),
355 (AggregationControllerInput::ObservationToSample { .. }, _) => {
356 Err(DagMlError::OofValidation(format!(
357 "aggregation controller result `{}` must return sample output for observation input",
358 self.task_id
359 )))
360 }
361 (AggregationControllerInput::SampleToUnit { .. }, _) => {
362 Err(DagMlError::OofValidation(format!(
363 "aggregation controller result `{}` must return unit output for sample input",
364 self.task_id
365 )))
366 }
367 }
368 }
369}
370
371fn validate_aggregation_controller_reduction_plan(
372 plan: &ReductionPlan,
373 policy: &AggregationPolicy,
374 input: &AggregationControllerInput,
375) -> Result<()> {
376 plan.validate()
377 .map_err(|error| DagMlError::OofValidation(error.to_string()))?;
378 if plan.method != ReductionMethod::from(policy.method) {
379 return Err(DagMlError::OofValidation(format!(
380 "reduction plan method {:?} does not match aggregation policy method {:?}",
381 plan.method, policy.method
382 )));
383 }
384 if plan.weight_source != policy.weights {
385 return Err(DagMlError::OofValidation(format!(
386 "reduction plan weight_source {:?} does not match aggregation policy weights {:?}",
387 plan.weight_source, policy.weights
388 )));
389 }
390 if plan.method == ReductionMethod::Custom {
391 let plan_controller = plan
392 .custom_controller
393 .as_ref()
394 .expect("reduction plan validation requires custom controller");
395 let policy_controller = policy
396 .custom_controller
397 .as_ref()
398 .expect("aggregation policy validation requires custom controller");
399 if plan_controller.controller_id != policy_controller.controller_id {
400 return Err(DagMlError::OofValidation(format!(
401 "reduction plan controller `{}` does not match aggregation policy controller `{}`",
402 plan_controller.controller_id, policy_controller.controller_id
403 )));
404 }
405 }
406 if plan.axis != ReductionAxis::Unit {
407 return Err(DagMlError::OofValidation(format!(
408 "aggregation controller reduction plan axis {:?} is not supported for unit aggregation tasks",
409 plan.axis
410 )));
411 }
412 match input {
413 AggregationControllerInput::ObservationToSample { .. } => {
414 if !matches!(
415 plan.input_unit_level,
416 EntityUnitLevel::Observation | EntityUnitLevel::Combo
417 ) {
418 return Err(DagMlError::OofValidation(format!(
419 "observation aggregation reduction plan input_unit_level {:?} is invalid",
420 plan.input_unit_level
421 )));
422 }
423 if plan.output_unit_level != EntityUnitLevel::PhysicalSample {
424 return Err(DagMlError::OofValidation(format!(
425 "observation aggregation reduction plan output_unit_level {:?} must be physical_sample",
426 plan.output_unit_level
427 )));
428 }
429 if policy.aggregation_level != PredictionLevel::Sample {
430 return Err(DagMlError::OofValidation(format!(
431 "observation aggregation reduction plan must output sample predictions, got {:?}",
432 policy.aggregation_level
433 )));
434 }
435 }
436 AggregationControllerInput::SampleToUnit { .. } => {
437 if plan.input_unit_level != EntityUnitLevel::PhysicalSample {
438 return Err(DagMlError::OofValidation(format!(
439 "sample aggregation reduction plan input_unit_level {:?} must be physical_sample",
440 plan.input_unit_level
441 )));
442 }
443 if plan.output_unit_level != EntityUnitLevel::PhysicalSample
444 || policy.aggregation_level != PredictionLevel::Sample
445 {
446 return Err(DagMlError::OofValidation(
447 "sample aggregation reduction plans currently support only physical_sample output; target/group aggregation remains available without a ReductionPlan".to_string(),
448 ));
449 }
450 }
451 }
452 Ok(())
453}
454
455fn validate_aggregation_controller_result_reduction_plan(
456 task: &AggregationControllerTask,
457 result: &AggregationControllerResult,
458) -> Result<()> {
459 match (&task.reduction_plan, &result.reduction_plan) {
460 (Some(task_plan), Some(result_plan)) if task_plan == result_plan => Ok(()),
461 (Some(_), Some(_)) => Err(DagMlError::OofValidation(format!(
462 "aggregation controller result `{}` reduction_plan does not match task reduction_plan",
463 result.task_id
464 ))),
465 (Some(_), None) => Err(DagMlError::OofValidation(format!(
466 "aggregation controller result `{}` must echo task reduction_plan",
467 result.task_id
468 ))),
469 (None, Some(_)) => Err(DagMlError::OofValidation(format!(
470 "aggregation controller result `{}` declares reduction_plan but task does not",
471 result.task_id
472 ))),
473 (None, None) => Ok(()),
474 }
475}
476
477fn validate_aggregation_controller_observation_input(
478 block: &ObservationPredictionBlock,
479 relations: &SampleRelationSet,
480 policy: &AggregationPolicy,
481 requested_sample_order: &[SampleId],
482) -> Result<()> {
483 block.validate_shape()?;
484 relations.validate()?;
485 if policy.aggregation_level != PredictionLevel::Sample {
486 return Err(DagMlError::OofValidation(format!(
487 "observation aggregation controller task must output sample predictions, got {:?}",
488 policy.aggregation_level
489 )));
490 }
491 validate_unique_order(requested_sample_order, "requested_sample_order")?;
492 if matches!(
493 policy.weights,
494 AggregationWeights::ControllerEmitted | AggregationWeights::Quality
495 ) && block.weights.is_empty()
496 {
497 return Err(DagMlError::OofValidation(format!(
498 "aggregation controller task with {:?} weights requires observation weights",
499 policy.weights
500 )));
501 }
502 let requested = requested_sample_order.iter().collect::<BTreeSet<_>>();
503 let mut covered = BTreeSet::new();
504 for observation_id in &block.observation_ids {
505 let sample_id = relations
506 .sample_for_observation(observation_id)
507 .ok_or_else(|| {
508 DagMlError::OofValidation(format!(
509 "observation prediction `{observation_id}` has no sample relation"
510 ))
511 })?;
512 if !requested.contains(sample_id) {
513 return Err(DagMlError::OofValidation(format!(
514 "observation prediction `{observation_id}` maps to unexpected sample `{sample_id}`"
515 )));
516 }
517 covered.insert(sample_id);
518 }
519 for sample_id in requested_sample_order {
520 if !covered.contains(sample_id) {
521 return Err(DagMlError::OofValidation(format!(
522 "sample `{sample_id}` has no observation predictions for aggregation controller task"
523 )));
524 }
525 }
526 Ok(())
527}
528
529fn validate_aggregation_controller_sample_input(
530 block: &PredictionBlock,
531 relations: &SampleRelationSet,
532 policy: &AggregationPolicy,
533 requested_unit_order: &[PredictionUnitId],
534) -> Result<()> {
535 validate_sample_prediction_block(block)?;
536 relations.validate()?;
537 if policy.aggregation_level == PredictionLevel::Observation {
538 return Err(DagMlError::OofValidation(
539 "sample aggregation controller task cannot output observation-level predictions"
540 .to_string(),
541 ));
542 }
543 if matches!(
544 policy.weights,
545 AggregationWeights::ControllerEmitted | AggregationWeights::Quality
546 ) {
547 return Err(DagMlError::OofValidation(format!(
548 "sample aggregation controller task cannot use {:?} weights without sample weights",
549 policy.weights
550 )));
551 }
552 validate_unique_order(requested_unit_order, "requested_unit_order")?;
553 if requested_unit_order
554 .iter()
555 .any(|unit_id| unit_id.level() != policy.aggregation_level)
556 {
557 return Err(DagMlError::OofValidation(format!(
558 "aggregation controller requested units do not match level {:?}",
559 policy.aggregation_level
560 )));
561 }
562 let requested = requested_unit_order.iter().collect::<BTreeSet<_>>();
563 let mut covered = BTreeSet::new();
564 for sample_id in &block.sample_ids {
565 let unit_id = unit_for_sample(relations, policy.aggregation_level, sample_id)?;
566 if !requested.contains(&unit_id) {
567 return Err(DagMlError::OofValidation(format!(
568 "sample prediction `{sample_id}` maps to unexpected aggregation unit `{unit_id}`"
569 )));
570 }
571 covered.insert(unit_id);
572 }
573 for unit_id in requested_unit_order {
574 if !covered.contains(unit_id) {
575 return Err(DagMlError::OofValidation(format!(
576 "aggregation unit `{unit_id}` has no sample predictions for aggregation controller task"
577 )));
578 }
579 }
580 Ok(())
581}
582
583fn validate_aggregation_controller_sample_output(
584 input_block: &ObservationPredictionBlock,
585 requested_sample_order: &[SampleId],
586 block: &PredictionBlock,
587) -> Result<()> {
588 validate_sample_prediction_block(block)?;
589 if block.producer_node != input_block.producer_node
590 || block.partition != input_block.partition
591 || block.fold_id != input_block.fold_id
592 {
593 return Err(DagMlError::OofValidation(format!(
594 "aggregation controller sample output for `{}` does not preserve producer, partition and fold",
595 input_block.producer_node
596 )));
597 }
598 if block.target_names != input_block.target_names {
599 return Err(DagMlError::OofValidation(format!(
600 "aggregation controller sample output for `{}` does not preserve target names",
601 input_block.producer_node
602 )));
603 }
604 if block.sample_ids != requested_sample_order {
605 return Err(DagMlError::OofValidation(format!(
606 "aggregation controller sample output for `{}` does not match requested sample order",
607 input_block.producer_node
608 )));
609 }
610 Ok(())
611}
612
613fn validate_aggregation_controller_unit_output(
614 input_block: &PredictionBlock,
615 requested_unit_order: &[PredictionUnitId],
616 expected_level: PredictionLevel,
617 block: &AggregatedPredictionBlock,
618) -> Result<()> {
619 block.validate_shape()?;
620 if block.producer_node != input_block.producer_node
621 || block.partition != input_block.partition
622 || block.fold_id != input_block.fold_id
623 {
624 return Err(DagMlError::OofValidation(format!(
625 "aggregation controller unit output for `{}` does not preserve producer, partition and fold",
626 input_block.producer_node
627 )));
628 }
629 if block.target_names != input_block.target_names {
630 return Err(DagMlError::OofValidation(format!(
631 "aggregation controller unit output for `{}` does not preserve target names",
632 input_block.producer_node
633 )));
634 }
635 if block.level != expected_level {
636 return Err(DagMlError::OofValidation(format!(
637 "aggregation controller unit output for `{}` has level {:?}, expected {:?}",
638 input_block.producer_node, block.level, expected_level
639 )));
640 }
641 if block.unit_ids != requested_unit_order {
642 return Err(DagMlError::OofValidation(format!(
643 "aggregation controller unit output for `{}` does not match requested unit order",
644 input_block.producer_node
645 )));
646 }
647 Ok(())
648}
649
650fn validate_unique_order<T>(values: &[T], label: &str) -> Result<()>
651where
652 T: Ord,
653{
654 if values.is_empty() {
655 return Err(DagMlError::OofValidation(format!(
656 "aggregation controller {label} is empty"
657 )));
658 }
659 let unique = values.iter().collect::<BTreeSet<_>>();
660 if unique.len() != values.len() {
661 return Err(DagMlError::OofValidation(format!(
662 "aggregation controller {label} contains duplicates"
663 )));
664 }
665 Ok(())
666}
667
668pub fn aggregate_observation_predictions(
669 block: &ObservationPredictionBlock,
670 relations: &SampleRelationSet,
671 policy: &AggregationPolicy,
672 requested_sample_order: &[SampleId],
673) -> Result<PredictionBlock> {
674 let width = block.validate_shape()?;
675 relations.validate()?;
676 policy.validate()?;
677 if requested_sample_order.is_empty() {
678 return Err(DagMlError::OofValidation(
679 "aggregation requested_sample_order is empty".to_string(),
680 ));
681 }
682 let requested = requested_sample_order.iter().collect::<BTreeSet<_>>();
683 if requested.len() != requested_sample_order.len() {
684 return Err(DagMlError::OofValidation(
685 "aggregation requested_sample_order contains duplicates".to_string(),
686 ));
687 }
688 if policy.aggregation_level != PredictionLevel::Sample {
689 return Err(DagMlError::OofValidation(format!(
690 "observation aggregation currently supports sample-level output, got {:?}",
691 policy.aggregation_level
692 )));
693 }
694 if policy.method == AggregationMethod::WeightedMean
695 && policy.weights == AggregationWeights::None
696 {
697 return Err(DagMlError::OofValidation(
698 "weighted_mean aggregation requires an explicit weights policy".to_string(),
699 ));
700 }
701 if policy.method != AggregationMethod::WeightedMean
702 && policy.weights != AggregationWeights::None
703 {
704 return Err(DagMlError::OofValidation(format!(
705 "aggregation weights {:?} are only valid with weighted_mean",
706 policy.weights
707 )));
708 }
709 if !block.weights.is_empty() && policy.method != AggregationMethod::WeightedMean {
710 return Err(DagMlError::OofValidation(format!(
711 "producer `{}` supplied observation weights for non-weighted aggregation {:?}",
712 block.producer_node, policy.method
713 )));
714 }
715
716 let store_rows = matches!(
717 policy.method,
718 AggregationMethod::Median | AggregationMethod::Vote | AggregationMethod::RobustMean
719 );
720 let mut accumulators = requested_sample_order
721 .iter()
722 .cloned()
723 .map(|sample_id| (sample_id, SampleAccumulator::new(width, store_rows)))
724 .collect::<BTreeMap<_, _>>();
725
726 for (row_idx, (observation_id, row)) in block
727 .observation_ids
728 .iter()
729 .zip(block.values.iter())
730 .enumerate()
731 {
732 let sample_id = relations
733 .sample_for_observation(observation_id)
734 .ok_or_else(|| {
735 DagMlError::OofValidation(format!(
736 "observation prediction `{observation_id}` has no sample relation"
737 ))
738 })?;
739 if !requested.contains(sample_id) {
740 return Err(DagMlError::OofValidation(format!(
741 "observation prediction `{observation_id}` maps to unexpected sample `{sample_id}`"
742 )));
743 }
744 let accumulator = accumulators
745 .get_mut(sample_id)
746 .expect("requested sample accumulator exists");
747 let weight = observation_weight(block, policy, row_idx)?;
748 accumulator.push(row, weight);
749 }
750
751 let values = requested_sample_order
752 .iter()
753 .map(|sample_id| {
754 let accumulator = accumulators
755 .get(sample_id)
756 .expect("requested sample accumulator exists");
757 if accumulator.count == 0 {
758 return Err(DagMlError::OofValidation(format!(
759 "sample `{sample_id}` has no observation predictions to aggregate"
760 )));
761 }
762 match policy.method {
763 AggregationMethod::Mean => Ok(accumulator.mean()),
764 AggregationMethod::WeightedMean => accumulator.weighted_mean(&sample_id.to_string()),
765 AggregationMethod::Median => Ok(accumulator.median()),
766 AggregationMethod::Vote => Ok(accumulator.vote()),
767 AggregationMethod::RobustMean => {
768 Ok(accumulator.robust_mean(DEFAULT_ROBUST_TRIM_FRACTION))
769 }
770 AggregationMethod::ExcludeOutliers => Err(DagMlError::OofValidation(
771 "exclude_outliers aggregation requires a custom aggregation controller"
772 .to_string(),
773 )),
774 AggregationMethod::None => {
775 if accumulator.count == 1 {
776 Ok(accumulator
777 .first_row
778 .clone()
779 .expect("single prediction accumulator stores first row"))
780 } else {
781 Err(DagMlError::OofValidation(format!(
782 "sample `{sample_id}` has {} observation predictions but aggregation method is none",
783 accumulator.count
784 )))
785 }
786 }
787 AggregationMethod::CustomController => Err(DagMlError::OofValidation(format!(
788 "aggregation method {:?} is delegated to an aggregation controller",
789 policy.method
790 ))),
791 }
792 })
793 .collect::<Result<Vec<Vec<f64>>>>()?;
794
795 Ok(PredictionBlock {
796 prediction_id: block
797 .prediction_id
798 .as_ref()
799 .map(|prediction_id| format!("{prediction_id}:sample_agg")),
800 producer_node: block.producer_node.clone(),
801 partition: block.partition.clone(),
802 fold_id: block.fold_id.clone(),
803 sample_ids: requested_sample_order.to_vec(),
804 values,
805 target_names: block.target_names.clone(),
806 })
807}
808
809pub fn aggregate_sample_predictions_by_unit(
810 block: &PredictionBlock,
811 relations: &SampleRelationSet,
812 policy: &AggregationPolicy,
813 requested_unit_order: &[PredictionUnitId],
814) -> Result<AggregatedPredictionBlock> {
815 let width = validate_sample_prediction_block(block)?;
816 relations.validate()?;
817 policy.validate()?;
818 if requested_unit_order.is_empty() {
819 return Err(DagMlError::OofValidation(
820 "aggregation requested_unit_order is empty".to_string(),
821 ));
822 }
823 let requested_level = policy.aggregation_level;
824 if requested_level == PredictionLevel::Observation {
825 return Err(DagMlError::OofValidation(
826 "sample prediction aggregation cannot output observation-level predictions".to_string(),
827 ));
828 }
829 if requested_unit_order
830 .iter()
831 .any(|unit_id| unit_id.level() != requested_level)
832 {
833 return Err(DagMlError::OofValidation(format!(
834 "aggregation requested units do not match level {:?}",
835 requested_level
836 )));
837 }
838 let requested = requested_unit_order.iter().collect::<BTreeSet<_>>();
839 if requested.len() != requested_unit_order.len() {
840 return Err(DagMlError::OofValidation(
841 "aggregation requested_unit_order contains duplicates".to_string(),
842 ));
843 }
844
845 let by_sample = block
846 .sample_ids
847 .iter()
848 .cloned()
849 .zip(block.values.iter().cloned())
850 .collect::<BTreeMap<_, _>>();
851 if requested_level == PredictionLevel::Sample {
852 let values = requested_unit_order
853 .iter()
854 .map(|unit_id| {
855 let PredictionUnitId::Sample(sample_id) = unit_id else {
856 unreachable!("requested unit level already validated");
857 };
858 by_sample.get(sample_id).cloned().ok_or_else(|| {
859 DagMlError::OofValidation(format!(
860 "sample prediction block for `{}` is missing requested sample `{sample_id}`",
861 block.producer_node
862 ))
863 })
864 })
865 .collect::<Result<Vec<_>>>()?;
866 if by_sample.len() != requested_unit_order.len() {
867 return Err(DagMlError::OofValidation(format!(
868 "sample prediction block for `{}` contains samples outside requested sample order",
869 block.producer_node
870 )));
871 }
872 let aggregated = AggregatedPredictionBlock {
873 prediction_id: block.prediction_id.clone(),
874 producer_node: block.producer_node.clone(),
875 partition: block.partition.clone(),
876 fold_id: block.fold_id.clone(),
877 level: PredictionLevel::Sample,
878 unit_ids: requested_unit_order.to_vec(),
879 values,
880 target_names: block.target_names.clone(),
881 };
882 aggregated.validate_shape()?;
883 return Ok(aggregated);
884 }
885
886 if policy.method == AggregationMethod::WeightedMean
887 && matches!(
888 policy.weights,
889 AggregationWeights::ControllerEmitted | AggregationWeights::Quality
890 )
891 {
892 return Err(DagMlError::OofValidation(format!(
893 "sample-to-{:?} weighted_mean cannot use {:?} weights without sample-level weights",
894 requested_level, policy.weights
895 )));
896 }
897
898 let store_rows = matches!(
899 policy.method,
900 AggregationMethod::Median | AggregationMethod::Vote | AggregationMethod::RobustMean
901 );
902 let mut accumulators = requested_unit_order
903 .iter()
904 .cloned()
905 .map(|unit_id| (unit_id, SampleAccumulator::new(width, store_rows)))
906 .collect::<BTreeMap<_, _>>();
907
908 for (sample_id, row) in block.sample_ids.iter().zip(block.values.iter()) {
909 let unit_id = unit_for_sample(relations, requested_level, sample_id)?;
910 if !requested.contains(&unit_id) {
911 return Err(DagMlError::OofValidation(format!(
912 "sample prediction `{sample_id}` maps to unexpected aggregation unit `{unit_id}`"
913 )));
914 }
915 let weight = sample_weight(relations, policy, sample_id)?;
916 accumulators
917 .get_mut(&unit_id)
918 .expect("requested aggregation unit accumulator exists")
919 .push(row, weight);
920 }
921
922 let values = requested_unit_order
923 .iter()
924 .map(|unit_id| {
925 let accumulator = accumulators
926 .get(unit_id)
927 .expect("requested aggregation unit accumulator exists");
928 if accumulator.count == 0 {
929 return Err(DagMlError::OofValidation(format!(
930 "aggregation unit `{unit_id}` has no sample predictions to aggregate"
931 )));
932 }
933 match policy.method {
934 AggregationMethod::Mean => Ok(accumulator.mean()),
935 AggregationMethod::WeightedMean => accumulator.weighted_mean(&unit_id.to_string()),
936 AggregationMethod::Median => Ok(accumulator.median()),
937 AggregationMethod::Vote => Ok(accumulator.vote()),
938 AggregationMethod::RobustMean => {
939 Ok(accumulator.robust_mean(DEFAULT_ROBUST_TRIM_FRACTION))
940 }
941 AggregationMethod::ExcludeOutliers => Err(DagMlError::OofValidation(
942 "exclude_outliers aggregation requires a custom aggregation controller"
943 .to_string(),
944 )),
945 AggregationMethod::None => {
946 if accumulator.count == 1 {
947 Ok(accumulator
948 .first_row
949 .clone()
950 .expect("single prediction accumulator stores first row"))
951 } else {
952 Err(DagMlError::OofValidation(format!(
953 "aggregation unit `{unit_id}` has {} sample predictions but aggregation method is none",
954 accumulator.count
955 )))
956 }
957 }
958 AggregationMethod::CustomController => Err(DagMlError::OofValidation(format!(
959 "aggregation method {:?} is delegated to an aggregation controller",
960 policy.method
961 ))),
962 }
963 })
964 .collect::<Result<Vec<_>>>()?;
965
966 let suffix = match requested_level {
967 PredictionLevel::Target => "target_agg",
968 PredictionLevel::Group => "group_agg",
969 PredictionLevel::Sample => "sample_agg",
970 PredictionLevel::Observation => unreachable!("observation output rejected above"),
971 };
972 let aggregated = AggregatedPredictionBlock {
973 prediction_id: block
974 .prediction_id
975 .as_ref()
976 .map(|prediction_id| format!("{prediction_id}:{suffix}")),
977 producer_node: block.producer_node.clone(),
978 partition: block.partition.clone(),
979 fold_id: block.fold_id.clone(),
980 level: requested_level,
981 unit_ids: requested_unit_order.to_vec(),
982 values,
983 target_names: block.target_names.clone(),
984 };
985 aggregated.validate_shape()?;
986 Ok(aggregated)
987}
988
989fn validate_sample_prediction_block(block: &PredictionBlock) -> Result<usize> {
990 let width = block.validate_shape()?;
991 if block
992 .values
993 .iter()
994 .flatten()
995 .any(|value| !value.is_finite())
996 {
997 return Err(DagMlError::OofValidation(format!(
998 "producer `{}` emitted non-finite sample prediction values",
999 block.producer_node
1000 )));
1001 }
1002 let unique = block.sample_ids.iter().collect::<BTreeSet<_>>();
1003 if unique.len() != block.sample_ids.len() {
1004 return Err(DagMlError::OofValidation(format!(
1005 "producer `{}` emitted duplicate sample predictions",
1006 block.producer_node
1007 )));
1008 }
1009 Ok(width)
1010}
1011
1012fn unit_for_sample(
1013 relations: &SampleRelationSet,
1014 level: PredictionLevel,
1015 sample_id: &SampleId,
1016) -> Result<PredictionUnitId> {
1017 match level {
1018 PredictionLevel::Sample => Ok(PredictionUnitId::Sample(sample_id.clone())),
1019 PredictionLevel::Target => relations
1020 .target_for_sample(sample_id)
1021 .cloned()
1022 .map(PredictionUnitId::Target)
1023 .ok_or_else(|| {
1024 DagMlError::OofValidation(format!(
1025 "sample `{sample_id}` is missing target id for target aggregation"
1026 ))
1027 }),
1028 PredictionLevel::Group => relations
1029 .group_for_sample(sample_id)
1030 .cloned()
1031 .map(PredictionUnitId::Group)
1032 .ok_or_else(|| {
1033 DagMlError::OofValidation(format!(
1034 "sample `{sample_id}` is missing group id for group aggregation"
1035 ))
1036 }),
1037 PredictionLevel::Observation => Err(DagMlError::OofValidation(
1038 "sample prediction aggregation cannot output observation-level predictions".to_string(),
1039 )),
1040 }
1041}
1042
1043fn sample_weight(
1044 relations: &SampleRelationSet,
1045 policy: &AggregationPolicy,
1046 sample_id: &SampleId,
1047) -> Result<f64> {
1048 if policy.method != AggregationMethod::WeightedMean {
1049 return Ok(1.0);
1050 }
1051 match policy.weights {
1052 AggregationWeights::RepetitionCount => {
1053 let count = relations.observation_count_for_sample(sample_id);
1054 if count == 0 {
1055 return Err(DagMlError::OofValidation(format!(
1056 "sample `{sample_id}` has no observation relations for repetition_count weights"
1057 )));
1058 }
1059 Ok(count as f64)
1060 }
1061 AggregationWeights::ControllerEmitted | AggregationWeights::Quality => {
1062 Err(DagMlError::OofValidation(format!(
1063 "sample-level {:?} weights are not present in PredictionBlock",
1064 policy.weights
1065 )))
1066 }
1067 AggregationWeights::None => Err(DagMlError::OofValidation(
1068 "weighted_mean aggregation requires an explicit weights policy".to_string(),
1069 )),
1070 }
1071}
1072
1073#[derive(Clone, Debug)]
1074struct SampleAccumulator {
1075 sum: Vec<f64>,
1076 weighted_sum: Vec<f64>,
1077 weight_sum: f64,
1078 rows: Vec<Vec<f64>>,
1079 first_row: Option<Vec<f64>>,
1080 store_rows: bool,
1081 count: usize,
1082}
1083
1084impl SampleAccumulator {
1085 fn new(width: usize, store_rows: bool) -> Self {
1086 Self {
1087 sum: vec![0.0; width],
1088 weighted_sum: vec![0.0; width],
1089 weight_sum: 0.0,
1090 rows: Vec::new(),
1091 first_row: None,
1092 store_rows,
1093 count: 0,
1094 }
1095 }
1096
1097 fn push(&mut self, row: &[f64], weight: f64) {
1098 for (idx, value) in row.iter().enumerate() {
1099 self.sum[idx] += *value;
1100 self.weighted_sum[idx] += *value * weight;
1101 }
1102 self.weight_sum += weight;
1103 if self.first_row.is_none() {
1104 self.first_row = Some(row.to_vec());
1105 }
1106 if self.store_rows {
1107 self.rows.push(row.to_vec());
1108 }
1109 self.count += 1;
1110 }
1111
1112 fn mean(&self) -> Vec<f64> {
1113 self.sum
1114 .iter()
1115 .map(|value| *value / self.count as f64)
1116 .collect()
1117 }
1118
1119 fn weighted_mean(&self, unit_label: &str) -> Result<Vec<f64>> {
1120 if self.weight_sum <= 0.0 {
1121 return Err(DagMlError::OofValidation(format!(
1122 "aggregation unit `{unit_label}` has zero total prediction weight"
1123 )));
1124 }
1125 Ok(self
1126 .weighted_sum
1127 .iter()
1128 .map(|value| *value / self.weight_sum)
1129 .collect())
1130 }
1131
1132 fn median(&self) -> Vec<f64> {
1133 let width = self.sum.len();
1134 (0..width)
1135 .map(|column_idx| {
1136 let mut column = self
1137 .rows
1138 .iter()
1139 .map(|row| row[column_idx])
1140 .collect::<Vec<_>>();
1141 column.sort_by(f64::total_cmp);
1142 let middle = column.len() / 2;
1143 if column.len() % 2 == 1 {
1144 column[middle]
1145 } else {
1146 (column[middle - 1] + column[middle]) / 2.0
1147 }
1148 })
1149 .collect()
1150 }
1151
1152 fn vote(&self) -> Vec<f64> {
1153 let width = self.sum.len();
1154 (0..width)
1155 .map(|column_idx| {
1156 let mut column = self
1157 .rows
1158 .iter()
1159 .map(|row| row[column_idx])
1160 .collect::<Vec<_>>();
1161 column.sort_by(f64::total_cmp);
1162 mode_sorted(&column)
1163 })
1164 .collect()
1165 }
1166
1167 fn robust_mean(&self, trim_fraction: f64) -> Vec<f64> {
1168 let width = self.sum.len();
1169 (0..width)
1170 .map(|column_idx| {
1171 let mut column = self
1172 .rows
1173 .iter()
1174 .map(|row| row[column_idx])
1175 .collect::<Vec<_>>();
1176 column.sort_by(f64::total_cmp);
1177 let trim_count = ((column.len() as f64) * trim_fraction).floor() as usize;
1178 let max_trim = column.len().saturating_sub(1) / 2;
1179 let trim_count = trim_count.min(max_trim);
1180 let kept = &column[trim_count..column.len() - trim_count];
1181 kept.iter().sum::<f64>() / kept.len() as f64
1182 })
1183 .collect()
1184 }
1185}
1186
1187fn observation_weight(
1188 block: &ObservationPredictionBlock,
1189 policy: &AggregationPolicy,
1190 row_idx: usize,
1191) -> Result<f64> {
1192 if policy.method != AggregationMethod::WeightedMean {
1193 return Ok(1.0);
1194 }
1195 match policy.weights {
1196 AggregationWeights::ControllerEmitted | AggregationWeights::Quality => block
1197 .weights
1198 .get(row_idx)
1199 .copied()
1200 .ok_or_else(|| {
1201 DagMlError::OofValidation(format!(
1202 "weighted_mean aggregation with {:?} weights requires one weight per observation",
1203 policy.weights
1204 ))
1205 }),
1206 AggregationWeights::RepetitionCount => Ok(1.0),
1207 AggregationWeights::None => Err(DagMlError::OofValidation(
1208 "weighted_mean aggregation requires an explicit weights policy".to_string(),
1209 )),
1210 }
1211}
1212
1213fn mode_sorted(values: &[f64]) -> f64 {
1214 let mut best_value = values[0];
1215 let mut best_count = 1usize;
1216 let mut current_value = values[0];
1217 let mut current_count = 1usize;
1218 for value in values.iter().skip(1) {
1219 if *value == current_value {
1220 current_count += 1;
1221 continue;
1222 }
1223 if current_count > best_count {
1224 best_value = current_value;
1225 best_count = current_count;
1226 }
1227 current_value = *value;
1228 current_count = 1;
1229 }
1230 if current_count > best_count {
1231 current_value
1232 } else {
1233 best_value
1234 }
1235}
1236
1237fn default_aggregation_controller_task_schema_version() -> u32 {
1238 AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION
1239}
1240
1241fn default_aggregation_controller_result_schema_version() -> u32 {
1242 AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION
1243}
1244
1245#[cfg(test)]
1246mod tests {
1247 use super::*;
1248 use crate::ids::{ControllerId, GroupId, TargetId};
1249 use crate::relation::SampleRelation;
1250
1251 fn sid(value: &str) -> SampleId {
1252 SampleId::new(value).unwrap()
1253 }
1254
1255 fn oid(value: &str) -> ObservationId {
1256 ObservationId::new(value).unwrap()
1257 }
1258
1259 fn relation(observation: &str, sample: &str) -> SampleRelation {
1260 let mut relation = SampleRelation::new(oid(observation), sid(sample));
1261 relation.target_id = Some(TargetId::new(format!("target:{sample}")).unwrap());
1262 relation
1263 }
1264
1265 fn relation_with_units(
1266 observation: &str,
1267 sample: &str,
1268 target: &str,
1269 group: &str,
1270 ) -> SampleRelation {
1271 let mut relation = SampleRelation::new(oid(observation), sid(sample));
1272 relation.target_id = Some(TargetId::new(target).unwrap());
1273 relation.group_id = Some(GroupId::new(group).unwrap());
1274 relation
1275 }
1276
1277 fn combo_relation(observation: &str, sample: &str, components: &[&str]) -> SampleRelation {
1278 let mut relation = SampleRelation::new(oid(observation), sid(sample));
1279 relation.unit_level = EntityUnitLevel::Combo;
1280 relation.derived_unit_id = Some(format!("combo:{observation}"));
1281 relation.component_observation_ids =
1282 components.iter().map(|component| oid(component)).collect();
1283 relation
1284 }
1285
1286 fn custom_policy(level: PredictionLevel) -> AggregationPolicy {
1287 AggregationPolicy {
1288 aggregation_level: level,
1289 method: AggregationMethod::CustomController,
1290 custom_controller: Some(crate::policy::AggregationControllerSpec {
1291 controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1292 params: serde_json::json!({ "trim_fraction": 0.1 }),
1293 }),
1294 ..AggregationPolicy::default()
1295 }
1296 }
1297
1298 #[test]
1299 fn validates_custom_observation_aggregation_controller_result() {
1300 let reduction_plan = ReductionPlan {
1301 role: crate::policy::ReductionRole::FinalOutput,
1302 axis: ReductionAxis::Unit,
1303 input_unit_level: EntityUnitLevel::Observation,
1304 output_unit_level: EntityUnitLevel::PhysicalSample,
1305 method: ReductionMethod::Custom,
1306 custom_controller: Some(crate::policy::AggregationControllerSpec {
1307 controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1308 params: serde_json::json!({ "trim_fraction": 0.1 }),
1309 }),
1310 ..ReductionPlan::default()
1311 };
1312 let task = AggregationControllerTask {
1313 schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1314 task_id: "agg-task:obs.sample.fold0".to_string(),
1315 controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1316 policy: custom_policy(PredictionLevel::Sample),
1317 reduction_plan: Some(reduction_plan.clone()),
1318 input: AggregationControllerInput::ObservationToSample {
1319 block: ObservationPredictionBlock {
1320 prediction_id: Some("prediction:model.fold0".to_string()),
1321 producer_node: NodeId::new("model:pls").unwrap(),
1322 partition: PredictionPartition::Validation,
1323 fold_id: Some(FoldId::new("fold:0").unwrap()),
1324 observation_ids: vec![oid("obs:1"), oid("obs:2"), oid("obs:3")],
1325 values: vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![9.0, 10.0]],
1326 weights: Vec::new(),
1327 target_names: vec!["moisture".to_string(), "protein".to_string()],
1328 },
1329 relations: SampleRelationSet {
1330 records: vec![
1331 relation("obs:1", "sample:1"),
1332 relation("obs:2", "sample:1"),
1333 relation("obs:3", "sample:2"),
1334 ],
1335 },
1336 requested_sample_order: vec![sid("sample:1"), sid("sample:2")],
1337 },
1338 };
1339 task.validate().unwrap();
1340
1341 let result = AggregationControllerResult {
1342 schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1343 task_id: task.task_id.clone(),
1344 reduction_plan: Some(reduction_plan),
1345 output: AggregationControllerOutput::Sample {
1346 block: PredictionBlock {
1347 prediction_id: Some("prediction:model.fold0:custom_sample_agg".to_string()),
1348 producer_node: NodeId::new("model:pls").unwrap(),
1349 partition: PredictionPartition::Validation,
1350 fold_id: Some(FoldId::new("fold:0").unwrap()),
1351 sample_ids: vec![sid("sample:1"), sid("sample:2")],
1352 values: vec![vec![2.0, 3.0], vec![9.0, 10.0]],
1353 target_names: vec!["moisture".to_string(), "protein".to_string()],
1354 },
1355 },
1356 };
1357
1358 result.validate_for_task(&task).unwrap();
1359 }
1360
1361 #[test]
1362 fn custom_aggregation_controller_result_must_echo_reduction_plan() {
1363 let reduction_plan = ReductionPlan {
1364 method: ReductionMethod::Custom,
1365 custom_controller: Some(crate::policy::AggregationControllerSpec {
1366 controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1367 params: serde_json::json!({}),
1368 }),
1369 ..ReductionPlan::default()
1370 };
1371 let task = AggregationControllerTask {
1372 schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1373 task_id: "agg-task:obs.sample.fold0".to_string(),
1374 controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1375 policy: custom_policy(PredictionLevel::Sample),
1376 reduction_plan: Some(reduction_plan),
1377 input: AggregationControllerInput::ObservationToSample {
1378 block: ObservationPredictionBlock {
1379 prediction_id: None,
1380 producer_node: NodeId::new("model:pls").unwrap(),
1381 partition: PredictionPartition::Validation,
1382 fold_id: None,
1383 observation_ids: vec![oid("obs:1")],
1384 values: vec![vec![1.0]],
1385 weights: Vec::new(),
1386 target_names: vec!["y".to_string()],
1387 },
1388 relations: SampleRelationSet {
1389 records: vec![relation("obs:1", "sample:1")],
1390 },
1391 requested_sample_order: vec![sid("sample:1")],
1392 },
1393 };
1394 let result = AggregationControllerResult {
1395 schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1396 task_id: task.task_id.clone(),
1397 reduction_plan: None,
1398 output: AggregationControllerOutput::Sample {
1399 block: PredictionBlock {
1400 prediction_id: None,
1401 producer_node: NodeId::new("model:pls").unwrap(),
1402 partition: PredictionPartition::Validation,
1403 fold_id: None,
1404 sample_ids: vec![sid("sample:1")],
1405 values: vec![vec![1.0]],
1406 target_names: vec!["y".to_string()],
1407 },
1408 },
1409 };
1410
1411 let error = result.validate_for_task(&task).unwrap_err().to_string();
1412
1413 assert!(error.contains("echo task reduction_plan"));
1414 }
1415
1416 #[test]
1417 fn custom_aggregation_controller_result_refuses_order_mismatch() {
1418 let task = AggregationControllerTask {
1419 schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1420 task_id: "agg-task:obs.sample.fold0".to_string(),
1421 controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1422 policy: custom_policy(PredictionLevel::Sample),
1423 reduction_plan: None,
1424 input: AggregationControllerInput::ObservationToSample {
1425 block: ObservationPredictionBlock {
1426 prediction_id: None,
1427 producer_node: NodeId::new("model:pls").unwrap(),
1428 partition: PredictionPartition::Validation,
1429 fold_id: None,
1430 observation_ids: vec![oid("obs:1"), oid("obs:2")],
1431 values: vec![vec![1.0], vec![2.0]],
1432 weights: Vec::new(),
1433 target_names: vec!["y".to_string()],
1434 },
1435 relations: SampleRelationSet {
1436 records: vec![relation("obs:1", "sample:1"), relation("obs:2", "sample:2")],
1437 },
1438 requested_sample_order: vec![sid("sample:1"), sid("sample:2")],
1439 },
1440 };
1441 let result = AggregationControllerResult {
1442 schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1443 task_id: task.task_id.clone(),
1444 reduction_plan: None,
1445 output: AggregationControllerOutput::Sample {
1446 block: PredictionBlock {
1447 prediction_id: None,
1448 producer_node: NodeId::new("model:pls").unwrap(),
1449 partition: PredictionPartition::Validation,
1450 fold_id: None,
1451 sample_ids: vec![sid("sample:2"), sid("sample:1")],
1452 values: vec![vec![2.0], vec![1.0]],
1453 target_names: vec!["y".to_string()],
1454 },
1455 },
1456 };
1457
1458 let error = result.validate_for_task(&task).unwrap_err().to_string();
1459 assert!(error.contains("requested sample order"));
1460 }
1461
1462 #[test]
1463 fn validates_custom_sample_to_group_aggregation_controller_result() {
1464 let task = AggregationControllerTask {
1465 schema_version: AGGREGATION_CONTROLLER_TASK_SCHEMA_VERSION,
1466 task_id: "agg-task:sample.group.fold0".to_string(),
1467 controller_id: ControllerId::new("controller:agg.trimmed").unwrap(),
1468 policy: custom_policy(PredictionLevel::Group),
1469 reduction_plan: None,
1470 input: AggregationControllerInput::SampleToUnit {
1471 block: PredictionBlock {
1472 prediction_id: Some("prediction:model.fold0".to_string()),
1473 producer_node: NodeId::new("model:pls").unwrap(),
1474 partition: PredictionPartition::Validation,
1475 fold_id: Some(FoldId::new("fold:0").unwrap()),
1476 sample_ids: vec![sid("sample:1"), sid("sample:2"), sid("sample:3")],
1477 values: vec![vec![1.0], vec![3.0], vec![10.0]],
1478 target_names: vec!["y".to_string()],
1479 },
1480 relations: SampleRelationSet {
1481 records: vec![
1482 relation_with_units("obs:1", "sample:1", "target:1", "group:left"),
1483 relation_with_units("obs:2", "sample:2", "target:2", "group:left"),
1484 relation_with_units("obs:3", "sample:3", "target:3", "group:right"),
1485 ],
1486 },
1487 requested_unit_order: vec![
1488 PredictionUnitId::Group(GroupId::new("group:left").unwrap()),
1489 PredictionUnitId::Group(GroupId::new("group:right").unwrap()),
1490 ],
1491 },
1492 };
1493 task.validate().unwrap();
1494
1495 let result = AggregationControllerResult {
1496 schema_version: AGGREGATION_CONTROLLER_RESULT_SCHEMA_VERSION,
1497 task_id: task.task_id.clone(),
1498 reduction_plan: None,
1499 output: AggregationControllerOutput::Unit {
1500 block: AggregatedPredictionBlock {
1501 prediction_id: Some("prediction:model.fold0:custom_group_agg".to_string()),
1502 producer_node: NodeId::new("model:pls").unwrap(),
1503 partition: PredictionPartition::Validation,
1504 fold_id: Some(FoldId::new("fold:0").unwrap()),
1505 level: PredictionLevel::Group,
1506 unit_ids: vec![
1507 PredictionUnitId::Group(GroupId::new("group:left").unwrap()),
1508 PredictionUnitId::Group(GroupId::new("group:right").unwrap()),
1509 ],
1510 values: vec![vec![2.0], vec![10.0]],
1511 target_names: vec!["y".to_string()],
1512 },
1513 },
1514 };
1515
1516 result.validate_for_task(&task).unwrap();
1517 }
1518
1519 #[test]
1520 fn averages_repeated_observation_predictions_by_sample() {
1521 let block = ObservationPredictionBlock {
1522 prediction_id: Some("pred:oof".to_string()),
1523 producer_node: NodeId::new("model:pls").unwrap(),
1524 partition: PredictionPartition::Validation,
1525 fold_id: Some(FoldId::new("fold:0").unwrap()),
1526 observation_ids: vec![oid("obs:1a"), oid("obs:1b"), oid("obs:2a")],
1527 values: vec![vec![1.0], vec![3.0], vec![10.0]],
1528 weights: Vec::new(),
1529 target_names: vec!["y".to_string()],
1530 };
1531 let relations = SampleRelationSet {
1532 records: vec![
1533 relation("obs:1a", "sample:1"),
1534 relation("obs:1b", "sample:1"),
1535 relation("obs:2a", "sample:2"),
1536 ],
1537 };
1538
1539 let aggregated = aggregate_observation_predictions(
1540 &block,
1541 &relations,
1542 &AggregationPolicy::default(),
1543 &[sid("sample:1"), sid("sample:2")],
1544 )
1545 .unwrap();
1546
1547 assert_eq!(
1548 aggregated.sample_ids,
1549 vec![sid("sample:1"), sid("sample:2")]
1550 );
1551 assert_eq!(aggregated.values, vec![vec![2.0], vec![10.0]]);
1552 }
1553
1554 #[test]
1555 fn aggregates_relation_backed_combo_predictions_by_sample() {
1556 let relations = SampleRelationSet {
1557 records: vec![
1558 relation("obs:s1.a", "sample:1"),
1559 relation("obs:s1.b", "sample:1"),
1560 relation("obs:s2.a", "sample:2"),
1561 relation("obs:s2.b", "sample:2"),
1562 combo_relation("obs:s1.combo", "sample:1", &["obs:s1.a", "obs:s1.b"]),
1563 combo_relation("obs:s2.combo", "sample:2", &["obs:s2.a", "obs:s2.b"]),
1564 ],
1565 };
1566 let block = ObservationPredictionBlock {
1567 prediction_id: Some("pred:combo".to_string()),
1568 producer_node: NodeId::new("model:combo").unwrap(),
1569 partition: PredictionPartition::Validation,
1570 fold_id: Some(FoldId::new("fold:0").unwrap()),
1571 observation_ids: vec![oid("obs:s1.combo"), oid("obs:s2.combo")],
1572 values: vec![vec![5.0], vec![9.0]],
1573 weights: Vec::new(),
1574 target_names: vec!["y".to_string()],
1575 };
1576
1577 let aggregated = aggregate_observation_predictions(
1578 &block,
1579 &relations,
1580 &AggregationPolicy::default(),
1581 &[sid("sample:1"), sid("sample:2")],
1582 )
1583 .unwrap();
1584
1585 assert_eq!(aggregated.values, vec![vec![5.0], vec![9.0]]);
1586 }
1587
1588 #[test]
1589 fn robust_mean_trims_extreme_repeated_predictions() {
1590 let observations = (0..10)
1591 .map(|idx| format!("obs:s1.{idx}"))
1592 .collect::<Vec<_>>();
1593 let relations = SampleRelationSet {
1594 records: observations
1595 .iter()
1596 .map(|observation| relation(observation, "sample:1"))
1597 .collect(),
1598 };
1599 let block = ObservationPredictionBlock {
1600 prediction_id: Some("pred:robust".to_string()),
1601 producer_node: NodeId::new("model:pls").unwrap(),
1602 partition: PredictionPartition::Validation,
1603 fold_id: Some(FoldId::new("fold:0").unwrap()),
1604 observation_ids: observations
1605 .iter()
1606 .map(|observation| oid(observation))
1607 .collect(),
1608 values: vec![
1609 vec![0.0],
1610 vec![1.0],
1611 vec![2.0],
1612 vec![3.0],
1613 vec![4.0],
1614 vec![5.0],
1615 vec![6.0],
1616 vec![7.0],
1617 vec![8.0],
1618 vec![100.0],
1619 ],
1620 weights: Vec::new(),
1621 target_names: vec!["y".to_string()],
1622 };
1623
1624 let aggregated = aggregate_observation_predictions(
1625 &block,
1626 &relations,
1627 &AggregationPolicy {
1628 method: AggregationMethod::RobustMean,
1629 ..AggregationPolicy::default()
1630 },
1631 &[sid("sample:1")],
1632 )
1633 .unwrap();
1634
1635 assert_eq!(aggregated.values, vec![vec![4.5]]);
1636 }
1637
1638 #[test]
1639 fn exclude_outliers_requires_custom_controller_path() {
1640 let relations = SampleRelationSet {
1641 records: vec![relation("obs:1", "sample:1")],
1642 };
1643 let block = ObservationPredictionBlock {
1644 prediction_id: None,
1645 producer_node: NodeId::new("model:pls").unwrap(),
1646 partition: PredictionPartition::Validation,
1647 fold_id: None,
1648 observation_ids: vec![oid("obs:1")],
1649 values: vec![vec![1.0]],
1650 weights: Vec::new(),
1651 target_names: vec!["y".to_string()],
1652 };
1653
1654 let error = aggregate_observation_predictions(
1655 &block,
1656 &relations,
1657 &AggregationPolicy {
1658 method: AggregationMethod::ExcludeOutliers,
1659 ..AggregationPolicy::default()
1660 },
1661 &[sid("sample:1")],
1662 )
1663 .unwrap_err()
1664 .to_string();
1665
1666 assert!(error.contains("custom aggregation controller"));
1667 }
1668
1669 #[test]
1670 fn aggregates_repeated_predictions_with_median_vote_and_weights() {
1671 let relations = SampleRelationSet {
1672 records: vec![
1673 relation("obs:1a", "sample:1"),
1674 relation("obs:1b", "sample:1"),
1675 relation("obs:1c", "sample:1"),
1676 relation("obs:2a", "sample:2"),
1677 relation("obs:2b", "sample:2"),
1678 ],
1679 };
1680 let base_block = ObservationPredictionBlock {
1681 prediction_id: Some("pred:oof".to_string()),
1682 producer_node: NodeId::new("model:pls").unwrap(),
1683 partition: PredictionPartition::Validation,
1684 fold_id: Some(FoldId::new("fold:0").unwrap()),
1685 observation_ids: vec![
1686 oid("obs:1a"),
1687 oid("obs:1b"),
1688 oid("obs:1c"),
1689 oid("obs:2a"),
1690 oid("obs:2b"),
1691 ],
1692 values: vec![
1693 vec![1.0, 0.0],
1694 vec![5.0, 1.0],
1695 vec![9.0, 1.0],
1696 vec![10.0, 2.0],
1697 vec![30.0, 3.0],
1698 ],
1699 weights: Vec::new(),
1700 target_names: vec!["regression".to_string(), "class".to_string()],
1701 };
1702 let sample_order = [sid("sample:1"), sid("sample:2")];
1703
1704 let median_policy = AggregationPolicy {
1705 method: AggregationMethod::Median,
1706 ..AggregationPolicy::default()
1707 };
1708 let median = aggregate_observation_predictions(
1709 &base_block,
1710 &relations,
1711 &median_policy,
1712 &sample_order,
1713 )
1714 .unwrap();
1715 assert_eq!(median.values, vec![vec![5.0, 1.0], vec![20.0, 2.5]]);
1716
1717 let vote_policy = AggregationPolicy {
1718 method: AggregationMethod::Vote,
1719 ..AggregationPolicy::default()
1720 };
1721 let vote =
1722 aggregate_observation_predictions(&base_block, &relations, &vote_policy, &sample_order)
1723 .unwrap();
1724 assert_eq!(vote.values, vec![vec![1.0, 1.0], vec![10.0, 2.0]]);
1725
1726 let mut weighted_block = base_block;
1727 weighted_block.weights = vec![1.0, 1.0, 2.0, 1.0, 3.0];
1728 let weighted_policy = AggregationPolicy {
1729 method: AggregationMethod::WeightedMean,
1730 weights: AggregationWeights::ControllerEmitted,
1731 ..AggregationPolicy::default()
1732 };
1733 let weighted = aggregate_observation_predictions(
1734 &weighted_block,
1735 &relations,
1736 &weighted_policy,
1737 &sample_order,
1738 )
1739 .unwrap();
1740 assert_eq!(weighted.values, vec![vec![6.0, 0.75], vec![25.0, 2.75]]);
1741 }
1742
1743 #[test]
1744 fn refuses_incompatible_observation_weight_contracts() {
1745 let relations = SampleRelationSet {
1746 records: vec![
1747 relation("obs:1a", "sample:1"),
1748 relation("obs:1b", "sample:1"),
1749 ],
1750 };
1751 let block = ObservationPredictionBlock {
1752 prediction_id: None,
1753 producer_node: NodeId::new("model:pls").unwrap(),
1754 partition: PredictionPartition::Validation,
1755 fold_id: None,
1756 observation_ids: vec![oid("obs:1a"), oid("obs:1b")],
1757 values: vec![vec![1.0], vec![2.0]],
1758 weights: vec![1.0, 2.0],
1759 target_names: vec!["y".to_string()],
1760 };
1761
1762 let mean_error = aggregate_observation_predictions(
1763 &block,
1764 &relations,
1765 &AggregationPolicy::default(),
1766 &[sid("sample:1")],
1767 )
1768 .unwrap_err()
1769 .to_string();
1770 assert!(
1771 mean_error.contains("non-weighted aggregation"),
1772 "unexpected mean error: {mean_error}"
1773 );
1774
1775 let mut missing_weights_block = block;
1776 missing_weights_block.weights.clear();
1777 let weighted_error = aggregate_observation_predictions(
1778 &missing_weights_block,
1779 &relations,
1780 &AggregationPolicy {
1781 method: AggregationMethod::WeightedMean,
1782 weights: AggregationWeights::ControllerEmitted,
1783 ..AggregationPolicy::default()
1784 },
1785 &[sid("sample:1")],
1786 )
1787 .unwrap_err()
1788 .to_string();
1789 assert!(
1790 weighted_error.contains("requires one weight per observation"),
1791 "unexpected weighted error: {weighted_error}"
1792 );
1793 }
1794
1795 #[test]
1796 fn aggregates_sample_predictions_to_target_and_group_units() {
1797 let relations = SampleRelationSet {
1798 records: vec![
1799 relation_with_units("obs:s1:a", "sample:1", "target:a", "group:left"),
1800 relation_with_units("obs:s1:b", "sample:1", "target:a", "group:left"),
1801 relation_with_units("obs:s2:a", "sample:2", "target:a", "group:left"),
1802 relation_with_units("obs:s3:a", "sample:3", "target:b", "group:right"),
1803 ],
1804 };
1805 let block = PredictionBlock {
1806 prediction_id: Some("pred:sample".to_string()),
1807 producer_node: NodeId::new("model:pls").unwrap(),
1808 partition: PredictionPartition::Validation,
1809 fold_id: Some(FoldId::new("fold:0").unwrap()),
1810 sample_ids: vec![sid("sample:1"), sid("sample:2"), sid("sample:3")],
1811 values: vec![vec![10.0], vec![4.0], vec![30.0]],
1812 target_names: vec!["y".to_string()],
1813 };
1814
1815 let target_policy = AggregationPolicy {
1816 aggregation_level: PredictionLevel::Target,
1817 method: AggregationMethod::Mean,
1818 ..AggregationPolicy::default()
1819 };
1820 let by_target = aggregate_sample_predictions_by_unit(
1821 &block,
1822 &relations,
1823 &target_policy,
1824 &[
1825 PredictionUnitId::Target(TargetId::new("target:a").unwrap()),
1826 PredictionUnitId::Target(TargetId::new("target:b").unwrap()),
1827 ],
1828 )
1829 .unwrap();
1830 assert_eq!(by_target.level, PredictionLevel::Target);
1831 assert_eq!(by_target.values, vec![vec![7.0], vec![30.0]]);
1832
1833 let group_policy = AggregationPolicy {
1834 aggregation_level: PredictionLevel::Group,
1835 method: AggregationMethod::WeightedMean,
1836 weights: AggregationWeights::RepetitionCount,
1837 ..AggregationPolicy::default()
1838 };
1839 let by_group = aggregate_sample_predictions_by_unit(
1840 &block,
1841 &relations,
1842 &group_policy,
1843 &[
1844 PredictionUnitId::Group(GroupId::new("group:left").unwrap()),
1845 PredictionUnitId::Group(GroupId::new("group:right").unwrap()),
1846 ],
1847 )
1848 .unwrap();
1849 assert_eq!(by_group.level, PredictionLevel::Group);
1850 assert_eq!(by_group.values, vec![vec![8.0], vec![30.0]]);
1851 }
1852
1853 #[test]
1854 fn refuses_target_group_aggregation_without_relation_units() {
1855 let relations = SampleRelationSet {
1856 records: vec![SampleRelation::new(oid("obs:1"), sid("sample:1"))],
1857 };
1858 let block = PredictionBlock {
1859 prediction_id: None,
1860 producer_node: NodeId::new("model:pls").unwrap(),
1861 partition: PredictionPartition::Validation,
1862 fold_id: None,
1863 sample_ids: vec![sid("sample:1")],
1864 values: vec![vec![1.0]],
1865 target_names: vec!["y".to_string()],
1866 };
1867
1868 let error = aggregate_sample_predictions_by_unit(
1869 &block,
1870 &relations,
1871 &AggregationPolicy {
1872 aggregation_level: PredictionLevel::Target,
1873 method: AggregationMethod::Mean,
1874 ..AggregationPolicy::default()
1875 },
1876 &[PredictionUnitId::Target(
1877 TargetId::new("target:missing").unwrap(),
1878 )],
1879 )
1880 .unwrap_err()
1881 .to_string();
1882 assert!(
1883 error.contains("missing target id"),
1884 "unexpected target aggregation error: {error}"
1885 );
1886 }
1887
1888 #[test]
1889 fn refuses_missing_observation_relation() {
1890 let block = ObservationPredictionBlock {
1891 prediction_id: None,
1892 producer_node: NodeId::new("model:pls").unwrap(),
1893 partition: PredictionPartition::Validation,
1894 fold_id: None,
1895 observation_ids: vec![oid("obs:missing")],
1896 values: vec![vec![1.0]],
1897 weights: Vec::new(),
1898 target_names: vec!["y".to_string()],
1899 };
1900
1901 assert!(aggregate_observation_predictions(
1902 &block,
1903 &SampleRelationSet::default(),
1904 &AggregationPolicy::default(),
1905 &[sid("sample:1")]
1906 )
1907 .is_err());
1908 }
1909}