Skip to main content

dag_ml_core/
relation.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, Result};
7use crate::fold::FoldSet;
8use crate::ids::{GroupId, ObservationId, SampleId, TargetId};
9use crate::policy::{LeakageUnitPolicy, SplitUnit};
10
11#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FoldPartition {
14    Train,
15    Validation,
16}
17
18#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum EntityUnitLevel {
21    PhysicalSample,
22    SourceSample,
23    #[default]
24    Observation,
25    Combo,
26}
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub struct SampleRelation {
30    #[serde(default)]
31    pub unit_level: EntityUnitLevel,
32    #[serde(default)]
33    pub unit_id: Option<String>,
34    pub observation_id: ObservationId,
35    pub sample_id: SampleId,
36    #[serde(default)]
37    pub source_id: Option<String>,
38    #[serde(default)]
39    pub rep_id: Option<String>,
40    #[serde(default)]
41    pub target_id: Option<TargetId>,
42    #[serde(default)]
43    pub group_id: Option<GroupId>,
44    #[serde(default)]
45    pub origin_sample_id: Option<SampleId>,
46    #[serde(default)]
47    pub derived_unit_id: Option<String>,
48    #[serde(default)]
49    pub component_observation_ids: Vec<ObservationId>,
50    #[serde(default)]
51    pub sample_influence_weight: Option<f64>,
52    #[serde(default)]
53    pub quality_flag: Option<String>,
54    #[serde(default)]
55    pub is_augmented: bool,
56}
57
58impl SampleRelation {
59    pub fn new(observation_id: ObservationId, sample_id: SampleId) -> Self {
60        Self {
61            unit_level: EntityUnitLevel::Observation,
62            unit_id: None,
63            observation_id,
64            sample_id,
65            source_id: None,
66            rep_id: None,
67            target_id: None,
68            group_id: None,
69            origin_sample_id: None,
70            derived_unit_id: None,
71            component_observation_ids: Vec::new(),
72            sample_influence_weight: None,
73            quality_flag: None,
74            is_augmented: false,
75        }
76    }
77
78    pub fn effective_unit_id(&self) -> Result<String> {
79        if let Some(unit_id) = non_empty_optional("unit_id", &self.observation_id, &self.unit_id)? {
80            return Ok(unit_id.to_string());
81        }
82
83        match self.unit_level {
84            EntityUnitLevel::PhysicalSample => Ok(self.sample_id.to_string()),
85            EntityUnitLevel::SourceSample => {
86                let source_id =
87                    non_empty_optional("source_id", &self.observation_id, &self.source_id)?
88                        .ok_or_else(|| {
89                            DagMlError::CampaignValidation(format!(
90                                "source-sample relation `{}` requires source_id",
91                                self.observation_id
92                            ))
93                        })?;
94                Ok(format!("{}::{source_id}", self.sample_id))
95            }
96            EntityUnitLevel::Observation => Ok(self.observation_id.to_string()),
97            EntityUnitLevel::Combo => {
98                let derived_unit_id = non_empty_optional(
99                    "derived_unit_id",
100                    &self.observation_id,
101                    &self.derived_unit_id,
102                )?
103                .ok_or_else(|| {
104                    DagMlError::CampaignValidation(format!(
105                        "combo relation `{}` requires derived_unit_id",
106                        self.observation_id
107                    ))
108                })?;
109                Ok(derived_unit_id.to_string())
110            }
111        }
112    }
113
114    fn validate(&self) -> Result<()> {
115        non_empty_optional("unit_id", &self.observation_id, &self.unit_id)?;
116        non_empty_optional("source_id", &self.observation_id, &self.source_id)?;
117        non_empty_optional(
118            "derived_unit_id",
119            &self.observation_id,
120            &self.derived_unit_id,
121        )?;
122        non_empty_optional("quality_flag", &self.observation_id, &self.quality_flag)?;
123        validate_optional_identifier("rep_id", &self.observation_id, &self.rep_id)?;
124
125        if let Some(weight) = self.sample_influence_weight {
126            if !weight.is_finite() || weight <= 0.0 {
127                return Err(DagMlError::CampaignValidation(format!(
128                    "relation `{}` has invalid sample_influence_weight",
129                    self.observation_id
130                )));
131            }
132        }
133
134        if self.unit_level != EntityUnitLevel::Combo && !self.component_observation_ids.is_empty() {
135            return Err(DagMlError::CampaignValidation(format!(
136                "relation `{}` has component_observation_ids but is not a combo",
137                self.observation_id
138            )));
139        }
140
141        self.effective_unit_id()?;
142        Ok(())
143    }
144}
145
146#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
147pub struct SampleRelationSet {
148    #[serde(default)]
149    pub records: Vec<SampleRelation>,
150}
151
152pub fn relation_set_fingerprint(relations: &SampleRelationSet) -> Result<String> {
153    relations.fingerprint()
154}
155
156#[derive(Clone, Debug, Serialize)]
157struct CanonicalRelationRecord {
158    effective_unit_id: String,
159    unit_level: EntityUnitLevel,
160    unit_id: Option<String>,
161    observation_id: ObservationId,
162    sample_id: SampleId,
163    source_id: Option<String>,
164    rep_id: Option<String>,
165    target_id: Option<TargetId>,
166    group_id: Option<GroupId>,
167    origin_sample_id: Option<SampleId>,
168    derived_unit_id: Option<String>,
169    component_observation_ids: Vec<ObservationId>,
170    sample_influence_weight: Option<f64>,
171    quality_flag: Option<String>,
172    is_augmented: bool,
173}
174
175impl SampleRelationSet {
176    pub fn validate(&self) -> Result<()> {
177        let mut observations = BTreeSet::new();
178        let mut observation_samples = BTreeMap::<ObservationId, SampleId>::new();
179        let mut unit_ids = BTreeMap::<String, ObservationId>::new();
180        let mut sample_targets = BTreeMap::<SampleId, TargetId>::new();
181        let mut sample_groups = BTreeMap::<SampleId, GroupId>::new();
182        for record in &self.records {
183            record.validate()?;
184            if !observations.insert(&record.observation_id) {
185                return Err(DagMlError::CampaignValidation(format!(
186                    "duplicate observation relation `{}`",
187                    record.observation_id
188                )));
189            }
190            observation_samples.insert(record.observation_id.clone(), record.sample_id.clone());
191            let effective_unit_id = record.effective_unit_id()?;
192            if let Some(previous) =
193                unit_ids.insert(effective_unit_id.clone(), record.observation_id.clone())
194            {
195                return Err(DagMlError::CampaignValidation(format!(
196                    "relations `{previous}` and `{}` share effective unit id `{effective_unit_id}`",
197                    record.observation_id
198                )));
199            }
200            if let Some(target_id) = &record.target_id {
201                if let Some(previous) = sample_targets.get(&record.sample_id) {
202                    if previous != target_id {
203                        return Err(DagMlError::CampaignValidation(format!(
204                            "sample `{}` maps to multiple targets",
205                            record.sample_id
206                        )));
207                    }
208                } else {
209                    sample_targets.insert(record.sample_id.clone(), target_id.clone());
210                }
211            }
212            if let Some(group_id) = &record.group_id {
213                if let Some(previous) = sample_groups.get(&record.sample_id) {
214                    if previous != group_id {
215                        return Err(DagMlError::CampaignValidation(format!(
216                            "sample `{}` maps to multiple groups",
217                            record.sample_id
218                        )));
219                    }
220                } else {
221                    sample_groups.insert(record.sample_id.clone(), group_id.clone());
222                }
223            }
224        }
225        for record in &self.records {
226            validate_combo_record(record, &observation_samples)?;
227        }
228        Ok(())
229    }
230
231    pub fn fingerprint(&self) -> Result<String> {
232        self.validate()?;
233        let mut canonical = self
234            .records
235            .iter()
236            .map(|record| {
237                let effective_unit_id = record.effective_unit_id()?;
238                Ok(CanonicalRelationRecord {
239                    effective_unit_id,
240                    unit_level: record.unit_level,
241                    unit_id: record.unit_id.clone(),
242                    observation_id: record.observation_id.clone(),
243                    sample_id: record.sample_id.clone(),
244                    source_id: record.source_id.clone(),
245                    rep_id: record.rep_id.clone(),
246                    target_id: record.target_id.clone(),
247                    group_id: record.group_id.clone(),
248                    origin_sample_id: record.origin_sample_id.clone(),
249                    derived_unit_id: record.derived_unit_id.clone(),
250                    component_observation_ids: record.component_observation_ids.clone(),
251                    sample_influence_weight: record.sample_influence_weight,
252                    quality_flag: record.quality_flag.clone(),
253                    is_augmented: record.is_augmented,
254                })
255            })
256            .collect::<Result<Vec<_>>>()?;
257        canonical.sort_by(|left, right| {
258            (
259                left.effective_unit_id.as_str(),
260                left.observation_id.as_str(),
261                left.sample_id.as_str(),
262            )
263                .cmp(&(
264                    right.effective_unit_id.as_str(),
265                    right.observation_id.as_str(),
266                    right.sample_id.as_str(),
267                ))
268        });
269        stable_json_fingerprint(&canonical)
270    }
271
272    pub fn validate_against_fold_set(
273        &self,
274        fold_set: &FoldSet,
275        policy: &LeakageUnitPolicy,
276    ) -> Result<()> {
277        self.validate()?;
278        fold_set.validate()?;
279        policy.validate()?;
280
281        let universe = fold_set.sample_ids.iter().collect::<BTreeSet<_>>();
282        for record in &self.records {
283            if !universe.contains(&record.sample_id) {
284                return Err(DagMlError::CampaignValidation(format!(
285                    "relation `{}` references sample `{}` outside fold set",
286                    record.observation_id, record.sample_id
287                )));
288            }
289            if let Some(origin_sample_id) = &record.origin_sample_id {
290                if !universe.contains(origin_sample_id) {
291                    return Err(DagMlError::CampaignValidation(format!(
292                        "relation `{}` references origin sample `{}` outside fold set",
293                        record.observation_id, origin_sample_id
294                    )));
295                }
296            }
297            if policy.require_group_ids && record.group_id.is_none() {
298                return Err(DagMlError::CampaignValidation(format!(
299                    "relation `{}` is missing required group id",
300                    record.observation_id
301                )));
302            }
303        }
304
305        let sample_to_target = self.sample_targets();
306        let sample_to_group = self.sample_groups();
307        validate_fold_set_groups_match_relations(fold_set, &sample_to_group)?;
308
309        for fold in &fold_set.folds {
310            let partitions = fold
311                .train_sample_ids
312                .iter()
313                .map(|sample_id| (sample_id, FoldPartition::Train))
314                .chain(
315                    fold.validation_sample_ids
316                        .iter()
317                        .map(|sample_id| (sample_id, FoldPartition::Validation)),
318                )
319                .collect::<BTreeMap<_, _>>();
320
321            if policy.forbid_origin_cross_fold {
322                for record in &self.records {
323                    if let Some(origin_sample_id) = &record.origin_sample_id {
324                        let sample_partition =
325                            partitions.get(&record.sample_id).ok_or_else(|| {
326                                DagMlError::CampaignValidation(format!(
327                                    "fold `{}` does not contain sample `{}`",
328                                    fold.fold_id, record.sample_id
329                                ))
330                            })?;
331                        let origin_partition =
332                            partitions.get(origin_sample_id).ok_or_else(|| {
333                                DagMlError::CampaignValidation(format!(
334                                    "fold `{}` does not contain origin sample `{}`",
335                                    fold.fold_id, origin_sample_id
336                                ))
337                            })?;
338                        if sample_partition != origin_partition {
339                            return Err(DagMlError::CampaignValidation(format!(
340                                "fold `{}` leaks origin sample `{}` into {:?} sample `{}`",
341                                fold.fold_id, origin_sample_id, sample_partition, record.sample_id
342                            )));
343                        }
344                    }
345                }
346            }
347
348            match policy.split_unit {
349                SplitUnit::PhysicalSample | SplitUnit::Observation | SplitUnit::Sample => {}
350                SplitUnit::Target => validate_unit_partitions(
351                    &fold.fold_id.to_string(),
352                    "target",
353                    &partitions,
354                    &sample_to_target,
355                )?,
356                SplitUnit::Group => validate_unit_partitions(
357                    &fold.fold_id.to_string(),
358                    "group",
359                    &partitions,
360                    &sample_to_group,
361                )?,
362            }
363        }
364        Ok(())
365    }
366
367    pub fn sample_for_observation(&self, observation_id: &ObservationId) -> Option<&SampleId> {
368        self.records
369            .iter()
370            .find(|record| &record.observation_id == observation_id)
371            .map(|record| &record.sample_id)
372    }
373
374    pub fn target_for_sample(&self, sample_id: &SampleId) -> Option<&TargetId> {
375        self.records
376            .iter()
377            .find(|record| &record.sample_id == sample_id)
378            .and_then(|record| record.target_id.as_ref())
379    }
380
381    pub fn group_for_sample(&self, sample_id: &SampleId) -> Option<&GroupId> {
382        self.records
383            .iter()
384            .find(|record| &record.sample_id == sample_id)
385            .and_then(|record| record.group_id.as_ref())
386    }
387
388    pub fn observation_count_for_sample(&self, sample_id: &SampleId) -> usize {
389        self.records
390            .iter()
391            .filter(|record| &record.sample_id == sample_id)
392            .count()
393    }
394
395    pub fn sample_targets(&self) -> BTreeMap<SampleId, TargetId> {
396        self.records
397            .iter()
398            .filter_map(|record| {
399                record
400                    .target_id
401                    .as_ref()
402                    .map(|target_id| (record.sample_id.clone(), target_id.clone()))
403            })
404            .collect()
405    }
406
407    pub fn sample_groups(&self) -> BTreeMap<SampleId, GroupId> {
408        self.records
409            .iter()
410            .filter_map(|record| {
411                record
412                    .group_id
413                    .as_ref()
414                    .map(|group_id| (record.sample_id.clone(), group_id.clone()))
415            })
416            .collect()
417    }
418}
419
420fn non_empty_optional<'a>(
421    field: &str,
422    observation_id: &ObservationId,
423    value: &'a Option<String>,
424) -> Result<Option<&'a str>> {
425    if let Some(value) = value.as_deref() {
426        if value.trim().is_empty() {
427            return Err(DagMlError::CampaignValidation(format!(
428                "relation `{observation_id}` has empty {field}"
429            )));
430        }
431        Ok(Some(value))
432    } else {
433        Ok(None)
434    }
435}
436
437fn validate_optional_identifier(
438    field: &str,
439    observation_id: &ObservationId,
440    value: &Option<String>,
441) -> Result<()> {
442    let Some(value) = non_empty_optional(field, observation_id, value)? else {
443        return Ok(());
444    };
445    if value.len() > 128
446        || !value
447            .bytes()
448            .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b':'))
449    {
450        return Err(DagMlError::CampaignValidation(format!(
451            "relation `{observation_id}` has invalid {field}"
452        )));
453    }
454    Ok(())
455}
456
457fn validate_combo_record(
458    record: &SampleRelation,
459    observation_samples: &BTreeMap<ObservationId, SampleId>,
460) -> Result<()> {
461    if record.unit_level != EntityUnitLevel::Combo {
462        return Ok(());
463    }
464    if record.component_observation_ids.is_empty() {
465        return Err(DagMlError::CampaignValidation(format!(
466            "combo relation `{}` has no component observations",
467            record.observation_id
468        )));
469    }
470    if record.derived_unit_id.is_none() {
471        return Err(DagMlError::CampaignValidation(format!(
472            "combo relation `{}` requires derived_unit_id",
473            record.observation_id
474        )));
475    }
476    if let Some(origin_sample_id) = &record.origin_sample_id {
477        if origin_sample_id != &record.sample_id {
478            return Err(DagMlError::CampaignValidation(format!(
479                "combo relation `{}` origin sample `{}` differs from sample `{}`",
480                record.observation_id, origin_sample_id, record.sample_id
481            )));
482        }
483    }
484
485    let mut components = BTreeSet::new();
486    for component_observation_id in &record.component_observation_ids {
487        if component_observation_id == &record.observation_id {
488            return Err(DagMlError::CampaignValidation(format!(
489                "combo relation `{}` cannot list itself as a component",
490                record.observation_id
491            )));
492        }
493        if !components.insert(component_observation_id) {
494            return Err(DagMlError::CampaignValidation(format!(
495                "combo relation `{}` repeats component observation `{}`",
496                record.observation_id, component_observation_id
497            )));
498        }
499        let component_sample = observation_samples
500            .get(component_observation_id)
501            .ok_or_else(|| {
502                DagMlError::CampaignValidation(format!(
503                    "combo relation `{}` references missing component observation `{}`",
504                    record.observation_id, component_observation_id
505                ))
506            })?;
507        if component_sample != &record.sample_id {
508            return Err(DagMlError::CampaignValidation(format!(
509                "combo relation `{}` component observation `{}` belongs to sample `{}` not `{}`",
510                record.observation_id, component_observation_id, component_sample, record.sample_id
511            )));
512        }
513    }
514    Ok(())
515}
516
517fn validate_fold_set_groups_match_relations(
518    fold_set: &FoldSet,
519    sample_to_group: &BTreeMap<SampleId, GroupId>,
520) -> Result<()> {
521    for (sample_id, fold_group) in &fold_set.sample_groups {
522        if let Some(relation_group) = sample_to_group.get(sample_id) {
523            if relation_group != fold_group {
524                return Err(DagMlError::CampaignValidation(format!(
525                    "sample `{sample_id}` has group `{relation_group}` in relations but `{fold_group}` in fold set"
526                )));
527            }
528        }
529    }
530    Ok(())
531}
532
533fn validate_unit_partitions<Unit: Ord + std::fmt::Display>(
534    fold_id: &str,
535    label: &str,
536    partitions: &BTreeMap<&SampleId, FoldPartition>,
537    sample_units: &BTreeMap<SampleId, Unit>,
538) -> Result<()> {
539    let mut unit_partitions = BTreeMap::<&Unit, FoldPartition>::new();
540    for (sample_id, partition) in partitions {
541        let Some(unit) = sample_units.get(*sample_id) else {
542            return Err(DagMlError::CampaignValidation(format!(
543                "fold `{fold_id}` sample `{sample_id}` is missing {label} id"
544            )));
545        };
546        if let Some(previous) = unit_partitions.insert(unit, *partition) {
547            if previous != *partition {
548                return Err(DagMlError::CampaignValidation(format!(
549                    "fold `{fold_id}` leaks {label} `{unit}` across train/validation"
550                )));
551            }
552        }
553    }
554    Ok(())
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use crate::fold::FoldAssignment;
561
562    fn sid(value: &str) -> SampleId {
563        SampleId::new(value).unwrap()
564    }
565
566    fn oid(value: &str) -> ObservationId {
567        ObservationId::new(value).unwrap()
568    }
569
570    fn tid(value: &str) -> TargetId {
571        TargetId::new(value).unwrap()
572    }
573
574    fn gid(value: &str) -> GroupId {
575        GroupId::new(value).unwrap()
576    }
577
578    fn fold_set() -> FoldSet {
579        FoldSet {
580            id: "outer".to_string(),
581            sample_ids: vec![sid("s1"), sid("s2"), sid("s3"), sid("s4")],
582            folds: vec![
583                FoldAssignment {
584                    fold_id: crate::ids::FoldId::new("fold:0").unwrap(),
585                    train_sample_ids: vec![sid("s3"), sid("s4")],
586                    validation_sample_ids: vec![sid("s1"), sid("s2")],
587                    metadata: BTreeMap::new(),
588                },
589                FoldAssignment {
590                    fold_id: crate::ids::FoldId::new("fold:1").unwrap(),
591                    train_sample_ids: vec![sid("s1"), sid("s2")],
592                    validation_sample_ids: vec![sid("s3"), sid("s4")],
593                    metadata: BTreeMap::new(),
594                },
595            ],
596            sample_groups: BTreeMap::new(),
597        }
598    }
599
600    fn relation(observation: &str, sample: &str, target: &str, group: &str) -> SampleRelation {
601        let mut relation = SampleRelation::new(oid(observation), sid(sample));
602        relation.target_id = Some(tid(target));
603        relation.group_id = Some(gid(group));
604        relation
605    }
606
607    fn source_relation(observation: &str, sample: &str, source: &str, rep: &str) -> SampleRelation {
608        let mut relation = relation(observation, sample, "target:sample", "group:sample");
609        relation.source_id = Some(source.to_string());
610        relation.rep_id = Some(rep.to_string());
611        relation
612    }
613
614    #[test]
615    fn repeated_observations_validate_at_sample_split_unit() {
616        let relations = SampleRelationSet {
617            records: vec![
618                relation("obs:1a", "s1", "t1", "g1"),
619                relation("obs:1b", "s1", "t1", "g1"),
620                relation("obs:2a", "s2", "t2", "g2"),
621                relation("obs:3a", "s3", "t3", "g3"),
622                relation("obs:4a", "s4", "t4", "g4"),
623            ],
624        };
625
626        relations
627            .validate_against_fold_set(&fold_set(), &LeakageUnitPolicy::default())
628            .unwrap();
629    }
630
631    #[test]
632    fn repeated_observations_validate_at_physical_sample_split_unit() {
633        let relations = SampleRelationSet {
634            records: vec![
635                relation("obs:1a", "s1", "t1", "g1"),
636                relation("obs:1b", "s1", "t1", "g1"),
637                relation("obs:2a", "s2", "t2", "g2"),
638                relation("obs:3a", "s3", "t3", "g3"),
639                relation("obs:4a", "s4", "t4", "g4"),
640            ],
641        };
642        let policy = LeakageUnitPolicy {
643            split_unit: SplitUnit::PhysicalSample,
644            ..LeakageUnitPolicy::default()
645        };
646
647        relations
648            .validate_against_fold_set(&fold_set(), &policy)
649            .unwrap();
650    }
651
652    #[test]
653    fn asymmetric_multisource_repetitions_and_combo_validate_as_relations() {
654        let mut combo = relation(
655            "obs:s1.combo.a0.b0.c0",
656            "s1",
657            "target:sample",
658            "group:sample",
659        );
660        combo.unit_level = EntityUnitLevel::Combo;
661        combo.source_id = Some("combo".to_string());
662        combo.derived_unit_id = Some("combo:s1:a0:b0:c0".to_string());
663        combo.origin_sample_id = Some(sid("s1"));
664        combo.component_observation_ids =
665            vec![oid("obs:s1.A.0"), oid("obs:s1.B.0"), oid("obs:s1.C.0")];
666        combo.sample_influence_weight = Some(1.0);
667        combo.quality_flag = Some("ok".to_string());
668
669        let relations = SampleRelationSet {
670            records: vec![
671                source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
672                source_relation("obs:s1.A.1", "s1", "A", "rep:1"),
673                source_relation("obs:s1.B.0", "s1", "B", "rep:0"),
674                source_relation("obs:s1.B.1", "s1", "B", "rep:1"),
675                source_relation("obs:s1.B.2", "s1", "B", "rep:2"),
676                source_relation("obs:s1.C.0", "s1", "C", "rep:0"),
677                source_relation("obs:s1.C.1", "s1", "C", "rep:1"),
678                combo,
679            ],
680        };
681
682        relations.validate().unwrap();
683        assert_eq!(
684            relations.sample_for_observation(&oid("obs:s1.combo.a0.b0.c0")),
685            Some(&sid("s1"))
686        );
687    }
688
689    #[test]
690    fn combo_components_cannot_cross_sample_boundary() {
691        let mut combo = relation("obs:s1.combo", "s1", "target:sample", "group:sample");
692        combo.unit_level = EntityUnitLevel::Combo;
693        combo.derived_unit_id = Some("combo:s1".to_string());
694        combo.component_observation_ids = vec![oid("obs:s1.A.0"), oid("obs:s2.B.0")];
695
696        let relations = SampleRelationSet {
697            records: vec![
698                source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
699                source_relation("obs:s2.B.0", "s2", "B", "rep:0"),
700                combo,
701            ],
702        };
703
704        assert!(relations.validate().is_err());
705    }
706
707    #[test]
708    fn relation_fingerprint_is_order_stable_and_provenance_sensitive() {
709        let left = SampleRelationSet {
710            records: vec![
711                source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
712                source_relation("obs:s1.B.0", "s1", "B", "rep:0"),
713            ],
714        };
715        let right = SampleRelationSet {
716            records: vec![
717                source_relation("obs:s1.B.0", "s1", "B", "rep:0"),
718                source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
719            ],
720        };
721        assert_eq!(left.fingerprint().unwrap(), right.fingerprint().unwrap());
722
723        let mut changed = left.clone();
724        changed.records[0].rep_id = Some("rep:1".to_string());
725        assert_ne!(left.fingerprint().unwrap(), changed.fingerprint().unwrap());
726    }
727
728    #[test]
729    fn old_relation_json_defaults_to_observation_unit() {
730        let relation: SampleRelation = serde_json::from_value(serde_json::json!({
731            "observation_id": "obs:legacy",
732            "sample_id": "s1",
733            "target_id": "t1",
734            "group_id": "g1",
735            "source_id": "legacy",
736            "is_augmented": false
737        }))
738        .unwrap();
739
740        assert_eq!(relation.unit_level, EntityUnitLevel::Observation);
741        assert!(relation.rep_id.is_none());
742        assert!(relation.component_observation_ids.is_empty());
743        SampleRelationSet {
744            records: vec![relation],
745        }
746        .validate()
747        .unwrap();
748    }
749
750    #[test]
751    fn relation_validation_rejects_invalid_new_fields() {
752        let mut invalid_rep = source_relation("obs:s1.A.0", "s1", "A", "rep/0");
753        assert!(invalid_rep.validate().is_err());
754
755        invalid_rep.rep_id = Some("rep:0".to_string());
756        invalid_rep.sample_influence_weight = Some(0.0);
757        assert!(invalid_rep.validate().is_err());
758
759        invalid_rep.sample_influence_weight = Some(1.0);
760        invalid_rep.quality_flag = Some(" ".to_string());
761        assert!(invalid_rep.validate().is_err());
762    }
763
764    #[test]
765    fn target_split_refuses_shared_target_across_fold_boundary() {
766        let relations = SampleRelationSet {
767            records: vec![
768                relation("obs:1", "s1", "same_target", "g1"),
769                relation("obs:2", "s2", "t2", "g2"),
770                relation("obs:3", "s3", "same_target", "g3"),
771                relation("obs:4", "s4", "t4", "g4"),
772            ],
773        };
774        let policy = LeakageUnitPolicy {
775            split_unit: SplitUnit::Target,
776            ..LeakageUnitPolicy::default()
777        };
778
779        assert!(relations
780            .validate_against_fold_set(&fold_set(), &policy)
781            .is_err());
782    }
783
784    #[test]
785    fn augmentation_origin_cannot_cross_train_validation_boundary() {
786        let mut generated = relation("obs:aug", "s3", "t3", "g3");
787        generated.origin_sample_id = Some(sid("s1"));
788        generated.is_augmented = true;
789        let relations = SampleRelationSet {
790            records: vec![
791                relation("obs:1", "s1", "t1", "g1"),
792                relation("obs:2", "s2", "t2", "g2"),
793                generated,
794                relation("obs:4", "s4", "t4", "g4"),
795            ],
796        };
797
798        assert!(relations
799            .validate_against_fold_set(&fold_set(), &LeakageUnitPolicy::default())
800            .is_err());
801    }
802}