1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use std::{collections::HashSet, hash::Hash};

use thiserror::Error;

/// Trait for identifying accessed 'resources' that may conflict if used at the same time.
pub trait Resources: Default {
    /// Union this set of resources with the given set of resources.
    fn union(&mut self, other: &Self);
    /// Return true if any resource in this set may not be used at the same time with any resource
    /// in the other set.
    fn conflicts_with(&self, other: &Self) -> bool;
}

#[derive(Debug, Error)]
#[error("resource conflict in {type_name:?}")]
pub struct ResourceConflict {
    pub type_name: &'static str,
}

/// A `Resources` implementation that describes R/W locks.
///
/// Two read locks for the same resource do not conflict, but a read and a write or two writes to
/// the same resource do.
pub struct RwResources<R> {
    reads: HashSet<R>,
    writes: HashSet<R>,
}

impl<R> Default for RwResources<R>
where
    R: Eq + Hash,
{
    fn default() -> Self {
        RwResources {
            reads: HashSet::new(),
            writes: HashSet::new(),
        }
    }
}

impl<R> RwResources<R>
where
    R: Eq + Hash,
{
    pub fn new() -> Self {
        Default::default()
    }

    pub fn from_iters(
        reads: impl IntoIterator<Item = R>,
        writes: impl IntoIterator<Item = R>,
    ) -> Self {
        let writes: HashSet<R> = writes.into_iter().collect();
        let reads: HashSet<R> = reads.into_iter().filter(|r| !writes.contains(r)).collect();
        RwResources { reads, writes }
    }

    pub fn add_read(&mut self, r: R) {
        if !self.writes.contains(&r) {
            self.reads.insert(r);
        }
    }

    pub fn add_write(&mut self, r: R) {
        self.reads.remove(&r);
        self.writes.insert(r);
    }

    pub fn read(mut self, r: R) -> Self {
        self.add_read(r);
        self
    }

    pub fn write(mut self, r: R) -> Self {
        self.add_write(r);
        self
    }
}

impl<R: Eq + Hash + Clone> Resources for RwResources<R> {
    fn union(&mut self, other: &Self) {
        for w in &other.writes {
            self.writes.insert(w.clone());
        }

        for r in &other.reads {
            if !self.writes.contains(r) {
                self.reads.insert(r.clone());
            }
        }
    }

    fn conflicts_with(&self, other: &Self) -> bool {
        self.writes.intersection(&other.reads).next().is_some()
            || self.writes.intersection(&other.writes).next().is_some()
            || other.writes.intersection(&self.reads).next().is_some()
            || other.writes.intersection(&self.writes).next().is_some()
    }
}