use std::time::Instant;
use super::params::GraphAlgorithm;
#[derive(Debug, Clone)]
pub struct AlgoProgress {
pub algorithm: GraphAlgorithm,
pub iteration: usize,
pub max_iterations: usize,
pub convergence_delta: Option<f64>,
pub tolerance: Option<f64>,
pub elapsed_ms: u64,
pub nodes_processed: usize,
pub converged: bool,
}
pub struct ProgressReporter {
algorithm: GraphAlgorithm,
max_iterations: usize,
tolerance: Option<f64>,
node_count: usize,
start: Instant,
last_delta: Option<f64>,
}
impl ProgressReporter {
pub fn new(
algorithm: GraphAlgorithm,
max_iterations: usize,
tolerance: Option<f64>,
node_count: usize,
) -> Self {
tracing::info!(
algorithm = algorithm.name(),
max_iterations,
node_count,
"graph algorithm started"
);
Self {
algorithm,
max_iterations,
tolerance,
node_count,
start: Instant::now(),
last_delta: None,
}
}
pub fn report_iteration(&mut self, iteration: usize, convergence_delta: Option<f64>) {
self.last_delta = convergence_delta;
let elapsed_ms = self.start.elapsed().as_millis() as u64;
let converged = match (convergence_delta, self.tolerance) {
(Some(delta), Some(tol)) => delta < tol,
_ => false,
};
tracing::debug!(
algorithm = self.algorithm.name(),
iteration,
max_iterations = self.max_iterations,
convergence_delta = convergence_delta.unwrap_or(0.0),
elapsed_ms,
converged,
"graph algorithm iteration"
);
}
pub fn snapshot(&self, iteration: usize) -> AlgoProgress {
let converged = match (self.last_delta, self.tolerance) {
(Some(delta), Some(tol)) => delta < tol,
_ => false,
};
AlgoProgress {
algorithm: self.algorithm,
iteration,
max_iterations: self.max_iterations,
convergence_delta: self.last_delta,
tolerance: self.tolerance,
elapsed_ms: self.start.elapsed().as_millis() as u64,
nodes_processed: self.node_count,
converged,
}
}
pub fn finish(&self) {
let elapsed_ms = self.start.elapsed().as_millis() as u64;
let converged = match (self.last_delta, self.tolerance) {
(Some(delta), Some(tol)) => delta < tol,
_ => true, };
tracing::info!(
algorithm = self.algorithm.name(),
elapsed_ms,
converged,
final_delta = self.last_delta.unwrap_or(0.0),
node_count = self.node_count,
"graph algorithm completed"
);
}
pub fn elapsed_ms(&self) -> u64 {
self.start.elapsed().as_millis() as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn progress_reporter_lifecycle() {
let mut reporter = ProgressReporter::new(GraphAlgorithm::PageRank, 20, Some(1e-7), 1000);
reporter.report_iteration(1, Some(0.5));
reporter.report_iteration(2, Some(0.01));
reporter.report_iteration(3, Some(1e-8));
let snap = reporter.snapshot(3);
assert!(snap.converged);
assert_eq!(snap.algorithm, GraphAlgorithm::PageRank);
assert_eq!(snap.max_iterations, 20);
assert_eq!(snap.nodes_processed, 1000);
reporter.finish();
}
#[test]
fn progress_non_iterative() {
let reporter = ProgressReporter::new(GraphAlgorithm::Wcc, 1, None, 500);
let snap = reporter.snapshot(1);
assert!(!snap.converged); assert!(snap.convergence_delta.is_none());
reporter.finish();
}
#[test]
fn progress_not_converged() {
let mut reporter = ProgressReporter::new(GraphAlgorithm::LabelPropagation, 10, None, 100);
reporter.report_iteration(1, Some(50.0));
let snap = reporter.snapshot(1);
assert!(!snap.converged); }
#[test]
fn elapsed_increases() {
let reporter = ProgressReporter::new(GraphAlgorithm::Sssp, 1, None, 10);
std::thread::sleep(std::time::Duration::from_millis(5));
assert!(reporter.elapsed_ms() >= 4);
}
}