1use std::{collections::HashSet, hash::Hash};
2
3use thiserror::Error;
4
5pub trait Resources: Default {
7 fn union(&mut self, other: &Self);
9 fn conflicts_with(&self, other: &Self) -> bool;
12}
13
14#[derive(Debug, Error)]
15#[error("resource conflict in {type_name:?}")]
16pub struct ResourceConflict {
17 pub type_name: &'static str,
18}
19
20pub struct RwResources<R> {
25 reads: HashSet<R>,
26 writes: HashSet<R>,
27}
28
29impl<R> Default for RwResources<R>
30where
31 R: Eq + Hash,
32{
33 fn default() -> Self {
34 RwResources {
35 reads: HashSet::new(),
36 writes: HashSet::new(),
37 }
38 }
39}
40
41impl<R> RwResources<R>
42where
43 R: Eq + Hash,
44{
45 pub fn new() -> Self {
46 Default::default()
47 }
48
49 pub fn from_iters(
50 reads: impl IntoIterator<Item = R>,
51 writes: impl IntoIterator<Item = R>,
52 ) -> Self {
53 let writes: HashSet<R> = writes.into_iter().collect();
54 let reads: HashSet<R> = reads.into_iter().filter(|r| !writes.contains(r)).collect();
55 RwResources { reads, writes }
56 }
57
58 pub fn add_read(&mut self, r: R) {
59 if !self.writes.contains(&r) {
60 self.reads.insert(r);
61 }
62 }
63
64 pub fn add_write(&mut self, r: R) {
65 self.reads.remove(&r);
66 self.writes.insert(r);
67 }
68
69 pub fn read(mut self, r: R) -> Self {
70 self.add_read(r);
71 self
72 }
73
74 pub fn write(mut self, r: R) -> Self {
75 self.add_write(r);
76 self
77 }
78}
79
80impl<R: Eq + Hash + Clone> Resources for RwResources<R> {
81 fn union(&mut self, other: &Self) {
82 for w in &other.writes {
83 self.writes.insert(w.clone());
84 }
85
86 for r in &other.reads {
87 if !self.writes.contains(r) {
88 self.reads.insert(r.clone());
89 }
90 }
91 }
92
93 fn conflicts_with(&self, other: &Self) -> bool {
94 self.writes.intersection(&other.reads).next().is_some()
95 || self.writes.intersection(&other.writes).next().is_some()
96 || other.writes.intersection(&self.reads).next().is_some()
97 || other.writes.intersection(&self.writes).next().is_some()
98 }
99}