#![allow(missing_docs)]
use super::MonadicValue;
use crate::diagnostics::Result;
use crate::eval::value::Value;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
pub trait MonadicOps<A> {
type Output<B>;
fn pure(value: A) -> Self;
fn bind<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(A) -> Self::Output<B>;
fn map<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(A) -> B;
}
#[derive(Debug, Clone, PartialEq)]
pub enum Maybe<A> {
Nothing,
Just(A),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Either<L, R> {
Left(L),
Right(R),
}
#[derive(Debug, Clone, PartialEq)]
pub struct List<A> {
items: Vec<A>,
}
#[derive(Clone)]
pub struct IO<A> {
action: Arc<dyn Fn() -> Result<A> + Send + Sync>,
}
impl<A> std::fmt::Debug for IO<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IO")
.field("action", &"<function>")
.finish()
}
}
#[derive(Clone)]
pub struct State<S, A> {
run_state: Arc<dyn Fn(S) -> Result<(A, S)> + Send + Sync>,
}
impl<S, A> std::fmt::Debug for State<S, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("State")
.field("run_state", &"<function>")
.finish()
}
}
#[derive(Clone)]
pub struct Reader<R, A> {
run_reader: Arc<dyn Fn(R) -> Result<A> + Send + Sync>,
}
impl<R, A> std::fmt::Debug for Reader<R, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Reader")
.field("run_reader", &"<function>")
.finish()
}
}
#[derive(Clone)]
pub struct Writer<W, A> {
run_writer: Arc<dyn Fn() -> Result<(A, W)> + Send + Sync>,
}
impl<W, A> std::fmt::Debug for Writer<W, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Writer")
.field("run_writer", &"<function>")
.finish()
}
}
pub trait MonadTrans<M> {
type T<A>;
fn lift<A>(ma: M) -> Self::T<A>;
}
#[derive(Debug, Clone)]
pub struct MaybeT<M, A> {
run_maybe_t: M, phantom: PhantomData<A>,
}
#[derive(Debug, Clone)]
pub struct EitherT<E, M, A> {
run_either_t: M, phantom: PhantomData<(E, A)>,
}
#[derive(Clone)]
pub struct StateT<S, M, A> {
run_state_t: Arc<dyn Fn(S) -> M + Send + Sync>, phantom: PhantomData<A>,
}
impl<S, M, A> std::fmt::Debug for StateT<S, M, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateT")
.field("run_state_t", &"<function>")
.field("phantom", &self.phantom)
.finish()
}
}
#[derive(Clone)]
pub struct ReaderT<R, M, A> {
run_reader_t: Arc<dyn Fn(R) -> M + Send + Sync>, phantom: PhantomData<A>,
}
impl<R, M, A> std::fmt::Debug for ReaderT<R, M, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReaderT")
.field("run_reader_t", &"<function>")
.field("phantom", &self.phantom)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct WriterT<W, M, A> {
run_writer_t: M, phantom: PhantomData<(W, A)>,
}
#[derive(Debug, Clone)]
pub enum Free<F, A> {
Pure(A),
Free(F), }
#[derive(Debug, Clone)]
pub struct DoNotation {
bindings: Vec<DoBinding>,
result: String, }
#[derive(Debug, Clone)]
pub struct DoBinding {
var: Option<String>,
expr: String, }
pub struct MonadOps;
impl<A> Maybe<A> {
pub fn just(value: A) -> Self {
Maybe::Just(value)
}
pub fn nothing() -> Self {
Maybe::Nothing
}
pub fn is_nothing(&self) -> bool {
matches!(self, Maybe::Nothing)
}
pub fn is_just(&self) -> bool {
matches!(self, Maybe::Just(_))
}
pub fn to_option(self) -> Option<A> {
match self {
Maybe::Just(a) => Some(a),
Maybe::Nothing => None,
}
}
pub fn from_option(opt: Option<A>) -> Self {
match opt {
Some(a) => Maybe::Just(a),
None => Maybe::Nothing,
}
}
}
impl<L, R> Either<L, R> {
pub fn left(value: L) -> Self {
Either::Left(value)
}
pub fn right(value: R) -> Self {
Either::Right(value)
}
pub fn is_left(&self) -> bool {
matches!(self, Either::Left(_))
}
pub fn is_right(&self) -> bool {
matches!(self, Either::Right(_))
}
pub fn map_right<R2, F>(self, f: F) -> Either<L, R2>
where
F: FnOnce(R) -> R2,
{
match self {
Either::Left(l) => Either::Left(l),
Either::Right(r) => Either::Right(f(r)),
}
}
pub fn map_left<L2, F>(self, f: F) -> Either<L2, R>
where
F: FnOnce(L) -> L2,
{
match self {
Either::Left(l) => Either::Left(f(l)),
Either::Right(r) => Either::Right(r),
}
}
}
impl<A> List<A> {
pub fn empty() -> Self {
List { items: Vec::new() }
}
pub fn singleton(item: A) -> Self {
List { items: vec![item] }
}
pub fn from_vec(items: Vec<A>) -> Self {
List { items }
}
pub fn to_vec(self) -> Vec<A> {
self.items
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn len(&self) -> usize {
self.items.len()
}
}
impl<A> IO<A> {
pub fn new<F>(action: F) -> Self
where
F: Fn() -> Result<A> + Send + Sync + 'static,
{
IO {
action: Arc::new(action),
}
}
pub fn run(self) -> Result<A> {
(self.action)()
}
}
impl<S, A> State<S, A> {
pub fn new<F>(f: F) -> Self
where
F: Fn(S) -> Result<(A, S)> + Send + Sync + 'static,
{
State {
run_state: Arc::new(f),
}
}
pub fn run(self, initial_state: S) -> Result<(A, S)> {
(self.run_state)(initial_state)
}
pub fn get<St>() -> State<St, St>
where
St: Clone + Send + Sync + 'static,
{
State::new(|s: St| Ok((s.clone(), s)))
}
pub fn put(new_state: S) -> State<S, ()>
where
S: Clone + Send + Sync + 'static,
{
State::new(move |_| Ok(((), new_state.clone())))
}
pub fn modify<F>(f: F) -> State<S, ()>
where
F: Fn(S) -> S + Send + Sync + 'static,
{
State::new(move |s| Ok(((), f(s))))
}
}
impl<R, A> Reader<R, A> {
pub fn new<F>(f: F) -> Self
where
F: Fn(R) -> Result<A> + Send + Sync + 'static,
{
Reader {
run_reader: Arc::new(f),
}
}
pub fn run(self, environment: R) -> Result<A> {
(self.run_reader)(environment)
}
pub fn ask<Rd>() -> Reader<Rd, Rd>
where
Rd: Clone + Send + Sync + 'static,
{
Reader::new(|r: Rd| Ok(r.clone()))
}
pub fn asks<F, B>(f: F) -> Reader<R, B>
where
F: Fn(R) -> B + Send + Sync + 'static,
{
Reader::new(move |r| Ok(f(r)))
}
}
impl<W, A> Writer<W, A> {
pub fn new<F>(f: F) -> Self
where
F: Fn() -> Result<(A, W)> + Send + Sync + 'static,
{
Writer {
run_writer: Arc::new(f),
}
}
pub fn run(self) -> Result<(A, W)> {
(self.run_writer)()
}
pub fn tell(w: W) -> Writer<W, ()>
where
W: Clone + Send + Sync + 'static,
{
Writer::new(move || Ok(((), w.clone())))
}
}
impl<A> MonadicOps<A> for Maybe<A> {
type Output<B> = Maybe<B>;
fn pure(value: A) -> Self {
Maybe::Just(value)
}
fn bind<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(A) -> Self::Output<B>,
{
match self {
Maybe::Just(a) => f(a),
Maybe::Nothing => Maybe::Nothing,
}
}
fn map<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(A) -> B,
{
match self {
Maybe::Just(a) => Maybe::Just(f(a)),
Maybe::Nothing => Maybe::Nothing,
}
}
}
impl<L, R> MonadicOps<R> for Either<L, R> {
type Output<B> = Either<L, B>;
fn pure(value: R) -> Self {
Either::Right(value)
}
fn bind<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(R) -> Self::Output<B>,
{
match self {
Either::Right(r) => f(r),
Either::Left(l) => Either::Left(l),
}
}
fn map<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(R) -> B,
{
match self {
Either::Right(r) => Either::Right(f(r)),
Either::Left(l) => Either::Left(l),
}
}
}
impl<A> MonadicOps<A> for List<A> {
type Output<B> = List<B>;
fn pure(value: A) -> Self {
List::singleton(value)
}
fn bind<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(A) -> Self::Output<B>,
{
let mut result = Vec::new();
for item in self.items {
let mapped = f(item);
result.extend(mapped.items);
}
List::from_vec(result)
}
fn map<B, F>(self, f: F) -> Self::Output<B>
where
F: Fn(A) -> B,
{
let mut result = Vec::new();
for item in self.items {
result.push(f(item));
}
List::from_vec(result)
}
}
impl MonadOps {
pub fn when_maybe(condition: bool, action: Maybe<()>) -> Maybe<()> {
if condition {
action
} else {
Maybe::Just(())
}
}
pub fn unless_maybe(condition: bool, action: Maybe<()>) -> Maybe<()> {
Self::when_maybe(!condition, action)
}
pub fn sequence_maybe<A, B>(ma: Maybe<A>, mb: Maybe<B>) -> Maybe<(A, B)>
where
A: Clone,
B: Clone,
{
ma.bind(|a| mb.clone().map(|b| (a.clone(), b)))
}
pub fn map_maybe<A, B, F>(f: F, items: Vec<A>) -> Maybe<Vec<B>>
where
F: Fn(A) -> Maybe<B>,
{
let mut results = Vec::new();
for item in items {
match f(item) {
Maybe::Just(b) => results.push(b),
Maybe::Nothing => return Maybe::Nothing,
}
}
Maybe::Just(results)
}
}
impl DoNotation {
pub fn new() -> Self {
Self {
bindings: Vec::new(),
result: String::new(),
}
}
pub fn bind(mut self, var: Option<String>, expr: String) -> Self {
self.bindings.push(DoBinding { var, expr });
self
}
pub fn result(mut self, expr: String) -> Self {
self.result = expr;
self
}
pub fn compile(&self) -> String {
format!("do {{ {} }}", self.result)
}
}
pub fn maybe_to_monadic<A>(maybe: Maybe<A>) -> MonadicValue
where
A: Into<Value>,
{
match maybe {
Maybe::Just(a) => MonadicValue::pure(a.into()),
Maybe::Nothing => MonadicValue::pure(Value::Nil),
}
}
pub fn either_to_monadic<L, R>(either: Either<L, R>) -> MonadicValue
where
L: Into<Value>,
R: Into<Value>,
{
match either {
Either::Right(r) => MonadicValue::pure(r.into()),
Either::Left(l) => MonadicValue::error(super::monad::ErrorAction::Return(l.into())),
}
}
pub fn list_to_monadic<A>(list: List<A>) -> MonadicValue
where
A: Into<Value>,
{
let values: Vec<Value> = list.items.into_iter().map(|a| a.into()).collect();
MonadicValue::pure(Value::list(values))
}
impl fmt::Display for DoNotation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "do")?;
for binding in &self.bindings {
if let Some(var) = &binding.var {
writeln!(f, " {} <- {}", var, binding.expr)?;
} else {
writeln!(f, " {}", binding.expr)?;
}
}
writeln!(f, " {}", self.result)
}
}
impl Default for DoNotation {
fn default() -> Self {
Self::new()
}
}
unsafe impl<A: Send> Send for Maybe<A> {}
unsafe impl<A: Sync> Sync for Maybe<A> {}
unsafe impl<L: Send, R: Send> Send for Either<L, R> {}
unsafe impl<L: Sync, R: Sync> Sync for Either<L, R> {}
unsafe impl<A: Send> Send for List<A> {}
unsafe impl<A: Sync> Sync for List<A> {}
unsafe impl<A: Send> Send for IO<A> {}
unsafe impl<A: Sync> Sync for IO<A> {}
unsafe impl<S: Send, A: Send> Send for State<S, A> {}
unsafe impl<S: Sync, A: Sync> Sync for State<S, A> {}
unsafe impl<R: Send, A: Send> Send for Reader<R, A> {}
unsafe impl<R: Sync, A: Sync> Sync for Reader<R, A> {}
unsafe impl<W: Send, A: Send> Send for Writer<W, A> {}
unsafe impl<W: Sync, A: Sync> Sync for Writer<W, A> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maybe_monad_ops() {
let just_5 = Maybe::just(5);
let nothing = Maybe::<i32>::nothing();
assert!(just_5.is_just());
assert!(nothing.is_nothing());
let result = just_5.bind(|x| Maybe::just(x * 2));
assert_eq!(result, Maybe::just(10));
let result = nothing.bind(|x| Maybe::just(x * 2));
assert_eq!(result, Maybe::nothing());
let result = Maybe::just(5).map(|x| x * 2);
assert_eq!(result, Maybe::just(10));
}
#[test]
fn test_either_monad_ops() {
let right_5 = Either::<String, i32>::right(5);
let left_err = Either::<String, i32>::left("error".to_string());
assert!(right_5.is_right());
assert!(left_err.is_left());
let result = right_5.bind(|x| Either::right(x * 2));
assert_eq!(result, Either::right(10));
let result = left_err.bind(|x| Either::right(x * 2));
assert!(result.is_left());
let result = Either::<String, i32>::right(5).map(|x| x * 2);
assert_eq!(result, Either::right(10));
}
#[test]
fn test_list_monad() {
let list = List::from_vec(vec![1, 2, 3]);
assert_eq!(list.len(), 3);
assert!(!list.is_empty());
let empty = List::<i32>::empty();
assert!(empty.is_empty());
let singleton = List::singleton(42);
assert_eq!(singleton.len(), 1);
}
#[test]
fn test_state_monad() {
let computation = State::new(|s: i32| Ok((s + 1, s * 2)));
let result = computation.run(5).unwrap();
assert_eq!(result, (6, 10));
let get_state = State::<i32, i32>::get();
let result = get_state.run(42).unwrap();
assert_eq!(result, (42, 42));
}
#[test]
fn test_reader_monad() {
let reader = Reader::new(|env: String| Ok(env.len()));
let result = reader.run("hello".to_string()).unwrap();
assert_eq!(result, 5);
let ask = Reader::<String, String>::ask();
let result = ask.run("world".to_string()).unwrap();
assert_eq!(result, "world");
}
#[test]
fn test_writer_monad() {
let writer = Writer::new(|| Ok((42, "logged".to_string())));
let result = writer.run().unwrap();
assert_eq!(result, (42, "logged".to_string()));
}
#[test]
fn test_do_notation() {
let do_block = DoNotation::new()
.bind(Some("x".to_string()), "getValue()".to_string())
.bind(Some("y".to_string()), "getAnother()".to_string())
.bind(None, "sideEffect()".to_string())
.result("return (x + y)".to_string());
assert_eq!(do_block.bindings.len(), 3);
assert_eq!(do_block.result, "return (x + y)");
}
}