use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq)]
pub enum VisionError {
#[error("dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch { expected: usize, got: usize },
#[error("shape mismatch: lhs {lhs:?} vs rhs {rhs:?}")]
ShapeMismatch { lhs: Vec<usize>, rhs: Vec<usize> },
#[error("empty input: {0}")]
EmptyInput(&'static str),
#[error("invalid image size: height={height}, width={width}, channels={channels}")]
InvalidImageSize {
height: usize,
width: usize,
channels: usize,
},
#[error("invalid patch size {patch_size}: image size {img_size} is not divisible")]
InvalidPatchSize { patch_size: usize, img_size: usize },
#[error("invalid embed dim: {0}")]
InvalidEmbedDim(usize),
#[error("invalid number of heads: {0}")]
InvalidNumHeads(usize),
#[error("head count {n_heads} does not divide embed dim {embed_dim}")]
HeadDimMismatch { n_heads: usize, embed_dim: usize },
#[error("invalid number of classes: {0}")]
InvalidNumClasses(usize),
#[error("invalid projection dim: {0}")]
InvalidProjDim(usize),
#[error("non-positive temperature: {0}")]
NonPositiveTemperature(f32),
#[error("invalid RoI box [{x1}, {y1}, {x2}, {y2}]")]
InvalidRoiBox { x1: f32, y1: f32, x2: f32, y2: f32 },
#[error("weight shape mismatch for '{name}': expected {expected:?}, got {got:?}")]
WeightShapeMismatch {
name: &'static str,
expected: Vec<usize>,
got: Vec<usize>,
},
#[error("non-finite value encountered: {0}")]
NonFinite(&'static str),
#[error("internal error: {0}")]
Internal(String),
}
pub type VisionResult<T> = Result<T, VisionError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_display_dimension_mismatch() {
let e = VisionError::DimensionMismatch {
expected: 64,
got: 128,
};
assert!(e.to_string().contains("64"));
assert!(e.to_string().contains("128"));
}
#[test]
fn error_display_shape_mismatch() {
let e = VisionError::ShapeMismatch {
lhs: vec![2, 4],
rhs: vec![2, 8],
};
let s = e.to_string();
assert!(s.contains("4"));
assert!(s.contains("8"));
}
#[test]
fn error_display_empty_input() {
let e = VisionError::EmptyInput("image tensor");
assert!(e.to_string().contains("image tensor"));
}
#[test]
fn error_display_invalid_image_size() {
let e = VisionError::InvalidImageSize {
height: 0,
width: 32,
channels: 3,
};
let s = e.to_string();
assert!(s.contains("0"));
assert!(s.contains("32"));
}
#[test]
fn error_display_invalid_patch_size() {
let e = VisionError::InvalidPatchSize {
patch_size: 5,
img_size: 32,
};
let s = e.to_string();
assert!(s.contains("5"));
assert!(s.contains("32"));
}
#[test]
fn error_display_head_dim_mismatch() {
let e = VisionError::HeadDimMismatch {
n_heads: 3,
embed_dim: 64,
};
let s = e.to_string();
assert!(s.contains("3") && s.contains("64"));
}
#[test]
fn error_display_non_positive_temperature() {
let e = VisionError::NonPositiveTemperature(-0.1);
assert!(e.to_string().contains("non-positive"));
}
#[test]
fn error_display_invalid_roi_box() {
let e = VisionError::InvalidRoiBox {
x1: 5.0,
y1: 0.0,
x2: 3.0,
y2: 4.0,
};
let s = e.to_string();
assert!(s.contains("5") && s.contains("3"));
}
#[test]
fn error_display_weight_shape_mismatch() {
let e = VisionError::WeightShapeMismatch {
name: "patch_kernel",
expected: vec![64, 3, 4, 4],
got: vec![64, 3, 8, 8],
};
let s = e.to_string();
assert!(s.contains("patch_kernel"));
assert!(s.contains("64"));
}
#[test]
fn error_display_non_finite() {
let e = VisionError::NonFinite("attention logits");
assert!(e.to_string().contains("attention logits"));
}
#[test]
fn error_display_internal() {
let e = VisionError::Internal("unexpected state".into());
assert!(e.to_string().contains("unexpected state"));
}
#[test]
fn error_clone_eq() {
let a = VisionError::InvalidEmbedDim(0);
let b = a.clone();
assert_eq!(a, b);
}
#[test]
fn result_alias_ok() {
fn make_ok() -> VisionResult<u32> {
Ok(42)
}
assert_eq!(make_ok().expect("ok result"), 42);
}
#[test]
fn result_alias_err() {
let r: VisionResult<u32> = Err(VisionError::EmptyInput("test"));
assert!(r.is_err());
}
}