raddy/scalar/
borrow_operator_traits_impl.rs

1use crate::Ad;
2use std::ops::{Add, Mul};
3
4// ################################### Unary Operators ###################################
5
6impl<const N: usize> Ad<N> {
7    pub fn neg(&self) -> Self {
8        let mut res = Self::_zeroed();
9        res.value = -self.value;
10        res.grad = -self.grad;
11        res.hess = -self.hess;
12
13        res
14    }
15
16    pub fn sqrt(&self) -> Self {
17        if self.value < -0.0 {
18            // We don't want to mute this behavior or get NaN as this is fucking undebuggable.
19            panic!("Sqrt on negative value!");
20        }
21        let f = self.value.sqrt();
22
23        Self::chain(f, 0.5 / f, -0.25 / (f * self.value), self)
24    }
25
26    pub fn square(&self) -> Self {
27        let mut res = Self::_zeroed();
28        res.value = self.value * self.value;
29        res.grad = 2.0 * self.value * self.grad;
30        res.hess = 2.0 * (self.value * self.hess + self.grad * self.grad.transpose());
31
32        res
33    }
34
35    pub fn powi(&self, exponent: i32) -> Self {
36        if self.value.abs() == 0.0 && exponent == 0 {
37            // We don't want to mute this behavior or get NaN as this is fucking undebuggable.
38            panic!("0.pow(0) is undefined!");
39        }
40
41        let f2 = self.value.powi(exponent - 2);
42        let f1 = f2 * self.value;
43        let f = f1 * self.value;
44
45        // exponent in float
46        let ef = exponent as f64;
47
48        Self::chain(f, ef * f1, ef * (ef - 1.0) * f2, self)
49    }
50
51    pub fn powf(&self, exponent: f64) -> Self {
52        if self.value.abs() == 0.0 && exponent.abs() == 0.0 {
53            // We don't want to mute this behavior or get NaN as this is fucking undebuggable.
54            panic!("0.pow(0) is undefined!");
55        }
56
57        let f2 = self.value.powf(exponent - 2.0);
58        let f1 = f2 * self.value;
59        let f = f1 * self.value;
60
61        // exponent in float
62
63        Self::chain(f, exponent * f1, exponent * (exponent - 1.0) * f2, self)
64    }
65
66    pub fn abs(&self) -> Self {
67        let mut res = Self::_zeroed();
68        res.value = self.value.abs();
69        let sign = if self.value >= 0.0 { 1.0 } else { -1.0 };
70        res.grad = sign * self.grad;
71        res.hess = sign * self.hess;
72
73        res
74    }
75
76    pub fn exp(&self) -> Self {
77        let exp_val = self.value.exp();
78
79        Self::chain(exp_val, exp_val, exp_val, self)
80    }
81
82    pub fn ln(&self) -> Self {
83        if self.value <= 0.0 {
84            panic!("Ln on non-positive value!");
85        }
86        let inv = 1.0 / self.value;
87
88        Self::chain(self.value.ln(), inv, -inv * inv, self)
89    }
90
91    pub fn log(&self, base: f64) -> Self {
92        if self.value <= 0.0 {
93            panic!("Log2 on non-positive value!");
94        }
95        if base <= 0.0 {
96            panic!("Base must be positive!");
97        }
98
99        let inv = 1.0 / self.value / base.ln();
100
101        Self::chain(self.value.log(base), inv, -inv / self.value, self)
102    }
103
104    pub fn log2(&self) -> Self {
105        if self.value <= 0.0 {
106            panic!("Log2 on non-positive value!");
107        }
108        let inv = 1.0 / self.value / std::f64::consts::LN_2;
109
110        Self::chain(self.value.log2(), inv, -inv / self.value, self)
111    }
112
113    pub fn log10(&self) -> Self {
114        if self.value <= 0.0 {
115            panic!("Log10 on non-positive value!");
116        }
117        let inv = 1.0 / self.value / std::f64::consts::LN_10;
118
119        Self::chain(self.value.log10(), inv, -inv / self.value, self)
120    }
121
122    pub fn sin(&self) -> Self {
123        let sin_val = self.value.sin();
124        let cos_val = self.value.cos();
125
126        Self::chain(sin_val, cos_val, -sin_val, self)
127    }
128
129    pub fn cos(&self) -> Self {
130        let cos_val = self.value.cos();
131        let sin_val = self.value.sin();
132
133        Self::chain(cos_val, -sin_val, -cos_val, self)
134    }
135
136    pub fn tan(&self) -> Self {
137        let cos_val = self.value.cos();
138        let cos_sq = cos_val * cos_val;
139
140        Self::chain(
141            self.value.tan(),
142            1.0 / cos_sq,
143            2.0 * self.value.sin() / (cos_sq * cos_val),
144            self,
145        )
146    }
147
148    pub fn asin(&self) -> Self {
149        if self.value < -1.0 || self.value > 1.0 {
150            panic!("Asin out of domain!");
151        }
152        let s = 1.0 - self.value * self.value;
153        let s_sqrt = s.sqrt();
154
155        Self::chain(
156            self.value.asin(),
157            1.0 / s_sqrt,
158            self.value / (s * s_sqrt),
159            self,
160        )
161    }
162
163    pub fn acos(&self) -> Self {
164        if self.value < -1.0 || self.value > 1.0 {
165            panic!("Acos out of domain!");
166        }
167        let s = 1.0 - self.value * self.value;
168        let s_sqrt = s.sqrt();
169
170        Self::chain(
171            self.value.acos(),
172            -1.0 / s_sqrt,
173            -self.value / (s * s_sqrt),
174            self,
175        )
176    }
177
178    #[deprecated = "Please use atan2 instead."]
179    pub fn atan(&self) -> Self {
180        let s = self.value * self.value + 1.0;
181
182        Self::chain(
183            self.value.atan(),
184            1.0 / s,
185            -2.0 * self.value / (s * s),
186            self,
187        )
188    }
189
190    pub fn sinh(&self) -> Self {
191        let sinh_val = self.value.sinh();
192        let cosh_val = self.value.cosh();
193
194        Self::chain(sinh_val, cosh_val, sinh_val, self)
195    }
196
197    pub fn cosh(&self) -> Self {
198        let sinh_val = self.value.sinh();
199        let cosh_val = self.value.cosh();
200
201        Self::chain(cosh_val, sinh_val, cosh_val, self)
202    }
203
204    pub fn tanh(&self) -> Self {
205        let cosh_val = self.value.cosh();
206        let cosh_sq = cosh_val * cosh_val;
207
208        Self::chain(
209            self.value.tanh(),
210            1.0 / cosh_sq,
211            -2.0 * self.value.sinh() / (cosh_sq * cosh_val),
212            self,
213        )
214    }
215
216    pub fn asinh(&self) -> Self {
217        let s = self.value * self.value + 1.0;
218        let s_sqrt = s.sqrt();
219
220        Self::chain(
221            self.value.asinh(),
222            1.0 / s_sqrt,
223            -self.value / (s * s_sqrt),
224            self,
225        )
226    }
227
228    pub fn acosh(&self) -> Self {
229        if self.value < 1.0 {
230            panic!("Acosh out of domain!");
231        }
232        let sm = self.value - 1.0;
233        let sp = self.value + 1.0;
234        let prod = (sm * sp).sqrt();
235
236        Self::chain(
237            self.value.acosh(),
238            1.0 / prod,
239            -self.value / (prod * sm * sp),
240            self,
241        )
242    }
243
244    pub fn atanh(&self) -> Self {
245        if self.value <= -1.0 || self.value >= 1.0 {
246            panic!("Atanh out of domain!");
247        }
248        let s = 1.0 - self.value * self.value;
249
250        Self::chain(
251            self.value.atanh(),
252            1.0 / s,
253            2.0 * self.value / (s * s),
254            self,
255        )
256    }
257}
258
259// ################################### Binary Operators ###################################
260
261impl<const N: usize> Ad<N> {
262    pub fn add_value(&self, other: f64) -> Self {
263        let mut res = Self::_zeroed();
264        res.value = self.value + other;
265        res
266    }
267
268    pub fn sub_value(&self, other: f64) -> Self {
269        let mut res = Self::_zeroed();
270        res.value = self.value - other;
271        res
272    }
273
274    pub fn mul_value(&self, other: f64) -> Self {
275        let mut res = Self::_zeroed();
276        res.value = self.value * other;
277        res.grad = self.grad * other;
278        res.hess = self.hess * other;
279
280        res
281    }
282
283    pub fn recip(&self) -> Self {
284        // todo!("resolve codegen problem")
285        Ad::inactive_scalar(1.0) / self
286    }
287
288    pub fn div_value(&self, other: f64) -> Self {
289        if other.abs() == 0.0 {
290            // We don't want to mute this behavior or get NaN as this is fucking undebuggable.
291            panic!("Division By Zero!");
292        }
293
294        let mut res = Self::_zeroed();
295        res.value = self.value / other;
296        res.grad = self.grad / other;
297        res.hess = self.hess / other;
298
299        res
300    }
301
302    /// ## self is y
303    pub fn atan2(&self, x: &Self) -> Self {
304        let mut res = Self::_zeroed();
305
306        // Compute scalar value of atan2
307        res.value = self.value.atan2(x.value);
308
309        // Gradient computation
310        let u = x.value * &self.grad - self.value * &x.grad;
311        let v = x.value * x.value + self.value * self.value;
312        res.grad = &u / v;
313
314        // Hessian computation (if enabled)
315        let du = x.value * &self.hess - self.value * &x.hess + &self.grad * x.grad.transpose()
316            - &x.grad * self.grad.transpose();
317        let dv = 2.0 * (x.value * &x.grad + self.value * &self.grad);
318        res.hess = (&du - &res.grad * dv.transpose()) / v;
319
320        res
321    }
322
323    pub fn min(&self, other: &Self) -> Self {
324        if self < other {
325            self.clone()
326        } else {
327            other.clone()
328        }
329    }
330
331    pub fn max(&self, other: &Self) -> Self {
332        if self > other {
333            self.clone()
334        } else {
335            other.clone()
336        }
337    }
338
339    pub fn clamp(&self, low: &Self, high: &Self) -> Self {
340        self.max(low).min(high)
341    }
342
343    // Computes hypot(self, b) = sqrt(self^2 + b^2) with gradients and Hessians if enabled.
344    pub fn hypot(&self, other: &Self) -> Self {
345        (self.mul(self).add(&other.mul(other))).sqrt()
346    }
347}