use atomic_float::AtomicF64;
use std::sync::{
Arc, RwLock,
atomic::{AtomicU32, AtomicUsize, Ordering},
};
pub struct ProgressTracker {
consumed_samples: Arc<AtomicUsize>,
sample_rate: Arc<AtomicU32>,
channels: Arc<AtomicU32>,
#[allow(clippy::type_complexity)]
callback: Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>>,
last_reported_position: Arc<AtomicF64>,
threshold: f64,
}
impl ProgressTracker {
#[must_use]
pub fn new(threshold: Option<f64>) -> Self {
Self {
consumed_samples: Arc::new(AtomicUsize::new(0)),
sample_rate: Arc::new(AtomicU32::new(0)),
channels: Arc::new(AtomicU32::new(0)),
callback: Arc::new(RwLock::new(None)),
last_reported_position: Arc::new(AtomicF64::new(0.0)),
threshold: threshold.unwrap_or(0.1),
}
}
pub fn set_audio_spec(&self, sample_rate: u32, channels: u32) {
self.sample_rate.store(sample_rate, Ordering::SeqCst);
self.channels.store(channels, Ordering::SeqCst);
log::debug!("ProgressTracker: audio spec set - rate={sample_rate}, channels={channels}");
}
pub fn set_callback(&self, callback: Option<Box<dyn Fn(f64) + Send + Sync + 'static>>) {
if let Ok(mut cb) = self.callback.write() {
*cb = callback;
log::debug!("ProgressTracker: callback set");
} else {
log::error!("ProgressTracker: failed to acquire write lock for callback");
}
}
#[must_use]
pub fn consumed_samples_counter(&self) -> Arc<AtomicUsize> {
self.consumed_samples.clone()
}
pub fn update_consumed_samples(&self, additional_samples: usize) {
if additional_samples == 0 {
return;
}
let new_consumed = self
.consumed_samples
.fetch_add(additional_samples, Ordering::SeqCst)
+ additional_samples;
let sample_rate = self.sample_rate.load(Ordering::SeqCst);
let channels = self.channels.load(Ordering::SeqCst);
if sample_rate > 0 && channels > 0 {
#[allow(clippy::cast_precision_loss)]
let current_position =
new_consumed as f64 / (f64::from(sample_rate) * f64::from(channels));
let last_position = self.last_reported_position.load(Ordering::SeqCst);
if (current_position - last_position).abs() > self.threshold {
self.last_reported_position
.store(current_position, Ordering::SeqCst);
if let Ok(callback_guard) = self.callback.try_read()
&& let Some(callback) = callback_guard.as_ref()
{
callback(current_position);
}
}
}
}
#[must_use]
pub fn get_position(&self) -> Option<f64> {
let consumed = self.consumed_samples.load(Ordering::SeqCst);
let sample_rate = self.sample_rate.load(Ordering::SeqCst);
let channels = self.channels.load(Ordering::SeqCst);
if sample_rate > 0 && channels > 0 {
#[allow(clippy::cast_precision_loss)]
Some(consumed as f64 / (f64::from(sample_rate) * f64::from(channels)))
} else {
None
}
}
pub fn set_consumed_samples(&self, samples: usize) {
self.consumed_samples.store(samples, Ordering::SeqCst);
if let Some(position) = self.get_position() {
self.last_reported_position
.store(position, Ordering::SeqCst);
}
log::debug!("ProgressTracker: consumed samples set to {samples}");
}
pub fn reset(&self) {
self.consumed_samples.store(0, Ordering::SeqCst);
self.last_reported_position.store(0.0, Ordering::SeqCst);
log::debug!("ProgressTracker: reset for new track");
}
#[must_use]
#[allow(clippy::type_complexity)]
pub fn get_callback_refs(
&self,
) -> (
Arc<AtomicUsize>,
Arc<AtomicU32>,
Arc<AtomicU32>,
Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>>,
Arc<AtomicF64>,
) {
(
self.consumed_samples.clone(),
self.sample_rate.clone(),
self.channels.clone(),
self.callback.clone(),
self.last_reported_position.clone(),
)
}
#[allow(clippy::type_complexity)]
pub fn update_from_callback_refs(
consumed_samples: &Arc<AtomicUsize>,
sample_rate: &Arc<AtomicU32>,
channels: &Arc<AtomicU32>,
callback: &Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>>,
last_reported_position: &Arc<AtomicF64>,
additional_samples: usize,
threshold: f64,
) {
if additional_samples == 0 {
return;
}
let new_consumed =
consumed_samples.fetch_add(additional_samples, Ordering::SeqCst) + additional_samples;
let sample_rate_val = sample_rate.load(Ordering::SeqCst);
let channels_val = channels.load(Ordering::SeqCst);
if sample_rate_val > 0 && channels_val > 0 {
#[allow(clippy::cast_precision_loss)]
let current_position =
new_consumed as f64 / (f64::from(sample_rate_val) * f64::from(channels_val));
let last_position = last_reported_position.load(Ordering::SeqCst);
if (current_position - last_position).abs() > threshold {
last_reported_position.store(current_position, Ordering::SeqCst);
if let Ok(callback_guard) = callback.try_read()
&& let Some(cb) = callback_guard.as_ref()
{
cb(current_position);
}
}
}
}
}
impl std::fmt::Debug for ProgressTracker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProgressTracker")
.field(
"consumed_samples",
&self.consumed_samples.load(Ordering::SeqCst),
)
.field("sample_rate", &self.sample_rate.load(Ordering::SeqCst))
.field("channels", &self.channels.load(Ordering::SeqCst))
.field(
"last_reported_position",
&self.last_reported_position.load(Ordering::SeqCst),
)
.field("threshold", &self.threshold)
.finish_non_exhaustive()
}
}
impl Default for ProgressTracker {
fn default() -> Self {
Self::new(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[test_log::test]
fn test_progress_tracker_new() {
let tracker = ProgressTracker::new(Some(0.5));
assert!((tracker.threshold - 0.5).abs() < f64::EPSILON);
assert_eq!(tracker.get_position(), None);
}
#[test_log::test]
fn test_progress_tracker_default() {
let tracker = ProgressTracker::default();
assert!((tracker.threshold - 0.1).abs() < f64::EPSILON);
assert_eq!(tracker.get_position(), None);
}
#[test_log::test]
fn test_set_audio_spec() {
let tracker = ProgressTracker::new(None);
tracker.set_audio_spec(44100, 2);
assert_eq!(tracker.sample_rate.load(Ordering::SeqCst), 44100);
assert_eq!(tracker.channels.load(Ordering::SeqCst), 2);
}
#[test_log::test]
fn test_get_position_without_spec() {
let tracker = ProgressTracker::new(None);
assert_eq!(tracker.get_position(), None);
}
#[test_log::test]
fn test_get_position_with_spec() {
let tracker = ProgressTracker::new(None);
tracker.set_audio_spec(44100, 2);
tracker.update_consumed_samples(88200);
let position = tracker.get_position().unwrap();
assert!((position - 1.0).abs() < 0.001);
}
#[test_log::test]
fn test_update_consumed_samples_zero() {
let tracker = ProgressTracker::new(None);
tracker.set_audio_spec(44100, 2);
tracker.update_consumed_samples(0);
assert_eq!(tracker.consumed_samples.load(Ordering::SeqCst), 0);
}
#[test_log::test]
fn test_update_consumed_samples_incremental() {
let tracker = ProgressTracker::new(None);
tracker.set_audio_spec(44100, 2);
tracker.update_consumed_samples(44100); assert_eq!(tracker.consumed_samples.load(Ordering::SeqCst), 44100);
tracker.update_consumed_samples(44100); assert_eq!(tracker.consumed_samples.load(Ordering::SeqCst), 88200);
}
#[test_log::test]
fn test_set_consumed_samples() {
let tracker = ProgressTracker::new(None);
tracker.set_audio_spec(44100, 2);
tracker.set_consumed_samples(176_400); assert_eq!(tracker.consumed_samples.load(Ordering::SeqCst), 176_400);
let position = tracker.get_position().unwrap();
assert!((position - 2.0).abs() < 0.001);
}
#[test_log::test]
fn test_reset() {
let tracker = ProgressTracker::new(None);
tracker.set_audio_spec(44100, 2);
tracker.update_consumed_samples(88200);
assert!(tracker.consumed_samples.load(Ordering::SeqCst) > 0);
tracker.reset();
assert_eq!(tracker.consumed_samples.load(Ordering::SeqCst), 0);
assert!(tracker.last_reported_position.load(Ordering::SeqCst).abs() < f64::EPSILON);
}
#[test_log::test]
fn test_callback_triggered_on_threshold() {
let tracker = ProgressTracker::new(Some(0.5)); tracker.set_audio_spec(44100, 2);
let callback_positions = Arc::new(Mutex::new(Vec::new()));
let callback_positions_clone = callback_positions.clone();
tracker.set_callback(Some(Box::new(move |pos| {
callback_positions_clone.lock().unwrap().push(pos);
})));
tracker.update_consumed_samples(22050); assert_eq!(callback_positions.lock().unwrap().len(), 0);
tracker.update_consumed_samples(44100); assert_eq!(callback_positions.lock().unwrap().len(), 1);
}
#[test_log::test]
fn test_callback_not_triggered_without_spec() {
let tracker = ProgressTracker::new(Some(0.1));
let callback_positions = Arc::new(Mutex::new(Vec::new()));
let callback_positions_clone = callback_positions.clone();
tracker.set_callback(Some(Box::new(move |pos| {
callback_positions_clone.lock().unwrap().push(pos);
})));
tracker.update_consumed_samples(88200);
assert_eq!(callback_positions.lock().unwrap().len(), 0);
}
#[test_log::test]
fn test_consumed_samples_counter() {
let tracker = ProgressTracker::new(None);
let counter = tracker.consumed_samples_counter();
tracker.set_audio_spec(44100, 2);
tracker.update_consumed_samples(1000);
assert_eq!(counter.load(Ordering::SeqCst), 1000);
}
#[test_log::test]
fn test_get_callback_refs() {
let tracker = ProgressTracker::new(Some(0.2));
tracker.set_audio_spec(48000, 2);
tracker.update_consumed_samples(1000);
let (consumed, rate, channels, _callback, last_pos) = tracker.get_callback_refs();
assert_eq!(consumed.load(Ordering::SeqCst), 1000);
assert_eq!(rate.load(Ordering::SeqCst), 48000);
assert_eq!(channels.load(Ordering::SeqCst), 2);
assert!(last_pos.load(Ordering::SeqCst).abs() < f64::EPSILON);
}
#[test_log::test]
fn test_update_from_callback_refs() {
let consumed = Arc::new(AtomicUsize::new(0));
let sample_rate = Arc::new(AtomicU32::new(44100));
let channels = Arc::new(AtomicU32::new(2));
let callback = Arc::new(RwLock::new(None));
let last_pos = Arc::new(AtomicF64::new(0.0));
ProgressTracker::update_from_callback_refs(
&consumed,
&sample_rate,
&channels,
&callback,
&last_pos,
88200, 0.5, );
assert_eq!(consumed.load(Ordering::SeqCst), 88200);
}
#[test_log::test]
#[allow(clippy::type_complexity)]
fn test_update_from_callback_refs_with_callback() {
let consumed = Arc::new(AtomicUsize::new(0));
let sample_rate = Arc::new(AtomicU32::new(44100));
let channels = Arc::new(AtomicU32::new(2));
let callback_positions = Arc::new(Mutex::new(Vec::new()));
let callback_positions_clone = callback_positions.clone();
let callback: Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>> =
Arc::new(RwLock::new(Some(Box::new(move |pos| {
callback_positions_clone.lock().unwrap().push(pos);
}))));
let last_pos = Arc::new(AtomicF64::new(0.0));
ProgressTracker::update_from_callback_refs(
&consumed,
&sample_rate,
&channels,
&callback,
&last_pos,
88200, 0.5,
);
assert_eq!(callback_positions.lock().unwrap().len(), 1);
let pos = callback_positions.lock().unwrap()[0];
assert!((pos - 1.0).abs() < 0.001);
}
#[test_log::test]
#[allow(clippy::type_complexity)]
fn test_update_from_callback_refs_zero_samples() {
let consumed = Arc::new(AtomicUsize::new(100));
let sample_rate = Arc::new(AtomicU32::new(44100));
let channels = Arc::new(AtomicU32::new(2));
let callback: Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>> =
Arc::new(RwLock::new(None));
let last_pos = Arc::new(AtomicF64::new(0.0));
ProgressTracker::update_from_callback_refs(
&consumed,
&sample_rate,
&channels,
&callback,
&last_pos,
0, 0.1,
);
assert_eq!(consumed.load(Ordering::SeqCst), 100);
}
#[test_log::test]
#[allow(clippy::type_complexity)]
fn test_update_from_callback_refs_zero_sample_rate() {
let consumed = Arc::new(AtomicUsize::new(0));
let sample_rate = Arc::new(AtomicU32::new(0)); let channels = Arc::new(AtomicU32::new(2));
let callback_positions = Arc::new(Mutex::new(Vec::new()));
let callback_positions_clone = callback_positions.clone();
let callback: Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>> =
Arc::new(RwLock::new(Some(Box::new(move |pos| {
callback_positions_clone.lock().unwrap().push(pos);
}))));
let last_pos = Arc::new(AtomicF64::new(0.0));
ProgressTracker::update_from_callback_refs(
&consumed,
&sample_rate,
&channels,
&callback,
&last_pos,
88200,
0.5,
);
assert_eq!(consumed.load(Ordering::SeqCst), 88200);
assert!(callback_positions.lock().unwrap().is_empty());
}
#[test_log::test]
#[allow(clippy::type_complexity)]
fn test_update_from_callback_refs_zero_channels() {
let consumed = Arc::new(AtomicUsize::new(0));
let sample_rate = Arc::new(AtomicU32::new(44100));
let channels = Arc::new(AtomicU32::new(0)); let callback_positions = Arc::new(Mutex::new(Vec::new()));
let callback_positions_clone = callback_positions.clone();
let callback: Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>> =
Arc::new(RwLock::new(Some(Box::new(move |pos| {
callback_positions_clone.lock().unwrap().push(pos);
}))));
let last_pos = Arc::new(AtomicF64::new(0.0));
ProgressTracker::update_from_callback_refs(
&consumed,
&sample_rate,
&channels,
&callback,
&last_pos,
88200,
0.5,
);
assert_eq!(consumed.load(Ordering::SeqCst), 88200);
assert!(callback_positions.lock().unwrap().is_empty());
}
#[test_log::test]
#[allow(clippy::type_complexity)]
fn test_update_from_callback_refs_below_threshold() {
let consumed = Arc::new(AtomicUsize::new(0));
let sample_rate = Arc::new(AtomicU32::new(44100));
let channels = Arc::new(AtomicU32::new(2));
let callback_positions = Arc::new(Mutex::new(Vec::new()));
let callback_positions_clone = callback_positions.clone();
let callback: Arc<RwLock<Option<Box<dyn Fn(f64) + Send + Sync + 'static>>>> =
Arc::new(RwLock::new(Some(Box::new(move |pos| {
callback_positions_clone.lock().unwrap().push(pos);
}))));
let last_pos = Arc::new(AtomicF64::new(0.0));
ProgressTracker::update_from_callback_refs(
&consumed,
&sample_rate,
&channels,
&callback,
&last_pos,
22050, 0.5, );
assert_eq!(consumed.load(Ordering::SeqCst), 22050);
assert!(callback_positions.lock().unwrap().is_empty());
}
}