use std::fmt::Debug;
pub struct State<S, A> {
f: Box<dyn FnOnce(S) -> (A, S)>,
}
impl<S: 'static, A: 'static> State<S, A> {
pub fn new(f: impl FnOnce(S) -> (A, S) + 'static) -> Self {
Self { f: Box::new(f) }
}
pub fn pure(a: A) -> Self {
Self {
f: Box::new(move |s| (a, s)),
}
}
pub fn run(self, initial: S) -> (A, S) {
(self.f)(initial)
}
pub fn map<B: 'static>(self, g: impl FnOnce(A) -> B + 'static) -> State<S, B> {
State::new(move |s| {
let (a, s2) = (self.f)(s);
(g(a), s2)
})
}
pub fn bind<B: 'static>(self, g: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
State::new(move |s| {
let (a, s2) = (self.f)(s);
g(a).run(s2)
})
}
pub fn get() -> State<S, S>
where
S: Clone,
{
State::new(|s: S| (s.clone(), s))
}
pub fn put(new_state: S) -> State<S, ()> {
State::new(move |_| ((), new_state))
}
pub fn modify(f: impl FnOnce(S) -> S + 'static) -> State<S, ()> {
State::new(move |s| ((), f(s)))
}
}
impl<S, A: Debug> Debug for State<S, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "State<_, _>")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn left_identity() {
let a = 42;
let f = |x: i32| State::new(move |s: i32| (x + s, s + 1));
let left = State::<i32, i32>::pure(a).bind(f);
let right = (|x: i32| State::new(move |s: i32| (x + s, s + 1)))(a);
assert_eq!(left.run(10), right.run(10));
}
#[test]
fn right_identity() {
let (val1, state1) = State::new(|s: i32| (s * 2, s + 1)).run(10);
let (val2, state2) = State::new(|s: i32| (s * 2, s + 1))
.bind(State::<i32, i32>::pure)
.run(10);
assert_eq!(val1, val2);
assert_eq!(state1, state2);
}
#[test]
fn associativity() {
let m = || State::new(|s: i32| (s, s + 1));
let f = |x: i32| State::new(move |s: i32| (x * 2, s + 10));
let g = |x: i32| State::new(move |s: i32| (x + 100, s * 2));
let left = m().bind(f).bind(g).run(0);
let right = m()
.bind(|x| {
(|x: i32| State::new(move |s: i32| (x * 2, s + 10)))(x)
.bind(|y| State::new(move |s: i32| (y + 100, s * 2)))
})
.run(0);
assert_eq!(left, right);
}
#[test]
fn get_returns_state() {
let (val, state) = State::<i32, i32>::get().run(42);
assert_eq!(val, 42);
assert_eq!(state, 42);
}
#[test]
fn put_replaces_state() {
let (_, state) = State::<i32, ()>::put(99).run(0);
assert_eq!(state, 99);
}
#[test]
fn modify_transforms_state() {
let (_, state) = State::<i32, ()>::modify(|s| s + 10).run(32);
assert_eq!(state, 42);
}
#[test]
fn engine_as_state_monad() {
#[derive(Clone, Debug, PartialEq)]
struct PipelineState {
tokens: Vec<String>,
parsed: bool,
}
let tokenize = State::new(|mut s: PipelineState| {
s.tokens = vec!["hello".into(), "world".into()];
(s.tokens.len(), s)
});
let parse = |token_count: usize| {
State::new(move |mut s: PipelineState| {
s.parsed = token_count > 0;
(s.parsed, s)
})
};
let initial = PipelineState {
tokens: vec![],
parsed: false,
};
let (result, final_state) = tokenize.bind(parse).run(initial);
assert!(result); assert_eq!(final_state.tokens, vec!["hello", "world"]);
assert!(final_state.parsed);
}
#[test]
fn get_then_put_is_identity() {
let (_, state) = State::<i32, i32>::get()
.bind(|s| State::<i32, ()>::put(s))
.run(42);
assert_eq!(state, 42);
}
#[test]
fn put_then_get_returns_new_state() {
let (val, state) = State::<i32, ()>::put(99)
.bind(|_| State::<i32, i32>::get())
.run(0);
assert_eq!(val, 99);
assert_eq!(state, 99);
}
}