Skip to main content

rand/distributions/
utils.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Math helper functions
10
11
12pub(crate) trait WideningMultiply<RHS = Self> {
13    type Output;
14
15    fn wmul(self, x: RHS) -> Self::Output;
16}
17
18macro_rules! wmul_impl {
19    ($ty:ty, $wide:ty, $shift:expr) => {
20        impl WideningMultiply for $ty {
21            type Output = ($ty, $ty);
22
23            #[inline(always)]
24            fn wmul(self, x: $ty) -> Self::Output {
25                let tmp = (self as $wide) * (x as $wide);
26                ((tmp >> $shift) as $ty, tmp as $ty)
27            }
28        }
29    };
30
31    // simd bulk implementation
32    ($(($ty:ident, $wide:ident),)+, $shift:expr) => {
33        $(
34            impl WideningMultiply for $ty {
35                type Output = ($ty, $ty);
36
37                #[inline(always)]
38                fn wmul(self, x: $ty) -> Self::Output {
39                    // For supported vectors, this should compile to a couple
40                    // supported multiply & swizzle instructions (no actual
41                    // casting).
42                    // TODO: optimize
43                    let y: $wide = self.cast();
44                    let x: $wide = x.cast();
45                    let tmp = y * x;
46                    let hi: $ty = (tmp >> $shift).cast();
47                    let lo: $ty = tmp.cast();
48                    (hi, lo)
49                }
50            }
51        )+
52    };
53}
54wmul_impl! { u8, u16, 8 }
55wmul_impl! { u16, u32, 16 }
56wmul_impl! { u32, u64, 32 }
57wmul_impl! { u64, u128, 64 }
58
59// This code is a translation of the __mulddi3 function in LLVM's
60// compiler-rt. It is an optimised variant of the common method
61// `(a + b) * (c + d) = ac + ad + bc + bd`.
62//
63// For some reason LLVM can optimise the C version very well, but
64// keeps shuffling registers in this Rust translation.
65macro_rules! wmul_impl_large {
66    ($ty:ty, $half:expr) => {
67        impl WideningMultiply for $ty {
68            type Output = ($ty, $ty);
69
70            #[inline(always)]
71            fn wmul(self, b: $ty) -> Self::Output {
72                const LOWER_MASK: $ty = !0 >> $half;
73                let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
74                let mut t = low >> $half;
75                low &= LOWER_MASK;
76                t += (self >> $half).wrapping_mul(b & LOWER_MASK);
77                low += (t & LOWER_MASK) << $half;
78                let mut high = t >> $half;
79                t = low >> $half;
80                low &= LOWER_MASK;
81                t += (b >> $half).wrapping_mul(self & LOWER_MASK);
82                low += (t & LOWER_MASK) << $half;
83                high += t >> $half;
84                high += (self >> $half).wrapping_mul(b >> $half);
85
86                (high, low)
87            }
88        }
89    };
90
91    // simd bulk implementation
92    (($($ty:ty,)+) $scalar:ty, $half:expr) => {
93        $(
94            impl WideningMultiply for $ty {
95                type Output = ($ty, $ty);
96
97                #[inline(always)]
98                fn wmul(self, b: $ty) -> Self::Output {
99                    // needs wrapping multiplication
100                    const LOWER_MASK: $scalar = !0 >> $half;
101                    let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
102                    let mut t = low >> $half;
103                    low &= LOWER_MASK;
104                    t += (self >> $half) * (b & LOWER_MASK);
105                    low += (t & LOWER_MASK) << $half;
106                    let mut high = t >> $half;
107                    t = low >> $half;
108                    low &= LOWER_MASK;
109                    t += (b >> $half) * (self & LOWER_MASK);
110                    low += (t & LOWER_MASK) << $half;
111                    high += t >> $half;
112                    high += (self >> $half) * (b >> $half);
113
114                    (high, low)
115                }
116            }
117        )+
118    };
119}
120wmul_impl_large! { u128, 64 }
121
122macro_rules! wmul_impl_usize {
123    ($ty:ty) => {
124        impl WideningMultiply for usize {
125            type Output = (usize, usize);
126
127            #[inline(always)]
128            fn wmul(self, x: usize) -> Self::Output {
129                let (high, low) = (self as $ty).wmul(x as $ty);
130                (high as usize, low as usize)
131            }
132        }
133    };
134}
135#[cfg(target_pointer_width = "16")]
136wmul_impl_usize! { u16 }
137#[cfg(target_pointer_width = "32")]
138wmul_impl_usize! { u32 }
139#[cfg(target_pointer_width = "64")]
140wmul_impl_usize! { u64 }
141
142/// Helper trait when dealing with scalar and SIMD floating point types.
143pub(crate) trait FloatSIMDUtils {
144    // `PartialOrd` for vectors compares lexicographically. We want to compare all
145    // the individual SIMD lanes instead, and get the combined result over all
146    // lanes. This is possible using something like `a.lt(b).all()`, but we
147    // implement it as a trait so we can write the same code for `f32` and `f64`.
148    // Only the comparison functions we need are implemented.
149    fn all_lt(self, other: Self) -> bool;
150    fn all_le(self, other: Self) -> bool;
151    fn all_finite(self) -> bool;
152
153    type Mask;
154    fn finite_mask(self) -> Self::Mask;
155    fn gt_mask(self, other: Self) -> Self::Mask;
156    fn ge_mask(self, other: Self) -> Self::Mask;
157
158    // Decrease all lanes where the mask is `true` to the next lower value
159    // representable by the floating-point type. At least one of the lanes
160    // must be set.
161    fn decrease_masked(self, mask: Self::Mask) -> Self;
162
163    // Convert from int value. Conversion is done while retaining the numerical
164    // value, not by retaining the binary representation.
165    type UInt;
166    fn cast_from_int(i: Self::UInt) -> Self;
167}
168
169/// Implement functions available in std builds but missing from core primitives
170#[cfg(not(feature = "std"))]
171#[allow(unused)]
172// False positive: We are following `std` here.
173#[allow(clippy::wrong_self_convention)]
174pub(crate) trait Float: Sized {
175    fn is_nan(self) -> bool;
176    fn is_infinite(self) -> bool;
177    fn is_finite(self) -> bool;
178}
179
180/// Implement functions on f32/f64 to give them APIs similar to SIMD types
181#[allow(unused)]
182pub(crate) trait FloatAsSIMD: Sized {
183    #[inline(always)]
184    fn lanes() -> usize {
185        1
186    }
187    #[inline(always)]
188    fn splat(scalar: Self) -> Self {
189        scalar
190    }
191    #[inline(always)]
192    fn extract(self, index: usize) -> Self {
193        debug_assert_eq!(index, 0);
194        self
195    }
196    #[inline(always)]
197    fn replace(self, index: usize, new_value: Self) -> Self {
198        debug_assert_eq!(index, 0);
199        new_value
200    }
201}
202
203#[allow(unused)]
204pub(crate) trait BoolAsSIMD: Sized {
205    fn any(self) -> bool;
206    fn all(self) -> bool;
207    fn none(self) -> bool;
208}
209
210impl BoolAsSIMD for bool {
211    #[inline(always)]
212    fn any(self) -> bool {
213        self
214    }
215
216    #[inline(always)]
217    fn all(self) -> bool {
218        self
219    }
220
221    #[inline(always)]
222    fn none(self) -> bool {
223        !self
224    }
225}
226
227macro_rules! scalar_float_impl {
228    ($ty:ident, $uty:ident) => {
229        #[cfg(not(feature = "std"))]
230        impl Float for $ty {
231            #[inline]
232            fn is_nan(self) -> bool {
233                self != self
234            }
235
236            #[inline]
237            fn is_infinite(self) -> bool {
238                self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY
239            }
240
241            #[inline]
242            fn is_finite(self) -> bool {
243                !(self.is_nan() || self.is_infinite())
244            }
245        }
246
247        impl FloatSIMDUtils for $ty {
248            type Mask = bool;
249            type UInt = $uty;
250
251            #[inline(always)]
252            fn all_lt(self, other: Self) -> bool {
253                self < other
254            }
255
256            #[inline(always)]
257            fn all_le(self, other: Self) -> bool {
258                self <= other
259            }
260
261            #[inline(always)]
262            fn all_finite(self) -> bool {
263                self.is_finite()
264            }
265
266            #[inline(always)]
267            fn finite_mask(self) -> Self::Mask {
268                self.is_finite()
269            }
270
271            #[inline(always)]
272            fn gt_mask(self, other: Self) -> Self::Mask {
273                self > other
274            }
275
276            #[inline(always)]
277            fn ge_mask(self, other: Self) -> Self::Mask {
278                self >= other
279            }
280
281            #[inline(always)]
282            fn decrease_masked(self, mask: Self::Mask) -> Self {
283                debug_assert!(mask, "At least one lane must be set");
284                <$ty>::from_bits(self.to_bits() - 1)
285            }
286
287            #[inline]
288            fn cast_from_int(i: Self::UInt) -> Self {
289                i as $ty
290            }
291        }
292
293        impl FloatAsSIMD for $ty {}
294    };
295}
296
297scalar_float_impl!(f32, u32);
298scalar_float_impl!(f64, u64);