Skip to main content

oxicuda_vision/
error.rs

1//! Error types for `oxicuda-vision`.
2
3use thiserror::Error;
4
5/// Errors returned by `oxicuda-vision` operations.
6#[derive(Debug, Error, Clone, PartialEq)]
7pub enum VisionError {
8    /// Tensor dimension does not match the expected value.
9    #[error("dimension mismatch: expected {expected}, got {got}")]
10    DimensionMismatch { expected: usize, got: usize },
11
12    /// Shape mismatch between two tensors.
13    #[error("shape mismatch: lhs {lhs:?} vs rhs {rhs:?}")]
14    ShapeMismatch { lhs: Vec<usize>, rhs: Vec<usize> },
15
16    /// The input slice or tensor is empty.
17    #[error("empty input: {0}")]
18    EmptyInput(&'static str),
19
20    /// Image spatial dimensions are zero or inconsistent.
21    #[error("invalid image size: height={height}, width={width}, channels={channels}")]
22    InvalidImageSize {
23        height: usize,
24        width: usize,
25        channels: usize,
26    },
27
28    /// Patch size is zero, negative, or does not divide the image dimension.
29    #[error("invalid patch size {patch_size}: image size {img_size} is not divisible")]
30    InvalidPatchSize { patch_size: usize, img_size: usize },
31
32    /// Embedding dimension is zero or otherwise invalid.
33    #[error("invalid embed dim: {0}")]
34    InvalidEmbedDim(usize),
35
36    /// Number of attention heads is zero or invalid.
37    #[error("invalid number of heads: {0}")]
38    InvalidNumHeads(usize),
39
40    /// Head dimension does not divide the embedding dimension.
41    #[error("head count {n_heads} does not divide embed dim {embed_dim}")]
42    HeadDimMismatch { n_heads: usize, embed_dim: usize },
43
44    /// Number of output classes is zero.
45    #[error("invalid number of classes: {0}")]
46    InvalidNumClasses(usize),
47
48    /// Projection dimension is zero.
49    #[error("invalid projection dim: {0}")]
50    InvalidProjDim(usize),
51
52    /// Contrastive loss temperature is non-positive.
53    #[error("non-positive temperature: {0}")]
54    NonPositiveTemperature(f32),
55
56    /// RoI box coordinates are invalid (e.g., x1 >= x2).
57    #[error("invalid RoI box [{x1}, {y1}, {x2}, {y2}]")]
58    InvalidRoiBox { x1: f32, y1: f32, x2: f32, y2: f32 },
59
60    /// Weight tensor has wrong shape.
61    #[error("weight shape mismatch for '{name}': expected {expected:?}, got {got:?}")]
62    WeightShapeMismatch {
63        name: &'static str,
64        expected: Vec<usize>,
65        got: Vec<usize>,
66    },
67
68    /// NaN or infinity encountered in intermediate values.
69    #[error("non-finite value encountered: {0}")]
70    NonFinite(&'static str),
71
72    /// Internal logic error (should not occur in correct usage).
73    #[error("internal error: {0}")]
74    Internal(String),
75}
76
77/// Convenience alias for `Result<T, VisionError>`.
78pub type VisionResult<T> = Result<T, VisionError>;
79
80// ─── Tests ───────────────────────────────────────────────────────────────────
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn error_display_dimension_mismatch() {
88        let e = VisionError::DimensionMismatch {
89            expected: 64,
90            got: 128,
91        };
92        assert!(e.to_string().contains("64"));
93        assert!(e.to_string().contains("128"));
94    }
95
96    #[test]
97    fn error_display_shape_mismatch() {
98        let e = VisionError::ShapeMismatch {
99            lhs: vec![2, 4],
100            rhs: vec![2, 8],
101        };
102        let s = e.to_string();
103        assert!(s.contains("4"));
104        assert!(s.contains("8"));
105    }
106
107    #[test]
108    fn error_display_empty_input() {
109        let e = VisionError::EmptyInput("image tensor");
110        assert!(e.to_string().contains("image tensor"));
111    }
112
113    #[test]
114    fn error_display_invalid_image_size() {
115        let e = VisionError::InvalidImageSize {
116            height: 0,
117            width: 32,
118            channels: 3,
119        };
120        let s = e.to_string();
121        assert!(s.contains("0"));
122        assert!(s.contains("32"));
123    }
124
125    #[test]
126    fn error_display_invalid_patch_size() {
127        let e = VisionError::InvalidPatchSize {
128            patch_size: 5,
129            img_size: 32,
130        };
131        let s = e.to_string();
132        assert!(s.contains("5"));
133        assert!(s.contains("32"));
134    }
135
136    #[test]
137    fn error_display_head_dim_mismatch() {
138        let e = VisionError::HeadDimMismatch {
139            n_heads: 3,
140            embed_dim: 64,
141        };
142        let s = e.to_string();
143        assert!(s.contains("3") && s.contains("64"));
144    }
145
146    #[test]
147    fn error_display_non_positive_temperature() {
148        let e = VisionError::NonPositiveTemperature(-0.1);
149        assert!(e.to_string().contains("non-positive"));
150    }
151
152    #[test]
153    fn error_display_invalid_roi_box() {
154        let e = VisionError::InvalidRoiBox {
155            x1: 5.0,
156            y1: 0.0,
157            x2: 3.0,
158            y2: 4.0,
159        };
160        let s = e.to_string();
161        assert!(s.contains("5") && s.contains("3"));
162    }
163
164    #[test]
165    fn error_display_weight_shape_mismatch() {
166        let e = VisionError::WeightShapeMismatch {
167            name: "patch_kernel",
168            expected: vec![64, 3, 4, 4],
169            got: vec![64, 3, 8, 8],
170        };
171        let s = e.to_string();
172        assert!(s.contains("patch_kernel"));
173        assert!(s.contains("64"));
174    }
175
176    #[test]
177    fn error_display_non_finite() {
178        let e = VisionError::NonFinite("attention logits");
179        assert!(e.to_string().contains("attention logits"));
180    }
181
182    #[test]
183    fn error_display_internal() {
184        let e = VisionError::Internal("unexpected state".into());
185        assert!(e.to_string().contains("unexpected state"));
186    }
187
188    #[test]
189    fn error_clone_eq() {
190        let a = VisionError::InvalidEmbedDim(0);
191        let b = a.clone();
192        assert_eq!(a, b);
193    }
194
195    #[test]
196    fn result_alias_ok() {
197        fn make_ok() -> VisionResult<u32> {
198            Ok(42)
199        }
200        assert_eq!(make_ok().expect("ok result"), 42);
201    }
202
203    #[test]
204    fn result_alias_err() {
205        let r: VisionResult<u32> = Err(VisionError::EmptyInput("test"));
206        assert!(r.is_err());
207    }
208}