1use num_traits::float::FloatCore;
87use std::sync::{Arc, Mutex};
88use tokio_util::sync::CancellationToken;
89
90use crate::engine::policy::{CancellationPolicy, CompletionPolicy, EnginePolicy, PolicyStack};
91use crate::{
92 engine::{
93 checkpoint::{CheckpointBackend, CheckpointExtension},
94 extensions::Extensions,
95 Engine,
96 },
97 state::{Snapshotable, State, StateRestorer},
98 watchers::{Frequency, Observe, Observers},
99 FallibleProcedure, UserState,
100};
101
102pub trait GenerateBuilder: Sized {
103 fn build_for<P>(self, problem: P) -> Builder<Self, P, Uninitialised>
104 where
105 Self: FallibleProcedure<P>,
106 Self::State: UserState;
107}
108
109impl<Proc> GenerateBuilder for Proc {
110 fn build_for<P>(self, problem: P) -> Builder<Self, P, Uninitialised>
111 where
112 Proc: FallibleProcedure<P>,
113 Proc::State: UserState,
114 {
115 Builder {
116 procedure: self,
117 problem,
118 state: None,
119 time: true,
120 cancellation_token: None,
121
122 observers: Observers::new(),
123
124 policies: PolicyStack::new()
125 .add(CancellationPolicy)
126 .add(CompletionPolicy),
127
128 extensions: Extensions::new(),
129
130 _initialised: std::marker::PhantomData,
131 }
132 }
133}
134
135pub struct Uninitialised;
136pub struct Initialised;
137
138pub struct Builder<Proc, P, I>
139where
140 Proc: FallibleProcedure<P>,
141 Proc::State: UserState,
142 <Proc::State as UserState>::Float: FloatCore,
143{
144 procedure: Proc,
145 problem: P,
146 state: Option<Proc::State>,
147 time: bool,
148 cancellation_token: Option<CancellationToken>,
149
150 observers: Observers<Proc::State>,
151
152 policies: PolicyStack<<Proc::State as UserState>::Float>,
153 extensions: Extensions<Proc::State>,
154
155 _initialised: std::marker::PhantomData<I>,
156}
157
158impl<Proc, P, I> Builder<Proc, P, I>
159where
160 Proc: FallibleProcedure<P>,
161 Proc::State: UserState,
162 <Proc::State as UserState>::Float: FloatCore + 'static,
163{
164 #[must_use]
165 pub fn time(mut self, time: bool) -> Self {
166 self.time = time;
167 self
168 }
169
170 #[must_use]
172 pub fn attach_observer<OBS>(mut self, observer: OBS, frequency: Frequency) -> Self
173 where
174 OBS: Observe<Proc::State> + 'static,
175 {
176 self.observers
177 .attach(Arc::new(Mutex::new(observer)), frequency);
178 self
179 }
180
181 #[must_use]
182 pub fn and_policy<Q>(mut self, policy: Q) -> Self
183 where
184 Q: EnginePolicy<<Proc::State as UserState>::Float> + 'static,
185 {
186 self.policies = self.policies.add(policy);
187 self
188 }
189
190 #[must_use]
191 pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
192 self.cancellation_token = Some(token);
193 self
194 }
195
196 #[must_use]
197 pub fn with_default_policies(
201 mut self,
202 max_iter: usize,
203 absolute_tolerance: <Proc::State as UserState>::Float,
204 window_size: usize,
205 ) -> Self {
206 self.policies = self.policies.merge(PolicyStack::standard(
207 max_iter,
208 absolute_tolerance,
209 window_size,
210 ));
211 self
212 }
213
214 #[must_use]
215 pub fn with_checkpoint_backend<C>(mut self, store: C) -> Self
221 where
222 C: CheckpointBackend<
223 <Proc::State as Snapshotable>::Snapshot,
224 <Proc::State as UserState>::Float,
225 > + 'static,
226 Proc::State: Snapshotable,
227 {
228 self.extensions = self.extensions.add(CheckpointExtension::new(store));
229 self
230 }
231}
232
233impl<Proc, P> Builder<Proc, P, Uninitialised>
234where
235 Proc: FallibleProcedure<P>,
236 Proc::State: UserState,
237 <Proc::State as UserState>::Float: FloatCore + 'static,
238{
239 #[must_use]
241 pub fn with_initial_state(self, user: Proc::State) -> Builder<Proc, P, Initialised> {
242 Builder {
243 procedure: self.procedure,
244 problem: self.problem,
245 state: Some(user),
246 time: self.time,
247 cancellation_token: self.cancellation_token,
248
249 observers: self.observers,
250
251 policies: self.policies,
252
253 extensions: self.extensions,
254
255 _initialised: std::marker::PhantomData,
256 }
257 }
258
259 #[must_use]
260 pub fn resume_from_checkpoint(
261 self,
262 snapshot: <Proc::State as Snapshotable>::Snapshot,
263 ) -> Builder<Proc, P, Initialised>
264 where
265 Proc: FallibleProcedure<P>,
266 Proc::State: Snapshotable + StateRestorer<Proc::State>,
267 {
268 let user = Proc::State::restore(snapshot);
269
270 Builder {
271 procedure: self.procedure,
272 problem: self.problem,
273 state: Some(user),
274 time: self.time,
275 cancellation_token: self.cancellation_token,
276
277 observers: self.observers,
278
279 policies: self.policies,
280
281 extensions: self.extensions,
282
283 _initialised: std::marker::PhantomData,
284 }
285 }
286}
287
288impl<Proc, P> Builder<Proc, P, Initialised>
289where
290 Proc: FallibleProcedure<P>,
291 Proc::State: UserState,
292 <Proc::State as UserState>::Float: FloatCore + 'static,
293{
294 pub fn finalise(mut self) -> Engine<Proc, P, PolicyStack<<Proc::State as UserState>::Float>>
299 where
300 <Proc::State as UserState>::Float: num_traits::FromPrimitive,
301 {
302 let user = self.state.take().expect("builder invariant: user is set");
303
304 let cancellation = self.cancellation_token.unwrap_or_default();
305
306 #[cfg(feature = "ctrlc")]
307 {
308 let token = cancellation.clone();
309 ctrlc::set_handler(move || {
310 token.cancel();
311 })
312 .unwrap();
313 }
314
315 Engine {
316 procedure: self.procedure,
317 problem: self.problem,
318 state: State::new(user),
319
320 time: self.time,
321 start_time: None,
322
323 cancellation,
324
325 policy: self.policies,
326
327 observers: self.observers,
328 extensions: self.extensions,
329 }
330 }
331
332 pub fn finalise_with(
340 mut self,
341 policy: PolicyStack<<Proc::State as UserState>::Float>,
342 ) -> Engine<Proc, P, PolicyStack<<Proc::State as UserState>::Float>> {
343 let user = self.state.take().expect("builder invariant: user is set");
344 let cancellation = self.cancellation_token.unwrap_or_default();
345
346 #[cfg(feature = "ctrlc")]
347 {
348 let token = cancellation.clone();
349 ctrlc::set_handler(move || {
350 token.cancel();
351 })
352 .unwrap();
353 }
354
355 Engine {
356 procedure: self.procedure,
357 problem: self.problem,
358 state: State::new(user),
359
360 time: self.time,
361 start_time: None,
362
363 cancellation,
364
365 policy,
366
367 observers: self.observers,
368 extensions: self.extensions,
369 }
370 }
371}
372