1use std::hash::{Hash, Hasher};
2
3use bevy::{
4 math::{FloatOrd, Quat, Vec2, Vec3, Vec3A, Vec4},
5 reflect::{FromReflect, Reflect},
6};
7use serde::{Deserialize, Serialize};
8
9pub trait Lerp: Copy {
14 fn lerp(self, other: Self, ratio: f32) -> Self;
15}
16
17impl Lerp for f32 {
18 #[inline]
19 fn lerp(self, other: Self, ratio: f32) -> Self {
20 self.mul_add(1. - ratio, other * ratio)
21 }
22}
23
24impl Lerp for f64 {
25 #[inline]
26 fn lerp(self, other: Self, ratio: f32) -> Self {
27 self.mul_add((1. - ratio) as f64, other * ratio as f64)
28 }
29}
30
31macro_rules! impl_lerp_vecn {
32 ($t:ty) => {
33 impl Lerp for $t {
34 #[inline]
35 fn lerp(self, other: Self, ratio: f32) -> Self {
36 <$t>::lerp(self, other, ratio)
38 }
39 }
40 };
41}
42
43impl_lerp_vecn!(Vec2);
44impl_lerp_vecn!(Vec3);
45impl_lerp_vecn!(Vec3A);
46impl_lerp_vecn!(Vec4);
47
48impl Lerp for Quat {
49 fn lerp(self, other: Self, ratio: f32) -> Self {
50 self.slerp(other, ratio)
55 }
56}
57
58#[derive(Debug, Default, Clone, Copy, PartialEq, Reflect, Serialize, Deserialize)]
60pub struct GradientKey<T: Lerp + FromReflect> {
61 ratio: f32,
63
64 pub value: T,
69}
70
71impl<T: Lerp + FromReflect> GradientKey<T> {
72 pub fn ratio(&self) -> f32 {
74 self.ratio
75 }
76}
77
78impl Hash for GradientKey<f32> {
79 fn hash<H: Hasher>(&self, state: &mut H) {
80 FloatOrd(self.ratio).hash(state);
81 FloatOrd(self.value).hash(state);
82 }
83}
84
85impl Hash for GradientKey<Vec2> {
86 fn hash<H: Hasher>(&self, state: &mut H) {
87 FloatOrd(self.ratio).hash(state);
88 FloatOrd(self.value.x).hash(state);
89 FloatOrd(self.value.y).hash(state);
90 }
91}
92
93impl Hash for GradientKey<Vec3> {
94 fn hash<H: Hasher>(&self, state: &mut H) {
95 FloatOrd(self.ratio).hash(state);
96 FloatOrd(self.value.x).hash(state);
97 FloatOrd(self.value.y).hash(state);
98 FloatOrd(self.value.z).hash(state);
99 }
100}
101
102impl Hash for GradientKey<Vec4> {
103 fn hash<H: Hasher>(&self, state: &mut H) {
104 FloatOrd(self.ratio).hash(state);
105 FloatOrd(self.value.x).hash(state);
106 FloatOrd(self.value.y).hash(state);
107 FloatOrd(self.value.z).hash(state);
108 FloatOrd(self.value.w).hash(state);
109 }
110}
111
112#[derive(Debug, Default, Clone, PartialEq, Reflect, Serialize, Deserialize)]
131pub struct Gradient<T: Lerp + FromReflect> {
132 keys: Vec<GradientKey<T>>,
133}
134
135#[allow(clippy::derived_hash_with_manual_eq)]
138impl<T> Hash for Gradient<T>
139where
140 T: Default + Lerp + FromReflect,
141 GradientKey<T>: Hash,
142{
143 fn hash<H: Hasher>(&self, state: &mut H) {
144 self.keys.hash(state);
145 }
146}
147
148impl<T: Lerp + FromReflect> Gradient<T> {
149 pub const fn new() -> Self {
159 Self { keys: vec![] }
160 }
161
162 pub fn constant(value: T) -> Self {
176 Self {
177 keys: vec![GradientKey::<T> { ratio: 0., value }],
178 }
179 }
180
181 pub fn linear(start: T, end: T) -> Self {
195 Self {
196 keys: vec![
197 GradientKey::<T> {
198 ratio: 0.,
199 value: start,
200 },
201 GradientKey::<T> {
202 ratio: 1.,
203 value: end,
204 },
205 ],
206 }
207 }
208
209 pub fn from_keys(keys: impl IntoIterator<Item = (f32, T)>) -> Self {
240 let mut keys = keys
243 .into_iter()
244 .map(|(ratio, value)| GradientKey { ratio, value })
245 .collect::<Vec<_>>();
246 keys.sort_by(|a, b| FloatOrd(a.ratio).cmp(&FloatOrd(b.ratio)));
247 Self { keys }
248 }
249
250 pub fn is_empty(&self) -> bool {
263 self.keys.is_empty()
264 }
265
266 pub fn len(&self) -> usize {
277 self.keys.len()
278 }
279
280 pub fn with_key(mut self, ratio: f32, value: T) -> Self {
300 self.add_key(ratio, value);
301 self
302 }
303
304 pub fn with_keys(mut self, keys: impl ExactSizeIterator<Item = (f32, T)>) -> Self {
334 self.keys.reserve(keys.len());
335 for (ratio, value) in keys {
336 self.add_key(ratio, value);
337 }
338 self
339 }
340
341 pub fn add_key(&mut self, ratio: f32, value: T) {
352 assert!(ratio >= 0.0);
353 assert!(ratio <= 1.0);
354 let index = match self
355 .keys
356 .binary_search_by(|key| FloatOrd(key.ratio).cmp(&FloatOrd(ratio)))
357 {
358 Ok(mut index) => {
359 let len = self.keys.len();
363 while index + 1 < len && self.keys[index].ratio == self.keys[index + 1].ratio {
364 index += 1;
365 }
366 index + 1 }
368 Err(upper_index) => upper_index,
369 };
370 self.keys.insert(index, GradientKey { ratio, value });
371 }
372
373 pub fn keys(&self) -> &[GradientKey<T>] {
375 &self.keys[..]
376 }
377
378 pub fn keys_mut(&mut self) -> &mut [GradientKey<T>] {
380 &mut self.keys[..]
381 }
382
383 pub fn sample(&self, ratio: f32) -> T {
395 assert!(!self.keys.is_empty());
396 match self
397 .keys
398 .binary_search_by(|key| FloatOrd(key.ratio).cmp(&FloatOrd(ratio)))
399 {
400 Ok(mut index) => {
401 while index > 0 && self.keys[index - 1].ratio == self.keys[index].ratio {
404 index -= 1;
405 }
406 self.keys[index].value
407 }
408 Err(upper_index) => {
409 if upper_index > 0 {
410 if upper_index < self.keys.len() {
411 let key0 = &self.keys[upper_index - 1];
412 let key1 = &self.keys[upper_index];
413 let t = (ratio - key0.ratio) / (key1.ratio - key0.ratio);
414 key0.value.lerp(key1.value, t)
415 } else {
416 self.keys[upper_index - 1].value
418 }
419 } else {
420 self.keys[upper_index].value
422 }
423 }
424 }
425 }
426
427 pub fn sample_by(&self, start: f32, inc: f32, dst: &mut [T]) {
439 let count = dst.len();
440 assert!(!self.keys.is_empty());
441 let mut ratio = start;
442 let first_ratio = self.keys[0].ratio;
444 let first_col = self.keys[0].value;
445 let mut idst = 0;
446 while idst < count && ratio <= first_ratio {
447 dst[idst] = first_col;
448 idst += 1;
449 ratio += inc;
450 }
451 let mut ikey = 1;
453 let len = self.keys.len();
454 for i in idst..count {
455 while ikey < len && ratio > self.keys[ikey].ratio {
457 ikey += 1;
458 }
459 if ikey >= len {
460 let last_col = self.keys[len - 1].value;
462 for d in &mut dst[i..] {
463 *d = last_col;
464 }
465 return;
466 }
467 if self.keys[ikey].ratio == ratio {
468 dst[i] = self.keys[ikey].value;
469 } else {
470 let k0 = &self.keys[ikey - 1];
471 let k1 = &self.keys[ikey];
472 let t = (ratio - k0.ratio) / (k1.ratio - k0.ratio);
473 dst[i] = k0.value.lerp(k1.value, t);
474 }
475 ratio += inc;
476 }
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use std::collections::hash_map::DefaultHasher;
483
484 use bevy::reflect::{PartialReflect, ReflectRef, Struct};
485 use rand::{distr::StandardUniform, prelude::Distribution, rng, rngs::ThreadRng, Rng};
486
487 use super::*;
488 use crate::test_utils::*;
489
490 const RED: Vec4 = Vec4::new(1., 0., 0., 1.);
491 const BLUE: Vec4 = Vec4::new(0., 0., 1., 1.);
492 const GREEN: Vec4 = Vec4::new(0., 1., 0., 1.);
493
494 fn make_test_gradient() -> Gradient<Vec4> {
495 let mut g = Gradient::new();
496 g.add_key(0.5, RED);
497 g.add_key(0.8, BLUE);
498 g.add_key(0.8, GREEN);
499 g
500 }
501
502 fn color_approx_eq(c0: Vec4, c1: Vec4, tol: f32) -> bool {
503 ((c0.x - c1.x).abs() < tol)
504 && ((c0.y - c1.y).abs() < tol)
505 && ((c0.z - c1.z).abs() < tol)
506 && ((c0.w - c1.w).abs() < tol)
507 }
508
509 #[test]
510 fn lerp_test() {
511 assert_approx_eq!(Lerp::lerp(3_f32, 5_f32, 0.1), 3.2_f32);
512 assert_approx_eq!(Lerp::lerp(3_f32, 5_f32, 0.5), 4.0_f32);
513 assert_approx_eq!(Lerp::lerp(3_f32, 5_f32, 0.9), 4.8_f32);
514 assert_approx_eq!(Lerp::lerp(5_f32, 3_f32, 0.1), 4.8_f32);
515 assert_approx_eq!(Lerp::lerp(5_f32, 3_f32, 0.5), 4.0_f32);
516 assert_approx_eq!(Lerp::lerp(5_f32, 3_f32, 0.9), 3.2_f32);
517
518 assert_approx_eq!(Lerp::lerp(3_f64, 5_f64, 0.1), 3.2_f64);
519 assert_approx_eq!(Lerp::lerp(3_f64, 5_f64, 0.5), 4.0_f64);
520 assert_approx_eq!(Lerp::lerp(3_f64, 5_f64, 0.9), 4.8_f64);
521 assert_approx_eq!(Lerp::lerp(5_f64, 3_f64, 0.1), 4.8_f64);
522 assert_approx_eq!(Lerp::lerp(5_f64, 3_f64, 0.5), 4.0_f64);
523 assert_approx_eq!(Lerp::lerp(5_f64, 3_f64, 0.9), 3.2_f64);
524
525 let s = Quat::IDENTITY;
526 let e = Quat::from_rotation_x(90_f32.to_radians());
527 assert_approx_eq!(Lerp::lerp(s, e, 0.1), s.slerp(e, 0.1));
528 assert_approx_eq!(Lerp::lerp(s, e, 0.5), s.slerp(e, 0.5));
529 assert_approx_eq!(Lerp::lerp(s, e, 0.9), s.slerp(e, 0.9));
530 assert_approx_eq!(Lerp::lerp(e, s, 0.1), s.slerp(e, 0.9));
531 assert_approx_eq!(Lerp::lerp(e, s, 0.5), s.slerp(e, 0.5));
532 assert_approx_eq!(Lerp::lerp(e, s, 0.9), s.slerp(e, 0.1));
533 }
534
535 #[test]
536 fn constant() {
537 let grad = Gradient::constant(3.0);
538 assert!(!grad.is_empty());
539 assert_eq!(grad.len(), 1);
540 for r in [
541 -1e5, -0.5, -0.0001, 0., 0.0001, 0.3, 0.5, 0.9, 0.9999, 1., 1.0001, 100., 1e5,
542 ] {
543 assert_approx_eq!(grad.sample(r), 3.0);
544 }
545 }
546
547 #[test]
548 fn with_keys() {
549 let g = Gradient::new().with_keys([(0.5, RED), (0.8, BLUE)].into_iter());
550 assert_eq!(g.len(), 2);
551 let g2 = g.with_keys([(0.5, BLUE), (0.8, RED)].into_iter());
553 assert_eq!(g2.len(), 4);
554 assert_eq!(g2.sample(0.499), RED);
555 assert_eq!(g2.sample(0.501), BLUE);
556 assert_eq!(g2.sample(0.799), BLUE);
557 assert_eq!(g2.sample(0.801), RED);
558 }
559
560 #[test]
561 fn add_key() {
562 let mut g = Gradient::new();
563 assert!(g.is_empty());
564 assert_eq!(g.len(), 0);
565 g.add_key(0.3, RED);
566 assert!(!g.is_empty());
567 assert_eq!(g.len(), 1);
568 let mut g = g.with_key(0.3, RED);
570 assert_eq!(g.len(), 2);
571 g.add_key(0.7, BLUE);
573 g.add_key(0.7, GREEN);
574 assert_eq!(g.len(), 4);
575 let keys = g.keys();
576 assert_eq!(keys.len(), 4);
577 assert!(color_approx_eq(RED, keys[0].value, 1e-5));
578 assert!(color_approx_eq(RED, keys[1].value, 1e-5));
579 assert!(color_approx_eq(BLUE, keys[2].value, 1e-5));
580 assert!(color_approx_eq(GREEN, keys[3].value, 1e-5));
581 }
582
583 #[test]
584 fn sample() {
585 let mut g = Gradient::new();
586 g.add_key(0.5, RED);
587 assert_eq!(RED, g.sample(0.0));
588 assert_eq!(RED, g.sample(0.5));
589 assert_eq!(RED, g.sample(1.0));
590 g.add_key(0.8, BLUE);
591 g.add_key(0.8, GREEN);
592 assert_eq!(RED, g.sample(0.0));
593 assert_eq!(RED, g.sample(0.499));
594 assert_eq!(RED, g.sample(0.5));
595 let expected = RED.lerp(BLUE, 1. / 3.);
596 let actual = g.sample(0.6);
597 assert!(color_approx_eq(actual, expected, 1e-5));
598 assert_eq!(BLUE, g.sample(0.8));
599 assert_eq!(GREEN, g.sample(0.801));
600 assert_eq!(GREEN, g.sample(1.0));
601 }
602
603 #[test]
604 fn sample_by() {
605 let g = Gradient::from_keys([(0.5, RED), (0.8, BLUE)]);
606 const COUNT: usize = 256;
607 let mut data: [Vec4; COUNT] = [Vec4::ZERO; COUNT];
608 let start = 0.;
609 let inc = 1. / COUNT as f32;
610 g.sample_by(start, inc, &mut data[..]);
611 for (i, &d) in data.iter().enumerate() {
612 let ratio = inc.mul_add(i as f32, start);
613 let expected = g.sample(ratio);
614 assert!(color_approx_eq(expected, d, 1e-5));
615 }
616 }
617
618 #[test]
619 fn reflect() {
620 let g = make_test_gradient();
621
622 let reflect: &dyn PartialReflect = &g;
624 assert!(reflect
625 .get_represented_type_info()
626 .unwrap()
627 .is::<Gradient<Vec4>>());
628 let g_reflect = reflect.try_downcast_ref::<Gradient<Vec4>>();
629 assert!(g_reflect.is_some());
630 let g_reflect = g_reflect.unwrap();
631 assert_eq!(*g_reflect, g);
632
633 let g_from = Gradient::<Vec4>::from_reflect(reflect).unwrap();
635 assert_eq!(g_from, g);
636
637 assert!(g
639 .get_represented_type_info()
640 .unwrap()
641 .type_path()
642 .starts_with("bevy_hanabi::gradient::Gradient<")); let keys = g.field("keys").unwrap();
644 let ReflectRef::List(keys) = keys.reflect_ref() else {
645 panic!("Invalid type");
646 };
647 assert_eq!(keys.len(), 3);
648 for (i, (r, v)) in [(0.5, RED), (0.8, BLUE), (0.8, GREEN)].iter().enumerate() {
649 let k = keys.get(i).unwrap();
650 let gk = k.try_downcast_ref::<GradientKey<Vec4>>().unwrap();
651 assert_approx_eq!(gk.ratio(), r);
652 assert_approx_eq!(gk.value, v);
653
654 let ReflectRef::Struct(k) = k.reflect_ref() else {
655 panic!("Invalid type");
656 };
657 assert!(k
658 .get_represented_type_info()
659 .unwrap()
660 .type_path()
661 .contains("GradientKey"));
662 }
663 }
664
665 #[test]
666 fn serde() {
667 let g = make_test_gradient();
668
669 let s = ron::to_string(&g).unwrap();
670 let g_serde: Gradient<Vec4> = ron::from_str(&s).unwrap();
672 assert_eq!(g, g_serde);
673 }
674
675 fn hash_gradient<T>(g: &Gradient<T>) -> u64
677 where
678 T: Default + Lerp + FromReflect,
679 GradientKey<T>: Hash,
680 {
681 let mut hasher = DefaultHasher::default();
682 g.hash(&mut hasher);
683 hasher.finish()
684 }
685
686 fn make_keys<R, T, S>(rng: &mut R, count: usize) -> Vec<(f32, T)>
688 where
689 R: Rng + ?Sized,
690 T: Lerp + FromReflect + From<S>,
691 StandardUniform: Distribution<S>,
692 {
693 if count == 0 {
694 return vec![];
695 }
696 if count == 1 {
697 return vec![(0., rng.random().into())];
698 }
699 let mut ret = Vec::with_capacity(count);
700 for i in 0..count {
701 ret.push((i as f32 / (count - 1) as f32, rng.random().into()));
702 }
703 ret
704 }
705
706 #[test]
707 fn hash() {
708 let mut thread_rng = rng();
709 for count in 0..10 {
710 let keys: Vec<(f32, f32)> = make_keys::<ThreadRng, f32, f32>(&mut thread_rng, count);
711 let mut g1 = Gradient::new().with_keys(keys.into_iter());
712 let g2 = g1.clone();
713 assert_eq!(g1, g2);
714 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
715 if count > 0 {
716 g1.keys_mut()[0].value += 1.;
717 assert_ne!(g1, g2);
718 assert_ne!(hash_gradient(&g1), hash_gradient(&g2));
719 g1.keys_mut()[0].value = g2.keys()[0].value;
720 assert_eq!(g1, g2);
721 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
722 }
723 }
724
725 let mut thread_rng = rng();
726 for count in 0..10 {
727 let keys: Vec<(f32, Vec2)> =
728 make_keys::<ThreadRng, Vec2, (f32, f32)>(&mut thread_rng, count);
729 let mut g1 = Gradient::new().with_keys(keys.into_iter());
730 let g2 = g1.clone();
731 assert_eq!(g1, g2);
732 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
733 if count > 0 {
734 g1.keys_mut()[0].value += 1.;
735 assert_ne!(g1, g2);
736 assert_ne!(hash_gradient(&g1), hash_gradient(&g2));
737 g1.keys_mut()[0].value = g2.keys()[0].value;
738 assert_eq!(g1, g2);
739 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
740 }
741 }
742
743 let mut thread_rng = rng();
744 for count in 0..10 {
745 let keys: Vec<(f32, Vec3)> =
746 make_keys::<ThreadRng, Vec3, (f32, f32, f32)>(&mut thread_rng, count);
747 let mut g1 = Gradient::new().with_keys(keys.into_iter());
748 let g2 = g1.clone();
749 assert_eq!(g1, g2);
750 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
751 if count > 0 {
752 g1.keys_mut()[0].value += 1.;
753 assert_ne!(g1, g2);
754 assert_ne!(hash_gradient(&g1), hash_gradient(&g2));
755 g1.keys_mut()[0].value = g2.keys()[0].value;
756 assert_eq!(g1, g2);
757 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
758 }
759 }
760
761 let mut thread_rng = rng();
762 for count in 0..10 {
763 let keys: Vec<(f32, Vec4)> =
764 make_keys::<ThreadRng, Vec4, (f32, f32, f32, f32)>(&mut thread_rng, count);
765 let mut g1 = Gradient::new().with_keys(keys.into_iter());
766 let g2 = g1.clone();
767 assert_eq!(g1, g2);
768 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
769 if count > 0 {
770 g1.keys_mut()[0].value += 1.;
771 assert_ne!(g1, g2);
772 assert_ne!(hash_gradient(&g1), hash_gradient(&g2));
773 g1.keys_mut()[0].value = g2.keys()[0].value;
774 assert_eq!(g1, g2);
775 assert_eq!(hash_gradient(&g1), hash_gradient(&g2));
776 }
777 }
778 }
779}