Skip to main content

feanor_math/
computation.rs

1use std::fmt::Arguments;
2use std::io::Write;
3use std::sync::atomic::{AtomicBool, Ordering};
4
5use atomicbox::AtomicOptionBox;
6
7use crate::seq::VectorFn;
8use crate::unstable_sealed::UnstableSealed;
9
10/// Provides an idiomatic way to convert a `Result<T, !>` into `T`, via
11/// ```rust
12/// # #![feature(never_type)]
13/// # use feanor_math::computation::*;
14/// fn some_computation() -> Result<&'static str, !> { Ok("this computation does not fail") }
15/// println!("{}", some_computation().unwrap_or_else(no_error));
16/// ```
17pub fn no_error<T>(error: !) -> T { error }
18
19/// Trait for objects that observe and control a potentially long-running computation.
20///
21/// The idea is that this trait defines multiple functions that can be called during an
22/// algorithm, and provide certain functionality. This way, each algorithm can decide which
23/// functionality is relevant and how it is used.
24///
25/// This is currently unstable-sealed, since I expect significant additional functionality,
26/// potentially including
27///  - Early aborts, timeouts
28///  - Multithreading
29///  - Logging
30///  - ...
31///
32/// As a user, this trait should currently be used by passing either [`LogProgress`]
33/// or [`DontObserve`] to algorithms.
34///
35/// Also, note that all `description` parameters passed to computation controller
36/// functions are only for logging/debugging purposes only. There is no specified format,
37/// nor any stability guarantees on those messages.
38///
39/// # Example
40///
41/// Which features of a [`ComputationController`] an algorithm supports is completely up
42/// to the algorithm. Elliptic Curve factorization currently supports logging, abortion
43/// and multithreading.
44/// ```rust
45/// # use feanor_math::ring::*;
46/// # use feanor_math::algorithms::ec_factor::*;
47/// # use feanor_math::rings::zn::*;
48/// # use feanor_math::computation::*;
49/// let ring = zn_64::Zn::new(8591966237);
50/// // factors 8591966237 while printing progress
51/// let factor = lenstra_ec_factor(ring, LOG_PROGRESS).unwrap_or_else(no_error);
52/// assert!(8591966237 % factor == 0);
53/// // factor it again, but don't print progress
54/// let factor = lenstra_ec_factor(ring, DontObserve).unwrap_or_else(no_error);
55/// assert!(8591966237 % factor == 0);
56/// ```
57/// If the multithreading with rayon is enabled, we can also do
58/// ```rust
59/// # use feanor_math::ring::*;
60/// # use feanor_math::algorithms::ec_factor::*;
61/// # use feanor_math::rings::zn::*;
62/// # use feanor_math::computation::*;
63/// # let ring = zn_64::Zn::new(8591966237);
64/// // factors 8591966237 using multiple threads
65/// let factor = lenstra_ec_factor(ring, RunMultithreadedLogProgress).unwrap_or_else(no_error);
66/// assert!(8591966237 % factor == 0);
67/// ```
68pub trait ComputationController: Clone + UnstableSealed {
69    type Abort: Send;
70
71    /// Called by algorithms in (more or less) regular time intervals, can provide
72    /// e.g. early aborts or tracking progress.
73    #[stability::unstable(feature = "enable")]
74    fn checkpoint(&self, _description: Arguments) -> Result<(), Self::Abort> { Ok(()) }
75
76    /// Runs the given closure with a clone of this iterator, possibly adding a log
77    /// message before and/or after the computation starts/finishes.
78    ///
79    /// I am currently not completely sure what the right behavior is when this
80    /// function is called multiple times (possibly nested) for clones of the current
81    /// controller. We should certainly support nesting of computations, but what should
82    /// happen in multithreaded scenarios, if we have clones of controllers, or multiple
83    /// different controllers?
84    #[stability::unstable(feature = "enable")]
85    fn run_computation<F, T>(self, _description: Arguments, computation: F) -> T
86    where
87        F: FnOnce(Self) -> T,
88    {
89        computation(self)
90    }
91
92    #[stability::unstable(feature = "enable")]
93    fn log(&self, _description: Arguments) {}
94
95    /// Inspired by Rayon, and behaves the same as `join()` there.
96    /// Concretely, this function runs both closures, possibly in parallel, and
97    /// returns their results.
98    #[stability::unstable(feature = "enable")]
99    fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
100    where
101        A: FnOnce(Self) -> RA + Send,
102        B: FnOnce(Self) -> RB + Send,
103        RA: Send,
104        RB: Send,
105    {
106        (oper_a(self.clone()), oper_b(self.clone()))
107    }
108}
109
110/// The reason why a (part of a) short-circuiting computation was aborted.
111///
112/// `Finished` means that the computation was aborted, since another part already
113/// found a result or aborted. `Abort(e)` means that the controller chose to abort
114/// the computation at a checkpoint, with data `e`.
115pub enum ShortCircuitingComputationAbort<E> {
116    Finished,
117    Abort(E),
118}
119
120/// Shared data of a short-circuiting computation.
121pub struct ShortCircuitingComputation<T, Controller>
122where
123    T: Send,
124    Controller: ComputationController,
125{
126    finished: AtomicBool,
127    abort: AtomicOptionBox<Controller::Abort>,
128    result: AtomicOptionBox<T>,
129}
130
131/// Handle to a short-circuiting computation.
132pub struct ShortCircuitingComputationHandle<'a, T, Controller>
133where
134    T: Send,
135    Controller: ComputationController,
136{
137    controller: Controller,
138    executor: &'a ShortCircuitingComputation<T, Controller>,
139}
140
141impl<'a, T, Controller> Clone for ShortCircuitingComputationHandle<'a, T, Controller>
142where
143    T: Send,
144    Controller: ComputationController,
145{
146    fn clone(&self) -> Self {
147        Self {
148            controller: self.controller.clone(),
149            executor: self.executor,
150        }
151    }
152}
153
154impl<'a, T, Controller> ShortCircuitingComputationHandle<'a, T, Controller>
155where
156    T: Send,
157    Controller: ComputationController,
158{
159    #[stability::unstable(feature = "enable")]
160    pub fn controller(&self) -> &Controller { &self.controller }
161
162    #[stability::unstable(feature = "enable")]
163    pub fn checkpoint(&self, description: Arguments) -> Result<(), ShortCircuitingComputationAbort<Controller::Abort>> {
164        if self.executor.finished.load(Ordering::Relaxed) {
165            return Err(ShortCircuitingComputationAbort::Finished);
166        } else if let Err(e) = self.controller.checkpoint(description) {
167            return Err(ShortCircuitingComputationAbort::Abort(e));
168        } else {
169            return Ok(());
170        }
171    }
172
173    #[stability::unstable(feature = "enable")]
174    pub fn log(&self, description: Arguments) { self.controller.log(description) }
175
176    #[stability::unstable(feature = "enable")]
177    pub fn join_many<V, F>(self, operations: V)
178    where
179        V: VectorFn<F> + Sync,
180        F: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>,
181    {
182        fn join_many_internal<'a, T, V, F, Controller>(
183            controller: Controller,
184            executor: &'a ShortCircuitingComputation<T, Controller>,
185            tasks: &V,
186            from: usize,
187            to: usize,
188            batch_tasks: usize,
189        ) where
190            T: Send,
191            Controller: ComputationController,
192            V: VectorFn<F> + Sync,
193            F: FnOnce(
194                ShortCircuitingComputationHandle<'a, T, Controller>,
195            ) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>,
196        {
197            if executor.finished.load(Ordering::Relaxed) {
198                return;
199            } else if from == to {
200                return;
201            } else if from + batch_tasks >= to {
202                for i in from..to {
203                    match tasks.at(i)(ShortCircuitingComputationHandle {
204                        controller: controller.clone(),
205                        executor,
206                    }) {
207                        Ok(Some(result)) => {
208                            executor.finished.store(true, Ordering::Relaxed);
209                            executor.result.store(Some(Box::new(result)), Ordering::AcqRel);
210                        }
211                        Err(ShortCircuitingComputationAbort::Abort(abort)) => {
212                            executor.finished.store(true, Ordering::Relaxed);
213                            executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
214                        }
215                        Err(ShortCircuitingComputationAbort::Finished) | Ok(None) => {}
216                    }
217                }
218            } else {
219                let mid = (from + to) / 2;
220                _ = controller.join(
221                    move |controller| join_many_internal(controller, executor, tasks, from, mid, batch_tasks),
222                    move |controller| join_many_internal(controller, executor, tasks, mid, to, batch_tasks),
223                );
224            }
225        }
226        join_many_internal(self.controller, self.executor, &operations, 0, operations.len(), 1)
227    }
228
229    #[stability::unstable(feature = "enable")]
230    pub fn join<A, B>(self, oper_a: A, oper_b: B)
231    where
232        A: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
233        B: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
234    {
235        let success_fn = |value: T| {
236            self.executor.finished.store(true, Ordering::Relaxed);
237            self.executor.result.store(Some(Box::new(value)), Ordering::AcqRel);
238        };
239        let abort_fn = |abort: Controller::Abort| {
240            self.executor.finished.store(true, Ordering::Relaxed);
241            self.executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
242        };
243        _ = self.controller.join(
244            |controller| {
245                if self.executor.finished.load(Ordering::Relaxed) {
246                    return;
247                }
248                match oper_a(ShortCircuitingComputationHandle {
249                    controller,
250                    executor: self.executor,
251                }) {
252                    Ok(Some(result)) => success_fn(result),
253                    Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
254                    Err(ShortCircuitingComputationAbort::Finished) => {}
255                    Ok(None) => {}
256                }
257            },
258            |controller| {
259                if self.executor.finished.load(Ordering::Relaxed) {
260                    return;
261                }
262                match oper_b(ShortCircuitingComputationHandle {
263                    controller,
264                    executor: self.executor,
265                }) {
266                    Ok(Some(result)) => success_fn(result),
267                    Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
268                    Err(ShortCircuitingComputationAbort::Finished) => {}
269                    Ok(None) => {}
270                }
271            },
272        );
273    }
274}
275
276impl<T, Controller> ShortCircuitingComputation<T, Controller>
277where
278    T: Send,
279    Controller: ComputationController,
280{
281    #[stability::unstable(feature = "enable")]
282    pub fn new() -> Self {
283        Self {
284            finished: AtomicBool::new(false),
285            abort: AtomicOptionBox::none(),
286            result: AtomicOptionBox::none(),
287        }
288    }
289
290    #[stability::unstable(feature = "enable")]
291    pub fn handle<'a>(&'a self, controller: Controller) -> ShortCircuitingComputationHandle<'a, T, Controller> {
292        ShortCircuitingComputationHandle {
293            controller,
294            executor: self,
295        }
296    }
297
298    #[stability::unstable(feature = "enable")]
299    pub fn finish(self) -> Result<Option<T>, Controller::Abort> {
300        if let Some(abort) = self.abort.swap(None, Ordering::AcqRel) {
301            return Err(*abort);
302        } else if let Some(result) = self.result.swap(None, Ordering::AcqRel) {
303            return Ok(Some(*result));
304        } else {
305            return Ok(None);
306        }
307    }
308}
309
310#[macro_export]
311macro_rules! checkpoint {
312    ($controller:expr) => {
313        ($controller).checkpoint(std::format_args!(""))?
314    };
315    ($controller:expr,$($args:tt)*) => {
316        ($controller).checkpoint(std::format_args!($($args)*))?
317    };
318}
319
320#[macro_export]
321macro_rules! log_progress {
322    ($controller:expr,$($args:tt)*) => {
323        ($controller).log(std::format_args!($($args)*))
324    };
325}
326
327#[derive(Clone, Copy, Debug)]
328pub struct LogProgress {
329    inner_comp: bool,
330}
331
332pub const LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
333
334/// Use this in tests, to distinguish it from temporary uses of
335/// `LOG_PROGRESS` that shouldn't be used when publishing the crate.
336#[cfg(test)]
337pub(crate) const TEST_LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
338
339impl UnstableSealed for LogProgress {}
340
341impl ComputationController for LogProgress {
342    type Abort = !;
343
344    #[stability::unstable(feature = "enable")]
345    fn log(&self, description: Arguments) {
346        print!("{}", description);
347        std::io::stdout().flush().unwrap();
348    }
349
350    #[stability::unstable(feature = "enable")]
351    fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
352    where
353        F: FnOnce(Self) -> T,
354    {
355        self.log(description);
356        let result = computation(Self { inner_comp: true });
357        if self.inner_comp {
358            self.log(format_args!("done."));
359        } else {
360            self.log(format_args!("done.\n"));
361        }
362        return result;
363    }
364
365    #[stability::unstable(feature = "enable")]
366    fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
367        self.log(description);
368        Ok(())
369    }
370}
371
372#[derive(Clone, Copy, Debug)]
373pub struct DontObserve;
374
375impl UnstableSealed for DontObserve {}
376
377impl ComputationController for DontObserve {
378    type Abort = !;
379}
380
381#[cfg(feature = "parallel")]
382mod parallel_controller {
383
384    use super::*;
385
386    #[stability::unstable(feature = "enable")]
387    pub struct ExecuteMultithreaded<Rest: ComputationController + Send> {
388        rest: Rest,
389    }
390
391    impl<Rest: ComputationController + Send + Copy> Copy for ExecuteMultithreaded<Rest> {}
392
393    impl<Rest: ComputationController + Send> Clone for ExecuteMultithreaded<Rest> {
394        fn clone(&self) -> Self {
395            Self {
396                rest: self.rest.clone(),
397            }
398        }
399    }
400
401    impl<Rest: ComputationController + Send> UnstableSealed for ExecuteMultithreaded<Rest> {}
402
403    impl<Rest: ComputationController + Send> ComputationController for ExecuteMultithreaded<Rest> {
404        type Abort = Rest::Abort;
405
406        #[stability::unstable(feature = "enable")]
407        fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> { self.rest.checkpoint(description) }
408
409        #[stability::unstable(feature = "enable")]
410        fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
411        where
412            F: FnOnce(Self) -> T,
413        {
414            self.rest
415                .run_computation(description, |rest| computation(ExecuteMultithreaded { rest }))
416        }
417
418        #[stability::unstable(feature = "enable")]
419        fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
420        where
421            A: FnOnce(Self) -> RA + Send,
422            B: FnOnce(Self) -> RB + Send,
423            RA: Send,
424            RB: Send,
425        {
426            let self1 = self.clone();
427            let self2 = self;
428            rayon::join(|| oper_a(self1), || oper_b(self2))
429        }
430    }
431
432    #[stability::unstable(feature = "enable")]
433    #[allow(non_upper_case_globals)]
434    pub static RunMultithreadedLogProgress: ExecuteMultithreaded<LogProgress> =
435        ExecuteMultithreaded { rest: LOG_PROGRESS };
436    #[stability::unstable(feature = "enable")]
437    #[allow(non_upper_case_globals)]
438    pub static RunMultithreaded: ExecuteMultithreaded<DontObserve> = ExecuteMultithreaded { rest: DontObserve };
439}
440
441#[cfg(feature = "parallel")]
442pub use parallel_controller::*;