use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Step<T> {
Continue(T),
Stop(T),
}
impl<T> Step<T> {
#[inline(always)]
pub fn is_continue(&self) -> bool {
matches!(self, Step::Continue(_))
}
#[inline(always)]
pub fn is_stop(&self) -> bool {
matches!(self, Step::Stop(_))
}
#[inline(always)]
pub fn unwrap(self) -> T {
match self {
Step::Continue(v) | Step::Stop(v) => v,
}
}
#[inline(always)]
pub fn map<U, F>(self, f: F) -> Step<U>
where
F: FnOnce(T) -> U,
{
match self {
Step::Continue(v) => Step::Continue(f(v)),
Step::Stop(v) => Step::Stop(f(v)),
}
}
#[inline(always)]
pub fn and_then<U, F>(self, f: F) -> Step<U>
where
F: FnOnce(T) -> Step<U>,
U: From<T>,
{
match self {
Step::Continue(v) => f(v),
Step::Stop(v) => Step::Stop(v.into()),
}
}
#[inline(always)]
pub fn continue_value(self) -> Option<T> {
match self {
Step::Continue(v) => Some(v),
Step::Stop(_) => None,
}
}
}
#[inline(always)]
pub fn cont<T>(value: T) -> Step<T> {
Step::Continue(value)
}
#[inline(always)]
pub fn stop<T>(value: T) -> Step<T> {
Step::Stop(value)
}
#[inline(always)]
pub fn is_stopped<T>(step: &Step<T>) -> bool {
step.is_stop()
}
#[inline(always)]
pub fn unwrap_step<T>(step: Step<T>) -> T {
step.unwrap()
}
impl<T: fmt::Display> fmt::Display for Step<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Step::Continue(v) => write!(f, "Continue({})", v),
Step::Stop(v) => write!(f, "Stop({})", v),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_continue() {
let step = cont(42);
assert!(step.is_continue());
assert!(!step.is_stop());
assert_eq!(step.unwrap(), 42);
}
#[test]
fn test_stop() {
let step = stop(42);
assert!(step.is_stop());
assert!(!step.is_continue());
assert_eq!(step.unwrap(), 42);
}
#[test]
fn test_map() {
let step = cont(42);
let mapped = step.map(|x| x * 2);
assert_eq!(mapped, cont(84));
let stopped = stop(42).map(|x| x * 2);
assert_eq!(stopped, stop(84));
}
#[test]
fn test_monad_laws() {
let f = |x| cont(x * 2);
assert_eq!(cont(42).and_then(f), f(42));
let m = cont(42);
assert_eq!(m.and_then(cont), m);
let f = |x| cont(x * 2);
let g = |x| cont(x + 10);
let m = cont(42);
assert_eq!(m.and_then(f).and_then(g), m.and_then(|x| f(x).and_then(g)));
}
}