1#[inline]
15pub fn abs(x: f64) -> f64 {
16 libm::fabs(x)
17}
18
19#[inline]
21pub fn sqrt(x: f64) -> f64 {
22 libm::sqrt(x)
23}
24
25#[inline]
27pub fn exp(x: f64) -> f64 {
28 libm::exp(x)
29}
30
31#[inline]
33pub fn ln(x: f64) -> f64 {
34 libm::log(x)
35}
36
37#[inline]
39pub fn log2(x: f64) -> f64 {
40 libm::log2(x)
41}
42
43#[inline]
45pub fn log10(x: f64) -> f64 {
46 libm::log10(x)
47}
48
49#[inline]
51pub fn powf(x: f64, n: f64) -> f64 {
52 libm::pow(x, n)
53}
54
55#[inline]
57pub fn powi(x: f64, n: i32) -> f64 {
58 libm::pow(x, n as f64)
59}
60
61#[inline]
63pub fn sin(x: f64) -> f64 {
64 libm::sin(x)
65}
66
67#[inline]
69pub fn cos(x: f64) -> f64 {
70 libm::cos(x)
71}
72
73#[inline]
75pub fn floor(x: f64) -> f64 {
76 libm::floor(x)
77}
78
79#[inline]
81pub fn ceil(x: f64) -> f64 {
82 libm::ceil(x)
83}
84
85#[inline]
87pub fn round(x: f64) -> f64 {
88 libm::round(x)
89}
90
91#[inline]
93pub fn tanh(x: f64) -> f64 {
94 libm::tanh(x)
95}
96
97#[inline]
99pub fn softplus(x: f64) -> f64 {
100 if x > 20.0 {
101 x
102 } else if x < -20.0 {
103 libm::exp(x)
104 } else {
105 libm::log(1.0 + libm::exp(x))
106 }
107}
108
109#[inline]
111pub fn sigmoid(x: f64) -> f64 {
112 if x >= 0.0 {
113 let e = libm::exp(-x);
114 1.0 / (1.0 + e)
115 } else {
116 let e = libm::exp(x);
117 e / (1.0 + e)
118 }
119}
120
121#[inline]
123pub fn fmin(x: f64, y: f64) -> f64 {
124 libm::fmin(x, y)
125}
126
127#[inline]
129pub fn fmax(x: f64, y: f64) -> f64 {
130 libm::fmax(x, y)
131}
132
133#[inline]
135pub fn erf(x: f64) -> f64 {
136 libm::erf(x)
137}
138
139#[inline]
141pub fn abs_f32(x: f32) -> f32 {
142 libm::fabsf(x)
143}
144
145#[inline]
147pub fn sqrt_f32(x: f32) -> f32 {
148 libm::sqrtf(x)
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn sqrt_of_4() {
157 assert!((sqrt(4.0) - 2.0).abs() < 1e-15);
158 }
159
160 #[test]
161 fn exp_of_0() {
162 assert!((exp(0.0) - 1.0).abs() < 1e-15);
163 }
164
165 #[test]
166 fn ln_of_e() {
167 assert!((ln(core::f64::consts::E) - 1.0).abs() < 1e-15);
168 }
169
170 #[test]
171 fn abs_negative() {
172 assert_eq!(abs(-5.0), 5.0);
173 assert_eq!(abs(5.0), 5.0);
174 assert_eq!(abs(0.0), 0.0);
175 }
176
177 #[test]
178 fn powf_squares() {
179 assert!((powf(3.0, 2.0) - 9.0).abs() < 1e-15);
180 }
181
182 #[test]
183 fn powi_cubes() {
184 assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
185 }
186
187 #[test]
188 fn sin_cos_identity() {
189 let x = 1.0;
190 let s = sin(x);
191 let c = cos(x);
192 assert!((s * s + c * c - 1.0).abs() < 1e-15);
193 }
194
195 #[test]
196 fn floor_ceil_round() {
197 assert_eq!(floor(2.7), 2.0);
198 assert_eq!(ceil(2.3), 3.0);
199 assert_eq!(round(2.5), 3.0);
200 assert_eq!(round(2.4), 2.0);
201 }
202
203 #[test]
204 fn log2_of_8() {
205 assert!((log2(8.0) - 3.0).abs() < 1e-15);
206 }
207
208 #[test]
209 fn tanh_of_0() {
210 assert!((tanh(0.0)).abs() < 1e-15);
211 }
212
213 #[test]
214 fn fmin_fmax() {
215 assert_eq!(fmin(1.0, 2.0), 1.0);
216 assert_eq!(fmax(1.0, 2.0), 2.0);
217 }
218
219 #[test]
220 fn softplus_large_positive() {
221 assert!((softplus(50.0) - 50.0).abs() < 1e-10);
223 }
224
225 #[test]
226 fn softplus_large_negative() {
227 let result = softplus(-50.0);
229 assert!(result >= 0.0 && result < 1e-20);
230 }
231
232 #[test]
233 fn softplus_zero() {
234 let expected = ln(2.0);
235 assert!((softplus(0.0) - expected).abs() < 1e-12);
236 }
237
238 #[test]
239 fn softplus_always_positive() {
240 for &x in &[-10.0, -1.0, 0.0, 1.0, 10.0] {
241 assert!(softplus(x) > 0.0, "softplus({}) should be > 0", x);
242 }
243 }
244
245 #[test]
246 fn sigmoid_at_zero() {
247 assert!((sigmoid(0.0) - 0.5).abs() < 1e-12);
248 }
249
250 #[test]
251 fn sigmoid_range() {
252 for &x in &[-10.0, -1.0, 0.0, 1.0, 10.0] {
253 let s = sigmoid(x);
254 assert!(
255 s > 0.0 && s < 1.0,
256 "sigmoid({}) = {} should be in (0, 1)",
257 x,
258 s
259 );
260 }
261 }
262
263 #[test]
264 fn sigmoid_symmetry() {
265 let x = 3.0;
266 assert!((sigmoid(x) + sigmoid(-x) - 1.0).abs() < 1e-12);
267 }
268
269 #[test]
270 fn sigmoid_extreme_values() {
271 let s_pos = sigmoid(100.0);
272 let s_neg = sigmoid(-100.0);
273 assert!(s_pos >= 0.0 && s_pos <= 1.0);
274 assert!(s_neg >= 0.0 && s_neg <= 1.0);
275 }
276}