use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufWriter, Write};
use std::time::{Duration, Instant};
use crate::algebra::prelude::*;
use crate::preconditioner::stats::ParIluIterSample;
pub enum Event<'a> {
IluSetupBegin {
opts_hash: u64,
},
IluSetupIter {
sample: &'a ParIluIterSample,
},
IluSetupEnd {
iters: u32,
converged: bool,
setup_time_s: R,
},
}
pub trait Monitor: Send + Sync {
fn on_event(&self, ev: Event<'_>);
}
pub struct NullMonitor;
impl Monitor for NullMonitor {
fn on_event(&self, _: Event<'_>) {}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ResidualSnapshot {
pub true_residual: R,
pub preconditioned_residual: R,
pub recurrence_residual: Option<R>,
}
#[inline]
pub fn log_residuals(iteration: usize, solver: &str, snapshot: ResidualSnapshot) {
#[cfg(feature = "logging")]
{
if log::log_enabled!(log::Level::Info) {
match snapshot.recurrence_residual {
Some(recur) => log::info!(
"{solver}: it {iteration:>4} true={:.3e} prec={:.3e} recur={:.3e}",
snapshot.true_residual,
snapshot.preconditioned_residual,
recur,
),
None => log::info!(
"{solver}: it {iteration:>4} true={:.3e} prec={:.3e}",
snapshot.true_residual,
snapshot.preconditioned_residual,
),
}
}
}
#[cfg(not(feature = "logging"))]
let _ = (iteration, solver, snapshot);
}
#[inline]
pub fn stagnation_detected(recent: &[R], threshold: R) -> bool {
if recent.len() < 2 {
return false;
}
let mut ratios = Vec::with_capacity(recent.len() - 1);
for window in recent.windows(2) {
let prev = window[0];
let cur = window[1];
if prev <= R::default() {
return false;
}
ratios.push(cur / prev);
}
let sum = ratios.iter().copied().sum::<R>();
let avg = sum / S::from_real(ratios.len() as f64).real();
avg > threshold
}
#[inline]
pub fn log_krylov_stagnation(solver: &str, iteration: usize, residual: R, action: &str) {
#[cfg(feature = "logging")]
if log::log_enabled!(log::Level::Warn) {
log::warn!(
"{solver}: stagnation detected at it {iteration} (res={residual:.3e}); {action}"
);
}
#[cfg(not(feature = "logging"))]
let _ = (solver, iteration, residual, action);
}
pub struct TextMonitor {
pub rank0: bool,
}
impl Monitor for TextMonitor {
fn on_event(&self, ev: Event<'_>) {
#[cfg(not(feature = "logging"))]
let _ = ev;
#[cfg(feature = "logging")]
if self.rank0 {
match ev {
Event::IluSetupBegin { opts_hash } => {
log::info!("ILU: setup begin (opts={opts_hash:016x})")
}
Event::IluSetupIter { sample } => log::info!(
"ILU: it {:>3} parilu_res≈{:.3e} dt={:.3e}s",
sample.iter,
sample.residual,
sample.time_s,
),
Event::IluSetupEnd {
iters,
converged,
setup_time_s,
} => log::info!(
"ILU: setup end iters={iters} converged={converged} time={setup_time_s:.3}s"
),
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvergenceStats {
pub total_iterations: usize,
pub initial_residual: R,
pub final_residual: R,
pub avg_convergence_rate: R,
pub best_convergence_rate: R,
pub worst_convergence_rate: R,
pub total_solve_time: Duration,
pub avg_iteration_time: Duration,
pub avg_pc_time: Option<Duration>,
pub converged: bool,
pub convergence_reason: String,
}
#[derive(Debug, Clone)]
pub struct IterationData {
pub iteration: usize,
pub residual_norm: R,
pub convergence_rate: Option<R>,
pub iteration_time: Duration,
pub pc_time: Option<Duration>,
pub timestamp: Instant,
}
pub struct IterationMonitor {
history: Vec<IterationData>,
csv_writer: Option<BufWriter<File>>,
solve_start_time: Option<Instant>,
max_history: usize,
compute_rates: bool,
converged: bool,
convergence_reason: String,
}
impl IterationMonitor {
pub fn new_with_capacity(max_history: usize) -> Self {
Self {
history: Vec::new(),
csv_writer: None,
solve_start_time: None,
max_history,
compute_rates: true,
converged: false,
convergence_reason: "In progress".to_string(),
}
}
pub fn new() -> Self {
Self::new_with_capacity(1000) }
pub fn enable_csv_logging(&mut self, filename: &str) -> Result<(), std::io::Error> {
let file = File::create(filename)?;
let mut writer = BufWriter::new(file);
writeln!(
writer,
"iteration,residual_norm,convergence_rate,iteration_time_ms,pc_time_ms,elapsed_s"
)?;
self.csv_writer = Some(writer);
Ok(())
}
pub fn start_solve(&mut self) {
self.solve_start_time = Some(Instant::now());
self.history.clear();
self.converged = false;
self.convergence_reason = "In progress".to_string();
}
pub fn record_iteration(
&mut self,
iteration: usize,
residual_norm: R,
pc_time: Option<Duration>,
) {
let now = Instant::now();
let iteration_time = if iteration == 0 {
Duration::from_nanos(0)
} else if let Some(prev) = self.history.last() {
now.duration_since(prev.timestamp)
} else {
Duration::from_nanos(0)
};
let convergence_rate = if iteration > 0 && self.compute_rates {
if let Some(prev) = self.history.last() {
if prev.residual_norm > R::default() && residual_norm > R::default() {
Some(residual_norm / prev.residual_norm)
} else {
None
}
} else {
None
}
} else {
None
};
let iter_data = IterationData {
iteration,
residual_norm,
convergence_rate,
iteration_time,
pc_time,
timestamp: now,
};
if self.max_history > 0 && self.history.len() >= self.max_history {
self.history.remove(0); }
self.history.push(iter_data.clone());
if let Some(ref mut writer) = self.csv_writer {
let rate_str = iter_data
.convergence_rate
.map(|r| format!("{r:.6e}"))
.unwrap_or_default();
let pc_time_str = iter_data
.pc_time
.map(|t| format!("{:.3}", t.as_secs_f64() * 1000.0))
.unwrap_or_default();
let elapsed_s = self
.solve_start_time
.map(|t0| t0.elapsed().as_secs_f64())
.unwrap_or_default();
let _ = writeln!(
writer,
"{},{:.6e},{},{:.3},{},{:.6}",
iteration,
residual_norm,
rate_str,
iteration_time.as_secs_f64() * 1000.0,
pc_time_str,
elapsed_s
);
let _ = writer.flush();
}
}
pub fn mark_converged(&mut self, reason: &str) {
self.converged = true;
self.convergence_reason = reason.to_string();
}
pub fn mark_diverged(&mut self, reason: &str) {
self.converged = false;
self.convergence_reason = reason.to_string();
}
pub fn get_statistics(&self) -> ConvergenceStats {
let total_iterations = self.history.len();
let initial_residual = self
.history
.first()
.map(|d| d.residual_norm)
.unwrap_or_default();
let final_residual = self
.history
.last()
.map(|d| d.residual_norm)
.unwrap_or_default();
let total_solve_time = self
.solve_start_time
.and_then(|start| {
self.history
.last()
.map(|last| last.timestamp.duration_since(start))
})
.unwrap_or_default();
let avg_iteration_time = if total_iterations > 1 {
let total_nanos = self
.history
.iter()
.skip(1) .map(|d| d.iteration_time.as_nanos())
.sum::<u128>();
let avg_nanos = (total_nanos / ((total_iterations - 1) as u128)) as u64;
Duration::from_nanos(avg_nanos)
} else {
Duration::from_nanos(0)
};
let avg_pc_time = {
let pc_times: Vec<_> = self.history.iter().filter_map(|d| d.pc_time).collect();
if !pc_times.is_empty() {
let total_pc_nanos: u128 = pc_times.iter().map(|t| t.as_nanos()).sum();
let avg_nanos = (total_pc_nanos / (pc_times.len() as u128)) as u64;
Some(Duration::from_nanos(avg_nanos))
} else {
None
}
};
let rates: Vec<R> = self
.history
.iter()
.filter_map(|d| d.convergence_rate)
.collect();
let (avg_convergence_rate, best_convergence_rate, worst_convergence_rate) =
if !rates.is_empty() {
let len_r = S::from_real(rates.len() as f64).real();
let avg = rates.iter().copied().sum::<R>() / len_r;
let best = rates.iter().fold(R::INFINITY, |a, &b| a.min(b));
let worst = rates.iter().fold(R::default(), |a, &b| a.max(b));
(avg, best, worst)
} else {
let one = S::one().real();
(one, one, one) };
ConvergenceStats {
total_iterations,
initial_residual,
final_residual,
avg_convergence_rate,
best_convergence_rate,
worst_convergence_rate,
total_solve_time,
avg_iteration_time,
avg_pc_time,
converged: self.converged,
convergence_reason: self.convergence_reason.clone(),
}
}
pub fn current_iteration(&self) -> usize {
self.history.len()
}
pub fn current_residual(&self) -> Option<R> {
self.history.last().map(|d| d.residual_norm)
}
pub fn recent_convergence_rate(&self, window: usize) -> Option<R> {
if self.history.len() < 2 {
return None;
}
let start_idx = self
.history
.len()
.saturating_sub(window.min(self.history.len()));
let recent_rates: Vec<R> = self.history[start_idx..]
.iter()
.filter_map(|d| d.convergence_rate)
.collect();
if recent_rates.is_empty() {
None
} else {
let len_r = S::from_real(recent_rates.len() as f64).real();
Some(recent_rates.iter().copied().sum::<R>() / len_r)
}
}
pub fn is_stagnating(&self, threshold: R, window: usize) -> bool {
if let Some(recent_rate) = self.recent_convergence_rate(window) {
recent_rate > threshold
} else {
false
}
}
}
impl Default for IterationMonitor {
fn default() -> Self {
Self::new()
}
}
impl Drop for IterationMonitor {
fn drop(&mut self) {
if let Some(ref mut writer) = self.csv_writer {
let _ = writer.flush();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_monitor_basic_functionality() {
let mut monitor = IterationMonitor::new();
monitor.start_solve();
monitor.record_iteration(0, S::one().real(), None);
monitor.record_iteration(1, S::from_real(0.5).real(), Some(Duration::from_millis(10)));
monitor.record_iteration(
2,
S::from_real(0.25).real(),
Some(Duration::from_millis(12)),
);
monitor.mark_converged("Relative tolerance achieved");
let stats = monitor.get_statistics();
assert_eq!(stats.total_iterations, 3);
crate::assert_s_close!(
"initial residual",
S::from_real(stats.initial_residual),
S::one()
);
crate::assert_s_close!(
"final residual",
S::from_real(stats.final_residual),
S::from_real(0.25)
);
assert!(stats.converged);
assert!(stats.avg_convergence_rate < S::one().real()); }
#[test]
fn test_convergence_rate_calculation() {
let mut monitor = IterationMonitor::new();
monitor.start_solve();
monitor.record_iteration(0, S::one().real(), None);
monitor.record_iteration(1, S::from_real(0.1).real(), None); monitor.record_iteration(2, S::from_real(0.01).real(), None);
let recent_rate = monitor.recent_convergence_rate(2);
crate::assert_s_close!(
"recent rate",
S::from_real(recent_rate.unwrap()),
S::from_real(0.1)
);
}
#[test]
fn test_stagnation_detection() {
let mut monitor = IterationMonitor::new();
monitor.start_solve();
monitor.record_iteration(0, S::one().real(), None);
monitor.record_iteration(1, S::from_real(0.99).real(), None); monitor.record_iteration(2, S::from_real(0.98).real(), None);
assert!(monitor.is_stagnating(S::from_real(0.95).real(), 2));
assert!(!monitor.is_stagnating(S::from_real(0.999).real(), 2));
}
}