1use crate::*;
2use py_ir::types::TypeDefine;
3use std::collections::{HashMap, HashSet};
4use terl::{Span, WithSpan};
5
6#[derive(Default, Debug)]
9pub struct DeclareGraph {
10 pub(crate) groups: Vec<DeclareGroup>,
11 pub(crate) deps: HashMap<Branch, HashSet<Branch>>,
15 pub(crate) rdeps: HashMap<Branch, HashSet<Branch>>,
16}
17
18impl DeclareGraph {
19 pub fn new() -> Self {
20 Self::default()
21 }
22
23 fn insert_depends(&mut self, who: Branch, depend: HashSet<Branch>) {
24 if depend.is_empty() {
25 return;
26 }
27 for &depend in &depend {
28 self.rdeps.entry(depend).or_default().insert(who);
29 }
30 self.deps.entry(who).or_default().extend(depend);
31 }
32
33 pub(crate) fn new_group_inner(
34 &mut self,
35 span: Span,
36 failds: HashMap<usize, DeclareError>,
37 status: DeclareState,
38 ) -> GroupIdx {
39 let idx = GroupIdx {
40 idx: self.groups.len(),
41 };
42 self.groups
43 .push(DeclareGroup::new(span, idx, failds, status));
44 idx
45 }
46
47 pub fn new_static_group<I>(&mut self, at: terl::Span, items: I) -> GroupIdx
48 where
49 I: IntoIterator<Item = Type>,
50 {
51 self.new_group_inner(at, Default::default(), DeclareState::from_iter(items))
52 }
53
54 pub fn build_group(&mut self, gb: GroupBuilder) -> GroupIdx {
55 let gidx = GroupIdx::new(self.groups.len());
56
57 let mut alives = HashMap::new();
58 let mut failds = HashMap::new();
59
60 #[derive(Debug)]
61 enum BranchMark {
62 Used,
63 Error(DeclareError),
64 }
65
66 let mut used_branches: HashMap<GroupIdx, HashMap<usize, BranchMark>> = HashMap::new();
67 for branch_builder in gb.branches {
68 let ty = match branch_builder.state {
69 Ok(ty) => ty,
70 Err(e) => {
71 failds.insert(alives.len() + failds.len(), e);
72 continue;
73 }
74 };
75
76 let use_branch = |branch: Branch| -> Branch {
77 used_branches
78 .entry(branch.belong_to)
79 .or_default()
80 .insert(branch.branch_idx, BranchMark::Used);
81 branch
82 };
83
84 match branch_builder.depends.merge_depends(use_branch) {
85 Ok(branch_depends) if branch_depends.is_empty() => {
87 let new_branch = Branch::new(gidx, alives.len() + failds.len());
88 alives.insert(new_branch.branch_idx, ty.clone());
89 }
90 Ok(branch_depends) => {
91 for branch_depends in branch_depends {
92 let new_branch = Branch::new(gidx, alives.len() + failds.len());
93 self.insert_depends(new_branch, branch_depends);
94 alives.insert(new_branch.branch_idx, ty.clone());
95 }
96 }
97 Err(group_errors) => {
98 for (group, errors) in group_errors {
99 let group = used_branches.entry(group).or_default();
100 for (branch, error) in errors {
101 group.entry(branch).or_insert(BranchMark::Error(error));
102 }
103 }
104 }
105 };
106 }
107
108 let new_group = self.new_group_inner(gb.span, failds, alives.into());
109
110 for (group, mut branch_marks) in used_branches {
111 let remove: HashMap<_, _> = self[group].filter_alive(|branch, ty| {
113 match branch_marks.remove(&branch.branch_idx) {
114 Some(BranchMark::Used) => Ok((branch, ty)),
115 Some(BranchMark::Error(error)) => Err(error.with_location(gb.span)),
116 None => Err(DeclareError::NeverUsed {
117 in_group: gb.span,
118 reason: None,
119 }),
120 }
121 .map_err(|e| e.into_shared())
122 });
123 for (branch, reason) in remove {
125 self.remove_branch(branch, reason);
126 }
127 }
128
129 new_group
130 }
131
132 pub fn apply_filter<T, B>(&mut self, gidx: GroupIdx, defs: &Defs, filter: B)
133 where
134 T: Types,
135 B: BranchFilter<T>,
136 {
137 let location = self[gidx].get_span();
138 let reason = || {
139 DeclareError::Unexpect {
140 expect: filter.expect(defs),
141 }
142 .with_location(location)
143 .into_shared()
144 };
145 let removed = self[gidx].remove_branches(|_, ty| !filter.satisfy(ty), reason);
146
147 for (branch, reason) in removed {
148 self.remove_branch(branch, reason);
149 }
150 }
151
152 pub fn merge_group(&mut self, at: terl::Span, base: GroupIdx, from: GroupIdx) {
153 let bases = self[from].alives(|alives| {
154 alives
155 .map(|(branch, ty)| (branch, ty.get_type()))
156 .collect::<Vec<_>>()
157 });
158 let exists = self[base].alives(|alives| {
159 alives
160 .map(|(branch, ty)| (branch, ty.get_type()))
161 .collect::<Vec<_>>()
162 });
163
164 let merge = exists
166 .iter()
167 .flat_map(|&(branch, ty)| {
168 bases
169 .iter()
170 .filter(move |(.., f_ty)| *f_ty == ty)
171 .map(move |&(f_branch, ..)| (branch, f_branch, ty))
172 })
173 .collect::<Vec<_>>();
174
175 let (base_kept, from_kept): (HashSet<_>, HashSet<_>) =
176 merge.iter().map(|(base, from, _)| (*base, *from)).unzip();
177
178 let removed = bases
179 .iter()
180 .map(|(branch, ..)| *branch)
181 .filter(|branch| !from_kept.contains(branch))
182 .chain(
183 exists
184 .iter()
185 .map(|(branch, ..)| *branch)
186 .filter(|branch| !base_kept.contains(branch)),
187 )
188 .collect::<Vec<_>>();
189
190 let remove_reason = DeclareError::Filtered.with_location(at).into_shared();
192
193 for remove in removed {
194 self.remove_branch(remove, remove_reason.clone());
195 }
196 }
197
198 pub fn declare_type(&mut self, at: terl::Span, gidx: GroupIdx, expect_ty: &TypeDefine) {
203 let group = &mut self[gidx];
204 let reason = || {
207 DeclareError::Unexpect {
208 expect: expect_ty.to_string(),
209 }
210 .with_location(at)
211 .into_shared()
212 };
213 for (branch, remove) in group.remove_branches(|_, ty| ty.get_type() != expect_ty, reason) {
214 self.remove_branch(branch, remove);
215 }
216 }
217
218 pub(crate) fn remove_branch(&mut self, branch: Branch, reason: DeclareError) {
227 let group = &mut self[branch.belong_to];
229 let group_loc = group.get_span();
230 {
231 let reason = group.remove_branch(branch.branch_idx, reason.clone());
232 group.push_error(branch.branch_idx, reason.clone());
233 }
234
235 if let Some(rdeps) = self.rdeps.remove(&branch) {
237 for rdep in rdeps {
238 self.remove_branch(rdep, reason.clone());
239 }
240 }
241 if let Some(deps) = self.deps.remove(&branch) {
243 for dep in deps {
244 match self.rdeps.get_mut(&dep) {
245 Some(rdeps) if rdeps.len() == 1 => {
246 let reason = DeclareError::NeverUsed {
247 in_group: group_loc,
248 reason: Some(reason.clone().into()),
249 };
250 self.remove_branch(dep, reason);
251 }
252 Some(rdeps) => {
253 rdeps.remove(&branch);
254 }
255 None => {}
256 }
257 }
258 }
259 }
260
261 pub fn declare_all(&mut self) -> Result<(), Vec<terl::Error>> {
262 let mut errors = vec![];
263 for group in &self.groups {
264 if !group.is_declared() {
266 errors.push(group.make_error());
267 }
268 }
269 if errors.is_empty() {
270 Ok(())
271 } else {
272 Err(errors)
273 }
274 }
275
276 pub fn get_type(&self, gidx: GroupIdx) -> &TypeDefine {
277 self[gidx].result().get_type()
278 }
279}
280
281impl std::ops::Index<GroupIdx> for DeclareGraph {
282 type Output = DeclareGroup;
283
284 fn index(&self, index: GroupIdx) -> &Self::Output {
285 &self.groups[index.idx]
286 }
287}
288
289impl std::ops::IndexMut<GroupIdx> for DeclareGraph {
290 fn index_mut(&mut self, index: GroupIdx) -> &mut Self::Output {
291 &mut self.groups[index.idx]
292 }
293}
294
295impl std::ops::Index<Branch> for DeclareGraph {
296 type Output = Type;
297
298 fn index(&self, index: Branch) -> &Self::Output {
299 self[index.belong_to].get_branch(index)
300 }
301}