use std::{
sync::atomic::{
AtomicU64,
Ordering::{self,Relaxed},
},
mem::ManuallyDrop,
fmt::{Debug, Formatter},
};
struct IncrementalAveragePair32 {
counter: u32,
average: f32,
}
pub union AtomicIncrementalAverage64 {
split: ManuallyDrop<IncrementalAveragePair32>,
joined: ManuallyDrop<AtomicU64>,
}
impl Default for AtomicIncrementalAverage64 {
fn default() -> Self {
Self {
split: ManuallyDrop::new(IncrementalAveragePair32 {
counter: 0,
average: 0.0,
})
}
}
}
impl AtomicIncrementalAverage64 {
pub fn new() -> Self {
Self::default()
}
pub fn inc(&self, measurement: f32) {
self.atomic_compute(Relaxed, Relaxed, |mut counter, average| {
if counter == u32::MAX {
counter = 100;
}
( ( counter + 1 ),
( ( counter as f32 / (1.0 + counter as f32) ) * average ) + ( measurement / (1.0 + counter as f32) )
)
});
}
pub fn probe(&self) -> (u32, f32) {
AtomicIncrementalAverage64::split_joined(unsafe {&self.joined}.load(Relaxed))
}
pub fn lightweight_probe(&self) -> (u32, f32) {
unsafe {
(self.split.counter, self.split.average)
}
}
fn _reset(&self, weight: u32) {
self.atomic_compute(Relaxed, Relaxed, |_counter, average| (weight, average) );
}
fn _split(&self) -> &IncrementalAveragePair32 {
unsafe { &self.split }
}
fn _atomic(&self) -> &AtomicU64 {
unsafe { &self.joined }
}
fn split_joined(joined: u64) -> (u32, f32) {
( (joined & ((1<<32)-1)) as u32, f32::from_bits((joined >> 32) as u32) )
}
fn join_split(counter: u32, average: f32) -> u64 {
(counter as u64) | ((f32::to_bits(average) as u64) << 32)
}
fn atomic_compute(&self, load_ordering: Ordering, store_ordering: Ordering, computation: impl Fn(u32, f32) -> (u32, f32)) {
unsafe {
let mut current_joined = self.joined.load(load_ordering);
loop {
let (current_counter, current_average) = AtomicIncrementalAverage64::split_joined(current_joined);
let (new_counter, new_average) = computation(current_counter, current_average);
let new_joined = AtomicIncrementalAverage64::join_split(new_counter, new_average);
match self.joined.compare_exchange(current_joined, new_joined, store_ordering, load_ordering) {
Err(reloaded_current_joined) => current_joined = reloaded_current_joined,
Ok(_) => break,
}
}
}
}
}
impl Debug for AtomicIncrementalAverage64 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let (counter, average) = self.probe();
write!(f, "AtomicIncrementalAverage64 {{counter: {}, average: {:.5}}}", counter, average)
}
}
#[cfg(any(test,doc))]
mod tests {
use super::*;
#[cfg_attr(not(doc),test)]
fn incremental_averages() {
const ELEMENTS: u32 = 12345679;
let expected_average = |elements| (elements as f32 - 1.0) / 2.0;
let rolling_avg = AtomicIncrementalAverage64::new();
for i in 0..ELEMENTS {
rolling_avg.inc(i as f32);
let (observed_count, observed_avg) = rolling_avg.lightweight_probe();
assert_eq!(observed_count, i+1, "count is wrong");
let delta = (observed_avg-expected_average(i+1)).abs();
assert!(delta <= 0.207, "incremental average, probed along the way, is wrong (within ~ 10^-1 precision) at element #{i} -- observed: {}; expected: {} / delta: {delta}", observed_avg, expected_average(i+1));
}
let (observed_count, observed_avg) = rolling_avg.probe();
assert_eq!(observed_count, ELEMENTS, "count is wrong");
assert!((observed_avg-expected_average(ELEMENTS)).abs() <= 1e-4, "average is wrong -- observed: {}; expected: {}", observed_avg, expected_average(ELEMENTS));
}
}