#[cfg(feature = "visualization")]
pub mod conversions;
#[cfg(feature = "visualization")]
pub mod visualization;
#[cfg(feature = "visualization")]
pub use visualization::{RerunObserver, VisualizationConfig, VisualizationMode};
#[cfg(feature = "visualization")]
pub use conversions::{CollectRerun2D, CollectRerun3D, RerunConvert2D, RerunConvert3D};
use crate::core::problem::VariableEnum;
use faer::Mat;
use faer::sparse;
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Clone, Error)]
pub enum ObserverError {
#[error("Failed to initialize Rerun recording stream: {0}")]
RerunInitialization(String),
#[error("Failed to spawn Rerun viewer: {0}")]
ViewerSpawnFailed(String),
#[error("Failed to save recording to file '{path}': {reason}")]
RecordingSaveFailed { path: String, reason: String },
#[error("Failed to log data to Rerun at '{entity_path}': {reason}")]
LoggingFailed { entity_path: String, reason: String },
#[error("Failed to convert matrix to image: {0}")]
MatrixVisualizationFailed(String),
#[error("Failed to create tensor data: {0}")]
TensorConversionFailed(String),
#[error("Recording stream is in invalid state: {0}")]
InvalidState(String),
#[error("Mutex poisoned in {context}: {reason}")]
MutexPoisoned { context: String, reason: String },
}
pub type ObserverResult<T> = Result<T, ObserverError>;
pub trait OptObserver: Send {
fn on_step(&self, values: &HashMap<String, VariableEnum>, iteration: usize);
fn set_iteration_metrics(
&self,
_cost: f64,
_gradient_norm: f64,
_damping: Option<f64>,
_step_norm: f64,
_step_quality: Option<f64>,
) {
}
fn set_matrix_data(
&self,
_hessian: Option<sparse::SparseColMat<usize, f64>>,
_gradient: Option<Mat<f64>>,
) {
}
fn on_optimization_complete(
&self,
_values: &HashMap<String, VariableEnum>,
_iterations: usize,
) {
}
}
#[derive(Default)]
pub struct OptObserverVec {
observers: Vec<Box<dyn OptObserver>>,
}
impl OptObserverVec {
pub fn new() -> Self {
Self {
observers: Vec::new(),
}
}
pub fn add(&mut self, observer: impl OptObserver + 'static) {
self.observers.push(Box::new(observer));
}
#[inline]
pub fn set_iteration_metrics(
&self,
cost: f64,
gradient_norm: f64,
damping: Option<f64>,
step_norm: f64,
step_quality: Option<f64>,
) {
for observer in &self.observers {
observer.set_iteration_metrics(cost, gradient_norm, damping, step_norm, step_quality);
}
}
#[inline]
pub fn set_matrix_data(
&self,
hessian: Option<sparse::SparseColMat<usize, f64>>,
gradient: Option<Mat<f64>>,
) {
for observer in &self.observers {
observer.set_matrix_data(hessian.clone(), gradient.clone());
}
}
#[inline]
pub fn notify(&self, values: &HashMap<String, VariableEnum>, iteration: usize) {
for observer in &self.observers {
observer.on_step(values, iteration);
}
}
#[inline]
pub fn notify_complete(&self, values: &HashMap<String, VariableEnum>, iterations: usize) {
for observer in &self.observers {
observer.on_optimization_complete(values, iterations);
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.observers.is_empty()
}
#[inline]
pub fn len(&self) -> usize {
self.observers.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ErrorLogging;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct TestObserver {
calls: Arc<Mutex<Vec<usize>>>,
}
impl OptObserver for TestObserver {
fn on_step(&self, _values: &HashMap<String, VariableEnum>, iteration: usize) {
if let Ok(mut guard) = self.calls.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "TestObserver::on_step".to_string(),
reason: e.to_string(),
}
.log()
}) {
guard.push(iteration);
}
}
}
#[test]
fn test_empty_observers() {
let observers = OptObserverVec::new();
assert!(observers.is_empty());
assert_eq!(observers.len(), 0);
observers.notify(&HashMap::new(), 0);
}
#[test]
fn test_single_observer() -> Result<(), ObserverError> {
let calls = Arc::new(Mutex::new(Vec::new()));
let observer = TestObserver {
calls: calls.clone(),
};
let mut observers = OptObserverVec::new();
observers.add(observer);
assert_eq!(observers.len(), 1);
observers.notify(&HashMap::new(), 0);
observers.notify(&HashMap::new(), 1);
observers.notify(&HashMap::new(), 2);
let guard = calls.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "test_single_observer".to_string(),
reason: e.to_string(),
}
.log()
})?;
assert_eq!(*guard, vec![0, 1, 2]);
Ok(())
}
#[test]
fn test_multiple_observers() -> Result<(), ObserverError> {
let calls1 = Arc::new(Mutex::new(Vec::new()));
let calls2 = Arc::new(Mutex::new(Vec::new()));
let observer1 = TestObserver {
calls: calls1.clone(),
};
let observer2 = TestObserver {
calls: calls2.clone(),
};
let mut observers = OptObserverVec::new();
observers.add(observer1);
observers.add(observer2);
assert_eq!(observers.len(), 2);
observers.notify(&HashMap::new(), 5);
let guard1 = calls1.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "test_multiple_observers (calls1)".to_string(),
reason: e.to_string(),
}
.log()
})?;
assert_eq!(*guard1, vec![5]);
let guard2 = calls2.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "test_multiple_observers (calls2)".to_string(),
reason: e.to_string(),
}
.log()
})?;
assert_eq!(*guard2, vec![5]);
Ok(())
}
#[test]
fn test_observer_error_rerun_initialization_display() {
let e = ObserverError::RerunInitialization("init fail".into());
assert!(e.to_string().contains("init fail"));
}
#[test]
fn test_observer_error_viewer_spawn_failed_display() {
let e = ObserverError::ViewerSpawnFailed("spawn fail".into());
assert!(e.to_string().contains("spawn fail"));
}
#[test]
fn test_observer_error_recording_save_failed_display() {
let e = ObserverError::RecordingSaveFailed {
path: "/tmp/out.rrd".into(),
reason: "disk full".into(),
};
let s = e.to_string();
assert!(s.contains("/tmp/out.rrd"), "{s}");
assert!(s.contains("disk full"), "{s}");
}
#[test]
fn test_observer_error_logging_failed_display() {
let e = ObserverError::LoggingFailed {
entity_path: "world/points".into(),
reason: "timeout".into(),
};
let s = e.to_string();
assert!(s.contains("world/points"), "{s}");
assert!(s.contains("timeout"), "{s}");
}
#[test]
fn test_observer_error_matrix_visualization_failed_display() {
let e = ObserverError::MatrixVisualizationFailed("bad dims".into());
assert!(e.to_string().contains("bad dims"));
}
#[test]
fn test_observer_error_tensor_conversion_failed_display() {
let e = ObserverError::TensorConversionFailed("nan values".into());
assert!(e.to_string().contains("nan values"));
}
#[test]
fn test_observer_error_invalid_state_display() {
let e = ObserverError::InvalidState("stream closed".into());
assert!(e.to_string().contains("stream closed"));
}
#[test]
fn test_observer_error_mutex_poisoned_display() {
let e = ObserverError::MutexPoisoned {
context: "on_step".into(),
reason: "thread panicked".into(),
};
let s = e.to_string();
assert!(s.contains("on_step"), "{s}");
assert!(s.contains("thread panicked"), "{s}");
}
#[test]
fn test_observer_error_log_returns_self() {
let e = ObserverError::InvalidState("log_test".into());
let returned = e.log();
assert!(returned.to_string().contains("log_test"));
}
#[test]
fn test_observer_error_log_with_source_returns_self() {
let e = ObserverError::MatrixVisualizationFailed("src_test".into());
let source = std::io::Error::other("src");
let returned = e.log_with_source(source);
assert!(returned.to_string().contains("src_test"));
}
#[test]
fn test_set_iteration_metrics_no_panic() {
let mut observers = OptObserverVec::new();
observers.add(TestObserver {
calls: Arc::new(Mutex::new(Vec::new())),
});
observers.set_iteration_metrics(1.5, 1e-3, Some(1e-4), 0.01, Some(0.9));
}
#[test]
fn test_set_iteration_metrics_empty_no_panic() {
let observers = OptObserverVec::new();
observers.set_iteration_metrics(0.0, 0.0, None, 0.0, None);
}
#[test]
fn test_set_matrix_data_no_panic() {
let mut observers = OptObserverVec::new();
observers.add(TestObserver {
calls: Arc::new(Mutex::new(Vec::new())),
});
observers.set_matrix_data(None, None);
}
#[derive(Clone)]
struct CompleteObserver {
complete_calls: Arc<Mutex<usize>>,
}
impl OptObserver for CompleteObserver {
fn on_step(&self, _values: &HashMap<String, VariableEnum>, _iteration: usize) {}
fn on_optimization_complete(
&self,
_values: &HashMap<String, VariableEnum>,
_iterations: usize,
) {
if let Ok(mut guard) = self.complete_calls.lock() {
*guard += 1;
}
}
}
#[test]
fn test_notify_complete_calls_on_optimization_complete() {
let complete_calls = Arc::new(Mutex::new(0usize));
let observer = CompleteObserver {
complete_calls: complete_calls.clone(),
};
let mut observers = OptObserverVec::new();
observers.add(observer);
observers.notify_complete(&HashMap::new(), 10);
let count = *complete_calls.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(count, 1);
}
#[test]
fn test_notify_complete_empty_no_panic() {
let observers = OptObserverVec::new();
observers.notify_complete(&HashMap::new(), 5);
}
#[test]
fn test_default_trait_methods_no_panic() {
let observer = TestObserver {
calls: Arc::new(Mutex::new(Vec::new())),
};
observer.set_iteration_metrics(1.0, 1e-3, None, 0.0, None);
observer.set_matrix_data(None, None);
observer.on_optimization_complete(&HashMap::new(), 5);
}
}