use crate::types::{Effect, StackType, Type};
pub fn calculate_captures(body_effect: &Effect, call_effect: &Effect) -> Result<Vec<Type>, String> {
let body_inputs = extract_concrete_types(&body_effect.inputs);
let call_inputs = extract_concrete_types(&call_effect.inputs);
if call_inputs.len() > body_inputs.len() {
return Err(format!(
"Closure signature error: call site provides {} values but body only needs {}",
call_inputs.len(),
body_inputs.len()
));
}
let capture_count = body_inputs.len() - call_inputs.len();
Ok(body_inputs[0..capture_count].to_vec())
}
pub(crate) fn extract_concrete_types(stack: &StackType) -> Vec<Type> {
fn collect(stack: &StackType, result: &mut Vec<Type>) {
match stack {
StackType::Cons { rest, top } => {
collect(rest, result);
result.push(top.clone());
}
StackType::Empty | StackType::RowVar(_) => {
}
}
}
let mut types = Vec::new();
collect(stack, &mut types);
types
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Effect, StackType, Type};
fn make_stack(types: &[Type]) -> StackType {
let mut stack = StackType::Empty;
for t in types {
stack = StackType::Cons {
rest: Box::new(stack),
top: t.clone(),
};
}
stack
}
fn make_effect(inputs: &[Type], outputs: &[Type]) -> Effect {
Effect {
inputs: make_stack(inputs),
outputs: make_stack(outputs),
effects: Vec::new(),
}
}
#[test]
fn test_extract_empty_stack() {
let types = extract_concrete_types(&StackType::Empty);
assert!(types.is_empty());
}
#[test]
fn test_extract_single_type() {
let stack = make_stack(&[Type::Int]);
let types = extract_concrete_types(&stack);
assert_eq!(types, vec![Type::Int]);
}
#[test]
fn test_extract_multiple_types() {
let stack = make_stack(&[Type::Int, Type::String, Type::Bool]);
let types = extract_concrete_types(&stack);
assert_eq!(types, vec![Type::Int, Type::String, Type::Bool]);
}
#[test]
fn test_calculate_no_captures() {
let body = make_effect(&[Type::Int], &[Type::Int]);
let call = make_effect(&[Type::Int], &[Type::Int]);
let captures = calculate_captures(&body, &call).unwrap();
assert!(captures.is_empty());
}
#[test]
fn test_calculate_one_capture() {
let body = make_effect(&[Type::Int, Type::Int], &[Type::Int]);
let call = make_effect(&[Type::Int], &[Type::Int]);
let captures = calculate_captures(&body, &call).unwrap();
assert_eq!(captures, vec![Type::Int]);
}
#[test]
fn test_calculate_multiple_captures() {
let body = make_effect(&[Type::Int, Type::String, Type::Bool], &[Type::Bool]);
let call = make_effect(&[Type::Bool], &[Type::Bool]);
let captures = calculate_captures(&body, &call).unwrap();
assert_eq!(captures, vec![Type::Int, Type::String]);
}
#[test]
fn test_calculate_all_captured() {
let body = make_effect(&[Type::Int, Type::String], &[Type::Int]);
let call = make_effect(&[], &[Type::Int]);
let captures = calculate_captures(&body, &call).unwrap();
assert_eq!(captures, vec![Type::Int, Type::String]);
}
#[test]
fn test_calculate_error_too_many_call_inputs() {
let body = make_effect(&[Type::Int], &[Type::Int]);
let call = make_effect(&[Type::Int, Type::Int], &[Type::Int]);
let result = calculate_captures(&body, &call);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("provides 2 values but body only needs 1")
);
}
}