1use thiserror::Error;
4
5#[derive(Debug, Error, PartialEq)]
7pub enum SslError {
8 #[error("dimension mismatch: expected {expected}, got {got}")]
9 DimensionMismatch { expected: usize, got: usize },
10
11 #[error("empty input")]
12 EmptyInput,
13
14 #[error("invalid temperature {temp}: must be > 0 and finite")]
15 InvalidTemperature { temp: f32 },
16
17 #[error("invalid momentum {momentum}: must be in [0, 1]")]
18 InvalidMomentum { momentum: f32 },
19
20 #[error("invalid mask ratio {ratio}: must be in [0, 1)")]
21 InvalidMaskRatio { ratio: f32 },
22
23 #[error("invalid number of crops: must be >= 1")]
24 InvalidNumCrops,
25
26 #[error("invalid loss weight {weight}: must be finite")]
27 InvalidLossWeight { weight: f32 },
28
29 #[error("queue capacity must be >= 1")]
30 QueueCapacityTooSmall,
31
32 #[error("queue is empty")]
33 QueueEmpty,
34
35 #[error("number of prototypes must be >= 2")]
36 NumPrototypesTooSmall,
37
38 #[error("Sinkhorn-Knopp diverged after {iters} iterations")]
39 SinkhornDiverged { iters: usize },
40
41 #[error("invalid feature dimension: must be > 0")]
42 InvalidFeatureDim,
43
44 #[error("invalid batch size: must be >= 2 (need positive + negative pairs)")]
45 BatchTooSmall,
46
47 #[error("non-finite value at: {location}")]
48 NanEncountered { location: &'static str },
49
50 #[error("invalid projector layer dim: in/hidden/out must be > 0")]
51 InvalidProjectorDim,
52
53 #[error("internal error: {0}")]
54 Internal(String),
55}
56
57pub type SslResult<T> = Result<T, SslError>;
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn error_display_dimension_mismatch() {
66 let e = SslError::DimensionMismatch {
67 expected: 64,
68 got: 32,
69 };
70 assert!(e.to_string().contains("64") && e.to_string().contains("32"));
71 }
72
73 #[test]
74 fn error_display_invalid_temperature() {
75 let e = SslError::InvalidTemperature { temp: -1.0 };
76 assert!(e.to_string().contains("-1"));
77 }
78
79 #[test]
80 fn error_display_invalid_momentum() {
81 let e = SslError::InvalidMomentum { momentum: 1.5 };
82 assert!(e.to_string().contains("1.5"));
83 }
84
85 #[test]
86 fn error_display_sinkhorn() {
87 let e = SslError::SinkhornDiverged { iters: 100 };
88 assert!(e.to_string().contains("100"));
89 }
90
91 #[test]
92 fn error_display_internal() {
93 let e = SslError::Internal("bad shape".into());
94 assert!(e.to_string().contains("bad shape"));
95 }
96
97 #[test]
98 fn error_equality() {
99 assert_eq!(SslError::EmptyInput, SslError::EmptyInput);
100 }
101
102 #[test]
103 fn ssl_result_ok() {
104 let r: SslResult<i32> = Ok(42);
105 assert!(r.is_ok());
106 }
107
108 #[test]
109 fn ssl_result_err() {
110 let r: SslResult<i32> = Err(SslError::InvalidNumCrops);
111 assert!(r.is_err());
112 }
113}