py_declare/
graph.rs

1use crate::*;
2use py_ir::types::TypeDefine;
3use std::collections::{HashMap, HashSet};
4use terl::{Span, WithSpan};
5
6/// used to declare which overload of function is called, or which possiable type is
7///
8#[derive(Default, Debug)]
9pub struct DeclareGraph {
10    pub(crate) groups: Vec<DeclareGroup>,
11    /// deps means that [`Branch`] depend **ALL** of them
12    ///
13    /// if any of them is impossible, the [`Branch`] will be removed, too
14    pub(crate) deps: HashMap<Branch, HashSet<Branch>>,
15    pub(crate) rdeps: HashMap<Branch, HashSet<Branch>>,
16}
17
18impl DeclareGraph {
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    fn insert_depends(&mut self, who: Branch, depend: HashSet<Branch>) {
24        if depend.is_empty() {
25            return;
26        }
27        for &depend in &depend {
28            self.rdeps.entry(depend).or_default().insert(who);
29        }
30        self.deps.entry(who).or_default().extend(depend);
31    }
32
33    pub(crate) fn new_group_inner(
34        &mut self,
35        span: Span,
36        failds: HashMap<usize, DeclareError>,
37        status: DeclareState,
38    ) -> GroupIdx {
39        let idx = GroupIdx {
40            idx: self.groups.len(),
41        };
42        self.groups
43            .push(DeclareGroup::new(span, idx, failds, status));
44        idx
45    }
46
47    pub fn new_static_group<I>(&mut self, at: terl::Span, items: I) -> GroupIdx
48    where
49        I: IntoIterator<Item = Type>,
50    {
51        self.new_group_inner(at, Default::default(), DeclareState::from_iter(items))
52    }
53
54    pub fn build_group(&mut self, gb: GroupBuilder) -> GroupIdx {
55        let gidx = GroupIdx::new(self.groups.len());
56
57        let mut alives = HashMap::new();
58        let mut failds = HashMap::new();
59
60        #[derive(Debug)]
61        enum BranchMark {
62            Used,
63            Error(DeclareError),
64        }
65
66        let mut used_branches: HashMap<GroupIdx, HashMap<usize, BranchMark>> = HashMap::new();
67        for branch_builder in gb.branches {
68            let ty = match branch_builder.state {
69                Ok(ty) => ty,
70                Err(e) => {
71                    failds.insert(alives.len() + failds.len(), e);
72                    continue;
73                }
74            };
75
76            let use_branch = |branch: Branch| -> Branch {
77                used_branches
78                    .entry(branch.belong_to)
79                    .or_default()
80                    .insert(branch.branch_idx, BranchMark::Used);
81                branch
82            };
83
84            match branch_builder.depends.merge_depends(use_branch) {
85                // may the branch doesnot depend on any other branches
86                Ok(branch_depends) if branch_depends.is_empty() => {
87                    let new_branch = Branch::new(gidx, alives.len() + failds.len());
88                    alives.insert(new_branch.branch_idx, ty.clone());
89                }
90                Ok(branch_depends) => {
91                    for branch_depends in branch_depends {
92                        let new_branch = Branch::new(gidx, alives.len() + failds.len());
93                        self.insert_depends(new_branch, branch_depends);
94                        alives.insert(new_branch.branch_idx, ty.clone());
95                    }
96                }
97                Err(group_errors) => {
98                    for (group, errors) in group_errors {
99                        let group = used_branches.entry(group).or_default();
100                        for (branch, error) in errors {
101                            group.entry(branch).or_insert(BranchMark::Error(error));
102                        }
103                    }
104                }
105            };
106        }
107
108        let new_group = self.new_group_inner(gb.span, failds, alives.into());
109
110        for (group, mut branch_marks) in used_branches {
111            // use hashmap avoiding remove same branch more than once
112            let remove: HashMap<_, _> = self[group].filter_alive(|branch, ty| {
113                match branch_marks.remove(&branch.branch_idx) {
114                    Some(BranchMark::Used) => Ok((branch, ty)),
115                    Some(BranchMark::Error(error)) => Err(error.with_location(gb.span)),
116                    None => Err(DeclareError::NeverUsed {
117                        in_group: gb.span,
118                        reason: None,
119                    }),
120                }
121                .map_err(|e| e.into_shared())
122            });
123            //
124            for (branch, reason) in remove {
125                self.remove_branch(branch, reason);
126            }
127        }
128
129        new_group
130    }
131
132    pub fn apply_filter<T, B>(&mut self, gidx: GroupIdx, defs: &Defs, filter: B)
133    where
134        T: Types,
135        B: BranchFilter<T>,
136    {
137        let location = self[gidx].get_span();
138        let reason = || {
139            DeclareError::Unexpect {
140                expect: filter.expect(defs),
141            }
142            .with_location(location)
143            .into_shared()
144        };
145        let removed = self[gidx].remove_branches(|_, ty| !filter.satisfy(ty), reason);
146
147        for (branch, reason) in removed {
148            self.remove_branch(branch, reason);
149        }
150    }
151
152    pub fn merge_group(&mut self, at: terl::Span, base: GroupIdx, from: GroupIdx) {
153        let bases = self[from].alives(|alives| {
154            alives
155                .map(|(branch, ty)| (branch, ty.get_type()))
156                .collect::<Vec<_>>()
157        });
158        let exists = self[base].alives(|alives| {
159            alives
160                .map(|(branch, ty)| (branch, ty.get_type()))
161                .collect::<Vec<_>>()
162        });
163
164        // to_branch, from_branch, type
165        let merge = exists
166            .iter()
167            .flat_map(|&(branch, ty)| {
168                bases
169                    .iter()
170                    .filter(move |(.., f_ty)| *f_ty == ty)
171                    .map(move |&(f_branch, ..)| (branch, f_branch, ty))
172            })
173            .collect::<Vec<_>>();
174
175        let (base_kept, from_kept): (HashSet<_>, HashSet<_>) =
176            merge.iter().map(|(base, from, _)| (*base, *from)).unzip();
177
178        let removed = bases
179            .iter()
180            .map(|(branch, ..)| *branch)
181            .filter(|branch| !from_kept.contains(branch))
182            .chain(
183                exists
184                    .iter()
185                    .map(|(branch, ..)| *branch)
186                    .filter(|branch| !base_kept.contains(branch)),
187            )
188            .collect::<Vec<_>>();
189
190        // TODO: improve error message here
191        let remove_reason = DeclareError::Filtered.with_location(at).into_shared();
192
193        for remove in removed {
194            self.remove_branch(remove, remove_reason.clone());
195        }
196    }
197
198    /// declare a [`DeclareGroup`]'s result is a type
199    ///
200    /// return [`Err`] if the type has be declared and isn't given type,
201    /// or non of [`Branch`] match the given type
202    pub fn declare_type(&mut self, at: terl::Span, gidx: GroupIdx, expect_ty: &TypeDefine) {
203        let group = &mut self[gidx];
204        // TODO: unknown type support
205
206        let reason = || {
207            DeclareError::Unexpect {
208                expect: expect_ty.to_string(),
209            }
210            .with_location(at)
211            .into_shared()
212        };
213        for (branch, remove) in group.remove_branches(|_, ty| ty.get_type() != expect_ty, reason) {
214            self.remove_branch(branch, remove);
215        }
216    }
217
218    /// Zhu double eight: is your Nine Clan([`Branch`]) wholesale?
219    ///
220    /// `remove` a node, and all node which must depend on it, and then a generate a error
221    /// with [`Type`] which previous branch stored in
222    ///
223    /// # Note:
224    ///
225    /// make sure the reason passed in are wrapped by rc(by calling [`DeclareError::into_shared`])
226    pub(crate) fn remove_branch(&mut self, branch: Branch, reason: DeclareError) {
227        // is it impossiable to be a cycle dep in map?
228        let group = &mut self[branch.belong_to];
229        let group_loc = group.get_span();
230        {
231            let reason = group.remove_branch(branch.branch_idx, reason.clone());
232            group.push_error(branch.branch_idx, reason.clone());
233        }
234
235        // remove all branches depend on removed branch
236        if let Some(rdeps) = self.rdeps.remove(&branch) {
237            for rdep in rdeps {
238                self.remove_branch(rdep, reason.clone());
239            }
240        }
241        // remove the record of all branch which removed branch depend on
242        if let Some(deps) = self.deps.remove(&branch) {
243            for dep in deps {
244                match self.rdeps.get_mut(&dep) {
245                    Some(rdeps) if rdeps.len() == 1 => {
246                        let reason = DeclareError::NeverUsed {
247                            in_group: group_loc,
248                            reason: Some(reason.clone().into()),
249                        };
250                        self.remove_branch(dep, reason);
251                    }
252                    Some(rdeps) => {
253                        rdeps.remove(&branch);
254                    }
255                    None => {}
256                }
257            }
258        }
259    }
260
261    pub fn declare_all(&mut self) -> Result<(), Vec<terl::Error>> {
262        let mut errors = vec![];
263        for group in &self.groups {
264            // un-declared group
265            if !group.is_declared() {
266                errors.push(group.make_error());
267            }
268        }
269        if errors.is_empty() {
270            Ok(())
271        } else {
272            Err(errors)
273        }
274    }
275
276    pub fn get_type(&self, gidx: GroupIdx) -> &TypeDefine {
277        self[gidx].result().get_type()
278    }
279}
280
281impl std::ops::Index<GroupIdx> for DeclareGraph {
282    type Output = DeclareGroup;
283
284    fn index(&self, index: GroupIdx) -> &Self::Output {
285        &self.groups[index.idx]
286    }
287}
288
289impl std::ops::IndexMut<GroupIdx> for DeclareGraph {
290    fn index_mut(&mut self, index: GroupIdx) -> &mut Self::Output {
291        &mut self.groups[index.idx]
292    }
293}
294
295impl std::ops::Index<Branch> for DeclareGraph {
296    type Output = Type;
297
298    fn index(&self, index: Branch) -> &Self::Output {
299        self[index.belong_to].get_branch(index)
300    }
301}