Skip to main content

yulang_native/
abi_validate.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use crate::abi::{NativeAbiBlock, NativeAbiFunction, NativeAbiModule, NativeAbiStmt};
5use crate::control_ir::{BlockId, NativeTerminator, ValueId};
6
7pub type NativeAbiValidateResult<T> = Result<T, NativeAbiValidateError>;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum NativeAbiValidateError {
11    DuplicateFunction {
12        name: String,
13    },
14    DuplicateBlock {
15        function: String,
16        block: BlockId,
17    },
18    DuplicateBlockParam {
19        function: String,
20        block: BlockId,
21        value: ValueId,
22    },
23    DuplicateValue {
24        function: String,
25        block: BlockId,
26        value: ValueId,
27    },
28    UndefinedValue {
29        function: String,
30        block: BlockId,
31        value: ValueId,
32    },
33    MissingBlock {
34        function: String,
35        block: BlockId,
36    },
37    EnvSlotOutOfRange {
38        function: String,
39        block: BlockId,
40        slot: usize,
41        slots: usize,
42    },
43}
44
45impl fmt::Display for NativeAbiValidateError {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            NativeAbiValidateError::DuplicateFunction { name } => {
49                write!(f, "duplicate native ABI function `{name}`")
50            }
51            NativeAbiValidateError::DuplicateBlock { function, block } => {
52                write!(f, "duplicate native ABI block {block:?} in `{function}`")
53            }
54            NativeAbiValidateError::DuplicateBlockParam {
55                function,
56                block,
57                value,
58            } => write!(
59                f,
60                "duplicate native ABI block param {value:?} in block {block:?} of `{function}`"
61            ),
62            NativeAbiValidateError::DuplicateValue {
63                function,
64                block,
65                value,
66            } => write!(
67                f,
68                "duplicate native ABI value {value:?} in block {block:?} of `{function}`"
69            ),
70            NativeAbiValidateError::UndefinedValue {
71                function,
72                block,
73                value,
74            } => write!(
75                f,
76                "undefined native ABI value {value:?} in block {block:?} of `{function}`"
77            ),
78            NativeAbiValidateError::MissingBlock { function, block } => {
79                write!(f, "missing native ABI block {block:?} in `{function}`")
80            }
81            NativeAbiValidateError::EnvSlotOutOfRange {
82                function,
83                block,
84                slot,
85                slots,
86            } => write!(
87                f,
88                "native ABI env slot {slot} is out of range for {slots} slots in block {block:?} of `{function}`"
89            ),
90        }
91    }
92}
93
94impl std::error::Error for NativeAbiValidateError {}
95
96pub fn validate_abi_module(module: &NativeAbiModule) -> NativeAbiValidateResult<()> {
97    let mut functions = HashSet::new();
98    for function in module.functions.iter().chain(&module.roots) {
99        if !functions.insert(function.name.clone()) {
100            return Err(NativeAbiValidateError::DuplicateFunction {
101                name: function.name.clone(),
102            });
103        }
104        validate_function(function)?;
105    }
106    Ok(())
107}
108
109fn validate_function(function: &NativeAbiFunction) -> NativeAbiValidateResult<()> {
110    let mut blocks = HashSet::new();
111    for block in &function.blocks {
112        if !blocks.insert(block.id) {
113            return Err(NativeAbiValidateError::DuplicateBlock {
114                function: function.name.clone(),
115                block: block.id,
116            });
117        }
118    }
119    let entry = function.blocks.first().map(|block| block.id);
120    let block_start_values = function_block_start_values(function);
121    for block in &function.blocks {
122        let values = block_start_values
123            .get(&block.id)
124            .cloned()
125            .unwrap_or_default();
126        validate_block(function, block, &blocks, Some(block.id) == entry, values)?;
127    }
128    Ok(())
129}
130
131fn validate_block(
132    function: &NativeAbiFunction,
133    block: &NativeAbiBlock,
134    blocks: &HashSet<BlockId>,
135    is_entry: bool,
136    mut values: HashSet<ValueId>,
137) -> NativeAbiValidateResult<()> {
138    let block_params = if is_entry && block.params.starts_with(&function.params) {
139        &block.params[function.params.len()..]
140    } else {
141        block.params.as_slice()
142    };
143    let mut seen_params = function.params.iter().copied().collect::<HashSet<_>>();
144    for param in block_params {
145        if !seen_params.insert(*param) {
146            return Err(NativeAbiValidateError::DuplicateBlockParam {
147                function: function.name.clone(),
148                block: block.id,
149                value: *param,
150            });
151        }
152    }
153    for stmt in &block.stmts {
154        validate_stmt_uses(function, block, stmt, &values)?;
155        let dest = stmt_dest(stmt);
156        if !values.insert(dest) {
157            return Err(NativeAbiValidateError::DuplicateValue {
158                function: function.name.clone(),
159                block: block.id,
160                value: dest,
161            });
162        }
163    }
164    validate_terminator(function, block, blocks, &values)
165}
166
167fn function_block_start_values(function: &NativeAbiFunction) -> HashMap<BlockId, HashSet<ValueId>> {
168    let mut start = function
169        .blocks
170        .iter()
171        .map(|block| {
172            (
173                block.id,
174                block.params.iter().copied().collect::<HashSet<_>>(),
175            )
176        })
177        .collect::<HashMap<_, _>>();
178    if let Some(entry) = function.blocks.first() {
179        start
180            .entry(entry.id)
181            .or_default()
182            .extend(function.params.iter().copied());
183    }
184
185    let mut changed = true;
186    while changed {
187        changed = false;
188        for block in &function.blocks {
189            let mut out = start.get(&block.id).cloned().unwrap_or_default();
190            for stmt in &block.stmts {
191                out.insert(stmt_dest(stmt));
192            }
193            for successor in terminator_successors(&block.terminator) {
194                let entry = start.entry(successor).or_default();
195                let old_len = entry.len();
196                entry.extend(out.iter().copied());
197                changed |= entry.len() != old_len;
198            }
199        }
200    }
201    start
202}
203
204fn validate_stmt_uses(
205    function: &NativeAbiFunction,
206    block: &NativeAbiBlock,
207    stmt: &NativeAbiStmt,
208    values: &HashSet<ValueId>,
209) -> NativeAbiValidateResult<()> {
210    match stmt {
211        NativeAbiStmt::Literal { .. } => Ok(()),
212        NativeAbiStmt::Primitive { args, .. }
213        | NativeAbiStmt::DirectCall { args, .. }
214        | NativeAbiStmt::Tuple { items: args, .. }
215        | NativeAbiStmt::IndirectClosureCall { args, .. } => {
216            for arg in args {
217                require_value(function, block, values, *arg)?;
218            }
219            if let NativeAbiStmt::IndirectClosureCall { callee, .. } = stmt {
220                require_value(function, block, values, *callee)?;
221            }
222            Ok(())
223        }
224        NativeAbiStmt::Record { base, fields, .. } => {
225            if let Some(base) = base {
226                require_value(function, block, values, *base)?;
227            }
228            for field in fields {
229                require_value(function, block, values, field.value)?;
230            }
231            Ok(())
232        }
233        NativeAbiStmt::RecordWithoutFields { base, .. } => {
234            require_value(function, block, values, *base)
235        }
236        NativeAbiStmt::Variant { value, .. } => {
237            if let Some(value) = value {
238                require_value(function, block, values, *value)?;
239            }
240            Ok(())
241        }
242        NativeAbiStmt::Select { base, .. } => require_value(function, block, values, *base),
243        NativeAbiStmt::TupleGet { tuple, .. } => require_value(function, block, values, *tuple),
244        NativeAbiStmt::VariantTagEq { variant, .. }
245        | NativeAbiStmt::VariantPayload { variant, .. } => {
246            require_value(function, block, values, *variant)
247        }
248        NativeAbiStmt::ValueEq { left, right, .. } => {
249            require_value(function, block, values, *left)?;
250            require_value(function, block, values, *right)
251        }
252        NativeAbiStmt::BoolAnd { left, right, .. } => {
253            require_value(function, block, values, *left)?;
254            require_value(function, block, values, *right)
255        }
256        NativeAbiStmt::LoadEnv { slot, .. } => {
257            if *slot >= function.environment_slots {
258                return Err(NativeAbiValidateError::EnvSlotOutOfRange {
259                    function: function.name.clone(),
260                    block: block.id,
261                    slot: *slot,
262                    slots: function.environment_slots,
263                });
264            }
265            Ok(())
266        }
267        NativeAbiStmt::AllocateClosure { environment, .. } => {
268            for value in environment {
269                require_value(function, block, values, *value)?;
270            }
271            Ok(())
272        }
273    }
274}
275
276fn validate_terminator(
277    function: &NativeAbiFunction,
278    block: &NativeAbiBlock,
279    blocks: &HashSet<BlockId>,
280    values: &HashSet<ValueId>,
281) -> NativeAbiValidateResult<()> {
282    match &block.terminator {
283        NativeTerminator::Return(value) => require_value(function, block, values, *value),
284        NativeTerminator::Jump { target, args } => {
285            require_block(function, *target, blocks)?;
286            for arg in args {
287                require_value(function, block, values, *arg)?;
288            }
289            Ok(())
290        }
291        NativeTerminator::Branch {
292            cond,
293            then_block,
294            else_block,
295        } => {
296            require_value(function, block, values, *cond)?;
297            require_block(function, *then_block, blocks)?;
298            require_block(function, *else_block, blocks)
299        }
300    }
301}
302
303fn terminator_successors(terminator: &NativeTerminator) -> Vec<BlockId> {
304    match terminator {
305        NativeTerminator::Return(_) => Vec::new(),
306        NativeTerminator::Jump { target, .. } => vec![*target],
307        NativeTerminator::Branch {
308            then_block,
309            else_block,
310            ..
311        } => vec![*then_block, *else_block],
312    }
313}
314
315fn stmt_dest(stmt: &NativeAbiStmt) -> ValueId {
316    match stmt {
317        NativeAbiStmt::Literal { dest, .. }
318        | NativeAbiStmt::Primitive { dest, .. }
319        | NativeAbiStmt::DirectCall { dest, .. }
320        | NativeAbiStmt::Tuple { dest, .. }
321        | NativeAbiStmt::Record { dest, .. }
322        | NativeAbiStmt::RecordWithoutFields { dest, .. }
323        | NativeAbiStmt::Variant { dest, .. }
324        | NativeAbiStmt::Select { dest, .. }
325        | NativeAbiStmt::TupleGet { dest, .. }
326        | NativeAbiStmt::VariantTagEq { dest, .. }
327        | NativeAbiStmt::VariantPayload { dest, .. }
328        | NativeAbiStmt::ValueEq { dest, .. }
329        | NativeAbiStmt::BoolAnd { dest, .. }
330        | NativeAbiStmt::LoadEnv { dest, .. }
331        | NativeAbiStmt::AllocateClosure { dest, .. }
332        | NativeAbiStmt::IndirectClosureCall { dest, .. } => *dest,
333    }
334}
335
336fn require_value(
337    function: &NativeAbiFunction,
338    block: &NativeAbiBlock,
339    values: &HashSet<ValueId>,
340    value: ValueId,
341) -> NativeAbiValidateResult<()> {
342    if values.contains(&value) {
343        Ok(())
344    } else {
345        Err(NativeAbiValidateError::UndefinedValue {
346            function: function.name.clone(),
347            block: block.id,
348            value,
349        })
350    }
351}
352
353fn require_block(
354    function: &NativeAbiFunction,
355    block: BlockId,
356    blocks: &HashSet<BlockId>,
357) -> NativeAbiValidateResult<()> {
358    if blocks.contains(&block) {
359        Ok(())
360    } else {
361        Err(NativeAbiValidateError::MissingBlock {
362            function: function.name.clone(),
363            block,
364        })
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use crate::abi::{NativeAbiBlock, NativeAbiFunction, NativeAbiModule, NativeAbiStmt};
371    use crate::control_ir::{BlockId, NativeLiteral, NativeTerminator, ValueId};
372
373    use super::*;
374
375    #[test]
376    fn accepts_valid_abi_module() {
377        let module = NativeAbiModule {
378            functions: Vec::new(),
379            roots: vec![NativeAbiFunction {
380                name: "root".to_string(),
381                params: Vec::new(),
382                environment_slots: 1,
383                blocks: vec![NativeAbiBlock {
384                    id: BlockId(0),
385                    params: Vec::new(),
386                    stmts: vec![
387                        NativeAbiStmt::LoadEnv {
388                            dest: ValueId(0),
389                            slot: 0,
390                        },
391                        NativeAbiStmt::Literal {
392                            dest: ValueId(1),
393                            literal: NativeLiteral::Int("1".to_string()),
394                        },
395                        NativeAbiStmt::AllocateClosure {
396                            dest: ValueId(2),
397                            target: "root#lambda0".to_string(),
398                            environment: vec![ValueId(0), ValueId(1)],
399                        },
400                    ],
401                    terminator: NativeTerminator::Return(ValueId(2)),
402                }],
403            }],
404        };
405
406        validate_abi_module(&module).expect("valid abi");
407    }
408
409    #[test]
410    fn rejects_out_of_range_env_slot() {
411        let module = NativeAbiModule {
412            functions: Vec::new(),
413            roots: vec![NativeAbiFunction {
414                name: "root".to_string(),
415                params: Vec::new(),
416                environment_slots: 0,
417                blocks: vec![NativeAbiBlock {
418                    id: BlockId(0),
419                    params: Vec::new(),
420                    stmts: vec![NativeAbiStmt::LoadEnv {
421                        dest: ValueId(0),
422                        slot: 0,
423                    }],
424                    terminator: NativeTerminator::Return(ValueId(0)),
425                }],
426            }],
427        };
428
429        assert_eq!(
430            validate_abi_module(&module),
431            Err(NativeAbiValidateError::EnvSlotOutOfRange {
432                function: "root".to_string(),
433                block: BlockId(0),
434                slot: 0,
435                slots: 0,
436            })
437        );
438    }
439
440    #[test]
441    fn rejects_undefined_call_argument() {
442        let module = NativeAbiModule {
443            functions: Vec::new(),
444            roots: vec![NativeAbiFunction {
445                name: "root".to_string(),
446                params: Vec::new(),
447                environment_slots: 0,
448                blocks: vec![NativeAbiBlock {
449                    id: BlockId(0),
450                    params: Vec::new(),
451                    stmts: vec![NativeAbiStmt::DirectCall {
452                        dest: ValueId(1),
453                        target: "f".to_string(),
454                        args: vec![ValueId(0)],
455                    }],
456                    terminator: NativeTerminator::Return(ValueId(1)),
457                }],
458            }],
459        };
460
461        assert_eq!(
462            validate_abi_module(&module),
463            Err(NativeAbiValidateError::UndefinedValue {
464                function: "root".to_string(),
465                block: BlockId(0),
466                value: ValueId(0),
467            })
468        );
469    }
470}