dep_res/
lib.rs

1use dashmap::{DashMap, DashSet};
2use rayon::{iter::Either, prelude::*};
3use std::{
4    hash::Hash,
5    ops::Deref,
6    rc::Rc,
7    sync::{atomic::AtomicBool, Arc},
8};
9use thiserror::Error;
10use tuples::TupleCloned;
11
12#[cfg(test)]
13mod tests;
14
15pub trait DepMeta {
16    type Id: Eq + Hash + Clone;
17
18    fn get_id(&self) -> Self::Id;
19
20    fn get_deps(&self) -> &[Self::Id];
21}
22
23mod impls {
24    use crate::*;
25
26    impl<T: DepMeta> DepMeta for &T {
27        type Id = T::Id;
28
29        fn get_id(&self) -> Self::Id {
30            (**self).get_id()
31        }
32
33        fn get_deps(&self) -> &[Self::Id] {
34            (**self).get_deps()
35        }
36    }
37
38    impl<T: DepMeta> DepMeta for Rc<T> {
39        type Id = T::Id;
40
41        fn get_id(&self) -> Self::Id {
42            self.deref().get_id()
43        }
44
45        fn get_deps(&self) -> &[Self::Id] {
46            self.deref().get_deps()
47        }
48    }
49
50    impl<T: DepMeta> DepMeta for Box<T> {
51        type Id = T::Id;
52
53        fn get_id(&self) -> Self::Id {
54            self.deref().get_id()
55        }
56
57        fn get_deps(&self) -> &[Self::Id] {
58            self.deref().get_deps()
59        }
60    }
61
62    impl<T: DepMeta> DepMeta for Arc<T> {
63        type Id = T::Id;
64
65        fn get_id(&self) -> Self::Id {
66            self.deref().get_id()
67        }
68
69        fn get_deps(&self) -> &[Self::Id] {
70            self.deref().get_deps()
71        }
72    }
73}
74
75#[derive(Debug, Default)]
76pub struct DepRes<Id: Eq + Hash + Clone> {
77    ids: DashSet<Id>,
78    deps: DashMap<Id, DashSet<Id>>,
79}
80
81impl<Id: Eq + Hash + Clone> DepRes<Id> {
82    pub fn new() -> Self {
83        Self {
84            ids: DashSet::new(),
85            deps: DashMap::new(),
86        }
87    }
88}
89
90impl<Id: Sync + Send + Eq + Hash + Clone> DepRes<Id> {
91    pub fn add<'a>(
92        &self,
93        items: &'a impl IntoParallelRefIterator<'a, Item = impl DepMeta<Id = Id>>,
94    ) {
95        items.par_iter().for_each(|item| {
96            let id = item.get_id();
97            let deps = item.get_deps();
98            let has_dep = !deps.is_empty();
99            if has_dep {
100                deps.par_iter().for_each(|dep| {
101                    let dset = self
102                        .deps
103                        .entry(id.clone())
104                        .or_insert_with(|| DashSet::new());
105                    dset.insert(dep.clone());
106                });
107            }
108            self.ids.insert(id);
109        });
110    }
111}
112
113#[derive(Debug, Default, Clone)]
114pub struct ResolvedDeps<Id: Eq + Hash + Clone> {
115    lvs: DashMap<usize, Arc<DashSet<Id>>>,
116}
117
118#[derive(Debug, Default, Clone)]
119pub struct DepLevel<D> {
120    pub level: usize,
121    pub deps: D,
122}
123
124impl<Id: Eq + Hash + Clone> ResolvedDeps<Id> {
125    fn new(lvs: DashMap<usize, Arc<DashSet<Id>>>) -> Self {
126        Self { lvs }
127    }
128
129    pub fn sorted_by_level(&self) -> Vec<Id> {
130        let mut vec = self.lvs.iter().collect::<Vec<_>>();
131        vec.sort_by_key(|r| *r.key());
132        let ids = vec
133            .iter()
134            .flat_map(|r| r.value().iter().map(|a| a.clone()))
135            .collect::<Vec<_>>();
136        ids
137    }
138
139    pub fn raw_level(&self) -> &DashMap<usize, Arc<DashSet<Id>>> {
140        &self.lvs
141    }
142
143    pub fn iter_level(&self) -> impl Iterator<Item = DepLevel<Arc<DashSet<Id>>>> + '_ {
144        self.lvs.iter().map(|kv| DepLevel {
145            level: *kv.key(),
146            deps: kv.value().clone(),
147        })
148    }
149}
150
151impl<Id: Sync + Send + Eq + Hash + Clone> DepRes<Id> {
152    pub fn resolve(&mut self) -> Result<ResolvedDeps<Id>, DepResolveError> {
153        let lvs = DashMap::new();
154
155        if self.ids.is_empty() {
156            return Ok(ResolvedDeps::new(lvs));
157        }
158
159        let (lv0, other): (DashSet<Id>, DashSet<Id>) = self.ids.par_iter().partition_map(|kv| {
160            let id = kv.key().clone();
161            if let None = self.deps.get(&id) {
162                Either::Left(id.clone())
163            } else {
164                Either::Right(id.clone())
165            }
166        });
167        if lv0.is_empty() {
168            return Err(DepResolveError::IslandsOrCircular);
169        }
170        let lv0 = Arc::new(lv0);
171        lvs.insert(0, lv0.cloned());
172
173        let mut last = lv0;
174        let mut other = other;
175        let mut new_other = DashSet::new();
176        let internal_data_error = AtomicBool::new(false);
177        let mut lv = 1;
178        loop {
179            let lvn = DashSet::new();
180            new_other.clear();
181
182            other.par_iter().for_each(|id| {
183                if let Some(deps) = self.deps.get(&*id) {
184                    if deps.par_iter().any(|id| last.contains(&*id)) {
185                        lvn.insert(id.clone());
186                    } else {
187                        new_other.insert(id.clone());
188                    }
189                } else {
190                    internal_data_error.store(true, std::sync::atomic::Ordering::Relaxed);
191                }
192            });
193
194            if internal_data_error.load(std::sync::atomic::Ordering::Relaxed) {
195                return Err(DepResolveError::InternalDataError);
196            }
197
198            if lvn.is_empty() {
199                if other.is_empty() {
200                    return Ok(ResolvedDeps::new(lvs));
201                } else {
202                    return Err(DepResolveError::IslandsOrCircular);
203                }
204            }
205
206            let lvn = Arc::new(lvn);
207            lvs.insert(lv, lvn.cloned());
208            last = lvn;
209            std::mem::swap(&mut other, &mut new_other);
210            lv += 1;
211        }
212    }
213}
214
215#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)]
216pub enum DepResolveError {
217    #[error("There are islands or circular reference dependencies")]
218    IslandsOrCircular,
219    #[error("internal data error")]
220    InternalDataError,
221}