use std::sync::{atomic::AtomicBool, Arc, Mutex};
use crate::EncodingError;
trait ThreadSafeReporter: Send {
fn report(&mut self, progress: f32);
}
impl<F: FnMut(f32) + Send> ThreadSafeReporter for F {
fn report(&mut self, progress: f32) {
self(progress);
}
}
enum ProgressInner<'a> {
Send(&'a mut dyn ThreadSafeReporter),
Bound(&'a mut dyn FnMut(f32)),
None,
}
#[derive(Clone, Copy)]
pub(crate) struct ProgressRange {
pub start: f32,
pub length: f32,
}
impl ProgressRange {
pub const FULL: Self = Self {
start: 0.0,
length: 1.0,
};
pub fn from_to(from: f32, to: f32) -> Self {
debug_assert!(from <= to);
Self {
start: from,
length: to - from,
}
}
pub fn sub_range(&self, other: Self) -> Self {
Self {
start: self.start + other.start * self.length,
length: other.length * self.length,
}
}
pub fn project(&self, progress: f32) -> f32 {
debug_assert!((0.0..=1.0).contains(&progress));
self.start + self.length * progress
}
}
impl Default for ProgressRange {
fn default() -> Self {
Self::FULL
}
}
pub struct Progress<'a> {
reporter: ProgressInner<'a>,
range: ProgressRange,
cancel: Option<CancellationToken>,
}
impl<'a> Progress<'a> {
pub fn new<F: FnMut(f32) + Send>(reporter: &'a mut F) -> Self {
Self {
reporter: ProgressInner::Send(reporter),
range: ProgressRange::FULL,
cancel: None,
}
}
pub fn new_single_threaded<F: FnMut(f32)>(reporter: &'a mut F) -> Self {
Self {
reporter: ProgressInner::Bound(reporter),
range: ProgressRange::FULL,
cancel: None,
}
}
pub fn none() -> Self {
Self {
reporter: ProgressInner::None,
range: ProgressRange::FULL,
cancel: None,
}
}
pub fn with_cancellation(mut self, token: &CancellationToken) -> Self {
self.cancel = Some(token.clone());
self
}
pub fn report(&mut self, progress: f32) {
let progress = self.range.project(progress);
match &mut self.reporter {
ProgressInner::Send(report) => report.report(progress),
ProgressInner::Bound(report) => report(progress),
ProgressInner::None => {}
}
}
pub fn is_cancelled(&self) -> bool {
if let Some(cancel) = &self.cancel {
cancel.is_cancelled()
} else {
false
}
}
pub(crate) fn sub_range(&mut self, range: ProgressRange) -> Progress<'_> {
let range = self.range.sub_range(range);
Progress {
reporter: match &mut self.reporter {
ProgressInner::Send(f) => ProgressInner::Send(*f),
ProgressInner::Bound(f) => ProgressInner::Bound(*f),
ProgressInner::None => ProgressInner::None,
},
range,
cancel: self.cancel.clone(),
}
}
pub(crate) fn check_cancelled(&self) -> Result<(), EncodingError> {
if self.is_cancelled() {
Err(EncodingError::Cancelled)
} else {
Ok(())
}
}
pub(crate) fn checked_report(&mut self, progress: f32) -> Result<(), EncodingError> {
self.check_cancelled()?;
self.report(progress);
Ok(())
}
pub(crate) fn checked_report_if(
&mut self,
cond: bool,
progress: f32,
) -> Result<(), EncodingError> {
if cond {
self.checked_report(progress)
} else {
self.check_cancelled()
}
}
}
pub struct CancellationToken {
cancelled: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
cancelled: Arc::new(AtomicBool::new(false)),
}
}
pub fn cancel(&self) {
self.cancelled
.store(true, std::sync::atomic::Ordering::SeqCst);
}
pub fn reset(&self) {
self.cancelled
.store(false, std::sync::atomic::Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
impl Clone for CancellationToken {
fn clone(&self) -> Self {
CancellationToken {
cancelled: self.cancelled.clone(),
}
}
}
const _: () = {
fn assert_send_sync<T: Send + Sync>() {}
let _ = assert_send_sync::<CancellationToken>;
};
type InnerState<'a> = (u64, &'a mut dyn ThreadSafeReporter);
pub(crate) struct ParallelProgress<'a> {
progress: Option<Mutex<InnerState<'a>>>,
total: u64,
range: ProgressRange,
cancel: Option<CancellationToken>,
}
impl<'a> ParallelProgress<'a> {
pub fn new(progress: &'a mut &mut Progress, total: u64) -> Self {
Self {
progress: match &mut progress.reporter {
ProgressInner::Send(f) => Some(Mutex::new((0, *f))),
_ => None,
},
total,
range: progress.range,
cancel: progress.cancel.clone(),
}
}
pub fn submit(&self, progress: u64) {
if let Some(mutex) = self.progress.as_ref() {
let mut guard = mutex.lock().unwrap();
guard.0 += progress;
let progress = self.range.project(guard.0 as f32 / self.total as f32);
guard.1.report(progress);
}
}
pub fn is_cancelled(&self) -> bool {
self.cancel
.as_ref()
.map(|c| c.is_cancelled())
.unwrap_or(false)
}
pub fn check_cancelled(&self) -> Result<(), EncodingError> {
if self.is_cancelled() {
Err(EncodingError::Cancelled)
} else {
Ok(())
}
}
}