entrenar/hf_pipeline/
error.rs1use std::path::PathBuf;
6use std::time::Duration;
7use thiserror::Error;
8
9pub type Result<T> = std::result::Result<T, FetchError>;
11
12#[derive(Debug, Error)]
17pub enum FetchError {
18 #[error("Network timeout for {repo} after {elapsed:?}")]
20 NetworkTimeout { repo: String, elapsed: Duration },
21
22 #[error("Rate limited, retry after {retry_after:?}")]
24 RateLimited { retry_after: Duration },
25
26 #[error("Repository not found: {repo}")]
28 ModelNotFound { repo: String },
29
30 #[error("File not found in {repo}: {file}")]
32 FileNotFound { repo: String, file: String },
33
34 #[error("Corrupt file at {path}: expected SHA256 {expected_hash}, got {actual_hash}")]
36 CorruptFile { path: PathBuf, expected_hash: String, actual_hash: String },
37
38 #[error("Insufficient disk space: need {required} bytes, have {available} bytes")]
40 InsufficientDisk { required: u64, available: u64 },
41
42 #[error("Out of memory: model requires {required} bytes, available {available} bytes")]
44 OutOfMemory { required: u64, available: u64 },
45
46 #[error("Authentication failed: {message}")]
48 AuthenticationFailed { message: String },
49
50 #[error("Missing HF_TOKEN - set environment variable or use with_token()")]
52 MissingToken,
53
54 #[error("Invalid repository ID format (expected 'org/name'): {repo_id}")]
56 InvalidRepoId { repo_id: String },
57
58 #[error("Unsupported model format: {format}")]
60 UnsupportedFormat { format: String },
61
62 #[error("SECURITY: PyTorch .bin files may contain arbitrary code. Enable allow_pytorch_pickle to proceed.")]
64 PickleSecurityRisk,
65
66 #[error("Failed to parse config.json: {message}")]
68 ConfigParseError { message: String },
69
70 #[error("Tensor shape mismatch for {tensor}: expected {expected:?}, got {actual:?}")]
72 ShapeMismatch { tensor: String, expected: Vec<usize>, actual: Vec<usize> },
73
74 #[error("IO error: {0}")]
76 Io(#[from] std::io::Error),
77
78 #[error("JSON error: {0}")]
80 Json(#[from] serde_json::Error),
81
82 #[error("SafeTensors parse error: {message}")]
84 SafeTensorsParseError { message: String },
85
86 #[error("Leaderboard not found: {kind}")]
88 LeaderboardNotFound { kind: String },
89
90 #[error("Failed to parse leaderboard data: {message}")]
92 LeaderboardParseError { message: String },
93
94 #[error("HTTP error: {message}")]
96 HttpError { message: String },
97
98 #[error("GGUF write error: {message}")]
100 GgufWriteError { message: String },
101}
102
103impl FetchError {
104 #[must_use]
106 pub fn is_retryable(&self) -> bool {
107 matches!(self, Self::NetworkTimeout { .. } | Self::RateLimited { .. })
108 }
109
110 #[must_use]
112 pub fn retry_after(&self) -> Option<Duration> {
113 match self {
114 Self::RateLimited { retry_after } => Some(*retry_after),
115 Self::NetworkTimeout { .. } => Some(Duration::from_secs(5)),
116 _ => None,
117 }
118 }
119
120 #[must_use]
122 pub fn is_security_risk(&self) -> bool {
123 matches!(self, Self::PickleSecurityRisk | Self::CorruptFile { .. })
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_network_timeout_is_retryable() {
133 let err = FetchError::NetworkTimeout {
134 repo: "test/model".into(),
135 elapsed: Duration::from_secs(30),
136 };
137 assert!(err.is_retryable());
138 assert!(err.retry_after().is_some());
139 }
140
141 #[test]
142 fn test_rate_limited_is_retryable() {
143 let err = FetchError::RateLimited { retry_after: Duration::from_secs(60) };
144 assert!(err.is_retryable());
145 assert_eq!(err.retry_after(), Some(Duration::from_secs(60)));
146 }
147
148 #[test]
149 fn test_model_not_found_not_retryable() {
150 let err = FetchError::ModelNotFound { repo: "test/model".into() };
151 assert!(!err.is_retryable());
152 assert!(err.retry_after().is_none());
153 }
154
155 #[test]
156 fn test_pickle_is_security_risk() {
157 let err = FetchError::PickleSecurityRisk;
158 assert!(err.is_security_risk());
159 assert!(!err.is_retryable());
160 }
161
162 #[test]
163 fn test_corrupt_file_is_security_risk() {
164 let err = FetchError::CorruptFile {
165 path: PathBuf::from("/tmp/model.safetensors"),
166 expected_hash: "abc123".into(),
167 actual_hash: "def456".into(),
168 };
169 assert!(err.is_security_risk());
170 }
171
172 #[test]
173 fn test_missing_token_display() {
174 let err = FetchError::MissingToken;
175 let msg = err.to_string();
176 assert!(msg.contains("HF_TOKEN"));
177 }
178
179 #[test]
180 fn test_invalid_repo_id_display() {
181 let err = FetchError::InvalidRepoId { repo_id: "invalid".into() };
182 let msg = err.to_string();
183 assert!(msg.contains("org/name"));
184 }
185
186 #[test]
187 fn test_all_error_variants_display() {
188 let errors: Vec<FetchError> = vec![
190 FetchError::NetworkTimeout { repo: "r".into(), elapsed: Duration::from_secs(1) },
191 FetchError::RateLimited { retry_after: Duration::from_secs(1) },
192 FetchError::ModelNotFound { repo: "r".into() },
193 FetchError::FileNotFound { repo: "r".into(), file: "f".into() },
194 FetchError::CorruptFile {
195 path: PathBuf::from("p"),
196 expected_hash: "e".into(),
197 actual_hash: "a".into(),
198 },
199 FetchError::InsufficientDisk { required: 100, available: 50 },
200 FetchError::OutOfMemory { required: 100, available: 50 },
201 FetchError::AuthenticationFailed { message: "m".into() },
202 FetchError::MissingToken,
203 FetchError::InvalidRepoId { repo_id: "r".into() },
204 FetchError::UnsupportedFormat { format: "f".into() },
205 FetchError::PickleSecurityRisk,
206 FetchError::ConfigParseError { message: "m".into() },
207 FetchError::ShapeMismatch {
208 tensor: "t".into(),
209 expected: vec![1, 2],
210 actual: vec![3, 4],
211 },
212 FetchError::LeaderboardNotFound { kind: "OpenASR".into() },
213 FetchError::LeaderboardParseError { message: "missing field".into() },
214 FetchError::HttpError { message: "connection refused".into() },
215 FetchError::GgufWriteError { message: "alignment error".into() },
216 ];
217
218 for err in errors {
219 let msg = err.to_string();
220 assert!(!msg.is_empty(), "Error display should not be empty: {err:?}");
221 }
222 }
223
224 #[test]
225 fn test_io_error_conversion() {
226 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
227 let fetch_err: FetchError = io_err.into();
228 assert!(matches!(fetch_err, FetchError::Io(_)));
229 }
230}