1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, Result};
7use crate::fold::FoldSet;
8use crate::ids::{GroupId, ObservationId, SampleId, TargetId};
9use crate::policy::{LeakageUnitPolicy, SplitUnit};
10
11#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum FoldPartition {
14 Train,
15 Validation,
16}
17
18#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum EntityUnitLevel {
21 PhysicalSample,
22 SourceSample,
23 #[default]
24 Observation,
25 Combo,
26}
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub struct SampleRelation {
30 #[serde(default)]
31 pub unit_level: EntityUnitLevel,
32 #[serde(default)]
33 pub unit_id: Option<String>,
34 pub observation_id: ObservationId,
35 pub sample_id: SampleId,
36 #[serde(default)]
37 pub source_id: Option<String>,
38 #[serde(default)]
39 pub rep_id: Option<String>,
40 #[serde(default)]
41 pub target_id: Option<TargetId>,
42 #[serde(default)]
43 pub group_id: Option<GroupId>,
44 #[serde(default)]
45 pub origin_sample_id: Option<SampleId>,
46 #[serde(default)]
47 pub derived_unit_id: Option<String>,
48 #[serde(default)]
49 pub component_observation_ids: Vec<ObservationId>,
50 #[serde(default)]
51 pub sample_influence_weight: Option<f64>,
52 #[serde(default)]
53 pub quality_flag: Option<String>,
54 #[serde(default)]
55 pub is_augmented: bool,
56}
57
58impl SampleRelation {
59 pub fn new(observation_id: ObservationId, sample_id: SampleId) -> Self {
60 Self {
61 unit_level: EntityUnitLevel::Observation,
62 unit_id: None,
63 observation_id,
64 sample_id,
65 source_id: None,
66 rep_id: None,
67 target_id: None,
68 group_id: None,
69 origin_sample_id: None,
70 derived_unit_id: None,
71 component_observation_ids: Vec::new(),
72 sample_influence_weight: None,
73 quality_flag: None,
74 is_augmented: false,
75 }
76 }
77
78 pub fn effective_unit_id(&self) -> Result<String> {
79 if let Some(unit_id) = non_empty_optional("unit_id", &self.observation_id, &self.unit_id)? {
80 return Ok(unit_id.to_string());
81 }
82
83 match self.unit_level {
84 EntityUnitLevel::PhysicalSample => Ok(self.sample_id.to_string()),
85 EntityUnitLevel::SourceSample => {
86 let source_id =
87 non_empty_optional("source_id", &self.observation_id, &self.source_id)?
88 .ok_or_else(|| {
89 DagMlError::CampaignValidation(format!(
90 "source-sample relation `{}` requires source_id",
91 self.observation_id
92 ))
93 })?;
94 Ok(format!("{}::{source_id}", self.sample_id))
95 }
96 EntityUnitLevel::Observation => Ok(self.observation_id.to_string()),
97 EntityUnitLevel::Combo => {
98 let derived_unit_id = non_empty_optional(
99 "derived_unit_id",
100 &self.observation_id,
101 &self.derived_unit_id,
102 )?
103 .ok_or_else(|| {
104 DagMlError::CampaignValidation(format!(
105 "combo relation `{}` requires derived_unit_id",
106 self.observation_id
107 ))
108 })?;
109 Ok(derived_unit_id.to_string())
110 }
111 }
112 }
113
114 fn validate(&self) -> Result<()> {
115 non_empty_optional("unit_id", &self.observation_id, &self.unit_id)?;
116 non_empty_optional("source_id", &self.observation_id, &self.source_id)?;
117 non_empty_optional(
118 "derived_unit_id",
119 &self.observation_id,
120 &self.derived_unit_id,
121 )?;
122 non_empty_optional("quality_flag", &self.observation_id, &self.quality_flag)?;
123 validate_optional_identifier("rep_id", &self.observation_id, &self.rep_id)?;
124
125 if let Some(weight) = self.sample_influence_weight {
126 if !weight.is_finite() || weight <= 0.0 {
127 return Err(DagMlError::CampaignValidation(format!(
128 "relation `{}` has invalid sample_influence_weight",
129 self.observation_id
130 )));
131 }
132 }
133
134 if self.unit_level != EntityUnitLevel::Combo && !self.component_observation_ids.is_empty() {
135 return Err(DagMlError::CampaignValidation(format!(
136 "relation `{}` has component_observation_ids but is not a combo",
137 self.observation_id
138 )));
139 }
140
141 self.effective_unit_id()?;
142 Ok(())
143 }
144}
145
146#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
147pub struct SampleRelationSet {
148 #[serde(default)]
149 pub records: Vec<SampleRelation>,
150}
151
152pub fn relation_set_fingerprint(relations: &SampleRelationSet) -> Result<String> {
153 relations.fingerprint()
154}
155
156#[derive(Clone, Debug, Serialize)]
157struct CanonicalRelationRecord {
158 effective_unit_id: String,
159 unit_level: EntityUnitLevel,
160 unit_id: Option<String>,
161 observation_id: ObservationId,
162 sample_id: SampleId,
163 source_id: Option<String>,
164 rep_id: Option<String>,
165 target_id: Option<TargetId>,
166 group_id: Option<GroupId>,
167 origin_sample_id: Option<SampleId>,
168 derived_unit_id: Option<String>,
169 component_observation_ids: Vec<ObservationId>,
170 sample_influence_weight: Option<f64>,
171 quality_flag: Option<String>,
172 is_augmented: bool,
173}
174
175impl SampleRelationSet {
176 pub fn validate(&self) -> Result<()> {
177 let mut observations = BTreeSet::new();
178 let mut observation_samples = BTreeMap::<ObservationId, SampleId>::new();
179 let mut unit_ids = BTreeMap::<String, ObservationId>::new();
180 let mut sample_targets = BTreeMap::<SampleId, TargetId>::new();
181 let mut sample_groups = BTreeMap::<SampleId, GroupId>::new();
182 for record in &self.records {
183 record.validate()?;
184 if !observations.insert(&record.observation_id) {
185 return Err(DagMlError::CampaignValidation(format!(
186 "duplicate observation relation `{}`",
187 record.observation_id
188 )));
189 }
190 observation_samples.insert(record.observation_id.clone(), record.sample_id.clone());
191 let effective_unit_id = record.effective_unit_id()?;
192 if let Some(previous) =
193 unit_ids.insert(effective_unit_id.clone(), record.observation_id.clone())
194 {
195 return Err(DagMlError::CampaignValidation(format!(
196 "relations `{previous}` and `{}` share effective unit id `{effective_unit_id}`",
197 record.observation_id
198 )));
199 }
200 if let Some(target_id) = &record.target_id {
201 if let Some(previous) = sample_targets.get(&record.sample_id) {
202 if previous != target_id {
203 return Err(DagMlError::CampaignValidation(format!(
204 "sample `{}` maps to multiple targets",
205 record.sample_id
206 )));
207 }
208 } else {
209 sample_targets.insert(record.sample_id.clone(), target_id.clone());
210 }
211 }
212 if let Some(group_id) = &record.group_id {
213 if let Some(previous) = sample_groups.get(&record.sample_id) {
214 if previous != group_id {
215 return Err(DagMlError::CampaignValidation(format!(
216 "sample `{}` maps to multiple groups",
217 record.sample_id
218 )));
219 }
220 } else {
221 sample_groups.insert(record.sample_id.clone(), group_id.clone());
222 }
223 }
224 }
225 for record in &self.records {
226 validate_combo_record(record, &observation_samples)?;
227 }
228 Ok(())
229 }
230
231 pub fn fingerprint(&self) -> Result<String> {
232 self.validate()?;
233 let mut canonical = self
234 .records
235 .iter()
236 .map(|record| {
237 let effective_unit_id = record.effective_unit_id()?;
238 Ok(CanonicalRelationRecord {
239 effective_unit_id,
240 unit_level: record.unit_level,
241 unit_id: record.unit_id.clone(),
242 observation_id: record.observation_id.clone(),
243 sample_id: record.sample_id.clone(),
244 source_id: record.source_id.clone(),
245 rep_id: record.rep_id.clone(),
246 target_id: record.target_id.clone(),
247 group_id: record.group_id.clone(),
248 origin_sample_id: record.origin_sample_id.clone(),
249 derived_unit_id: record.derived_unit_id.clone(),
250 component_observation_ids: record.component_observation_ids.clone(),
251 sample_influence_weight: record.sample_influence_weight,
252 quality_flag: record.quality_flag.clone(),
253 is_augmented: record.is_augmented,
254 })
255 })
256 .collect::<Result<Vec<_>>>()?;
257 canonical.sort_by(|left, right| {
258 (
259 left.effective_unit_id.as_str(),
260 left.observation_id.as_str(),
261 left.sample_id.as_str(),
262 )
263 .cmp(&(
264 right.effective_unit_id.as_str(),
265 right.observation_id.as_str(),
266 right.sample_id.as_str(),
267 ))
268 });
269 stable_json_fingerprint(&canonical)
270 }
271
272 pub fn validate_against_fold_set(
273 &self,
274 fold_set: &FoldSet,
275 policy: &LeakageUnitPolicy,
276 ) -> Result<()> {
277 self.validate()?;
278 fold_set.validate()?;
279 policy.validate()?;
280
281 let universe = fold_set.sample_ids.iter().collect::<BTreeSet<_>>();
282 for record in &self.records {
283 if !universe.contains(&record.sample_id) {
284 return Err(DagMlError::CampaignValidation(format!(
285 "relation `{}` references sample `{}` outside fold set",
286 record.observation_id, record.sample_id
287 )));
288 }
289 if let Some(origin_sample_id) = &record.origin_sample_id {
290 if !universe.contains(origin_sample_id) {
291 return Err(DagMlError::CampaignValidation(format!(
292 "relation `{}` references origin sample `{}` outside fold set",
293 record.observation_id, origin_sample_id
294 )));
295 }
296 }
297 if policy.require_group_ids && record.group_id.is_none() {
298 return Err(DagMlError::CampaignValidation(format!(
299 "relation `{}` is missing required group id",
300 record.observation_id
301 )));
302 }
303 }
304
305 let sample_to_target = self.sample_targets();
306 let sample_to_group = self.sample_groups();
307 validate_fold_set_groups_match_relations(fold_set, &sample_to_group)?;
308
309 for fold in &fold_set.folds {
310 let partitions = fold
311 .train_sample_ids
312 .iter()
313 .map(|sample_id| (sample_id, FoldPartition::Train))
314 .chain(
315 fold.validation_sample_ids
316 .iter()
317 .map(|sample_id| (sample_id, FoldPartition::Validation)),
318 )
319 .collect::<BTreeMap<_, _>>();
320
321 if policy.forbid_origin_cross_fold {
322 for record in &self.records {
323 if let Some(origin_sample_id) = &record.origin_sample_id {
324 let sample_partition =
325 partitions.get(&record.sample_id).ok_or_else(|| {
326 DagMlError::CampaignValidation(format!(
327 "fold `{}` does not contain sample `{}`",
328 fold.fold_id, record.sample_id
329 ))
330 })?;
331 let origin_partition =
332 partitions.get(origin_sample_id).ok_or_else(|| {
333 DagMlError::CampaignValidation(format!(
334 "fold `{}` does not contain origin sample `{}`",
335 fold.fold_id, origin_sample_id
336 ))
337 })?;
338 if sample_partition != origin_partition {
339 return Err(DagMlError::CampaignValidation(format!(
340 "fold `{}` leaks origin sample `{}` into {:?} sample `{}`",
341 fold.fold_id, origin_sample_id, sample_partition, record.sample_id
342 )));
343 }
344 }
345 }
346 }
347
348 match policy.split_unit {
349 SplitUnit::PhysicalSample | SplitUnit::Observation | SplitUnit::Sample => {}
350 SplitUnit::Target => validate_unit_partitions(
351 &fold.fold_id.to_string(),
352 "target",
353 &partitions,
354 &sample_to_target,
355 )?,
356 SplitUnit::Group => validate_unit_partitions(
357 &fold.fold_id.to_string(),
358 "group",
359 &partitions,
360 &sample_to_group,
361 )?,
362 }
363 }
364 Ok(())
365 }
366
367 pub fn sample_for_observation(&self, observation_id: &ObservationId) -> Option<&SampleId> {
368 self.records
369 .iter()
370 .find(|record| &record.observation_id == observation_id)
371 .map(|record| &record.sample_id)
372 }
373
374 pub fn target_for_sample(&self, sample_id: &SampleId) -> Option<&TargetId> {
375 self.records
376 .iter()
377 .find(|record| &record.sample_id == sample_id)
378 .and_then(|record| record.target_id.as_ref())
379 }
380
381 pub fn group_for_sample(&self, sample_id: &SampleId) -> Option<&GroupId> {
382 self.records
383 .iter()
384 .find(|record| &record.sample_id == sample_id)
385 .and_then(|record| record.group_id.as_ref())
386 }
387
388 pub fn observation_count_for_sample(&self, sample_id: &SampleId) -> usize {
389 self.records
390 .iter()
391 .filter(|record| &record.sample_id == sample_id)
392 .count()
393 }
394
395 pub fn sample_targets(&self) -> BTreeMap<SampleId, TargetId> {
396 self.records
397 .iter()
398 .filter_map(|record| {
399 record
400 .target_id
401 .as_ref()
402 .map(|target_id| (record.sample_id.clone(), target_id.clone()))
403 })
404 .collect()
405 }
406
407 pub fn sample_groups(&self) -> BTreeMap<SampleId, GroupId> {
408 self.records
409 .iter()
410 .filter_map(|record| {
411 record
412 .group_id
413 .as_ref()
414 .map(|group_id| (record.sample_id.clone(), group_id.clone()))
415 })
416 .collect()
417 }
418}
419
420fn non_empty_optional<'a>(
421 field: &str,
422 observation_id: &ObservationId,
423 value: &'a Option<String>,
424) -> Result<Option<&'a str>> {
425 if let Some(value) = value.as_deref() {
426 if value.trim().is_empty() {
427 return Err(DagMlError::CampaignValidation(format!(
428 "relation `{observation_id}` has empty {field}"
429 )));
430 }
431 Ok(Some(value))
432 } else {
433 Ok(None)
434 }
435}
436
437fn validate_optional_identifier(
438 field: &str,
439 observation_id: &ObservationId,
440 value: &Option<String>,
441) -> Result<()> {
442 let Some(value) = non_empty_optional(field, observation_id, value)? else {
443 return Ok(());
444 };
445 if value.len() > 128
446 || !value
447 .bytes()
448 .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b':'))
449 {
450 return Err(DagMlError::CampaignValidation(format!(
451 "relation `{observation_id}` has invalid {field}"
452 )));
453 }
454 Ok(())
455}
456
457fn validate_combo_record(
458 record: &SampleRelation,
459 observation_samples: &BTreeMap<ObservationId, SampleId>,
460) -> Result<()> {
461 if record.unit_level != EntityUnitLevel::Combo {
462 return Ok(());
463 }
464 if record.component_observation_ids.is_empty() {
465 return Err(DagMlError::CampaignValidation(format!(
466 "combo relation `{}` has no component observations",
467 record.observation_id
468 )));
469 }
470 if record.derived_unit_id.is_none() {
471 return Err(DagMlError::CampaignValidation(format!(
472 "combo relation `{}` requires derived_unit_id",
473 record.observation_id
474 )));
475 }
476 if let Some(origin_sample_id) = &record.origin_sample_id {
477 if origin_sample_id != &record.sample_id {
478 return Err(DagMlError::CampaignValidation(format!(
479 "combo relation `{}` origin sample `{}` differs from sample `{}`",
480 record.observation_id, origin_sample_id, record.sample_id
481 )));
482 }
483 }
484
485 let mut components = BTreeSet::new();
486 for component_observation_id in &record.component_observation_ids {
487 if component_observation_id == &record.observation_id {
488 return Err(DagMlError::CampaignValidation(format!(
489 "combo relation `{}` cannot list itself as a component",
490 record.observation_id
491 )));
492 }
493 if !components.insert(component_observation_id) {
494 return Err(DagMlError::CampaignValidation(format!(
495 "combo relation `{}` repeats component observation `{}`",
496 record.observation_id, component_observation_id
497 )));
498 }
499 let component_sample = observation_samples
500 .get(component_observation_id)
501 .ok_or_else(|| {
502 DagMlError::CampaignValidation(format!(
503 "combo relation `{}` references missing component observation `{}`",
504 record.observation_id, component_observation_id
505 ))
506 })?;
507 if component_sample != &record.sample_id {
508 return Err(DagMlError::CampaignValidation(format!(
509 "combo relation `{}` component observation `{}` belongs to sample `{}` not `{}`",
510 record.observation_id, component_observation_id, component_sample, record.sample_id
511 )));
512 }
513 }
514 Ok(())
515}
516
517fn validate_fold_set_groups_match_relations(
518 fold_set: &FoldSet,
519 sample_to_group: &BTreeMap<SampleId, GroupId>,
520) -> Result<()> {
521 for (sample_id, fold_group) in &fold_set.sample_groups {
522 if let Some(relation_group) = sample_to_group.get(sample_id) {
523 if relation_group != fold_group {
524 return Err(DagMlError::CampaignValidation(format!(
525 "sample `{sample_id}` has group `{relation_group}` in relations but `{fold_group}` in fold set"
526 )));
527 }
528 }
529 }
530 Ok(())
531}
532
533fn validate_unit_partitions<Unit: Ord + std::fmt::Display>(
534 fold_id: &str,
535 label: &str,
536 partitions: &BTreeMap<&SampleId, FoldPartition>,
537 sample_units: &BTreeMap<SampleId, Unit>,
538) -> Result<()> {
539 let mut unit_partitions = BTreeMap::<&Unit, FoldPartition>::new();
540 for (sample_id, partition) in partitions {
541 let Some(unit) = sample_units.get(*sample_id) else {
542 return Err(DagMlError::CampaignValidation(format!(
543 "fold `{fold_id}` sample `{sample_id}` is missing {label} id"
544 )));
545 };
546 if let Some(previous) = unit_partitions.insert(unit, *partition) {
547 if previous != *partition {
548 return Err(DagMlError::CampaignValidation(format!(
549 "fold `{fold_id}` leaks {label} `{unit}` across train/validation"
550 )));
551 }
552 }
553 }
554 Ok(())
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::fold::FoldAssignment;
561
562 fn sid(value: &str) -> SampleId {
563 SampleId::new(value).unwrap()
564 }
565
566 fn oid(value: &str) -> ObservationId {
567 ObservationId::new(value).unwrap()
568 }
569
570 fn tid(value: &str) -> TargetId {
571 TargetId::new(value).unwrap()
572 }
573
574 fn gid(value: &str) -> GroupId {
575 GroupId::new(value).unwrap()
576 }
577
578 fn fold_set() -> FoldSet {
579 FoldSet {
580 id: "outer".to_string(),
581 sample_ids: vec![sid("s1"), sid("s2"), sid("s3"), sid("s4")],
582 folds: vec![
583 FoldAssignment {
584 fold_id: crate::ids::FoldId::new("fold:0").unwrap(),
585 train_sample_ids: vec![sid("s3"), sid("s4")],
586 validation_sample_ids: vec![sid("s1"), sid("s2")],
587 metadata: BTreeMap::new(),
588 },
589 FoldAssignment {
590 fold_id: crate::ids::FoldId::new("fold:1").unwrap(),
591 train_sample_ids: vec![sid("s1"), sid("s2")],
592 validation_sample_ids: vec![sid("s3"), sid("s4")],
593 metadata: BTreeMap::new(),
594 },
595 ],
596 sample_groups: BTreeMap::new(),
597 }
598 }
599
600 fn relation(observation: &str, sample: &str, target: &str, group: &str) -> SampleRelation {
601 let mut relation = SampleRelation::new(oid(observation), sid(sample));
602 relation.target_id = Some(tid(target));
603 relation.group_id = Some(gid(group));
604 relation
605 }
606
607 fn source_relation(observation: &str, sample: &str, source: &str, rep: &str) -> SampleRelation {
608 let mut relation = relation(observation, sample, "target:sample", "group:sample");
609 relation.source_id = Some(source.to_string());
610 relation.rep_id = Some(rep.to_string());
611 relation
612 }
613
614 #[test]
615 fn repeated_observations_validate_at_sample_split_unit() {
616 let relations = SampleRelationSet {
617 records: vec![
618 relation("obs:1a", "s1", "t1", "g1"),
619 relation("obs:1b", "s1", "t1", "g1"),
620 relation("obs:2a", "s2", "t2", "g2"),
621 relation("obs:3a", "s3", "t3", "g3"),
622 relation("obs:4a", "s4", "t4", "g4"),
623 ],
624 };
625
626 relations
627 .validate_against_fold_set(&fold_set(), &LeakageUnitPolicy::default())
628 .unwrap();
629 }
630
631 #[test]
632 fn repeated_observations_validate_at_physical_sample_split_unit() {
633 let relations = SampleRelationSet {
634 records: vec![
635 relation("obs:1a", "s1", "t1", "g1"),
636 relation("obs:1b", "s1", "t1", "g1"),
637 relation("obs:2a", "s2", "t2", "g2"),
638 relation("obs:3a", "s3", "t3", "g3"),
639 relation("obs:4a", "s4", "t4", "g4"),
640 ],
641 };
642 let policy = LeakageUnitPolicy {
643 split_unit: SplitUnit::PhysicalSample,
644 ..LeakageUnitPolicy::default()
645 };
646
647 relations
648 .validate_against_fold_set(&fold_set(), &policy)
649 .unwrap();
650 }
651
652 #[test]
653 fn asymmetric_multisource_repetitions_and_combo_validate_as_relations() {
654 let mut combo = relation(
655 "obs:s1.combo.a0.b0.c0",
656 "s1",
657 "target:sample",
658 "group:sample",
659 );
660 combo.unit_level = EntityUnitLevel::Combo;
661 combo.source_id = Some("combo".to_string());
662 combo.derived_unit_id = Some("combo:s1:a0:b0:c0".to_string());
663 combo.origin_sample_id = Some(sid("s1"));
664 combo.component_observation_ids =
665 vec![oid("obs:s1.A.0"), oid("obs:s1.B.0"), oid("obs:s1.C.0")];
666 combo.sample_influence_weight = Some(1.0);
667 combo.quality_flag = Some("ok".to_string());
668
669 let relations = SampleRelationSet {
670 records: vec![
671 source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
672 source_relation("obs:s1.A.1", "s1", "A", "rep:1"),
673 source_relation("obs:s1.B.0", "s1", "B", "rep:0"),
674 source_relation("obs:s1.B.1", "s1", "B", "rep:1"),
675 source_relation("obs:s1.B.2", "s1", "B", "rep:2"),
676 source_relation("obs:s1.C.0", "s1", "C", "rep:0"),
677 source_relation("obs:s1.C.1", "s1", "C", "rep:1"),
678 combo,
679 ],
680 };
681
682 relations.validate().unwrap();
683 assert_eq!(
684 relations.sample_for_observation(&oid("obs:s1.combo.a0.b0.c0")),
685 Some(&sid("s1"))
686 );
687 }
688
689 #[test]
690 fn combo_components_cannot_cross_sample_boundary() {
691 let mut combo = relation("obs:s1.combo", "s1", "target:sample", "group:sample");
692 combo.unit_level = EntityUnitLevel::Combo;
693 combo.derived_unit_id = Some("combo:s1".to_string());
694 combo.component_observation_ids = vec![oid("obs:s1.A.0"), oid("obs:s2.B.0")];
695
696 let relations = SampleRelationSet {
697 records: vec![
698 source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
699 source_relation("obs:s2.B.0", "s2", "B", "rep:0"),
700 combo,
701 ],
702 };
703
704 assert!(relations.validate().is_err());
705 }
706
707 #[test]
708 fn relation_fingerprint_is_order_stable_and_provenance_sensitive() {
709 let left = SampleRelationSet {
710 records: vec![
711 source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
712 source_relation("obs:s1.B.0", "s1", "B", "rep:0"),
713 ],
714 };
715 let right = SampleRelationSet {
716 records: vec![
717 source_relation("obs:s1.B.0", "s1", "B", "rep:0"),
718 source_relation("obs:s1.A.0", "s1", "A", "rep:0"),
719 ],
720 };
721 assert_eq!(left.fingerprint().unwrap(), right.fingerprint().unwrap());
722
723 let mut changed = left.clone();
724 changed.records[0].rep_id = Some("rep:1".to_string());
725 assert_ne!(left.fingerprint().unwrap(), changed.fingerprint().unwrap());
726 }
727
728 #[test]
729 fn old_relation_json_defaults_to_observation_unit() {
730 let relation: SampleRelation = serde_json::from_value(serde_json::json!({
731 "observation_id": "obs:legacy",
732 "sample_id": "s1",
733 "target_id": "t1",
734 "group_id": "g1",
735 "source_id": "legacy",
736 "is_augmented": false
737 }))
738 .unwrap();
739
740 assert_eq!(relation.unit_level, EntityUnitLevel::Observation);
741 assert!(relation.rep_id.is_none());
742 assert!(relation.component_observation_ids.is_empty());
743 SampleRelationSet {
744 records: vec![relation],
745 }
746 .validate()
747 .unwrap();
748 }
749
750 #[test]
751 fn relation_validation_rejects_invalid_new_fields() {
752 let mut invalid_rep = source_relation("obs:s1.A.0", "s1", "A", "rep/0");
753 assert!(invalid_rep.validate().is_err());
754
755 invalid_rep.rep_id = Some("rep:0".to_string());
756 invalid_rep.sample_influence_weight = Some(0.0);
757 assert!(invalid_rep.validate().is_err());
758
759 invalid_rep.sample_influence_weight = Some(1.0);
760 invalid_rep.quality_flag = Some(" ".to_string());
761 assert!(invalid_rep.validate().is_err());
762 }
763
764 #[test]
765 fn target_split_refuses_shared_target_across_fold_boundary() {
766 let relations = SampleRelationSet {
767 records: vec![
768 relation("obs:1", "s1", "same_target", "g1"),
769 relation("obs:2", "s2", "t2", "g2"),
770 relation("obs:3", "s3", "same_target", "g3"),
771 relation("obs:4", "s4", "t4", "g4"),
772 ],
773 };
774 let policy = LeakageUnitPolicy {
775 split_unit: SplitUnit::Target,
776 ..LeakageUnitPolicy::default()
777 };
778
779 assert!(relations
780 .validate_against_fold_set(&fold_set(), &policy)
781 .is_err());
782 }
783
784 #[test]
785 fn augmentation_origin_cannot_cross_train_validation_boundary() {
786 let mut generated = relation("obs:aug", "s3", "t3", "g3");
787 generated.origin_sample_id = Some(sid("s1"));
788 generated.is_augmented = true;
789 let relations = SampleRelationSet {
790 records: vec![
791 relation("obs:1", "s1", "t1", "g1"),
792 relation("obs:2", "s2", "t2", "g2"),
793 generated,
794 relation("obs:4", "s4", "t4", "g4"),
795 ],
796 };
797
798 assert!(relations
799 .validate_against_fold_set(&fold_set(), &LeakageUnitPolicy::default())
800 .is_err());
801 }
802}