1use std::path::PathBuf;
6use std::process::ExitCode;
7use thiserror::Error;
8
9pub type Result<T> = std::result::Result<T, CliError>;
11
12#[derive(Error, Debug)]
14pub enum CliError {
15 #[error("File not found: {0}")]
17 FileNotFound(PathBuf),
18
19 #[error("Not a file: {0}")]
21 NotAFile(PathBuf),
22
23 #[error("Invalid APR format: {0}")]
25 InvalidFormat(String),
26
27 #[error("IO error: {0}")]
29 Io(#[from] std::io::Error),
30
31 #[error("Validation failed: {0}")]
33 ValidationFailed(String),
34
35 #[error("Aprender error: {0}")]
37 Aprender(String),
38
39 #[error("Model load failed: {0}")]
41 #[allow(dead_code)]
42 ModelLoadFailed(String),
43
44 #[error("Inference failed: {0}")]
46 #[allow(dead_code)]
47 InferenceFailed(String),
48
49 #[error("Feature not enabled: {0}")]
51 #[allow(dead_code)]
52 FeatureDisabled(String),
53
54 #[error("Network error: {0}")]
56 NetworkError(String),
57
58 #[error("HTTP 404 Not Found: {0}")]
60 HttpNotFound(String),
61}
62
63impl CliError {
64 pub fn exit_code(&self) -> ExitCode {
66 contract_pre_exit_code_semantics!();
67 match self {
68 Self::FileNotFound(_) | Self::NotAFile(_) => ExitCode::from(3),
69 Self::InvalidFormat(_) => ExitCode::from(4),
70 Self::Io(_) => ExitCode::from(7),
71 Self::ValidationFailed(_) => ExitCode::from(5),
72 Self::Aprender(_) => ExitCode::from(1),
73 Self::ModelLoadFailed(_) => ExitCode::from(6),
74 Self::InferenceFailed(_) => ExitCode::from(8),
75 Self::FeatureDisabled(_) => ExitCode::from(9),
76 Self::NetworkError(_) => ExitCode::from(10),
77 Self::HttpNotFound(_) => ExitCode::from(11),
78 }
79 }
80}
81
82impl From<aprender::error::AprenderError> for CliError {
83 fn from(e: aprender::error::AprenderError) -> Self {
84 Self::Aprender(e.to_string())
85 }
86}
87
88pub fn resolve_model_path(
94 path: &std::path::Path,
95) -> std::result::Result<std::path::PathBuf, CliError> {
96 if !path.exists() {
97 return Err(CliError::FileNotFound(path.to_path_buf()));
98 }
99 if path.is_file() {
100 return Ok(path.to_path_buf());
101 }
102 if path.is_dir() {
103 if let Some(parent) = path.parent() {
106 let depth = path.components().count();
107 if depth <= 2 {
109 return Err(CliError::NotAFile(path.to_path_buf()));
110 }
111 let _ = parent; }
113
114 let index = path.join("model.safetensors.index.json");
117 if index.is_file() {
118 return Ok(index);
119 }
120 let candidates = [
122 "model.safetensors",
123 "model-00001-of-00001.safetensors",
124 "model-00001-of-00002.safetensors",
125 "model-00001-of-00003.safetensors",
126 "model-00001-of-00004.safetensors",
127 ];
128 for candidate in &candidates {
129 let p = path.join(candidate);
130 if p.is_file() {
131 return Ok(p);
132 }
133 }
134 if let Ok(entries) = std::fs::read_dir(path) {
136 for entry in entries.flatten() {
137 let p = entry.path();
138 let is_temp = p
140 .file_name()
141 .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
142 if !is_temp && p.extension().is_some_and(|ext| ext == "gguf") && p.is_file() {
143 return Ok(p);
144 }
145 }
146 }
147 if let Ok(entries) = std::fs::read_dir(path) {
149 for entry in entries.flatten() {
150 let p = entry.path();
151 let is_temp = p
152 .file_name()
153 .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
154 if !is_temp && p.extension().is_some_and(|ext| ext == "apr") && p.is_file() {
155 return Ok(p);
156 }
157 }
158 }
159 Err(CliError::ValidationFailed(format!(
160 "Directory {} does not contain a model file (expected model.safetensors, *.gguf, or *.apr)",
161 path.display()
162 )))
163 } else {
164 Err(CliError::NotAFile(path.to_path_buf()))
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use std::path::PathBuf;
172
173 #[test]
176 fn test_file_not_found_exit_code() {
177 let err = CliError::FileNotFound(PathBuf::from("/test"));
178 assert_eq!(err.exit_code(), ExitCode::from(3));
179 }
180
181 #[test]
182 fn test_not_a_file_exit_code() {
183 let err = CliError::NotAFile(PathBuf::from("/test"));
184 assert_eq!(err.exit_code(), ExitCode::from(3));
185 }
186
187 #[test]
188 fn test_invalid_format_exit_code() {
189 let err = CliError::InvalidFormat("bad".to_string());
190 assert_eq!(err.exit_code(), ExitCode::from(4));
191 }
192
193 #[test]
194 fn test_io_error_exit_code() {
195 let err = CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "test"));
196 assert_eq!(err.exit_code(), ExitCode::from(7));
197 }
198
199 #[test]
200 fn test_validation_failed_exit_code() {
201 let err = CliError::ValidationFailed("test".to_string());
202 assert_eq!(err.exit_code(), ExitCode::from(5));
203 }
204
205 #[test]
206 fn test_aprender_error_exit_code() {
207 let err = CliError::Aprender("test".to_string());
208 assert_eq!(err.exit_code(), ExitCode::from(1));
209 }
210
211 #[test]
212 fn test_model_load_failed_exit_code() {
213 let err = CliError::ModelLoadFailed("test".to_string());
214 assert_eq!(err.exit_code(), ExitCode::from(6));
215 }
216
217 #[test]
218 fn test_inference_failed_exit_code() {
219 let err = CliError::InferenceFailed("test".to_string());
220 assert_eq!(err.exit_code(), ExitCode::from(8));
221 }
222
223 #[test]
224 fn test_feature_disabled_exit_code() {
225 let err = CliError::FeatureDisabled("test".to_string());
226 assert_eq!(err.exit_code(), ExitCode::from(9));
227 }
228
229 #[test]
230 fn test_network_error_exit_code() {
231 let err = CliError::NetworkError("test".to_string());
232 assert_eq!(err.exit_code(), ExitCode::from(10));
233 }
234
235 #[test]
236 fn test_http_not_found_exit_code() {
237 let err = CliError::HttpNotFound("test".to_string());
238 assert_eq!(err.exit_code(), ExitCode::from(11));
239 }
240
241 #[test]
244 fn test_file_not_found_display() {
245 let err = CliError::FileNotFound(PathBuf::from("/model.apr"));
246 assert_eq!(err.to_string(), "File not found: /model.apr");
247 }
248
249 #[test]
250 fn test_not_a_file_display() {
251 let err = CliError::NotAFile(PathBuf::from("/dir"));
252 assert_eq!(err.to_string(), "Not a file: /dir");
253 }
254
255 #[test]
256 fn test_invalid_format_display() {
257 let err = CliError::InvalidFormat("bad magic".to_string());
258 assert_eq!(err.to_string(), "Invalid APR format: bad magic");
259 }
260
261 #[test]
262 fn test_validation_failed_display() {
263 let err = CliError::ValidationFailed("missing field".to_string());
264 assert_eq!(err.to_string(), "Validation failed: missing field");
265 }
266
267 #[test]
268 fn test_aprender_error_display() {
269 let err = CliError::Aprender("internal".to_string());
270 assert_eq!(err.to_string(), "Aprender error: internal");
271 }
272
273 #[test]
274 fn test_model_load_failed_display() {
275 let err = CliError::ModelLoadFailed("corrupt".to_string());
276 assert_eq!(err.to_string(), "Model load failed: corrupt");
277 }
278
279 #[test]
280 fn test_inference_failed_display() {
281 let err = CliError::InferenceFailed("OOM".to_string());
282 assert_eq!(err.to_string(), "Inference failed: OOM");
283 }
284
285 #[test]
286 fn test_feature_disabled_display() {
287 let err = CliError::FeatureDisabled("cuda".to_string());
288 assert_eq!(err.to_string(), "Feature not enabled: cuda");
289 }
290
291 #[test]
292 fn test_network_error_display() {
293 let err = CliError::NetworkError("timeout".to_string());
294 assert_eq!(err.to_string(), "Network error: timeout");
295 }
296
297 #[test]
298 fn test_http_not_found_display() {
299 let err = CliError::HttpNotFound("tokenizer.json".to_string());
300 assert_eq!(err.to_string(), "HTTP 404 Not Found: tokenizer.json");
301 }
302
303 #[test]
306 fn test_io_error_conversion() {
307 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
308 let cli_err: CliError = io_err.into();
309 assert!(cli_err.to_string().contains("file missing"));
310 assert_eq!(cli_err.exit_code(), ExitCode::from(7));
311 }
312
313 #[test]
314 fn test_debug_impl() {
315 let err = CliError::FileNotFound(PathBuf::from("/test"));
316 let debug = format!("{:?}", err);
317 assert!(debug.contains("FileNotFound"));
318 }
319
320 #[test]
323 fn test_result_type_ok() {
324 let result: Result<i32> = Ok(42);
325 assert_eq!(result.unwrap(), 42);
326 }
327
328 #[test]
329 fn test_result_type_err() {
330 let result: Result<i32> = Err(CliError::InvalidFormat("test".to_string()));
331 assert!(result.is_err());
332 }
333
334 #[test]
337 fn test_all_exit_codes_are_distinct_per_category() {
338 let codes = vec![
340 (
341 CliError::FileNotFound(PathBuf::from("a")).exit_code(),
342 "file",
343 ),
344 (
345 CliError::InvalidFormat("a".to_string()).exit_code(),
346 "format",
347 ),
348 (
349 CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "")).exit_code(),
350 "io",
351 ),
352 (
353 CliError::ValidationFailed("a".to_string()).exit_code(),
354 "validation",
355 ),
356 (CliError::Aprender("a".to_string()).exit_code(), "aprender"),
357 (
358 CliError::ModelLoadFailed("a".to_string()).exit_code(),
359 "model_load",
360 ),
361 (
362 CliError::InferenceFailed("a".to_string()).exit_code(),
363 "inference",
364 ),
365 (
366 CliError::FeatureDisabled("a".to_string()).exit_code(),
367 "feature",
368 ),
369 (
370 CliError::NetworkError("a".to_string()).exit_code(),
371 "network",
372 ),
373 (
374 CliError::HttpNotFound("a".to_string()).exit_code(),
375 "http_not_found",
376 ),
377 ];
378 assert_eq!(codes[0].0, ExitCode::from(3));
380 }
381
382 #[test]
385 fn test_resolve_model_path_nonexistent() {
386 let result = resolve_model_path(std::path::Path::new("/nonexistent/path/model.gguf"));
387 assert!(result.is_err());
388 assert!(matches!(result.unwrap_err(), CliError::FileNotFound(_)));
389 }
390
391 #[test]
392 fn test_resolve_model_path_regular_file() {
393 let tmp = std::env::temp_dir().join("apr-test-resolve.safetensors");
395 std::fs::write(&tmp, b"test").expect("write");
396 let result = resolve_model_path(&tmp);
397 assert!(result.is_ok());
398 assert_eq!(result.unwrap(), tmp);
399 std::fs::remove_file(&tmp).ok();
400 }
401
402 #[test]
403 fn test_resolve_model_path_dir_with_safetensors() {
404 let dir = std::env::temp_dir().join("apr-test-resolve-dir");
405 std::fs::create_dir_all(&dir).expect("mkdir");
406 let model_file = dir.join("model.safetensors");
407 std::fs::write(&model_file, b"test").expect("write");
408 let result = resolve_model_path(&dir);
409 assert!(result.is_ok());
410 assert_eq!(result.unwrap(), model_file);
411 std::fs::remove_file(&model_file).ok();
412 std::fs::remove_dir(&dir).ok();
413 }
414
415 #[test]
416 fn test_resolve_model_path_dir_with_gguf() {
417 let dir = std::env::temp_dir().join("apr-test-resolve-gguf");
418 std::fs::create_dir_all(&dir).expect("mkdir");
419 let model_file = dir.join("model-q4.gguf");
420 std::fs::write(&model_file, b"test").expect("write");
421 let result = resolve_model_path(&dir);
422 assert!(result.is_ok());
423 assert_eq!(result.unwrap(), model_file);
424 std::fs::remove_file(&model_file).ok();
425 std::fs::remove_dir(&dir).ok();
426 }
427
428 #[test]
429 fn test_resolve_model_path_dir_with_sharded_safetensors() {
430 let dir = std::env::temp_dir().join("apr-test-resolve-sharded");
433 std::fs::create_dir_all(&dir).expect("mkdir");
434 let index_file = dir.join("model.safetensors.index.json");
435 let shard_file = dir.join("model-00001-of-00002.safetensors");
436 std::fs::write(&index_file, b"{}").expect("write index");
437 std::fs::write(&shard_file, b"test").expect("write shard");
438 let result = resolve_model_path(&dir);
439 assert!(result.is_ok());
440 assert_eq!(
441 result.unwrap(),
442 index_file,
443 "index.json must take priority over shard files"
444 );
445 std::fs::remove_file(&shard_file).ok();
446 std::fs::remove_file(&index_file).ok();
447 std::fs::remove_dir(&dir).ok();
448 }
449
450 #[test]
451 fn test_resolve_model_path_empty_dir() {
452 let dir = std::env::temp_dir().join("apr-test-resolve-empty");
453 std::fs::create_dir_all(&dir).expect("mkdir");
454 let result = resolve_model_path(&dir);
455 assert!(result.is_err());
456 assert!(matches!(result.unwrap_err(), CliError::ValidationFailed(_)));
457 std::fs::remove_dir(&dir).ok();
458 }
459}