#[cfg(feature = "visualization")]
pub mod conversions;
#[cfg(feature = "visualization")]
pub mod visualization;
#[cfg(feature = "visualization")]
pub use visualization::{RerunObserver, VisualizationConfig, VisualizationMode};
#[cfg(feature = "visualization")]
pub use conversions::{
CollectRerunArrows2D, CollectRerunArrows3D, CollectRerunPoints2D, CollectRerunPoints3D,
ToRerunArrows2D, ToRerunArrows3D, ToRerunPoints2D, ToRerunPoints3D, ToRerunTransform3D,
ToRerunTransform3DFrom2D, ToRerunVec2D, ToRerunVec3D,
};
use crate::core::problem::VariableEnum;
use faer::Mat;
use faer::sparse;
use std::collections::HashMap;
use thiserror::Error;
use tracing::error;
#[derive(Debug, Clone, Error)]
pub enum ObserverError {
#[error("Failed to initialize Rerun recording stream: {0}")]
RerunInitialization(String),
#[error("Failed to spawn Rerun viewer: {0}")]
ViewerSpawnFailed(String),
#[error("Failed to save recording to file '{path}': {reason}")]
RecordingSaveFailed { path: String, reason: String },
#[error("Failed to log data to Rerun at '{entity_path}': {reason}")]
LoggingFailed { entity_path: String, reason: String },
#[error("Failed to convert matrix to image: {0}")]
MatrixVisualizationFailed(String),
#[error("Failed to create tensor data: {0}")]
TensorConversionFailed(String),
#[error("Recording stream is in invalid state: {0}")]
InvalidState(String),
#[error("Mutex poisoned in {context}: {reason}")]
MutexPoisoned { context: String, reason: String },
}
impl ObserverError {
#[must_use]
pub fn log(self) -> Self {
error!("{}", self);
self
}
#[must_use]
pub fn log_with_source<E: std::fmt::Debug>(self, source_error: E) -> Self {
error!("{} | Source: {:?}", self, source_error);
self
}
}
pub type ObserverResult<T> = Result<T, ObserverError>;
pub trait OptObserver: Send {
fn on_step(&self, values: &HashMap<String, VariableEnum>, iteration: usize);
fn set_iteration_metrics(
&self,
_cost: f64,
_gradient_norm: f64,
_damping: Option<f64>,
_step_norm: f64,
_step_quality: Option<f64>,
) {
}
fn set_matrix_data(
&self,
_hessian: Option<sparse::SparseColMat<usize, f64>>,
_gradient: Option<Mat<f64>>,
) {
}
fn on_optimization_complete(
&self,
_values: &HashMap<String, VariableEnum>,
_iterations: usize,
) {
}
}
#[derive(Default)]
pub struct OptObserverVec {
observers: Vec<Box<dyn OptObserver>>,
}
impl OptObserverVec {
pub fn new() -> Self {
Self {
observers: Vec::new(),
}
}
pub fn add(&mut self, observer: impl OptObserver + 'static) {
self.observers.push(Box::new(observer));
}
#[inline]
pub fn set_iteration_metrics(
&self,
cost: f64,
gradient_norm: f64,
damping: Option<f64>,
step_norm: f64,
step_quality: Option<f64>,
) {
for observer in &self.observers {
observer.set_iteration_metrics(cost, gradient_norm, damping, step_norm, step_quality);
}
}
#[inline]
pub fn set_matrix_data(
&self,
hessian: Option<sparse::SparseColMat<usize, f64>>,
gradient: Option<Mat<f64>>,
) {
for observer in &self.observers {
observer.set_matrix_data(hessian.clone(), gradient.clone());
}
}
#[inline]
pub fn notify(&self, values: &HashMap<String, VariableEnum>, iteration: usize) {
for observer in &self.observers {
observer.on_step(values, iteration);
}
}
#[inline]
pub fn notify_complete(&self, values: &HashMap<String, VariableEnum>, iterations: usize) {
for observer in &self.observers {
observer.on_optimization_complete(values, iterations);
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.observers.is_empty()
}
#[inline]
pub fn len(&self) -> usize {
self.observers.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct TestObserver {
calls: Arc<Mutex<Vec<usize>>>,
}
impl OptObserver for TestObserver {
fn on_step(&self, _values: &HashMap<String, VariableEnum>, iteration: usize) {
if let Ok(mut guard) = self.calls.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "TestObserver::on_step".to_string(),
reason: e.to_string(),
}
.log()
}) {
guard.push(iteration);
}
}
}
#[test]
fn test_empty_observers() {
let observers = OptObserverVec::new();
assert!(observers.is_empty());
assert_eq!(observers.len(), 0);
observers.notify(&HashMap::new(), 0);
}
#[test]
fn test_single_observer() -> Result<(), ObserverError> {
let calls = Arc::new(Mutex::new(Vec::new()));
let observer = TestObserver {
calls: calls.clone(),
};
let mut observers = OptObserverVec::new();
observers.add(observer);
assert_eq!(observers.len(), 1);
observers.notify(&HashMap::new(), 0);
observers.notify(&HashMap::new(), 1);
observers.notify(&HashMap::new(), 2);
let guard = calls.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "test_single_observer".to_string(),
reason: e.to_string(),
}
.log()
})?;
assert_eq!(*guard, vec![0, 1, 2]);
Ok(())
}
#[test]
fn test_multiple_observers() -> Result<(), ObserverError> {
let calls1 = Arc::new(Mutex::new(Vec::new()));
let calls2 = Arc::new(Mutex::new(Vec::new()));
let observer1 = TestObserver {
calls: calls1.clone(),
};
let observer2 = TestObserver {
calls: calls2.clone(),
};
let mut observers = OptObserverVec::new();
observers.add(observer1);
observers.add(observer2);
assert_eq!(observers.len(), 2);
observers.notify(&HashMap::new(), 5);
let guard1 = calls1.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "test_multiple_observers (calls1)".to_string(),
reason: e.to_string(),
}
.log()
})?;
assert_eq!(*guard1, vec![5]);
let guard2 = calls2.lock().map_err(|e| {
ObserverError::MutexPoisoned {
context: "test_multiple_observers (calls2)".to_string(),
reason: e.to_string(),
}
.log()
})?;
assert_eq!(*guard2, vec![5]);
Ok(())
}
}