Skip to main content

yulang_native/
closure.rs

1use std::collections::HashSet;
2
3use crate::control_ir::{
4    BlockId, NativeFunction, NativeModule, NativeStmt, NativeTerminator, ValueId,
5};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NativeClosureModule {
9    pub functions: Vec<NativeClosureFunction>,
10    pub roots: Vec<NativeClosureFunction>,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct NativeClosureFunction {
15    pub name: String,
16    pub params: Vec<ValueId>,
17    pub environment: NativeClosureEnvironment,
18    pub abi: NativeClosureAbi,
19    pub blocks: Vec<NativeClosureBlock>,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct NativeClosureBlock {
24    pub id: BlockId,
25    pub params: Vec<ValueId>,
26    pub stmts: Vec<NativeClosureStmt>,
27    pub terminator: NativeTerminator,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum NativeClosureStmt {
32    LoadEnv {
33        dest: ValueId,
34        slot: usize,
35    },
36    MakeClosure {
37        dest: ValueId,
38        target: String,
39        environment: Vec<NativeClosureCapture>,
40    },
41    ClosureCall {
42        dest: ValueId,
43        callee: ValueId,
44        args: Vec<ValueId>,
45    },
46    Native(NativeStmt),
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct NativeClosureEnvironment {
51    pub slots: Vec<NativeClosureSlot>,
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct NativeClosureAbi {
56    pub code: NativeClosureCodeRef,
57    pub environment: NativeClosureEnvRef,
58    pub params: Vec<ValueId>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct NativeClosureCodeRef {
63    pub function: String,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct NativeClosureEnvRef {
68    pub slots: usize,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub struct NativeClosureSlot {
73    pub index: usize,
74    pub value: ValueId,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct NativeClosureCapture {
79    pub slot: usize,
80    pub value: ValueId,
81}
82
83pub fn closure_convert_module(module: &NativeModule) -> NativeClosureModule {
84    NativeClosureModule {
85        functions: module
86            .functions
87            .iter()
88            .map(closure_convert_function)
89            .collect(),
90        roots: module.roots.iter().map(closure_convert_function).collect(),
91    }
92}
93
94fn closure_convert_function(function: &NativeFunction) -> NativeClosureFunction {
95    let environment_slots = closure_environment_slots(function);
96    let capture_values = function.captures.iter().copied().collect::<HashSet<_>>();
97    let params = function
98        .params
99        .iter()
100        .copied()
101        .filter(|param| !capture_values.contains(param))
102        .collect();
103    let blocks = closure_convert_blocks(function, &environment_slots, &capture_values);
104    NativeClosureFunction {
105        name: function.name.clone(),
106        params,
107        abi: NativeClosureAbi {
108            code: NativeClosureCodeRef {
109                function: function.name.clone(),
110            },
111            environment: NativeClosureEnvRef {
112                slots: environment_slots.len(),
113            },
114            params: function
115                .params
116                .iter()
117                .copied()
118                .filter(|param| !capture_values.contains(param))
119                .collect(),
120        },
121        environment: NativeClosureEnvironment {
122            slots: environment_slots,
123        },
124        blocks,
125    }
126}
127
128fn closure_environment_slots(function: &NativeFunction) -> Vec<NativeClosureSlot> {
129    function
130        .captures
131        .iter()
132        .copied()
133        .enumerate()
134        .map(|(index, value)| NativeClosureSlot { index, value })
135        .collect()
136}
137
138fn closure_convert_blocks(
139    function: &NativeFunction,
140    environment_slots: &[NativeClosureSlot],
141    capture_values: &HashSet<ValueId>,
142) -> Vec<NativeClosureBlock> {
143    let entry = function.blocks.first().map(|block| block.id);
144    function
145        .blocks
146        .iter()
147        .map(|block| {
148            let mut stmts = Vec::new();
149            if Some(block.id) == entry {
150                stmts.extend(
151                    environment_slots
152                        .iter()
153                        .map(|slot| NativeClosureStmt::LoadEnv {
154                            dest: slot.value,
155                            slot: slot.index,
156                        }),
157                );
158            }
159            stmts.extend(block.stmts.iter().cloned().map(closure_convert_stmt));
160            NativeClosureBlock {
161                id: block.id,
162                params: block
163                    .params
164                    .iter()
165                    .copied()
166                    .filter(|param| !capture_values.contains(param))
167                    .collect(),
168                stmts,
169                terminator: block.terminator.clone(),
170            }
171        })
172        .collect()
173}
174
175fn closure_convert_stmt(stmt: NativeStmt) -> NativeClosureStmt {
176    match stmt {
177        NativeStmt::MakeClosure {
178            dest,
179            target,
180            captures,
181        } => NativeClosureStmt::MakeClosure {
182            dest,
183            target,
184            environment: captures
185                .into_iter()
186                .enumerate()
187                .map(|(slot, value)| NativeClosureCapture { slot, value })
188                .collect(),
189        },
190        NativeStmt::ClosureCall { dest, callee, args } => {
191            NativeClosureStmt::ClosureCall { dest, callee, args }
192        }
193        stmt => NativeClosureStmt::Native(stmt),
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use crate::control_ir::{
200        BlockId, NativeBlock, NativeFunction, NativeLiteral, NativeModule, NativeStmt,
201        NativeTerminator, ValueId,
202    };
203
204    use super::*;
205
206    #[test]
207    fn converts_first_order_function_to_empty_environment_closure() {
208        let function = NativeFunction {
209            name: "root".to_string(),
210            captures: Vec::new(),
211            params: vec![ValueId(0)],
212            blocks: vec![NativeBlock {
213                id: BlockId(0),
214                params: vec![ValueId(0)],
215                stmts: vec![NativeStmt::Literal {
216                    dest: ValueId(1),
217                    literal: NativeLiteral::Int("1".to_string()),
218                }],
219                terminator: NativeTerminator::Return(ValueId(1)),
220            }],
221        };
222        let module = NativeModule {
223            functions: Vec::new(),
224            roots: vec![function.clone()],
225        };
226
227        let converted = closure_convert_module(&module);
228
229        assert_eq!(
230            converted.roots,
231            vec![NativeClosureFunction {
232                name: "root".to_string(),
233                params: vec![ValueId(0)],
234                abi: NativeClosureAbi {
235                    code: NativeClosureCodeRef {
236                        function: "root".to_string(),
237                    },
238                    environment: NativeClosureEnvRef { slots: 0 },
239                    params: vec![ValueId(0)],
240                },
241                environment: NativeClosureEnvironment { slots: Vec::new() },
242                blocks: function
243                    .blocks
244                    .into_iter()
245                    .map(|block| NativeClosureBlock {
246                        id: block.id,
247                        params: block.params,
248                        stmts: block.stmts.into_iter().map(closure_convert_stmt).collect(),
249                        terminator: block.terminator,
250                    })
251                    .collect(),
252            }]
253        );
254    }
255
256    #[test]
257    fn separates_capture_params_into_environment_slots() {
258        let function = NativeFunction {
259            name: "root#lambda0".to_string(),
260            captures: vec![ValueId(0)],
261            params: vec![ValueId(0), ValueId(1)],
262            blocks: vec![NativeBlock {
263                id: BlockId(0),
264                params: vec![ValueId(0), ValueId(1)],
265                stmts: Vec::new(),
266                terminator: NativeTerminator::Return(ValueId(1)),
267            }],
268        };
269        let module = NativeModule {
270            functions: vec![function.clone()],
271            roots: Vec::new(),
272        };
273
274        let converted = closure_convert_module(&module);
275
276        assert_eq!(converted.functions[0].params, vec![ValueId(1)]);
277        assert_eq!(
278            converted.functions[0].abi,
279            NativeClosureAbi {
280                code: NativeClosureCodeRef {
281                    function: "root#lambda0".to_string(),
282                },
283                environment: NativeClosureEnvRef { slots: 1 },
284                params: vec![ValueId(1)],
285            }
286        );
287        assert_eq!(converted.functions[0].blocks[0].params, vec![ValueId(1)]);
288        assert_eq!(
289            converted.functions[0].environment,
290            NativeClosureEnvironment {
291                slots: vec![NativeClosureSlot {
292                    index: 0,
293                    value: ValueId(0),
294                }]
295            }
296        );
297        assert_eq!(
298            converted.functions[0].blocks[0].stmts,
299            vec![NativeClosureStmt::LoadEnv {
300                dest: ValueId(0),
301                slot: 0,
302            }]
303        );
304        assert_eq!(
305            converted.functions[0].blocks[0].terminator,
306            NativeTerminator::Return(ValueId(1))
307        );
308    }
309
310    #[test]
311    fn converts_make_closure_to_environment_allocation() {
312        let function = NativeFunction {
313            name: "root".to_string(),
314            captures: Vec::new(),
315            params: Vec::new(),
316            blocks: vec![NativeBlock {
317                id: BlockId(0),
318                params: Vec::new(),
319                stmts: vec![NativeStmt::MakeClosure {
320                    dest: ValueId(2),
321                    target: "root#lambda0".to_string(),
322                    captures: vec![ValueId(0), ValueId(1)],
323                }],
324                terminator: NativeTerminator::Return(ValueId(2)),
325            }],
326        };
327        let module = NativeModule {
328            functions: Vec::new(),
329            roots: vec![function],
330        };
331
332        let converted = closure_convert_module(&module);
333
334        assert_eq!(
335            converted.roots[0].blocks[0].stmts,
336            vec![NativeClosureStmt::MakeClosure {
337                dest: ValueId(2),
338                target: "root#lambda0".to_string(),
339                environment: vec![
340                    NativeClosureCapture {
341                        slot: 0,
342                        value: ValueId(0),
343                    },
344                    NativeClosureCapture {
345                        slot: 1,
346                        value: ValueId(1),
347                    },
348                ],
349            }]
350        );
351    }
352}