use crate::device::Device;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum FerrotorchError {
#[error("shape mismatch: {message}")]
ShapeMismatch { message: String },
#[error("device mismatch: expected {expected}, got {got}")]
DeviceMismatch { expected: Device, got: Device },
#[error("backward called on non-scalar tensor with shape {shape:?}")]
BackwardNonScalar { shape: Vec<usize> },
#[error("no gradient function on non-leaf tensor")]
NoGradFn,
#[error("dtype mismatch: expected {expected}, got {got}")]
DtypeMismatch { expected: String, got: String },
#[error("index out of bounds: index {index} on axis {axis} with size {size}")]
IndexOutOfBounds {
index: usize,
axis: usize,
size: usize,
},
#[error("invalid argument: {message}")]
InvalidArgument { message: String },
#[error("internal lock poisoned: {message}")]
LockPoisoned { message: String },
#[error("internal error: {message}")]
Internal { message: String },
#[error("no GPU backend available -- install ferrotorch-gpu and call init()")]
DeviceUnavailable,
#[error("cannot access GPU tensor data as CPU slice -- call .cpu() first")]
GpuTensorNotAccessible,
#[error("{op} is not supported on CUDA tensors -- call .cpu() first")]
NotImplementedOnCuda { op: &'static str },
#[error("gpu error: {source}")]
Gpu {
#[source]
source: Box<dyn std::error::Error + Send + Sync + 'static>,
},
#[error("data loading worker panicked: {message}")]
WorkerPanic { message: String },
#[error(transparent)]
Ferray(#[from] ferray_core::FerrayError),
}
pub type FerrotorchResult<T> = Result<T, FerrotorchError>;
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
#[derive(Debug, thiserror::Error)]
#[error("test error: {0}")]
struct TestError(&'static str);
#[test]
fn gpu_variant_preserves_source_chain() {
let inner = TestError("backend kernel failed");
let outer = FerrotorchError::Gpu {
source: Box::new(inner),
};
let source = outer.source().expect("source must be set via #[source]");
let downcast = source
.downcast_ref::<TestError>()
.expect("downcast back to TestError");
assert_eq!(downcast.0, "backend kernel failed");
}
#[test]
fn gpu_variant_display() {
let inner = TestError("oom");
let outer = FerrotorchError::Gpu {
source: Box::new(inner),
};
assert_eq!(outer.to_string(), "gpu error: test error: oom");
}
}