medmodels_core/medrecord/
group_mapping.rs

1use super::{EdgeIndex, MedRecordAttribute, NodeIndex};
2use crate::errors::MedRecordError;
3use medmodels_utils::aliases::{MrHashMap, MrHashMapEntry, MrHashSet};
4use serde::{Deserialize, Serialize};
5
6pub type Group = MedRecordAttribute;
7
8#[derive(Debug, Serialize, Deserialize, Clone)]
9pub(super) struct GroupMapping {
10    nodes_in_group: MrHashMap<Group, MrHashSet<NodeIndex>>,
11    edges_in_group: MrHashMap<Group, MrHashSet<EdgeIndex>>,
12    groups_of_node: MrHashMap<NodeIndex, MrHashSet<Group>>,
13    groups_of_edge: MrHashMap<EdgeIndex, MrHashSet<Group>>,
14}
15
16impl GroupMapping {
17    pub fn new() -> Self {
18        Self {
19            nodes_in_group: MrHashMap::new(),
20            edges_in_group: MrHashMap::new(),
21            groups_of_node: MrHashMap::new(),
22            groups_of_edge: MrHashMap::new(),
23        }
24    }
25
26    pub fn add_group(
27        &mut self,
28        group: Group,
29        node_indices: Option<Vec<NodeIndex>>,
30        edge_indices: Option<Vec<EdgeIndex>>,
31    ) -> Result<(), MedRecordError> {
32        match self.nodes_in_group.entry(group.clone()) {
33            MrHashMapEntry::Occupied(o) => Err(MedRecordError::AssertionError(format!(
34                "Group {} already exists",
35                o.key()
36            ))),
37            MrHashMapEntry::Vacant(v) => {
38                v.insert(MrHashSet::from_iter(
39                    node_indices.clone().unwrap_or_default().into_iter(),
40                ));
41                Ok(())
42            }
43        }?;
44
45        match self.edges_in_group.entry(group.clone()) {
46            MrHashMapEntry::Occupied(o) => Err(MedRecordError::AssertionError(format!(
47                "Group {} already exists",
48                o.key()
49            ))),
50            MrHashMapEntry::Vacant(v) => {
51                v.insert(MrHashSet::from_iter(
52                    edge_indices.clone().unwrap_or_default().into_iter(),
53                ));
54                Ok(())
55            }
56        }?;
57
58        match (node_indices, edge_indices) {
59            (None, None) => (),
60            (None, Some(edge_indices)) => {
61                for edge_index in edge_indices {
62                    self.groups_of_edge
63                        .entry(edge_index)
64                        .or_default()
65                        .insert(group.clone());
66                }
67            }
68            (Some(node_indices), None) => {
69                for node_index in node_indices {
70                    self.groups_of_node
71                        .entry(node_index)
72                        .or_default()
73                        .insert(group.clone());
74                }
75            }
76            (Some(node_indices), Some(edge_indices)) => {
77                for node_index in node_indices {
78                    self.groups_of_node
79                        .entry(node_index)
80                        .or_default()
81                        .insert(group.clone());
82                }
83
84                for edge_index in edge_indices {
85                    self.groups_of_edge
86                        .entry(edge_index)
87                        .or_default()
88                        .insert(group.clone());
89                }
90            }
91        };
92        Ok(())
93    }
94
95    pub fn add_node_to_group(
96        &mut self,
97        group: Group,
98        node_index: NodeIndex,
99    ) -> Result<(), MedRecordError> {
100        let nodes_in_group =
101            self.nodes_in_group
102                .get_mut(&group)
103                .ok_or(MedRecordError::IndexError(format!(
104                    "Cannot find group {}",
105                    group
106                )))?;
107
108        if !nodes_in_group.insert(node_index.clone()) {
109            return Err(MedRecordError::AssertionError(format!(
110                "Node with index {} already in group {}",
111                node_index, group
112            )));
113        }
114
115        self.groups_of_node
116            .entry(node_index)
117            .or_default()
118            .insert(group);
119
120        Ok(())
121    }
122
123    pub fn add_edge_to_group(
124        &mut self,
125        group: Group,
126        edge_index: EdgeIndex,
127    ) -> Result<(), MedRecordError> {
128        let edges_in_group =
129            self.edges_in_group
130                .get_mut(&group)
131                .ok_or(MedRecordError::IndexError(format!(
132                    "Cannot find group {}",
133                    group
134                )))?;
135
136        if !edges_in_group.insert(edge_index) {
137            return Err(MedRecordError::AssertionError(format!(
138                "Edge with index {} already in group {}",
139                edge_index, group
140            )));
141        }
142
143        self.groups_of_edge
144            .entry(edge_index)
145            .or_default()
146            .insert(group);
147
148        Ok(())
149    }
150
151    pub fn remove_group(&mut self, group: &Group) -> Result<(), MedRecordError> {
152        let nodes_in_group =
153            self.nodes_in_group
154                .remove(group)
155                .ok_or(MedRecordError::IndexError(format!(
156                    "Cannot find group {}",
157                    group
158                )))?;
159
160        for node in nodes_in_group {
161            self.groups_of_node
162                .get_mut(&node)
163                .expect("Node must exist")
164                .remove(group);
165        }
166
167        Ok(())
168    }
169
170    pub fn remove_node(&mut self, node_index: &NodeIndex) {
171        let groups_of_node = self.groups_of_node.remove(node_index);
172
173        let Some(groups_of_node) = groups_of_node else {
174            return;
175        };
176
177        for group in groups_of_node {
178            self.nodes_in_group
179                .get_mut(&group)
180                .expect("Group must exist")
181                .remove(node_index);
182        }
183    }
184
185    pub fn remove_edge(&mut self, edge_index: &EdgeIndex) {
186        let groups_of_edge = self.groups_of_edge.remove(edge_index);
187
188        let Some(groups_of_edge) = groups_of_edge else {
189            return;
190        };
191
192        for group in groups_of_edge {
193            self.edges_in_group
194                .get_mut(&group)
195                .expect("Group must exist")
196                .remove(edge_index);
197        }
198    }
199
200    pub fn remove_node_from_group(
201        &mut self,
202        group: &Group,
203        node_index: &NodeIndex,
204    ) -> Result<(), MedRecordError> {
205        let nodes_in_group =
206            self.nodes_in_group
207                .get_mut(group)
208                .ok_or(MedRecordError::IndexError(format!(
209                    "Cannot find group {}",
210                    group
211                )))?;
212
213        nodes_in_group
214            .remove(node_index)
215            .then_some(())
216            .ok_or(MedRecordError::AssertionError(format!(
217                "Node with index {} not in group {}",
218                node_index, group
219            )))
220    }
221
222    pub fn remove_edge_from_group(
223        &mut self,
224        group: &Group,
225        edge_index: &EdgeIndex,
226    ) -> Result<(), MedRecordError> {
227        let edges_in_group =
228            self.edges_in_group
229                .get_mut(group)
230                .ok_or(MedRecordError::IndexError(format!(
231                    "Cannot find group {}",
232                    group
233                )))?;
234
235        edges_in_group
236            .remove(edge_index)
237            .then_some(())
238            .ok_or(MedRecordError::AssertionError(format!(
239                "Edge with index {} not in group {}",
240                edge_index, group
241            )))
242    }
243
244    pub fn groups(&self) -> impl Iterator<Item = &Group> {
245        self.nodes_in_group.keys()
246    }
247
248    pub fn nodes_in_group(
249        &self,
250        group: &Group,
251    ) -> Result<impl Iterator<Item = &NodeIndex>, MedRecordError> {
252        Ok(self
253            .nodes_in_group
254            .get(group)
255            .ok_or(MedRecordError::IndexError(format!(
256                "Cannot find group {}",
257                group
258            )))?
259            .iter())
260    }
261
262    pub fn edges_in_group(
263        &self,
264        group: &Group,
265    ) -> Result<impl Iterator<Item = &EdgeIndex>, MedRecordError> {
266        Ok(self
267            .edges_in_group
268            .get(group)
269            .ok_or(MedRecordError::IndexError(format!(
270                "Cannot find group {}",
271                group
272            )))?
273            .iter())
274    }
275
276    pub fn groups_of_node(&self, node_index: &NodeIndex) -> impl Iterator<Item = &Group> {
277        self.groups_of_node.get(node_index).into_iter().flatten()
278    }
279
280    pub fn groups_of_edge(&self, edge_index: &EdgeIndex) -> impl Iterator<Item = &Group> {
281        self.groups_of_edge.get(edge_index).into_iter().flatten()
282    }
283
284    pub fn group_count(&self) -> usize {
285        self.nodes_in_group.len()
286    }
287
288    pub fn contains_group(&self, group: &Group) -> bool {
289        self.nodes_in_group.contains_key(group)
290    }
291
292    pub fn clear(&mut self) {
293        self.nodes_in_group.clear();
294        self.edges_in_group.clear();
295        self.groups_of_node.clear();
296        self.groups_of_edge.clear();
297    }
298}
299
300#[cfg(test)]
301mod test {
302    use super::GroupMapping;
303    use crate::errors::MedRecordError;
304
305    #[test]
306    fn test_add_group() {
307        let mut group_mapping = GroupMapping::new();
308
309        assert_eq!(0, group_mapping.group_count());
310
311        group_mapping.add_group("0".into(), None, None).unwrap();
312
313        assert_eq!(1, group_mapping.group_count());
314
315        group_mapping
316            .add_group(
317                "1".into(),
318                Some(vec!["0".into(), "1".into()]),
319                Some(vec![0, 1]),
320            )
321            .unwrap();
322
323        assert_eq!(2, group_mapping.group_count());
324        assert_eq!(
325            2,
326            group_mapping.nodes_in_group(&"1".into()).unwrap().count()
327        );
328        assert_eq!(
329            2,
330            group_mapping.edges_in_group(&"1".into()).unwrap().count()
331        );
332    }
333
334    #[test]
335    fn test_invalid_add_group() {
336        let mut group_mapping = GroupMapping::new();
337
338        group_mapping.add_group("0".into(), None, None).unwrap();
339
340        // Adding an already existing group should fail
341        assert!(group_mapping
342            .add_group("0".into(), None, None)
343            .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
344    }
345
346    #[test]
347    fn test_add_node_to_group() {
348        let mut group_mapping = GroupMapping::new();
349
350        group_mapping.add_group("0".into(), None, None).unwrap();
351
352        assert_eq!(
353            0,
354            group_mapping.nodes_in_group(&"0".into()).unwrap().count()
355        );
356
357        group_mapping
358            .add_node_to_group("0".into(), "0".into())
359            .unwrap();
360
361        assert_eq!(
362            1,
363            group_mapping.nodes_in_group(&"0".into()).unwrap().count()
364        );
365    }
366
367    #[test]
368    fn test_invalid_add_node_to_group() {
369        let mut group_mapping = GroupMapping::new();
370
371        group_mapping
372            .add_group("0".into(), Some(vec!["0".into()]), None)
373            .unwrap();
374
375        // Adding to a non-existing group should fail
376        assert!(group_mapping
377            .add_node_to_group("50".into(), "1".into())
378            .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
379
380        // Adding a node to a group that already is in the group should fail
381        assert!(group_mapping
382            .add_node_to_group("0".into(), "0".into())
383            .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
384    }
385
386    #[test]
387    fn test_add_edge_to_group() {
388        let mut group_mapping = GroupMapping::new();
389
390        group_mapping.add_group("0".into(), None, None).unwrap();
391
392        assert_eq!(
393            0,
394            group_mapping.edges_in_group(&"0".into()).unwrap().count()
395        );
396
397        group_mapping.add_edge_to_group("0".into(), 0).unwrap();
398
399        assert_eq!(
400            1,
401            group_mapping.edges_in_group(&"0".into()).unwrap().count()
402        );
403    }
404
405    #[test]
406    fn test_invalid_add_edge_to_group() {
407        let mut group_mapping = GroupMapping::new();
408
409        group_mapping
410            .add_group("0".into(), None, Some(vec![0]))
411            .unwrap();
412
413        // Adding to a non-existing group should fail
414        assert!(group_mapping
415            .add_edge_to_group("50".into(), 1)
416            .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
417
418        // Adding an edge to a group that already is in the group should fail
419        assert!(group_mapping
420            .add_edge_to_group("0".into(), 0)
421            .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
422    }
423
424    #[test]
425    fn test_remove_group() {
426        let mut group_mapping = GroupMapping::new();
427
428        group_mapping.add_group("0".into(), None, None).unwrap();
429
430        assert_eq!(1, group_mapping.group_count());
431
432        group_mapping.remove_group(&"0".into()).unwrap();
433
434        assert_eq!(0, group_mapping.group_count());
435    }
436
437    #[test]
438    fn test_invalid_remove_group() {
439        let mut group_mapping = GroupMapping::new();
440
441        // Removing a non-existing group should fail
442        assert!(group_mapping
443            .remove_group(&"0".into())
444            .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
445    }
446
447    #[test]
448    fn test_remove_node() {
449        let mut group_mapping = GroupMapping::new();
450
451        group_mapping
452            .add_group("0".into(), Some(vec!["0".into()]), None)
453            .unwrap();
454
455        assert_eq!(
456            1,
457            group_mapping.nodes_in_group(&"0".into()).unwrap().count()
458        );
459
460        group_mapping.remove_node(&"0".into());
461
462        assert_eq!(
463            0,
464            group_mapping.nodes_in_group(&"0".into()).unwrap().count()
465        );
466    }
467
468    #[test]
469    fn test_remove_edge() {
470        let mut group_mapping = GroupMapping::new();
471
472        group_mapping
473            .add_group("0".into(), None, Some(vec![0]))
474            .unwrap();
475
476        assert_eq!(
477            1,
478            group_mapping.edges_in_group(&"0".into()).unwrap().count()
479        );
480
481        group_mapping.remove_edge(&0);
482
483        assert_eq!(
484            0,
485            group_mapping.edges_in_group(&"0".into()).unwrap().count()
486        );
487    }
488
489    #[test]
490    fn test_remove_node_from_group() {
491        let mut group_mapping = GroupMapping::new();
492
493        group_mapping
494            .add_group("0".into(), Some(vec!["0".into(), "1".into()]), None)
495            .unwrap();
496
497        assert_eq!(
498            2,
499            group_mapping.nodes_in_group(&"0".into()).unwrap().count()
500        );
501
502        group_mapping
503            .remove_node_from_group(&"0".into(), &"0".into())
504            .unwrap();
505
506        assert_eq!(
507            1,
508            group_mapping.nodes_in_group(&"0".into()).unwrap().count()
509        );
510    }
511
512    #[test]
513    fn test_invalid_remove_node_from_group() {
514        let mut group_mapping = GroupMapping::new();
515
516        group_mapping
517            .add_group("0".into(), Some(vec!["0".into()]), None)
518            .unwrap();
519
520        // Removing a node from a non-existing group should fail
521        assert!(group_mapping
522            .remove_node_from_group(&"50".into(), &"0".into())
523            .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
524
525        // Removing a non-existing node from a group should fail
526        assert!(group_mapping
527            .remove_node_from_group(&"0".into(), &"50".into())
528            .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
529    }
530
531    #[test]
532    fn test_remove_edge_from_group() {
533        let mut group_mapping = GroupMapping::new();
534
535        group_mapping
536            .add_group("0".into(), None, Some(vec![0, 1]))
537            .unwrap();
538
539        assert_eq!(
540            2,
541            group_mapping.edges_in_group(&"0".into()).unwrap().count()
542        );
543
544        group_mapping
545            .remove_edge_from_group(&"0".into(), &0)
546            .unwrap();
547
548        assert_eq!(
549            1,
550            group_mapping.edges_in_group(&"0".into()).unwrap().count()
551        );
552    }
553
554    #[test]
555    fn test_invalid_remove_edge_from_group() {
556        let mut group_mapping = GroupMapping::new();
557
558        group_mapping
559            .add_group("0".into(), None, Some(vec![0]))
560            .unwrap();
561
562        // Removing an edge from a non-existing group should fail
563        assert!(group_mapping
564            .remove_edge_from_group(&"50".into(), &0)
565            .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
566
567        // Removing a non-existing edge from a group should fail
568        assert!(group_mapping
569            .remove_edge_from_group(&"0".into(), &50)
570            .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
571    }
572
573    #[test]
574    fn test_groups() {
575        let mut group_mapping = GroupMapping::new();
576
577        group_mapping.add_group("0".into(), None, None).unwrap();
578
579        assert_eq!(1, group_mapping.groups().count());
580    }
581
582    #[test]
583    fn test_nodes_in_group() {
584        let mut group_mapping = GroupMapping::new();
585
586        group_mapping
587            .add_group("0".into(), Some(vec!["0".into(), "1".into()]), None)
588            .unwrap();
589
590        assert_eq!(
591            2,
592            group_mapping.nodes_in_group(&"0".into()).unwrap().count()
593        );
594    }
595
596    #[test]
597    fn test_invalid_nodes_in_group() {
598        let group_mapping = GroupMapping::new();
599
600        // Querying the nodes in a non-existing group should fail
601        assert!(group_mapping
602            .nodes_in_group(&"0".into())
603            .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
604    }
605
606    #[test]
607    fn test_edges_in_group() {
608        let mut group_mapping = GroupMapping::new();
609
610        group_mapping
611            .add_group("0".into(), None, Some(vec![0, 1]))
612            .unwrap();
613
614        assert_eq!(
615            2,
616            group_mapping.edges_in_group(&"0".into()).unwrap().count()
617        );
618    }
619
620    #[test]
621    fn test_invalid_edges_in_group() {
622        let group_mapping = GroupMapping::new();
623
624        // Querying the edges in a non-existing group should fail
625        assert!(group_mapping
626            .edges_in_group(&"0".into())
627            .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
628    }
629
630    #[test]
631    fn test_groups_of_node() {
632        let mut group_mapping = GroupMapping::new();
633
634        group_mapping
635            .add_group("0".into(), Some(vec!["0".into()]), None)
636            .unwrap();
637
638        assert_eq!(1, group_mapping.groups_of_node(&"0".into()).count());
639    }
640
641    #[test]
642    fn test_groups_of_edge() {
643        let mut group_mapping = GroupMapping::new();
644
645        group_mapping
646            .add_group("0".into(), None, Some(vec![0]))
647            .unwrap();
648
649        assert_eq!(1, group_mapping.groups_of_edge(&0).count());
650    }
651
652    #[test]
653    fn test_group_count() {
654        let mut group_mapping = GroupMapping::new();
655
656        assert_eq!(0, group_mapping.group_count());
657
658        group_mapping.add_group("0".into(), None, None).unwrap();
659
660        assert_eq!(1, group_mapping.group_count());
661    }
662
663    #[test]
664    fn test_contains_group() {
665        let mut group_mapping = GroupMapping::new();
666
667        assert!(!group_mapping.contains_group(&"0".into()));
668
669        group_mapping.add_group("0".into(), None, None).unwrap();
670
671        assert!(group_mapping.contains_group(&"0".into()));
672    }
673
674    #[test]
675    fn test_clear() {
676        let mut group_mapping = GroupMapping::new();
677
678        group_mapping.add_group("0".into(), None, None).unwrap();
679
680        assert_eq!(1, group_mapping.group_count());
681
682        group_mapping.clear();
683
684        assert_eq!(0, group_mapping.group_count());
685    }
686}