Skip to main content

yulang_native/
cps_validate.rs

1use std::collections::HashSet;
2use std::fmt;
3
4use crate::cps_ir::{
5    CpsContinuation, CpsContinuationId, CpsFunction, CpsHandlerId, CpsModule, CpsStmt,
6    CpsTerminator, CpsValueId,
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum CpsValidateError {
11    MissingEntry {
12        function: String,
13        entry: CpsContinuationId,
14    },
15    DuplicateContinuation {
16        function: String,
17        id: CpsContinuationId,
18    },
19    MissingContinuation {
20        function: String,
21        id: CpsContinuationId,
22    },
23    DuplicateHandler {
24        function: String,
25        id: CpsHandlerId,
26    },
27    MissingHandler {
28        function: String,
29        id: CpsHandlerId,
30    },
31    ContinuationArityMismatch {
32        function: String,
33        id: CpsContinuationId,
34        expected: usize,
35        actual: usize,
36    },
37    MissingValue {
38        function: String,
39        id: CpsValueId,
40    },
41}
42
43impl fmt::Display for CpsValidateError {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            CpsValidateError::MissingEntry { function, entry } => {
47                write!(
48                    f,
49                    "CPS function {function} has no entry continuation {entry:?}"
50                )
51            }
52            CpsValidateError::DuplicateContinuation { function, id } => {
53                write!(
54                    f,
55                    "CPS function {function} defines continuation {id:?} twice"
56                )
57            }
58            CpsValidateError::MissingContinuation { function, id } => {
59                write!(
60                    f,
61                    "CPS function {function} references missing continuation {id:?}"
62                )
63            }
64            CpsValidateError::DuplicateHandler { function, id } => {
65                write!(f, "CPS function {function} defines handler {id:?} twice")
66            }
67            CpsValidateError::MissingHandler { function, id } => {
68                write!(
69                    f,
70                    "CPS function {function} references missing handler {id:?}"
71                )
72            }
73            CpsValidateError::ContinuationArityMismatch {
74                function,
75                id,
76                expected,
77                actual,
78            } => write!(
79                f,
80                "CPS function {function} calls continuation {id:?} with {actual} arguments, expected {expected}"
81            ),
82            CpsValidateError::MissingValue { function, id } => {
83                write!(f, "CPS function {function} references missing value {id:?}")
84            }
85        }
86    }
87}
88
89impl std::error::Error for CpsValidateError {}
90
91pub fn validate_cps_module(module: &CpsModule) -> Result<(), CpsValidateError> {
92    for function in module.functions.iter().chain(&module.roots) {
93        validate_function(function)?;
94    }
95    Ok(())
96}
97
98fn validate_function(function: &CpsFunction) -> Result<(), CpsValidateError> {
99    let mut continuation_ids = HashSet::new();
100    for continuation in &function.continuations {
101        if !continuation_ids.insert(continuation.id) {
102            return Err(CpsValidateError::DuplicateContinuation {
103                function: function.name.clone(),
104                id: continuation.id,
105            });
106        }
107    }
108    if !continuation_ids.contains(&function.entry) {
109        return Err(CpsValidateError::MissingEntry {
110            function: function.name.clone(),
111            entry: function.entry,
112        });
113    }
114
115    let mut handler_ids = HashSet::new();
116    for handler in &function.handlers {
117        if !handler_ids.insert(handler.id) {
118            return Err(CpsValidateError::DuplicateHandler {
119                function: function.name.clone(),
120                id: handler.id,
121            });
122        }
123        for arm in &handler.arms {
124            require_continuation(function, &continuation_ids, arm.entry)?;
125        }
126    }
127
128    let defined_values = function_defined_values(function);
129    for continuation in &function.continuations {
130        validate_continuation(
131            function,
132            continuation,
133            &continuation_ids,
134            &handler_ids,
135            &defined_values,
136        )?;
137    }
138    Ok(())
139}
140
141fn function_defined_values(function: &CpsFunction) -> HashSet<CpsValueId> {
142    let mut values = function.params.iter().copied().collect::<HashSet<_>>();
143    for continuation in &function.continuations {
144        values.extend(continuation.params.iter().copied());
145        for stmt in &continuation.stmts {
146            match stmt {
147                CpsStmt::Literal { dest, .. }
148                | CpsStmt::FreshGuard { dest, .. }
149                | CpsStmt::PeekGuard { dest }
150                | CpsStmt::FindGuard { dest, .. }
151                | CpsStmt::MakeThunk { dest, .. }
152                | CpsStmt::AddThunkBoundary { dest, .. }
153                | CpsStmt::MakeClosure { dest, .. }
154                | CpsStmt::MakeRecursiveClosure { dest, .. }
155                | CpsStmt::ForceThunk { dest, .. }
156                | CpsStmt::Tuple { dest, .. }
157                | CpsStmt::Record { dest, .. }
158                | CpsStmt::RecordWithoutFields { dest, .. }
159                | CpsStmt::Variant { dest, .. }
160                | CpsStmt::Select { dest, .. }
161                | CpsStmt::SelectWithDefault { dest, .. }
162                | CpsStmt::RecordHasField { dest, .. }
163                | CpsStmt::TupleGet { dest, .. }
164                | CpsStmt::VariantTagEq { dest, .. }
165                | CpsStmt::VariantPayload { dest, .. }
166                | CpsStmt::Primitive { dest, .. }
167                | CpsStmt::DirectCall { dest, .. }
168                | CpsStmt::ApplyClosure { dest, .. }
169                | CpsStmt::CloneContinuation { dest, .. }
170                | CpsStmt::Resume { dest, .. }
171                | CpsStmt::ResumeWithHandler { dest, .. } => {
172                    values.insert(*dest);
173                }
174                CpsStmt::InstallHandler { .. } | CpsStmt::UninstallHandler { .. } => {}
175            }
176        }
177    }
178    values
179}
180
181fn validate_continuation(
182    function: &CpsFunction,
183    continuation: &CpsContinuation,
184    continuation_ids: &HashSet<CpsContinuationId>,
185    handler_ids: &HashSet<CpsHandlerId>,
186    defined_values: &HashSet<CpsValueId>,
187) -> Result<(), CpsValidateError> {
188    let mut values = continuation.params.iter().copied().collect::<HashSet<_>>();
189    for capture in &continuation.captures {
190        require_value(function, defined_values, *capture)?;
191        values.insert(*capture);
192    }
193
194    for stmt in &continuation.stmts {
195        match stmt {
196            CpsStmt::Literal { dest, .. } => {
197                values.insert(*dest);
198            }
199            CpsStmt::FreshGuard { dest, .. } | CpsStmt::PeekGuard { dest } => {
200                values.insert(*dest);
201            }
202            CpsStmt::FindGuard { dest, guard } => {
203                require_value(function, &values, *guard)?;
204                values.insert(*dest);
205            }
206            CpsStmt::MakeThunk { dest, entry } => {
207                require_continuation(function, continuation_ids, *entry)?;
208                values.insert(*dest);
209            }
210            CpsStmt::AddThunkBoundary {
211                dest, thunk, guard, ..
212            } => {
213                require_value(function, &values, *thunk)?;
214                require_value(function, &values, *guard)?;
215                values.insert(*dest);
216            }
217            CpsStmt::MakeClosure { dest, entry } => {
218                require_continuation(function, continuation_ids, *entry)?;
219                values.insert(*dest);
220            }
221            CpsStmt::MakeRecursiveClosure { dest, entry } => {
222                require_continuation(function, continuation_ids, *entry)?;
223                values.insert(*dest);
224            }
225            CpsStmt::ForceThunk { dest, thunk } => {
226                require_value(function, &values, *thunk)?;
227                values.insert(*dest);
228            }
229            CpsStmt::Tuple { dest, items } => {
230                for item in items {
231                    require_value(function, &values, *item)?;
232                }
233                values.insert(*dest);
234            }
235            CpsStmt::Record { dest, base, fields } => {
236                if let Some(base) = base {
237                    require_value(function, &values, *base)?;
238                }
239                for field in fields {
240                    require_value(function, &values, field.value)?;
241                }
242                values.insert(*dest);
243            }
244            CpsStmt::RecordWithoutFields { dest, base, .. } => {
245                require_value(function, &values, *base)?;
246                values.insert(*dest);
247            }
248            CpsStmt::Variant { dest, value, .. } => {
249                if let Some(value) = value {
250                    require_value(function, &values, *value)?;
251                }
252                values.insert(*dest);
253            }
254            CpsStmt::Select { dest, base, .. } => {
255                require_value(function, &values, *base)?;
256                values.insert(*dest);
257            }
258            CpsStmt::SelectWithDefault {
259                dest,
260                base,
261                default,
262                ..
263            } => {
264                require_value(function, &values, *base)?;
265                require_value(function, &values, *default)?;
266                values.insert(*dest);
267            }
268            CpsStmt::RecordHasField { dest, base, .. } => {
269                require_value(function, &values, *base)?;
270                values.insert(*dest);
271            }
272            CpsStmt::TupleGet { dest, tuple, .. } => {
273                require_value(function, &values, *tuple)?;
274                values.insert(*dest);
275            }
276            CpsStmt::VariantTagEq { dest, variant, .. }
277            | CpsStmt::VariantPayload { dest, variant, .. } => {
278                require_value(function, &values, *variant)?;
279                values.insert(*dest);
280            }
281            CpsStmt::Primitive { dest, args, .. } | CpsStmt::DirectCall { dest, args, .. } => {
282                for arg in args {
283                    require_value(function, &values, *arg)?;
284                }
285                values.insert(*dest);
286            }
287            CpsStmt::ApplyClosure { dest, closure, arg } => {
288                require_value(function, &values, *closure)?;
289                require_value(function, &values, *arg)?;
290                values.insert(*dest);
291            }
292            CpsStmt::CloneContinuation { dest, source } => {
293                require_value(function, &values, *source)?;
294                values.insert(*dest);
295            }
296            CpsStmt::Resume {
297                dest,
298                resumption,
299                arg,
300            } => {
301                require_value(function, &values, *resumption)?;
302                require_value(function, &values, *arg)?;
303                values.insert(*dest);
304            }
305            CpsStmt::ResumeWithHandler {
306                dest,
307                resumption,
308                arg,
309                envs,
310                ..
311            } => {
312                require_value(function, &values, *resumption)?;
313                require_value(function, &values, *arg)?;
314                for env in envs {
315                    for value in &env.values {
316                        require_value(function, &values, *value)?;
317                    }
318                }
319                values.insert(*dest);
320            }
321            CpsStmt::InstallHandler { envs, .. } => {
322                for env in envs {
323                    for value in &env.values {
324                        require_value(function, &values, *value)?;
325                    }
326                }
327            }
328            CpsStmt::UninstallHandler { .. } => {}
329        }
330    }
331
332    match &continuation.terminator {
333        CpsTerminator::Return(value) => require_value(function, &values, *value),
334        CpsTerminator::Continue { target, args } => {
335            let target_cont = function
336                .continuations
337                .iter()
338                .find(|continuation| continuation.id == *target)
339                .ok_or_else(|| CpsValidateError::MissingContinuation {
340                    function: function.name.clone(),
341                    id: *target,
342                })?;
343            if target_cont.params.len() != args.len() {
344                return Err(CpsValidateError::ContinuationArityMismatch {
345                    function: function.name.clone(),
346                    id: *target,
347                    expected: target_cont.params.len(),
348                    actual: args.len(),
349                });
350            }
351            for arg in args {
352                require_value(function, &values, *arg)?;
353            }
354            Ok(())
355        }
356        CpsTerminator::Branch {
357            cond,
358            then_cont,
359            else_cont,
360        } => {
361            require_value(function, &values, *cond)?;
362            require_continuation(function, continuation_ids, *then_cont)?;
363            require_continuation(function, continuation_ids, *else_cont)
364        }
365        CpsTerminator::Perform {
366            payload,
367            resume,
368            blocked,
369            handler,
370            ..
371        } => {
372            require_value(function, &values, *payload)?;
373            if let Some(blocked) = blocked {
374                require_value(function, &values, *blocked)?;
375            }
376            require_continuation(function, continuation_ids, *resume)?;
377            if handler.0 == usize::MAX {
378                Ok(())
379            } else {
380                require_handler(function, handler_ids, *handler)
381            }
382        }
383        CpsTerminator::EffectfulCall { args, resume, .. } => {
384            for arg in args {
385                require_value(function, &values, *arg)?;
386            }
387            require_continuation(function, continuation_ids, *resume)
388        }
389        CpsTerminator::EffectfulApply {
390            closure,
391            arg,
392            resume,
393        } => {
394            require_value(function, &values, *closure)?;
395            require_value(function, &values, *arg)?;
396            require_continuation(function, continuation_ids, *resume)
397        }
398        CpsTerminator::EffectfulForce { thunk, resume } => {
399            require_value(function, &values, *thunk)?;
400            require_continuation(function, continuation_ids, *resume)
401        }
402    }
403}
404
405fn require_value(
406    function: &CpsFunction,
407    values: &HashSet<CpsValueId>,
408    id: CpsValueId,
409) -> Result<(), CpsValidateError> {
410    if values.contains(&id) {
411        Ok(())
412    } else {
413        Err(CpsValidateError::MissingValue {
414            function: function.name.clone(),
415            id,
416        })
417    }
418}
419
420fn require_continuation(
421    function: &CpsFunction,
422    continuation_ids: &HashSet<CpsContinuationId>,
423    id: CpsContinuationId,
424) -> Result<(), CpsValidateError> {
425    if continuation_ids.contains(&id) {
426        Ok(())
427    } else {
428        Err(CpsValidateError::MissingContinuation {
429            function: function.name.clone(),
430            id,
431        })
432    }
433}
434
435fn require_handler(
436    function: &CpsFunction,
437    handler_ids: &HashSet<CpsHandlerId>,
438    id: CpsHandlerId,
439) -> Result<(), CpsValidateError> {
440    if handler_ids.contains(&id) {
441        Ok(())
442    } else {
443        Err(CpsValidateError::MissingHandler {
444            function: function.name.clone(),
445            id,
446        })
447    }
448}