use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct WriteProgress {
elapsed: Duration,
output_rows: usize,
output_bytes: usize,
total_rows: Option<usize>,
active_tasks: usize,
total_tasks: usize,
done: bool,
}
impl WriteProgress {
pub fn elapsed(&self) -> Duration {
self.elapsed
}
pub fn output_rows(&self) -> usize {
self.output_rows
}
pub fn output_bytes(&self) -> usize {
self.output_bytes
}
pub fn total_rows(&self) -> Option<usize> {
self.total_rows
}
pub fn active_tasks(&self) -> usize {
self.active_tasks
}
pub fn total_tasks(&self) -> usize {
self.total_tasks
}
pub fn done(&self) -> bool {
self.done
}
}
pub type ProgressCallback = Arc<Mutex<dyn FnMut(&WriteProgress) + Send>>;
impl std::fmt::Debug for WriteProgressTracker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WriteProgressTracker")
.field("total_rows", &self.total_rows)
.finish()
}
}
pub(crate) struct WriteProgressTracker {
rows_and_bytes: std::sync::Mutex<(usize, usize)>,
wire_bytes: AtomicUsize,
active_tasks: Arc<AtomicUsize>,
total_tasks: AtomicUsize,
start: Instant,
total_rows: Option<usize>,
callback: ProgressCallback,
}
impl WriteProgressTracker {
pub fn new(callback: ProgressCallback, total_rows: Option<usize>) -> Self {
Self {
rows_and_bytes: std::sync::Mutex::new((0, 0)),
wire_bytes: AtomicUsize::new(0),
active_tasks: Arc::new(AtomicUsize::new(0)),
total_tasks: AtomicUsize::new(1),
start: Instant::now(),
total_rows,
callback,
}
}
pub fn set_total_tasks(&self, n: usize) {
self.total_tasks.store(n, Ordering::Relaxed);
}
pub fn track_task(&self) -> ActiveTaskGuard {
self.active_tasks.fetch_add(1, Ordering::Relaxed);
ActiveTaskGuard(self.active_tasks.clone())
}
pub fn record_batch(&self, rows: usize, bytes: usize) {
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
let mut guard = self
.rows_and_bytes
.lock()
.unwrap_or_else(|e| e.into_inner());
guard.0 += rows;
guard.1 += bytes;
let progress = self.snapshot(guard.0, guard.1, false);
drop(guard);
cb(&progress);
}
pub fn record_bytes(&self, bytes: usize) {
self.wire_bytes.fetch_add(bytes, Ordering::Relaxed);
}
pub fn finish(&self) {
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
let guard = self
.rows_and_bytes
.lock()
.unwrap_or_else(|e| e.into_inner());
let mut snap = self.snapshot(guard.0, guard.1, true);
snap.total_rows = Some(self.total_rows.unwrap_or(guard.0));
drop(guard);
cb(&snap);
}
fn snapshot(&self, rows: usize, in_memory_bytes: usize, done: bool) -> WriteProgress {
let wire = self.wire_bytes.load(Ordering::Relaxed);
let output_bytes = if wire > 0 { wire } else { in_memory_bytes };
WriteProgress {
elapsed: self.start.elapsed(),
output_rows: rows,
output_bytes,
total_rows: self.total_rows,
active_tasks: self.active_tasks.load(Ordering::Relaxed),
total_tasks: self.total_tasks.load(Ordering::Relaxed),
done,
}
}
}
pub(crate) struct ActiveTaskGuard(Arc<AtomicUsize>);
impl Drop for ActiveTaskGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
pub(crate) struct FinishOnDrop(pub Option<Arc<WriteProgressTracker>>);
impl Drop for FinishOnDrop {
fn drop(&mut self) {
if let Some(t) = self.0.take() {
t.finish();
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow_array::record_batch;
use crate::connect;
#[tokio::test]
async fn test_progress_monitor_fires_callback() {
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("progress_test", batch)
.execute()
.await
.unwrap();
let callback_count = Arc::new(AtomicUsize::new(0));
let last_rows = Arc::new(AtomicUsize::new(0));
let max_active = Arc::new(AtomicUsize::new(0));
let last_total_tasks = Arc::new(AtomicUsize::new(0));
let cb_count = callback_count.clone();
let cb_rows = last_rows.clone();
let cb_active = max_active.clone();
let cb_total_tasks = last_total_tasks.clone();
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
table
.add(new_data)
.progress(move |p| {
cb_count.fetch_add(1, Ordering::SeqCst);
cb_rows.store(p.output_rows(), Ordering::SeqCst);
cb_active.fetch_max(p.active_tasks(), Ordering::SeqCst);
cb_total_tasks.store(p.total_tasks(), Ordering::SeqCst);
})
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 6);
assert!(callback_count.load(Ordering::SeqCst) >= 1);
assert_eq!(last_rows.load(Ordering::SeqCst), 3);
assert!(max_active.load(Ordering::SeqCst) >= 1);
assert!(last_total_tasks.load(Ordering::SeqCst) >= 1);
}
#[tokio::test]
async fn test_progress_done_fires_at_end() {
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("progress_done", batch)
.execute()
.await
.unwrap();
let seen_done = Arc::new(std::sync::Mutex::new(Vec::<bool>::new()));
let seen = seen_done.clone();
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
table
.add(new_data)
.progress(move |p| {
seen.lock().unwrap().push(p.done());
})
.execute()
.await
.unwrap();
let done_flags = seen_done.lock().unwrap();
assert!(!done_flags.is_empty(), "at least one callback must fire");
let last = *done_flags.last().unwrap();
assert!(last, "last callback must have done=true");
for &d in done_flags.iter().rev().skip(1) {
assert!(!d, "non-final callbacks must have done=false");
}
}
#[tokio::test]
async fn test_progress_total_rows_known() {
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("total_known", batch)
.execute()
.await
.unwrap();
let seen_total = Arc::new(std::sync::Mutex::new(Vec::new()));
let seen = seen_total.clone();
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
table
.add(new_data)
.progress(move |p| {
seen.lock().unwrap().push(p.total_rows());
})
.execute()
.await
.unwrap();
let totals = seen_total.lock().unwrap();
assert!(
totals.contains(&Some(3)),
"expected total_rows=Some(3) in at least one callback, got: {:?}",
*totals
);
}
#[tokio::test]
async fn test_progress_total_rows_unknown() {
use arrow_array::RecordBatchIterator;
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("total_unknown", batch)
.execute()
.await
.unwrap();
let seen_total = Arc::new(std::sync::Mutex::new(Vec::new()));
let seen = seen_total.clone();
let schema = arrow_schema::Schema::new(vec![arrow_schema::Field::new(
"id",
arrow_schema::DataType::Int32,
false,
)]);
let new_data: Box<dyn arrow_array::RecordBatchReader + Send> =
Box::new(RecordBatchIterator::new(
vec![Ok(record_batch!(("id", Int32, [4, 5, 6])).unwrap())],
Arc::new(schema),
));
table
.add(new_data)
.progress(move |p| {
seen.lock().unwrap().push((p.total_rows(), p.done()));
})
.execute()
.await
.unwrap();
let entries = seen_total.lock().unwrap();
assert!(!entries.is_empty(), "at least one callback must fire");
for (total, done) in entries.iter() {
if *done {
assert!(
total.is_some(),
"done callback must have total_rows set, got: {:?}",
total
);
} else {
assert_eq!(
*total, None,
"intermediate callback must have total_rows=None, got: {:?}",
total
);
}
}
}
#[test]
fn test_record_batch_recovers_from_poisoned_callback_lock() {
use super::{ProgressCallback, WriteProgressTracker};
use std::sync::Mutex;
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
let cb_clone = callback.clone();
let handle = std::thread::spawn(move || {
let _guard = cb_clone.lock().unwrap();
panic!("intentional panic to poison callback mutex");
});
let _ = handle.join();
assert!(
callback.lock().is_err(),
"callback mutex should be poisoned"
);
let tracker = WriteProgressTracker::new(callback, Some(100));
tracker.record_batch(10, 1024);
}
#[test]
fn test_finish_recovers_from_poisoned_callback_lock() {
use super::{ProgressCallback, WriteProgressTracker};
use std::sync::Mutex;
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
let cb_clone = callback.clone();
let handle = std::thread::spawn(move || {
let _guard = cb_clone.lock().unwrap();
panic!("intentional panic to poison callback mutex");
});
let _ = handle.join();
let tracker = WriteProgressTracker::new(callback, Some(100));
tracker.finish();
}
}