1use std::path::PathBuf;
6use std::process::ExitCode;
7use thiserror::Error;
8
9pub type Result<T> = std::result::Result<T, CliError>;
11
12#[derive(Error, Debug)]
14pub enum CliError {
15 #[error("File not found: {0}")]
17 FileNotFound(PathBuf),
18
19 #[error("Not a file: {0}")]
21 NotAFile(PathBuf),
22
23 #[error("Invalid APR format: {0}")]
25 InvalidFormat(String),
26
27 #[error("IO error: {0}")]
29 Io(#[from] std::io::Error),
30
31 #[error("Validation failed: {0}")]
33 ValidationFailed(String),
34
35 #[error("Aprender error: {0}")]
37 Aprender(String),
38
39 #[error("Model load failed: {0}")]
41 #[allow(dead_code)]
42 ModelLoadFailed(String),
43
44 #[error("Inference failed: {0}")]
46 #[allow(dead_code)]
47 InferenceFailed(String),
48
49 #[error("Feature not enabled: {0}")]
51 #[allow(dead_code)]
52 FeatureDisabled(String),
53
54 #[error("Network error: {0}")]
56 NetworkError(String),
57
58 #[error("HTTP 404 Not Found: {0}")]
60 HttpNotFound(String),
61}
62
63impl CliError {
64 pub fn exit_code(&self) -> ExitCode {
66 match self {
67 Self::FileNotFound(_) | Self::NotAFile(_) => ExitCode::from(3),
68 Self::InvalidFormat(_) => ExitCode::from(4),
69 Self::Io(_) => ExitCode::from(7),
70 Self::ValidationFailed(_) => ExitCode::from(5),
71 Self::Aprender(_) => ExitCode::from(1),
72 Self::ModelLoadFailed(_) => ExitCode::from(6),
73 Self::InferenceFailed(_) => ExitCode::from(8),
74 Self::FeatureDisabled(_) => ExitCode::from(9),
75 Self::NetworkError(_) => ExitCode::from(10),
76 Self::HttpNotFound(_) => ExitCode::from(11),
77 }
78 }
79}
80
81impl From<aprender::error::AprenderError> for CliError {
82 fn from(e: aprender::error::AprenderError) -> Self {
83 Self::Aprender(e.to_string())
84 }
85}
86
87pub fn resolve_model_path(
93 path: &std::path::Path,
94) -> std::result::Result<std::path::PathBuf, CliError> {
95 if !path.exists() {
96 return Err(CliError::FileNotFound(path.to_path_buf()));
97 }
98 if path.is_file() {
99 return Ok(path.to_path_buf());
100 }
101 if path.is_dir() {
102 let index = path.join("model.safetensors.index.json");
105 if index.is_file() {
106 return Ok(index);
107 }
108 let candidates = [
110 "model.safetensors",
111 "model-00001-of-00001.safetensors",
112 "model-00001-of-00002.safetensors",
113 "model-00001-of-00003.safetensors",
114 "model-00001-of-00004.safetensors",
115 ];
116 for candidate in &candidates {
117 let p = path.join(candidate);
118 if p.is_file() {
119 return Ok(p);
120 }
121 }
122 if let Ok(entries) = std::fs::read_dir(path) {
124 for entry in entries.flatten() {
125 let p = entry.path();
126 if p.extension().is_some_and(|ext| ext == "gguf") && p.is_file() {
127 return Ok(p);
128 }
129 }
130 }
131 if let Ok(entries) = std::fs::read_dir(path) {
133 for entry in entries.flatten() {
134 let p = entry.path();
135 if p.extension().is_some_and(|ext| ext == "apr") && p.is_file() {
136 return Ok(p);
137 }
138 }
139 }
140 Err(CliError::ValidationFailed(format!(
141 "Directory {} does not contain a model file (expected model.safetensors, *.gguf, or *.apr)",
142 path.display()
143 )))
144 } else {
145 Err(CliError::NotAFile(path.to_path_buf()))
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use std::path::PathBuf;
153
154 #[test]
157 fn test_file_not_found_exit_code() {
158 let err = CliError::FileNotFound(PathBuf::from("/test"));
159 assert_eq!(err.exit_code(), ExitCode::from(3));
160 }
161
162 #[test]
163 fn test_not_a_file_exit_code() {
164 let err = CliError::NotAFile(PathBuf::from("/test"));
165 assert_eq!(err.exit_code(), ExitCode::from(3));
166 }
167
168 #[test]
169 fn test_invalid_format_exit_code() {
170 let err = CliError::InvalidFormat("bad".to_string());
171 assert_eq!(err.exit_code(), ExitCode::from(4));
172 }
173
174 #[test]
175 fn test_io_error_exit_code() {
176 let err = CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "test"));
177 assert_eq!(err.exit_code(), ExitCode::from(7));
178 }
179
180 #[test]
181 fn test_validation_failed_exit_code() {
182 let err = CliError::ValidationFailed("test".to_string());
183 assert_eq!(err.exit_code(), ExitCode::from(5));
184 }
185
186 #[test]
187 fn test_aprender_error_exit_code() {
188 let err = CliError::Aprender("test".to_string());
189 assert_eq!(err.exit_code(), ExitCode::from(1));
190 }
191
192 #[test]
193 fn test_model_load_failed_exit_code() {
194 let err = CliError::ModelLoadFailed("test".to_string());
195 assert_eq!(err.exit_code(), ExitCode::from(6));
196 }
197
198 #[test]
199 fn test_inference_failed_exit_code() {
200 let err = CliError::InferenceFailed("test".to_string());
201 assert_eq!(err.exit_code(), ExitCode::from(8));
202 }
203
204 #[test]
205 fn test_feature_disabled_exit_code() {
206 let err = CliError::FeatureDisabled("test".to_string());
207 assert_eq!(err.exit_code(), ExitCode::from(9));
208 }
209
210 #[test]
211 fn test_network_error_exit_code() {
212 let err = CliError::NetworkError("test".to_string());
213 assert_eq!(err.exit_code(), ExitCode::from(10));
214 }
215
216 #[test]
217 fn test_http_not_found_exit_code() {
218 let err = CliError::HttpNotFound("test".to_string());
219 assert_eq!(err.exit_code(), ExitCode::from(11));
220 }
221
222 #[test]
225 fn test_file_not_found_display() {
226 let err = CliError::FileNotFound(PathBuf::from("/model.apr"));
227 assert_eq!(err.to_string(), "File not found: /model.apr");
228 }
229
230 #[test]
231 fn test_not_a_file_display() {
232 let err = CliError::NotAFile(PathBuf::from("/dir"));
233 assert_eq!(err.to_string(), "Not a file: /dir");
234 }
235
236 #[test]
237 fn test_invalid_format_display() {
238 let err = CliError::InvalidFormat("bad magic".to_string());
239 assert_eq!(err.to_string(), "Invalid APR format: bad magic");
240 }
241
242 #[test]
243 fn test_validation_failed_display() {
244 let err = CliError::ValidationFailed("missing field".to_string());
245 assert_eq!(err.to_string(), "Validation failed: missing field");
246 }
247
248 #[test]
249 fn test_aprender_error_display() {
250 let err = CliError::Aprender("internal".to_string());
251 assert_eq!(err.to_string(), "Aprender error: internal");
252 }
253
254 #[test]
255 fn test_model_load_failed_display() {
256 let err = CliError::ModelLoadFailed("corrupt".to_string());
257 assert_eq!(err.to_string(), "Model load failed: corrupt");
258 }
259
260 #[test]
261 fn test_inference_failed_display() {
262 let err = CliError::InferenceFailed("OOM".to_string());
263 assert_eq!(err.to_string(), "Inference failed: OOM");
264 }
265
266 #[test]
267 fn test_feature_disabled_display() {
268 let err = CliError::FeatureDisabled("cuda".to_string());
269 assert_eq!(err.to_string(), "Feature not enabled: cuda");
270 }
271
272 #[test]
273 fn test_network_error_display() {
274 let err = CliError::NetworkError("timeout".to_string());
275 assert_eq!(err.to_string(), "Network error: timeout");
276 }
277
278 #[test]
279 fn test_http_not_found_display() {
280 let err = CliError::HttpNotFound("tokenizer.json".to_string());
281 assert_eq!(err.to_string(), "HTTP 404 Not Found: tokenizer.json");
282 }
283
284 #[test]
287 fn test_io_error_conversion() {
288 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
289 let cli_err: CliError = io_err.into();
290 assert!(cli_err.to_string().contains("file missing"));
291 assert_eq!(cli_err.exit_code(), ExitCode::from(7));
292 }
293
294 #[test]
295 fn test_debug_impl() {
296 let err = CliError::FileNotFound(PathBuf::from("/test"));
297 let debug = format!("{:?}", err);
298 assert!(debug.contains("FileNotFound"));
299 }
300
301 #[test]
304 fn test_result_type_ok() {
305 let result: Result<i32> = Ok(42);
306 assert_eq!(result.unwrap(), 42);
307 }
308
309 #[test]
310 fn test_result_type_err() {
311 let result: Result<i32> = Err(CliError::InvalidFormat("test".to_string()));
312 assert!(result.is_err());
313 }
314
315 #[test]
318 fn test_all_exit_codes_are_distinct_per_category() {
319 let codes = vec![
321 (
322 CliError::FileNotFound(PathBuf::from("a")).exit_code(),
323 "file",
324 ),
325 (
326 CliError::InvalidFormat("a".to_string()).exit_code(),
327 "format",
328 ),
329 (
330 CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "")).exit_code(),
331 "io",
332 ),
333 (
334 CliError::ValidationFailed("a".to_string()).exit_code(),
335 "validation",
336 ),
337 (CliError::Aprender("a".to_string()).exit_code(), "aprender"),
338 (
339 CliError::ModelLoadFailed("a".to_string()).exit_code(),
340 "model_load",
341 ),
342 (
343 CliError::InferenceFailed("a".to_string()).exit_code(),
344 "inference",
345 ),
346 (
347 CliError::FeatureDisabled("a".to_string()).exit_code(),
348 "feature",
349 ),
350 (
351 CliError::NetworkError("a".to_string()).exit_code(),
352 "network",
353 ),
354 (
355 CliError::HttpNotFound("a".to_string()).exit_code(),
356 "http_not_found",
357 ),
358 ];
359 assert_eq!(codes[0].0, ExitCode::from(3));
361 }
362
363 #[test]
366 fn test_resolve_model_path_nonexistent() {
367 let result = resolve_model_path(std::path::Path::new("/nonexistent/path/model.gguf"));
368 assert!(result.is_err());
369 assert!(matches!(result.unwrap_err(), CliError::FileNotFound(_)));
370 }
371
372 #[test]
373 fn test_resolve_model_path_regular_file() {
374 let tmp = std::env::temp_dir().join("apr-test-resolve.safetensors");
376 std::fs::write(&tmp, b"test").expect("write");
377 let result = resolve_model_path(&tmp);
378 assert!(result.is_ok());
379 assert_eq!(result.unwrap(), tmp);
380 std::fs::remove_file(&tmp).ok();
381 }
382
383 #[test]
384 fn test_resolve_model_path_dir_with_safetensors() {
385 let dir = std::env::temp_dir().join("apr-test-resolve-dir");
386 std::fs::create_dir_all(&dir).expect("mkdir");
387 let model_file = dir.join("model.safetensors");
388 std::fs::write(&model_file, b"test").expect("write");
389 let result = resolve_model_path(&dir);
390 assert!(result.is_ok());
391 assert_eq!(result.unwrap(), model_file);
392 std::fs::remove_file(&model_file).ok();
393 std::fs::remove_dir(&dir).ok();
394 }
395
396 #[test]
397 fn test_resolve_model_path_dir_with_gguf() {
398 let dir = std::env::temp_dir().join("apr-test-resolve-gguf");
399 std::fs::create_dir_all(&dir).expect("mkdir");
400 let model_file = dir.join("model-q4.gguf");
401 std::fs::write(&model_file, b"test").expect("write");
402 let result = resolve_model_path(&dir);
403 assert!(result.is_ok());
404 assert_eq!(result.unwrap(), model_file);
405 std::fs::remove_file(&model_file).ok();
406 std::fs::remove_dir(&dir).ok();
407 }
408
409 #[test]
410 fn test_resolve_model_path_dir_with_sharded_safetensors() {
411 let dir = std::env::temp_dir().join("apr-test-resolve-sharded");
414 std::fs::create_dir_all(&dir).expect("mkdir");
415 let index_file = dir.join("model.safetensors.index.json");
416 let shard_file = dir.join("model-00001-of-00002.safetensors");
417 std::fs::write(&index_file, b"{}").expect("write index");
418 std::fs::write(&shard_file, b"test").expect("write shard");
419 let result = resolve_model_path(&dir);
420 assert!(result.is_ok());
421 assert_eq!(
422 result.unwrap(),
423 index_file,
424 "index.json must take priority over shard files"
425 );
426 std::fs::remove_file(&shard_file).ok();
427 std::fs::remove_file(&index_file).ok();
428 std::fs::remove_dir(&dir).ok();
429 }
430
431 #[test]
432 fn test_resolve_model_path_empty_dir() {
433 let dir = std::env::temp_dir().join("apr-test-resolve-empty");
434 std::fs::create_dir_all(&dir).expect("mkdir");
435 let result = resolve_model_path(&dir);
436 assert!(result.is_err());
437 assert!(matches!(result.unwrap_err(), CliError::ValidationFailed(_)));
438 std::fs::remove_dir(&dir).ok();
439 }
440}