use oxicuda_blas::BlasError;
use oxicuda_driver::CudaError;
use oxicuda_ptx::PtxGenError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SolverError {
#[error("CUDA driver error: {0}")]
Cuda(#[from] CudaError),
#[error("BLAS error: {0}")]
Blas(#[from] BlasError),
#[error("PTX generation error: {0}")]
PtxGeneration(#[from] PtxGenError),
#[error("singular matrix detected")]
SingularMatrix,
#[error("matrix is not positive definite")]
NotPositiveDefinite,
#[error("dimension mismatch: {0}")]
DimensionMismatch(String),
#[error("convergence failure after {iterations} iterations (residual = {residual:.6e})")]
ConvergenceFailure {
iterations: u32,
residual: f64,
},
#[error("workspace of at least {0} bytes required")]
WorkspaceRequired(usize),
#[error("internal solver error: {0}")]
InternalError(String),
}
pub type SolverResult<T> = Result<T, SolverError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_singular_matrix() {
let err = SolverError::SingularMatrix;
assert!(err.to_string().contains("singular"));
}
#[test]
fn display_convergence_failure() {
let err = SolverError::ConvergenceFailure {
iterations: 100,
residual: 1e-3,
};
let msg = err.to_string();
assert!(msg.contains("100"));
assert!(msg.contains("1.0"));
}
#[test]
fn from_cuda_error() {
let cuda_err = CudaError::NotInitialized;
let solver_err: SolverError = cuda_err.into();
assert!(matches!(solver_err, SolverError::Cuda(_)));
}
#[test]
fn from_blas_error() {
let blas_err = BlasError::InvalidDimension("test".into());
let solver_err: SolverError = blas_err.into();
assert!(matches!(solver_err, SolverError::Blas(_)));
}
#[test]
fn display_workspace_required() {
let err = SolverError::WorkspaceRequired(4096);
assert!(err.to_string().contains("4096"));
}
}