use std::fmt::Debug;
#[derive(Debug, Clone)]
pub enum Free<F: Clone + Debug, A: Clone + Debug> {
Pure(A),
Suspend(F, Box<dyn CloneFn<F, A>>),
}
pub trait CloneFn<F: Clone + Debug, A: Clone + Debug>: Debug {
fn call(&self, f: F) -> Free<F, A>;
fn clone_box(&self) -> Box<dyn CloneFn<F, A>>;
}
impl<F: Clone + Debug, A: Clone + Debug> Clone for Box<dyn CloneFn<F, A>> {
fn clone(&self) -> Self {
self.clone_box()
}
}
impl<F: Clone + Debug + 'static, A: Clone + Debug + 'static> Free<F, A> {
pub fn pure(a: A) -> Self {
Free::Pure(a)
}
pub fn lift(cmd: F) -> Free<F, F> {
#[derive(Debug, Clone)]
struct IdCont<F: Clone + Debug>(std::marker::PhantomData<F>);
impl<F: Clone + Debug + 'static> CloneFn<F, F> for IdCont<F> {
fn call(&self, f: F) -> Free<F, F> {
Free::Pure(f)
}
fn clone_box(&self) -> Box<dyn CloneFn<F, F>> {
Box::new(self.clone())
}
}
Free::Suspend(cmd, Box::new(IdCont(std::marker::PhantomData)))
}
pub fn run(self) -> A
where
F: Into<A>,
{
match self {
Free::Pure(a) => a,
Free::Suspend(cmd, k) => k.call(cmd).run(),
}
}
}
#[derive(Debug, Clone)]
pub struct Chain<A: Clone + Debug> {
steps: Vec<String>,
value: A,
}
impl<A: Clone + Debug> Chain<A> {
pub fn start(value: A) -> Self {
Self {
steps: Vec::new(),
value,
}
}
pub fn then<B: Clone + Debug>(self, name: &str, f: impl FnOnce(A) -> B) -> Chain<B> {
let mut steps = self.steps;
steps.push(name.to_string());
Chain {
steps,
value: f(self.value),
}
}
pub fn value(&self) -> &A {
&self.value
}
pub fn steps(&self) -> &[String] {
&self.steps
}
pub fn run(self) -> (A, Vec<String>) {
(self.value, self.steps)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chain_accumulates_steps() {
let (result, steps) = Chain::start(10)
.then("double", |x| x * 2)
.then("add_one", |x| x + 1)
.then("to_string", |x| format!("{x}"))
.run();
assert_eq!(result, "21");
assert_eq!(steps, vec!["double", "add_one", "to_string"]);
}
#[test]
fn chain_empty() {
let chain = Chain::start(42);
assert_eq!(*chain.value(), 42);
assert!(chain.steps().is_empty());
}
#[test]
fn chain_single_step() {
let (val, steps) = Chain::start("hello").then("length", |s| s.len()).run();
assert_eq!(val, 5);
assert_eq!(steps, vec!["length"]);
}
#[test]
fn pipeline_as_chain() {
let (response, trace) = Chain::start("is a dog an animal")
.then("tokenize", |input| input.split_whitespace().count())
.then("parse", |token_count| token_count > 0)
.then(
"interpret",
|parsed| {
if parsed { "question" } else { "unknown" }
},
)
.then("respond", |intent| format!("Understood: {intent}"))
.run();
assert_eq!(response, "Understood: question");
assert_eq!(trace, vec!["tokenize", "parse", "interpret", "respond"]);
}
mod prop {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_chain_step_count(n in 0..10usize) {
let mut chain = Chain::start(0);
for i in 0..n {
chain = chain.then(&format!("step{i}"), |x| x + 1);
}
prop_assert_eq!(chain.steps().len(), n);
prop_assert_eq!(*chain.value(), n as i32);
}
}
}
}