computation_process/
computation.rs

1use crate::{Algorithm, Completable, Computable, Stateful};
2use cancel_this::is_cancelled;
3use serde::{Deserialize, Serialize};
4use std::marker::PhantomData;
5
6/// Defines a single step of a [`Computation`].
7///
8/// Implement this trait to define the logic for advancing a computation.
9/// Each call to `step` should either:
10/// - Return `Ok(output)` if the computation is complete
11/// - Return `Err(Incomplete::Suspended)` to yield control and allow resumption later
12/// - Return `Err(Incomplete::Cancelled(_))` if cancellation was detected
13///
14/// # Type Parameters
15///
16/// - `CONTEXT`: Immutable configuration/input for the computation
17/// - `STATE`: Mutable state that persists across steps
18/// - `OUTPUT`: The final result type of the computation
19pub trait ComputationStep<CONTEXT, STATE, OUTPUT> {
20    /// Execute one step of the computation.
21    ///
22    /// This method is called repeatedly until it returns `Ok(output)`.
23    fn step(context: &CONTEXT, state: &mut STATE) -> Completable<OUTPUT>;
24}
25
26/// A stateful computation that can be suspended and resumed.
27///
28/// `Computation` is the default implementation of [`Algorithm`]. It delegates the
29/// actual computation logic to a [`ComputationStep`] implementation while handling
30/// the boilerplate of state management and cancellation checking.
31///
32/// # Type Parameters
33///
34/// - `CONTEXT`: Immutable configuration passed to each step
35/// - `STATE`: Mutable state that persists across steps  
36/// - `OUTPUT`: The final result type
37/// - `STEP`: The [`ComputationStep`] implementation that defines the computation logic
38///
39/// # Example
40///
41/// ```rust
42/// use computation_process::{Computation, ComputationStep, Completable, Incomplete, Computable, Stateful};
43///
44/// struct SumStep;
45///
46/// impl ComputationStep<Vec<i32>, usize, i32> for SumStep {
47///     fn step(numbers: &Vec<i32>, index: &mut usize) -> Completable<i32> {
48///         if *index < numbers.len() {
49///             *index += 1;
50///             Err(Incomplete::Suspended) // Suspend after processing each number
51///         } else {
52///             Ok(numbers.iter().sum())
53///         }
54///     }
55/// }
56///
57/// let mut computation = Computation::<Vec<i32>, usize, i32, SumStep>::from_parts(
58///     vec![1, 2, 3, 4, 5],
59///     0,
60/// );
61/// assert_eq!(computation.compute().unwrap(), 15);
62/// ```
63#[derive(Debug, Serialize, Deserialize)]
64#[serde(
65    bound = "CONTEXT: Serialize + for<'a> Deserialize<'a>, STATE: Serialize + for<'a> Deserialize<'a>"
66)]
67pub struct Computation<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> {
68    context: CONTEXT,
69    state: STATE,
70    #[serde(skip)]
71    _phantom: PhantomData<(OUTPUT, STEP)>,
72}
73
74impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> Computable<OUTPUT>
75    for Computation<CONTEXT, STATE, OUTPUT, STEP>
76{
77    fn try_compute(&mut self) -> Completable<OUTPUT> {
78        is_cancelled!()?;
79        STEP::step(&self.context, &mut self.state)
80    }
81}
82
83impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> Stateful<CONTEXT, STATE>
84    for Computation<CONTEXT, STATE, OUTPUT, STEP>
85{
86    fn from_parts(context: CONTEXT, state: STATE) -> Self
87    where
88        Self: Sized + 'static,
89    {
90        Computation {
91            context,
92            state,
93            _phantom: Default::default(),
94        }
95    }
96
97    fn into_parts(self) -> (CONTEXT, STATE) {
98        (self.context, self.state)
99    }
100
101    fn context(&self) -> &CONTEXT {
102        &self.context
103    }
104
105    fn state(&self) -> &STATE {
106        &self.state
107    }
108
109    fn state_mut(&mut self) -> &mut STATE {
110        &mut self.state
111    }
112}
113
114impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>>
115    Algorithm<CONTEXT, STATE, OUTPUT> for Computation<CONTEXT, STATE, OUTPUT, STEP>
116{
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::{Algorithm, Computable, Incomplete, Stateful};
123
124    struct SimpleStep;
125
126    impl ComputationStep<i32, u32, String> for SimpleStep {
127        fn step(context: &i32, state: &mut u32) -> Completable<String> {
128            *state += 1;
129            if *state < 3 {
130                Err(Incomplete::Suspended)
131            } else {
132                Ok(format!("context={}, state={}", context, state))
133            }
134        }
135    }
136
137    #[test]
138    fn test_computation_from_parts() {
139        let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
140        assert_eq!(*computation.context(), 42);
141        assert_eq!(*computation.state(), 0);
142    }
143
144    #[test]
145    fn test_computation_into_parts() {
146        let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(100, 5);
147        let (context, state) = computation.into_parts();
148        assert_eq!(context, 100);
149        assert_eq!(state, 5);
150    }
151
152    #[test]
153    fn test_computation_state_mut() {
154        let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
155        *computation.state_mut() = 10;
156        assert_eq!(*computation.state(), 10);
157    }
158
159    #[test]
160    fn test_computation_try_compute() {
161        let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
162
163        // The first call should suspend
164        assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
165        assert_eq!(*computation.state(), 1);
166
167        // The second call should suspend
168        assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
169        assert_eq!(*computation.state(), 2);
170
171        // The third call should complete
172        let result = computation.try_compute().unwrap();
173        assert_eq!(result, "context=42, state=3");
174        assert_eq!(*computation.state(), 3);
175    }
176
177    #[test]
178    fn test_computation_compute() {
179        let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(100, 0);
180        let result = computation.compute().unwrap();
181        assert_eq!(result, "context=100, state=3");
182        assert_eq!(*computation.state(), 3);
183    }
184
185    #[test]
186    fn test_computation_configure() {
187        let computation = Computation::<i32, u32, String, SimpleStep>::configure(50, 0u32);
188        assert_eq!(*computation.context(), 50);
189        assert_eq!(*computation.state(), 0);
190    }
191
192    #[test]
193    fn test_computation_run() {
194        let result = Computation::<i32, u32, String, SimpleStep>::run(200, 0u32).unwrap();
195        assert_eq!(result, "context=200, state=3");
196    }
197
198    #[test]
199    fn test_computation_dyn_algorithm() {
200        let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
201        let mut dyn_algorithm = computation.dyn_algorithm();
202        let result = dyn_algorithm.compute().unwrap();
203        assert_eq!(result, "context=42, state=3");
204    }
205
206    struct ImmediateStep;
207
208    impl ComputationStep<(), (), i32> for ImmediateStep {
209        fn step(_context: &(), _state: &mut ()) -> Completable<i32> {
210            Ok(42)
211        }
212    }
213
214    #[test]
215    fn test_computation_immediate_completion() {
216        let mut computation = Computation::<(), (), i32, ImmediateStep>::from_parts((), ());
217        let result = computation.try_compute().unwrap();
218        assert_eq!(result, 42);
219    }
220
221    struct NeverCompleteStep;
222
223    impl ComputationStep<(), (), i32> for NeverCompleteStep {
224        fn step(_context: &(), _state: &mut ()) -> Completable<i32> {
225            Err(Incomplete::Suspended)
226        }
227    }
228
229    #[test]
230    fn test_computation_never_completes() {
231        let mut computation = Computation::<(), (), i32, NeverCompleteStep>::from_parts((), ());
232        // This will loop forever in compute(), so we test try_compute instead
233        assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
234        assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
235    }
236}