use std::collections::{HashMap, HashSet};
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use super::{DefId, Ty};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct EffectVarId(pub u32);
impl EffectVarId {
pub fn fresh() -> Self {
static COUNTER: AtomicU32 = AtomicU32::new(0);
Self(COUNTER.fetch_add(1, Ordering::SeqCst))
}
}
impl fmt::Display for EffectVarId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "?E{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Effect {
pub name: Arc<str>,
pub params: Vec<Ty>,
}
impl Effect {
pub fn new(name: impl Into<Arc<str>>) -> Self {
Self {
name: name.into(),
params: Vec::new(),
}
}
pub fn with_params(name: impl Into<Arc<str>>, params: Vec<Ty>) -> Self {
Self {
name: name.into(),
params,
}
}
pub fn io() -> Self {
Self::new("IO")
}
pub fn error(err_ty: Ty) -> Self {
Self::with_params("Error", vec![err_ty])
}
pub fn async_effect() -> Self {
Self::new("Async")
}
pub fn state(state_ty: Ty) -> Self {
Self::with_params("State", vec![state_ty])
}
pub fn nondet() -> Self {
Self::new("NonDet")
}
pub fn pure() -> Self {
Self::new("Pure")
}
pub fn is_io(&self) -> bool {
self.name.as_ref() == "IO"
}
pub fn is_error(&self) -> bool {
self.name.as_ref() == "Error"
}
pub fn is_async(&self) -> bool {
self.name.as_ref() == "Async"
}
}
impl fmt::Display for Effect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)?;
if !self.params.is_empty() {
write!(f, "<")?;
for (i, param) in self.params.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", param)?;
}
write!(f, ">")?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EffectRow {
pub effects: HashSet<Effect>,
pub tail: Option<EffectVarId>,
}
impl Hash for EffectRow {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut sorted: Vec<_> = self.effects.iter().collect();
sorted.sort_by(|a, b| a.name.cmp(&b.name));
for eff in &sorted {
eff.hash(state);
}
self.tail.hash(state);
}
}
impl EffectRow {
pub fn empty() -> Self {
Self {
effects: HashSet::new(),
tail: None,
}
}
pub fn closed(effects: impl IntoIterator<Item = Effect>) -> Self {
Self {
effects: effects.into_iter().collect(),
tail: None,
}
}
pub fn open(effects: impl IntoIterator<Item = Effect>, tail: EffectVarId) -> Self {
Self {
effects: effects.into_iter().collect(),
tail: Some(tail),
}
}
pub fn var(var: EffectVarId) -> Self {
Self {
effects: HashSet::new(),
tail: Some(var),
}
}
pub fn fresh_var() -> Self {
Self::var(EffectVarId::fresh())
}
pub fn is_empty(&self) -> bool {
self.effects.is_empty() && self.tail.is_none()
}
pub fn is_closed(&self) -> bool {
self.tail.is_none()
}
pub fn is_open(&self) -> bool {
self.tail.is_some()
}
pub fn contains(&self, effect: &Effect) -> bool {
self.effects.contains(effect)
}
pub fn has_io(&self) -> bool {
self.effects.iter().any(|e| e.is_io())
}
pub fn has_error(&self) -> bool {
self.effects.iter().any(|e| e.is_error())
}
pub fn has_async(&self) -> bool {
self.effects.iter().any(|e| e.is_async())
}
pub fn add(&mut self, effect: Effect) {
self.effects.insert(effect);
}
pub fn remove(&mut self, effect: &Effect) -> bool {
self.effects.remove(effect)
}
pub fn merge(&self, other: &EffectRow) -> EffectRow {
let effects: HashSet<_> = self.effects.union(&other.effects).cloned().collect();
let tail = match (self.tail, other.tail) {
(Some(v1), Some(v2)) if v1 == v2 => Some(v1),
(Some(v), None) | (None, Some(v)) => Some(v),
(Some(_), Some(_)) => Some(EffectVarId::fresh()), (None, None) => None,
};
EffectRow { effects, tail }
}
pub fn substitute(&self, subst: &EffectSubstitution) -> EffectRow {
let mut result = EffectRow {
effects: self.effects.clone(),
tail: None,
};
if let Some(var) = self.tail {
if let Some(row) = subst.get(var) {
result.effects.extend(row.effects.clone());
result.tail = row.tail;
} else {
result.tail = Some(var);
}
}
result
}
}
impl fmt::Display for EffectRow {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_empty() {
return write!(f, "Pure");
}
let effects: Vec<_> = self.effects.iter().map(|e| format!("{}", e)).collect();
if let Some(tail) = self.tail {
if effects.is_empty() {
write!(f, "{}", tail)
} else {
write!(f, "{} | {}", effects.join(", "), tail)
}
} else {
write!(f, "{}", effects.join(", "))
}
}
}
#[derive(Debug, Clone, Default)]
pub struct EffectSubstitution {
map: HashMap<EffectVarId, EffectRow>,
}
impl EffectSubstitution {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, var: EffectVarId, row: EffectRow) {
self.map.insert(var, row);
}
pub fn get(&self, var: EffectVarId) -> Option<&EffectRow> {
self.map.get(&var)
}
pub fn contains(&self, var: EffectVarId) -> bool {
self.map.contains_key(&var)
}
pub fn compose(&self, other: &EffectSubstitution) -> EffectSubstitution {
let mut result = EffectSubstitution::new();
for (var, row) in &other.map {
result.insert(*var, row.substitute(self));
}
for (var, row) in &self.map {
if !result.contains(*var) {
result.insert(*var, row.clone());
}
}
result
}
}
#[derive(Debug, Clone)]
pub struct EffectDef {
pub def_id: DefId,
pub name: Arc<str>,
pub type_params: Vec<Arc<str>>,
pub operations: Vec<EffectOperation>,
}
impl EffectDef {
pub fn new(def_id: DefId, name: impl Into<Arc<str>>) -> Self {
Self {
def_id,
name: name.into(),
type_params: Vec::new(),
operations: Vec::new(),
}
}
pub fn with_type_param(mut self, name: impl Into<Arc<str>>) -> Self {
self.type_params.push(name.into());
self
}
pub fn with_operation(mut self, op: EffectOperation) -> Self {
self.operations.push(op);
self
}
}
#[derive(Debug, Clone)]
pub struct EffectOperation {
pub name: Arc<str>,
pub params: Vec<Ty>,
pub return_ty: Ty,
}
impl EffectOperation {
pub fn new(name: impl Into<Arc<str>>, params: Vec<Ty>, return_ty: Ty) -> Self {
Self {
name: name.into(),
params,
return_ty,
}
}
}
#[derive(Debug, Clone)]
pub struct EffectHandler {
pub effect: Effect,
pub clauses: Vec<HandlerClause>,
pub return_clause: Option<(Ty, Ty)>,
}
#[derive(Debug, Clone)]
pub struct HandlerClause {
pub operation: Arc<str>,
pub param_names: Vec<Arc<str>>,
pub resumes: bool,
}
impl HandlerClause {
pub fn new(operation: impl Into<Arc<str>>, param_names: Vec<Arc<str>>) -> Self {
Self {
operation: operation.into(),
param_names,
resumes: true,
}
}
pub fn with_resume(mut self, resumes: bool) -> Self {
self.resumes = resumes;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EffectfulFn {
pub params: Vec<Ty>,
pub return_ty: Ty,
pub effects: EffectRow,
}
impl EffectfulFn {
pub fn new(params: Vec<Ty>, return_ty: Ty, effects: EffectRow) -> Self {
Self {
params,
return_ty,
effects,
}
}
pub fn pure(params: Vec<Ty>, return_ty: Ty) -> Self {
Self::new(params, return_ty, EffectRow::empty())
}
pub fn is_pure(&self) -> bool {
self.effects.is_empty()
}
}
impl fmt::Display for EffectfulFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "fn(")?;
for (i, param) in self.params.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", param)?;
}
write!(f, ") -> {}", self.return_ty)?;
if !self.effects.is_empty() {
write!(f, " with {}", self.effects)?;
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct EffectContext {
effects: HashMap<Arc<str>, EffectDef>,
subst: EffectSubstitution,
constraints: Vec<EffectConstraint>,
}
impl EffectContext {
pub fn new() -> Self {
let mut ctx = Self::default();
ctx.register_builtin_effects();
ctx
}
fn register_builtin_effects(&mut self) {
let io = EffectDef::new(DefId::new(0, 0), "IO")
.with_operation(EffectOperation::new("print", vec![Ty::str()], Ty::unit()))
.with_operation(EffectOperation::new("read_line", vec![], Ty::str()));
self.register_effect(io);
let error = EffectDef::new(DefId::new(0, 1), "Error")
.with_type_param("E")
.with_operation(EffectOperation::new(
"throw",
vec![Ty::param("E", 0)],
Ty::never(),
))
.with_operation(EffectOperation::new("catch", vec![], Ty::param("E", 0)));
self.register_effect(error);
let async_eff = EffectDef::new(DefId::new(0, 2), "Async")
.with_operation(EffectOperation::new(
"await",
vec![Ty::param("T", 0)],
Ty::param("T", 0),
))
.with_operation(EffectOperation::new("spawn", vec![], Ty::unit()));
self.register_effect(async_eff);
let state = EffectDef::new(DefId::new(0, 3), "State")
.with_type_param("S")
.with_operation(EffectOperation::new("get", vec![], Ty::param("S", 0)))
.with_operation(EffectOperation::new(
"put",
vec![Ty::param("S", 0)],
Ty::unit(),
))
.with_operation(EffectOperation::new(
"modify",
vec![Ty::function(vec![Ty::param("S", 0)], Ty::param("S", 0))],
Ty::unit(),
));
self.register_effect(state);
let nondet = EffectDef::new(DefId::new(0, 4), "NonDet")
.with_operation(EffectOperation::new(
"choose",
vec![Ty::param("T", 0), Ty::param("T", 0)],
Ty::param("T", 0),
))
.with_operation(EffectOperation::new("fail", vec![], Ty::never()));
self.register_effect(nondet);
}
pub fn register_effect(&mut self, effect: EffectDef) {
self.effects.insert(effect.name.clone(), effect);
}
pub fn get_effect(&self, name: &str) -> Option<&EffectDef> {
self.effects.get(name)
}
pub fn all_effects(&self) -> Vec<&EffectDef> {
self.effects.values().collect()
}
pub fn add_constraint(&mut self, constraint: EffectConstraint) {
self.constraints.push(constraint);
}
pub fn unify_rows(&mut self, r1: &EffectRow, r2: &EffectRow) -> Result<(), EffectError> {
let r1 = r1.substitute(&self.subst);
let r2 = r2.substitute(&self.subst);
match (r1.tail, r2.tail) {
(None, None) => {
if r1.effects == r2.effects {
Ok(())
} else {
Err(EffectError::Mismatch {
expected: r1.clone(),
found: r2.clone(),
})
}
}
(Some(v), None) => {
if r1.effects.is_subset(&r2.effects) {
let diff: HashSet<_> = r2.effects.difference(&r1.effects).cloned().collect();
self.subst.insert(v, EffectRow::closed(diff));
Ok(())
} else {
Err(EffectError::Mismatch {
expected: r1.clone(),
found: r2.clone(),
})
}
}
(None, Some(v)) => {
if r2.effects.is_subset(&r1.effects) {
let diff: HashSet<_> = r1.effects.difference(&r2.effects).cloned().collect();
self.subst.insert(v, EffectRow::closed(diff));
Ok(())
} else {
Err(EffectError::Mismatch {
expected: r1.clone(),
found: r2.clone(),
})
}
}
(Some(v1), Some(v2)) if v1 == v2 => {
if r1.effects == r2.effects {
Ok(())
} else {
Err(EffectError::Mismatch {
expected: r1.clone(),
found: r2.clone(),
})
}
}
(Some(v1), Some(v2)) => {
let fresh = EffectVarId::fresh();
let union: HashSet<_> = r1.effects.union(&r2.effects).cloned().collect();
let r1_diff: HashSet<_> = union.difference(&r1.effects).cloned().collect();
self.subst.insert(v1, EffectRow::open(r1_diff, fresh));
let r2_diff: HashSet<_> = union.difference(&r2.effects).cloned().collect();
self.subst.insert(v2, EffectRow::open(r2_diff, fresh));
Ok(())
}
}
}
pub fn subsumes(&self, r1: &EffectRow, r2: &EffectRow) -> bool {
let r1 = r1.substitute(&self.subst);
let r2 = r2.substitute(&self.subst);
if !r1.effects.is_subset(&r2.effects) {
return false;
}
match (r1.tail, r2.tail) {
(None, _) => true,
(Some(_), None) => false, (Some(v1), Some(v2)) => v1 == v2, }
}
pub fn apply_subst(&self, row: &EffectRow) -> EffectRow {
row.substitute(&self.subst)
}
}
#[derive(Debug, Clone)]
pub enum EffectConstraint {
Equal(EffectRow, EffectRow),
Subsumes(EffectRow, EffectRow),
MustHandle(Effect),
}
#[derive(Debug, Clone)]
pub enum EffectError {
Mismatch {
expected: EffectRow,
found: EffectRow,
},
Unhandled(Effect),
UnknownEffect(String),
UnknownOperation { effect: String, operation: String },
InvalidHandler(String),
}
impl fmt::Display for EffectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EffectError::Mismatch { expected, found } => {
write!(
f,
"effect mismatch: expected {{{}}}, found {{{}}}",
expected, found
)
}
EffectError::Unhandled(effect) => {
write!(f, "unhandled effect: {}", effect)
}
EffectError::UnknownEffect(name) => {
write!(f, "unknown effect: {}", name)
}
EffectError::UnknownOperation { effect, operation } => {
write!(
f,
"unknown operation '{}' in effect '{}'",
operation, effect
)
}
EffectError::InvalidHandler(msg) => {
write!(f, "invalid handler: {}", msg)
}
}
}
}
impl std::error::Error for EffectError {}
pub fn builtin_effects() -> Vec<EffectDef> {
let ctx = EffectContext::new();
ctx.effects.into_values().collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_effect_display() {
assert_eq!(format!("{}", Effect::io()), "IO");
assert_eq!(format!("{}", Effect::async_effect()), "Async");
let error = Effect::error(Ty::str());
assert_eq!(format!("{}", error), "Error<str>");
}
#[test]
fn test_effect_row() {
let empty = EffectRow::empty();
assert!(empty.is_empty());
assert!(empty.is_closed());
assert_eq!(format!("{}", empty), "Pure");
let io_row = EffectRow::closed(vec![Effect::io()]);
assert!(!io_row.is_empty());
assert!(io_row.has_io());
assert_eq!(format!("{}", io_row), "IO");
let multi = EffectRow::closed(vec![Effect::io(), Effect::async_effect()]);
assert!(multi.has_io());
assert!(multi.has_async());
}
#[test]
fn test_effect_row_merge() {
let r1 = EffectRow::closed(vec![Effect::io()]);
let r2 = EffectRow::closed(vec![Effect::async_effect()]);
let merged = r1.merge(&r2);
assert!(merged.has_io());
assert!(merged.has_async());
assert!(merged.is_closed());
}
#[test]
fn test_open_effect_row() {
let var = EffectVarId::fresh();
let open = EffectRow::open(vec![Effect::io()], var);
assert!(open.is_open());
assert!(open.has_io());
}
#[test]
fn test_effect_unification() {
let mut ctx = EffectContext::new();
let r1 = EffectRow::closed(vec![Effect::io()]);
let r2 = EffectRow::closed(vec![Effect::io()]);
assert!(ctx.unify_rows(&r1, &r2).is_ok());
let mut ctx2 = EffectContext::new();
let r3 = EffectRow::closed(vec![Effect::io()]);
let r4 = EffectRow::closed(vec![Effect::async_effect()]);
assert!(ctx2.unify_rows(&r3, &r4).is_err());
}
#[test]
fn test_effectful_fn() {
let pure_fn = EffectfulFn::pure(vec![Ty::int(super::super::IntTy::I32)], Ty::bool());
assert!(pure_fn.is_pure());
let io_fn = EffectfulFn::new(
vec![Ty::str()],
Ty::unit(),
EffectRow::closed(vec![Effect::io()]),
);
assert!(!io_fn.is_pure());
assert!(io_fn.effects.has_io());
}
}