1use crate::error::{Causality, NegativeWeight};
67use core::{
68 cell::Cell,
69 cmp::Ordering,
70 convert::Infallible,
71 fmt,
72 num::{NonZero, Wrapping},
73 ops::{Add, Sub},
74};
75
76pub struct RandomVariable<S: Weighable = f64, const M: usize = 2> {
114 moments: CentralMoments<M>,
116 inner: Cell<Option<Inner<S::Sample, S::Weight>>>,
118}
119
120#[derive(Clone, Copy, Eq, PartialEq)]
123struct Inner<S, W> {
124 count: NonZero<usize>,
126 min: S,
128 max: S,
130 weight: W,
132}
133
134#[derive(Clone, PartialEq)]
140pub struct CentralMoments<const N: usize>([Cell<f64>; N]);
141
142impl<S: Weighable> RandomVariable<S> {
143 pub const fn new<const M: usize>() -> RandomVariable<S, M> {
145 RandomVariable {
146 moments: CentralMoments([const { Cell::new(0.0) }; M]),
147 inner: Cell::new(None),
148 }
149 }
150}
151
152impl<S: Weighable, const M: usize> RandomVariable<S, M> {
153 pub fn reset(&self) {
155 for m in &self.moments.0 {
156 m.set(0.0);
157 }
158 self.inner.set(None);
159 }
160
161 pub fn tabulate(&self, value: S)
163 where
164 CentralMoments<M>: Univariate,
165 {
166 assert!(self.try_tabulate(value).is_ok());
167 }
168
169 pub fn try_tabulate(&self, value: S) -> Result<(), S::Error>
172 where
173 CentralMoments<M>: Univariate,
174 {
175 value.sample(self)
176 }
177
178 pub fn join(self, other: Self) -> Self
182 where
183 CentralMoments<M>: Univariate,
184 {
185 let wl = self.weight();
186 let wr = other.weight();
187
188 match self.inner.take() {
190 Some(li) => {
191 match other.inner.into_inner() {
193 Some(ri) => {
194 let count = li
196 .count
197 .checked_add(ri.count.get())
198 .expect("overflow in the number of tabulations detected");
199
200 RandomVariable {
201 moments: self.moments.combine(other.moments, wl, wr),
202 inner: Cell::new(Some(Inner {
203 count,
204 weight: S::combine(li.weight, ri.weight),
205 min: if li.min < ri.min { li.min } else { ri.min },
206 max: if li.max > ri.max { li.max } else { ri.max },
207 })),
208 }
209 }
210 _ => {
211 self.inner.set(Some(li));
213 self
214 }
215 }
216 }
217 _ => {
218 other
220 }
221 }
222 }
223
224 pub fn count(&self) -> usize {
226 let inner = self.inner.take();
227 let res = inner.as_ref().map_or(0, |inner| inner.count.get());
228 self.inner.set(inner);
229 res
230 }
231
232 pub fn min(&self) -> Option<S::Sample> {
235 let inner = self.inner.take();
236 let res = inner.as_ref().map(|inner| inner.min.clone());
237 self.inner.set(inner);
238 res
239 }
240
241 pub fn max(&self) -> Option<S::Sample> {
244 let inner = self.inner.take();
245 let res = inner.as_ref().map(|inner| inner.max.clone());
246 self.inner.set(inner);
247 res
248 }
249
250 pub fn weight(&self) -> f64 {
252 S::weight(self)
253 }
254
255 pub fn sum(&self) -> S::Sample
257 where
258 CentralMoments<M>: Mean,
259 {
260 let weight = self.weight();
261
262 if weight == 0.0 {
263 return Sample::qualify(0.0);
264 }
265
266 Sample::qualify(self.moments.mean() * weight)
267 }
268
269 pub fn mean(&self) -> f64
273 where
274 CentralMoments<M>: Mean,
275 {
276 if self.weight() == 0.0 {
278 return f64::NAN;
279 }
280
281 self.moments.mean()
283 }
284
285 pub fn variance(&self) -> f64
295 where
296 CentralMoments<M>: Variance,
297 {
298 let w = self.weight();
299
300 if w <= 1.0 {
302 return f64::NAN;
303 }
304
305 let m2 = self.moments.m2().get();
307 if m2 * m2 <= f64::EPSILON * self.mean() {
308 return 0.0;
309 }
310
311 self.moments.sample_variance(w)
313 }
314
315 #[cfg(any(feature = "std", feature = "libm"))]
326 pub fn std_dev(&self) -> f64
327 where
328 CentralMoments<M>: Variance,
329 {
330 let w = self.weight();
331
332 if w <= 1.0 {
334 return f64::NAN;
335 }
336
337 let m2 = self.moments.m2().get();
339 if m2 * m2 <= f64::EPSILON * self.mean() {
340 return 0.0;
341 }
342
343 sqrt(self.moments.sample_variance(w))
345 }
346
347 #[cfg(any(feature = "std", feature = "libm"))]
357 pub fn skew(&self) -> f64
358 where
359 CentralMoments<M>: Skewness,
360 {
361 let w = self.weight();
362 let m2 = self.moments.m2().get();
363
364 if w <= 2.0 || m2 * m2 <= f64::EPSILON * self.mean() {
366 return f64::NAN;
367 }
368
369 self.moments.sample_skew(w)
371 }
372
373 pub fn kurtosis(&self) -> f64
383 where
384 CentralMoments<M>: Kurtosis,
385 {
386 let w = self.weight();
387 let m2 = self.moments.m2().get();
388
389 if w <= 3.0 || m2 * m2 <= f64::EPSILON * self.mean() {
391 return f64::NAN;
392 }
393
394 self.moments.sample_kurtosis(self.weight())
396 }
397
398 pub fn excess_kurtosis(&self) -> f64
408 where
409 CentralMoments<M>: Kurtosis,
410 {
411 let w = self.weight();
412 let m2 = self.moments.m2().get();
413
414 if w <= 3.0 || m2 * m2 <= f64::EPSILON * self.mean() {
416 return f64::NAN;
417 }
418
419 self.moments.sample_excess_kurtosis(self.weight())
421 }
422
423 pub fn population_variance(&self) -> f64
434 where
435 CentralMoments<M>: Variance,
436 {
437 let w = self.weight();
438
439 let m2 = self.moments.m2().get();
441 if m2 * m2 <= f64::EPSILON * self.mean() {
442 return 0.0;
443 }
444
445 self.moments.population_variance(w)
447 }
448
449 #[cfg(any(feature = "std", feature = "libm"))]
460 pub fn population_std_dev(&self) -> f64
461 where
462 CentralMoments<M>: Variance,
463 {
464 let w = self.weight();
465
466 let m2 = self.moments.m2().get();
468 if m2 * m2 <= f64::EPSILON * self.mean() {
469 return 0.0;
470 }
471
472 sqrt(self.moments.population_variance(w))
474 }
475
476 #[cfg(any(feature = "std", feature = "libm"))]
487 pub fn population_skew(&self) -> f64
488 where
489 CentralMoments<M>: Skewness,
490 {
491 let w = self.weight();
492 let m2 = self.moments.m2().get();
493
494 if m2 * m2 <= f64::EPSILON * self.mean() {
496 return f64::NAN;
497 }
498
499 self.moments.population_skew(w)
501 }
502
503 pub fn population_kurtosis(&self) -> f64
514 where
515 CentralMoments<M>: Kurtosis,
516 {
517 let w = self.weight();
518 let m2 = self.moments.m2().get();
519
520 if m2 * m2 <= f64::EPSILON * self.mean() {
522 return f64::NAN;
523 }
524
525 self.moments.population_kurtosis(w)
527 }
528
529 pub fn population_excess_kurtosis(&self) -> f64
542 where
543 CentralMoments<M>: Kurtosis,
544 {
545 let w = self.weight();
546 let m2 = self.moments.m2().get();
547
548 if m2 * m2 <= f64::EPSILON * self.mean() {
550 return f64::NAN;
551 }
552
553 self.moments.population_excess_kurtosis(w)
555 }
556}
557
558impl<T: Sample, S: Sample, const M: usize> RandomVariable<Utilized<T, S>, M>
560where
561 T: Add<Output = T> + Sub<Output = T>,
562{
563 pub fn tracked(&self) -> Option<S> {
565 let inner = self.inner.take();
566 let res = inner.as_ref().map(|inner| inner.weight.2.clone());
567 self.inner.set(inner);
568 res
569 }
570
571 pub fn update(&self, time: T)
575 where
576 CentralMoments<M>: Univariate,
577 {
578 assert!(
579 self.try_update(time).is_ok(),
580 "updating the time to before the previous tabulation is not permitted"
581 )
582 }
583
584 pub fn try_update(&self, time: T) -> Result<(), Causality<T>>
588 where
589 CentralMoments<M>: Univariate,
590 {
591 let mut result = Ok(());
592
593 if let Some(mut inner) = self.inner.take() {
594 match time.partial_cmp(&inner.weight.1) {
595 None | Some(Ordering::Less) => {
596 result = Err(Causality {
597 cause: inner.weight.1.clone(),
598 effect: time,
599 });
600 }
601 Some(Ordering::Equal) => {}
602 Some(Ordering::Greater) => {
603 self.moments.sample_w(
605 inner.weight.2.quantify(),
606 (time.clone() - inner.weight.1).quantify(),
607 (time.clone() - inner.weight.0.clone()).quantify(),
608 );
609
610 inner.max = time.clone();
612
613 inner.weight.1 = time;
615 }
616 }
617
618 self.inner.set(Some(inner));
620 }
621 result
624 }
625}
626
627pub trait Sample: Clone + PartialOrd {
632 fn quantify(&self) -> f64;
638
639 fn qualify(val: f64) -> Self;
646}
647
648pub trait Weighable: Clone + Sized {
653 type Sample: Sample;
655 type Weight: Clone;
657 type Error;
659
660 fn sample<const M: usize>(self, rv: &RandomVariable<Self, M>) -> Result<(), Self::Error>
663 where
664 CentralMoments<M>: Univariate;
665
666 fn weight<const M: usize>(rv: &RandomVariable<Self, M>) -> f64;
668
669 fn combine(weight: Self::Weight, other: Self::Weight) -> Self::Weight;
671}
672
673impl<S: Sample> Weighable for S {
674 type Sample = S;
675 type Weight = ();
676 type Error = Infallible;
677
678 fn sample<const M: usize>(self, rv: &RandomVariable<Self, M>) -> Result<(), Self::Error>
679 where
680 CentralMoments<M>: Univariate,
681 {
682 rv.inner.set(Some(match rv.inner.replace(None) {
683 Some(mut inner) => {
684 inner.count = inner
686 .count
687 .checked_add(1)
688 .expect("overflow in the number of tabulations detected");
689
690 if inner.min > self {
692 inner.min = self.clone();
693 }
694 if inner.max < self {
695 inner.max = self.clone();
696 }
697
698 rv.moments.sample_u(self.quantify(), inner.count.get());
700
701 inner
702 }
703 _ => {
704 rv.moments.sample_u(self.quantify(), 1);
706 Inner {
707 count: NonZero::new(1).unwrap(),
708 min: self.clone(),
709 max: self,
710 weight: (),
711 }
712 }
713 }));
714
715 Ok(())
716 }
717
718 fn weight<const M: usize>(rv: &RandomVariable<Self, M>) -> f64 {
719 rv.count() as f64
720 }
721
722 fn combine(_: Self::Weight, _: Self::Weight) -> Self::Weight {}
723}
724
725#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
727pub struct Weighted<S, W = f64>(
728 pub S,
730 pub W,
732);
733
734impl<S: Sample, W: Sample + Add<Output = W>> Weighable for Weighted<S, W> {
735 type Sample = S;
736 type Weight = W;
737 type Error = NegativeWeight<W>;
738
739 fn sample<const M: usize>(self, rv: &RandomVariable<Self, M>) -> Result<(), Self::Error>
740 where
741 CentralMoments<M>: Univariate,
742 {
743 let Self(value, weight) = self;
744
745 let w = weight.quantify();
747
748 if w.partial_cmp(&0.0).is_none_or(Ordering::is_lt) {
750 return Err(NegativeWeight { weight });
751 }
752
753 rv.inner.set(Some(match rv.inner.replace(None) {
754 Some(mut inner) => {
755 inner.count = inner
757 .count
758 .checked_add(1)
759 .expect("overflow in the number of tabulations detected");
760 inner.weight = inner.weight + weight;
761
762 if inner.min > value {
764 inner.min = value.clone();
765 }
766 if inner.max < value {
767 inner.max = value.clone();
768 }
769
770 rv.moments
772 .sample_w(value.quantify(), w, inner.weight.quantify());
773
774 inner
775 }
776 _ => {
777 rv.moments.sample_w(value.quantify(), w, w);
779 Inner {
780 count: const { NonZero::new(1).unwrap() },
781 min: value.clone(),
782 max: value.clone(),
783 weight,
784 }
785 }
786 }));
787
788 Ok(())
789 }
790
791 fn weight<const M: usize>(rv: &RandomVariable<Self, M>) -> f64 {
792 let inner = rv.inner.take();
793 let w = inner.as_ref().map_or(0.0, |inner| inner.weight.quantify());
794 rv.inner.set(inner);
795 w
796 }
797
798 fn combine(weight: Self::Weight, other: Self::Weight) -> Self::Weight {
799 weight + other
800 }
801}
802
803#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
807pub struct Utilized<T, S = bool>(
808 pub T,
810 pub S,
812);
813
814impl<T: Sample, S: Sample> Weighable for Utilized<T, S>
815where
816 T: Add<Output = T> + Sub<Output = T>,
817{
818 type Sample = T;
819 type Weight = (T, T, S);
820 type Error = Causality<T>;
821
822 fn sample<const M: usize>(self, rv: &RandomVariable<Self, M>) -> Result<(), Self::Error>
823 where
824 CentralMoments<M>: Univariate,
825 {
826 use core::mem::replace;
827 let Self(time, value) = self;
828 let mut result = Ok(());
829
830 let inner = match rv.inner.take() {
831 Some(mut inner) => {
832 match time.partial_cmp(&inner.weight.1) {
833 None | Some(Ordering::Less) => {
835 result = Err(Causality {
836 cause: inner.weight.1.clone(),
837 effect: time,
838 });
839 }
840 Some(Ordering::Equal) => {
841 inner.weight.2 = value;
843 }
844 Some(Ordering::Greater) => {
845 inner.count = inner
847 .count
848 .checked_add(1)
849 .expect("overflow in the number of tabulations detected");
850
851 inner.max = time.clone();
853
854 rv.moments.sample_w(
856 replace(&mut inner.weight.2, value).quantify(),
857 (time.clone() - inner.weight.1.clone()).quantify(),
858 (time.clone() - inner.weight.0.clone()).quantify(),
859 );
860
861 inner.weight.1 = time;
863 }
864 }
865
866 inner
867 }
868 _ => {
869 Inner {
871 count: const { NonZero::new(1).unwrap() },
872 min: time.clone(),
873 max: time.clone(),
874 weight: (time.clone(), time, value),
875 }
876 }
877 };
878
879 rv.inner.set(Some(inner));
880 result
881 }
882
883 fn weight<const M: usize>(rv: &RandomVariable<Self, M>) -> f64 {
884 let inner = rv.inner.take();
885 let w = inner.as_ref().map_or(0.0, |inner| {
886 (inner.weight.1.clone() - inner.weight.0.clone()).quantify()
887 });
888 rv.inner.set(inner);
889 w
890 }
891
892 fn combine(weight: Self::Weight, other: Self::Weight) -> Self::Weight {
893 (weight.0, weight.1 + (other.1 - other.0), other.2)
894 }
895}
896
897impl<S: Weighable, const M: usize> Default for RandomVariable<S, M> {
900 fn default() -> Self {
901 RandomVariable::<S>::new::<M>()
902 }
903}
904
905impl<S: Weighable, const M: usize> Clone for RandomVariable<S, M> {
906 fn clone(&self) -> Self {
907 let inner = self.inner.take();
908 self.inner.set(inner.clone());
909
910 RandomVariable {
911 moments: self.moments.clone(),
912 inner: Cell::new(inner),
913 }
914 }
915}
916
917impl<S1, S2> PartialEq<RandomVariable<S2>> for RandomVariable<S1>
918where
919 S1: Weighable,
920 S2: Weighable<Sample = S1::Sample>,
921{
922 fn eq(&self, other: &RandomVariable<S2>) -> bool {
923 self.moments.eq(&other.moments)
924 }
925}
926
927impl<S: Weighable> fmt::Debug for RandomVariable<S, 0>
928where
929 S::Sample: fmt::Debug,
930{
931 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
932 let mut dbs = f.debug_struct("RandomVariable");
933 let len = self.count();
934
935 dbs.field("n", &len);
936
937 if let Some(inner) = self.inner.take() {
938 dbs.field("min", &inner.min).field("max", &inner.max);
939 self.inner.set(Some(inner));
940 }
941
942 dbs.finish()
943 }
944}
945
946impl<S: Weighable> fmt::Debug for RandomVariable<S, 1>
947where
948 S::Sample: fmt::Debug,
949{
950 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
951 let mut dbs = f.debug_struct("RandomVariable");
952 let len = self.count();
953
954 dbs.field("n", &len);
955
956 if let Some(inner) = self.inner.take() {
957 dbs.field("min", &inner.min).field("max", &inner.max);
958 self.inner.set(Some(inner));
959 dbs.field("mean", &self.mean());
960 }
961
962 dbs.finish()
963 }
964}
965
966impl<S: Weighable> fmt::Debug for RandomVariable<S, 2>
967where
968 S::Sample: fmt::Debug,
969{
970 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
971 let mut dbs = f.debug_struct("RandomVariable");
972 let len = self.count();
973
974 dbs.field("n", &len);
975
976 if let Some(inner) = self.inner.take() {
977 dbs.field("min", &inner.min).field("max", &inner.max);
978 self.inner.set(Some(inner));
979 dbs.field("mean", &self.mean());
980
981 #[cfg(any(feature = "std", feature = "libm"))]
982 dbs.field("sdev", &self.std_dev());
983
984 #[cfg(all(not(feature = "std"), not(feature = "libm")))]
985 dbs.field("variance", &self.variance());
986 }
987
988 dbs.finish()
989 }
990}
991
992impl<S: Weighable> fmt::Debug for RandomVariable<S, 3>
993where
994 S::Sample: fmt::Debug,
995{
996 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
997 let mut dbs = f.debug_struct("RandomVariable");
998 let len = self.count();
999
1000 dbs.field("n", &len);
1001
1002 if let Some(inner) = self.inner.take() {
1003 dbs.field("min", &inner.min).field("max", &inner.max);
1004 self.inner.set(Some(inner));
1005 dbs.field("mean", &self.mean());
1006
1007 #[cfg(any(feature = "std", feature = "libm"))]
1008 dbs.field("sdev", &self.std_dev())
1009 .field("skew", &self.skew());
1010
1011 #[cfg(all(not(feature = "std"), not(feature = "libm")))]
1012 dbs.field("variance", &self.variance());
1013 }
1014
1015 dbs.finish()
1016 }
1017}
1018
1019impl<S: Weighable> fmt::Debug for RandomVariable<S, 4>
1020where
1021 S::Sample: fmt::Debug,
1022{
1023 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1024 let mut dbs = f.debug_struct("RandomVariable");
1025 let len = self.count();
1026
1027 dbs.field("n", &len);
1028
1029 if let Some(inner) = self.inner.take() {
1030 dbs.field("min", &inner.min).field("max", &inner.max);
1031 self.inner.set(Some(inner));
1032 dbs.field("mean", &self.mean());
1033
1034 #[cfg(any(feature = "std", feature = "libm"))]
1035 dbs.field("sdev", &self.std_dev())
1036 .field("skew", &self.skew());
1037
1038 #[cfg(all(not(feature = "std"), not(feature = "libm")))]
1039 dbs.field("variance", &self.variance());
1040
1041 dbs.field("kurt", &self.kurtosis());
1042 }
1043
1044 dbs.finish()
1045 }
1046}
1047
1048macro_rules! impl_primitive_value {
1049 ($($T:ident),*) => {$(
1050 impl Sample for $T {
1051 fn quantify(&self) -> f64 {
1052 *self as f64
1053 }
1054
1055 fn qualify(val: f64) -> Self {
1056 val as $T
1057 }
1058 }
1059 )*}
1060}
1061
1062macro_rules! impl_wrapped_value {
1063 ($($T:ident < $V:ty >),*) => {$(
1064 impl Sample for $T<$V> {
1065 fn quantify(&self) -> f64 {
1066 self.0 as f64
1067 }
1068
1069 fn qualify(val: f64) -> Self {
1070 $T(val as $V)
1071 }
1072 }
1073 )*}
1074}
1075
1076impl_primitive_value!(
1077 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
1078);
1079
1080impl_wrapped_value!(
1081 Wrapping<i8>,
1082 Wrapping<i16>,
1083 Wrapping<i32>,
1084 Wrapping<i64>,
1085 Wrapping<isize>,
1086 Wrapping<u8>,
1087 Wrapping<u16>,
1088 Wrapping<u32>,
1089 Wrapping<u64>,
1090 Wrapping<usize>
1091);
1092
1093impl Sample for bool {
1094 fn quantify(&self) -> f64 {
1095 *self as u8 as f64
1096 }
1097
1098 fn qualify(val: f64) -> Self {
1099 val != 0.0
1100 }
1101}
1102
1103#[doc(hidden)]
1108pub trait Univariate {
1109 fn sample_u(&self, x: f64, n: usize);
1111
1112 fn sample_w(&self, x: f64, w: f64, s: f64);
1114
1115 fn combine(self, other: Self, wl: f64, wr: f64) -> Self;
1117}
1118
1119#[doc(hidden)]
1121pub trait Mean {
1122 fn m1(&self) -> &Cell<f64>;
1124
1125 fn mean(&self) -> f64 {
1127 self.m1().get()
1128 }
1129}
1130
1131#[doc(hidden)]
1133pub trait Variance: Mean {
1134 fn m2(&self) -> &Cell<f64>;
1136
1137 fn sample_variance(&self, w: f64) -> f64 {
1140 let m2 = self.m2().get();
1141 m2 / (w - 1.0)
1142 }
1143
1144 fn population_variance(&self, w: f64) -> f64 {
1147 let m2 = self.m2().get();
1148 m2 / w
1149 }
1150}
1151
1152#[doc(hidden)]
1154pub trait Skewness: Variance {
1155 fn m3(&self) -> &Cell<f64>;
1157
1158 #[cfg(any(feature = "std", feature = "libm"))]
1161 fn sample_skew(&self, w: f64) -> f64 {
1162 let variance = self.sample_variance(w);
1163 let stddev = sqrt(variance);
1164 w / ((w - 1.0) * (w - 2.0)) * self.m3().get() / (variance * stddev)
1165 }
1166
1167 #[cfg(any(feature = "std", feature = "libm"))]
1170 fn population_skew(&self, w: f64) -> f64 {
1171 let variance = self.population_variance(w);
1172 let stddev = sqrt(variance);
1173 self.m3().get() / (w * variance * stddev)
1174 }
1175}
1176
1177#[doc(hidden)]
1179pub trait Kurtosis: Variance {
1180 fn m4(&self) -> &Cell<f64>;
1182
1183 fn sample_kurtosis(&self, w: f64) -> f64 {
1186 let v = self.sample_variance(w);
1187 (w / (w - 1.0)) * ((w + 1.0) / (w - 2.0)) / (w - 3.0) * self.m4().get() / (v * v)
1188 }
1189
1190 fn sample_excess_kurtosis(&self, w: f64) -> f64 {
1196 self.sample_kurtosis(w) - 3.0 * ((w - 1.0) / (w - 2.0)) * ((w - 1.0) / (w - 3.0))
1197 }
1198
1199 fn population_kurtosis(&self, w: f64) -> f64 {
1202 let v = self.population_variance(w);
1203 self.m4().get() / w / (v * v)
1204 }
1205
1206 fn population_excess_kurtosis(&self, w: f64) -> f64 {
1212 self.population_kurtosis(w) - 3.0
1213 }
1214}
1215
1216macro_rules! central_moments {
1217 (Mean for $CM:ident <$($N:literal),*>) => {
1218 $(impl Mean for $CM<$N> {
1219 #[inline]
1220 fn m1(&self) -> &Cell<f64> {
1221 &self.0[0]
1222 }
1223 })*
1224 };
1225 (Variance for $CM:ident <$($N:literal),*>) => {
1226 $(impl Variance for $CM<$N> {
1227 #[inline]
1228 fn m2(&self) -> &Cell<f64> {
1229 &self.0[1]
1230 }
1231 })*
1232 };
1233 (Skewness for $CM:ident <$($N:literal),*>) => {
1234 $(impl Skewness for $CM<$N> {
1235 #[inline]
1236 fn m3(&self) -> &Cell<f64> {
1237 &self.0[2]
1238 }
1239 })*
1240 };
1241 (Kurtosis for $CM:ident <$($N:literal),*>) => {
1242 $(impl Kurtosis for $CM<$N> {
1243 #[inline]
1244 fn m4(&self) -> &Cell<f64> {
1245 &self.0[3]
1246 }
1247 })*
1248 };
1249}
1250
1251central_moments!(Mean for CentralMoments<1, 2, 3, 4>);
1252central_moments!(Variance for CentralMoments<2, 3, 4>);
1253central_moments!(Skewness for CentralMoments<3, 4>);
1254central_moments!(Kurtosis for CentralMoments<4>);
1255
1256impl Univariate for CentralMoments<0> {
1257 fn sample_u(&self, _x: f64, _n: usize) {}
1258
1259 fn sample_w(&self, _x: f64, _w: f64, _s: f64) {}
1260
1261 fn combine(self, _other: Self, _wl: f64, _wr: f64) -> Self {
1262 self
1263 }
1264}
1265
1266impl Univariate for CentralMoments<1> {
1267 fn sample_u(&self, x: f64, n: usize) {
1268 let m1 = self.m1();
1269 m1.set(m1.get() + (x - m1.get()) / n as f64);
1270 }
1271
1272 fn sample_w(&self, x: f64, w: f64, s: f64) {
1273 let m1 = self.m1();
1274 m1.set(m1.get() + w * (x - m1.get()) / s);
1275 }
1276
1277 fn combine(self, other: Self, wl: f64, wr: f64) -> Self {
1278 let sm1 = self.m1();
1279 let om1 = other.m1();
1280
1281 sm1.set((wl * sm1.get() + wr * om1.get()) / (wl + wr));
1282
1283 self
1284 }
1285}
1286
1287impl Univariate for CentralMoments<2> {
1288 fn sample_u(&self, x: f64, n: usize) {
1289 let m1 = self.m1();
1290 let m2 = self.m2();
1291 let d = x - m1.get();
1292
1293 m1.set(m1.get() + d / n as f64);
1294 m2.set(m2.get() + d * (x - m1.get()));
1295 }
1296
1297 fn sample_w(&self, x: f64, w: f64, s: f64) {
1298 let m1 = self.m1();
1299 let m2 = self.m2();
1300 let d = w * (x - m1.get());
1301
1302 m1.set(m1.get() + d / s);
1303 m2.set(m2.get() + d * (x - m1.get()));
1304 }
1305
1306 fn combine(self, other: Self, wl: f64, wr: f64) -> Self {
1307 let sm1 = self.m1();
1308 let om1 = other.m1();
1309 let sm2 = self.m2();
1310 let om2 = other.m2();
1311
1312 let m1d = om1.get() - sm1.get();
1313 let w = wl + wr;
1314
1315 sm2.set(sm2.get() + om2.get() + m1d * m1d * wl * wr / w);
1316 sm1.set((wl * sm1.get() + wr * om1.get()) / w);
1317
1318 self
1319 }
1320}
1321
1322impl Univariate for CentralMoments<3> {
1323 fn sample_u(&self, x: f64, n: usize) {
1324 let m1 = self.m1();
1325 let m2 = self.m2();
1326 let m3 = self.m3();
1327 let d = x - m1.get();
1328 let s = n as f64;
1329
1330 let a = d / s;
1331 m1.set(m1.get() + a);
1332 let b = x - m1.get();
1333 m3.set(m3.get() + a * (d * b * (s - 2.0) - 3.0 * m2.get()));
1334 m2.set(m2.get() + d * b);
1335 }
1336
1337 fn sample_w(&self, x: f64, w: f64, s: f64) {
1338 let m1 = self.m1();
1339 let m2 = self.m2();
1340 let m3 = self.m3();
1341 let d = x - m1.get();
1342
1343 let a = w * d / s;
1344 m1.set(m1.get() + a);
1345 let b = x - m1.get();
1346 m3.set(m3.get() + a * (d * b * (s - 2.0 * w) - 3.0 * m2.get()));
1347 m2.set(m2.get() + w * d * b);
1348 }
1349
1350 fn combine(self, other: Self, wl: f64, wr: f64) -> Self {
1351 let sm1 = self.m1();
1352 let om1 = other.m1();
1353 let sm2 = self.m2();
1354 let om2 = other.m2();
1355 let sm3 = self.m3();
1356 let om3 = other.m3();
1357
1358 let m1d = om1.get() - sm1.get();
1359 let m1d2 = m1d * m1d;
1360 let w = wl + wr;
1361
1362 sm3.set(
1363 sm3.get()
1364 + om3.get() + 3.0 * m1d * (wl * om2.get() - wr * sm2.get()) / w
1365 + m1d2 * m1d * wl * wr * (wl - wr) / (w * w),
1366 );
1367
1368 sm2.set(sm2.get() + om2.get() + m1d2 * wl * wr / w);
1369 sm1.set((wl * sm1.get() + wr * om1.get()) / w);
1370
1371 self
1372 }
1373}
1374
1375impl Univariate for CentralMoments<4> {
1376 fn sample_u(&self, x: f64, n: usize) {
1377 let m1 = self.m1();
1378 let m2 = self.m2();
1379 let m3 = self.m3();
1380 let m4 = self.m4();
1381 let d = x - m1.get();
1382 let s = n as f64;
1383
1384 let a = d / s;
1385 m4.set(
1386 m4.get()
1387 + a * (6.0 * a * m2.get() - 4.0 * m3.get()
1388 + a * a * d * (s - 1.0) * (s * (s - 3.0) + 3.0)),
1389 );
1390 m1.set(m1.get() + a);
1391 let b = x - m1.get();
1392 m3.set(m3.get() + a * (d * b * (s - 2.0) - 3.0 * m2.get()));
1393 m2.set(m2.get() + d * b);
1394 }
1395
1396 fn sample_w(&self, x: f64, w: f64, s: f64) {
1397 let m1 = self.m1();
1398 let m2 = self.m2();
1399 let m3 = self.m3();
1400 let m4 = self.m4();
1401 let d = x - m1.get();
1402
1403 let a = d / s;
1404 let t = w * a;
1405 m4.set(
1406 m4.get()
1407 + t * (6.0 * t * m2.get() - 4.0 * m3.get()
1408 + a * a * d * (s - w) * (s * (s - 3.0 * w) + 3.0 * w * w)),
1409 );
1410 m1.set(m1.get() + t);
1411 let b = x - m1.get();
1412 m3.set(m3.get() + t * (d * b * (s - 2.0 * w) - 3.0 * m2.get()));
1413 m2.set(m2.get() + w * d * b);
1414 }
1415
1416 fn combine(self, other: Self, wl: f64, wr: f64) -> Self {
1417 let sm1 = self.m1();
1418 let om1 = other.m1();
1419 let sm2 = self.m2();
1420 let om2 = other.m2();
1421 let sm3 = self.m3();
1422 let om3 = other.m3();
1423 let sm4 = self.m4();
1424 let om4 = other.m4();
1425
1426 let m1d = om1.get() - sm1.get();
1427 let m1d2 = m1d * m1d;
1428 let w = wl + wr;
1429 let w2 = w * w;
1430
1431 sm4.set(
1432 sm4.get()
1433 + om4.get() + 4.0 * m1d * (wl * om3.get() - wr * sm3.get()) / w
1434 + 6.0 * m1d2 * (wl * wl * om2.get() / w2 + wr * wr * sm2.get() / w2)
1435 + m1d2 * m1d2 * wl * wr / w * (wl * wl - wl * wr + wr * wr) / w2,
1436 );
1437
1438 sm3.set(
1439 sm3.get()
1440 + om3.get() + 3.0 * m1d * (wl * om2.get() - wr * sm2.get()) / w
1441 + m1d2 * m1d * wl * wr * (wl - wr) / w2,
1442 );
1443
1444 sm2.set(sm2.get() + om2.get() + m1d2 * wl * wr / w);
1445 sm1.set((wl * sm1.get() + wr * om1.get()) / w);
1446
1447 self
1448 }
1449}
1450
1451#[cfg(any(feature = "std", feature = "libm"))]
1454fn sqrt(val: f64) -> f64 {
1455 #[cfg(feature = "std")]
1456 {
1457 val.sqrt()
1458 }
1459 #[cfg(all(feature = "libm", not(feature = "std")))]
1460 {
1461 libm::sqrt(val)
1462 }
1463}
1464
1465#[cfg(test)]
1466mod tests {
1467 use float_eq::float_eq;
1468 use prop::collection::vec;
1469 use proptest::prelude::*;
1470
1471 use super::*;
1472
1473 #[test]
1474 fn new_rv_is_empty() {
1475 let rv1 = RandomVariable::<i32>::default();
1476 let rv2 = RandomVariable::<Weighted<i32>>::default();
1477 let rv3 = RandomVariable::<Utilized<f32, i32>>::default();
1478
1479 assert!(
1480 rv1.count() == 0
1481 && rv1.min().is_none()
1482 && rv1.max().is_none()
1483 && rv1.mean().is_nan()
1484 && rv1.variance().is_nan()
1485 && rv1.std_dev().is_nan()
1486 );
1487
1488 assert!(
1489 rv2.count() == 0
1490 && rv2.min().is_none()
1491 && rv2.max().is_none()
1492 && rv2.mean().is_nan()
1493 && rv2.variance().is_nan()
1494 && rv2.std_dev().is_nan()
1495 );
1496
1497 assert!(
1498 rv3.count() == 0
1499 && rv3.min().is_none()
1500 && rv3.max().is_none()
1501 && rv3.mean().is_nan()
1502 && rv3.variance().is_nan()
1503 && rv3.std_dev().is_nan()
1504 );
1505 }
1506
1507 #[test]
1508 #[should_panic]
1509 fn negative_weight_panics() {
1510 let rv = RandomVariable::new::<4>();
1511
1512 for e in 0..100 {
1513 rv.tabulate(Weighted(e, 1.5));
1514 }
1515
1516 rv.tabulate(Weighted(0, -1.0));
1517 }
1518
1519 #[test]
1520 #[should_panic]
1521 fn negative_time_delta_panics() {
1522 let rv = RandomVariable::new::<4>();
1523 let mut time = 0;
1524
1525 for e in 0..100 {
1526 rv.tabulate(Utilized(time, e));
1527 time += 1;
1528 }
1529
1530 rv.tabulate(Utilized(time - 2, 0));
1531 }
1532
1533 proptest! {
1534 #[test]
1535 #[cfg_attr(miri, ignore)]
1536 fn mean_calculated_correctly(elems in vec(0..1000, 2..100)) {
1537 let rv1 = RandomVariable::new::<4>();
1538 let n = elems.len();
1539 let mut s = 0;
1540
1541 for elem in elems {
1542 rv1.tabulate(elem);
1543 s += elem;
1544 }
1545
1546 let mean = s as f64 / n as f64;
1547 prop_assert!(
1548 float_eq!(
1549 rv1.mean(),
1550 mean,
1551 r2nd <= ((n as f64) * 0.5) * f64::EPSILON
1552 ),
1553 "{:?} != {:?}", rv1.mean(), mean
1554 );
1555 }
1556
1557 #[test]
1558 #[cfg_attr(miri, ignore)]
1559 fn join_with_empty_rv_unweighted(elems in vec(-100.0f64..100.0, 2..100)) {
1560 let rv1 = RandomVariable::default();
1561
1562 for elem in elems {
1563 rv1.tabulate(elem);
1564 }
1565
1566 prop_assert_eq!(&rv1, &rv1.clone().join(RandomVariable::new()));
1567 prop_assert_eq!(&rv1, &RandomVariable::new().join(rv1.clone()));
1568 }
1569
1570 #[test]
1571 #[cfg_attr(miri, ignore)]
1572 fn join_with_empty_rv_weighted(elems in vec((0..100, 0.0f64..1000.0), 2..100)) {
1573 let rv1 = RandomVariable::default();
1574
1575 for elem in elems {
1576 rv1.tabulate(Weighted(elem.0, elem.1));
1577 }
1578
1579 prop_assert_eq!(&rv1, &rv1.clone().join(RandomVariable::default()));
1580 prop_assert_eq!(&rv1, &RandomVariable::default().join(rv1.clone()));
1581 }
1582
1583 #[test]
1584 #[cfg_attr(miri, ignore)]
1585 fn join_with_empty_rv_temporal(elems in vec((0.0f64..1000.0, 0..100), 2..100)) {
1586 let rv1 = RandomVariable::default();
1587 let mut time = 0.0;
1588
1589 for elem in elems {
1590 time += elem.0;
1591 rv1.tabulate(Utilized(time, elem.1));
1592 }
1593
1594 prop_assert_eq!(&rv1, &rv1.clone().join(RandomVariable::default()));
1595 prop_assert_eq!(&rv1, &RandomVariable::default().join(rv1.clone()));
1596 }
1597
1598 #[test]
1599 #[cfg_attr(miri, ignore)]
1600 fn equivalence_between_unweighted_and_weighted(elems in vec(-100.0..100.0, 3..100)) {
1601 let rv1 = RandomVariable::default();
1602 let rv2 = RandomVariable::default();
1603
1604 for elem in elems {
1605 rv1.tabulate(elem);
1606 rv2.tabulate(Weighted(elem, 1));
1607 }
1608
1609 prop_assert_eq!(rv1, rv2);
1610 }
1611
1612 #[test]
1613 #[cfg_attr(miri, ignore)]
1614 fn equivalence_between_weighted_and_temporal(elems in vec((1..100, 1..100), 3..100)) {
1615 let rv1 = RandomVariable::default();
1616 let rv2 = RandomVariable::default();
1617 let mut t = 0;
1618
1619 for (w,e) in elems {
1620 rv1.tabulate(Weighted(e, w));
1621 rv2.tabulate(Utilized(t, e));
1622 t += w;
1623 }
1624 rv2.update(t);
1625
1626 prop_assert_eq!(rv1, rv2);
1627 }
1628 }
1629}