1#[inline]
15pub fn abs(x: f64) -> f64 {
16 libm::fabs(x)
17}
18
19#[inline]
21pub fn sqrt(x: f64) -> f64 {
22 libm::sqrt(x)
23}
24
25#[inline]
27pub fn exp(x: f64) -> f64 {
28 libm::exp(x)
29}
30
31#[inline]
33pub fn ln(x: f64) -> f64 {
34 libm::log(x)
35}
36
37#[inline]
39pub fn log2(x: f64) -> f64 {
40 libm::log2(x)
41}
42
43#[inline]
45pub fn log10(x: f64) -> f64 {
46 libm::log10(x)
47}
48
49#[inline]
51pub fn powf(x: f64, n: f64) -> f64 {
52 libm::pow(x, n)
53}
54
55#[inline]
57pub fn powi(x: f64, n: i32) -> f64 {
58 libm::pow(x, n as f64)
59}
60
61#[inline]
63pub fn sin(x: f64) -> f64 {
64 libm::sin(x)
65}
66
67#[inline]
69pub fn cos(x: f64) -> f64 {
70 libm::cos(x)
71}
72
73#[inline]
75pub fn floor(x: f64) -> f64 {
76 libm::floor(x)
77}
78
79#[inline]
81pub fn ceil(x: f64) -> f64 {
82 libm::ceil(x)
83}
84
85#[inline]
87pub fn round(x: f64) -> f64 {
88 libm::round(x)
89}
90
91#[inline]
93pub fn tanh(x: f64) -> f64 {
94 libm::tanh(x)
95}
96
97#[inline]
99pub fn softplus(x: f64) -> f64 {
100 if x > 20.0 {
101 x
102 } else if x < -20.0 {
103 libm::exp(x)
104 } else {
105 libm::log(1.0 + libm::exp(x))
106 }
107}
108
109#[inline]
111pub fn sigmoid(x: f64) -> f64 {
112 if x >= 0.0 {
113 let e = libm::exp(-x);
114 1.0 / (1.0 + e)
115 } else {
116 let e = libm::exp(x);
117 e / (1.0 + e)
118 }
119}
120
121#[inline]
123pub fn fmin(x: f64, y: f64) -> f64 {
124 libm::fmin(x, y)
125}
126
127#[inline]
129pub fn fmax(x: f64, y: f64) -> f64 {
130 libm::fmax(x, y)
131}
132
133#[inline]
138pub fn silu(x: f64) -> f64 {
139 x * sigmoid(x)
140}
141
142#[inline]
144pub fn silu_derivative(x: f64) -> f64 {
145 let s = sigmoid(x);
146 s + x * s * (1.0 - s)
147}
148
149#[inline]
151pub fn erf(x: f64) -> f64 {
152 libm::erf(x)
153}
154
155#[inline]
157pub fn abs_f32(x: f32) -> f32 {
158 libm::fabsf(x)
159}
160
161#[inline]
163pub fn sqrt_f32(x: f32) -> f32 {
164 libm::sqrtf(x)
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn sqrt_of_4() {
173 assert!((sqrt(4.0) - 2.0).abs() < 1e-15);
174 }
175
176 #[test]
177 fn exp_of_0() {
178 assert!((exp(0.0) - 1.0).abs() < 1e-15);
179 }
180
181 #[test]
182 fn ln_of_e() {
183 assert!((ln(core::f64::consts::E) - 1.0).abs() < 1e-15);
184 }
185
186 #[test]
187 fn abs_negative() {
188 assert_eq!(abs(-5.0), 5.0);
189 assert_eq!(abs(5.0), 5.0);
190 assert_eq!(abs(0.0), 0.0);
191 }
192
193 #[test]
194 fn powf_squares() {
195 assert!((powf(3.0, 2.0) - 9.0).abs() < 1e-15);
196 }
197
198 #[test]
199 fn powi_cubes() {
200 assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
201 }
202
203 #[test]
204 fn sin_cos_identity() {
205 let x = 1.0;
206 let s = sin(x);
207 let c = cos(x);
208 assert!((s * s + c * c - 1.0).abs() < 1e-15);
209 }
210
211 #[test]
212 fn floor_ceil_round() {
213 assert_eq!(floor(2.7), 2.0);
214 assert_eq!(ceil(2.3), 3.0);
215 assert_eq!(round(2.5), 3.0);
216 assert_eq!(round(2.4), 2.0);
217 }
218
219 #[test]
220 fn log2_of_8() {
221 assert!((log2(8.0) - 3.0).abs() < 1e-15);
222 }
223
224 #[test]
225 fn tanh_of_0() {
226 assert!((tanh(0.0)).abs() < 1e-15);
227 }
228
229 #[test]
230 fn fmin_fmax() {
231 assert_eq!(fmin(1.0, 2.0), 1.0);
232 assert_eq!(fmax(1.0, 2.0), 2.0);
233 }
234
235 #[test]
236 fn softplus_large_positive() {
237 assert!((softplus(50.0) - 50.0).abs() < 1e-10);
239 }
240
241 #[test]
242 fn softplus_large_negative() {
243 let result = softplus(-50.0);
245 assert!((0.0..1e-20).contains(&result));
246 }
247
248 #[test]
249 fn softplus_zero() {
250 let expected = ln(2.0);
251 assert!((softplus(0.0) - expected).abs() < 1e-12);
252 }
253
254 #[test]
255 fn softplus_always_positive() {
256 for &x in &[-10.0, -1.0, 0.0, 1.0, 10.0] {
257 assert!(softplus(x) > 0.0, "softplus({}) should be > 0", x);
258 }
259 }
260
261 #[test]
262 fn sigmoid_at_zero() {
263 assert!((sigmoid(0.0) - 0.5).abs() < 1e-12);
264 }
265
266 #[test]
267 fn sigmoid_range() {
268 for &x in &[-10.0, -1.0, 0.0, 1.0, 10.0] {
269 let s = sigmoid(x);
270 assert!(
271 (0.0..1.0).contains(&s),
272 "sigmoid({}) = {} should be in (0, 1)",
273 x,
274 s
275 );
276 }
277 }
278
279 #[test]
280 fn sigmoid_symmetry() {
281 let x = 3.0;
282 assert!((sigmoid(x) + sigmoid(-x) - 1.0).abs() < 1e-12);
283 }
284
285 #[test]
286 fn sigmoid_extreme_values() {
287 let s_pos = sigmoid(100.0);
288 let s_neg = sigmoid(-100.0);
289 assert!((0.0..=1.0).contains(&s_pos));
290 assert!((0.0..=1.0).contains(&s_neg));
291 }
292}