1use std::f64::consts::FRAC_PI_2;
4
5use fixed::types::I16F16;
6use fixed_exp2::FixedPowF;
7use fixed_trigonometry::sin;
8pub type Fix = I16F16;
9
10macro_rules! easer {
11 ($f:ident, $t:ident, $e:expr) => {
12 pub struct $t {
13 start: Fix,
14 dist: Fix,
15 step: u64,
16 steps: u64,
17 }
18
19 pub fn $f(start: Fix, end: Fix, steps: u64) -> $t {
20 $t {
21 start,
22 dist: end - start,
23 step: 0,
24 steps,
25 }
26 }
27
28 impl $t {
29 pub fn at(x: Fix, start: Fix, dist: Fix) -> Fix {
30 Fix::from_num($e(x)).mul_add(dist, start)
31 }
32 pub fn at_normalized(x: Fix) -> Fix {
33 Self::at(x, Fix::from_num(0), Fix::from_num(1))
34 }
35 }
36
37 impl Iterator for $t {
38 type Item = Fix;
39
40 fn next(&mut self) -> Option<Fix> {
41 self.step += 1;
42 if self.step > self.steps {
43 None
44 } else {
45 let x: Fix = Fix::from_num(self.step) / Fix::from_num(self.steps);
46 Some(Self::at(x, self.start, self.dist))
47 }
48 }
49 }
50 };
51}
52
53easer!(linear, Linear, |x: Fix| { x });
54easer!(quad_in, QuadIn, |x: Fix| { x * x });
55easer!(quad_out, QuadOut, |x: Fix| {
56 -(x * (x - Fix::from_num(2)))
57});
58easer!(quad_inout, QuadInOut, |x: Fix| -> Fix {
59 if x < Fix::from_num(0.5) {
60 Fix::from_num(2) * x * x
61 } else {
62 (Fix::from_num(-2) * x * x) + x.mul_add(Fix::from_num(4), Fix::from_num(-1))
63 }
64});
65easer!(cubic_in, CubicIn, |x: Fix| { x * x * x });
66easer!(cubic_out, CubicOut, |x: Fix| {
67 let y = x - Fix::from_num(1);
68 y * y * y + Fix::from_num(1)
69});
70easer!(cubic_inout, CubicInOut, |x: Fix| {
71 if x < Fix::from_num(0.5) {
72 Fix::from_num(4) * x * x * x
73 } else {
74 let y = x.mul_add(2.into(), Fix::from_num(-2));
75 (y * y * y).mul_add(Fix::from_num(0.5), Fix::from_num(1))
76 }
77});
78easer!(quartic_in, QuarticIn, |x: Fix| { x * x * x * x });
79easer!(quartic_out, QuarticOut, |x: Fix| {
80 let y = x - Fix::from_num(1);
81 (y * y * y).mul_add(Fix::from_num(1) - x, Fix::from_num(1))
82});
83easer!(quartic_inout, QuarticInOut, |x: Fix| {
84 if x < Fix::from_num(0.5) {
85 Fix::from_num(8) * x * x * x * x
86 } else {
87 let y = x - Fix::from_num(1);
88 (y * y * y * y).mul_add(Fix::from_num(-8), Fix::from_num(1))
89 }
90});
91easer!(sin_in, SinIn, |x: Fix| {
92 let y = (x - Fix::from_num(1)) * Fix::from_num(FRAC_PI_2);
93 sin(y) + Fix::from_num(1)
94});
95easer!(sin_out, SinOut, |x: Fix| {
96 sin(x * Fix::from_num(FRAC_PI_2))
97});
98easer!(sin_inout, SinInOut, |x: Fix| {
99 if x < Fix::from_num(0.5) {
100 Fix::from_num(0.5)
101 * (Fix::from_num(1) - (x * x).mul_add(Fix::from_num(-4), Fix::from_num(1)).sqrt())
102 } else {
103 Fix::from_num(0.5)
104 * ((x.mul_add(Fix::from_num(-2), Fix::from_num(3))
105 * x.mul_add(Fix::from_num(2), Fix::from_num(-1)))
106 .sqrt()
107 + Fix::from_num(1))
108 }
109});
110easer!(exp_in, ExpIn, |x: Fix| {
111 if x == 0. {
112 Fix::from_num(0)
113 } else {
114 Fix::from_num(2).powf(Fix::from_num(10) * (x - Fix::from_num(1)))
115 }
116});
117
118easer!(exp_out, ExpOut, |x: Fix| {
119 if x == Fix::from_num(1) {
120 Fix::from_num(1)
121 } else {
122 Fix::from_num(2).powf(-Fix::from_num(10) * x) * Fix::from_num(-1) + Fix::from_num(1)
123 }
124});
125easer!(exp_inout, ExpInOut, |x: Fix| {
126 if x == Fix::from_num(1) {
127 Fix::from_num(1)
128 } else if x == 0. {
129 Fix::from_num(0)
130 } else if x < Fix::from_num(0.5) {
131 Fix::from_num(2).powf(x.mul_add(Fix::from_num(20), Fix::from_num(-10))) * Fix::from_num(0.5)
132 } else {
133 Fix::from_num(2)
134 .powf(x.mul_add(Fix::from_num(-20), Fix::from_num(10)))
135 .mul_add(Fix::from_num(-0.5), Fix::from_num(1))
136 }
137});
138easer!(smoothstep, SmoothStep, |x: Fix| {
139 if x == Fix::from_num(1) {
140 Fix::from_num(1)
141 } else if x == 0. {
142 Fix::from_num(0)
143 } else {
144 x * x * (Fix::from_num(3) - Fix::from_num(2) * x)
145 }
146});
147
148#[cfg(test)]
149mod test {
150 const ERROR_MARGIN_FAC: f64 = 0.00015;
152
153 use std::{fs::File, iter::zip, path::PathBuf};
154
155 use anyhow::anyhow;
156
157 use super::*;
158 macro_rules! function {
159 () => {{
160 fn f() {}
161 fn type_name_of<T>(_: T) -> &'static str {
162 std::any::type_name::<T>()
163 }
164 let name = type_name_of(f);
165 let full_path = name.strip_suffix("::f").unwrap();
166 full_path.split("::").last().unwrap()
167 }};
168 }
169
170 fn must_be_withing_error_margin_else_write(
172 test_name: &str,
173 ought_data: Vec<f64>,
174 is_data: Vec<Fix>,
175 ) -> anyhow::Result<()> {
176 let min = ought_data
177 .iter()
178 .min_by(|a, b| a.partial_cmp(b).unwrap())
179 .unwrap();
180 let max = ought_data
181 .iter()
182 .max_by(|a, b| a.partial_cmp(b).unwrap())
183 .unwrap();
184 let value_range = max - min;
185
186 let max_error = (value_range * ERROR_MARGIN_FAC).abs();
187 let inside_margin = |is: f64, ought: f64| -> bool {
188 let delta = (is - ought).abs();
189
190 let res = delta <= max_error;
191 if !res {
192 eprintln!("measured error outside of acceptable margin: {is} <> {ought}");
193 }
194 res
195 };
196
197 let mut ok = true;
198 for (ought, is) in zip(ought_data.iter(), is_data.iter()) {
199 let is = is.to_num::<f64>();
200 if !inside_margin(is, *ought) {
201 ok = false;
202 }
203 }
204
205 if !ok {
206 let root = option_env!("CARGO_MANIFEST_DIR")
207 .ok_or(anyhow!("missing env var CARGO_MANIFEST_DIR"))?;
208 let root = PathBuf::from(root);
209 let target_path = root.join("jupyter-tests");
210
211 let ought_file = File::create(target_path.join(format!("{test_name}-ought.json")))?;
212 let is_file = File::create(target_path.join(format!("{test_name}-is.json")))?;
213
214 let is_converted = is_data
215 .into_iter()
216 .map(|fix| fix.to_num())
217 .collect::<Vec<f64>>();
218 serde_json::to_writer(ought_file, &ought_data)?;
219 serde_json::to_writer(is_file, &is_converted)?;
220
221 panic!("{test_name} outside of error {ERROR_MARGIN_FAC} margin: {is_converted:?} <> {ought_data:?}");
222 }
223 Ok(())
224 }
225
226 #[test]
227 fn at() {
228 let res = ExpInOut::at(Fix::from_num(0.5), Fix::from_num(10.), Fix::from_num(100));
229 assert_eq!(res, Fix::from_num(60.));
230 }
231
232 #[test]
233 fn at_normalized() {
234 let res = ExpInOut::at_normalized(Fix::from_num(0.75));
235 assert_eq!(res, Fix::from_num(0.984375));
236 }
237
238 #[test]
239 fn linear_test() -> anyhow::Result<()> {
240 let model = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
241 let res: Vec<Fix> = linear(Fix::from_num(0), Fix::from_num(1), 10).collect();
242 must_be_withing_error_margin_else_write(function!(), model, res)
243 }
244
245 #[test]
246 fn quad_in_test() -> anyhow::Result<()> {
247 let model = vec![
248 100., 400., 900., 1600., 2500., 3600., 4900., 6400., 8100., 10000.,
249 ];
250 let res: Vec<Fix> = quad_in(Fix::from_num(0), Fix::from_num(10000), 10).collect();
251 must_be_withing_error_margin_else_write(function!(), model, res)
252 }
253
254 #[test]
255 fn quad_out_test() -> anyhow::Result<()> {
256 let model = vec![
257 1900., 3600., 5100., 6400., 7500., 8400., 9100., 9600., 9900., 10000.,
258 ];
259 let res: Vec<Fix> = quad_out(Fix::from_num(0), Fix::from_num(10000), 10).collect();
260 must_be_withing_error_margin_else_write(function!(), model, res)
261 }
262
263 #[test]
264 fn quad_inout_test() -> anyhow::Result<()> {
265 let model = vec![
266 200., 800., 1800., 3200., 5000., 6800., 8200., 9200., 9800., 10000.,
267 ];
268 let res: Vec<Fix> = quad_inout(Fix::from_num(0), Fix::from_num(10000), 10).collect();
269 must_be_withing_error_margin_else_write(function!(), model, res)
270 }
271
272 #[test]
273 fn cubic_in_test() -> anyhow::Result<()> {
274 let model = vec![
275 10., 80., 270., 640., 1250., 2160., 3430., 5120., 7290., 10000.,
276 ];
277 let res: Vec<Fix> = cubic_in(Fix::from_num(0), Fix::from_num(10000), 10).collect();
278 must_be_withing_error_margin_else_write(function!(), model, res)
279 }
280
281 #[test]
282 fn cubic_out_test() -> anyhow::Result<()> {
283 let model = vec![
284 2710., 4880., 6570., 7840., 8750., 9360., 9730., 9920., 9990., 10000.,
285 ];
286 let res: Vec<Fix> = cubic_out(Fix::from_num(0), Fix::from_num(10000), 10).collect();
287 must_be_withing_error_margin_else_write(function!(), model, res)
288 }
289
290 #[test]
291 fn quartic_in_test() -> anyhow::Result<()> {
292 let model = vec![1., 16., 81., 256., 625., 1296., 2401., 4096., 6561., 10000.];
293 let res: Vec<Fix> = quartic_in(Fix::from_num(0), Fix::from_num(10000), 10).collect();
294 must_be_withing_error_margin_else_write(function!(), model, res)
295 }
296
297 #[test]
298 fn quartic_out_test() -> anyhow::Result<()> {
299 let model = vec![
300 3439., 5904., 7599., 8704., 9375., 9744., 9919., 9984., 9999., 10000.,
301 ];
302 let res: Vec<Fix> = quartic_out(Fix::from_num(0), Fix::from_num(10000), 10).collect();
303 must_be_withing_error_margin_else_write(function!(), model, res)
304 }
305
306 #[test]
307 fn quartic_inout_test() -> anyhow::Result<()> {
308 let model = vec![
309 8., 128., 648., 2048., 5000., 7952., 9352., 9872., 9992., 10000.,
310 ];
311 let res: Vec<Fix> = quartic_inout(Fix::from_num(0), Fix::from_num(10000), 10).collect();
312 must_be_withing_error_margin_else_write(function!(), model, res)
313 }
314
315 #[test]
316 fn sin_in_test() -> anyhow::Result<()> {
317 let model = vec![
318 123.116594,
319 489.434837,
320 1089.934758,
321 1909.830056,
322 2928.932188,
323 4122.147477,
324 5460.095003,
325 6909.830056,
326 8435.655350,
327 10000.,
328 ];
329 let res: Vec<Fix> = sin_in(Fix::from_num(0), Fix::from_num(10000), 10).collect();
330 must_be_withing_error_margin_else_write(function!(), model, res)
331 }
332
333 #[test]
334 fn sin_out_test() -> anyhow::Result<()> {
335 let model = vec![
336 1564.344650,
337 3090.169944,
338 4539.904997,
339 5877.852523,
340 7071.067812,
341 8090.169944,
342 8910.065242,
343 9510.565163,
344 9876.883406,
345 10000.,
346 ];
347 let res: Vec<Fix> = sin_out(Fix::from_num(0), Fix::from_num(10000), 10).collect();
348 must_be_withing_error_margin_else_write(function!(), model, res)
349 }
350
351 #[test]
352 fn sin_inout_test() -> anyhow::Result<()> {
353 let model = vec![
354 101.020514,
355 417.424305,
356 1000.,
357 2000.,
358 5000.,
359 8000.,
360 9000.,
361 9582.575695,
362 9898.979486,
363 10000.,
364 ];
365 let res: Vec<Fix> = sin_inout(Fix::from_num(0), Fix::from_num(10000), 10).collect();
366 must_be_withing_error_margin_else_write(function!(), model, res)
367 }
368
369 #[test]
370 fn exp_in_test() -> anyhow::Result<()> {
371 let model = vec![
372 19.53125, 39.0625, 78.125, 156.25, 312.5, 625., 1250., 2500., 5000., 10000.,
373 ];
374 let res: Vec<Fix> = exp_in(Fix::from_num(0), Fix::from_num(10000), 10).collect();
375 must_be_withing_error_margin_else_write(function!(), model, res)
376 }
377
378 #[test]
379 fn exp_out_test() -> anyhow::Result<()> {
380 let model = vec![
381 5000., 7500., 8750., 9375., 9687.5, 9843.75, 9921.875, 9960.9375, 9980.46875, 10000.,
382 ];
383 let res: Vec<Fix> = exp_out(Fix::from_num(0), Fix::from_num(10000), 10).collect();
384 must_be_withing_error_margin_else_write(function!(), model, res)
385 }
386
387 #[test]
388 fn exp_inout_test() -> anyhow::Result<()> {
389 let model = vec![
390 19.53125, 78.125, 312.5, 1250., 5000., 8750., 9687.5, 9921.875, 9980.46875, 10000.,
391 ];
392 let res: Vec<Fix> = exp_inout(Fix::from_num(0), Fix::from_num(10000), 10).collect();
393 must_be_withing_error_margin_else_write(function!(), model, res)
394 }
395
396 #[test]
397 fn smoothstep_test() -> anyhow::Result<()> {
398 let model = vec![
399 280.00000000000006,
400 1040.0000000000002,
401 2160.0,
402 3520.000000000001,
403 5000.0,
404 6480.0,
405 7839.999999999999,
406 8960.000000000002,
407 9720.0,
408 10000.,
409 ];
410 let res: Vec<Fix> = smoothstep(Fix::from_num(0), Fix::from_num(10000), 10).collect();
411 must_be_withing_error_margin_else_write(function!(), model, res)
412 }
413}