1#![warn(missing_docs)]
2
3use chomsky_cost::{Cost, CostModel};
4use chomsky_uir::egraph::{Analysis, EGraph, HasDebugInfo, Id, Language};
5pub use chomsky_uir::{IKun, IKunTree};
6use std::collections::HashMap;
7
8use chomsky_types::ChomskyResult;
9use chomsky_uir::intent::CrossLanguageCall;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum BackendArtifact {
13 Source(String),
14 Binary(Vec<u8>),
15 Assembly(String),
16 Collection(HashMap<String, Vec<u8>>),
17}
18
19pub trait Backend {
22 fn name(&self) -> &str;
23 fn generate(&self, tree: &IKunTree) -> ChomskyResult<BackendArtifact>;
24
25 fn get_model(&self) -> &dyn CostModel {
27 &chomsky_cost::DEFAULT_COST_MODEL
28 }
29}
30
31pub struct IKunExtractor<'a, A, C>
32where
33 A: Analysis<IKun>,
34 A::Data: HasDebugInfo,
35 C: CostModel,
36{
37 pub egraph: &'a EGraph<IKun, A>,
38 pub cost_model: C,
39 pub costs: HashMap<Id, (Cost, IKun)>,
40}
41
42impl<'a, A, C> IKunExtractor<'a, A, C>
43where
44 A: Analysis<IKun>,
45 A::Data: HasDebugInfo,
46 C: CostModel,
47{
48 pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
49 let mut extractor = Self {
50 egraph,
51 cost_model,
52 costs: HashMap::new(),
53 };
54 extractor.find_best();
55 extractor
56 }
57
58 fn find_best(&mut self) {
59 let mut changed = true;
60 while changed {
61 changed = false;
62 for entry in self.egraph.classes.iter() {
63 let id = *entry.key();
64 let eclass = entry.value();
65
66 for node in &eclass.nodes {
67 let mut node_cost = self.cost_model.cost(node);
68 let mut can_compute = true;
69
70 for child in node.children() {
71 let child_id = self.egraph.union_find.find(child);
72 if let Some((child_cost, _)) = self.costs.get(&child_id) {
73 node_cost = node_cost.add(child_cost);
74 } else {
75 can_compute = false;
76 break;
77 }
78 }
79
80 if can_compute {
81 let current_best = self.costs.get(&id);
82
83 if current_best.is_none()
84 || node_cost.score() < current_best.unwrap().0.score()
85 {
86 self.costs.insert(id, (node_cost, node.clone()));
87 changed = true;
88 }
89 }
90 }
91 }
92 }
93 }
94
95 pub fn extract(&self, id: Id) -> IKunTree {
96 let root = self.egraph.union_find.find(id);
97 let (_, node) = self.costs.get(&root).expect("No cost found for eclass");
98
99 let mut tree = match node {
100 IKun::Constant(v) => IKunTree::Constant(*v),
101 IKun::FloatConstant(v) => IKunTree::FloatConstant(*v),
102 IKun::BooleanConstant(v) => IKunTree::BooleanConstant(*v),
103 IKun::StringConstant(s) => IKunTree::StringConstant(s.clone()),
104 IKun::None => IKunTree::None,
105 IKun::Symbol(s) => IKunTree::Symbol(s.clone()),
106 IKun::Map(f, x) => {
107 IKunTree::Map(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
108 }
109 IKun::Filter(f, x) => {
110 IKunTree::Filter(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
111 }
112 IKun::Reduce(f, init, list) => IKunTree::Reduce(
113 Box::new(self.extract(*f)),
114 Box::new(self.extract(*init)),
115 Box::new(self.extract(*list)),
116 ),
117 IKun::StateUpdate(var, val) => {
118 IKunTree::StateUpdate(Box::new(self.extract(*var)), Box::new(self.extract(*val)))
119 }
120 IKun::Choice(cond, t, f) => IKunTree::Choice(
121 Box::new(self.extract(*cond)),
122 Box::new(self.extract(*t)),
123 Box::new(self.extract(*f)),
124 ),
125 IKun::Repeat(cond, body) => {
126 IKunTree::Repeat(Box::new(self.extract(*cond)), Box::new(self.extract(*body)))
127 }
128 IKun::LifeCycle(setup, cleanup) => IKunTree::LifeCycle(
129 Box::new(self.extract(*setup)),
130 Box::new(self.extract(*cleanup)),
131 ),
132 IKun::Meta(body) => IKunTree::Meta(Box::new(self.extract(*body))),
133 IKun::Trap(body) => IKunTree::Trap(Box::new(self.extract(*body))),
134 IKun::Return(val) => IKunTree::Return(Box::new(self.extract(*val))),
135 IKun::Seq(ids) => IKunTree::Seq(ids.iter().map(|&id| self.extract(id)).collect()),
136 IKun::Compose(a, b) => {
137 IKunTree::Compose(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
138 }
139 IKun::WithContext(ctx, body) => {
140 IKunTree::WithContext(Box::new(self.extract(*ctx)), Box::new(self.extract(*body)))
141 }
142 IKun::WithConstraint(constraint, body) => IKunTree::WithConstraint(
143 Box::new(self.extract(*constraint)),
144 Box::new(self.extract(*body)),
145 ),
146 IKun::CpuContext => IKunTree::CpuContext,
147 IKun::GpuContext => IKunTree::GpuContext,
148 IKun::AsyncContext => IKunTree::AsyncContext,
149 IKun::SpatialContext => IKunTree::SpatialContext,
150 IKun::ComptimeContext => IKunTree::ComptimeContext,
151 IKun::ResourceContext => IKunTree::ResourceContext,
152 IKun::SafeContext => IKunTree::SafeContext,
153 IKun::EffectConstraint(e) => IKunTree::EffectConstraint(e.clone()),
154 IKun::OwnershipConstraint(o) => IKunTree::OwnershipConstraint(o.clone()),
155 IKun::TypeConstraint(t) => IKunTree::TypeConstraint(t.clone()),
156 IKun::AtomicConstraint => IKunTree::AtomicConstraint,
157 IKun::Extension(name, args) => IKunTree::Extension(
158 name.clone(),
159 args.iter().map(|&id| self.extract(id)).collect(),
160 ),
161 IKun::CrossLangCall(CrossLanguageCall {
162 language: lang,
163 module_path: group,
164 function_name: func,
165 arguments: args,
166 }) => IKunTree::CrossLangCall {
167 language: lang.clone(),
168 module_path: group.clone(),
169 function_name: func.clone(),
170 arguments: args.iter().map(|&id| self.extract(id)).collect(),
171 },
172 IKun::GpuMap(f, x) => {
173 IKunTree::GpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
174 }
175 IKun::CpuMap(f, x) => {
176 IKunTree::CpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
177 }
178 IKun::TiledMap(size, f, x) => IKunTree::TiledMap(
179 *size,
180 Box::new(self.extract(*f)),
181 Box::new(self.extract(*x)),
182 ),
183 IKun::VectorizedMap(width, f, x) => IKunTree::VectorizedMap(
184 *width,
185 Box::new(self.extract(*f)),
186 Box::new(self.extract(*x)),
187 ),
188 IKun::UnrolledMap(factor, f, x) => IKunTree::UnrolledMap(
189 *factor,
190 Box::new(self.extract(*f)),
191 Box::new(self.extract(*x)),
192 ),
193 IKun::SoAMap(f, x) => {
194 IKunTree::SoAMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
195 }
196 IKun::SoALayout(x) => IKunTree::SoALayout(Box::new(self.extract(*x))),
197 IKun::AoSLayout(x) => IKunTree::AoSLayout(Box::new(self.extract(*x))),
198 IKun::Tiled(size, x) => IKunTree::Tiled(*size, Box::new(self.extract(*x))),
199 IKun::Unrolled(factor, x) => IKunTree::Unrolled(*factor, Box::new(self.extract(*x))),
200 IKun::Vectorized(width, x) => IKunTree::Vectorized(*width, Box::new(self.extract(*x))),
201 IKun::Pipe(a, b) => {
202 IKunTree::Pipe(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
203 }
204 IKun::Reg(x) => IKunTree::Reg(Box::new(self.extract(*x))),
205 IKun::Lambda(params, body) => {
206 IKunTree::Lambda(params.clone(), Box::new(self.extract(*body)))
207 }
208 IKun::Apply(func, args) => IKunTree::Apply(
209 Box::new(self.extract(*func)),
210 args.iter().map(|&id| self.extract(id)).collect(),
211 ),
212 IKun::Closure(body, captured) => IKunTree::Closure(
213 Box::new(self.extract(*body)),
214 captured.iter().map(|&id| self.extract(id)).collect(),
215 ),
216 IKun::ResourceClone(x) => IKunTree::ResourceClone(Box::new(self.extract(*x))),
217 IKun::ResourceDrop(x) => IKunTree::ResourceDrop(Box::new(self.extract(*x))),
218 IKun::AddrOf(a) => IKunTree::AddrOf(Box::new(self.extract(*a))),
219 IKun::Deref(a) => IKunTree::Deref(Box::new(self.extract(*a))),
220 IKun::PtrOffset(a, b) => {
221 IKunTree::PtrOffset(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
222 }
223 IKun::ClassDef(name, bases, body) => IKunTree::ClassDef(
224 name.clone(),
225 bases.iter().map(|&id| self.extract(id)).collect(),
226 Box::new(self.extract(*body)),
227 ),
228 IKun::Table(ids) => IKunTree::Table(ids.iter().map(|&id| self.extract(id)).collect()),
229 IKun::Pair(a, b) => {
230 IKunTree::Pair(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
231 }
232 IKun::GetIndex(a, b) => {
233 IKunTree::GetIndex(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
234 }
235 IKun::SetIndex(a, b, c) => IKunTree::SetIndex(
236 Box::new(self.extract(*a)),
237 Box::new(self.extract(*b)),
238 Box::new(self.extract(*c)),
239 ),
240 IKun::BinaryOp(op, a, b) => IKunTree::BinaryOp(
241 op.clone(),
242 Box::new(self.extract(*a)),
243 Box::new(self.extract(*b)),
244 ),
245 IKun::UnaryOp(op, a) => IKunTree::UnaryOp(op.clone(), Box::new(self.extract(*a))),
246 IKun::Import(m, s) => IKunTree::Import(m.clone(), s.clone()),
247 IKun::Export(s, body) => IKunTree::Export(s.clone(), Box::new(self.extract(*body))),
248 IKun::Module(m, items) => IKunTree::Module(
249 m.clone(),
250 items.iter().map(|&id| self.extract(id)).collect(),
251 ),
252 IKun::StaticAccess(a, s) => IKunTree::StaticAccess(Box::new(self.extract(*a)), *s),
253 IKun::WitnessAccess(a, b, s) => {
254 IKunTree::WitnessAccess(Box::new(self.extract(*a)), Box::new(self.extract(*b)), *s)
255 }
256 IKun::DynamicAccess(a, s, call_site_id) => {
257 IKunTree::DynamicAccess(Box::new(self.extract(*a)), s.clone(), *call_site_id)
258 }
259 IKun::StaticCall(a, ids) => IKunTree::StaticCall(
260 Box::new(self.extract(*a)),
261 ids.iter().map(|&id| self.extract(id)).collect(),
262 ),
263 IKun::WitnessCall(a, b, s, ids) => IKunTree::WitnessCall(
264 Box::new(self.extract(*a)),
265 Box::new(self.extract(*b)),
266 *s,
267 ids.iter().map(|&id| self.extract(id)).collect(),
268 ),
269 IKun::DynamicCall(a, s, ids, call_site_id) => IKunTree::DynamicCall(
270 Box::new(self.extract(*a)),
271 s.clone(),
272 ids.iter().map(|&id| self.extract(id)).collect(),
273 *call_site_id,
274 ),
275 };
276
277 if let Some(loc) = self.egraph.get_class(root).data.get_locs().first() {
278 tree = IKunTree::Source(*loc, Box::new(tree));
279 }
280 tree
281 }
282
283 pub fn get_best_node(&self, id: Id) -> Option<&IKun> {
284 let root = self.egraph.union_find.find(id);
285 self.costs.get(&root).map(|(_, node)| node)
286 }
287
288 pub fn get_best_cost(&self, id: Id) -> Option<Cost> {
289 let root = self.egraph.union_find.find(id);
290 self.costs.get(&root).map(|(cost, _)| *cost)
291 }
292}
293
294pub struct IKunIlpExtractor<'a, A: Analysis<IKun>, C: CostModel> {
297 pub egraph: &'a EGraph<IKun, A>,
298 pub cost_model: C,
299}
300
301impl<'a, A: Analysis<IKun>, C: CostModel + Clone> IKunIlpExtractor<'a, A, C> {
302 pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
303 Self { egraph, cost_model }
304 }
305
306 pub fn extract(&self, id: Id) -> IKunTree
310 where
311 A::Data: chomsky_uir::egraph::HasDebugInfo,
312 {
313 let extractor = IKunExtractor::new(self.egraph, self.cost_model.clone());
314 extractor.extract(id)
315 }
316}