feanor_math/
computation.rs

1use std::{fmt::Arguments, io::Write, sync::atomic::{AtomicBool, Ordering}};
2
3use atomicbox::AtomicOptionBox;
4
5use crate::{seq::VectorFn, unstable_sealed::UnstableSealed};
6
7///
8/// Provides an idiomatic way to convert a `Result<T, !>` into `T`, via
9/// ```
10/// # #![feature(never_type)]
11/// # use feanor_math::computation::*;
12/// fn some_computation() -> Result<&'static str, !> { Ok("this computation does not fail") }
13/// println!("{}", some_computation().unwrap_or_else(no_error));
14/// ```
15/// 
16pub fn no_error<T>(error: !) -> T {
17    error
18}
19
20///
21/// Trait for objects that observe and control a potentially long-running computation.
22/// 
23/// The idea is that this trait defines multiple functions that can be called during an
24/// algorithm, and provide certain functionality. This way, each algorithm can decide which
25/// functionality is relevant and how it is used.
26/// 
27/// This is currently unstable-sealed, since I expect significant additional functionality,
28/// potentially including
29///  - Early aborts, timeouts
30///  - Multithreading
31///  - Logging
32///  - ...
33/// 
34/// As a user, this trait should currently be used by passing either [`LogProgress`]
35/// or [`DontObserve`] to algorithms.
36/// 
37/// Also, note that all `description` parameters passed to computation controller
38/// functions are only for logging/debugging purposes only. There is no specified format,
39/// nor any stability guarantees on those messages.
40/// 
41/// # Example
42/// 
43/// Which features of a [`ComputationController`] an algorithm supports is completely up
44/// to the algorithm. Elliptic Curve factorization currently supports logging, abortion
45/// and multithreading.
46/// ```
47/// # use feanor_math::ring::*;
48/// # use feanor_math::algorithms::ec_factor::*;
49/// # use feanor_math::rings::zn::*;
50/// # use feanor_math::computation::*;
51/// let ring = zn_64::Zn::new(8591966237);
52/// // factors 8591966237 while printing progress
53/// let factor = lenstra_ec_factor(ring, LogProgress).unwrap_or_else(no_error);
54/// assert!(8591966237 % factor == 0);
55/// // factor it again, but don't print progress
56/// let factor = lenstra_ec_factor(ring, DontObserve).unwrap_or_else(no_error);
57/// assert!(8591966237 % factor == 0);
58/// ```
59/// If the multithreading with rayon is enabled, we can also do
60/// ```
61/// # use feanor_math::ring::*;
62/// # use feanor_math::algorithms::ec_factor::*;
63/// # use feanor_math::rings::zn::*;
64/// # use feanor_math::computation::*;
65/// # let ring = zn_64::Zn::new(8591966237);
66/// // factors 8591966237 using multiple threads
67/// let factor = lenstra_ec_factor(ring, RunMultithreadedLogProgress).unwrap_or_else(no_error);
68/// assert!(8591966237 % factor == 0);
69/// ```
70///
71pub trait ComputationController: Clone + UnstableSealed {
72
73    type Abort: Send;
74
75    ///
76    /// Called by algorithms in (more or less) regular time intervals, can provide
77    /// e.g. early aborts or tracking progress.
78    /// 
79    #[stability::unstable(feature = "enable")]
80    fn checkpoint(&self, _description: Arguments) -> Result<(), Self::Abort> { 
81        Ok(())
82    }
83    
84    ///
85    /// Runs the given closure with a clone of this iterator, possibly adding a log
86    /// message before and/or after the computation starts/finishes.
87    /// 
88    /// I am currently not completely sure what the right behavior is when this
89    /// function is called multiple times (possibly nested) for clones of the current
90    /// controller. We should certainly support nesting of computations, but what should
91    /// happen in multithreaded scenarios, if we have clones of controllers, or multiple
92    /// different controllers?
93    /// 
94    #[stability::unstable(feature = "enable")]
95    fn run_computation<F, T>(self, _description: Arguments, computation: F) -> T
96        where F: FnOnce(Self) -> T
97    {
98        computation(self)
99    }
100
101    #[stability::unstable(feature = "enable")]
102    fn log(&self, _description: Arguments) {}
103
104    ///
105    /// Inspired by Rayon, and behaves the same as `join()` there.
106    /// Concretely, this function runs both closures, possibly in parallel, and
107    /// returns their results.
108    /// 
109    /// 
110    #[stability::unstable(feature = "enable")]
111    fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
112        where
113            A: FnOnce(Self) -> RA + Send,
114            B: FnOnce(Self) -> RB + Send,
115            RA: Send,
116            RB: Send
117    {
118        (oper_a(self.clone()), oper_b(self.clone()))
119    }
120}
121
122///
123/// The reason why a (part of a) short-circuiting computation was aborted.
124/// 
125/// `Finished` means that the computation was aborted, since another part already
126/// found a result or aborted. `Abort(e)` means that the controller chose to abort 
127/// the computation at a checkpoint, with data `e`.
128/// 
129pub enum ShortCircuitingComputationAbort<E> {
130    Finished,
131    Abort(E)
132}
133
134///
135/// Shared data of a short-circuiting computation.
136/// 
137pub struct ShortCircuitingComputation<T, Controller>
138    where T: Send,
139        Controller: ComputationController
140{
141    finished: AtomicBool,
142    abort: AtomicOptionBox<Controller::Abort>,
143    result: AtomicOptionBox<T>,
144}
145
146///
147/// Handle to a short-circuiting computation.
148/// 
149pub struct ShortCircuitingComputationHandle<'a, T, Controller>
150    where T: Send,
151        Controller: ComputationController
152{
153    controller: Controller,
154    executor: &'a ShortCircuitingComputation<T, Controller>
155}
156
157impl<'a, T, Controller> Clone for ShortCircuitingComputationHandle<'a, T, Controller>
158    where T: Send,
159        Controller: ComputationController
160{
161    fn clone(&self) -> Self {
162        Self {
163            controller: self.controller.clone(),
164            executor: self.executor
165        }
166    }
167}
168
169impl<'a, T, Controller> ShortCircuitingComputationHandle<'a, T, Controller>
170    where T: Send,
171        Controller: ComputationController
172{
173    #[stability::unstable(feature = "enable")]
174    pub fn controller(&self) -> &Controller {
175        &self.controller
176    }
177
178    #[stability::unstable(feature = "enable")]
179    pub fn checkpoint(&self, description: Arguments) -> Result<(), ShortCircuitingComputationAbort<Controller::Abort>> { 
180        if self.executor.finished.load(Ordering::Relaxed) {
181            return Err(ShortCircuitingComputationAbort::Finished);
182        } else if let Err(e) = self.controller.checkpoint(description) {
183            return Err(ShortCircuitingComputationAbort::Abort(e));
184        } else {
185            return Ok(());
186        }
187    }
188
189    #[stability::unstable(feature = "enable")]
190    pub fn log(&self, description: Arguments) {
191        self.controller.log(description)
192    }
193
194    #[stability::unstable(feature = "enable")]
195    pub fn join_many<V, F>(self, operations: V)
196        where V: VectorFn<F> + Sync,
197            F: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
198    {
199        fn join_many_internal<'a, T, V, F, Controller>(controller: Controller, executor: &'a ShortCircuitingComputation<T, Controller>, tasks: &V, from: usize, to: usize, batch_tasks: usize)
200            where T: Send,
201                Controller: ComputationController,
202                V: VectorFn<F> + Sync,
203                F: FnOnce(ShortCircuitingComputationHandle<'a, T, Controller>) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
204        {
205            if executor.finished.load(Ordering::Relaxed) {
206                return;
207            } else if from == to {
208                return;
209            } else if from + batch_tasks >= to {
210                for i in from..to {
211                    match tasks.at(i)(ShortCircuitingComputationHandle {
212                        controller: controller.clone(),
213                        executor: executor
214                    }) {
215                        Ok(Some(result)) => {
216                            executor.finished.store(true, Ordering::Relaxed);
217                            executor.result.store(Some(Box::new(result)), Ordering::AcqRel);
218                        },
219                        Err(ShortCircuitingComputationAbort::Abort(abort)) => {
220                            executor.finished.store(true, Ordering::Relaxed);
221                            executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
222                        },
223                        Err(ShortCircuitingComputationAbort::Finished) | Ok(None) => {}
224                    }
225                }
226            } else {
227                let mid = (from + to) / 2;
228                controller.join(move |controller| join_many_internal(controller, executor, tasks, from, mid, batch_tasks), move |controller| join_many_internal(controller, executor, tasks, mid, to, batch_tasks));
229            }
230        }
231        join_many_internal(self.controller, self.executor, &operations, 0, operations.len(), 1)
232    }
233
234    #[stability::unstable(feature = "enable")]
235    pub fn join<A, B>(self, oper_a: A, oper_b: B)
236        where
237            A: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
238            B: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send
239    {
240        let success_fn = |value: T| {
241            self.executor.finished.store(true, Ordering::Relaxed);
242            self.executor.result.store(Some(Box::new(value)), Ordering::AcqRel);
243        };
244        let abort_fn = |abort: Controller::Abort| {
245            self.executor.finished.store(true, Ordering::Relaxed);
246            self.executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
247        };
248        self.controller.join(
249            |controller| {
250                if self.executor.finished.load(Ordering::Relaxed) {
251                    return;
252                }
253                match oper_a(ShortCircuitingComputationHandle {
254                    controller,
255                    executor: self.executor
256                }) {
257                    Ok(Some(result)) => success_fn(result),
258                    Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
259                    Err(ShortCircuitingComputationAbort::Finished) => {},
260                    Ok(None) => {}
261                }
262            },
263            |controller| {
264                if self.executor.finished.load(Ordering::Relaxed) {
265                    return;
266                }
267                match oper_b(ShortCircuitingComputationHandle {
268                    controller,
269                    executor: self.executor
270                }) {
271                    Ok(Some(result)) => success_fn(result),
272                    Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
273                    Err(ShortCircuitingComputationAbort::Finished) => {},
274                    Ok(None) => {}
275                }
276            }
277        );
278    }
279}
280
281impl<T, Controller> ShortCircuitingComputation<T, Controller>
282    where T: Send,
283        Controller: ComputationController
284{
285    #[stability::unstable(feature = "enable")]
286    pub fn new() -> Self {
287        Self {
288            finished: AtomicBool::new(false),
289            abort: AtomicOptionBox::none(),
290            result: AtomicOptionBox::none()
291        }
292    }
293
294    #[stability::unstable(feature = "enable")]
295    pub fn handle<'a>(&'a self, controller: Controller) -> ShortCircuitingComputationHandle<'a, T, Controller> {
296        ShortCircuitingComputationHandle {
297            controller: controller,
298            executor: self
299        }
300    }
301
302    #[stability::unstable(feature = "enable")]
303    pub fn finish(self) -> Result<Option<T>, Controller::Abort> {
304        if let Some(abort) = self.abort.swap(None, Ordering::AcqRel) {
305            return Err(*abort);
306        } else if let Some(result) = self.result.swap(None, Ordering::AcqRel) {
307            return Ok(Some(*result));
308        } else {
309            return Ok(None);
310        }
311    }
312}
313
314#[macro_export]
315macro_rules! checkpoint {
316    ($controller:expr) => {
317        ($controller).checkpoint(std::format_args!(""))?
318    };
319    ($controller:expr, $($args:tt)*) => {
320        ($controller).checkpoint(std::format_args!($($args)*))?
321    };
322}
323
324#[macro_export]
325macro_rules! log_progress {
326    ($controller:expr, $($args:tt)*) => {
327        ($controller).log(std::format_args!($($args)*))
328    };
329}
330
331#[derive(Clone, Copy)]
332pub struct LogProgress;
333
334impl UnstableSealed for LogProgress {}
335
336impl ComputationController for LogProgress {
337
338    type Abort = !;
339
340    #[stability::unstable(feature = "enable")]
341    fn log(&self, description: Arguments) {
342        print!("{}", description);
343        std::io::stdout().flush().unwrap();
344    }
345
346    #[stability::unstable(feature = "enable")]
347    fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
348        where F: FnOnce(Self) -> T
349    {
350        self.log(description);
351        let result = computation(self);
352        self.log(format_args!("\n"));
353        return result;
354    }
355
356    #[stability::unstable(feature = "enable")]
357    fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
358        self.log(description);
359        Ok(())
360    }
361}
362
363#[derive(Clone, Copy)]
364pub struct DontObserve;
365
366impl UnstableSealed for DontObserve {}
367
368impl ComputationController for DontObserve {
369
370    type Abort = !;
371}
372
373#[cfg(feature = "parallel")]
374mod parallel_controller {
375
376    use super::*;
377
378    #[stability::unstable(feature = "enable")]
379    pub struct ExecuteMultithreaded<Rest: ComputationController + Send> {
380        rest: Rest
381    }
382
383    impl<Rest: ComputationController + Send + Copy> Copy for ExecuteMultithreaded<Rest> {}
384
385    impl<Rest: ComputationController + Send> Clone for ExecuteMultithreaded<Rest> {
386        fn clone(&self) -> Self {
387            Self { rest: self.rest.clone() }
388        }
389    }
390    
391    impl<Rest: ComputationController + Send> UnstableSealed for ExecuteMultithreaded<Rest> {}
392
393    impl<Rest: ComputationController + Send> ComputationController for ExecuteMultithreaded<Rest> {
394        type Abort = Rest::Abort;
395
396        #[stability::unstable(feature = "enable")]
397        fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> { 
398            self.rest.checkpoint(description)
399        }
400    
401        #[stability::unstable(feature = "enable")]
402        fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
403            where F: FnOnce(Self) -> T
404        {
405            self.rest.run_computation(description, |rest| computation(ExecuteMultithreaded { rest }))
406        }
407
408        #[stability::unstable(feature = "enable")]
409        fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
410            where
411                A: FnOnce(Self) -> RA + Send,
412                B: FnOnce(Self) -> RB + Send,
413                RA: Send,
414                RB: Send
415        {
416            let self1 = self.clone();
417            let self2 = self;
418            rayon::join(|| oper_a(self1), || oper_b(self2))
419        }
420    }
421
422    #[stability::unstable(feature = "enable")]
423    #[allow(non_upper_case_globals)]
424    pub static RunMultithreadedLogProgress: ExecuteMultithreaded<LogProgress> = ExecuteMultithreaded { rest: LogProgress };
425    #[stability::unstable(feature = "enable")]
426    #[allow(non_upper_case_globals)]
427    pub static RunMultithreaded: ExecuteMultithreaded<DontObserve> = ExecuteMultithreaded { rest: DontObserve };
428}
429
430#[cfg(feature = "parallel")]
431pub use parallel_controller::*;