1use std::path::PathBuf;
6use std::process::ExitCode;
7use thiserror::Error;
8
9pub type Result<T> = std::result::Result<T, CliError>;
11
12#[derive(Error, Debug)]
14pub enum CliError {
15 #[error("File not found: {0}")]
17 FileNotFound(PathBuf),
18
19 #[error("Not a file: {0}")]
21 NotAFile(PathBuf),
22
23 #[error("Invalid APR format: {0}")]
25 InvalidFormat(String),
26
27 #[error("IO error: {0}")]
29 Io(#[from] std::io::Error),
30
31 #[error("Validation failed: {0}")]
33 ValidationFailed(String),
34
35 #[error("Aprender error: {0}")]
37 Aprender(String),
38
39 #[error("Model load failed: {0}")]
41 #[allow(dead_code)]
42 ModelLoadFailed(String),
43
44 #[error("Inference failed: {0}")]
46 #[allow(dead_code)]
47 InferenceFailed(String),
48
49 #[error("Feature not enabled: {0}")]
51 #[allow(dead_code)]
52 FeatureDisabled(String),
53
54 #[error("Network error: {0}")]
56 NetworkError(String),
57
58 #[error("HTTP 404 Not Found: {0}")]
60 HttpNotFound(String),
61}
62
63impl CliError {
64 pub fn exit_code(&self) -> ExitCode {
66 contract_pre_exit_code_semantics!();
67 contract_pre_error_mapping!();
68 contract_pre_exit_code_on_error!();
69 match self {
70 Self::FileNotFound(_) | Self::NotAFile(_) => ExitCode::from(3),
71 Self::InvalidFormat(_) => ExitCode::from(4),
72 Self::Io(_) => ExitCode::from(7),
73 Self::ValidationFailed(_) => ExitCode::from(5),
74 Self::Aprender(_) => ExitCode::from(1),
75 Self::ModelLoadFailed(_) => ExitCode::from(6),
76 Self::InferenceFailed(_) => ExitCode::from(8),
77 Self::FeatureDisabled(_) => ExitCode::from(9),
78 Self::NetworkError(_) => ExitCode::from(10),
79 Self::HttpNotFound(_) => ExitCode::from(11),
80 }
81 }
82}
83
84impl From<aprender::error::AprenderError> for CliError {
85 fn from(e: aprender::error::AprenderError) -> Self {
86 Self::Aprender(e.to_string())
87 }
88}
89
90pub fn resolve_model_path(
96 path: &std::path::Path,
97) -> std::result::Result<std::path::PathBuf, CliError> {
98 if !path.exists() {
99 return Err(CliError::FileNotFound(path.to_path_buf()));
100 }
101 if path.is_file() {
102 return Ok(path.to_path_buf());
103 }
104 if path.is_dir() {
105 if let Some(parent) = path.parent() {
108 let depth = path.components().count();
109 if depth <= 2 {
111 return Err(CliError::NotAFile(path.to_path_buf()));
112 }
113 let _ = parent; }
115
116 let index = path.join("model.safetensors.index.json");
119 if index.is_file() {
120 return Ok(index);
121 }
122 let candidates = [
124 "model.safetensors",
125 "model-00001-of-00001.safetensors",
126 "model-00001-of-00002.safetensors",
127 "model-00001-of-00003.safetensors",
128 "model-00001-of-00004.safetensors",
129 ];
130 for candidate in &candidates {
131 let p = path.join(candidate);
132 if p.is_file() {
133 return Ok(p);
134 }
135 }
136 if let Ok(entries) = std::fs::read_dir(path) {
138 for entry in entries.flatten() {
139 let p = entry.path();
140 let is_temp = p
142 .file_name()
143 .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
144 if !is_temp && p.extension().is_some_and(|ext| ext == "gguf") && p.is_file() {
145 return Ok(p);
146 }
147 }
148 }
149 if let Ok(entries) = std::fs::read_dir(path) {
151 for entry in entries.flatten() {
152 let p = entry.path();
153 let is_temp = p
154 .file_name()
155 .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
156 if !is_temp && p.extension().is_some_and(|ext| ext == "apr") && p.is_file() {
157 return Ok(p);
158 }
159 }
160 }
161 Err(CliError::ValidationFailed(format!(
162 "Directory {} does not contain a model file (expected model.safetensors, *.gguf, or *.apr)",
163 path.display()
164 )))
165 } else {
166 Err(CliError::NotAFile(path.to_path_buf()))
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use std::path::PathBuf;
174
175 #[test]
178 fn test_file_not_found_exit_code() {
179 let err = CliError::FileNotFound(PathBuf::from("/test"));
180 assert_eq!(err.exit_code(), ExitCode::from(3));
181 }
182
183 #[test]
184 fn test_not_a_file_exit_code() {
185 let err = CliError::NotAFile(PathBuf::from("/test"));
186 assert_eq!(err.exit_code(), ExitCode::from(3));
187 }
188
189 #[test]
190 fn test_invalid_format_exit_code() {
191 let err = CliError::InvalidFormat("bad".to_string());
192 assert_eq!(err.exit_code(), ExitCode::from(4));
193 }
194
195 #[test]
196 fn test_io_error_exit_code() {
197 let err = CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "test"));
198 assert_eq!(err.exit_code(), ExitCode::from(7));
199 }
200
201 #[test]
202 fn test_validation_failed_exit_code() {
203 let err = CliError::ValidationFailed("test".to_string());
204 assert_eq!(err.exit_code(), ExitCode::from(5));
205 }
206
207 #[test]
208 fn test_aprender_error_exit_code() {
209 let err = CliError::Aprender("test".to_string());
210 assert_eq!(err.exit_code(), ExitCode::from(1));
211 }
212
213 #[test]
214 fn test_model_load_failed_exit_code() {
215 let err = CliError::ModelLoadFailed("test".to_string());
216 assert_eq!(err.exit_code(), ExitCode::from(6));
217 }
218
219 #[test]
220 fn test_inference_failed_exit_code() {
221 let err = CliError::InferenceFailed("test".to_string());
222 assert_eq!(err.exit_code(), ExitCode::from(8));
223 }
224
225 #[test]
226 fn test_feature_disabled_exit_code() {
227 let err = CliError::FeatureDisabled("test".to_string());
228 assert_eq!(err.exit_code(), ExitCode::from(9));
229 }
230
231 #[test]
232 fn test_network_error_exit_code() {
233 let err = CliError::NetworkError("test".to_string());
234 assert_eq!(err.exit_code(), ExitCode::from(10));
235 }
236
237 #[test]
238 fn test_http_not_found_exit_code() {
239 let err = CliError::HttpNotFound("test".to_string());
240 assert_eq!(err.exit_code(), ExitCode::from(11));
241 }
242
243 #[test]
246 fn test_file_not_found_display() {
247 let err = CliError::FileNotFound(PathBuf::from("/model.apr"));
248 assert_eq!(err.to_string(), "File not found: /model.apr");
249 }
250
251 #[test]
252 fn test_not_a_file_display() {
253 let err = CliError::NotAFile(PathBuf::from("/dir"));
254 assert_eq!(err.to_string(), "Not a file: /dir");
255 }
256
257 #[test]
258 fn test_invalid_format_display() {
259 let err = CliError::InvalidFormat("bad magic".to_string());
260 assert_eq!(err.to_string(), "Invalid APR format: bad magic");
261 }
262
263 #[test]
264 fn test_validation_failed_display() {
265 let err = CliError::ValidationFailed("missing field".to_string());
266 assert_eq!(err.to_string(), "Validation failed: missing field");
267 }
268
269 #[test]
270 fn test_aprender_error_display() {
271 let err = CliError::Aprender("internal".to_string());
272 assert_eq!(err.to_string(), "Aprender error: internal");
273 }
274
275 #[test]
276 fn test_model_load_failed_display() {
277 let err = CliError::ModelLoadFailed("corrupt".to_string());
278 assert_eq!(err.to_string(), "Model load failed: corrupt");
279 }
280
281 #[test]
282 fn test_inference_failed_display() {
283 let err = CliError::InferenceFailed("OOM".to_string());
284 assert_eq!(err.to_string(), "Inference failed: OOM");
285 }
286
287 #[test]
288 fn test_feature_disabled_display() {
289 let err = CliError::FeatureDisabled("cuda".to_string());
290 assert_eq!(err.to_string(), "Feature not enabled: cuda");
291 }
292
293 #[test]
294 fn test_network_error_display() {
295 let err = CliError::NetworkError("timeout".to_string());
296 assert_eq!(err.to_string(), "Network error: timeout");
297 }
298
299 #[test]
300 fn test_http_not_found_display() {
301 let err = CliError::HttpNotFound("tokenizer.json".to_string());
302 assert_eq!(err.to_string(), "HTTP 404 Not Found: tokenizer.json");
303 }
304
305 #[test]
308 fn test_io_error_conversion() {
309 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
310 let cli_err: CliError = io_err.into();
311 assert!(cli_err.to_string().contains("file missing"));
312 assert_eq!(cli_err.exit_code(), ExitCode::from(7));
313 }
314
315 #[test]
316 fn test_debug_impl() {
317 let err = CliError::FileNotFound(PathBuf::from("/test"));
318 let debug = format!("{:?}", err);
319 assert!(debug.contains("FileNotFound"));
320 }
321
322 #[test]
325 fn test_result_type_ok() {
326 let result: Result<i32> = Ok(42);
327 assert_eq!(result.expect("value"), 42);
328 }
329
330 #[test]
331 fn test_result_type_err() {
332 let result: Result<i32> = Err(CliError::InvalidFormat("test".to_string()));
333 assert!(result.is_err());
334 }
335
336 #[test]
339 fn test_all_exit_codes_are_distinct_per_category() {
340 let codes = vec![
342 (
343 CliError::FileNotFound(PathBuf::from("a")).exit_code(),
344 "file",
345 ),
346 (
347 CliError::InvalidFormat("a".to_string()).exit_code(),
348 "format",
349 ),
350 (
351 CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "")).exit_code(),
352 "io",
353 ),
354 (
355 CliError::ValidationFailed("a".to_string()).exit_code(),
356 "validation",
357 ),
358 (CliError::Aprender("a".to_string()).exit_code(), "aprender"),
359 (
360 CliError::ModelLoadFailed("a".to_string()).exit_code(),
361 "model_load",
362 ),
363 (
364 CliError::InferenceFailed("a".to_string()).exit_code(),
365 "inference",
366 ),
367 (
368 CliError::FeatureDisabled("a".to_string()).exit_code(),
369 "feature",
370 ),
371 (
372 CliError::NetworkError("a".to_string()).exit_code(),
373 "network",
374 ),
375 (
376 CliError::HttpNotFound("a".to_string()).exit_code(),
377 "http_not_found",
378 ),
379 ];
380 assert_eq!(codes[0].0, ExitCode::from(3));
382 }
383
384 #[test]
387 fn test_resolve_model_path_nonexistent() {
388 let result = resolve_model_path(std::path::Path::new("/nonexistent/path/model.gguf"));
389 assert!(result.is_err());
390 assert!(matches!(result.unwrap_err(), CliError::FileNotFound(_)));
391 }
392
393 #[test]
394 fn test_resolve_model_path_regular_file() {
395 let tmp = std::env::temp_dir().join("apr-test-resolve.safetensors");
397 std::fs::write(&tmp, b"test").expect("write");
398 let result = resolve_model_path(&tmp);
399 assert!(result.is_ok());
400 assert_eq!(result.expect("value"), tmp);
401 std::fs::remove_file(&tmp).ok();
402 }
403
404 #[test]
405 fn test_resolve_model_path_dir_with_safetensors() {
406 let dir = std::env::temp_dir().join("apr-test-resolve-dir");
407 std::fs::create_dir_all(&dir).expect("mkdir");
408 let model_file = dir.join("model.safetensors");
409 std::fs::write(&model_file, b"test").expect("write");
410 let result = resolve_model_path(&dir);
411 assert!(result.is_ok());
412 assert_eq!(result.expect("value"), model_file);
413 std::fs::remove_file(&model_file).ok();
414 std::fs::remove_dir(&dir).ok();
415 }
416
417 #[test]
418 fn test_resolve_model_path_dir_with_gguf() {
419 let dir = std::env::temp_dir().join("apr-test-resolve-gguf");
420 std::fs::create_dir_all(&dir).expect("mkdir");
421 let model_file = dir.join("model-q4.gguf");
422 std::fs::write(&model_file, b"test").expect("write");
423 let result = resolve_model_path(&dir);
424 assert!(result.is_ok());
425 assert_eq!(result.expect("value"), model_file);
426 std::fs::remove_file(&model_file).ok();
427 std::fs::remove_dir(&dir).ok();
428 }
429
430 #[test]
431 fn test_resolve_model_path_dir_with_sharded_safetensors() {
432 let dir = std::env::temp_dir().join("apr-test-resolve-sharded");
435 std::fs::create_dir_all(&dir).expect("mkdir");
436 let index_file = dir.join("model.safetensors.index.json");
437 let shard_file = dir.join("model-00001-of-00002.safetensors");
438 std::fs::write(&index_file, b"{}").expect("write index");
439 std::fs::write(&shard_file, b"test").expect("write shard");
440 let result = resolve_model_path(&dir);
441 assert!(result.is_ok());
442 assert_eq!(
443 result.expect("value"),
444 index_file,
445 "index.json must take priority over shard files"
446 );
447 std::fs::remove_file(&shard_file).ok();
448 std::fs::remove_file(&index_file).ok();
449 std::fs::remove_dir(&dir).ok();
450 }
451
452 #[test]
453 fn test_resolve_model_path_empty_dir() {
454 let dir = std::env::temp_dir().join("apr-test-resolve-empty");
455 std::fs::create_dir_all(&dir).expect("mkdir");
456 let result = resolve_model_path(&dir);
457 assert!(result.is_err());
458 assert!(matches!(result.unwrap_err(), CliError::ValidationFailed(_)));
459 std::fs::remove_dir(&dir).ok();
460 }
461}