use std::collections::HashSet;
use std::fmt;
use crate::cps_ir::{
CpsContinuation, CpsContinuationId, CpsFunction, CpsHandlerId, CpsModule, CpsStmt,
CpsTerminator, CpsValueId,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CpsValidateError {
MissingEntry {
function: String,
entry: CpsContinuationId,
},
DuplicateContinuation {
function: String,
id: CpsContinuationId,
},
MissingContinuation {
function: String,
id: CpsContinuationId,
},
DuplicateHandler {
function: String,
id: CpsHandlerId,
},
MissingHandler {
function: String,
id: CpsHandlerId,
},
ContinuationArityMismatch {
function: String,
id: CpsContinuationId,
expected: usize,
actual: usize,
},
MissingValue {
function: String,
id: CpsValueId,
},
}
impl fmt::Display for CpsValidateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CpsValidateError::MissingEntry { function, entry } => {
write!(
f,
"CPS function {function} has no entry continuation {entry:?}"
)
}
CpsValidateError::DuplicateContinuation { function, id } => {
write!(
f,
"CPS function {function} defines continuation {id:?} twice"
)
}
CpsValidateError::MissingContinuation { function, id } => {
write!(
f,
"CPS function {function} references missing continuation {id:?}"
)
}
CpsValidateError::DuplicateHandler { function, id } => {
write!(f, "CPS function {function} defines handler {id:?} twice")
}
CpsValidateError::MissingHandler { function, id } => {
write!(
f,
"CPS function {function} references missing handler {id:?}"
)
}
CpsValidateError::ContinuationArityMismatch {
function,
id,
expected,
actual,
} => write!(
f,
"CPS function {function} calls continuation {id:?} with {actual} arguments, expected {expected}"
),
CpsValidateError::MissingValue { function, id } => {
write!(f, "CPS function {function} references missing value {id:?}")
}
}
}
}
impl std::error::Error for CpsValidateError {}
pub fn validate_cps_module(module: &CpsModule) -> Result<(), CpsValidateError> {
for function in module.functions.iter().chain(&module.roots) {
validate_function(function)?;
}
Ok(())
}
fn validate_function(function: &CpsFunction) -> Result<(), CpsValidateError> {
let mut continuation_ids = HashSet::new();
for continuation in &function.continuations {
if !continuation_ids.insert(continuation.id) {
return Err(CpsValidateError::DuplicateContinuation {
function: function.name.clone(),
id: continuation.id,
});
}
}
if !continuation_ids.contains(&function.entry) {
return Err(CpsValidateError::MissingEntry {
function: function.name.clone(),
entry: function.entry,
});
}
let mut handler_ids = HashSet::new();
for handler in &function.handlers {
if !handler_ids.insert(handler.id) {
return Err(CpsValidateError::DuplicateHandler {
function: function.name.clone(),
id: handler.id,
});
}
for arm in &handler.arms {
require_continuation(function, &continuation_ids, arm.entry)?;
}
}
let defined_values = function_defined_values(function);
for continuation in &function.continuations {
validate_continuation(
function,
continuation,
&continuation_ids,
&handler_ids,
&defined_values,
)?;
}
Ok(())
}
fn function_defined_values(function: &CpsFunction) -> HashSet<CpsValueId> {
let mut values = function.params.iter().copied().collect::<HashSet<_>>();
for continuation in &function.continuations {
values.extend(continuation.params.iter().copied());
for stmt in &continuation.stmts {
match stmt {
CpsStmt::Literal { dest, .. }
| CpsStmt::FreshGuard { dest, .. }
| CpsStmt::PeekGuard { dest }
| CpsStmt::FindGuard { dest, .. }
| CpsStmt::MakeThunk { dest, .. }
| CpsStmt::AddThunkBoundary { dest, .. }
| CpsStmt::MakeClosure { dest, .. }
| CpsStmt::MakeRecursiveClosure { dest, .. }
| CpsStmt::ForceThunk { dest, .. }
| CpsStmt::Tuple { dest, .. }
| CpsStmt::Record { dest, .. }
| CpsStmt::RecordWithoutFields { dest, .. }
| CpsStmt::Variant { dest, .. }
| CpsStmt::Select { dest, .. }
| CpsStmt::SelectWithDefault { dest, .. }
| CpsStmt::RecordHasField { dest, .. }
| CpsStmt::TupleGet { dest, .. }
| CpsStmt::VariantTagEq { dest, .. }
| CpsStmt::VariantPayload { dest, .. }
| CpsStmt::Primitive { dest, .. }
| CpsStmt::DirectCall { dest, .. }
| CpsStmt::ApplyClosure { dest, .. }
| CpsStmt::CloneContinuation { dest, .. }
| CpsStmt::Resume { dest, .. }
| CpsStmt::ResumeWithHandler { dest, .. } => {
values.insert(*dest);
}
CpsStmt::InstallHandler { .. } | CpsStmt::UninstallHandler { .. } => {}
}
}
}
values
}
fn validate_continuation(
function: &CpsFunction,
continuation: &CpsContinuation,
continuation_ids: &HashSet<CpsContinuationId>,
handler_ids: &HashSet<CpsHandlerId>,
defined_values: &HashSet<CpsValueId>,
) -> Result<(), CpsValidateError> {
let mut values = continuation.params.iter().copied().collect::<HashSet<_>>();
for capture in &continuation.captures {
require_value(function, defined_values, *capture)?;
values.insert(*capture);
}
for stmt in &continuation.stmts {
match stmt {
CpsStmt::Literal { dest, .. } => {
values.insert(*dest);
}
CpsStmt::FreshGuard { dest, .. } | CpsStmt::PeekGuard { dest } => {
values.insert(*dest);
}
CpsStmt::FindGuard { dest, guard } => {
require_value(function, &values, *guard)?;
values.insert(*dest);
}
CpsStmt::MakeThunk { dest, entry } => {
require_continuation(function, continuation_ids, *entry)?;
values.insert(*dest);
}
CpsStmt::AddThunkBoundary {
dest, thunk, guard, ..
} => {
require_value(function, &values, *thunk)?;
require_value(function, &values, *guard)?;
values.insert(*dest);
}
CpsStmt::MakeClosure { dest, entry } => {
require_continuation(function, continuation_ids, *entry)?;
values.insert(*dest);
}
CpsStmt::MakeRecursiveClosure { dest, entry } => {
require_continuation(function, continuation_ids, *entry)?;
values.insert(*dest);
}
CpsStmt::ForceThunk { dest, thunk } => {
require_value(function, &values, *thunk)?;
values.insert(*dest);
}
CpsStmt::Tuple { dest, items } => {
for item in items {
require_value(function, &values, *item)?;
}
values.insert(*dest);
}
CpsStmt::Record { dest, base, fields } => {
if let Some(base) = base {
require_value(function, &values, *base)?;
}
for field in fields {
require_value(function, &values, field.value)?;
}
values.insert(*dest);
}
CpsStmt::RecordWithoutFields { dest, base, .. } => {
require_value(function, &values, *base)?;
values.insert(*dest);
}
CpsStmt::Variant { dest, value, .. } => {
if let Some(value) = value {
require_value(function, &values, *value)?;
}
values.insert(*dest);
}
CpsStmt::Select { dest, base, .. } => {
require_value(function, &values, *base)?;
values.insert(*dest);
}
CpsStmt::SelectWithDefault {
dest,
base,
default,
..
} => {
require_value(function, &values, *base)?;
require_value(function, &values, *default)?;
values.insert(*dest);
}
CpsStmt::RecordHasField { dest, base, .. } => {
require_value(function, &values, *base)?;
values.insert(*dest);
}
CpsStmt::TupleGet { dest, tuple, .. } => {
require_value(function, &values, *tuple)?;
values.insert(*dest);
}
CpsStmt::VariantTagEq { dest, variant, .. }
| CpsStmt::VariantPayload { dest, variant, .. } => {
require_value(function, &values, *variant)?;
values.insert(*dest);
}
CpsStmt::Primitive { dest, args, .. } | CpsStmt::DirectCall { dest, args, .. } => {
for arg in args {
require_value(function, &values, *arg)?;
}
values.insert(*dest);
}
CpsStmt::ApplyClosure { dest, closure, arg } => {
require_value(function, &values, *closure)?;
require_value(function, &values, *arg)?;
values.insert(*dest);
}
CpsStmt::CloneContinuation { dest, source } => {
require_value(function, &values, *source)?;
values.insert(*dest);
}
CpsStmt::Resume {
dest,
resumption,
arg,
} => {
require_value(function, &values, *resumption)?;
require_value(function, &values, *arg)?;
values.insert(*dest);
}
CpsStmt::ResumeWithHandler {
dest,
resumption,
arg,
envs,
..
} => {
require_value(function, &values, *resumption)?;
require_value(function, &values, *arg)?;
for env in envs {
for value in &env.values {
require_value(function, &values, *value)?;
}
}
values.insert(*dest);
}
CpsStmt::InstallHandler { envs, .. } => {
for env in envs {
for value in &env.values {
require_value(function, &values, *value)?;
}
}
}
CpsStmt::UninstallHandler { .. } => {}
}
}
match &continuation.terminator {
CpsTerminator::Return(value) => require_value(function, &values, *value),
CpsTerminator::Continue { target, args } => {
let target_cont = function
.continuations
.iter()
.find(|continuation| continuation.id == *target)
.ok_or_else(|| CpsValidateError::MissingContinuation {
function: function.name.clone(),
id: *target,
})?;
if target_cont.params.len() != args.len() {
return Err(CpsValidateError::ContinuationArityMismatch {
function: function.name.clone(),
id: *target,
expected: target_cont.params.len(),
actual: args.len(),
});
}
for arg in args {
require_value(function, &values, *arg)?;
}
Ok(())
}
CpsTerminator::Branch {
cond,
then_cont,
else_cont,
} => {
require_value(function, &values, *cond)?;
require_continuation(function, continuation_ids, *then_cont)?;
require_continuation(function, continuation_ids, *else_cont)
}
CpsTerminator::Perform {
payload,
resume,
blocked,
handler,
..
} => {
require_value(function, &values, *payload)?;
if let Some(blocked) = blocked {
require_value(function, &values, *blocked)?;
}
require_continuation(function, continuation_ids, *resume)?;
if handler.0 == usize::MAX {
Ok(())
} else {
require_handler(function, handler_ids, *handler)
}
}
CpsTerminator::EffectfulCall { args, resume, .. } => {
for arg in args {
require_value(function, &values, *arg)?;
}
require_continuation(function, continuation_ids, *resume)
}
CpsTerminator::EffectfulApply {
closure,
arg,
resume,
} => {
require_value(function, &values, *closure)?;
require_value(function, &values, *arg)?;
require_continuation(function, continuation_ids, *resume)
}
CpsTerminator::EffectfulForce { thunk, resume } => {
require_value(function, &values, *thunk)?;
require_continuation(function, continuation_ids, *resume)
}
}
}
fn require_value(
function: &CpsFunction,
values: &HashSet<CpsValueId>,
id: CpsValueId,
) -> Result<(), CpsValidateError> {
if values.contains(&id) {
Ok(())
} else {
Err(CpsValidateError::MissingValue {
function: function.name.clone(),
id,
})
}
}
fn require_continuation(
function: &CpsFunction,
continuation_ids: &HashSet<CpsContinuationId>,
id: CpsContinuationId,
) -> Result<(), CpsValidateError> {
if continuation_ids.contains(&id) {
Ok(())
} else {
Err(CpsValidateError::MissingContinuation {
function: function.name.clone(),
id,
})
}
}
fn require_handler(
function: &CpsFunction,
handler_ids: &HashSet<CpsHandlerId>,
id: CpsHandlerId,
) -> Result<(), CpsValidateError> {
if handler_ids.contains(&id) {
Ok(())
} else {
Err(CpsValidateError::MissingHandler {
function: function.name.clone(),
id,
})
}
}