1use super::{EdgeIndex, MedRecordAttribute, NodeIndex};
2use crate::errors::MedRecordError;
3use medmodels_utils::aliases::{MrHashMap, MrHashMapEntry, MrHashSet};
4use serde::{Deserialize, Serialize};
5
6pub type Group = MedRecordAttribute;
7
8#[derive(Debug, Serialize, Deserialize, Clone)]
9pub(super) struct GroupMapping {
10 nodes_in_group: MrHashMap<Group, MrHashSet<NodeIndex>>,
11 edges_in_group: MrHashMap<Group, MrHashSet<EdgeIndex>>,
12 groups_of_node: MrHashMap<NodeIndex, MrHashSet<Group>>,
13 groups_of_edge: MrHashMap<EdgeIndex, MrHashSet<Group>>,
14}
15
16impl GroupMapping {
17 pub fn new() -> Self {
18 Self {
19 nodes_in_group: MrHashMap::new(),
20 edges_in_group: MrHashMap::new(),
21 groups_of_node: MrHashMap::new(),
22 groups_of_edge: MrHashMap::new(),
23 }
24 }
25
26 pub fn add_group(
27 &mut self,
28 group: Group,
29 node_indices: Option<Vec<NodeIndex>>,
30 edge_indices: Option<Vec<EdgeIndex>>,
31 ) -> Result<(), MedRecordError> {
32 match self.nodes_in_group.entry(group.clone()) {
33 MrHashMapEntry::Occupied(o) => Err(MedRecordError::AssertionError(format!(
34 "Group {} already exists",
35 o.key()
36 ))),
37 MrHashMapEntry::Vacant(v) => {
38 v.insert(MrHashSet::from_iter(
39 node_indices.clone().unwrap_or_default().into_iter(),
40 ));
41 Ok(())
42 }
43 }?;
44
45 match self.edges_in_group.entry(group.clone()) {
46 MrHashMapEntry::Occupied(o) => Err(MedRecordError::AssertionError(format!(
47 "Group {} already exists",
48 o.key()
49 ))),
50 MrHashMapEntry::Vacant(v) => {
51 v.insert(MrHashSet::from_iter(
52 edge_indices.clone().unwrap_or_default().into_iter(),
53 ));
54 Ok(())
55 }
56 }?;
57
58 match (node_indices, edge_indices) {
59 (None, None) => (),
60 (None, Some(edge_indices)) => {
61 for edge_index in edge_indices {
62 self.groups_of_edge
63 .entry(edge_index)
64 .or_default()
65 .insert(group.clone());
66 }
67 }
68 (Some(node_indices), None) => {
69 for node_index in node_indices {
70 self.groups_of_node
71 .entry(node_index)
72 .or_default()
73 .insert(group.clone());
74 }
75 }
76 (Some(node_indices), Some(edge_indices)) => {
77 for node_index in node_indices {
78 self.groups_of_node
79 .entry(node_index)
80 .or_default()
81 .insert(group.clone());
82 }
83
84 for edge_index in edge_indices {
85 self.groups_of_edge
86 .entry(edge_index)
87 .or_default()
88 .insert(group.clone());
89 }
90 }
91 };
92 Ok(())
93 }
94
95 pub fn add_node_to_group(
96 &mut self,
97 group: Group,
98 node_index: NodeIndex,
99 ) -> Result<(), MedRecordError> {
100 let nodes_in_group =
101 self.nodes_in_group
102 .get_mut(&group)
103 .ok_or(MedRecordError::IndexError(format!(
104 "Cannot find group {}",
105 group
106 )))?;
107
108 if !nodes_in_group.insert(node_index.clone()) {
109 return Err(MedRecordError::AssertionError(format!(
110 "Node with index {} already in group {}",
111 node_index, group
112 )));
113 }
114
115 self.groups_of_node
116 .entry(node_index)
117 .or_default()
118 .insert(group);
119
120 Ok(())
121 }
122
123 pub fn add_edge_to_group(
124 &mut self,
125 group: Group,
126 edge_index: EdgeIndex,
127 ) -> Result<(), MedRecordError> {
128 let edges_in_group =
129 self.edges_in_group
130 .get_mut(&group)
131 .ok_or(MedRecordError::IndexError(format!(
132 "Cannot find group {}",
133 group
134 )))?;
135
136 if !edges_in_group.insert(edge_index) {
137 return Err(MedRecordError::AssertionError(format!(
138 "Edge with index {} already in group {}",
139 edge_index, group
140 )));
141 }
142
143 self.groups_of_edge
144 .entry(edge_index)
145 .or_default()
146 .insert(group);
147
148 Ok(())
149 }
150
151 pub fn remove_group(&mut self, group: &Group) -> Result<(), MedRecordError> {
152 let nodes_in_group =
153 self.nodes_in_group
154 .remove(group)
155 .ok_or(MedRecordError::IndexError(format!(
156 "Cannot find group {}",
157 group
158 )))?;
159
160 for node in nodes_in_group {
161 self.groups_of_node
162 .get_mut(&node)
163 .expect("Node must exist")
164 .remove(group);
165 }
166
167 Ok(())
168 }
169
170 pub fn remove_node(&mut self, node_index: &NodeIndex) {
171 let groups_of_node = self.groups_of_node.remove(node_index);
172
173 let Some(groups_of_node) = groups_of_node else {
174 return;
175 };
176
177 for group in groups_of_node {
178 self.nodes_in_group
179 .get_mut(&group)
180 .expect("Group must exist")
181 .remove(node_index);
182 }
183 }
184
185 pub fn remove_edge(&mut self, edge_index: &EdgeIndex) {
186 let groups_of_edge = self.groups_of_edge.remove(edge_index);
187
188 let Some(groups_of_edge) = groups_of_edge else {
189 return;
190 };
191
192 for group in groups_of_edge {
193 self.edges_in_group
194 .get_mut(&group)
195 .expect("Group must exist")
196 .remove(edge_index);
197 }
198 }
199
200 pub fn remove_node_from_group(
201 &mut self,
202 group: &Group,
203 node_index: &NodeIndex,
204 ) -> Result<(), MedRecordError> {
205 let nodes_in_group =
206 self.nodes_in_group
207 .get_mut(group)
208 .ok_or(MedRecordError::IndexError(format!(
209 "Cannot find group {}",
210 group
211 )))?;
212
213 nodes_in_group
214 .remove(node_index)
215 .then_some(())
216 .ok_or(MedRecordError::AssertionError(format!(
217 "Node with index {} not in group {}",
218 node_index, group
219 )))
220 }
221
222 pub fn remove_edge_from_group(
223 &mut self,
224 group: &Group,
225 edge_index: &EdgeIndex,
226 ) -> Result<(), MedRecordError> {
227 let edges_in_group =
228 self.edges_in_group
229 .get_mut(group)
230 .ok_or(MedRecordError::IndexError(format!(
231 "Cannot find group {}",
232 group
233 )))?;
234
235 edges_in_group
236 .remove(edge_index)
237 .then_some(())
238 .ok_or(MedRecordError::AssertionError(format!(
239 "Edge with index {} not in group {}",
240 edge_index, group
241 )))
242 }
243
244 pub fn groups(&self) -> impl Iterator<Item = &Group> {
245 self.nodes_in_group.keys()
246 }
247
248 pub fn nodes_in_group(
249 &self,
250 group: &Group,
251 ) -> Result<impl Iterator<Item = &NodeIndex>, MedRecordError> {
252 Ok(self
253 .nodes_in_group
254 .get(group)
255 .ok_or(MedRecordError::IndexError(format!(
256 "Cannot find group {}",
257 group
258 )))?
259 .iter())
260 }
261
262 pub fn edges_in_group(
263 &self,
264 group: &Group,
265 ) -> Result<impl Iterator<Item = &EdgeIndex>, MedRecordError> {
266 Ok(self
267 .edges_in_group
268 .get(group)
269 .ok_or(MedRecordError::IndexError(format!(
270 "Cannot find group {}",
271 group
272 )))?
273 .iter())
274 }
275
276 pub fn groups_of_node(&self, node_index: &NodeIndex) -> impl Iterator<Item = &Group> {
277 self.groups_of_node.get(node_index).into_iter().flatten()
278 }
279
280 pub fn groups_of_edge(&self, edge_index: &EdgeIndex) -> impl Iterator<Item = &Group> {
281 self.groups_of_edge.get(edge_index).into_iter().flatten()
282 }
283
284 pub fn group_count(&self) -> usize {
285 self.nodes_in_group.len()
286 }
287
288 pub fn contains_group(&self, group: &Group) -> bool {
289 self.nodes_in_group.contains_key(group)
290 }
291
292 pub fn clear(&mut self) {
293 self.nodes_in_group.clear();
294 self.edges_in_group.clear();
295 self.groups_of_node.clear();
296 self.groups_of_edge.clear();
297 }
298}
299
300#[cfg(test)]
301mod test {
302 use super::GroupMapping;
303 use crate::errors::MedRecordError;
304
305 #[test]
306 fn test_add_group() {
307 let mut group_mapping = GroupMapping::new();
308
309 assert_eq!(0, group_mapping.group_count());
310
311 group_mapping.add_group("0".into(), None, None).unwrap();
312
313 assert_eq!(1, group_mapping.group_count());
314
315 group_mapping
316 .add_group(
317 "1".into(),
318 Some(vec!["0".into(), "1".into()]),
319 Some(vec![0, 1]),
320 )
321 .unwrap();
322
323 assert_eq!(2, group_mapping.group_count());
324 assert_eq!(
325 2,
326 group_mapping.nodes_in_group(&"1".into()).unwrap().count()
327 );
328 assert_eq!(
329 2,
330 group_mapping.edges_in_group(&"1".into()).unwrap().count()
331 );
332 }
333
334 #[test]
335 fn test_invalid_add_group() {
336 let mut group_mapping = GroupMapping::new();
337
338 group_mapping.add_group("0".into(), None, None).unwrap();
339
340 assert!(group_mapping
342 .add_group("0".into(), None, None)
343 .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
344 }
345
346 #[test]
347 fn test_add_node_to_group() {
348 let mut group_mapping = GroupMapping::new();
349
350 group_mapping.add_group("0".into(), None, None).unwrap();
351
352 assert_eq!(
353 0,
354 group_mapping.nodes_in_group(&"0".into()).unwrap().count()
355 );
356
357 group_mapping
358 .add_node_to_group("0".into(), "0".into())
359 .unwrap();
360
361 assert_eq!(
362 1,
363 group_mapping.nodes_in_group(&"0".into()).unwrap().count()
364 );
365 }
366
367 #[test]
368 fn test_invalid_add_node_to_group() {
369 let mut group_mapping = GroupMapping::new();
370
371 group_mapping
372 .add_group("0".into(), Some(vec!["0".into()]), None)
373 .unwrap();
374
375 assert!(group_mapping
377 .add_node_to_group("50".into(), "1".into())
378 .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
379
380 assert!(group_mapping
382 .add_node_to_group("0".into(), "0".into())
383 .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
384 }
385
386 #[test]
387 fn test_add_edge_to_group() {
388 let mut group_mapping = GroupMapping::new();
389
390 group_mapping.add_group("0".into(), None, None).unwrap();
391
392 assert_eq!(
393 0,
394 group_mapping.edges_in_group(&"0".into()).unwrap().count()
395 );
396
397 group_mapping.add_edge_to_group("0".into(), 0).unwrap();
398
399 assert_eq!(
400 1,
401 group_mapping.edges_in_group(&"0".into()).unwrap().count()
402 );
403 }
404
405 #[test]
406 fn test_invalid_add_edge_to_group() {
407 let mut group_mapping = GroupMapping::new();
408
409 group_mapping
410 .add_group("0".into(), None, Some(vec![0]))
411 .unwrap();
412
413 assert!(group_mapping
415 .add_edge_to_group("50".into(), 1)
416 .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
417
418 assert!(group_mapping
420 .add_edge_to_group("0".into(), 0)
421 .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
422 }
423
424 #[test]
425 fn test_remove_group() {
426 let mut group_mapping = GroupMapping::new();
427
428 group_mapping.add_group("0".into(), None, None).unwrap();
429
430 assert_eq!(1, group_mapping.group_count());
431
432 group_mapping.remove_group(&"0".into()).unwrap();
433
434 assert_eq!(0, group_mapping.group_count());
435 }
436
437 #[test]
438 fn test_invalid_remove_group() {
439 let mut group_mapping = GroupMapping::new();
440
441 assert!(group_mapping
443 .remove_group(&"0".into())
444 .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
445 }
446
447 #[test]
448 fn test_remove_node() {
449 let mut group_mapping = GroupMapping::new();
450
451 group_mapping
452 .add_group("0".into(), Some(vec!["0".into()]), None)
453 .unwrap();
454
455 assert_eq!(
456 1,
457 group_mapping.nodes_in_group(&"0".into()).unwrap().count()
458 );
459
460 group_mapping.remove_node(&"0".into());
461
462 assert_eq!(
463 0,
464 group_mapping.nodes_in_group(&"0".into()).unwrap().count()
465 );
466 }
467
468 #[test]
469 fn test_remove_edge() {
470 let mut group_mapping = GroupMapping::new();
471
472 group_mapping
473 .add_group("0".into(), None, Some(vec![0]))
474 .unwrap();
475
476 assert_eq!(
477 1,
478 group_mapping.edges_in_group(&"0".into()).unwrap().count()
479 );
480
481 group_mapping.remove_edge(&0);
482
483 assert_eq!(
484 0,
485 group_mapping.edges_in_group(&"0".into()).unwrap().count()
486 );
487 }
488
489 #[test]
490 fn test_remove_node_from_group() {
491 let mut group_mapping = GroupMapping::new();
492
493 group_mapping
494 .add_group("0".into(), Some(vec!["0".into(), "1".into()]), None)
495 .unwrap();
496
497 assert_eq!(
498 2,
499 group_mapping.nodes_in_group(&"0".into()).unwrap().count()
500 );
501
502 group_mapping
503 .remove_node_from_group(&"0".into(), &"0".into())
504 .unwrap();
505
506 assert_eq!(
507 1,
508 group_mapping.nodes_in_group(&"0".into()).unwrap().count()
509 );
510 }
511
512 #[test]
513 fn test_invalid_remove_node_from_group() {
514 let mut group_mapping = GroupMapping::new();
515
516 group_mapping
517 .add_group("0".into(), Some(vec!["0".into()]), None)
518 .unwrap();
519
520 assert!(group_mapping
522 .remove_node_from_group(&"50".into(), &"0".into())
523 .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
524
525 assert!(group_mapping
527 .remove_node_from_group(&"0".into(), &"50".into())
528 .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
529 }
530
531 #[test]
532 fn test_remove_edge_from_group() {
533 let mut group_mapping = GroupMapping::new();
534
535 group_mapping
536 .add_group("0".into(), None, Some(vec![0, 1]))
537 .unwrap();
538
539 assert_eq!(
540 2,
541 group_mapping.edges_in_group(&"0".into()).unwrap().count()
542 );
543
544 group_mapping
545 .remove_edge_from_group(&"0".into(), &0)
546 .unwrap();
547
548 assert_eq!(
549 1,
550 group_mapping.edges_in_group(&"0".into()).unwrap().count()
551 );
552 }
553
554 #[test]
555 fn test_invalid_remove_edge_from_group() {
556 let mut group_mapping = GroupMapping::new();
557
558 group_mapping
559 .add_group("0".into(), None, Some(vec![0]))
560 .unwrap();
561
562 assert!(group_mapping
564 .remove_edge_from_group(&"50".into(), &0)
565 .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
566
567 assert!(group_mapping
569 .remove_edge_from_group(&"0".into(), &50)
570 .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))));
571 }
572
573 #[test]
574 fn test_groups() {
575 let mut group_mapping = GroupMapping::new();
576
577 group_mapping.add_group("0".into(), None, None).unwrap();
578
579 assert_eq!(1, group_mapping.groups().count());
580 }
581
582 #[test]
583 fn test_nodes_in_group() {
584 let mut group_mapping = GroupMapping::new();
585
586 group_mapping
587 .add_group("0".into(), Some(vec!["0".into(), "1".into()]), None)
588 .unwrap();
589
590 assert_eq!(
591 2,
592 group_mapping.nodes_in_group(&"0".into()).unwrap().count()
593 );
594 }
595
596 #[test]
597 fn test_invalid_nodes_in_group() {
598 let group_mapping = GroupMapping::new();
599
600 assert!(group_mapping
602 .nodes_in_group(&"0".into())
603 .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
604 }
605
606 #[test]
607 fn test_edges_in_group() {
608 let mut group_mapping = GroupMapping::new();
609
610 group_mapping
611 .add_group("0".into(), None, Some(vec![0, 1]))
612 .unwrap();
613
614 assert_eq!(
615 2,
616 group_mapping.edges_in_group(&"0".into()).unwrap().count()
617 );
618 }
619
620 #[test]
621 fn test_invalid_edges_in_group() {
622 let group_mapping = GroupMapping::new();
623
624 assert!(group_mapping
626 .edges_in_group(&"0".into())
627 .is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));
628 }
629
630 #[test]
631 fn test_groups_of_node() {
632 let mut group_mapping = GroupMapping::new();
633
634 group_mapping
635 .add_group("0".into(), Some(vec!["0".into()]), None)
636 .unwrap();
637
638 assert_eq!(1, group_mapping.groups_of_node(&"0".into()).count());
639 }
640
641 #[test]
642 fn test_groups_of_edge() {
643 let mut group_mapping = GroupMapping::new();
644
645 group_mapping
646 .add_group("0".into(), None, Some(vec![0]))
647 .unwrap();
648
649 assert_eq!(1, group_mapping.groups_of_edge(&0).count());
650 }
651
652 #[test]
653 fn test_group_count() {
654 let mut group_mapping = GroupMapping::new();
655
656 assert_eq!(0, group_mapping.group_count());
657
658 group_mapping.add_group("0".into(), None, None).unwrap();
659
660 assert_eq!(1, group_mapping.group_count());
661 }
662
663 #[test]
664 fn test_contains_group() {
665 let mut group_mapping = GroupMapping::new();
666
667 assert!(!group_mapping.contains_group(&"0".into()));
668
669 group_mapping.add_group("0".into(), None, None).unwrap();
670
671 assert!(group_mapping.contains_group(&"0".into()));
672 }
673
674 #[test]
675 fn test_clear() {
676 let mut group_mapping = GroupMapping::new();
677
678 group_mapping.add_group("0".into(), None, None).unwrap();
679
680 assert_eq!(1, group_mapping.group_count());
681
682 group_mapping.clear();
683
684 assert_eq!(0, group_mapping.group_count());
685 }
686}