Skip to main content

chomsky_extract/
lib.rs

1#![warn(missing_docs)]
2
3use chomsky_cost::{Cost, CostModel};
4use chomsky_uir::egraph::{Analysis, EGraph, HasDebugInfo, Id, Language};
5pub use chomsky_uir::{IKun, IKunTree};
6use std::collections::HashMap;
7
8use chomsky_types::ChomskyResult;
9use chomsky_uir::intent::CrossLanguageCall;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum BackendArtifact {
13    Source(String),
14    Binary(Vec<u8>),
15    Assembly(String),
16    Collection(HashMap<String, Vec<u8>>),
17}
18
19/// The standard trait for all backends.
20/// Backends consume an extracted IKunTree and produce an artifact.
21pub trait Backend {
22    fn name(&self) -> &str;
23    fn generate(&self, tree: &IKunTree) -> ChomskyResult<BackendArtifact>;
24
25    /// 获取该后端的成本模型
26    fn get_model(&self) -> &dyn CostModel {
27        &chomsky_cost::DEFAULT_COST_MODEL
28    }
29}
30
31pub struct IKunExtractor<'a, A, C>
32where
33    A: Analysis<IKun>,
34    A::Data: HasDebugInfo,
35    C: CostModel,
36{
37    pub egraph: &'a EGraph<IKun, A>,
38    pub cost_model: C,
39    pub costs: HashMap<Id, (Cost, IKun)>,
40}
41
42impl<'a, A, C> IKunExtractor<'a, A, C>
43where
44    A: Analysis<IKun>,
45    A::Data: HasDebugInfo,
46    C: CostModel,
47{
48    pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
49        let mut extractor = Self {
50            egraph,
51            cost_model,
52            costs: HashMap::new(),
53        };
54        extractor.find_best();
55        extractor
56    }
57
58    fn find_best(&mut self) {
59        let mut changed = true;
60        while changed {
61            changed = false;
62            for entry in self.egraph.classes.iter() {
63                let id = *entry.key();
64                let eclass = entry.value();
65
66                for node in &eclass.nodes {
67                    let mut node_cost = self.cost_model.cost(node);
68                    let mut can_compute = true;
69
70                    for child in node.children() {
71                        let child_id = self.egraph.union_find.find(child);
72                        if let Some((child_cost, _)) = self.costs.get(&child_id) {
73                            node_cost = node_cost.add(child_cost);
74                        } else {
75                            can_compute = false;
76                            break;
77                        }
78                    }
79
80                    if can_compute {
81                        let current_best = self.costs.get(&id);
82
83                        if current_best.is_none()
84                            || node_cost.score() < current_best.unwrap().0.score()
85                        {
86                            self.costs.insert(id, (node_cost, node.clone()));
87                            changed = true;
88                        }
89                    }
90                }
91            }
92        }
93    }
94
95    pub fn extract(&self, id: Id) -> IKunTree {
96        let root = self.egraph.union_find.find(id);
97        let (_, node) = self.costs.get(&root).expect("No cost found for eclass");
98
99        let mut tree = match node {
100            IKun::Constant(v) => IKunTree::Constant(*v),
101            IKun::FloatConstant(v) => IKunTree::FloatConstant(*v),
102            IKun::BooleanConstant(v) => IKunTree::BooleanConstant(*v),
103            IKun::StringConstant(s) => IKunTree::StringConstant(s.clone()),
104            IKun::Symbol(s) => IKunTree::Symbol(s.clone()),
105            IKun::Map(f, x) => {
106                IKunTree::Map(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
107            }
108            IKun::Filter(f, x) => {
109                IKunTree::Filter(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
110            }
111            IKun::Reduce(f, init, list) => IKunTree::Reduce(
112                Box::new(self.extract(*f)),
113                Box::new(self.extract(*init)),
114                Box::new(self.extract(*list)),
115            ),
116            IKun::StateUpdate(var, val) => {
117                IKunTree::StateUpdate(Box::new(self.extract(*var)), Box::new(self.extract(*val)))
118            }
119            IKun::Choice(cond, t, f) => IKunTree::Choice(
120                Box::new(self.extract(*cond)),
121                Box::new(self.extract(*t)),
122                Box::new(self.extract(*f)),
123            ),
124            IKun::Repeat(cond, body) => {
125                IKunTree::Repeat(Box::new(self.extract(*cond)), Box::new(self.extract(*body)))
126            }
127            IKun::LifeCycle(setup, cleanup) => IKunTree::LifeCycle(
128                Box::new(self.extract(*setup)),
129                Box::new(self.extract(*cleanup)),
130            ),
131            IKun::Meta(body) => IKunTree::Meta(Box::new(self.extract(*body))),
132            IKun::Trap(body) => IKunTree::Trap(Box::new(self.extract(*body))),
133            IKun::Return(val) => IKunTree::Return(Box::new(self.extract(*val))),
134            IKun::Seq(ids) => IKunTree::Seq(ids.iter().map(|&id| self.extract(id)).collect()),
135            IKun::Compose(a, b) => {
136                IKunTree::Compose(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
137            }
138            IKun::WithContext(ctx, body) => {
139                IKunTree::WithContext(Box::new(self.extract(*ctx)), Box::new(self.extract(*body)))
140            }
141            IKun::WithConstraint(constraint, body) => IKunTree::WithConstraint(
142                Box::new(self.extract(*constraint)),
143                Box::new(self.extract(*body)),
144            ),
145            IKun::CpuContext => IKunTree::CpuContext,
146            IKun::GpuContext => IKunTree::GpuContext,
147            IKun::AsyncContext => IKunTree::AsyncContext,
148            IKun::SpatialContext => IKunTree::SpatialContext,
149            IKun::ComptimeContext => IKunTree::ComptimeContext,
150            IKun::ResourceContext => IKunTree::ResourceContext,
151            IKun::SafeContext => IKunTree::SafeContext,
152            IKun::EffectConstraint(e) => IKunTree::EffectConstraint(e.clone()),
153            IKun::OwnershipConstraint(o) => IKunTree::OwnershipConstraint(o.clone()),
154            IKun::TypeConstraint(t) => IKunTree::TypeConstraint(t.clone()),
155            IKun::AtomicConstraint => IKunTree::AtomicConstraint,
156            IKun::Extension(name, args) => IKunTree::Extension(
157                name.clone(),
158                args.iter().map(|&id| self.extract(id)).collect(),
159            ),
160            IKun::CrossLangCall(CrossLanguageCall {
161                language: lang,
162                module_path: group,
163                function_name: func,
164                arguments: args,
165            }) => IKunTree::CrossLangCall {
166                language: lang.clone(),
167                module_path: group.clone(),
168                function_name: func.clone(),
169                arguments: args.iter().map(|&id| self.extract(id)).collect(),
170            },
171            IKun::GpuMap(f, x) => {
172                IKunTree::GpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
173            }
174            IKun::CpuMap(f, x) => {
175                IKunTree::CpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
176            }
177            IKun::TiledMap(size, f, x) => IKunTree::TiledMap(
178                *size,
179                Box::new(self.extract(*f)),
180                Box::new(self.extract(*x)),
181            ),
182            IKun::VectorizedMap(width, f, x) => IKunTree::VectorizedMap(
183                *width,
184                Box::new(self.extract(*f)),
185                Box::new(self.extract(*x)),
186            ),
187            IKun::UnrolledMap(factor, f, x) => IKunTree::UnrolledMap(
188                *factor,
189                Box::new(self.extract(*f)),
190                Box::new(self.extract(*x)),
191            ),
192            IKun::SoAMap(f, x) => {
193                IKunTree::SoAMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
194            }
195            IKun::SoALayout(x) => IKunTree::SoALayout(Box::new(self.extract(*x))),
196            IKun::AoSLayout(x) => IKunTree::AoSLayout(Box::new(self.extract(*x))),
197            IKun::Tiled(size, x) => IKunTree::Tiled(*size, Box::new(self.extract(*x))),
198            IKun::Unrolled(factor, x) => IKunTree::Unrolled(*factor, Box::new(self.extract(*x))),
199            IKun::Vectorized(width, x) => IKunTree::Vectorized(*width, Box::new(self.extract(*x))),
200            IKun::Pipe(a, b) => {
201                IKunTree::Pipe(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
202            }
203            IKun::Reg(x) => IKunTree::Reg(Box::new(self.extract(*x))),
204            IKun::Lambda(params, body) => {
205                IKunTree::Lambda(params.clone(), Box::new(self.extract(*body)))
206            }
207            IKun::Apply(func, args) => IKunTree::Apply(
208                Box::new(self.extract(*func)),
209                args.iter().map(|&id| self.extract(id)).collect(),
210            ),
211            IKun::Closure(body, captured) => IKunTree::Closure(
212                Box::new(self.extract(*body)),
213                captured.iter().map(|&id| self.extract(id)).collect(),
214            ),
215            IKun::ResourceClone(x) => IKunTree::ResourceClone(Box::new(self.extract(*x))),
216            IKun::ResourceDrop(x) => IKunTree::ResourceDrop(Box::new(self.extract(*x))),
217            IKun::Import(m, s) => IKunTree::Import(m.clone(), s.clone()),
218            IKun::Export(s, body) => IKunTree::Export(s.clone(), Box::new(self.extract(*body))),
219            IKun::Module(m, items) => IKunTree::Module(
220                m.clone(),
221                items.iter().map(|&id| self.extract(id)).collect(),
222            ),
223        };
224
225        if let Some(loc) = self.egraph.get_class(root).data.get_locs().first() {
226            tree = IKunTree::Source(*loc, Box::new(tree));
227        }
228        tree
229    }
230
231    pub fn get_best_node(&self, id: Id) -> Option<&IKun> {
232        let root = self.egraph.union_find.find(id);
233        self.costs.get(&root).map(|(_, node)| node)
234    }
235
236    pub fn get_best_cost(&self, id: Id) -> Option<Cost> {
237        let root = self.egraph.union_find.find(id);
238        self.costs.get(&root).map(|(cost, _)| *cost)
239    }
240}
241
242/// A DAG-aware extractor that attempts to find the optimal extraction
243/// considering node sharing. This is a placeholder for a full ILP solver.
244pub struct IKunIlpExtractor<'a, A: Analysis<IKun>, C: CostModel> {
245    pub egraph: &'a EGraph<IKun, A>,
246    pub cost_model: C,
247}
248
249impl<'a, A: Analysis<IKun>, C: CostModel + Clone> IKunIlpExtractor<'a, A, C> {
250    pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
251        Self { egraph, cost_model }
252    }
253
254    /// Performs extraction. Currently uses the greedy extractor as a fallback.
255    /// In a full implementation, this would use an ILP solver to find the
256    /// minimum cost subgraph that covers the root e-class.
257    pub fn extract(&self, id: Id) -> IKunTree
258    where
259        A::Data: chomsky_uir::egraph::HasDebugInfo,
260    {
261        let extractor = IKunExtractor::new(self.egraph, self.cost_model.clone());
262        extractor.extract(id)
263    }
264}