1use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{InspectArgs, InspectMode};
6use std::path::Path;
7
8fn inspect_safetensors(path: &Path, level: LogLevel) -> Result<(), String> {
10 use safetensors::SafeTensors;
11
12 let data = std::fs::read(path).map_err(|e| format!("Failed to read file: {e}"))?;
13
14 let tensors =
15 SafeTensors::deserialize(&data).map_err(|e| format!("Failed to parse SafeTensors: {e}"))?;
16
17 let tensor_names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
18 let total_params = count_total_parameters(&tensors, &tensor_names);
19 let file_size = data.len();
20
21 log_model_info(level, file_size, total_params, tensor_names.len());
22
23 if level == LogLevel::Verbose {
24 log_tensor_details(level, &tensors, &tensor_names);
25 }
26
27 Ok(())
28}
29
30fn count_total_parameters(tensors: &safetensors::SafeTensors<'_>, names: &[String]) -> u64 {
32 let mut total: u64 = 0;
33 for name in names {
34 if let Ok(tensor) = tensors.tensor(name) {
35 let params: u64 = tensor.shape().iter().product::<usize>() as u64;
36 total += params;
37 }
38 }
39 total
40}
41
42fn log_model_info(level: LogLevel, file_size: usize, total_params: u64, tensor_count: usize) {
44 log(level, LogLevel::Normal, "Model Information:");
45 log(level, LogLevel::Normal, &format!(" File size: {:.2} MB", file_size as f64 / 1_000_000.0));
46 log(level, LogLevel::Normal, &format!(" Parameters: {:.2}B", total_params as f64 / 1e9));
47 log(level, LogLevel::Normal, &format!(" Tensors: {tensor_count}"));
48}
49
50fn log_tensor_details(level: LogLevel, tensors: &safetensors::SafeTensors<'_>, names: &[String]) {
52 log(level, LogLevel::Verbose, "\nTensor Details:");
53 for name in &names[..names.len().min(20)] {
54 if let Ok(tensor) = tensors.tensor(name) {
55 log(
56 level,
57 LogLevel::Verbose,
58 &format!(" {}: {:?} ({:?})", name, tensor.shape(), tensor.dtype()),
59 );
60 }
61 }
62 if names.len() > 20 {
63 log(level, LogLevel::Verbose, &format!(" ... and {} more tensors", names.len() - 20));
64 }
65}
66
67fn inspect_gguf(path: &Path, level: LogLevel) -> Result<(), String> {
69 let metadata = std::fs::metadata(path).map_err(|e| format!("Failed to read metadata: {e}"))?;
70
71 log(level, LogLevel::Normal, "GGUF Model Information:");
72 log(
73 level,
74 LogLevel::Normal,
75 &format!(" File size: {:.2} MB", metadata.len() as f64 / 1_000_000.0),
76 );
77 log(level, LogLevel::Normal, " Format: GGUF (llama.cpp compatible)");
78 log(level, LogLevel::Normal, " (Use llama.cpp for detailed GGUF inspection)");
79
80 Ok(())
81}
82
83fn inspect_data_file(
85 path: &Path,
86 ext: &str,
87 mode: InspectMode,
88 z_threshold: f32,
89 level: LogLevel,
90) -> Result<(), String> {
91 let metadata = std::fs::metadata(path).map_err(|e| format!("Failed to read metadata: {e}"))?;
92
93 match mode {
94 InspectMode::Summary => {
95 log(level, LogLevel::Normal, "Data Summary:");
96 log(
97 level,
98 LogLevel::Normal,
99 &format!(" File size: {:.2} MB", metadata.len() as f64 / 1_000_000.0),
100 );
101 log(level, LogLevel::Normal, &format!(" Format: {ext}"));
102 }
103 InspectMode::Outliers => {
104 log(
105 level,
106 LogLevel::Normal,
107 &format!("Outlier Detection (z-threshold: {z_threshold}):"),
108 );
109 log(level, LogLevel::Normal, " Load data with alimentar for outlier analysis");
110 }
111 InspectMode::Distribution => {
112 log(level, LogLevel::Normal, "Distribution Statistics:");
113 log(level, LogLevel::Normal, " Load data with alimentar for distribution analysis");
114 }
115 InspectMode::Schema => {
116 log(level, LogLevel::Normal, "Schema:");
117 log(level, LogLevel::Normal, &format!(" Format: {ext}"));
118 }
119 }
120
121 Ok(())
122}
123
124fn get_extension(path: &Path) -> &str {
126 path.extension().and_then(|s| s.to_str()).unwrap_or("")
127}
128
129fn inspect_lora_adapter(dir: &Path, level: LogLevel) -> Result<(), String> {
131 let config_path = dir.join("adapter_config.json");
132 let adapter_path = dir.join("adapter_model.safetensors");
133
134 log(level, LogLevel::Normal, "LoRA Adapter:");
135 log(level, LogLevel::Normal, &format!(" Directory: {}", dir.display()));
136
137 if config_path.exists() {
139 let config_str =
140 std::fs::read_to_string(&config_path).map_err(|e| format!("Read config: {e}"))?;
141 if let Ok(config) = serde_json::from_str::<serde_json::Value>(&config_str) {
142 if let Some(rank) = config.get("r").and_then(serde_json::Value::as_u64) {
143 log(level, LogLevel::Normal, &format!(" Rank: {rank}"));
144 }
145 if let Some(alpha) = config.get("lora_alpha").and_then(serde_json::Value::as_f64) {
146 log(level, LogLevel::Normal, &format!(" Alpha: {alpha}"));
147 }
148 if let Some(modules) =
149 config.get("target_modules").and_then(serde_json::Value::as_array)
150 {
151 let names: Vec<&str> =
152 modules.iter().filter_map(serde_json::Value::as_str).collect();
153 log(level, LogLevel::Normal, &format!(" Target modules: {}", names.join(", ")));
154 }
155 if let Some(base) =
156 config.get("base_model_name_or_path").and_then(serde_json::Value::as_str)
157 {
158 log(level, LogLevel::Normal, &format!(" Base model: {base}"));
159 }
160 }
161 }
162
163 if adapter_path.exists() {
165 let size = std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0);
166 log(level, LogLevel::Normal, &format!(" Adapter size: {:.2} MB", size as f64 / 1e6));
167
168 let data = std::fs::read(&adapter_path).map_err(|e| format!("Read adapter: {e}"))?;
169 if let Ok(tensors) = safetensors::SafeTensors::deserialize(&data) {
170 let names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
171 log(level, LogLevel::Normal, &format!(" Adapter tensors: {}", names.len()));
172 let total_params: u64 = names
173 .iter()
174 .filter_map(|n| tensors.tensor(n).ok())
175 .map(|t| t.shape().iter().product::<usize>() as u64)
176 .sum();
177 log(level, LogLevel::Normal, &format!(" Trainable params: {total_params}"));
178 }
179 } else {
180 log(level, LogLevel::Normal, " (no adapter_model.safetensors found)");
181 }
182
183 Ok(())
184}
185
186pub fn run_inspect(args: InspectArgs, level: LogLevel) -> Result<(), String> {
187 log(level, LogLevel::Normal, &format!("Inspecting: {}", args.input.display()));
188
189 if !args.input.exists() {
190 return Err(format!("File not found: {}", args.input.display()));
191 }
192
193 if args.input.is_dir() && args.input.join("adapter_config.json").exists() {
195 return inspect_lora_adapter(&args.input, level);
196 }
197
198 let ext = get_extension(&args.input);
199 log(level, LogLevel::Normal, &format!(" Mode: {}", args.mode));
200
201 match ext {
202 "safetensors" => inspect_safetensors(&args.input, level),
203 "gguf" => inspect_gguf(&args.input, level),
204 "parquet" | "csv" => {
205 inspect_data_file(&args.input, ext, args.mode, args.z_threshold, level)
206 }
207 _ => {
208 if args.input.is_dir() {
209 Err(format!(
210 "Directory {} does not contain adapter_config.json",
211 args.input.display()
212 ))
213 } else {
214 Err(format!(
215 "Unsupported file format: {ext}. Use .safetensors, .gguf, .parquet, or .csv"
216 ))
217 }
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use std::path::PathBuf;
226
227 #[test]
228 fn test_get_extension_safetensors() {
229 let path = PathBuf::from("model.safetensors");
230 assert_eq!(get_extension(&path), "safetensors");
231 }
232
233 #[test]
234 fn test_get_extension_gguf() {
235 let path = PathBuf::from("model.gguf");
236 assert_eq!(get_extension(&path), "gguf");
237 }
238
239 #[test]
240 fn test_get_extension_parquet() {
241 let path = PathBuf::from("data.parquet");
242 assert_eq!(get_extension(&path), "parquet");
243 }
244
245 #[test]
246 fn test_get_extension_csv() {
247 let path = PathBuf::from("data.csv");
248 assert_eq!(get_extension(&path), "csv");
249 }
250
251 #[test]
252 fn test_get_extension_none() {
253 let path = PathBuf::from("noextension");
254 assert_eq!(get_extension(&path), "");
255 }
256
257 #[test]
258 fn test_run_inspect_file_not_found() {
259 let args = InspectArgs {
260 input: PathBuf::from("/nonexistent/path/model.safetensors"),
261 mode: InspectMode::Summary,
262 columns: None,
263 z_threshold: 3.0,
264 };
265 let result = run_inspect(args, LogLevel::Normal);
266 assert!(result.is_err());
267 assert!(result.unwrap_err().contains("File not found"));
268 }
269
270 #[test]
271 fn test_run_inspect_unsupported_format() {
272 let temp_dir = std::env::temp_dir();
273 let path = temp_dir.join("test_inspect.xyz");
274 std::fs::write(&path, "test").expect("file write should succeed");
275
276 let args = InspectArgs {
277 input: path.clone(),
278 mode: InspectMode::Summary,
279 columns: None,
280 z_threshold: 3.0,
281 };
282 let result = run_inspect(args, LogLevel::Normal);
283
284 let _ = std::fs::remove_file(&path);
285 assert!(result.is_err());
286 assert!(result.unwrap_err().contains("Unsupported file format"));
287 }
288
289 #[test]
290 fn test_inspect_data_file_summary() {
291 let temp = std::env::temp_dir().join("inspect_sum.csv");
292 std::fs::write(&temp, "a,b\n1,2").expect("write csv");
293 let r = inspect_data_file(&temp, "csv", InspectMode::Summary, 3.0, LogLevel::Normal);
294 let _ = std::fs::remove_file(&temp);
295 assert!(r.is_ok());
296 }
297
298 #[test]
299 fn test_inspect_data_file_outliers() {
300 let temp = std::env::temp_dir().join("inspect_out.csv");
301 std::fs::write(&temp, "a,b\n1,2").expect("write csv");
302 let r = inspect_data_file(&temp, "csv", InspectMode::Outliers, 2.5, LogLevel::Normal);
303 let _ = std::fs::remove_file(&temp);
304 assert!(r.is_ok());
305 }
306
307 #[test]
308 fn test_inspect_data_file_distribution() {
309 let temp = std::env::temp_dir().join("inspect_dist.csv");
310 std::fs::write(&temp, "a,b\n1,2").expect("write csv");
311 let r = inspect_data_file(&temp, "csv", InspectMode::Distribution, 3.0, LogLevel::Normal);
312 let _ = std::fs::remove_file(&temp);
313 assert!(r.is_ok());
314 }
315
316 #[test]
317 fn test_inspect_data_file_schema() {
318 let temp = std::env::temp_dir().join("inspect_sch.parquet");
319 std::fs::write(&temp, "fake parquet").expect("write file");
320 let r = inspect_data_file(&temp, "parquet", InspectMode::Schema, 3.0, LogLevel::Normal);
321 let _ = std::fs::remove_file(&temp);
322 assert!(r.is_ok());
323 }
324
325 #[test]
326 fn test_inspect_gguf() {
327 let temp = std::env::temp_dir().join("test_model.gguf");
328 std::fs::write(&temp, "GGUF fake data 12345678").expect("write gguf");
329 let r = inspect_gguf(&temp, LogLevel::Normal);
330 let _ = std::fs::remove_file(&temp);
331 assert!(r.is_ok());
332 }
333
334 #[test]
335 fn test_count_total_parameters_single_scalar() {
336 use safetensors::serialize;
337 use safetensors::tensor::TensorView;
338 use safetensors::Dtype;
339 let data = [0u8; 4]; let tv = TensorView::new(Dtype::F32, vec![1], &data).unwrap();
342 let tensors = vec![("scalar", tv)];
343 let bytes = serialize(tensors, None).unwrap();
344 let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
345 let names: Vec<String> = st.names().iter().map(|s| (*s).to_string()).collect();
346 let total = count_total_parameters(&st, &names);
347 assert_eq!(total, 1);
348 }
349
350 #[test]
351 fn test_count_total_parameters_with_data() {
352 use safetensors::serialize;
353 use safetensors::tensor::TensorView;
354 use safetensors::Dtype;
355 let data = [0u8; 24];
356 let tv = TensorView::new(Dtype::F32, vec![2, 3], &data).unwrap();
357 let tensors = vec![("w", tv)];
358 let bytes = serialize(tensors, None).unwrap();
359 let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
360 let names: Vec<String> = st.names().iter().map(|s| (*s).to_string()).collect();
361 let total = count_total_parameters(&st, &names);
362 assert_eq!(total, 6);
363 }
364
365 #[test]
366 fn test_log_model_info() {
367 log_model_info(LogLevel::Normal, 1_000_000, 500_000_000, 100);
368 }
369
370 #[test]
371 fn test_log_model_info_verbose() {
372 log_model_info(LogLevel::Verbose, 0, 0, 0);
373 }
374
375 #[test]
376 fn test_run_inspect_csv_file() {
377 let temp = std::env::temp_dir().join("ri_test.csv");
378 std::fs::write(&temp, "col1,col2\nval1,val2").expect("write csv");
379 let args = InspectArgs {
380 input: temp.clone(),
381 mode: InspectMode::Summary,
382 columns: None,
383 z_threshold: 3.0,
384 };
385 let r = run_inspect(args, LogLevel::Normal);
386 let _ = std::fs::remove_file(&temp);
387 assert!(r.is_ok());
388 }
389
390 #[test]
391 fn test_run_inspect_gguf_file() {
392 let temp = std::env::temp_dir().join("ri_test.gguf");
393 std::fs::write(&temp, "GGUF fake").expect("write gguf");
394 let args = InspectArgs {
395 input: temp.clone(),
396 mode: InspectMode::Summary,
397 columns: None,
398 z_threshold: 3.0,
399 };
400 let r = run_inspect(args, LogLevel::Normal);
401 let _ = std::fs::remove_file(&temp);
402 assert!(r.is_ok());
403 }
404
405 #[test]
406 fn test_run_inspect_directory_no_adapter() {
407 let temp = std::env::temp_dir().join("ri_dir_test");
408 let _ = std::fs::create_dir_all(&temp);
409 let args = InspectArgs {
410 input: temp.clone(),
411 mode: InspectMode::Summary,
412 columns: None,
413 z_threshold: 3.0,
414 };
415 let r = run_inspect(args, LogLevel::Normal);
416 let _ = std::fs::remove_dir_all(&temp);
417 assert!(r.is_err());
418 assert!(r.unwrap_err().contains("adapter_config.json"));
419 }
420
421 #[test]
422 fn test_run_inspect_parquet_distribution() {
423 let temp = std::env::temp_dir().join("ri_test.parquet");
424 std::fs::write(&temp, "fake parquet data").expect("write");
425 let args = InspectArgs {
426 input: temp.clone(),
427 mode: InspectMode::Distribution,
428 columns: None,
429 z_threshold: 3.0,
430 };
431 let r = run_inspect(args, LogLevel::Normal);
432 let _ = std::fs::remove_file(&temp);
433 assert!(r.is_ok());
434 }
435
436 #[test]
437 fn test_run_inspect_outlier_mode() {
438 let temp = std::env::temp_dir().join("ri_outlier.csv");
439 std::fs::write(&temp, "x\n1\n2\n3").expect("write");
440 let args = InspectArgs {
441 input: temp.clone(),
442 mode: InspectMode::Outliers,
443 columns: None,
444 z_threshold: 2.0,
445 };
446 let r = run_inspect(args, LogLevel::Normal);
447 let _ = std::fs::remove_file(&temp);
448 assert!(r.is_ok());
449 }
450
451 #[test]
452 fn test_run_inspect_schema_mode() {
453 let temp = std::env::temp_dir().join("ri_schema.csv");
454 std::fs::write(&temp, "a,b\n1,2").expect("write");
455 let args = InspectArgs {
456 input: temp.clone(),
457 mode: InspectMode::Schema,
458 columns: None,
459 z_threshold: 3.0,
460 };
461 let r = run_inspect(args, LogLevel::Normal);
462 let _ = std::fs::remove_file(&temp);
463 assert!(r.is_ok());
464 }
465
466 #[test]
467 fn test_inspect_lora_adapter_empty_dir() {
468 let temp = std::env::temp_dir().join("ri_lora_empty");
469 let _ = std::fs::create_dir_all(&temp);
470 std::fs::write(temp.join("adapter_config.json"), "{}").expect("write config");
471 let r = inspect_lora_adapter(&temp, LogLevel::Normal);
472 let _ = std::fs::remove_dir_all(&temp);
473 assert!(r.is_ok());
474 }
475
476 #[test]
477 fn test_inspect_lora_adapter_with_config() {
478 let temp = std::env::temp_dir().join("ri_lora_cfg");
479 let _ = std::fs::create_dir_all(&temp);
480 let config = serde_json::json!({
481 "r": 16,
482 "lora_alpha": 32.0,
483 "target_modules": ["q_proj", "v_proj"],
484 "base_model_name_or_path": "test/model"
485 });
486 std::fs::write(temp.join("adapter_config.json"), config.to_string()).expect("write");
487 let r = inspect_lora_adapter(&temp, LogLevel::Normal);
488 let _ = std::fs::remove_dir_all(&temp);
489 assert!(r.is_ok());
490 }
491
492 #[test]
493 fn test_run_inspect_lora_adapter_dir() {
494 let temp = std::env::temp_dir().join("ri_lora_run");
495 let _ = std::fs::create_dir_all(&temp);
496 std::fs::write(temp.join("adapter_config.json"), "{}").expect("write");
497 let args = InspectArgs {
498 input: temp.clone(),
499 mode: InspectMode::Summary,
500 columns: None,
501 z_threshold: 3.0,
502 };
503 let r = run_inspect(args, LogLevel::Normal);
504 let _ = std::fs::remove_dir_all(&temp);
505 assert!(r.is_ok());
506 }
507
508 #[test]
509 fn test_get_extension_compound() {
510 let p = PathBuf::from("model.v2.safetensors");
511 assert_eq!(get_extension(&p), "safetensors");
512 }
513
514 #[test]
515 fn test_get_extension_dotfile() {
516 let p = PathBuf::from(".hidden");
517 assert_eq!(get_extension(&p), "");
518 }
519}