1#![warn(missing_docs)]
2
3use chomsky_cost::{Cost, CostModel};
4use chomsky_uir::egraph::{Analysis, EGraph, HasDebugInfo, Id, Language};
5pub use chomsky_uir::{IKun, IKunTree};
6use std::collections::HashMap;
7
8use chomsky_types::ChomskyResult;
9use chomsky_uir::intent::CrossLanguageCall;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum BackendArtifact {
13 Source(String),
14 Binary(Vec<u8>),
15 Assembly(String),
16 Collection(HashMap<String, Vec<u8>>),
17}
18
19pub trait Backend {
22 fn name(&self) -> &str;
23 fn generate(&self, tree: &IKunTree) -> ChomskyResult<BackendArtifact>;
24
25 fn get_model(&self) -> &dyn CostModel {
27 &chomsky_cost::DEFAULT_COST_MODEL
28 }
29}
30
31pub struct IKunExtractor<'a, A, C>
32where
33 A: Analysis<IKun>,
34 A::Data: HasDebugInfo,
35 C: CostModel,
36{
37 pub egraph: &'a EGraph<IKun, A>,
38 pub cost_model: C,
39 pub costs: HashMap<Id, (Cost, IKun)>,
40}
41
42impl<'a, A, C> IKunExtractor<'a, A, C>
43where
44 A: Analysis<IKun>,
45 A::Data: HasDebugInfo,
46 C: CostModel,
47{
48 pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
49 let mut extractor = Self {
50 egraph,
51 cost_model,
52 costs: HashMap::new(),
53 };
54 extractor.find_best();
55 extractor
56 }
57
58 fn find_best(&mut self) {
59 let mut changed = true;
60 while changed {
61 changed = false;
62 for entry in self.egraph.classes.iter() {
63 let id = *entry.key();
64 let eclass = entry.value();
65
66 for node in &eclass.nodes {
67 let mut node_cost = self.cost_model.cost(node);
68 let mut can_compute = true;
69
70 for child in node.children() {
71 let child_id = self.egraph.union_find.find(child);
72 if let Some((child_cost, _)) = self.costs.get(&child_id) {
73 node_cost = node_cost.add(child_cost);
74 } else {
75 can_compute = false;
76 break;
77 }
78 }
79
80 if can_compute {
81 let current_best = self.costs.get(&id);
82
83 if current_best.is_none()
84 || node_cost.score() < current_best.unwrap().0.score()
85 {
86 self.costs.insert(id, (node_cost, node.clone()));
87 changed = true;
88 }
89 }
90 }
91 }
92 }
93 }
94
95 pub fn extract(&self, id: Id) -> IKunTree {
96 let root = self.egraph.union_find.find(id);
97 let (_, node) = self.costs.get(&root).expect("No cost found for eclass");
98
99 let mut tree = match node {
100 IKun::Constant(v) => IKunTree::Constant(*v),
101 IKun::FloatConstant(v) => IKunTree::FloatConstant(*v),
102 IKun::BooleanConstant(v) => IKunTree::BooleanConstant(*v),
103 IKun::StringConstant(s) => IKunTree::StringConstant(s.clone()),
104 IKun::Symbol(s) => IKunTree::Symbol(s.clone()),
105 IKun::Map(f, x) => {
106 IKunTree::Map(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
107 }
108 IKun::Filter(f, x) => {
109 IKunTree::Filter(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
110 }
111 IKun::Reduce(f, init, list) => IKunTree::Reduce(
112 Box::new(self.extract(*f)),
113 Box::new(self.extract(*init)),
114 Box::new(self.extract(*list)),
115 ),
116 IKun::StateUpdate(var, val) => {
117 IKunTree::StateUpdate(Box::new(self.extract(*var)), Box::new(self.extract(*val)))
118 }
119 IKun::Choice(cond, t, f) => IKunTree::Choice(
120 Box::new(self.extract(*cond)),
121 Box::new(self.extract(*t)),
122 Box::new(self.extract(*f)),
123 ),
124 IKun::Repeat(cond, body) => {
125 IKunTree::Repeat(Box::new(self.extract(*cond)), Box::new(self.extract(*body)))
126 }
127 IKun::LifeCycle(setup, cleanup) => IKunTree::LifeCycle(
128 Box::new(self.extract(*setup)),
129 Box::new(self.extract(*cleanup)),
130 ),
131 IKun::Meta(body) => IKunTree::Meta(Box::new(self.extract(*body))),
132 IKun::Trap(body) => IKunTree::Trap(Box::new(self.extract(*body))),
133 IKun::Return(val) => IKunTree::Return(Box::new(self.extract(*val))),
134 IKun::Seq(ids) => IKunTree::Seq(ids.iter().map(|&id| self.extract(id)).collect()),
135 IKun::Compose(a, b) => {
136 IKunTree::Compose(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
137 }
138 IKun::WithContext(ctx, body) => {
139 IKunTree::WithContext(Box::new(self.extract(*ctx)), Box::new(self.extract(*body)))
140 }
141 IKun::WithConstraint(constraint, body) => IKunTree::WithConstraint(
142 Box::new(self.extract(*constraint)),
143 Box::new(self.extract(*body)),
144 ),
145 IKun::CpuContext => IKunTree::CpuContext,
146 IKun::GpuContext => IKunTree::GpuContext,
147 IKun::AsyncContext => IKunTree::AsyncContext,
148 IKun::SpatialContext => IKunTree::SpatialContext,
149 IKun::ComptimeContext => IKunTree::ComptimeContext,
150 IKun::ResourceContext => IKunTree::ResourceContext,
151 IKun::SafeContext => IKunTree::SafeContext,
152 IKun::EffectConstraint(e) => IKunTree::EffectConstraint(e.clone()),
153 IKun::OwnershipConstraint(o) => IKunTree::OwnershipConstraint(o.clone()),
154 IKun::TypeConstraint(t) => IKunTree::TypeConstraint(t.clone()),
155 IKun::AtomicConstraint => IKunTree::AtomicConstraint,
156 IKun::Extension(name, args) => IKunTree::Extension(
157 name.clone(),
158 args.iter().map(|&id| self.extract(id)).collect(),
159 ),
160 IKun::CrossLangCall(CrossLanguageCall {
161 language: lang,
162 module_path: group,
163 function_name: func,
164 arguments: args,
165 }) => IKunTree::CrossLangCall {
166 language: lang.clone(),
167 module_path: group.clone(),
168 function_name: func.clone(),
169 arguments: args.iter().map(|&id| self.extract(id)).collect(),
170 },
171 IKun::GpuMap(f, x) => {
172 IKunTree::GpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
173 }
174 IKun::CpuMap(f, x) => {
175 IKunTree::CpuMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
176 }
177 IKun::TiledMap(size, f, x) => IKunTree::TiledMap(
178 *size,
179 Box::new(self.extract(*f)),
180 Box::new(self.extract(*x)),
181 ),
182 IKun::VectorizedMap(width, f, x) => IKunTree::VectorizedMap(
183 *width,
184 Box::new(self.extract(*f)),
185 Box::new(self.extract(*x)),
186 ),
187 IKun::UnrolledMap(factor, f, x) => IKunTree::UnrolledMap(
188 *factor,
189 Box::new(self.extract(*f)),
190 Box::new(self.extract(*x)),
191 ),
192 IKun::SoAMap(f, x) => {
193 IKunTree::SoAMap(Box::new(self.extract(*f)), Box::new(self.extract(*x)))
194 }
195 IKun::SoALayout(x) => IKunTree::SoALayout(Box::new(self.extract(*x))),
196 IKun::AoSLayout(x) => IKunTree::AoSLayout(Box::new(self.extract(*x))),
197 IKun::Tiled(size, x) => IKunTree::Tiled(*size, Box::new(self.extract(*x))),
198 IKun::Unrolled(factor, x) => IKunTree::Unrolled(*factor, Box::new(self.extract(*x))),
199 IKun::Vectorized(width, x) => IKunTree::Vectorized(*width, Box::new(self.extract(*x))),
200 IKun::Pipe(a, b) => {
201 IKunTree::Pipe(Box::new(self.extract(*a)), Box::new(self.extract(*b)))
202 }
203 IKun::Reg(x) => IKunTree::Reg(Box::new(self.extract(*x))),
204 IKun::Lambda(params, body) => {
205 IKunTree::Lambda(params.clone(), Box::new(self.extract(*body)))
206 }
207 IKun::Apply(func, args) => IKunTree::Apply(
208 Box::new(self.extract(*func)),
209 args.iter().map(|&id| self.extract(id)).collect(),
210 ),
211 IKun::Closure(body, captured) => IKunTree::Closure(
212 Box::new(self.extract(*body)),
213 captured.iter().map(|&id| self.extract(id)).collect(),
214 ),
215 IKun::ResourceClone(x) => IKunTree::ResourceClone(Box::new(self.extract(*x))),
216 IKun::ResourceDrop(x) => IKunTree::ResourceDrop(Box::new(self.extract(*x))),
217 IKun::Import(m, s) => IKunTree::Import(m.clone(), s.clone()),
218 IKun::Export(s, body) => IKunTree::Export(s.clone(), Box::new(self.extract(*body))),
219 IKun::Module(m, items) => IKunTree::Module(
220 m.clone(),
221 items.iter().map(|&id| self.extract(id)).collect(),
222 ),
223 };
224
225 if let Some(loc) = self.egraph.get_class(root).data.get_locs().first() {
226 tree = IKunTree::Source(*loc, Box::new(tree));
227 }
228 tree
229 }
230
231 pub fn get_best_node(&self, id: Id) -> Option<&IKun> {
232 let root = self.egraph.union_find.find(id);
233 self.costs.get(&root).map(|(_, node)| node)
234 }
235
236 pub fn get_best_cost(&self, id: Id) -> Option<Cost> {
237 let root = self.egraph.union_find.find(id);
238 self.costs.get(&root).map(|(cost, _)| *cost)
239 }
240}
241
242pub struct IKunIlpExtractor<'a, A: Analysis<IKun>, C: CostModel> {
245 pub egraph: &'a EGraph<IKun, A>,
246 pub cost_model: C,
247}
248
249impl<'a, A: Analysis<IKun>, C: CostModel + Clone> IKunIlpExtractor<'a, A, C> {
250 pub fn new(egraph: &'a EGraph<IKun, A>, cost_model: C) -> Self {
251 Self { egraph, cost_model }
252 }
253
254 pub fn extract(&self, id: Id) -> IKunTree
258 where
259 A::Data: chomsky_uir::egraph::HasDebugInfo,
260 {
261 let extractor = IKunExtractor::new(self.egraph, self.cost_model.clone());
262 extractor.extract(id)
263 }
264}