#![allow(missing_docs)]
use crate::diagnostics::{Error, Result, Span};
use crate::eval::value::{Value, ThreadSafeEnvironment};
use crate::ast::Literal;
use std::collections::HashMap;
use std::fmt;
use std::rc::Rc;
use std::sync::Arc;
use std::any::TypeId;
pub trait FromValue: Sized {
fn from_value(value: Value) -> Result<Self>;
}
pub trait ToValue {
fn to_value(self) -> Value;
}
impl FromValue for Value {
fn from_value(value: Value) -> Result<Self> {
Ok(value)
}
}
impl ToValue for Value {
fn to_value(self) -> Value {
self
}
}
impl FromValue for i32 {
fn from_value(value: Value) -> Result<Self> {
match value {
Value::Literal(Literal::ExactInteger(i)) => Ok(i as i32),
Value::Literal(Literal::InexactReal(f)) if f.fract() == 0.0 => Ok(f as i32),
Value::Literal(literal) if literal.is_number() => {
if let Some(f) = literal.to_f64() {
Ok(f as i32)
} else {
Err(Box::new(Error::type_error("Cannot convert number to i32", Span::new(0, 0))))
}
}
_ => Err(Box::new(Error::type_error("Expected number", Span::new(0, 0)))),
}
}
}
impl ToValue for i32 {
fn to_value(self) -> Value {
Value::Literal(Literal::integer(self as i64))
}
}
impl FromValue for String {
fn from_value(value: Value) -> Result<Self> {
match value {
Value::Literal(Literal::String(s)) => Ok(s),
_ => Err(Box::new(Error::type_error("Expected string", Span::new(0, 0)))),
}
}
}
impl ToValue for String {
fn to_value(self) -> Value {
Value::Literal(Literal::String(self))
}
}
impl FromValue for bool {
fn from_value(value: Value) -> Result<Self> {
match value {
Value::Literal(Literal::Boolean(b)) => Ok(b),
_ => Err(Box::new(Error::type_error("Expected boolean", Span::new(0, 0)))),
}
}
}
impl ToValue for bool {
fn to_value(self) -> Value {
Value::Literal(Literal::Boolean(self))
}
}
#[derive(Debug, Clone)]
pub struct ContinuationMonad<A> {
computation: ContComputation<A>,
}
#[derive(Clone)]
pub struct ContinuationFunc<A, B> {
id: u64,
func: Arc<dyn Fn(A) -> B + Send + Sync + 'static>,
}
impl<A, B> ContinuationFunc<A, B> {
pub fn new<F>(id: u64, func: F) -> Self
where
F: Fn(A) -> B + Send + Sync + 'static,
{
Self {
id,
func: Arc::new(func),
}
}
pub fn call(&self, arg: A) -> B {
(self.func)(arg)
}
}
impl<A, B> std::fmt::Debug for ContinuationFunc<A, B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ContinuationFunc({})", self.id)
}
}
#[derive(Debug, Clone)]
pub enum ContComputation<A> {
Pure(A),
CallCC {
proc: ContinuationFunc<ContinuationFunction, ContinuationMonad<A>>,
},
ApplyContinuation {
continuation: ContinuationFunction,
value: A,
},
Bind {
inner: Box<ContinuationMonad<Value>>,
next: ContinuationFunc<Value, ContinuationMonad<A>>,
},
Effect {
effect_computation: EffectfulComputation,
continuation: ContinuationFunc<Value, ContinuationMonad<A>>,
},
}
#[derive(Debug, Clone)]
pub struct ContinuationFunction {
pub id: u64,
pub environment: Arc<ThreadSafeEnvironment>,
pub computation: ContinuationComputation,
pub invoked: bool,
}
#[derive(Debug, Clone)]
pub enum ContinuationComputation {
EvaluationContext {
stack: Vec<EvaluationFrame>,
captured_env: Arc<ThreadSafeEnvironment>,
},
FunctionCall {
function: Value,
args: Vec<Value>,
env: Arc<ThreadSafeEnvironment>,
},
Composed {
first: Box<ContinuationFunction>,
second: Box<ContinuationFunction>,
},
}
#[derive(Debug, Clone)]
pub enum EvaluationFrame {
Application {
function: Value,
evaluated_args: Vec<Value>,
pending_args: Vec<Value>,
env: Arc<ThreadSafeEnvironment>,
},
Conditional {
then_branch: Value,
else_branch: Value,
env: Arc<ThreadSafeEnvironment>,
},
Sequence {
remaining: Vec<Value>,
env: Arc<ThreadSafeEnvironment>,
},
LetBinding {
bindings: HashMap<String, Value>,
body: Value,
env: Arc<ThreadSafeEnvironment>,
},
}
#[derive(Debug, Clone)]
pub enum EffectfulComputation {
IO {
action: ContIOAction,
},
State {
action: ContStateAction,
state: Arc<ThreadSafeEnvironment>,
},
Error {
error: Error,
},
}
#[derive(Debug, Clone)]
pub enum ContIOAction {
Read,
Write(Value),
Print(Value),
Return(Value),
}
pub use ContIOAction as IOAction;
#[derive(Debug, Clone)]
pub enum ContStateAction {
Get,
Put(Arc<ThreadSafeEnvironment>),
Return(Value),
}
pub use ContStateAction as StateAction;
impl<A> ContinuationMonad<A> {
pub fn pure(value: A) -> Self {
Self {
computation: ContComputation::Pure(value),
}
}
pub fn call_cc<F>(f: F) -> ContinuationMonad<Value>
where
F: Fn(ContinuationFunction) -> ContinuationMonad<Value> + Send + Sync + 'static,
{
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
ContinuationMonad {
computation: ContComputation::CallCC {
proc: ContinuationFunc::new(id, f),
},
}
}
pub fn apply_continuation(cont: ContinuationFunction, value: A) -> Self {
Self {
computation: ContComputation::ApplyContinuation {
continuation: cont,
value,
},
}
}
pub fn bind(self, f: impl Fn(A) -> ContinuationMonad<Value> + Send + Sync + 'static) -> ContinuationMonad<Value>
where
A: 'static,
{
ContinuationMonad::pure(Value::Unspecified) }
pub fn lift_effect(effect: EffectfulComputation) -> ContinuationMonad<Value> {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
ContinuationMonad {
computation: ContComputation::Effect {
effect_computation: effect,
continuation: ContinuationFunc::new(id, ContinuationMonad::pure),
},
}
}
}
impl ContinuationFunction {
pub fn new(
id: u64,
environment: Arc<ThreadSafeEnvironment>,
computation: ContinuationComputation,
) -> Self {
Self {
id,
environment,
computation,
invoked: false,
}
}
pub fn apply(&mut self, value: Value) -> Result<Value> {
if self.invoked {
return Err(Box::new(Error::runtime_error(
"Continuation has already been invoked".to_string(),
None,
)));
}
self.invoked = true;
match &self.computation {
ContinuationComputation::EvaluationContext { stack, captured_env } => {
self.restore_evaluation_context(stack, captured_env.clone(), value)
}
ContinuationComputation::FunctionCall { function, args, env } => {
let mut all_args = vec![value];
all_args.extend_from_slice(args);
Ok(Value::Unspecified)
}
ContinuationComputation::Composed { first, second } => {
let intermediate = first.clone().apply(value)?;
second.clone().apply(intermediate)
}
}
}
fn restore_evaluation_context(
&self,
_stack: &[EvaluationFrame],
_env: Arc<ThreadSafeEnvironment>,
value: Value,
) -> Result<Value> {
Ok(value)
}
pub fn is_valid(&self) -> bool {
!self.invoked
}
}
pub fn run_continuation<A>(cont: ContinuationMonad<A>) -> Result<A>
where
A: FromValue + ToValue + 'static,
{
match cont.computation {
ContComputation::Pure(value) => Ok(value),
ContComputation::CallCC { proc } => {
let dummy_cont = ContinuationFunction::new(
0,
Arc::new(ThreadSafeEnvironment::new(None, 0)),
ContinuationComputation::EvaluationContext {
stack: Vec::new(),
captured_env: Arc::new(ThreadSafeEnvironment::new(None, 0)),
},
);
let result_cont = proc.call(dummy_cont);
run_continuation(result_cont)
}
ContComputation::ApplyContinuation { mut continuation, value } => {
let value_as_value = value.to_value();
let result = continuation.apply(value_as_value)?;
if std::any::TypeId::of::<A>() == std::any::TypeId::of::<Value>() {
unsafe { Ok(std::mem::transmute_copy(&result)) }
} else {
match A::from_value(result) {
Ok(converted) => Ok(converted),
Err(_) => Err(Box::new(Error::type_error("Type conversion failed in continuation application", Span::new(0, 0)))),
}
}
}
ContComputation::Bind { inner, next } => {
let intermediate_result = run_continuation(*inner)?;
let final_cont = next.call(intermediate_result);
run_continuation(final_cont)
}
ContComputation::Effect { effect_computation, continuation } => {
let effect_result = execute_effect(effect_computation)?;
let cont_result = continuation.call(effect_result);
run_continuation(cont_result)
}
}
}
fn execute_effect(effect: EffectfulComputation) -> Result<Value> {
match effect {
EffectfulComputation::IO { action } => {
match action {
ContIOAction::Read => {
Ok(Value::string("input".to_string()))
}
ContIOAction::Write(value) => {
print!("{value}");
Ok(Value::Unspecified)
}
ContIOAction::Print(value) => {
println!("{value}");
Ok(Value::Unspecified)
}
ContIOAction::Return(value) => Ok(value),
}
}
EffectfulComputation::State { action, state: _ } => {
match action {
ContStateAction::Get => {
Ok(Value::Unspecified)
}
ContStateAction::Put(_new_state) => {
Ok(Value::Unspecified)
}
ContStateAction::Return(value) => Ok(value),
}
}
EffectfulComputation::Error { error } => {
Err(Box::new(error))
}
}
}
pub fn escape_continuation(escape_value: Value) -> ContinuationMonad<Value> {
ContinuationMonad::<Value>::call_cc(move |escape| {
ContinuationMonad::apply_continuation(escape, escape_value.clone())
})
}
pub fn retry_continuation<F>(computation: F) -> ContinuationMonad<Value>
where
F: Fn() -> ContinuationMonad<Value> + Send + Sync + 'static,
{
ContinuationMonad::<Value>::call_cc(move |_retry| {
computation() })
}
impl fmt::Display for ContinuationFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Continuation({})", self.id)
}
}
impl<A: fmt::Display> fmt::Display for ContinuationMonad<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.computation {
ContComputation::Pure(value) => write!(f, "Pure({value})"),
ContComputation::CallCC { .. } => write!(f, "CallCC(<procedure>)"),
ContComputation::ApplyContinuation { continuation, .. } => {
write!(f, "ApplyContinuation({continuation})")
}
ContComputation::Bind { .. } => write!(f, "Bind(<computation>)"),
ContComputation::Effect { .. } => write!(f, "Effect(<computation>)"),
}
}
}
unsafe impl<A: Send> Send for ContinuationMonad<A> {}
unsafe impl<A: Sync> Sync for ContinuationMonad<A> {}
unsafe impl Send for ContinuationFunction {}
unsafe impl Sync for ContinuationFunction {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pure_continuation() {
let cont = ContinuationMonad::pure(42);
let result = run_continuation(cont).unwrap();
assert_eq!(result, 42);
}
#[test]
fn test_continuation_bind() {
let cont = ContinuationMonad::pure(21)
.bind(|x| ContinuationMonad::pure((x * 2).to_value()));
let result = run_continuation(cont).unwrap();
assert_eq!(result, 42.to_value());
}
#[test]
fn test_escape_continuation() {
let cont = escape_continuation(42.to_value());
let result = run_continuation(cont).unwrap();
assert_eq!(result, 42.to_value());
}
#[test]
fn test_continuation_function_validity() {
let cont_func = ContinuationFunction::new(
1,
Arc::new(ThreadSafeEnvironment::new(None, 0)),
ContinuationComputation::EvaluationContext {
stack: Vec::new(),
captured_env: Arc::new(ThreadSafeEnvironment::new(None, 0)),
},
);
assert!(cont_func.is_valid());
}
}