Skip to main content

oxicuda_ssl/
error.rs

1//! Error types for `oxicuda-ssl`.
2
3use thiserror::Error;
4
5/// All errors that can be returned from `oxicuda-ssl`.
6#[derive(Debug, Error, PartialEq)]
7pub enum SslError {
8    #[error("dimension mismatch: expected {expected}, got {got}")]
9    DimensionMismatch { expected: usize, got: usize },
10
11    #[error("empty input")]
12    EmptyInput,
13
14    #[error("invalid temperature {temp}: must be > 0 and finite")]
15    InvalidTemperature { temp: f32 },
16
17    #[error("invalid momentum {momentum}: must be in [0, 1]")]
18    InvalidMomentum { momentum: f32 },
19
20    #[error("invalid mask ratio {ratio}: must be in [0, 1)")]
21    InvalidMaskRatio { ratio: f32 },
22
23    #[error("invalid number of crops: must be >= 1")]
24    InvalidNumCrops,
25
26    #[error("invalid loss weight {weight}: must be finite")]
27    InvalidLossWeight { weight: f32 },
28
29    #[error("queue capacity must be >= 1")]
30    QueueCapacityTooSmall,
31
32    #[error("queue is empty")]
33    QueueEmpty,
34
35    #[error("number of prototypes must be >= 2")]
36    NumPrototypesTooSmall,
37
38    #[error("Sinkhorn-Knopp diverged after {iters} iterations")]
39    SinkhornDiverged { iters: usize },
40
41    #[error("invalid feature dimension: must be > 0")]
42    InvalidFeatureDim,
43
44    #[error("invalid batch size: must be >= 2 (need positive + negative pairs)")]
45    BatchTooSmall,
46
47    #[error("non-finite value at: {location}")]
48    NanEncountered { location: &'static str },
49
50    #[error("invalid projector layer dim: in/hidden/out must be > 0")]
51    InvalidProjectorDim,
52
53    #[error("internal error: {0}")]
54    Internal(String),
55}
56
57/// Convenience result alias.
58pub type SslResult<T> = Result<T, SslError>;
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn error_display_dimension_mismatch() {
66        let e = SslError::DimensionMismatch {
67            expected: 64,
68            got: 32,
69        };
70        assert!(e.to_string().contains("64") && e.to_string().contains("32"));
71    }
72
73    #[test]
74    fn error_display_invalid_temperature() {
75        let e = SslError::InvalidTemperature { temp: -1.0 };
76        assert!(e.to_string().contains("-1"));
77    }
78
79    #[test]
80    fn error_display_invalid_momentum() {
81        let e = SslError::InvalidMomentum { momentum: 1.5 };
82        assert!(e.to_string().contains("1.5"));
83    }
84
85    #[test]
86    fn error_display_sinkhorn() {
87        let e = SslError::SinkhornDiverged { iters: 100 };
88        assert!(e.to_string().contains("100"));
89    }
90
91    #[test]
92    fn error_display_internal() {
93        let e = SslError::Internal("bad shape".into());
94        assert!(e.to_string().contains("bad shape"));
95    }
96
97    #[test]
98    fn error_equality() {
99        assert_eq!(SslError::EmptyInput, SslError::EmptyInput);
100    }
101
102    #[test]
103    fn ssl_result_ok() {
104        let r: SslResult<i32> = Ok(42);
105        assert!(r.is_ok());
106    }
107
108    #[test]
109    fn ssl_result_err() {
110        let r: SslResult<i32> = Err(SslError::InvalidNumCrops);
111        assert!(r.is_err());
112    }
113}