1use crate::{Algorithm, Completable, Computable, Stateful};
2use cancel_this::is_cancelled;
3use serde::{Deserialize, Serialize};
4use std::marker::PhantomData;
5
6pub trait ComputationStep<CONTEXT, STATE, OUTPUT> {
20 fn step(context: &CONTEXT, state: &mut STATE) -> Completable<OUTPUT>;
24}
25
26#[derive(Debug, Serialize, Deserialize)]
64#[serde(
65 bound = "CONTEXT: Serialize + for<'a> Deserialize<'a>, STATE: Serialize + for<'a> Deserialize<'a>"
66)]
67pub struct Computation<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> {
68 context: CONTEXT,
69 state: STATE,
70 #[serde(skip)]
71 _phantom: PhantomData<(OUTPUT, STEP)>,
72}
73
74impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> Computable<OUTPUT>
75 for Computation<CONTEXT, STATE, OUTPUT, STEP>
76{
77 fn try_compute(&mut self) -> Completable<OUTPUT> {
78 is_cancelled!()?;
79 STEP::step(&self.context, &mut self.state)
80 }
81}
82
83impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> Stateful<CONTEXT, STATE>
84 for Computation<CONTEXT, STATE, OUTPUT, STEP>
85{
86 fn from_parts(context: CONTEXT, state: STATE) -> Self
87 where
88 Self: Sized + 'static,
89 {
90 Computation {
91 context,
92 state,
93 _phantom: Default::default(),
94 }
95 }
96
97 fn into_parts(self) -> (CONTEXT, STATE) {
98 (self.context, self.state)
99 }
100
101 fn context(&self) -> &CONTEXT {
102 &self.context
103 }
104
105 fn state(&self) -> &STATE {
106 &self.state
107 }
108
109 fn state_mut(&mut self) -> &mut STATE {
110 &mut self.state
111 }
112}
113
114impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>>
115 Algorithm<CONTEXT, STATE, OUTPUT> for Computation<CONTEXT, STATE, OUTPUT, STEP>
116{
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::{Algorithm, Computable, Incomplete, Stateful};
123
124 struct SimpleStep;
125
126 impl ComputationStep<i32, u32, String> for SimpleStep {
127 fn step(context: &i32, state: &mut u32) -> Completable<String> {
128 *state += 1;
129 if *state < 3 {
130 Err(Incomplete::Suspended)
131 } else {
132 Ok(format!("context={}, state={}", context, state))
133 }
134 }
135 }
136
137 #[test]
138 fn test_computation_from_parts() {
139 let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
140 assert_eq!(*computation.context(), 42);
141 assert_eq!(*computation.state(), 0);
142 }
143
144 #[test]
145 fn test_computation_into_parts() {
146 let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(100, 5);
147 let (context, state) = computation.into_parts();
148 assert_eq!(context, 100);
149 assert_eq!(state, 5);
150 }
151
152 #[test]
153 fn test_computation_state_mut() {
154 let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
155 *computation.state_mut() = 10;
156 assert_eq!(*computation.state(), 10);
157 }
158
159 #[test]
160 fn test_computation_try_compute() {
161 let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
162
163 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
165 assert_eq!(*computation.state(), 1);
166
167 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
169 assert_eq!(*computation.state(), 2);
170
171 let result = computation.try_compute().unwrap();
173 assert_eq!(result, "context=42, state=3");
174 assert_eq!(*computation.state(), 3);
175 }
176
177 #[test]
178 fn test_computation_compute() {
179 let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(100, 0);
180 let result = computation.compute().unwrap();
181 assert_eq!(result, "context=100, state=3");
182 assert_eq!(*computation.state(), 3);
183 }
184
185 #[test]
186 fn test_computation_configure() {
187 let computation = Computation::<i32, u32, String, SimpleStep>::configure(50, 0u32);
188 assert_eq!(*computation.context(), 50);
189 assert_eq!(*computation.state(), 0);
190 }
191
192 #[test]
193 fn test_computation_run() {
194 let result = Computation::<i32, u32, String, SimpleStep>::run(200, 0u32).unwrap();
195 assert_eq!(result, "context=200, state=3");
196 }
197
198 #[test]
199 fn test_computation_dyn_algorithm() {
200 let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
201 let mut dyn_algorithm = computation.dyn_algorithm();
202 let result = dyn_algorithm.compute().unwrap();
203 assert_eq!(result, "context=42, state=3");
204 }
205
206 struct ImmediateStep;
207
208 impl ComputationStep<(), (), i32> for ImmediateStep {
209 fn step(_context: &(), _state: &mut ()) -> Completable<i32> {
210 Ok(42)
211 }
212 }
213
214 #[test]
215 fn test_computation_immediate_completion() {
216 let mut computation = Computation::<(), (), i32, ImmediateStep>::from_parts((), ());
217 let result = computation.try_compute().unwrap();
218 assert_eq!(result, 42);
219 }
220
221 struct NeverCompleteStep;
222
223 impl ComputationStep<(), (), i32> for NeverCompleteStep {
224 fn step(_context: &(), _state: &mut ()) -> Completable<i32> {
225 Err(Incomplete::Suspended)
226 }
227 }
228
229 #[test]
230 fn test_computation_never_completes() {
231 let mut computation = Computation::<(), (), i32, NeverCompleteStep>::from_parts((), ());
232 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
234 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
235 }
236}