Skip to main content

entrenar/cli/commands/
inspect.rs

1//! Inspect command implementation
2
3use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{InspectArgs, InspectMode};
6use std::path::Path;
7
8/// Inspect a SafeTensors model file
9fn inspect_safetensors(path: &Path, level: LogLevel) -> Result<(), String> {
10    use safetensors::SafeTensors;
11
12    let data = std::fs::read(path).map_err(|e| format!("Failed to read file: {e}"))?;
13
14    let tensors =
15        SafeTensors::deserialize(&data).map_err(|e| format!("Failed to parse SafeTensors: {e}"))?;
16
17    let tensor_names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
18    let total_params = count_total_parameters(&tensors, &tensor_names);
19    let file_size = data.len();
20
21    log_model_info(level, file_size, total_params, tensor_names.len());
22
23    if level == LogLevel::Verbose {
24        log_tensor_details(level, &tensors, &tensor_names);
25    }
26
27    Ok(())
28}
29
30/// Count total parameters across all tensors
31fn count_total_parameters(tensors: &safetensors::SafeTensors<'_>, names: &[String]) -> u64 {
32    let mut total: u64 = 0;
33    for name in names {
34        if let Ok(tensor) = tensors.tensor(name) {
35            let params: u64 = tensor.shape().iter().product::<usize>() as u64;
36            total += params;
37        }
38    }
39    total
40}
41
42/// Log basic model information
43fn log_model_info(level: LogLevel, file_size: usize, total_params: u64, tensor_count: usize) {
44    log(level, LogLevel::Normal, "Model Information:");
45    log(level, LogLevel::Normal, &format!("  File size: {:.2} MB", file_size as f64 / 1_000_000.0));
46    log(level, LogLevel::Normal, &format!("  Parameters: {:.2}B", total_params as f64 / 1e9));
47    log(level, LogLevel::Normal, &format!("  Tensors: {tensor_count}"));
48}
49
50/// Log detailed tensor information
51fn log_tensor_details(level: LogLevel, tensors: &safetensors::SafeTensors<'_>, names: &[String]) {
52    log(level, LogLevel::Verbose, "\nTensor Details:");
53    for name in &names[..names.len().min(20)] {
54        if let Ok(tensor) = tensors.tensor(name) {
55            log(
56                level,
57                LogLevel::Verbose,
58                &format!("  {}: {:?} ({:?})", name, tensor.shape(), tensor.dtype()),
59            );
60        }
61    }
62    if names.len() > 20 {
63        log(level, LogLevel::Verbose, &format!("  ... and {} more tensors", names.len() - 20));
64    }
65}
66
67/// Inspect a GGUF model file
68fn inspect_gguf(path: &Path, level: LogLevel) -> Result<(), String> {
69    let metadata = std::fs::metadata(path).map_err(|e| format!("Failed to read metadata: {e}"))?;
70
71    log(level, LogLevel::Normal, "GGUF Model Information:");
72    log(
73        level,
74        LogLevel::Normal,
75        &format!("  File size: {:.2} MB", metadata.len() as f64 / 1_000_000.0),
76    );
77    log(level, LogLevel::Normal, "  Format: GGUF (llama.cpp compatible)");
78    log(level, LogLevel::Normal, "  (Use llama.cpp for detailed GGUF inspection)");
79
80    Ok(())
81}
82
83/// Inspect a data file (parquet or csv)
84fn inspect_data_file(
85    path: &Path,
86    ext: &str,
87    mode: InspectMode,
88    z_threshold: f32,
89    level: LogLevel,
90) -> Result<(), String> {
91    let metadata = std::fs::metadata(path).map_err(|e| format!("Failed to read metadata: {e}"))?;
92
93    match mode {
94        InspectMode::Summary => {
95            log(level, LogLevel::Normal, "Data Summary:");
96            log(
97                level,
98                LogLevel::Normal,
99                &format!("  File size: {:.2} MB", metadata.len() as f64 / 1_000_000.0),
100            );
101            log(level, LogLevel::Normal, &format!("  Format: {ext}"));
102        }
103        InspectMode::Outliers => {
104            log(
105                level,
106                LogLevel::Normal,
107                &format!("Outlier Detection (z-threshold: {z_threshold}):"),
108            );
109            log(level, LogLevel::Normal, "  Load data with alimentar for outlier analysis");
110        }
111        InspectMode::Distribution => {
112            log(level, LogLevel::Normal, "Distribution Statistics:");
113            log(level, LogLevel::Normal, "  Load data with alimentar for distribution analysis");
114        }
115        InspectMode::Schema => {
116            log(level, LogLevel::Normal, "Schema:");
117            log(level, LogLevel::Normal, &format!("  Format: {ext}"));
118        }
119    }
120
121    Ok(())
122}
123
124/// Get file extension as lowercase string
125fn get_extension(path: &Path) -> &str {
126    path.extension().and_then(|s| s.to_str()).unwrap_or("")
127}
128
129/// Inspect a LoRA adapter directory (ENT-LoRA-018)
130fn inspect_lora_adapter(dir: &Path, level: LogLevel) -> Result<(), String> {
131    let config_path = dir.join("adapter_config.json");
132    let adapter_path = dir.join("adapter_model.safetensors");
133
134    log(level, LogLevel::Normal, "LoRA Adapter:");
135    log(level, LogLevel::Normal, &format!("  Directory: {}", dir.display()));
136
137    // Read adapter_config.json
138    if config_path.exists() {
139        let config_str =
140            std::fs::read_to_string(&config_path).map_err(|e| format!("Read config: {e}"))?;
141        if let Ok(config) = serde_json::from_str::<serde_json::Value>(&config_str) {
142            if let Some(rank) = config.get("r").and_then(serde_json::Value::as_u64) {
143                log(level, LogLevel::Normal, &format!("  Rank: {rank}"));
144            }
145            if let Some(alpha) = config.get("lora_alpha").and_then(serde_json::Value::as_f64) {
146                log(level, LogLevel::Normal, &format!("  Alpha: {alpha}"));
147            }
148            if let Some(modules) =
149                config.get("target_modules").and_then(serde_json::Value::as_array)
150            {
151                let names: Vec<&str> =
152                    modules.iter().filter_map(serde_json::Value::as_str).collect();
153                log(level, LogLevel::Normal, &format!("  Target modules: {}", names.join(", ")));
154            }
155            if let Some(base) =
156                config.get("base_model_name_or_path").and_then(serde_json::Value::as_str)
157            {
158                log(level, LogLevel::Normal, &format!("  Base model: {base}"));
159            }
160        }
161    }
162
163    // Read adapter_model.safetensors
164    if adapter_path.exists() {
165        let size = std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0);
166        log(level, LogLevel::Normal, &format!("  Adapter size: {:.2} MB", size as f64 / 1e6));
167
168        let data = std::fs::read(&adapter_path).map_err(|e| format!("Read adapter: {e}"))?;
169        if let Ok(tensors) = safetensors::SafeTensors::deserialize(&data) {
170            let names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
171            log(level, LogLevel::Normal, &format!("  Adapter tensors: {}", names.len()));
172            let total_params: u64 = names
173                .iter()
174                .filter_map(|n| tensors.tensor(n).ok())
175                .map(|t| t.shape().iter().product::<usize>() as u64)
176                .sum();
177            log(level, LogLevel::Normal, &format!("  Trainable params: {total_params}"));
178        }
179    } else {
180        log(level, LogLevel::Normal, "  (no adapter_model.safetensors found)");
181    }
182
183    Ok(())
184}
185
186pub fn run_inspect(args: InspectArgs, level: LogLevel) -> Result<(), String> {
187    log(level, LogLevel::Normal, &format!("Inspecting: {}", args.input.display()));
188
189    if !args.input.exists() {
190        return Err(format!("File not found: {}", args.input.display()));
191    }
192
193    // ENT-LoRA-018: Check if this is a LoRA adapter directory
194    if args.input.is_dir() && args.input.join("adapter_config.json").exists() {
195        return inspect_lora_adapter(&args.input, level);
196    }
197
198    let ext = get_extension(&args.input);
199    log(level, LogLevel::Normal, &format!("  Mode: {}", args.mode));
200
201    match ext {
202        "safetensors" => inspect_safetensors(&args.input, level),
203        "gguf" => inspect_gguf(&args.input, level),
204        "parquet" | "csv" => {
205            inspect_data_file(&args.input, ext, args.mode, args.z_threshold, level)
206        }
207        _ => {
208            if args.input.is_dir() {
209                Err(format!(
210                    "Directory {} does not contain adapter_config.json",
211                    args.input.display()
212                ))
213            } else {
214                Err(format!(
215                    "Unsupported file format: {ext}. Use .safetensors, .gguf, .parquet, or .csv"
216                ))
217            }
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use std::path::PathBuf;
226
227    #[test]
228    fn test_get_extension_safetensors() {
229        let path = PathBuf::from("model.safetensors");
230        assert_eq!(get_extension(&path), "safetensors");
231    }
232
233    #[test]
234    fn test_get_extension_gguf() {
235        let path = PathBuf::from("model.gguf");
236        assert_eq!(get_extension(&path), "gguf");
237    }
238
239    #[test]
240    fn test_get_extension_parquet() {
241        let path = PathBuf::from("data.parquet");
242        assert_eq!(get_extension(&path), "parquet");
243    }
244
245    #[test]
246    fn test_get_extension_csv() {
247        let path = PathBuf::from("data.csv");
248        assert_eq!(get_extension(&path), "csv");
249    }
250
251    #[test]
252    fn test_get_extension_none() {
253        let path = PathBuf::from("noextension");
254        assert_eq!(get_extension(&path), "");
255    }
256
257    #[test]
258    fn test_run_inspect_file_not_found() {
259        let args = InspectArgs {
260            input: PathBuf::from("/nonexistent/path/model.safetensors"),
261            mode: InspectMode::Summary,
262            columns: None,
263            z_threshold: 3.0,
264        };
265        let result = run_inspect(args, LogLevel::Normal);
266        assert!(result.is_err());
267        assert!(result.unwrap_err().contains("File not found"));
268    }
269
270    #[test]
271    fn test_run_inspect_unsupported_format() {
272        let temp_dir = std::env::temp_dir();
273        let path = temp_dir.join("test_inspect.xyz");
274        std::fs::write(&path, "test").expect("file write should succeed");
275
276        let args = InspectArgs {
277            input: path.clone(),
278            mode: InspectMode::Summary,
279            columns: None,
280            z_threshold: 3.0,
281        };
282        let result = run_inspect(args, LogLevel::Normal);
283
284        let _ = std::fs::remove_file(&path);
285        assert!(result.is_err());
286        assert!(result.unwrap_err().contains("Unsupported file format"));
287    }
288
289    #[test]
290    fn test_inspect_data_file_summary() {
291        let temp = std::env::temp_dir().join("inspect_sum.csv");
292        std::fs::write(&temp, "a,b\n1,2").expect("write csv");
293        let r = inspect_data_file(&temp, "csv", InspectMode::Summary, 3.0, LogLevel::Normal);
294        let _ = std::fs::remove_file(&temp);
295        assert!(r.is_ok());
296    }
297
298    #[test]
299    fn test_inspect_data_file_outliers() {
300        let temp = std::env::temp_dir().join("inspect_out.csv");
301        std::fs::write(&temp, "a,b\n1,2").expect("write csv");
302        let r = inspect_data_file(&temp, "csv", InspectMode::Outliers, 2.5, LogLevel::Normal);
303        let _ = std::fs::remove_file(&temp);
304        assert!(r.is_ok());
305    }
306
307    #[test]
308    fn test_inspect_data_file_distribution() {
309        let temp = std::env::temp_dir().join("inspect_dist.csv");
310        std::fs::write(&temp, "a,b\n1,2").expect("write csv");
311        let r = inspect_data_file(&temp, "csv", InspectMode::Distribution, 3.0, LogLevel::Normal);
312        let _ = std::fs::remove_file(&temp);
313        assert!(r.is_ok());
314    }
315
316    #[test]
317    fn test_inspect_data_file_schema() {
318        let temp = std::env::temp_dir().join("inspect_sch.parquet");
319        std::fs::write(&temp, "fake parquet").expect("write file");
320        let r = inspect_data_file(&temp, "parquet", InspectMode::Schema, 3.0, LogLevel::Normal);
321        let _ = std::fs::remove_file(&temp);
322        assert!(r.is_ok());
323    }
324
325    #[test]
326    fn test_inspect_gguf() {
327        let temp = std::env::temp_dir().join("test_model.gguf");
328        std::fs::write(&temp, "GGUF fake data 12345678").expect("write gguf");
329        let r = inspect_gguf(&temp, LogLevel::Normal);
330        let _ = std::fs::remove_file(&temp);
331        assert!(r.is_ok());
332    }
333
334    #[test]
335    fn test_count_total_parameters_single_scalar() {
336        use safetensors::serialize;
337        use safetensors::tensor::TensorView;
338        use safetensors::Dtype;
339        // A scalar tensor with shape [1] = 1 parameter
340        let data = [0u8; 4]; // 1 f32 = 4 bytes
341        let tv = TensorView::new(Dtype::F32, vec![1], &data).unwrap();
342        let tensors = vec![("scalar", tv)];
343        let bytes = serialize(tensors, None).unwrap();
344        let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
345        let names: Vec<String> = st.names().iter().map(|s| (*s).to_string()).collect();
346        let total = count_total_parameters(&st, &names);
347        assert_eq!(total, 1);
348    }
349
350    #[test]
351    fn test_count_total_parameters_with_data() {
352        use safetensors::serialize;
353        use safetensors::tensor::TensorView;
354        use safetensors::Dtype;
355        let data = [0u8; 24];
356        let tv = TensorView::new(Dtype::F32, vec![2, 3], &data).unwrap();
357        let tensors = vec![("w", tv)];
358        let bytes = serialize(tensors, None).unwrap();
359        let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
360        let names: Vec<String> = st.names().iter().map(|s| (*s).to_string()).collect();
361        let total = count_total_parameters(&st, &names);
362        assert_eq!(total, 6);
363    }
364
365    #[test]
366    fn test_log_model_info() {
367        log_model_info(LogLevel::Normal, 1_000_000, 500_000_000, 100);
368    }
369
370    #[test]
371    fn test_log_model_info_verbose() {
372        log_model_info(LogLevel::Verbose, 0, 0, 0);
373    }
374
375    #[test]
376    fn test_run_inspect_csv_file() {
377        let temp = std::env::temp_dir().join("ri_test.csv");
378        std::fs::write(&temp, "col1,col2\nval1,val2").expect("write csv");
379        let args = InspectArgs {
380            input: temp.clone(),
381            mode: InspectMode::Summary,
382            columns: None,
383            z_threshold: 3.0,
384        };
385        let r = run_inspect(args, LogLevel::Normal);
386        let _ = std::fs::remove_file(&temp);
387        assert!(r.is_ok());
388    }
389
390    #[test]
391    fn test_run_inspect_gguf_file() {
392        let temp = std::env::temp_dir().join("ri_test.gguf");
393        std::fs::write(&temp, "GGUF fake").expect("write gguf");
394        let args = InspectArgs {
395            input: temp.clone(),
396            mode: InspectMode::Summary,
397            columns: None,
398            z_threshold: 3.0,
399        };
400        let r = run_inspect(args, LogLevel::Normal);
401        let _ = std::fs::remove_file(&temp);
402        assert!(r.is_ok());
403    }
404
405    #[test]
406    fn test_run_inspect_directory_no_adapter() {
407        let temp = std::env::temp_dir().join("ri_dir_test");
408        let _ = std::fs::create_dir_all(&temp);
409        let args = InspectArgs {
410            input: temp.clone(),
411            mode: InspectMode::Summary,
412            columns: None,
413            z_threshold: 3.0,
414        };
415        let r = run_inspect(args, LogLevel::Normal);
416        let _ = std::fs::remove_dir_all(&temp);
417        assert!(r.is_err());
418        assert!(r.unwrap_err().contains("adapter_config.json"));
419    }
420
421    #[test]
422    fn test_run_inspect_parquet_distribution() {
423        let temp = std::env::temp_dir().join("ri_test.parquet");
424        std::fs::write(&temp, "fake parquet data").expect("write");
425        let args = InspectArgs {
426            input: temp.clone(),
427            mode: InspectMode::Distribution,
428            columns: None,
429            z_threshold: 3.0,
430        };
431        let r = run_inspect(args, LogLevel::Normal);
432        let _ = std::fs::remove_file(&temp);
433        assert!(r.is_ok());
434    }
435
436    #[test]
437    fn test_run_inspect_outlier_mode() {
438        let temp = std::env::temp_dir().join("ri_outlier.csv");
439        std::fs::write(&temp, "x\n1\n2\n3").expect("write");
440        let args = InspectArgs {
441            input: temp.clone(),
442            mode: InspectMode::Outliers,
443            columns: None,
444            z_threshold: 2.0,
445        };
446        let r = run_inspect(args, LogLevel::Normal);
447        let _ = std::fs::remove_file(&temp);
448        assert!(r.is_ok());
449    }
450
451    #[test]
452    fn test_run_inspect_schema_mode() {
453        let temp = std::env::temp_dir().join("ri_schema.csv");
454        std::fs::write(&temp, "a,b\n1,2").expect("write");
455        let args = InspectArgs {
456            input: temp.clone(),
457            mode: InspectMode::Schema,
458            columns: None,
459            z_threshold: 3.0,
460        };
461        let r = run_inspect(args, LogLevel::Normal);
462        let _ = std::fs::remove_file(&temp);
463        assert!(r.is_ok());
464    }
465
466    #[test]
467    fn test_inspect_lora_adapter_empty_dir() {
468        let temp = std::env::temp_dir().join("ri_lora_empty");
469        let _ = std::fs::create_dir_all(&temp);
470        std::fs::write(temp.join("adapter_config.json"), "{}").expect("write config");
471        let r = inspect_lora_adapter(&temp, LogLevel::Normal);
472        let _ = std::fs::remove_dir_all(&temp);
473        assert!(r.is_ok());
474    }
475
476    #[test]
477    fn test_inspect_lora_adapter_with_config() {
478        let temp = std::env::temp_dir().join("ri_lora_cfg");
479        let _ = std::fs::create_dir_all(&temp);
480        let config = serde_json::json!({
481            "r": 16,
482            "lora_alpha": 32.0,
483            "target_modules": ["q_proj", "v_proj"],
484            "base_model_name_or_path": "test/model"
485        });
486        std::fs::write(temp.join("adapter_config.json"), config.to_string()).expect("write");
487        let r = inspect_lora_adapter(&temp, LogLevel::Normal);
488        let _ = std::fs::remove_dir_all(&temp);
489        assert!(r.is_ok());
490    }
491
492    #[test]
493    fn test_run_inspect_lora_adapter_dir() {
494        let temp = std::env::temp_dir().join("ri_lora_run");
495        let _ = std::fs::create_dir_all(&temp);
496        std::fs::write(temp.join("adapter_config.json"), "{}").expect("write");
497        let args = InspectArgs {
498            input: temp.clone(),
499            mode: InspectMode::Summary,
500            columns: None,
501            z_threshold: 3.0,
502        };
503        let r = run_inspect(args, LogLevel::Normal);
504        let _ = std::fs::remove_dir_all(&temp);
505        assert!(r.is_ok());
506    }
507
508    #[test]
509    fn test_get_extension_compound() {
510        let p = PathBuf::from("model.v2.safetensors");
511        assert_eq!(get_extension(&p), "safetensors");
512    }
513
514    #[test]
515    fn test_get_extension_dotfile() {
516        let p = PathBuf::from(".hidden");
517        assert_eq!(get_extension(&p), "");
518    }
519}