Skip to main content

goggles/
resources.rs

1use std::{collections::HashSet, hash::Hash};
2
3use thiserror::Error;
4
5/// Trait for identifying accessed 'resources' that may conflict if used at the same time.
6pub trait Resources: Default {
7    /// Union this set of resources with the given set of resources.
8    fn union(&mut self, other: &Self);
9    /// Return true if any resource in this set may not be used at the same time with any resource
10    /// in the other set.
11    fn conflicts_with(&self, other: &Self) -> bool;
12}
13
14#[derive(Debug, Error)]
15#[error("resource conflict in {type_name:?}")]
16pub struct ResourceConflict {
17    pub type_name: &'static str,
18}
19
20/// A `Resources` implementation that describes R/W locks.
21///
22/// Two read locks for the same resource do not conflict, but a read and a write or two writes to
23/// the same resource do.
24pub struct RwResources<R> {
25    reads: HashSet<R>,
26    writes: HashSet<R>,
27}
28
29impl<R> Default for RwResources<R>
30where
31    R: Eq + Hash,
32{
33    fn default() -> Self {
34        RwResources {
35            reads: HashSet::new(),
36            writes: HashSet::new(),
37        }
38    }
39}
40
41impl<R> RwResources<R>
42where
43    R: Eq + Hash,
44{
45    pub fn new() -> Self {
46        Default::default()
47    }
48
49    pub fn from_iters(
50        reads: impl IntoIterator<Item = R>,
51        writes: impl IntoIterator<Item = R>,
52    ) -> Self {
53        let writes: HashSet<R> = writes.into_iter().collect();
54        let reads: HashSet<R> = reads.into_iter().filter(|r| !writes.contains(r)).collect();
55        RwResources { reads, writes }
56    }
57
58    pub fn add_read(&mut self, r: R) {
59        if !self.writes.contains(&r) {
60            self.reads.insert(r);
61        }
62    }
63
64    pub fn add_write(&mut self, r: R) {
65        self.reads.remove(&r);
66        self.writes.insert(r);
67    }
68
69    pub fn read(mut self, r: R) -> Self {
70        self.add_read(r);
71        self
72    }
73
74    pub fn write(mut self, r: R) -> Self {
75        self.add_write(r);
76        self
77    }
78}
79
80impl<R: Eq + Hash + Clone> Resources for RwResources<R> {
81    fn union(&mut self, other: &Self) {
82        for w in &other.writes {
83            self.writes.insert(w.clone());
84        }
85
86        for r in &other.reads {
87            if !self.writes.contains(r) {
88                self.reads.insert(r.clone());
89            }
90        }
91    }
92
93    fn conflicts_with(&self, other: &Self) -> bool {
94        self.writes.intersection(&other.reads).next().is_some()
95            || self.writes.intersection(&other.writes).next().is_some()
96            || other.writes.intersection(&self.reads).next().is_some()
97            || other.writes.intersection(&self.writes).next().is_some()
98    }
99}