Skip to main content

chomsky_extract/
lib.rs

1#![warn(missing_docs)]
2
3use chomsky_cost::{Cost, CostModel};
4use chomsky_uir::egraph::{Analysis, EGraph, HasDebugInfo, Id, Language};
5pub use chomsky_uir::{IKun, IKunTree};
6use std::collections::HashMap;
7
8use chomsky_types::ChomskyResult;
9use chomsky_uir::intent::CrossLanguageCall;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum BackendArtifact {
13    Source(String),
14    Binary(Vec<u8>),
15    Assembly(String),
16    Collection(HashMap<String, Vec<u8>>),
17}
18
19/// The standard trait for all backends.
20/// Backends consume an extracted IKunTree and produce an artifact.
21pub trait Backend {
22    fn name(&self) -> &str;
23    fn generate(&self, tree: &IKunTree) -> ChomskyResult<BackendArtifact>;
24
25    /// 获取该后端的成本模型
26    fn get_model(&self) -> &dyn CostModel {
27        &chomsky_cost::DEFAULT_COST_MODEL
28    }
29}
30
31pub struct IKunExtractor<'a, A, C>
32where
33    A: Analysis<IKun>,
34    A::Data: HasDebugInfo,
35    C: CostModel,
36{
37    pub egraph: &'a EGraph<IKun, A>,
38    pub cost_model: C,
39    pub costs: HashMap<Id, (Cost, IKun)>,
40}
41
42impl<'a, A, C> IKunExtractor<'a, A, C>
43where
44    A: Analysis<IKun>,
45    A::Data: HasDebugInfo,
46    C: CostModel,
47{
48    pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
49        let mut extractor = Self {
50            egraph,
51            cost_model,
52            costs: HashMap::new(),
53        };
54        extractor.find_best();
55        extractor
56    }
57
58    fn find_best(&mut self) {
59        let mut changed = true;
60        while changed {
61            changed = false;
62            for entry in self.egraph.classes.iter() {
63                let id = *entry.key();
64                let eclass = entry.value();
65
66                for node in &eclass.nodes {
67                    let mut node_cost = self.cost_model.cost(node);
68                    let mut can_compute = true;
69
70                    for child in node.children() {
71                        let child_id = self.egraph.union_find.find(child);
72                        if let Some((child_cost, _)) = self.costs.get(&child_id) {
73                            node_cost = node_cost.add(child_cost);
74                        } else {
75                            can_compute = false;
76                            break;
77                        }
78                    }
79
80                    if can_compute {
81                        let current_best = self.costs.get(&id);
82
83                        if current_best.is_none()
84                            || node_cost.score() < current_best.unwrap().0.score()
85                        {
86                            self.costs.insert(id, (node_cost, node.clone()));
87                            changed = true;
88                        }
89                    }
90                }
91            }
92        }
93    }
94
95    pub fn extract(&self, id: Id) -> IKunTree {
96        let root = self.egraph.union_find.find(id);
97        let (_, node) = self.costs.get(&root).expect("No cost found for eclass");
98
99        let mut tree = match node {
100            IKun::Constant(v) => IKunTree::Constant(*v),
101            IKun::FloatConstant(v) => IKunTree::FloatConstant(*v),
102            IKun::BooleanConstant(v) => IKunTree::BooleanConstant(*v),
103            IKun::StringConstant(s) => IKunTree::StringConstant(s.clone()),
104            IKun::None => IKunTree::None,
105            IKun::Symbol(s) => IKunTree::Symbol(s.clone()),
106            IKun::Map(f, x) => {
107                IKunTree::Map(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
108            }
109            IKun::Filter(f, x) => {
110                IKunTree::Filter(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
111            }
112            IKun::Reduce(f, init, list) => IKunTree::Reduce(
113                Box::new(self.extract(*f)),
114                Box::new(self.extract(*init)),
115                Box::new(self.extract(*list)),
116            ),
117            IKun::StateUpdate(var, val) => {
118                IKunTree::StateUpdate(Box::new(self.extract(*var)), Box::new(self.extract(*val)))
119            }
120            IKun::Choice(cond, t, f) => IKunTree::Choice(
121                Box::new(self.extract(*cond)),
122                Box::new(self.extract(*t)),
123                Box::new(self.extract(*f)),
124            ),
125            IKun::Repeat(cond, body) => {
126                IKunTree::Repeat(Box::new(self.extract(*cond)), Box::new(self.extract(*body)))
127            }
128            IKun::LifeCycle(setup, cleanup) => IKunTree::LifeCycle(
129                Box::new(self.extract(*setup)),
130                Box::new(self.extract(*cleanup)),
131            ),
132            IKun::Meta(body) => IKunTree::Meta(Box::new(self.extract(*body))),
133            IKun::Trap(body) => IKunTree::Trap(Box::new(self.extract(*body))),
134            IKun::Return(val) => IKunTree::Return(Box::new(self.extract(*val))),
135            IKun::Seq(ids) => IKunTree::Seq(ids.iter().map(|&id| self.extract(id)).collect()),
136            IKun::Compose(a, b) => {
137                IKunTree::Compose(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
138            }
139            IKun::WithContext(ctx, body) => {
140                IKunTree::WithContext(Box::new(self.extract(*ctx)), Box::new(self.extract(*body)))
141            }
142            IKun::WithConstraint(constraint, body) => IKunTree::WithConstraint(
143                Box::new(self.extract(*constraint)),
144                Box::new(self.extract(*body)),
145            ),
146            IKun::CpuContext => IKunTree::CpuContext,
147            IKun::GpuContext => IKunTree::GpuContext,
148            IKun::AsyncContext => IKunTree::AsyncContext,
149            IKun::SpatialContext => IKunTree::SpatialContext,
150            IKun::ComptimeContext => IKunTree::ComptimeContext,
151            IKun::ResourceContext => IKunTree::ResourceContext,
152            IKun::SafeContext => IKunTree::SafeContext,
153            IKun::EffectConstraint(e) => IKunTree::EffectConstraint(e.clone()),
154            IKun::OwnershipConstraint(o) => IKunTree::OwnershipConstraint(o.clone()),
155            IKun::TypeConstraint(t) => IKunTree::TypeConstraint(t.clone()),
156            IKun::AtomicConstraint => IKunTree::AtomicConstraint,
157            IKun::Extension(name, args) => IKunTree::Extension(
158                name.clone(),
159                args.iter().map(|&id| self.extract(id)).collect(),
160            ),
161            IKun::CrossLangCall(CrossLanguageCall {
162                language: lang,
163                module_path: group,
164                function_name: func,
165                arguments: args,
166            }) => IKunTree::CrossLangCall {
167                language: lang.clone(),
168                module_path: group.clone(),
169                function_name: func.clone(),
170                arguments: args.iter().map(|&id| self.extract(id)).collect(),
171            },
172            IKun::GpuMap(f, x) => {
173                IKunTree::GpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
174            }
175            IKun::CpuMap(f, x) => {
176                IKunTree::CpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
177            }
178            IKun::TiledMap(size, f, x) => IKunTree::TiledMap(
179                *size,
180                Box::new(self.extract(*f)),
181                Box::new(self.extract(*x)),
182            ),
183            IKun::VectorizedMap(width, f, x) => IKunTree::VectorizedMap(
184                *width,
185                Box::new(self.extract(*f)),
186                Box::new(self.extract(*x)),
187            ),
188            IKun::UnrolledMap(factor, f, x) => IKunTree::UnrolledMap(
189                *factor,
190                Box::new(self.extract(*f)),
191                Box::new(self.extract(*x)),
192            ),
193            IKun::SoAMap(f, x) => {
194                IKunTree::SoAMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
195            }
196            IKun::SoALayout(x) => IKunTree::SoALayout(Box::new(self.extract(*x))),
197            IKun::AoSLayout(x) => IKunTree::AoSLayout(Box::new(self.extract(*x))),
198            IKun::Tiled(size, x) => IKunTree::Tiled(*size, Box::new(self.extract(*x))),
199            IKun::Unrolled(factor, x) => IKunTree::Unrolled(*factor, Box::new(self.extract(*x))),
200            IKun::Vectorized(width, x) => IKunTree::Vectorized(*width, Box::new(self.extract(*x))),
201            IKun::Pipe(a, b) => {
202                IKunTree::Pipe(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
203            }
204            IKun::Reg(x) => IKunTree::Reg(Box::new(self.extract(*x))),
205            IKun::Lambda(params, body) => {
206                IKunTree::Lambda(params.clone(), Box::new(self.extract(*body)))
207            }
208            IKun::Apply(func, args) => IKunTree::Apply(
209                Box::new(self.extract(*func)),
210                args.iter().map(|&id| self.extract(id)).collect(),
211            ),
212            IKun::Closure(body, captured) => IKunTree::Closure(
213                Box::new(self.extract(*body)),
214                captured.iter().map(|&id| self.extract(id)).collect(),
215            ),
216            IKun::ResourceClone(x) => IKunTree::ResourceClone(Box::new(self.extract(*x))),
217            IKun::ResourceDrop(x) => IKunTree::ResourceDrop(Box::new(self.extract(*x))),
218            IKun::AddrOf(a) => IKunTree::AddrOf(Box::new(self.extract(*a))),
219            IKun::Deref(a) => IKunTree::Deref(Box::new(self.extract(*a))),
220            IKun::PtrOffset(a, b) => {
221                IKunTree::PtrOffset(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
222            }
223            IKun::ClassDef(name, bases, body) => IKunTree::ClassDef(
224                name.clone(),
225                bases.iter().map(|&id| self.extract(id)).collect(),
226                Box::new(self.extract(*body)),
227            ),
228            IKun::Table(ids) => IKunTree::Table(ids.iter().map(|&id| self.extract(id)).collect()),
229            IKun::Pair(a, b) => {
230                IKunTree::Pair(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
231            }
232            IKun::GetIndex(a, b) => {
233                IKunTree::GetIndex(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
234            }
235            IKun::SetIndex(a, b, c) => IKunTree::SetIndex(
236                Box::new(self.extract(*a)),
237                Box::new(self.extract(*b)),
238                Box::new(self.extract(*c)),
239            ),
240            IKun::BinaryOp(op, a, b) => IKunTree::BinaryOp(
241                op.clone(),
242                Box::new(self.extract(*a)),
243                Box::new(self.extract(*b)),
244            ),
245            IKun::UnaryOp(op, a) => IKunTree::UnaryOp(op.clone(), Box::new(self.extract(*a))),
246            IKun::Import(m, s) => IKunTree::Import(m.clone(), s.clone()),
247            IKun::Export(s, body) => IKunTree::Export(s.clone(), Box::new(self.extract(*body))),
248            IKun::Module(m, items) => IKunTree::Module(
249                m.clone(),
250                items.iter().map(|&id| self.extract(id)).collect(),
251            ),
252            IKun::StaticAccess(a, s) => IKunTree::StaticAccess(Box::new(self.extract(*a)), *s),
253            IKun::WitnessAccess(a, b, s) => {
254                IKunTree::WitnessAccess(Box::new(self.extract(*a)), Box::new(self.extract(*b)), *s)
255            }
256            IKun::DynamicAccess(a, s, call_site_id) => {
257                IKunTree::DynamicAccess(Box::new(self.extract(*a)), s.clone(), *call_site_id)
258            }
259            IKun::StaticCall(a, ids) => IKunTree::StaticCall(
260                Box::new(self.extract(*a)),
261                ids.iter().map(|&id| self.extract(id)).collect(),
262            ),
263            IKun::WitnessCall(a, b, s, ids) => IKunTree::WitnessCall(
264                Box::new(self.extract(*a)),
265                Box::new(self.extract(*b)),
266                *s,
267                ids.iter().map(|&id| self.extract(id)).collect(),
268            ),
269            IKun::DynamicCall(a, s, ids, call_site_id) => IKunTree::DynamicCall(
270                Box::new(self.extract(*a)),
271                s.clone(),
272                ids.iter().map(|&id| self.extract(id)).collect(),
273                *call_site_id,
274            ),
275        };
276
277        if let Some(loc) = self.egraph.get_class(root).data.get_locs().first() {
278            tree = IKunTree::Source(*loc, Box::new(tree));
279        }
280        tree
281    }
282
283    pub fn get_best_node(&self, id: Id) -> Option<&IKun> {
284        let root = self.egraph.union_find.find(id);
285        self.costs.get(&root).map(|(_, node)| node)
286    }
287
288    pub fn get_best_cost(&self, id: Id) -> Option<Cost> {
289        let root = self.egraph.union_find.find(id);
290        self.costs.get(&root).map(|(cost, _)| *cost)
291    }
292}
293
294/// A DAG-aware extractor that attempts to find the optimal extraction
295/// considering node sharing. This is a placeholder for a full ILP solver.
296pub struct IKunIlpExtractor<'a, A: Analysis<IKun>, C: CostModel> {
297    pub egraph: &'a EGraph<IKun, A>,
298    pub cost_model: C,
299}
300
301impl<'a, A: Analysis<IKun>, C: CostModel + Clone> IKunIlpExtractor<'a, A, C> {
302    pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
303        Self { egraph, cost_model }
304    }
305
306    /// Performs extraction. Currently uses the greedy extractor as a fallback.
307    /// In a full implementation, this would use an ILP solver to find the
308    /// minimum cost subgraph that covers the root e-class.
309    pub fn extract(&self, id: Id) -> IKunTree
310    where
311        A::Data: chomsky_uir::egraph::HasDebugInfo,
312    {
313        let extractor = IKunExtractor::new(self.egraph, self.cost_model.clone());
314        extractor.extract(id)
315    }
316}