raddy/scalar/
borrow_operator_traits_impl.rs1use crate::Ad;
2use std::ops::{Add, Mul};
3
4impl<const N: usize> Ad<N> {
7 pub fn neg(&self) -> Self {
8 let mut res = Self::_zeroed();
9 res.value = -self.value;
10 res.grad = -self.grad;
11 res.hess = -self.hess;
12
13 res
14 }
15
16 pub fn sqrt(&self) -> Self {
17 if self.value < -0.0 {
18 panic!("Sqrt on negative value!");
20 }
21 let f = self.value.sqrt();
22
23 Self::chain(f, 0.5 / f, -0.25 / (f * self.value), self)
24 }
25
26 pub fn square(&self) -> Self {
27 let mut res = Self::_zeroed();
28 res.value = self.value * self.value;
29 res.grad = 2.0 * self.value * self.grad;
30 res.hess = 2.0 * (self.value * self.hess + self.grad * self.grad.transpose());
31
32 res
33 }
34
35 pub fn powi(&self, exponent: i32) -> Self {
36 if self.value.abs() == 0.0 && exponent == 0 {
37 panic!("0.pow(0) is undefined!");
39 }
40
41 let f2 = self.value.powi(exponent - 2);
42 let f1 = f2 * self.value;
43 let f = f1 * self.value;
44
45 let ef = exponent as f64;
47
48 Self::chain(f, ef * f1, ef * (ef - 1.0) * f2, self)
49 }
50
51 pub fn powf(&self, exponent: f64) -> Self {
52 if self.value.abs() == 0.0 && exponent.abs() == 0.0 {
53 panic!("0.pow(0) is undefined!");
55 }
56
57 let f2 = self.value.powf(exponent - 2.0);
58 let f1 = f2 * self.value;
59 let f = f1 * self.value;
60
61 Self::chain(f, exponent * f1, exponent * (exponent - 1.0) * f2, self)
64 }
65
66 pub fn abs(&self) -> Self {
67 let mut res = Self::_zeroed();
68 res.value = self.value.abs();
69 let sign = if self.value >= 0.0 { 1.0 } else { -1.0 };
70 res.grad = sign * self.grad;
71 res.hess = sign * self.hess;
72
73 res
74 }
75
76 pub fn exp(&self) -> Self {
77 let exp_val = self.value.exp();
78
79 Self::chain(exp_val, exp_val, exp_val, self)
80 }
81
82 pub fn ln(&self) -> Self {
83 if self.value <= 0.0 {
84 panic!("Ln on non-positive value!");
85 }
86 let inv = 1.0 / self.value;
87
88 Self::chain(self.value.ln(), inv, -inv * inv, self)
89 }
90
91 pub fn log(&self, base: f64) -> Self {
92 if self.value <= 0.0 {
93 panic!("Log2 on non-positive value!");
94 }
95 if base <= 0.0 {
96 panic!("Base must be positive!");
97 }
98
99 let inv = 1.0 / self.value / base.ln();
100
101 Self::chain(self.value.log(base), inv, -inv / self.value, self)
102 }
103
104 pub fn log2(&self) -> Self {
105 if self.value <= 0.0 {
106 panic!("Log2 on non-positive value!");
107 }
108 let inv = 1.0 / self.value / std::f64::consts::LN_2;
109
110 Self::chain(self.value.log2(), inv, -inv / self.value, self)
111 }
112
113 pub fn log10(&self) -> Self {
114 if self.value <= 0.0 {
115 panic!("Log10 on non-positive value!");
116 }
117 let inv = 1.0 / self.value / std::f64::consts::LN_10;
118
119 Self::chain(self.value.log10(), inv, -inv / self.value, self)
120 }
121
122 pub fn sin(&self) -> Self {
123 let sin_val = self.value.sin();
124 let cos_val = self.value.cos();
125
126 Self::chain(sin_val, cos_val, -sin_val, self)
127 }
128
129 pub fn cos(&self) -> Self {
130 let cos_val = self.value.cos();
131 let sin_val = self.value.sin();
132
133 Self::chain(cos_val, -sin_val, -cos_val, self)
134 }
135
136 pub fn tan(&self) -> Self {
137 let cos_val = self.value.cos();
138 let cos_sq = cos_val * cos_val;
139
140 Self::chain(
141 self.value.tan(),
142 1.0 / cos_sq,
143 2.0 * self.value.sin() / (cos_sq * cos_val),
144 self,
145 )
146 }
147
148 pub fn asin(&self) -> Self {
149 if self.value < -1.0 || self.value > 1.0 {
150 panic!("Asin out of domain!");
151 }
152 let s = 1.0 - self.value * self.value;
153 let s_sqrt = s.sqrt();
154
155 Self::chain(
156 self.value.asin(),
157 1.0 / s_sqrt,
158 self.value / (s * s_sqrt),
159 self,
160 )
161 }
162
163 pub fn acos(&self) -> Self {
164 if self.value < -1.0 || self.value > 1.0 {
165 panic!("Acos out of domain!");
166 }
167 let s = 1.0 - self.value * self.value;
168 let s_sqrt = s.sqrt();
169
170 Self::chain(
171 self.value.acos(),
172 -1.0 / s_sqrt,
173 -self.value / (s * s_sqrt),
174 self,
175 )
176 }
177
178 #[deprecated = "Please use atan2 instead."]
179 pub fn atan(&self) -> Self {
180 let s = self.value * self.value + 1.0;
181
182 Self::chain(
183 self.value.atan(),
184 1.0 / s,
185 -2.0 * self.value / (s * s),
186 self,
187 )
188 }
189
190 pub fn sinh(&self) -> Self {
191 let sinh_val = self.value.sinh();
192 let cosh_val = self.value.cosh();
193
194 Self::chain(sinh_val, cosh_val, sinh_val, self)
195 }
196
197 pub fn cosh(&self) -> Self {
198 let sinh_val = self.value.sinh();
199 let cosh_val = self.value.cosh();
200
201 Self::chain(cosh_val, sinh_val, cosh_val, self)
202 }
203
204 pub fn tanh(&self) -> Self {
205 let cosh_val = self.value.cosh();
206 let cosh_sq = cosh_val * cosh_val;
207
208 Self::chain(
209 self.value.tanh(),
210 1.0 / cosh_sq,
211 -2.0 * self.value.sinh() / (cosh_sq * cosh_val),
212 self,
213 )
214 }
215
216 pub fn asinh(&self) -> Self {
217 let s = self.value * self.value + 1.0;
218 let s_sqrt = s.sqrt();
219
220 Self::chain(
221 self.value.asinh(),
222 1.0 / s_sqrt,
223 -self.value / (s * s_sqrt),
224 self,
225 )
226 }
227
228 pub fn acosh(&self) -> Self {
229 if self.value < 1.0 {
230 panic!("Acosh out of domain!");
231 }
232 let sm = self.value - 1.0;
233 let sp = self.value + 1.0;
234 let prod = (sm * sp).sqrt();
235
236 Self::chain(
237 self.value.acosh(),
238 1.0 / prod,
239 -self.value / (prod * sm * sp),
240 self,
241 )
242 }
243
244 pub fn atanh(&self) -> Self {
245 if self.value <= -1.0 || self.value >= 1.0 {
246 panic!("Atanh out of domain!");
247 }
248 let s = 1.0 - self.value * self.value;
249
250 Self::chain(
251 self.value.atanh(),
252 1.0 / s,
253 2.0 * self.value / (s * s),
254 self,
255 )
256 }
257}
258
259impl<const N: usize> Ad<N> {
262 pub fn add_value(&self, other: f64) -> Self {
263 let mut res = Self::_zeroed();
264 res.value = self.value + other;
265 res
266 }
267
268 pub fn sub_value(&self, other: f64) -> Self {
269 let mut res = Self::_zeroed();
270 res.value = self.value - other;
271 res
272 }
273
274 pub fn mul_value(&self, other: f64) -> Self {
275 let mut res = Self::_zeroed();
276 res.value = self.value * other;
277 res.grad = self.grad * other;
278 res.hess = self.hess * other;
279
280 res
281 }
282
283 pub fn recip(&self) -> Self {
284 Ad::inactive_scalar(1.0) / self
286 }
287
288 pub fn div_value(&self, other: f64) -> Self {
289 if other.abs() == 0.0 {
290 panic!("Division By Zero!");
292 }
293
294 let mut res = Self::_zeroed();
295 res.value = self.value / other;
296 res.grad = self.grad / other;
297 res.hess = self.hess / other;
298
299 res
300 }
301
302 pub fn atan2(&self, x: &Self) -> Self {
304 let mut res = Self::_zeroed();
305
306 res.value = self.value.atan2(x.value);
308
309 let u = x.value * &self.grad - self.value * &x.grad;
311 let v = x.value * x.value + self.value * self.value;
312 res.grad = &u / v;
313
314 let du = x.value * &self.hess - self.value * &x.hess + &self.grad * x.grad.transpose()
316 - &x.grad * self.grad.transpose();
317 let dv = 2.0 * (x.value * &x.grad + self.value * &self.grad);
318 res.hess = (&du - &res.grad * dv.transpose()) / v;
319
320 res
321 }
322
323 pub fn min(&self, other: &Self) -> Self {
324 if self < other {
325 self.clone()
326 } else {
327 other.clone()
328 }
329 }
330
331 pub fn max(&self, other: &Self) -> Self {
332 if self > other {
333 self.clone()
334 } else {
335 other.clone()
336 }
337 }
338
339 pub fn clamp(&self, low: &Self, high: &Self) -> Self {
340 self.max(low).min(high)
341 }
342
343 pub fn hypot(&self, other: &Self) -> Self {
345 (self.mul(self).add(&other.mul(other))).sqrt()
346 }
347}