use comp_cat_rs::effect::io::Io;
#[must_use]
pub struct Stage<E, A, B> {
run: Box<dyn FnOnce(A) -> Io<E, B> + Send>,
}
impl<E: Send + 'static, A: Send + 'static, B: Send + 'static> Stage<E, A, B> {
pub fn new(f: impl FnOnce(A) -> Io<E, B> + Send + 'static) -> Self {
Self { run: Box::new(f) }
}
pub fn apply(self, input: A) -> Io<E, B> {
(self.run)(input)
}
pub fn then<C: Send + 'static>(self, next: Stage<E, B, C>) -> Stage<E, A, C> {
Stage::new(move |a| self.apply(a).flat_map(move |b| next.apply(b)))
}
pub fn map_output<C: Send + 'static>(
self,
f: impl FnOnce(B) -> C + Send + 'static,
) -> Stage<E, A, C> {
Stage::new(move |a| self.apply(a).map(f))
}
}
#[allow(clippy::mismatching_type_param_order)]
impl<E: Send + 'static, A: Send + 'static> Stage<E, A, A> {
pub fn identity() -> Self {
Self::new(Io::pure)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_passes_through() {
let stage: Stage<std::convert::Infallible, i32, i32> = Stage::identity();
let result = stage.apply(42).run();
assert_eq!(result, Ok(42));
}
#[test]
fn then_composes_sequentially() {
let double: Stage<std::convert::Infallible, i32, i32> =
Stage::new(|x| Io::pure(x * 2));
let add_one: Stage<std::convert::Infallible, i32, i32> =
Stage::new(|x| Io::pure(x + 1));
let combined = double.then(add_one);
let result = combined.apply(5).run();
assert_eq!(result, Ok(11));
}
#[test]
fn left_identity_law() {
let f: Stage<std::convert::Infallible, i32, i32> =
Stage::new(|x| Io::pure(x * 3));
let via_id = Stage::identity().then(Stage::new(|x| Io::pure(x * 3)));
assert_eq!(f.apply(7).run(), via_id.apply(7).run());
}
#[test]
fn right_identity_law() {
let f: Stage<std::convert::Infallible, i32, i32> =
Stage::new(|x| Io::pure(x * 3));
let via_id = Stage::<std::convert::Infallible, i32, i32>::new(|x| Io::pure(x * 3))
.then(Stage::identity());
assert_eq!(f.apply(7).run(), via_id.apply(7).run());
}
#[test]
fn associativity_law() {
let f: Stage<std::convert::Infallible, i32, i32> =
Stage::new(|x| Io::pure(x + 1));
let g: Stage<std::convert::Infallible, i32, i32> =
Stage::new(|x| Io::pure(x * 2));
let h: Stage<std::convert::Infallible, i32, i32> =
Stage::new(|x| Io::pure(x - 3));
let left = Stage::<std::convert::Infallible, i32, i32>::new(|x| Io::pure(x + 1))
.then(Stage::new(|x| Io::pure(x * 2)))
.then(Stage::new(|x| Io::pure(x - 3)));
let right = Stage::<std::convert::Infallible, i32, i32>::new(|x| Io::pure(x + 1))
.then(
Stage::<std::convert::Infallible, i32, i32>::new(|x| Io::pure(x * 2))
.then(Stage::new(|x| Io::pure(x - 3))),
);
let input = 10;
assert_eq!(left.apply(input).run(), right.apply(input).run());
let _ = f;
let _ = g;
let _ = h;
}
#[test]
fn map_output_transforms_result() {
let stage: Stage<std::convert::Infallible, i32, String> =
Stage::<std::convert::Infallible, i32, i32>::new(|x| Io::pure(x * 2))
.map_output(|x| x.to_string());
let result = stage.apply(5).run();
assert_eq!(result, Ok("10".to_string()));
}
}