1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use hibitset::{BitProducer, BitSetLike};
use rayon::iter::{
    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
    ParallelIterator,
};

use crate::world::Index;

/// The purpose of the `ParJoin` trait is to provide a way
/// to access multiple storages in parallel at the same time with
/// the merged bit set.
///
/// # Safety
///
/// `ParJoin::get` must be callable from multiple threads, simultaneously.
///
/// The `Self::Mask` value returned with the `Self::Value` must correspond such
/// that it is safe to retrieve items from `Self::Value` whose presence is
/// indicated in the mask. As part of this, `BitSetLike::iter` must not produce
/// an iterator that repeats an `Index` value.
pub unsafe trait ParJoin {
    /// Type of joined components.
    type Type;
    /// Type of joined storages.
    type Value;
    /// Type of joined bit mask.
    type Mask: BitSetLike;

    /// Create a joined parallel iterator over the contents.
    fn par_join(self) -> JoinParIter<Self>
    where
        Self: Sized,
    {
        if Self::is_unconstrained() {
            log::warn!(
                "`ParJoin` possibly iterating through all indices, \
                you might've made a join with all `MaybeJoin`s, \
                which is unbounded in length."
            );
        }

        JoinParIter(self)
    }

    /// Open this join by returning the mask and the storages.
    ///
    /// # Safety
    ///
    /// This is unsafe because implementations of this trait can permit the
    /// `Value` to be mutated independently of the `Mask`. If the `Mask` does
    /// not correctly report the status of the `Value` then illegal memory
    /// access can occur.
    unsafe fn open(self) -> (Self::Mask, Self::Value);

    /// Get a joined component value by a given index.
    ///
    /// # Safety
    ///
    /// * A call to `get` must be preceded by a check if `id` is part of
    ///   `Self::Mask`.
    /// * The value returned from this method must no longer be alive before
    ///   subsequent calls with the same `id`.
    unsafe fn get(value: &Self::Value, id: Index) -> Self::Type;

    /// If this `LendJoin` typically returns all indices in the mask, then
    /// iterating over only it or combined with other joins that are also
    /// dangerous will cause the `JoinLendIter` to go through all indices which
    /// is usually not what is wanted and will kill performance.
    #[inline]
    fn is_unconstrained() -> bool {
        false
    }
}

/// `JoinParIter` is a `ParallelIterator` over a group of storages.
#[must_use]
pub struct JoinParIter<J>(J);

impl<J> ParallelIterator for JoinParIter<J>
where
    J: ParJoin + Send,
    J::Mask: Send + Sync,
    J::Type: Send,
    J::Value: Send + Sync,
{
    type Item = J::Type;

    fn drive_unindexed<C>(self, consumer: C) -> C::Result
    where
        C: UnindexedConsumer<Self::Item>,
    {
        // SAFETY: `keys` and `values` are not exposed outside this module and
        // we only use `values` for calling `ParJoin::get`.
        let (keys, values) = unsafe { self.0.open() };
        // Create a bit producer which splits on up to three levels
        let producer = BitProducer((&keys).iter(), 3);

        bridge_unindexed(JoinProducer::<J>::new(producer, &values), consumer)
    }
}

struct JoinProducer<'a, J>
where
    J: ParJoin + Send,
    J::Mask: Send + Sync + 'a,
    J::Type: Send,
    J::Value: Send + Sync + 'a,
{
    keys: BitProducer<'a, J::Mask>,
    values: &'a J::Value,
}

impl<'a, J> JoinProducer<'a, J>
where
    J: ParJoin + Send,
    J::Type: Send,
    J::Value: 'a + Send + Sync,
    J::Mask: 'a + Send + Sync,
{
    fn new(keys: BitProducer<'a, J::Mask>, values: &'a J::Value) -> Self {
        JoinProducer { keys, values }
    }
}

impl<'a, J> UnindexedProducer for JoinProducer<'a, J>
where
    J: ParJoin + Send,
    J::Type: Send,
    J::Value: 'a + Send + Sync,
    J::Mask: 'a + Send + Sync,
{
    type Item = J::Type;

    fn split(self) -> (Self, Option<Self>) {
        let (cur, other) = self.keys.split();
        let values = self.values;
        let first = JoinProducer::new(cur, values);
        let second = other.map(|o| JoinProducer::new(o, values));

        (first, second)
    }

    fn fold_with<F>(self, folder: F) -> F
    where
        F: Folder<Self::Item>,
    {
        let JoinProducer { values, keys, .. } = self;
        // SAFETY: `idx` is obtained from the `Mask` returned by
        // `ParJoin::open`. The indices here are guaranteed to be distinct
        // because of the fact that the bit set is split and because `ParJoin`
        // requires that the bit set iterator doesn't repeat indices.
        let iter = keys.0.map(|idx| unsafe { J::get(values, idx) });

        folder.consume_iter(iter)
    }
}