1use std::collections::HashSet;
2
3use crate::control_ir::{
4 BlockId, NativeFunction, NativeModule, NativeStmt, NativeTerminator, ValueId,
5};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NativeClosureModule {
9 pub functions: Vec<NativeClosureFunction>,
10 pub roots: Vec<NativeClosureFunction>,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct NativeClosureFunction {
15 pub name: String,
16 pub params: Vec<ValueId>,
17 pub environment: NativeClosureEnvironment,
18 pub abi: NativeClosureAbi,
19 pub blocks: Vec<NativeClosureBlock>,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct NativeClosureBlock {
24 pub id: BlockId,
25 pub params: Vec<ValueId>,
26 pub stmts: Vec<NativeClosureStmt>,
27 pub terminator: NativeTerminator,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum NativeClosureStmt {
32 LoadEnv {
33 dest: ValueId,
34 slot: usize,
35 },
36 MakeClosure {
37 dest: ValueId,
38 target: String,
39 environment: Vec<NativeClosureCapture>,
40 },
41 ClosureCall {
42 dest: ValueId,
43 callee: ValueId,
44 args: Vec<ValueId>,
45 },
46 Native(NativeStmt),
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct NativeClosureEnvironment {
51 pub slots: Vec<NativeClosureSlot>,
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct NativeClosureAbi {
56 pub code: NativeClosureCodeRef,
57 pub environment: NativeClosureEnvRef,
58 pub params: Vec<ValueId>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct NativeClosureCodeRef {
63 pub function: String,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct NativeClosureEnvRef {
68 pub slots: usize,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub struct NativeClosureSlot {
73 pub index: usize,
74 pub value: ValueId,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct NativeClosureCapture {
79 pub slot: usize,
80 pub value: ValueId,
81}
82
83pub fn closure_convert_module(module: &NativeModule) -> NativeClosureModule {
84 NativeClosureModule {
85 functions: module
86 .functions
87 .iter()
88 .map(closure_convert_function)
89 .collect(),
90 roots: module.roots.iter().map(closure_convert_function).collect(),
91 }
92}
93
94fn closure_convert_function(function: &NativeFunction) -> NativeClosureFunction {
95 let environment_slots = closure_environment_slots(function);
96 let capture_values = function.captures.iter().copied().collect::<HashSet<_>>();
97 let params = function
98 .params
99 .iter()
100 .copied()
101 .filter(|param| !capture_values.contains(param))
102 .collect();
103 let blocks = closure_convert_blocks(function, &environment_slots, &capture_values);
104 NativeClosureFunction {
105 name: function.name.clone(),
106 params,
107 abi: NativeClosureAbi {
108 code: NativeClosureCodeRef {
109 function: function.name.clone(),
110 },
111 environment: NativeClosureEnvRef {
112 slots: environment_slots.len(),
113 },
114 params: function
115 .params
116 .iter()
117 .copied()
118 .filter(|param| !capture_values.contains(param))
119 .collect(),
120 },
121 environment: NativeClosureEnvironment {
122 slots: environment_slots,
123 },
124 blocks,
125 }
126}
127
128fn closure_environment_slots(function: &NativeFunction) -> Vec<NativeClosureSlot> {
129 function
130 .captures
131 .iter()
132 .copied()
133 .enumerate()
134 .map(|(index, value)| NativeClosureSlot { index, value })
135 .collect()
136}
137
138fn closure_convert_blocks(
139 function: &NativeFunction,
140 environment_slots: &[NativeClosureSlot],
141 capture_values: &HashSet<ValueId>,
142) -> Vec<NativeClosureBlock> {
143 let entry = function.blocks.first().map(|block| block.id);
144 function
145 .blocks
146 .iter()
147 .map(|block| {
148 let mut stmts = Vec::new();
149 if Some(block.id) == entry {
150 stmts.extend(
151 environment_slots
152 .iter()
153 .map(|slot| NativeClosureStmt::LoadEnv {
154 dest: slot.value,
155 slot: slot.index,
156 }),
157 );
158 }
159 stmts.extend(block.stmts.iter().cloned().map(closure_convert_stmt));
160 NativeClosureBlock {
161 id: block.id,
162 params: block
163 .params
164 .iter()
165 .copied()
166 .filter(|param| !capture_values.contains(param))
167 .collect(),
168 stmts,
169 terminator: block.terminator.clone(),
170 }
171 })
172 .collect()
173}
174
175fn closure_convert_stmt(stmt: NativeStmt) -> NativeClosureStmt {
176 match stmt {
177 NativeStmt::MakeClosure {
178 dest,
179 target,
180 captures,
181 } => NativeClosureStmt::MakeClosure {
182 dest,
183 target,
184 environment: captures
185 .into_iter()
186 .enumerate()
187 .map(|(slot, value)| NativeClosureCapture { slot, value })
188 .collect(),
189 },
190 NativeStmt::ClosureCall { dest, callee, args } => {
191 NativeClosureStmt::ClosureCall { dest, callee, args }
192 }
193 stmt => NativeClosureStmt::Native(stmt),
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use crate::control_ir::{
200 BlockId, NativeBlock, NativeFunction, NativeLiteral, NativeModule, NativeStmt,
201 NativeTerminator, ValueId,
202 };
203
204 use super::*;
205
206 #[test]
207 fn converts_first_order_function_to_empty_environment_closure() {
208 let function = NativeFunction {
209 name: "root".to_string(),
210 captures: Vec::new(),
211 params: vec![ValueId(0)],
212 blocks: vec![NativeBlock {
213 id: BlockId(0),
214 params: vec![ValueId(0)],
215 stmts: vec![NativeStmt::Literal {
216 dest: ValueId(1),
217 literal: NativeLiteral::Int("1".to_string()),
218 }],
219 terminator: NativeTerminator::Return(ValueId(1)),
220 }],
221 };
222 let module = NativeModule {
223 functions: Vec::new(),
224 roots: vec![function.clone()],
225 };
226
227 let converted = closure_convert_module(&module);
228
229 assert_eq!(
230 converted.roots,
231 vec![NativeClosureFunction {
232 name: "root".to_string(),
233 params: vec![ValueId(0)],
234 abi: NativeClosureAbi {
235 code: NativeClosureCodeRef {
236 function: "root".to_string(),
237 },
238 environment: NativeClosureEnvRef { slots: 0 },
239 params: vec![ValueId(0)],
240 },
241 environment: NativeClosureEnvironment { slots: Vec::new() },
242 blocks: function
243 .blocks
244 .into_iter()
245 .map(|block| NativeClosureBlock {
246 id: block.id,
247 params: block.params,
248 stmts: block.stmts.into_iter().map(closure_convert_stmt).collect(),
249 terminator: block.terminator,
250 })
251 .collect(),
252 }]
253 );
254 }
255
256 #[test]
257 fn separates_capture_params_into_environment_slots() {
258 let function = NativeFunction {
259 name: "root#lambda0".to_string(),
260 captures: vec![ValueId(0)],
261 params: vec![ValueId(0), ValueId(1)],
262 blocks: vec![NativeBlock {
263 id: BlockId(0),
264 params: vec![ValueId(0), ValueId(1)],
265 stmts: Vec::new(),
266 terminator: NativeTerminator::Return(ValueId(1)),
267 }],
268 };
269 let module = NativeModule {
270 functions: vec![function.clone()],
271 roots: Vec::new(),
272 };
273
274 let converted = closure_convert_module(&module);
275
276 assert_eq!(converted.functions[0].params, vec![ValueId(1)]);
277 assert_eq!(
278 converted.functions[0].abi,
279 NativeClosureAbi {
280 code: NativeClosureCodeRef {
281 function: "root#lambda0".to_string(),
282 },
283 environment: NativeClosureEnvRef { slots: 1 },
284 params: vec![ValueId(1)],
285 }
286 );
287 assert_eq!(converted.functions[0].blocks[0].params, vec![ValueId(1)]);
288 assert_eq!(
289 converted.functions[0].environment,
290 NativeClosureEnvironment {
291 slots: vec![NativeClosureSlot {
292 index: 0,
293 value: ValueId(0),
294 }]
295 }
296 );
297 assert_eq!(
298 converted.functions[0].blocks[0].stmts,
299 vec![NativeClosureStmt::LoadEnv {
300 dest: ValueId(0),
301 slot: 0,
302 }]
303 );
304 assert_eq!(
305 converted.functions[0].blocks[0].terminator,
306 NativeTerminator::Return(ValueId(1))
307 );
308 }
309
310 #[test]
311 fn converts_make_closure_to_environment_allocation() {
312 let function = NativeFunction {
313 name: "root".to_string(),
314 captures: Vec::new(),
315 params: Vec::new(),
316 blocks: vec![NativeBlock {
317 id: BlockId(0),
318 params: Vec::new(),
319 stmts: vec![NativeStmt::MakeClosure {
320 dest: ValueId(2),
321 target: "root#lambda0".to_string(),
322 captures: vec![ValueId(0), ValueId(1)],
323 }],
324 terminator: NativeTerminator::Return(ValueId(2)),
325 }],
326 };
327 let module = NativeModule {
328 functions: Vec::new(),
329 roots: vec![function],
330 };
331
332 let converted = closure_convert_module(&module);
333
334 assert_eq!(
335 converted.roots[0].blocks[0].stmts,
336 vec![NativeClosureStmt::MakeClosure {
337 dest: ValueId(2),
338 target: "root#lambda0".to_string(),
339 environment: vec![
340 NativeClosureCapture {
341 slot: 0,
342 value: ValueId(0),
343 },
344 NativeClosureCapture {
345 slot: 1,
346 value: ValueId(1),
347 },
348 ],
349 }]
350 );
351 }
352}