use std::panic;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use crate::progress::{IndexProgress, SharedReporter};
fn safe_report(reporter: &SharedReporter, event: IndexProgress) {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
reporter.report(event);
}));
if let Err(e) = result {
let msg = if let Some(s) = e.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
log::warn!("Progress reporter panicked (ignored): {msg}");
}
}
const MIN_UPDATE_INTERVAL: Duration = Duration::from_millis(17);
pub struct GraphBuildProgressTracker {
reporter: SharedReporter,
phase_state: Mutex<PhaseState>,
items_processed: AtomicUsize,
total_items: AtomicUsize,
}
struct PhaseState {
phase_number: u8,
phase_name: &'static str,
phase_start: Instant,
last_update: Instant,
}
impl Default for PhaseState {
fn default() -> Self {
Self {
phase_number: 0,
phase_name: "",
phase_start: Instant::now(),
last_update: Instant::now(),
}
}
}
impl GraphBuildProgressTracker {
#[must_use]
pub fn new(reporter: SharedReporter) -> Self {
Self {
reporter,
phase_state: Mutex::new(PhaseState::default()),
items_processed: AtomicUsize::new(0),
total_items: AtomicUsize::new(0),
}
}
pub fn start_phase(&self, phase_number: u8, phase_name: &'static str, total_items: usize) {
self.items_processed.store(0, Ordering::SeqCst);
self.total_items.store(total_items, Ordering::SeqCst);
{
let mut state = self.phase_state.lock().unwrap();
state.phase_number = phase_number;
state.phase_name = phase_name;
state.phase_start = Instant::now();
state.last_update = Instant::now();
}
safe_report(
&self.reporter,
IndexProgress::GraphPhaseStarted {
phase_number,
phase_name,
total_items,
},
);
}
pub fn increment_progress(&self) {
self.add_progress(1);
}
pub fn add_progress(&self, count: usize) {
let new_count = self.items_processed.fetch_add(count, Ordering::SeqCst) + count;
self.maybe_emit_progress(new_count);
}
fn maybe_emit_progress(&self, items_processed: usize) {
let total = self.total_items.load(Ordering::SeqCst);
let emit_info = {
let Ok(mut state) = self.phase_state.try_lock() else {
return;
};
let now = Instant::now();
if now.duration_since(state.last_update) >= MIN_UPDATE_INTERVAL {
state.last_update = now;
Some(state.phase_number)
} else {
None
}
};
if let Some(phase_number) = emit_info {
safe_report(
&self.reporter,
IndexProgress::GraphPhaseProgress {
phase_number,
items_processed,
total_items: total,
},
);
}
}
pub fn complete_phase(&self) {
let (phase_number, phase_name, phase_duration) = {
let state = self.phase_state.lock().unwrap();
(
state.phase_number,
state.phase_name,
state.phase_start.elapsed(),
)
};
safe_report(
&self.reporter,
IndexProgress::GraphPhaseCompleted {
phase_number,
phase_name,
phase_duration,
},
);
}
pub fn start_saving(&self, component_name: &'static str) {
safe_report(
&self.reporter,
IndexProgress::SavingStarted { component_name },
);
}
pub fn complete_saving(&self, component_name: &'static str, save_duration: Duration) {
safe_report(
&self.reporter,
IndexProgress::SavingCompleted {
component_name,
save_duration,
},
);
}
#[cfg(test)]
pub fn current_progress(&self) -> usize {
self.items_processed.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::progress::no_op_reporter;
use std::sync::Arc;
struct EventCapture {
events: Mutex<Vec<IndexProgress>>,
}
impl EventCapture {
fn new() -> Arc<Self> {
Arc::new(Self {
events: Mutex::new(Vec::new()),
})
}
fn events(&self) -> Vec<IndexProgress> {
self.events.lock().unwrap().clone()
}
fn event_count(&self) -> usize {
self.events.lock().unwrap().len()
}
}
impl crate::progress::ProgressReporter for EventCapture {
fn report(&self, event: IndexProgress) {
self.events.lock().unwrap().push(event);
}
}
#[test]
fn test_phase_lifecycle() {
let capture = EventCapture::new();
let tracker = GraphBuildProgressTracker::new(capture.clone());
tracker.start_phase(1, "Test phase", 100);
tracker.complete_phase();
let events = capture.events();
assert_eq!(events.len(), 2);
assert!(matches!(
events[0],
IndexProgress::GraphPhaseStarted {
phase_number: 1,
phase_name: "Test phase",
total_items: 100
}
));
assert!(matches!(
events[1],
IndexProgress::GraphPhaseCompleted {
phase_number: 1,
phase_name: "Test phase",
..
}
));
}
#[test]
fn test_progress_increment() {
let capture = EventCapture::new();
let tracker = GraphBuildProgressTracker::new(capture.clone());
tracker.start_phase(2, "Increment test", 10);
tracker.increment_progress();
assert_eq!(tracker.current_progress(), 1);
tracker.complete_phase();
assert!(capture.event_count() >= 2);
}
#[test]
fn test_saving_events() {
let capture = EventCapture::new();
let tracker = GraphBuildProgressTracker::new(capture.clone());
tracker.start_saving("symbols");
tracker.complete_saving("symbols", Duration::from_millis(100));
let events = capture.events();
assert_eq!(events.len(), 2);
assert!(matches!(
events[0],
IndexProgress::SavingStarted {
component_name: "symbols"
}
));
assert!(matches!(
events[1],
IndexProgress::SavingCompleted {
component_name: "symbols",
..
}
));
}
#[test]
fn test_no_op_reporter_no_panic() {
let tracker = GraphBuildProgressTracker::new(no_op_reporter());
tracker.start_phase(1, "No-op test", 1000);
for _ in 0..1000 {
tracker.increment_progress();
}
tracker.complete_phase();
}
#[test]
fn test_throttling_limits_updates() {
let capture = EventCapture::new();
let tracker = GraphBuildProgressTracker::new(capture.clone());
tracker.start_phase(3, "Throttle test", 10000);
for _ in 0..1000 {
tracker.increment_progress();
}
tracker.complete_phase();
let progress_events = capture
.events()
.iter()
.filter(|e| matches!(e, IndexProgress::GraphPhaseProgress { .. }))
.count();
assert!(
progress_events < 100,
"Expected throttling to limit updates"
);
}
struct PanickingReporter;
impl crate::progress::ProgressReporter for PanickingReporter {
fn report(&self, _event: IndexProgress) {
panic!("Intentional test panic from PanickingReporter");
}
}
#[test]
fn test_safe_report_catches_panics() {
let reporter: SharedReporter = Arc::new(PanickingReporter);
safe_report(
&reporter,
IndexProgress::SavingStarted {
component_name: "test",
},
);
}
#[test]
fn test_tracker_with_panicking_reporter_continues() {
let tracker = GraphBuildProgressTracker::new(Arc::new(PanickingReporter));
tracker.start_phase(1, "Panic test", 100);
tracker.increment_progress();
tracker.add_progress(5);
tracker.complete_phase();
tracker.start_saving("test");
tracker.complete_saving("test", Duration::from_millis(10));
assert_eq!(tracker.current_progress(), 6);
}
}