1use crate::{Algorithm, Completable, Computable, Stateful};
2use cancel_this::is_cancelled;
3use std::marker::PhantomData;
4
5pub trait ComputationStep<CONTEXT, STATE, OUTPUT> {
19 fn step(context: &CONTEXT, state: &mut STATE) -> Completable<OUTPUT>;
23}
24
25#[derive(Debug)]
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64#[cfg_attr(
65 feature = "serde",
66 serde(
67 bound = "CONTEXT: serde::Serialize + for<'a> serde::Deserialize<'a>, STATE: serde::Serialize + for<'a> serde::Deserialize<'a>"
68 )
69)]
70pub struct Computation<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> {
71 context: CONTEXT,
72 state: STATE,
73 #[cfg_attr(feature = "serde", serde(skip))]
74 _phantom: PhantomData<(OUTPUT, STEP)>,
75}
76
77impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> Computable<OUTPUT>
78 for Computation<CONTEXT, STATE, OUTPUT, STEP>
79{
80 fn try_compute(&mut self) -> Completable<OUTPUT> {
81 is_cancelled!()?;
82 STEP::step(&self.context, &mut self.state)
83 }
84}
85
86impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>> Stateful<CONTEXT, STATE>
87 for Computation<CONTEXT, STATE, OUTPUT, STEP>
88{
89 fn from_parts(context: CONTEXT, state: STATE) -> Self
90 where
91 Self: Sized + 'static,
92 {
93 Computation {
94 context,
95 state,
96 _phantom: Default::default(),
97 }
98 }
99
100 fn into_parts(self) -> (CONTEXT, STATE) {
101 (self.context, self.state)
102 }
103
104 fn context(&self) -> &CONTEXT {
105 &self.context
106 }
107
108 fn state(&self) -> &STATE {
109 &self.state
110 }
111
112 fn state_mut(&mut self) -> &mut STATE {
113 &mut self.state
114 }
115}
116
117impl<CONTEXT, STATE, OUTPUT, STEP: ComputationStep<CONTEXT, STATE, OUTPUT>>
118 Algorithm<CONTEXT, STATE, OUTPUT> for Computation<CONTEXT, STATE, OUTPUT, STEP>
119{
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::{Algorithm, Computable, Incomplete, Stateful};
126
127 struct SimpleStep;
128
129 impl ComputationStep<i32, u32, String> for SimpleStep {
130 fn step(context: &i32, state: &mut u32) -> Completable<String> {
131 *state += 1;
132 if *state < 3 {
133 Err(Incomplete::Suspended)
134 } else {
135 Ok(format!("context={}, state={}", context, state))
136 }
137 }
138 }
139
140 #[test]
141 fn test_computation_from_parts() {
142 let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
143 assert_eq!(*computation.context(), 42);
144 assert_eq!(*computation.state(), 0);
145 }
146
147 #[test]
148 fn test_computation_into_parts() {
149 let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(100, 5);
150 let (context, state) = computation.into_parts();
151 assert_eq!(context, 100);
152 assert_eq!(state, 5);
153 }
154
155 #[test]
156 fn test_computation_state_mut() {
157 let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
158 *computation.state_mut() = 10;
159 assert_eq!(*computation.state(), 10);
160 }
161
162 #[test]
163 fn test_computation_try_compute() {
164 let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
165
166 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
168 assert_eq!(*computation.state(), 1);
169
170 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
172 assert_eq!(*computation.state(), 2);
173
174 let result = computation.try_compute().unwrap();
176 assert_eq!(result, "context=42, state=3");
177 assert_eq!(*computation.state(), 3);
178 }
179
180 #[test]
181 fn test_computation_compute() {
182 let mut computation = Computation::<i32, u32, String, SimpleStep>::from_parts(100, 0);
183 let result = computation.compute().unwrap();
184 assert_eq!(result, "context=100, state=3");
185 assert_eq!(*computation.state(), 3);
186 }
187
188 #[test]
189 fn test_computation_configure() {
190 let computation = Computation::<i32, u32, String, SimpleStep>::configure(50, 0u32);
191 assert_eq!(*computation.context(), 50);
192 assert_eq!(*computation.state(), 0);
193 }
194
195 #[test]
196 fn test_computation_run() {
197 let result = Computation::<i32, u32, String, SimpleStep>::run(200, 0u32).unwrap();
198 assert_eq!(result, "context=200, state=3");
199 }
200
201 #[test]
202 fn test_computation_dyn_algorithm() {
203 let computation = Computation::<i32, u32, String, SimpleStep>::from_parts(42, 0);
204 let mut dyn_algorithm = computation.dyn_algorithm();
205 let result = dyn_algorithm.compute().unwrap();
206 assert_eq!(result, "context=42, state=3");
207 }
208
209 struct ImmediateStep;
210
211 impl ComputationStep<(), (), i32> for ImmediateStep {
212 fn step(_context: &(), _state: &mut ()) -> Completable<i32> {
213 Ok(42)
214 }
215 }
216
217 #[test]
218 fn test_computation_immediate_completion() {
219 let mut computation = Computation::<(), (), i32, ImmediateStep>::from_parts((), ());
220 let result = computation.try_compute().unwrap();
221 assert_eq!(result, 42);
222 }
223
224 struct NeverCompleteStep;
225
226 impl ComputationStep<(), (), i32> for NeverCompleteStep {
227 fn step(_context: &(), _state: &mut ()) -> Completable<i32> {
228 Err(Incomplete::Suspended)
229 }
230 }
231
232 #[test]
233 fn test_computation_never_completes() {
234 let mut computation = Computation::<(), (), i32, NeverCompleteStep>::from_parts((), ());
235 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
237 assert_eq!(computation.try_compute(), Err(Incomplete::Suspended));
238 }
239}