use thiserror::Error;
pub type LearnedRhoResult<T> = Result<T, LearnedRhoError>;
#[derive(Debug, Error)]
pub enum LearnedRhoError {
#[error("dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
#[error("invalid configuration: {0}")]
InvalidConfiguration(String),
#[error("training error: {0}")]
TrainingError(String),
#[error("forward pass error: {0}")]
ForwardError(String),
#[error("backward pass error: {0}")]
BackwardError(String),
#[error("consolidation error: {0}")]
ConsolidationError(String),
#[error("replay buffer error: {0}")]
ReplayBufferError(String),
#[error("model not initialized")]
NotInitialized,
#[error("numerical instability: {0}")]
NumericalInstability(String),
#[error("internal learned rho error: {0}")]
Internal(String),
}
impl LearnedRhoError {
#[must_use]
pub fn dim_mismatch(expected: usize, actual: usize) -> Self {
Self::DimensionMismatch { expected, actual }
}
#[must_use]
pub fn training(msg: impl Into<String>) -> Self {
Self::TrainingError(msg.into())
}
#[must_use]
pub fn numerical(msg: impl Into<String>) -> Self {
Self::NumericalInstability(msg.into())
}
}