feanor_math/
computation.rs

1use std::{fmt::Arguments, io::Write, sync::atomic::{AtomicBool, Ordering}};
2
3use atomicbox::AtomicOptionBox;
4
5use crate::{seq::VectorFn, unstable_sealed::UnstableSealed};
6
7///
8/// Provides an idiomatic way to convert a `Result<T, !>` into `T`, via
9/// ```rust
10/// # #![feature(never_type)]
11/// # use feanor_math::computation::*;
12/// fn some_computation() -> Result<&'static str, !> { Ok("this computation does not fail") }
13/// println!("{}", some_computation().unwrap_or_else(no_error));
14/// ```
15/// 
16pub fn no_error<T>(error: !) -> T {
17    error
18}
19
20///
21/// Trait for objects that observe and control a potentially long-running computation.
22/// 
23/// The idea is that this trait defines multiple functions that can be called during an
24/// algorithm, and provide certain functionality. This way, each algorithm can decide which
25/// functionality is relevant and how it is used.
26/// 
27/// This is currently unstable-sealed, since I expect significant additional functionality,
28/// potentially including
29///  - Early aborts, timeouts
30///  - Multithreading
31///  - Logging
32///  - ...
33/// 
34/// As a user, this trait should currently be used by passing either [`LogProgress`]
35/// or [`DontObserve`] to algorithms.
36/// 
37/// Also, note that all `description` parameters passed to computation controller
38/// functions are only for logging/debugging purposes only. There is no specified format,
39/// nor any stability guarantees on those messages.
40/// 
41/// # Example
42/// 
43/// Which features of a [`ComputationController`] an algorithm supports is completely up
44/// to the algorithm. Elliptic Curve factorization currently supports logging, abortion
45/// and multithreading.
46/// ```rust
47/// # use feanor_math::ring::*;
48/// # use feanor_math::algorithms::ec_factor::*;
49/// # use feanor_math::rings::zn::*;
50/// # use feanor_math::computation::*;
51/// let ring = zn_64::Zn::new(8591966237);
52/// // factors 8591966237 while printing progress
53/// let factor = lenstra_ec_factor(ring, LOG_PROGRESS).unwrap_or_else(no_error);
54/// assert!(8591966237 % factor == 0);
55/// // factor it again, but don't print progress
56/// let factor = lenstra_ec_factor(ring, DontObserve).unwrap_or_else(no_error);
57/// assert!(8591966237 % factor == 0);
58/// ```
59/// If the multithreading with rayon is enabled, we can also do
60/// ```rust
61/// # use feanor_math::ring::*;
62/// # use feanor_math::algorithms::ec_factor::*;
63/// # use feanor_math::rings::zn::*;
64/// # use feanor_math::computation::*;
65/// # let ring = zn_64::Zn::new(8591966237);
66/// // factors 8591966237 using multiple threads
67/// let factor = lenstra_ec_factor(ring, RunMultithreadedLogProgress).unwrap_or_else(no_error);
68/// assert!(8591966237 % factor == 0);
69/// ```
70///
71pub trait ComputationController: Clone + UnstableSealed {
72
73    type Abort: Send;
74
75    ///
76    /// Called by algorithms in (more or less) regular time intervals, can provide
77    /// e.g. early aborts or tracking progress.
78    /// 
79    #[stability::unstable(feature = "enable")]
80    fn checkpoint(&self, _description: Arguments) -> Result<(), Self::Abort> { 
81        Ok(())
82    }
83    
84    ///
85    /// Runs the given closure with a clone of this iterator, possibly adding a log
86    /// message before and/or after the computation starts/finishes.
87    /// 
88    /// I am currently not completely sure what the right behavior is when this
89    /// function is called multiple times (possibly nested) for clones of the current
90    /// controller. We should certainly support nesting of computations, but what should
91    /// happen in multithreaded scenarios, if we have clones of controllers, or multiple
92    /// different controllers?
93    /// 
94    #[stability::unstable(feature = "enable")]
95    fn run_computation<F, T>(self, _description: Arguments, computation: F) -> T
96        where F: FnOnce(Self) -> T
97    {
98        computation(self)
99    }
100
101    #[stability::unstable(feature = "enable")]
102    fn log(&self, _description: Arguments) {}
103
104    ///
105    /// Inspired by Rayon, and behaves the same as `join()` there.
106    /// Concretely, this function runs both closures, possibly in parallel, and
107    /// returns their results.
108    /// 
109    /// 
110    #[stability::unstable(feature = "enable")]
111    fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
112        where
113            A: FnOnce(Self) -> RA + Send,
114            B: FnOnce(Self) -> RB + Send,
115            RA: Send,
116            RB: Send
117    {
118        (oper_a(self.clone()), oper_b(self.clone()))
119    }
120}
121
122///
123/// The reason why a (part of a) short-circuiting computation was aborted.
124/// 
125/// `Finished` means that the computation was aborted, since another part already
126/// found a result or aborted. `Abort(e)` means that the controller chose to abort 
127/// the computation at a checkpoint, with data `e`.
128/// 
129pub enum ShortCircuitingComputationAbort<E> {
130    Finished,
131    Abort(E)
132}
133
134///
135/// Shared data of a short-circuiting computation.
136/// 
137pub struct ShortCircuitingComputation<T, Controller>
138    where T: Send,
139        Controller: ComputationController
140{
141    finished: AtomicBool,
142    abort: AtomicOptionBox<Controller::Abort>,
143    result: AtomicOptionBox<T>,
144}
145
146///
147/// Handle to a short-circuiting computation.
148/// 
149pub struct ShortCircuitingComputationHandle<'a, T, Controller>
150    where T: Send,
151        Controller: ComputationController
152{
153    controller: Controller,
154    executor: &'a ShortCircuitingComputation<T, Controller>
155}
156
157impl<'a, T, Controller> Clone for ShortCircuitingComputationHandle<'a, T, Controller>
158    where T: Send,
159        Controller: ComputationController
160{
161    fn clone(&self) -> Self {
162        Self {
163            controller: self.controller.clone(),
164            executor: self.executor
165        }
166    }
167}
168
169impl<'a, T, Controller> ShortCircuitingComputationHandle<'a, T, Controller>
170    where T: Send,
171        Controller: ComputationController
172{
173    #[stability::unstable(feature = "enable")]
174    pub fn controller(&self) -> &Controller {
175        &self.controller
176    }
177
178    #[stability::unstable(feature = "enable")]
179    pub fn checkpoint(&self, description: Arguments) -> Result<(), ShortCircuitingComputationAbort<Controller::Abort>> { 
180        if self.executor.finished.load(Ordering::Relaxed) {
181            return Err(ShortCircuitingComputationAbort::Finished);
182        } else if let Err(e) = self.controller.checkpoint(description) {
183            return Err(ShortCircuitingComputationAbort::Abort(e));
184        } else {
185            return Ok(());
186        }
187    }
188
189    #[stability::unstable(feature = "enable")]
190    pub fn log(&self, description: Arguments) {
191        self.controller.log(description)
192    }
193
194    #[stability::unstable(feature = "enable")]
195    pub fn join_many<V, F>(self, operations: V)
196        where V: VectorFn<F> + Sync,
197            F: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
198    {
199        fn join_many_internal<'a, T, V, F, Controller>(controller: Controller, executor: &'a ShortCircuitingComputation<T, Controller>, tasks: &V, from: usize, to: usize, batch_tasks: usize)
200            where T: Send,
201                Controller: ComputationController,
202                V: VectorFn<F> + Sync,
203                F: FnOnce(ShortCircuitingComputationHandle<'a, T, Controller>) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
204        {
205            if executor.finished.load(Ordering::Relaxed) {
206                return;
207            } else if from == to {
208                return;
209            } else if from + batch_tasks >= to {
210                for i in from..to {
211                    match tasks.at(i)(ShortCircuitingComputationHandle {
212                        controller: controller.clone(),
213                        executor: executor
214                    }) {
215                        Ok(Some(result)) => {
216                            executor.finished.store(true, Ordering::Relaxed);
217                            executor.result.store(Some(Box::new(result)), Ordering::AcqRel);
218                        },
219                        Err(ShortCircuitingComputationAbort::Abort(abort)) => {
220                            executor.finished.store(true, Ordering::Relaxed);
221                            executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
222                        },
223                        Err(ShortCircuitingComputationAbort::Finished) | Ok(None) => {}
224                    }
225                }
226            } else {
227                let mid = (from + to) / 2;
228                controller.join(move |controller| join_many_internal(controller, executor, tasks, from, mid, batch_tasks), move |controller| join_many_internal(controller, executor, tasks, mid, to, batch_tasks));
229            }
230        }
231        join_many_internal(self.controller, self.executor, &operations, 0, operations.len(), 1)
232    }
233
234    #[stability::unstable(feature = "enable")]
235    pub fn join<A, B>(self, oper_a: A, oper_b: B)
236        where
237            A: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
238            B: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send
239    {
240        let success_fn = |value: T| {
241            self.executor.finished.store(true, Ordering::Relaxed);
242            self.executor.result.store(Some(Box::new(value)), Ordering::AcqRel);
243        };
244        let abort_fn = |abort: Controller::Abort| {
245            self.executor.finished.store(true, Ordering::Relaxed);
246            self.executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
247        };
248        self.controller.join(
249            |controller| {
250                if self.executor.finished.load(Ordering::Relaxed) {
251                    return;
252                }
253                match oper_a(ShortCircuitingComputationHandle {
254                    controller,
255                    executor: self.executor
256                }) {
257                    Ok(Some(result)) => success_fn(result),
258                    Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
259                    Err(ShortCircuitingComputationAbort::Finished) => {},
260                    Ok(None) => {}
261                }
262            },
263            |controller| {
264                if self.executor.finished.load(Ordering::Relaxed) {
265                    return;
266                }
267                match oper_b(ShortCircuitingComputationHandle {
268                    controller,
269                    executor: self.executor
270                }) {
271                    Ok(Some(result)) => success_fn(result),
272                    Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
273                    Err(ShortCircuitingComputationAbort::Finished) => {},
274                    Ok(None) => {}
275                }
276            }
277        );
278    }
279}
280
281impl<T, Controller> ShortCircuitingComputation<T, Controller>
282    where T: Send,
283        Controller: ComputationController
284{
285    #[stability::unstable(feature = "enable")]
286    pub fn new() -> Self {
287        Self {
288            finished: AtomicBool::new(false),
289            abort: AtomicOptionBox::none(),
290            result: AtomicOptionBox::none()
291        }
292    }
293
294    #[stability::unstable(feature = "enable")]
295    pub fn handle<'a>(&'a self, controller: Controller) -> ShortCircuitingComputationHandle<'a, T, Controller> {
296        ShortCircuitingComputationHandle {
297            controller: controller,
298            executor: self
299        }
300    }
301
302    #[stability::unstable(feature = "enable")]
303    pub fn finish(self) -> Result<Option<T>, Controller::Abort> {
304        if let Some(abort) = self.abort.swap(None, Ordering::AcqRel) {
305            return Err(*abort);
306        } else if let Some(result) = self.result.swap(None, Ordering::AcqRel) {
307            return Ok(Some(*result));
308        } else {
309            return Ok(None);
310        }
311    }
312}
313
314#[macro_export]
315macro_rules! checkpoint {
316    ($controller:expr) => {
317        ($controller).checkpoint(std::format_args!(""))?
318    };
319    ($controller:expr, $($args:tt)*) => {
320        ($controller).checkpoint(std::format_args!($($args)*))?
321    };
322}
323
324#[macro_export]
325macro_rules! log_progress {
326    ($controller:expr, $($args:tt)*) => {
327        ($controller).log(std::format_args!($($args)*))
328    };
329}
330
331#[derive(Clone, Copy, Debug)]
332pub struct LogProgress {
333    inner_comp: bool
334}
335
336pub const LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
337
338///
339/// Use this in tests, to distinguish it from temporary uses of
340/// `LOG_PROGRESS` that shouldn't be used when publishing the crate.
341/// 
342#[cfg(test)]
343pub(crate) const TEST_LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
344
345impl UnstableSealed for LogProgress {}
346
347impl ComputationController for LogProgress {
348
349    type Abort = !;
350
351    #[stability::unstable(feature = "enable")]
352    fn log(&self, description: Arguments) {
353        print!("{}", description);
354        std::io::stdout().flush().unwrap();
355    }
356
357    #[stability::unstable(feature = "enable")]
358    fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
359        where F: FnOnce(Self) -> T
360    {
361        self.log(description);
362        let result = computation(Self { inner_comp: true });
363        if self.inner_comp {
364            self.log(format_args!("done."));
365        } else {
366            self.log(format_args!("done.\n"));
367        }
368        return result;
369    }
370
371    #[stability::unstable(feature = "enable")]
372    fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
373        self.log(description);
374        Ok(())
375    }
376}
377
378#[derive(Clone, Copy, Debug)]
379pub struct DontObserve;
380
381impl UnstableSealed for DontObserve {}
382
383impl ComputationController for DontObserve {
384
385    type Abort = !;
386}
387
388#[cfg(feature = "parallel")]
389mod parallel_controller {
390
391    use super::*;
392
393    #[stability::unstable(feature = "enable")]
394    pub struct ExecuteMultithreaded<Rest: ComputationController + Send> {
395        rest: Rest
396    }
397
398    impl<Rest: ComputationController + Send + Copy> Copy for ExecuteMultithreaded<Rest> {}
399
400    impl<Rest: ComputationController + Send> Clone for ExecuteMultithreaded<Rest> {
401        fn clone(&self) -> Self {
402            Self { rest: self.rest.clone() }
403        }
404    }
405    
406    impl<Rest: ComputationController + Send> UnstableSealed for ExecuteMultithreaded<Rest> {}
407
408    impl<Rest: ComputationController + Send> ComputationController for ExecuteMultithreaded<Rest> {
409        type Abort = Rest::Abort;
410
411        #[stability::unstable(feature = "enable")]
412        fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> { 
413            self.rest.checkpoint(description)
414        }
415    
416        #[stability::unstable(feature = "enable")]
417        fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
418            where F: FnOnce(Self) -> T
419        {
420            self.rest.run_computation(description, |rest| computation(ExecuteMultithreaded { rest }))
421        }
422
423        #[stability::unstable(feature = "enable")]
424        fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
425            where
426                A: FnOnce(Self) -> RA + Send,
427                B: FnOnce(Self) -> RB + Send,
428                RA: Send,
429                RB: Send
430        {
431            let self1 = self.clone();
432            let self2 = self;
433            rayon::join(|| oper_a(self1), || oper_b(self2))
434        }
435    }
436
437    #[stability::unstable(feature = "enable")]
438    #[allow(non_upper_case_globals)]
439    pub static RunMultithreadedLogProgress: ExecuteMultithreaded<LogProgress> = ExecuteMultithreaded { rest: LOG_PROGRESS };
440    #[stability::unstable(feature = "enable")]
441    #[allow(non_upper_case_globals)]
442    pub static RunMultithreaded: ExecuteMultithreaded<DontObserve> = ExecuteMultithreaded { rest: DontObserve };
443}
444
445#[cfg(feature = "parallel")]
446pub use parallel_controller::*;