1use thiserror::Error;
4
5#[derive(Debug, Error, Clone, PartialEq)]
7pub enum VisionError {
8 #[error("dimension mismatch: expected {expected}, got {got}")]
10 DimensionMismatch { expected: usize, got: usize },
11
12 #[error("shape mismatch: lhs {lhs:?} vs rhs {rhs:?}")]
14 ShapeMismatch { lhs: Vec<usize>, rhs: Vec<usize> },
15
16 #[error("empty input: {0}")]
18 EmptyInput(&'static str),
19
20 #[error("invalid image size: height={height}, width={width}, channels={channels}")]
22 InvalidImageSize {
23 height: usize,
24 width: usize,
25 channels: usize,
26 },
27
28 #[error("invalid patch size {patch_size}: image size {img_size} is not divisible")]
30 InvalidPatchSize { patch_size: usize, img_size: usize },
31
32 #[error("invalid embed dim: {0}")]
34 InvalidEmbedDim(usize),
35
36 #[error("invalid number of heads: {0}")]
38 InvalidNumHeads(usize),
39
40 #[error("head count {n_heads} does not divide embed dim {embed_dim}")]
42 HeadDimMismatch { n_heads: usize, embed_dim: usize },
43
44 #[error("invalid number of classes: {0}")]
46 InvalidNumClasses(usize),
47
48 #[error("invalid projection dim: {0}")]
50 InvalidProjDim(usize),
51
52 #[error("non-positive temperature: {0}")]
54 NonPositiveTemperature(f32),
55
56 #[error("invalid RoI box [{x1}, {y1}, {x2}, {y2}]")]
58 InvalidRoiBox { x1: f32, y1: f32, x2: f32, y2: f32 },
59
60 #[error("weight shape mismatch for '{name}': expected {expected:?}, got {got:?}")]
62 WeightShapeMismatch {
63 name: &'static str,
64 expected: Vec<usize>,
65 got: Vec<usize>,
66 },
67
68 #[error("non-finite value encountered: {0}")]
70 NonFinite(&'static str),
71
72 #[error("internal error: {0}")]
74 Internal(String),
75}
76
77pub type VisionResult<T> = Result<T, VisionError>;
79
80#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn error_display_dimension_mismatch() {
88 let e = VisionError::DimensionMismatch {
89 expected: 64,
90 got: 128,
91 };
92 assert!(e.to_string().contains("64"));
93 assert!(e.to_string().contains("128"));
94 }
95
96 #[test]
97 fn error_display_shape_mismatch() {
98 let e = VisionError::ShapeMismatch {
99 lhs: vec![2, 4],
100 rhs: vec![2, 8],
101 };
102 let s = e.to_string();
103 assert!(s.contains("4"));
104 assert!(s.contains("8"));
105 }
106
107 #[test]
108 fn error_display_empty_input() {
109 let e = VisionError::EmptyInput("image tensor");
110 assert!(e.to_string().contains("image tensor"));
111 }
112
113 #[test]
114 fn error_display_invalid_image_size() {
115 let e = VisionError::InvalidImageSize {
116 height: 0,
117 width: 32,
118 channels: 3,
119 };
120 let s = e.to_string();
121 assert!(s.contains("0"));
122 assert!(s.contains("32"));
123 }
124
125 #[test]
126 fn error_display_invalid_patch_size() {
127 let e = VisionError::InvalidPatchSize {
128 patch_size: 5,
129 img_size: 32,
130 };
131 let s = e.to_string();
132 assert!(s.contains("5"));
133 assert!(s.contains("32"));
134 }
135
136 #[test]
137 fn error_display_head_dim_mismatch() {
138 let e = VisionError::HeadDimMismatch {
139 n_heads: 3,
140 embed_dim: 64,
141 };
142 let s = e.to_string();
143 assert!(s.contains("3") && s.contains("64"));
144 }
145
146 #[test]
147 fn error_display_non_positive_temperature() {
148 let e = VisionError::NonPositiveTemperature(-0.1);
149 assert!(e.to_string().contains("non-positive"));
150 }
151
152 #[test]
153 fn error_display_invalid_roi_box() {
154 let e = VisionError::InvalidRoiBox {
155 x1: 5.0,
156 y1: 0.0,
157 x2: 3.0,
158 y2: 4.0,
159 };
160 let s = e.to_string();
161 assert!(s.contains("5") && s.contains("3"));
162 }
163
164 #[test]
165 fn error_display_weight_shape_mismatch() {
166 let e = VisionError::WeightShapeMismatch {
167 name: "patch_kernel",
168 expected: vec![64, 3, 4, 4],
169 got: vec![64, 3, 8, 8],
170 };
171 let s = e.to_string();
172 assert!(s.contains("patch_kernel"));
173 assert!(s.contains("64"));
174 }
175
176 #[test]
177 fn error_display_non_finite() {
178 let e = VisionError::NonFinite("attention logits");
179 assert!(e.to_string().contains("attention logits"));
180 }
181
182 #[test]
183 fn error_display_internal() {
184 let e = VisionError::Internal("unexpected state".into());
185 assert!(e.to_string().contains("unexpected state"));
186 }
187
188 #[test]
189 fn error_clone_eq() {
190 let a = VisionError::InvalidEmbedDim(0);
191 let b = a.clone();
192 assert_eq!(a, b);
193 }
194
195 #[test]
196 fn result_alias_ok() {
197 fn make_ok() -> VisionResult<u32> {
198 Ok(42)
199 }
200 assert_eq!(make_ok().expect("ok result"), 42);
201 }
202
203 #[test]
204 fn result_alias_err() {
205 let r: VisionResult<u32> = Err(VisionError::EmptyInput("test"));
206 assert!(r.is_err());
207 }
208}