1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use crate::abi::{NativeAbiBlock, NativeAbiFunction, NativeAbiModule, NativeAbiStmt};
5use crate::control_ir::{BlockId, NativeTerminator, ValueId};
6
7pub type NativeAbiValidateResult<T> = Result<T, NativeAbiValidateError>;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum NativeAbiValidateError {
11 DuplicateFunction {
12 name: String,
13 },
14 DuplicateBlock {
15 function: String,
16 block: BlockId,
17 },
18 DuplicateBlockParam {
19 function: String,
20 block: BlockId,
21 value: ValueId,
22 },
23 DuplicateValue {
24 function: String,
25 block: BlockId,
26 value: ValueId,
27 },
28 UndefinedValue {
29 function: String,
30 block: BlockId,
31 value: ValueId,
32 },
33 MissingBlock {
34 function: String,
35 block: BlockId,
36 },
37 EnvSlotOutOfRange {
38 function: String,
39 block: BlockId,
40 slot: usize,
41 slots: usize,
42 },
43}
44
45impl fmt::Display for NativeAbiValidateError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 NativeAbiValidateError::DuplicateFunction { name } => {
49 write!(f, "duplicate native ABI function `{name}`")
50 }
51 NativeAbiValidateError::DuplicateBlock { function, block } => {
52 write!(f, "duplicate native ABI block {block:?} in `{function}`")
53 }
54 NativeAbiValidateError::DuplicateBlockParam {
55 function,
56 block,
57 value,
58 } => write!(
59 f,
60 "duplicate native ABI block param {value:?} in block {block:?} of `{function}`"
61 ),
62 NativeAbiValidateError::DuplicateValue {
63 function,
64 block,
65 value,
66 } => write!(
67 f,
68 "duplicate native ABI value {value:?} in block {block:?} of `{function}`"
69 ),
70 NativeAbiValidateError::UndefinedValue {
71 function,
72 block,
73 value,
74 } => write!(
75 f,
76 "undefined native ABI value {value:?} in block {block:?} of `{function}`"
77 ),
78 NativeAbiValidateError::MissingBlock { function, block } => {
79 write!(f, "missing native ABI block {block:?} in `{function}`")
80 }
81 NativeAbiValidateError::EnvSlotOutOfRange {
82 function,
83 block,
84 slot,
85 slots,
86 } => write!(
87 f,
88 "native ABI env slot {slot} is out of range for {slots} slots in block {block:?} of `{function}`"
89 ),
90 }
91 }
92}
93
94impl std::error::Error for NativeAbiValidateError {}
95
96pub fn validate_abi_module(module: &NativeAbiModule) -> NativeAbiValidateResult<()> {
97 let mut functions = HashSet::new();
98 for function in module.functions.iter().chain(&module.roots) {
99 if !functions.insert(function.name.clone()) {
100 return Err(NativeAbiValidateError::DuplicateFunction {
101 name: function.name.clone(),
102 });
103 }
104 validate_function(function)?;
105 }
106 Ok(())
107}
108
109fn validate_function(function: &NativeAbiFunction) -> NativeAbiValidateResult<()> {
110 let mut blocks = HashSet::new();
111 for block in &function.blocks {
112 if !blocks.insert(block.id) {
113 return Err(NativeAbiValidateError::DuplicateBlock {
114 function: function.name.clone(),
115 block: block.id,
116 });
117 }
118 }
119 let entry = function.blocks.first().map(|block| block.id);
120 let block_start_values = function_block_start_values(function);
121 for block in &function.blocks {
122 let values = block_start_values
123 .get(&block.id)
124 .cloned()
125 .unwrap_or_default();
126 validate_block(function, block, &blocks, Some(block.id) == entry, values)?;
127 }
128 Ok(())
129}
130
131fn validate_block(
132 function: &NativeAbiFunction,
133 block: &NativeAbiBlock,
134 blocks: &HashSet<BlockId>,
135 is_entry: bool,
136 mut values: HashSet<ValueId>,
137) -> NativeAbiValidateResult<()> {
138 let block_params = if is_entry && block.params.starts_with(&function.params) {
139 &block.params[function.params.len()..]
140 } else {
141 block.params.as_slice()
142 };
143 let mut seen_params = function.params.iter().copied().collect::<HashSet<_>>();
144 for param in block_params {
145 if !seen_params.insert(*param) {
146 return Err(NativeAbiValidateError::DuplicateBlockParam {
147 function: function.name.clone(),
148 block: block.id,
149 value: *param,
150 });
151 }
152 }
153 for stmt in &block.stmts {
154 validate_stmt_uses(function, block, stmt, &values)?;
155 let dest = stmt_dest(stmt);
156 if !values.insert(dest) {
157 return Err(NativeAbiValidateError::DuplicateValue {
158 function: function.name.clone(),
159 block: block.id,
160 value: dest,
161 });
162 }
163 }
164 validate_terminator(function, block, blocks, &values)
165}
166
167fn function_block_start_values(function: &NativeAbiFunction) -> HashMap<BlockId, HashSet<ValueId>> {
168 let mut start = function
169 .blocks
170 .iter()
171 .map(|block| {
172 (
173 block.id,
174 block.params.iter().copied().collect::<HashSet<_>>(),
175 )
176 })
177 .collect::<HashMap<_, _>>();
178 if let Some(entry) = function.blocks.first() {
179 start
180 .entry(entry.id)
181 .or_default()
182 .extend(function.params.iter().copied());
183 }
184
185 let mut changed = true;
186 while changed {
187 changed = false;
188 for block in &function.blocks {
189 let mut out = start.get(&block.id).cloned().unwrap_or_default();
190 for stmt in &block.stmts {
191 out.insert(stmt_dest(stmt));
192 }
193 for successor in terminator_successors(&block.terminator) {
194 let entry = start.entry(successor).or_default();
195 let old_len = entry.len();
196 entry.extend(out.iter().copied());
197 changed |= entry.len() != old_len;
198 }
199 }
200 }
201 start
202}
203
204fn validate_stmt_uses(
205 function: &NativeAbiFunction,
206 block: &NativeAbiBlock,
207 stmt: &NativeAbiStmt,
208 values: &HashSet<ValueId>,
209) -> NativeAbiValidateResult<()> {
210 match stmt {
211 NativeAbiStmt::Literal { .. } => Ok(()),
212 NativeAbiStmt::Primitive { args, .. }
213 | NativeAbiStmt::DirectCall { args, .. }
214 | NativeAbiStmt::Tuple { items: args, .. }
215 | NativeAbiStmt::IndirectClosureCall { args, .. } => {
216 for arg in args {
217 require_value(function, block, values, *arg)?;
218 }
219 if let NativeAbiStmt::IndirectClosureCall { callee, .. } = stmt {
220 require_value(function, block, values, *callee)?;
221 }
222 Ok(())
223 }
224 NativeAbiStmt::Record { base, fields, .. } => {
225 if let Some(base) = base {
226 require_value(function, block, values, *base)?;
227 }
228 for field in fields {
229 require_value(function, block, values, field.value)?;
230 }
231 Ok(())
232 }
233 NativeAbiStmt::RecordWithoutFields { base, .. } => {
234 require_value(function, block, values, *base)
235 }
236 NativeAbiStmt::Variant { value, .. } => {
237 if let Some(value) = value {
238 require_value(function, block, values, *value)?;
239 }
240 Ok(())
241 }
242 NativeAbiStmt::Select { base, .. } => require_value(function, block, values, *base),
243 NativeAbiStmt::TupleGet { tuple, .. } => require_value(function, block, values, *tuple),
244 NativeAbiStmt::VariantTagEq { variant, .. }
245 | NativeAbiStmt::VariantPayload { variant, .. } => {
246 require_value(function, block, values, *variant)
247 }
248 NativeAbiStmt::ValueEq { left, right, .. } => {
249 require_value(function, block, values, *left)?;
250 require_value(function, block, values, *right)
251 }
252 NativeAbiStmt::BoolAnd { left, right, .. } => {
253 require_value(function, block, values, *left)?;
254 require_value(function, block, values, *right)
255 }
256 NativeAbiStmt::LoadEnv { slot, .. } => {
257 if *slot >= function.environment_slots {
258 return Err(NativeAbiValidateError::EnvSlotOutOfRange {
259 function: function.name.clone(),
260 block: block.id,
261 slot: *slot,
262 slots: function.environment_slots,
263 });
264 }
265 Ok(())
266 }
267 NativeAbiStmt::AllocateClosure { environment, .. } => {
268 for value in environment {
269 require_value(function, block, values, *value)?;
270 }
271 Ok(())
272 }
273 }
274}
275
276fn validate_terminator(
277 function: &NativeAbiFunction,
278 block: &NativeAbiBlock,
279 blocks: &HashSet<BlockId>,
280 values: &HashSet<ValueId>,
281) -> NativeAbiValidateResult<()> {
282 match &block.terminator {
283 NativeTerminator::Return(value) => require_value(function, block, values, *value),
284 NativeTerminator::Jump { target, args } => {
285 require_block(function, *target, blocks)?;
286 for arg in args {
287 require_value(function, block, values, *arg)?;
288 }
289 Ok(())
290 }
291 NativeTerminator::Branch {
292 cond,
293 then_block,
294 else_block,
295 } => {
296 require_value(function, block, values, *cond)?;
297 require_block(function, *then_block, blocks)?;
298 require_block(function, *else_block, blocks)
299 }
300 }
301}
302
303fn terminator_successors(terminator: &NativeTerminator) -> Vec<BlockId> {
304 match terminator {
305 NativeTerminator::Return(_) => Vec::new(),
306 NativeTerminator::Jump { target, .. } => vec![*target],
307 NativeTerminator::Branch {
308 then_block,
309 else_block,
310 ..
311 } => vec![*then_block, *else_block],
312 }
313}
314
315fn stmt_dest(stmt: &NativeAbiStmt) -> ValueId {
316 match stmt {
317 NativeAbiStmt::Literal { dest, .. }
318 | NativeAbiStmt::Primitive { dest, .. }
319 | NativeAbiStmt::DirectCall { dest, .. }
320 | NativeAbiStmt::Tuple { dest, .. }
321 | NativeAbiStmt::Record { dest, .. }
322 | NativeAbiStmt::RecordWithoutFields { dest, .. }
323 | NativeAbiStmt::Variant { dest, .. }
324 | NativeAbiStmt::Select { dest, .. }
325 | NativeAbiStmt::TupleGet { dest, .. }
326 | NativeAbiStmt::VariantTagEq { dest, .. }
327 | NativeAbiStmt::VariantPayload { dest, .. }
328 | NativeAbiStmt::ValueEq { dest, .. }
329 | NativeAbiStmt::BoolAnd { dest, .. }
330 | NativeAbiStmt::LoadEnv { dest, .. }
331 | NativeAbiStmt::AllocateClosure { dest, .. }
332 | NativeAbiStmt::IndirectClosureCall { dest, .. } => *dest,
333 }
334}
335
336fn require_value(
337 function: &NativeAbiFunction,
338 block: &NativeAbiBlock,
339 values: &HashSet<ValueId>,
340 value: ValueId,
341) -> NativeAbiValidateResult<()> {
342 if values.contains(&value) {
343 Ok(())
344 } else {
345 Err(NativeAbiValidateError::UndefinedValue {
346 function: function.name.clone(),
347 block: block.id,
348 value,
349 })
350 }
351}
352
353fn require_block(
354 function: &NativeAbiFunction,
355 block: BlockId,
356 blocks: &HashSet<BlockId>,
357) -> NativeAbiValidateResult<()> {
358 if blocks.contains(&block) {
359 Ok(())
360 } else {
361 Err(NativeAbiValidateError::MissingBlock {
362 function: function.name.clone(),
363 block,
364 })
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use crate::abi::{NativeAbiBlock, NativeAbiFunction, NativeAbiModule, NativeAbiStmt};
371 use crate::control_ir::{BlockId, NativeLiteral, NativeTerminator, ValueId};
372
373 use super::*;
374
375 #[test]
376 fn accepts_valid_abi_module() {
377 let module = NativeAbiModule {
378 functions: Vec::new(),
379 roots: vec![NativeAbiFunction {
380 name: "root".to_string(),
381 params: Vec::new(),
382 environment_slots: 1,
383 blocks: vec![NativeAbiBlock {
384 id: BlockId(0),
385 params: Vec::new(),
386 stmts: vec![
387 NativeAbiStmt::LoadEnv {
388 dest: ValueId(0),
389 slot: 0,
390 },
391 NativeAbiStmt::Literal {
392 dest: ValueId(1),
393 literal: NativeLiteral::Int("1".to_string()),
394 },
395 NativeAbiStmt::AllocateClosure {
396 dest: ValueId(2),
397 target: "root#lambda0".to_string(),
398 environment: vec![ValueId(0), ValueId(1)],
399 },
400 ],
401 terminator: NativeTerminator::Return(ValueId(2)),
402 }],
403 }],
404 };
405
406 validate_abi_module(&module).expect("valid abi");
407 }
408
409 #[test]
410 fn rejects_out_of_range_env_slot() {
411 let module = NativeAbiModule {
412 functions: Vec::new(),
413 roots: vec![NativeAbiFunction {
414 name: "root".to_string(),
415 params: Vec::new(),
416 environment_slots: 0,
417 blocks: vec![NativeAbiBlock {
418 id: BlockId(0),
419 params: Vec::new(),
420 stmts: vec![NativeAbiStmt::LoadEnv {
421 dest: ValueId(0),
422 slot: 0,
423 }],
424 terminator: NativeTerminator::Return(ValueId(0)),
425 }],
426 }],
427 };
428
429 assert_eq!(
430 validate_abi_module(&module),
431 Err(NativeAbiValidateError::EnvSlotOutOfRange {
432 function: "root".to_string(),
433 block: BlockId(0),
434 slot: 0,
435 slots: 0,
436 })
437 );
438 }
439
440 #[test]
441 fn rejects_undefined_call_argument() {
442 let module = NativeAbiModule {
443 functions: Vec::new(),
444 roots: vec![NativeAbiFunction {
445 name: "root".to_string(),
446 params: Vec::new(),
447 environment_slots: 0,
448 blocks: vec![NativeAbiBlock {
449 id: BlockId(0),
450 params: Vec::new(),
451 stmts: vec![NativeAbiStmt::DirectCall {
452 dest: ValueId(1),
453 target: "f".to_string(),
454 args: vec![ValueId(0)],
455 }],
456 terminator: NativeTerminator::Return(ValueId(1)),
457 }],
458 }],
459 };
460
461 assert_eq!(
462 validate_abi_module(&module),
463 Err(NativeAbiValidateError::UndefinedValue {
464 function: "root".to_string(),
465 block: BlockId(0),
466 value: ValueId(0),
467 })
468 );
469 }
470}