dreid_kernel/
math.rs

1#[cfg(not(feature = "std"))]
2use libm::{acos, acosf, cos, cosf, exp, expf, fabs, fabsf, sin, sinf, sqrt, sqrtf};
3
4/// Abstract trait for floating-point operations required by force field kernels.
5///
6/// This trait abstracts over `f32` and `f64`, providing a unified interface
7/// for mathematical functions whether in `std` or `no_std` environments.
8pub trait Real:
9    Copy
10    + Clone
11    + PartialOrd
12    + PartialEq
13    + core::fmt::Debug
14    + core::ops::Add<Output = Self>
15    + core::ops::Sub<Output = Self>
16    + core::ops::Mul<Output = Self>
17    + core::ops::Div<Output = Self>
18    + core::ops::Neg<Output = Self>
19    + From<f32>
20{
21    // Constants
22    fn pi() -> Self;
23
24    // Basic functions
25    fn sqrt(self) -> Self;
26    fn recip(self) -> Self;
27    fn abs(self) -> Self;
28    fn max(self, other: Self) -> Self;
29    fn min(self, other: Self) -> Self;
30
31    // Transcendental functions
32    fn exp(self) -> Self;
33    fn sin(self) -> Self;
34    fn cos(self) -> Self;
35    fn acos(self) -> Self;
36
37    // Composite helpers (can be specialized by hardware)
38    #[inline(always)]
39    fn rsqrt(self) -> Self {
40        self.sqrt().recip()
41    }
42}
43
44// ============================================================================
45// Implementation for f32
46// ============================================================================
47
48impl Real for f32 {
49    #[inline(always)]
50    fn pi() -> Self {
51        core::f32::consts::PI
52    }
53
54    #[inline(always)]
55    fn sqrt(self) -> Self {
56        #[cfg(feature = "std")]
57        {
58            self.sqrt()
59        }
60        #[cfg(not(feature = "std"))]
61        {
62            sqrtf(self)
63        }
64    }
65
66    #[inline(always)]
67    fn recip(self) -> Self {
68        self.recip()
69    }
70
71    #[inline(always)]
72    fn abs(self) -> Self {
73        #[cfg(feature = "std")]
74        {
75            self.abs()
76        }
77        #[cfg(not(feature = "std"))]
78        {
79            fabsf(self)
80        }
81    }
82
83    #[inline(always)]
84    fn max(self, other: Self) -> Self {
85        #[cfg(feature = "std")]
86        {
87            self.max(other)
88        }
89        #[cfg(not(feature = "std"))]
90        {
91            if self > other { self } else { other }
92        }
93    }
94
95    #[inline(always)]
96    fn min(self, other: Self) -> Self {
97        #[cfg(feature = "std")]
98        {
99            self.min(other)
100        }
101        #[cfg(not(feature = "std"))]
102        {
103            if self < other { self } else { other }
104        }
105    }
106
107    #[inline(always)]
108    fn exp(self) -> Self {
109        #[cfg(feature = "std")]
110        {
111            self.exp()
112        }
113        #[cfg(not(feature = "std"))]
114        {
115            expf(self)
116        }
117    }
118
119    #[inline(always)]
120    fn sin(self) -> Self {
121        #[cfg(feature = "std")]
122        {
123            self.sin()
124        }
125        #[cfg(not(feature = "std"))]
126        {
127            sinf(self)
128        }
129    }
130
131    #[inline(always)]
132    fn cos(self) -> Self {
133        #[cfg(feature = "std")]
134        {
135            self.cos()
136        }
137        #[cfg(not(feature = "std"))]
138        {
139            cosf(self)
140        }
141    }
142
143    #[inline(always)]
144    fn acos(self) -> Self {
145        #[cfg(feature = "std")]
146        {
147            self.acos()
148        }
149        #[cfg(not(feature = "std"))]
150        {
151            acosf(self)
152        }
153    }
154}
155
156// ============================================================================
157// Implementation for f64
158// ============================================================================
159
160impl Real for f64 {
161    #[inline(always)]
162    fn pi() -> Self {
163        core::f64::consts::PI
164    }
165
166    #[inline(always)]
167    fn sqrt(self) -> Self {
168        #[cfg(feature = "std")]
169        {
170            self.sqrt()
171        }
172        #[cfg(not(feature = "std"))]
173        {
174            sqrt(self)
175        }
176    }
177
178    #[inline(always)]
179    fn recip(self) -> Self {
180        self.recip()
181    }
182
183    #[inline(always)]
184    fn abs(self) -> Self {
185        #[cfg(feature = "std")]
186        {
187            self.abs()
188        }
189        #[cfg(not(feature = "std"))]
190        {
191            fabs(self)
192        }
193    }
194
195    #[inline(always)]
196    fn max(self, other: Self) -> Self {
197        #[cfg(feature = "std")]
198        {
199            self.max(other)
200        }
201        #[cfg(not(feature = "std"))]
202        {
203            if self > other { self } else { other }
204        }
205    }
206
207    #[inline(always)]
208    fn min(self, other: Self) -> Self {
209        #[cfg(feature = "std")]
210        {
211            self.min(other)
212        }
213        #[cfg(not(feature = "std"))]
214        {
215            if self < other { self } else { other }
216        }
217    }
218
219    #[inline(always)]
220    fn exp(self) -> Self {
221        #[cfg(feature = "std")]
222        {
223            self.exp()
224        }
225        #[cfg(not(feature = "std"))]
226        {
227            exp(self)
228        }
229    }
230
231    #[inline(always)]
232    fn sin(self) -> Self {
233        #[cfg(feature = "std")]
234        {
235            self.sin()
236        }
237        #[cfg(not(feature = "std"))]
238        {
239            sin(self)
240        }
241    }
242
243    #[inline(always)]
244    fn cos(self) -> Self {
245        #[cfg(feature = "std")]
246        {
247            self.cos()
248        }
249        #[cfg(not(feature = "std"))]
250        {
251            cos(self)
252        }
253    }
254
255    #[inline(always)]
256    fn acos(self) -> Self {
257        #[cfg(feature = "std")]
258        {
259            self.acos()
260        }
261        #[cfg(not(feature = "std"))]
262        {
263            acos(self)
264        }
265    }
266}