Skip to main content

alimentar/cli/
basic.rs

1//! Basic CLI commands for data conversion and inspection.
2
3use std::path::{Path, PathBuf};
4
5use arrow::util::pretty::print_batches;
6
7use crate::{ArrowDataset, Dataset};
8
9/// Weighted dataset inputs: each entry is (dataset, weight, display_name), plus
10/// total weight.
11#[cfg(feature = "shuffle")]
12type MixInputs = (Vec<(ArrowDataset, f64, String)>, f64);
13
14/// Load a dataset from a file path based on extension.
15pub(crate) fn load_dataset(path: &Path) -> crate::Result<ArrowDataset> {
16    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
17
18    match ext {
19        "parquet" => ArrowDataset::from_parquet(path),
20        "csv" => ArrowDataset::from_csv(path),
21        "json" | "jsonl" => ArrowDataset::from_json(path),
22        ext => Err(crate::Error::unsupported_format(ext)),
23    }
24}
25
26/// Save a dataset to a file path based on extension.
27pub(crate) fn save_dataset(dataset: &ArrowDataset, path: &Path) -> crate::Result<()> {
28    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
29
30    match ext {
31        "parquet" => dataset.to_parquet(path),
32        "csv" => dataset.to_csv(path),
33        "json" | "jsonl" => dataset.to_json(path),
34        ext => Err(crate::Error::unsupported_format(ext)),
35    }
36}
37
38/// Get format name from file extension.
39pub(crate) fn get_format(path: &Path) -> &'static str {
40    match path.extension().and_then(|e| e.to_str()) {
41        Some("parquet") => "Parquet",
42        Some("arrow" | "ipc") => "Arrow IPC",
43        Some("csv") => "CSV",
44        Some("json" | "jsonl") => "JSON",
45        _ => "Unknown",
46    }
47}
48
49/// Convert between data formats.
50pub(crate) fn cmd_convert(input: &Path, output: &Path) -> crate::Result<()> {
51    // Load input (supports parquet, csv)
52    let dataset = load_dataset(input)?;
53
54    // Save output (supports parquet, csv)
55    save_dataset(&dataset, output)?;
56
57    println!(
58        "Converted {} -> {} ({} rows)",
59        input.display(),
60        output.display(),
61        dataset.len()
62    );
63
64    Ok(())
65}
66
67/// Display dataset information.
68/// ALB-099: For Parquet, reads only file metadata (footer) — zero row group
69/// decoding. dhat profiling showed the old path loaded the entire dataset (69.8
70/// MB for 9 MB file).
71pub(crate) fn cmd_info(path: &Path) -> crate::Result<()> {
72    let file_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
73    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
74
75    println!("File: {}", path.display());
76    println!("Format: {}", get_format(path));
77
78    if ext == "parquet" {
79        // Read only Parquet footer metadata — no row group decoding
80        let file = std::fs::File::open(path).map_err(|e| crate::Error::io(e, path))?;
81        let builder = parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file)
82            .map_err(crate::Error::Parquet)?;
83        let metadata = builder.metadata();
84        let num_rows: i64 = metadata.row_groups().iter().map(|rg| rg.num_rows()).sum();
85        let num_batches = metadata.num_row_groups();
86        let num_columns = metadata
87            .row_groups()
88            .first()
89            .map_or(0, |rg| rg.num_columns());
90        println!("Rows: {num_rows}");
91        println!("Batches: {num_batches}");
92        println!("Columns: {num_columns}");
93    } else {
94        let dataset = load_dataset(path)?;
95        println!("Rows: {}", dataset.len());
96        println!("Batches: {}", dataset.num_batches());
97        println!("Columns: {}", dataset.schema().fields().len());
98    }
99
100    println!("Size: {file_size} bytes");
101
102    Ok(())
103}
104
105/// Display first N rows of a dataset.
106pub(crate) fn cmd_head(path: &Path, rows: usize) -> crate::Result<()> {
107    // ALB-099: For Parquet, use with_limit() to avoid loading the entire dataset.
108    // dhat profiling showed `head -n 10` allocated 69.8 MB for a 9 MB/1M-row file.
109    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
110    let dataset = if ext == "parquet" {
111        let file = std::fs::File::open(path).map_err(|e| crate::Error::io(e, path))?;
112        let builder = parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file)
113            .map_err(crate::Error::Parquet)?;
114        let reader = builder
115            .with_limit(rows)
116            .build()
117            .map_err(crate::Error::Parquet)?;
118        let batches: Vec<arrow::record_batch::RecordBatch> = reader
119            .collect::<std::result::Result<Vec<_>, _>>()
120            .map_err(crate::Error::Arrow)?;
121        if batches.is_empty() {
122            println!("Dataset is empty");
123            return Ok(());
124        }
125        ArrowDataset::new(batches)?
126    } else {
127        load_dataset(path)?
128    };
129
130    if dataset.is_empty() {
131        println!("Dataset is empty");
132        return Ok(());
133    }
134
135    // Collect rows into batches
136    let mut collected = Vec::new();
137    let mut count = 0;
138
139    for batch in dataset.iter() {
140        let take = (rows - count).min(batch.num_rows());
141        if take > 0 {
142            collected.push(batch.slice(0, take));
143            count += take;
144        }
145        if count >= rows {
146            break;
147        }
148    }
149
150    if collected.is_empty() {
151        println!("No data to display");
152        return Ok(());
153    }
154
155    // Print using Arrow's pretty printer
156    print_batches(&collected).map_err(crate::Error::Arrow)?;
157
158    if count < dataset.len() {
159        println!("... showing {} of {} rows", count, dataset.len());
160    }
161
162    Ok(())
163}
164
165/// Display dataset schema.
166pub(crate) fn cmd_schema(path: &Path) -> crate::Result<()> {
167    let dataset = load_dataset(path)?;
168    let schema = dataset.schema();
169
170    println!("Schema for {}:", path.display());
171    println!();
172
173    for (i, field) in schema.fields().iter().enumerate() {
174        let nullable = if field.is_nullable() {
175            "nullable"
176        } else {
177            "not null"
178        };
179        println!(
180            "  {}: {} ({}) [{}]",
181            i,
182            field.name(),
183            field.data_type(),
184            nullable
185        );
186    }
187
188    println!();
189    println!("Total columns: {}", schema.fields().len());
190
191    Ok(())
192}
193
194/// Parse an input spec of the form "path" or "path:weight".
195#[cfg(feature = "shuffle")]
196fn parse_input_spec(spec: &str) -> (PathBuf, f64) {
197    if let Some((path, weight_str)) = spec.rsplit_once(':') {
198        // Check if the part after : is a valid float (not a Windows drive letter)
199        if let Ok(weight) = weight_str.parse::<f64>() {
200            return (PathBuf::from(path), weight);
201        }
202    }
203    (PathBuf::from(spec), 1.0)
204}
205
206/// Load input datasets from specs.
207#[cfg(feature = "shuffle")]
208fn load_mix_inputs(inputs: &[String]) -> crate::Result<MixInputs> {
209    let mut datasets = Vec::new();
210    let mut total_weight = 0.0;
211
212    for spec in inputs {
213        let (path, weight) = parse_input_spec(spec);
214        if !path.exists() {
215            return Err(crate::Error::io(
216                std::io::Error::new(std::io::ErrorKind::NotFound, "Input file not found"),
217                &path,
218            ));
219        }
220        let dataset = load_dataset(&path)?;
221        println!(
222            "  Loaded {} ({} rows, weight={:.2})",
223            path.display(),
224            dataset.len(),
225            weight
226        );
227        total_weight += weight;
228        datasets.push((dataset, weight, path.display().to_string()));
229    }
230    Ok((datasets, total_weight))
231}
232
233/// Sample rows from a dataset with optional upsampling.
234#[cfg(feature = "shuffle")]
235#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
236fn sample_dataset(
237    dataset: &ArrowDataset,
238    rows_needed: usize,
239    rng: &mut rand::rngs::StdRng,
240) -> crate::Result<arrow::array::RecordBatch> {
241    use rand::seq::SliceRandom;
242
243    let available = dataset.len();
244    let mut indices: Vec<usize> = (0..available).collect();
245    indices.shuffle(rng);
246
247    if rows_needed > available {
248        let extra: Vec<usize> = (0..available)
249            .cycle()
250            .take(rows_needed - available)
251            .collect();
252        indices.extend(extra);
253    }
254    indices.truncate(rows_needed);
255
256    let schema = dataset.schema();
257    let flat_batches: Vec<_> = dataset.iter().collect();
258    let concatenated = arrow::compute::concat_batches(&schema, &flat_batches)
259        .map_err(|e| crate::Error::invalid_config(format!("Arrow concat error: {e}")))?;
260
261    let take_indices: Vec<u32> = indices.iter().map(|&i| i as u32).collect();
262    let index_array = arrow::array::UInt32Array::from(take_indices);
263
264    let columns: Vec<arrow::array::ArrayRef> = (0..concatenated.num_columns())
265        .map(|col_idx| {
266            arrow::compute::take(concatenated.column(col_idx), &index_array, None)
267                .map_err(|e| crate::Error::invalid_config(format!("Arrow take error: {e}")))
268        })
269        .collect::<crate::Result<Vec<_>>>()?;
270
271    arrow::array::RecordBatch::try_new(schema, columns)
272        .map_err(|e| crate::Error::invalid_config(format!("RecordBatch error: {e}")))
273}
274
275/// Mix multiple datasets with weighted sampling.
276#[cfg(feature = "shuffle")]
277#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
278pub(crate) fn cmd_mix(
279    inputs: &[String],
280    output: &Path,
281    seed: u64,
282    max_rows: usize,
283) -> crate::Result<()> {
284    use rand::{rngs::StdRng, SeedableRng};
285
286    if inputs.is_empty() {
287        return Err(crate::Error::invalid_config("No input files provided"));
288    }
289
290    let (datasets, total_weight) = load_mix_inputs(inputs)?;
291    if total_weight == 0.0 {
292        return Err(crate::Error::invalid_config("All weights are zero"));
293    }
294
295    let total_available: usize = datasets.iter().map(|(d, _, _)| d.len()).sum();
296    let target_rows = if max_rows > 0 {
297        max_rows
298    } else {
299        total_available
300    };
301
302    println!(
303        "\nMixing {} datasets → {} target rows",
304        datasets.len(),
305        target_rows
306    );
307
308    let mut rng = StdRng::seed_from_u64(seed);
309    let mut all_batches = Vec::new();
310    let mut mixed_rows = 0;
311
312    for (dataset, weight, name) in &datasets {
313        let fraction = weight / total_weight;
314        let rows_for_dataset = (target_rows as f64 * fraction) as usize;
315
316        let batch = sample_dataset(dataset, rows_for_dataset, &mut rng)?;
317        let count = batch.num_rows();
318        all_batches.push(batch);
319        mixed_rows += count;
320
321        println!("  {} → {} rows ({:.1}%)", name, count, fraction * 100.0);
322    }
323
324    if all_batches.is_empty() {
325        return Err(crate::Error::invalid_config("No data to mix"));
326    }
327
328    let mixed = ArrowDataset::new(all_batches)?;
329    save_dataset(&mixed, output)?;
330
331    println!("\nMixed {} rows → {}", mixed_rows, output.display());
332    Ok(())
333}
334
335#[cfg(feature = "shuffle")]
336pub(crate) fn cmd_fim(
337    input: &Path,
338    output: &Path,
339    column: &str,
340    rate: f64,
341    format: &str,
342    seed: u64,
343) -> crate::Result<()> {
344    use crate::transform::{Fim, FimFormat, Transform};
345
346    let dataset = load_dataset(input)?;
347    let fim_format = match format {
348        "spm" => FimFormat::SPM,
349        _ => FimFormat::PSM,
350    };
351
352    let fim = Fim::new(column)
353        .with_rate(rate)
354        .with_format(fim_format)
355        .with_seed(seed);
356
357    let mut all_batches = Vec::new();
358    for batch in dataset.iter() {
359        all_batches.push(fim.apply(batch)?);
360    }
361
362    let transformed = ArrowDataset::new(all_batches)?;
363    save_dataset(&transformed, output)?;
364
365    println!(
366        "FIM transform ({} format, {:.0}% rate) applied to '{}' column",
367        format.to_uppercase(),
368        rate * 100.0,
369        column
370    );
371    println!("{} rows → {}", dataset.len(), output.display());
372    Ok(())
373}
374
375/// R-019: Deduplicate dataset rows by text content hash.
376///
377/// Uses SHA-256 content hashing for exact deduplication on the specified
378/// text column. Falls back to full-row deduplication if no column specified.
379pub(crate) fn cmd_dedup(input: &Path, output: &Path, column: Option<&str>) -> crate::Result<()> {
380    use crate::transform::{Transform, Unique};
381
382    let dataset = load_dataset(input)?;
383    let original_rows = dataset.len();
384
385    // Use content-hash dedup on text column, or full-row dedup
386    let dedup = match column {
387        Some(col) => Unique::by(vec![col]),
388        None => detect_text_column_dedup(&dataset),
389    };
390
391    let mut all_batches = Vec::new();
392    for batch in dataset.iter() {
393        all_batches.push(dedup.apply(batch)?);
394    }
395
396    let deduped = ArrowDataset::new(all_batches)?;
397    let deduped_rows = deduped.len();
398    save_dataset(&deduped, output)?;
399
400    let removed = original_rows - deduped_rows;
401    println!(
402        "Dedup: {} → {} rows ({} duplicates removed, {:.1}% reduction)",
403        original_rows,
404        deduped_rows,
405        removed,
406        removed as f64 / original_rows.max(1) as f64 * 100.0
407    );
408    Ok(())
409}
410
411/// Auto-detect text column and create Unique transform for it.
412fn detect_text_column_dedup(dataset: &ArrowDataset) -> crate::transform::Unique {
413    use arrow::datatypes::DataType;
414
415    use crate::transform::Unique;
416
417    let schema = dataset.schema();
418    for name in &["text", "content", "code", "source"] {
419        if let Some((_, field)) = schema.column_with_name(name) {
420            if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
421                return Unique::by(vec![*name]);
422            }
423        }
424    }
425    // Fallback: dedup on all columns
426    Unique::all()
427}
428
429/// R-022: Filter dataset rows by text quality signals.
430///
431/// Computes quality scores for each row and removes low-quality entries.
432/// Signals: line length, alphanumeric ratio, duplicate line ratio, entropy.
433pub(crate) fn cmd_filter_text(
434    input: &Path,
435    output: &Path,
436    column: Option<&str>,
437    min_score: f64,
438    min_length: usize,
439    max_length: usize,
440) -> crate::Result<()> {
441    use crate::transform::Transform;
442
443    let dataset = load_dataset(input)?;
444    let original_rows = dataset.len();
445
446    let col_name = column
447        .map(String::from)
448        .unwrap_or_else(|| find_text_column(&dataset));
449
450    let filter = TextQualityFilter::new(&col_name, min_score, min_length, max_length);
451
452    let mut all_batches = Vec::new();
453    for batch in dataset.iter() {
454        all_batches.push(filter.apply(batch)?);
455    }
456
457    let filtered = ArrowDataset::new(all_batches)?;
458    let kept = filtered.len();
459    save_dataset(&filtered, output)?;
460
461    let removed = original_rows - kept;
462    println!(
463        "Filter: {} → {} rows ({} removed, {:.1}% kept)",
464        original_rows,
465        kept,
466        removed,
467        kept as f64 / original_rows.max(1) as f64 * 100.0
468    );
469    println!(
470        "  min_score={:.2} min_len={} max_len={} column='{}'",
471        min_score, min_length, max_length, col_name
472    );
473    Ok(())
474}
475
476/// Find the first text column in a dataset.
477fn find_text_column(dataset: &ArrowDataset) -> String {
478    use arrow::datatypes::DataType;
479    let schema = dataset.schema();
480    for name in &["text", "content", "code", "source"] {
481        if let Some((_, field)) = schema.column_with_name(name) {
482            if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
483                return (*name).to_string();
484            }
485        }
486    }
487    // Fallback: first Utf8 column
488    for field in schema.fields() {
489        if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
490            return field.name().clone();
491        }
492    }
493    "text".to_string()
494}
495
496/// Text quality filter transform.
497struct TextQualityFilter {
498    column: String,
499    min_score: f64,
500    min_length: usize,
501    max_length: usize,
502}
503
504impl TextQualityFilter {
505    fn new(column: &str, min_score: f64, min_length: usize, max_length: usize) -> Self {
506        Self {
507            column: column.to_string(),
508            min_score,
509            min_length,
510            max_length,
511        }
512    }
513}
514
515impl crate::transform::Transform for TextQualityFilter {
516    fn apply(&self, batch: arrow::array::RecordBatch) -> crate::Result<arrow::array::RecordBatch> {
517        use arrow::{
518            array::{Array, BooleanArray, StringArray},
519            compute::filter_record_batch,
520        };
521
522        let schema = batch.schema();
523        let col_idx = schema
524            .column_with_name(&self.column)
525            .map(|(i, _)| i)
526            .ok_or_else(|| crate::Error::column_not_found(&self.column))?;
527
528        let text_arr = batch
529            .column(col_idx)
530            .as_any()
531            .downcast_ref::<StringArray>()
532            .ok_or_else(|| crate::Error::column_not_found(&self.column))?;
533
534        let mask: BooleanArray = (0..text_arr.len())
535            .map(|i| {
536                if text_arr.is_null(i) {
537                    Some(false)
538                } else {
539                    let text = text_arr.value(i);
540                    Some(passes_quality(
541                        text,
542                        self.min_score,
543                        self.min_length,
544                        self.max_length,
545                    ))
546                }
547            })
548            .collect();
549
550        filter_record_batch(&batch, &mask).map_err(crate::Error::Arrow)
551    }
552}
553
554/// Check if a text document passes quality thresholds.
555fn passes_quality(text: &str, min_score: f64, min_len: usize, max_len: usize) -> bool {
556    let len = text.len();
557    if len < min_len || len > max_len {
558        return false;
559    }
560    composite_score(text) >= min_score
561}
562
563/// Compute composite quality score (0.0-1.0) for a text document.
564fn composite_score(text: &str) -> f64 {
565    let s1 = score_alnum_ratio(text);
566    let s2 = score_line_length(text);
567    let s3 = score_dup_lines(text);
568    let s4 = score_entropy(text);
569    (s1 + s2 + s3 + s4) / 4.0
570}
571
572/// Alphanumeric character ratio. Below 0.3 = likely binary/garbage.
573fn score_alnum_ratio(text: &str) -> f64 {
574    if text.is_empty() {
575        return 0.0;
576    }
577    let alnum = text.chars().filter(|c| c.is_alphanumeric()).count();
578    let ratio = alnum as f64 / text.len() as f64;
579    if ratio < 0.2 {
580        0.0
581    } else if ratio < 0.3 {
582        ratio
583    } else {
584        1.0
585    }
586}
587
588/// Average line length score. Ideal 30-80 chars.
589fn score_line_length(text: &str) -> f64 {
590    let lines: Vec<&str> = text.lines().collect();
591    if lines.is_empty() {
592        return 0.0;
593    }
594    let avg = text.len() as f64 / lines.len() as f64;
595    if avg < 10.0 {
596        0.2
597    } else if avg > 200.0 {
598        0.5
599    } else {
600        1.0
601    }
602}
603
604/// Duplicate line ratio. High = boilerplate.
605fn score_dup_lines(text: &str) -> f64 {
606    use std::collections::HashSet;
607    let lines: Vec<&str> = text.lines().collect();
608    if lines.len() <= 1 {
609        return 1.0;
610    }
611    let unique: HashSet<&str> = lines.iter().copied().collect();
612    let dup_ratio = 1.0 - (unique.len() as f64 / lines.len() as f64);
613    if dup_ratio > 0.5 {
614        0.2
615    } else {
616        1.0 - dup_ratio
617    }
618}
619
620/// Character-level Shannon entropy. Low = repetitive, high = random/binary.
621fn score_entropy(text: &str) -> f64 {
622    if text.is_empty() {
623        return 0.0;
624    }
625    let mut counts = [0u32; 256];
626    for &b in text.as_bytes() {
627        counts[b as usize] += 1;
628    }
629    let len = text.len() as f64;
630    let entropy: f64 = counts
631        .iter()
632        .filter(|&&c| c > 0)
633        .map(|&c| {
634            let p = f64::from(c) / len;
635            -p * p.ln()
636        })
637        .sum();
638    let e = entropy / std::f64::consts::LN_2; // bits
639    if e < 2.0 {
640        0.2
641    } else if e > 6.5 {
642        0.3
643    } else {
644        1.0
645    }
646}
647
648#[cfg(test)]
649#[allow(
650    clippy::cast_possible_truncation,
651    clippy::cast_possible_wrap,
652    clippy::cast_precision_loss,
653    clippy::uninlined_format_args,
654    clippy::unwrap_used,
655    clippy::expect_used,
656    clippy::redundant_clone,
657    clippy::cast_lossless,
658    clippy::redundant_closure_for_method_calls,
659    clippy::too_many_lines,
660    clippy::float_cmp,
661    clippy::similar_names,
662    clippy::needless_late_init,
663    clippy::redundant_pattern_matching
664)]
665mod tests {
666    use std::sync::Arc;
667
668    use arrow::{
669        array::{Int32Array, StringArray},
670        datatypes::{DataType, Field, Schema},
671    };
672
673    use super::*;
674
675    fn create_test_parquet(path: &Path, rows: usize) {
676        let schema = Arc::new(Schema::new(vec![
677            Field::new("id", DataType::Int32, false),
678            Field::new("name", DataType::Utf8, false),
679        ]));
680
681        let ids: Vec<i32> = (0..rows as i32).collect();
682        let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
683
684        let batch = arrow::array::RecordBatch::try_new(
685            schema,
686            vec![
687                Arc::new(Int32Array::from(ids)),
688                Arc::new(StringArray::from(names)),
689            ],
690        )
691        .ok()
692        .unwrap_or_else(|| panic!("Should create batch"));
693
694        let dataset = ArrowDataset::from_batch(batch)
695            .ok()
696            .unwrap_or_else(|| panic!("Should create dataset"));
697
698        dataset
699            .to_parquet(path)
700            .ok()
701            .unwrap_or_else(|| panic!("Should write parquet"));
702    }
703
704    #[test]
705    fn test_cmd_info() {
706        let temp_dir = tempfile::tempdir()
707            .ok()
708            .unwrap_or_else(|| panic!("Should create temp dir"));
709        let path = temp_dir.path().join("test.parquet");
710        create_test_parquet(&path, 100);
711
712        let result = cmd_info(&path);
713        assert!(result.is_ok());
714    }
715
716    #[test]
717    fn test_cmd_head() {
718        let temp_dir = tempfile::tempdir()
719            .ok()
720            .unwrap_or_else(|| panic!("Should create temp dir"));
721        let path = temp_dir.path().join("test.parquet");
722        create_test_parquet(&path, 100);
723
724        let result = cmd_head(&path, 5);
725        assert!(result.is_ok());
726    }
727
728    #[test]
729    fn test_cmd_schema() {
730        let temp_dir = tempfile::tempdir()
731            .ok()
732            .unwrap_or_else(|| panic!("Should create temp dir"));
733        let path = temp_dir.path().join("test.parquet");
734        create_test_parquet(&path, 10);
735
736        let result = cmd_schema(&path);
737        assert!(result.is_ok());
738    }
739
740    #[test]
741    fn test_cmd_convert() {
742        let temp_dir = tempfile::tempdir()
743            .ok()
744            .unwrap_or_else(|| panic!("Should create temp dir"));
745        let input = temp_dir.path().join("input.parquet");
746        let output = temp_dir.path().join("output.parquet");
747        create_test_parquet(&input, 50);
748
749        let result = cmd_convert(&input, &output);
750        assert!(result.is_ok());
751
752        // Verify output was created and has same data
753        let original = ArrowDataset::from_parquet(&input)
754            .ok()
755            .unwrap_or_else(|| panic!("Should load original"));
756        let converted = ArrowDataset::from_parquet(&output)
757            .ok()
758            .unwrap_or_else(|| panic!("Should load converted"));
759
760        assert_eq!(original.len(), converted.len());
761    }
762
763    #[test]
764    fn test_load_dataset_unsupported() {
765        let path = PathBuf::from("test.xyz");
766        let result = load_dataset(&path);
767        assert!(result.is_err());
768    }
769
770    #[test]
771    fn test_get_format() {
772        assert_eq!(get_format(Path::new("test.parquet")), "Parquet");
773        assert_eq!(get_format(Path::new("test.arrow")), "Arrow IPC");
774        assert_eq!(get_format(Path::new("test.csv")), "CSV");
775        assert_eq!(get_format(Path::new("test.json")), "JSON");
776        assert_eq!(get_format(Path::new("test.unknown")), "Unknown");
777    }
778
779    #[test]
780    fn test_cmd_head_with_more_rows_than_dataset() {
781        let temp_dir = tempfile::tempdir()
782            .ok()
783            .unwrap_or_else(|| panic!("Should create temp dir"));
784        let path = temp_dir.path().join("test.parquet");
785        create_test_parquet(&path, 5);
786
787        // Request more rows than exist
788        let result = cmd_head(&path, 100);
789        assert!(result.is_ok());
790    }
791
792    #[test]
793    fn test_cmd_convert_parquet_to_csv() {
794        let temp_dir = tempfile::tempdir()
795            .ok()
796            .unwrap_or_else(|| panic!("Should create temp dir"));
797        let input = temp_dir.path().join("input.parquet");
798        let output = temp_dir.path().join("output.csv");
799        create_test_parquet(&input, 25);
800
801        let result = cmd_convert(&input, &output);
802        assert!(result.is_ok());
803        assert!(output.exists());
804    }
805
806    #[test]
807    fn test_cmd_convert_parquet_to_json() {
808        let temp_dir = tempfile::tempdir()
809            .ok()
810            .unwrap_or_else(|| panic!("Should create temp dir"));
811        let input = temp_dir.path().join("input.parquet");
812        let output = temp_dir.path().join("output.json");
813        create_test_parquet(&input, 15);
814
815        let result = cmd_convert(&input, &output);
816        assert!(result.is_ok());
817        assert!(output.exists());
818    }
819
820    #[test]
821    fn test_save_dataset_unsupported_format() {
822        let temp_dir = tempfile::tempdir()
823            .ok()
824            .unwrap_or_else(|| panic!("Should create temp dir"));
825        let input = temp_dir.path().join("data.parquet");
826        let output = temp_dir.path().join("output.xyz");
827        create_test_parquet(&input, 5);
828
829        let dataset = ArrowDataset::from_parquet(&input)
830            .ok()
831            .unwrap_or_else(|| panic!("Should load"));
832
833        let result = save_dataset(&dataset, &output);
834        assert!(result.is_err());
835    }
836
837    #[test]
838    fn test_get_format_ipc() {
839        assert_eq!(get_format(Path::new("test.ipc")), "Arrow IPC");
840    }
841
842    #[test]
843    fn test_get_format_jsonl() {
844        assert_eq!(get_format(Path::new("test.jsonl")), "JSON");
845    }
846
847    #[test]
848    fn test_get_format_no_extension() {
849        assert_eq!(get_format(Path::new("testfile")), "Unknown");
850    }
851
852    #[test]
853    fn test_cmd_convert_unsupported_output() {
854        let temp_dir = tempfile::tempdir()
855            .ok()
856            .unwrap_or_else(|| panic!("Should create temp dir"));
857        let input = temp_dir.path().join("input.parquet");
858        let output = temp_dir.path().join("output.xyz");
859        create_test_parquet(&input, 10);
860
861        let result = cmd_convert(&input, &output);
862        assert!(result.is_err());
863    }
864
865    #[test]
866    fn test_load_dataset_xyz_format() {
867        let temp_dir = tempfile::tempdir()
868            .ok()
869            .unwrap_or_else(|| panic!("Should create temp dir"));
870        let path = temp_dir.path().join("data.xyz");
871
872        std::fs::write(&path, "some data")
873            .ok()
874            .unwrap_or_else(|| panic!("Should write file"));
875
876        let result = load_dataset(&path);
877        assert!(result.is_err());
878    }
879
880    #[test]
881    fn test_get_format_arrow() {
882        assert_eq!(get_format(Path::new("test.arrow")), "Arrow IPC");
883    }
884
885    #[test]
886    fn test_get_format_unknown() {
887        assert_eq!(get_format(Path::new("test.feather")), "Unknown");
888        assert_eq!(get_format(Path::new("test.txt")), "Unknown");
889    }
890
891    #[test]
892    fn test_load_dataset_csv() {
893        let temp_dir = tempfile::tempdir()
894            .ok()
895            .unwrap_or_else(|| panic!("Should create temp dir"));
896        let parquet_path = temp_dir.path().join("data.parquet");
897        let csv_path = temp_dir.path().join("data.csv");
898
899        create_test_parquet(&parquet_path, 10);
900
901        // Convert to CSV first
902        let dataset = ArrowDataset::from_parquet(&parquet_path)
903            .ok()
904            .unwrap_or_else(|| panic!("Should load"));
905        dataset
906            .to_csv(&csv_path)
907            .ok()
908            .unwrap_or_else(|| panic!("Should write csv"));
909
910        // Load from CSV
911        let loaded = load_dataset(&csv_path);
912        assert!(loaded.is_ok());
913    }
914
915    #[test]
916    fn test_load_dataset_json() {
917        let temp_dir = tempfile::tempdir()
918            .ok()
919            .unwrap_or_else(|| panic!("Should create temp dir"));
920        let parquet_path = temp_dir.path().join("data.parquet");
921        let json_path = temp_dir.path().join("data.json");
922
923        create_test_parquet(&parquet_path, 10);
924
925        // Convert to JSON first
926        let dataset = ArrowDataset::from_parquet(&parquet_path)
927            .ok()
928            .unwrap_or_else(|| panic!("Should load"));
929        dataset
930            .to_json(&json_path)
931            .ok()
932            .unwrap_or_else(|| panic!("Should write json"));
933
934        // Load from JSON
935        let loaded = load_dataset(&json_path);
936        assert!(loaded.is_ok());
937    }
938
939    #[test]
940    fn test_load_dataset_jsonl() {
941        let temp_dir = tempfile::tempdir()
942            .ok()
943            .unwrap_or_else(|| panic!("Should create temp dir"));
944        let parquet_path = temp_dir.path().join("data.parquet");
945        let jsonl_path = temp_dir.path().join("data.jsonl");
946
947        create_test_parquet(&parquet_path, 10);
948
949        // Convert to JSON first (jsonl is same format)
950        let dataset = ArrowDataset::from_parquet(&parquet_path)
951            .ok()
952            .unwrap_or_else(|| panic!("Should load"));
953        dataset
954            .to_json(&jsonl_path)
955            .ok()
956            .unwrap_or_else(|| panic!("Should write jsonl"));
957
958        // Load from JSONL
959        let loaded = load_dataset(&jsonl_path);
960        assert!(loaded.is_ok());
961    }
962
963    #[test]
964    fn test_save_dataset_parquet() {
965        let temp_dir = tempfile::tempdir()
966            .ok()
967            .unwrap_or_else(|| panic!("Should create temp dir"));
968        let path = temp_dir.path().join("output.parquet");
969
970        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
971        let batch = arrow::array::RecordBatch::try_new(
972            schema,
973            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
974        )
975        .unwrap();
976        let dataset = ArrowDataset::from_batch(batch).unwrap();
977
978        let result = save_dataset(&dataset, &path);
979        assert!(result.is_ok());
980    }
981
982    #[test]
983    fn test_save_dataset_csv() {
984        let temp_dir = tempfile::tempdir()
985            .ok()
986            .unwrap_or_else(|| panic!("Should create temp dir"));
987        let path = temp_dir.path().join("output.csv");
988
989        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
990        let batch = arrow::array::RecordBatch::try_new(
991            schema,
992            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
993        )
994        .unwrap();
995        let dataset = ArrowDataset::from_batch(batch).unwrap();
996
997        let result = save_dataset(&dataset, &path);
998        assert!(result.is_ok());
999    }
1000
1001    #[test]
1002    fn test_save_dataset_json() {
1003        let temp_dir = tempfile::tempdir()
1004            .ok()
1005            .unwrap_or_else(|| panic!("Should create temp dir"));
1006        let path = temp_dir.path().join("output.json");
1007
1008        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1009        let batch = arrow::array::RecordBatch::try_new(
1010            schema,
1011            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1012        )
1013        .unwrap();
1014        let dataset = ArrowDataset::from_batch(batch).unwrap();
1015
1016        let result = save_dataset(&dataset, &path);
1017        assert!(result.is_ok());
1018    }
1019
1020    #[test]
1021    fn test_save_dataset_unknown_extension() {
1022        let temp_dir = tempfile::tempdir()
1023            .ok()
1024            .unwrap_or_else(|| panic!("Should create temp dir"));
1025        let path = temp_dir.path().join("output.xyz");
1026
1027        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1028        let batch = arrow::array::RecordBatch::try_new(
1029            schema,
1030            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1031        )
1032        .unwrap();
1033        let dataset = ArrowDataset::from_batch(batch).unwrap();
1034
1035        let result = save_dataset(&dataset, &path);
1036        assert!(result.is_err());
1037    }
1038
1039    #[test]
1040    fn test_cmd_convert_to_csv_format() {
1041        let temp_dir = tempfile::tempdir()
1042            .ok()
1043            .unwrap_or_else(|| panic!("Should create temp dir"));
1044        let input = temp_dir.path().join("input.parquet");
1045        let output = temp_dir.path().join("output.csv");
1046        create_test_parquet(&input, 20);
1047
1048        let result = cmd_convert(&input, &output);
1049        assert!(result.is_ok());
1050        assert!(output.exists());
1051    }
1052
1053    #[test]
1054    fn test_cmd_convert_to_json_format() {
1055        let temp_dir = tempfile::tempdir()
1056            .ok()
1057            .unwrap_or_else(|| panic!("Should create temp dir"));
1058        let input = temp_dir.path().join("input.parquet");
1059        let output = temp_dir.path().join("output.json");
1060        create_test_parquet(&input, 20);
1061
1062        let result = cmd_convert(&input, &output);
1063        assert!(result.is_ok());
1064        assert!(output.exists());
1065    }
1066
1067    #[test]
1068    fn test_cmd_head_more_than_available() {
1069        let temp_dir = tempfile::tempdir()
1070            .ok()
1071            .unwrap_or_else(|| panic!("Should create temp dir"));
1072        let path = temp_dir.path().join("small.parquet");
1073        create_test_parquet(&path, 5);
1074
1075        // Request more rows than available
1076        let result = cmd_head(&path, 100);
1077        assert!(result.is_ok());
1078    }
1079
1080    #[test]
1081    fn test_load_dataset_csv_file() {
1082        let temp_dir = tempfile::tempdir()
1083            .ok()
1084            .unwrap_or_else(|| panic!("Should create temp dir"));
1085        let csv_path = temp_dir.path().join("test.csv");
1086
1087        // Create a simple CSV
1088        std::fs::write(&csv_path, "id,name\n1,foo\n2,bar\n").unwrap();
1089
1090        let result = load_dataset(&csv_path);
1091        assert!(result.is_ok());
1092    }
1093
1094    #[test]
1095    fn test_load_dataset_json_file() {
1096        let temp_dir = tempfile::tempdir()
1097            .ok()
1098            .unwrap_or_else(|| panic!("Should create temp dir"));
1099        let json_path = temp_dir.path().join("test.json");
1100
1101        // Create a simple JSON Lines file
1102        std::fs::write(
1103            &json_path,
1104            r#"{"id":1,"name":"foo"}
1105{"id":2,"name":"bar"}"#,
1106        )
1107        .unwrap();
1108
1109        let result = load_dataset(&json_path);
1110        assert!(result.is_ok());
1111    }
1112
1113    // === Additional CLI basic tests ===
1114
1115    #[test]
1116    fn test_cmd_head_zero_rows() {
1117        let temp_dir = tempfile::tempdir()
1118            .ok()
1119            .unwrap_or_else(|| panic!("Should create temp dir"));
1120        let path = temp_dir.path().join("test.parquet");
1121        create_test_parquet(&path, 50);
1122
1123        // Request 0 rows - should still work
1124        let result = cmd_head(&path, 0);
1125        assert!(result.is_ok());
1126    }
1127
1128    #[test]
1129    fn test_cmd_info_small_file() {
1130        let temp_dir = tempfile::tempdir()
1131            .ok()
1132            .unwrap_or_else(|| panic!("Should create temp dir"));
1133        let path = temp_dir.path().join("small.parquet");
1134        create_test_parquet(&path, 5);
1135
1136        let result = cmd_info(&path);
1137        assert!(result.is_ok());
1138    }
1139
1140    #[test]
1141    fn test_cmd_info_large_file() {
1142        let temp_dir = tempfile::tempdir()
1143            .ok()
1144            .unwrap_or_else(|| panic!("Should create temp dir"));
1145        let path = temp_dir.path().join("large.parquet");
1146        create_test_parquet(&path, 1000);
1147
1148        let result = cmd_info(&path);
1149        assert!(result.is_ok());
1150    }
1151
1152    #[test]
1153    fn test_cmd_schema_complex() {
1154        let temp_dir = tempfile::tempdir()
1155            .ok()
1156            .unwrap_or_else(|| panic!("Should create temp dir"));
1157        let path = temp_dir.path().join("complex.parquet");
1158
1159        // Create with more columns
1160        let schema = Arc::new(Schema::new(vec![
1161            Field::new("id", DataType::Int32, false),
1162            Field::new("name", DataType::Utf8, true),
1163            Field::new("value", DataType::Float64, true),
1164        ]));
1165
1166        let batch = arrow::array::RecordBatch::try_new(
1167            schema,
1168            vec![
1169                Arc::new(Int32Array::from(vec![1, 2, 3])),
1170                Arc::new(StringArray::from(vec!["a", "b", "c"])),
1171                Arc::new(arrow::array::Float64Array::from(vec![1.0, 2.0, 3.0])),
1172            ],
1173        )
1174        .unwrap();
1175
1176        let dataset = ArrowDataset::from_batch(batch).unwrap();
1177        dataset.to_parquet(&path).unwrap();
1178
1179        let result = cmd_schema(&path);
1180        assert!(result.is_ok());
1181    }
1182
1183    #[test]
1184    fn test_cmd_convert_csv_to_parquet() {
1185        let temp_dir = tempfile::tempdir()
1186            .ok()
1187            .unwrap_or_else(|| panic!("Should create temp dir"));
1188        let csv_path = temp_dir.path().join("input.csv");
1189        let parquet_path = temp_dir.path().join("output.parquet");
1190
1191        std::fs::write(&csv_path, "id,name\n1,foo\n2,bar\n").unwrap();
1192
1193        let result = cmd_convert(&csv_path, &parquet_path);
1194        assert!(result.is_ok());
1195        assert!(parquet_path.exists());
1196    }
1197
1198    #[test]
1199    fn test_cmd_convert_json_to_parquet() {
1200        let temp_dir = tempfile::tempdir()
1201            .ok()
1202            .unwrap_or_else(|| panic!("Should create temp dir"));
1203        let json_path = temp_dir.path().join("input.json");
1204        let parquet_path = temp_dir.path().join("output.parquet");
1205
1206        std::fs::write(
1207            &json_path,
1208            r#"{"id":1,"name":"foo"}
1209{"id":2,"name":"bar"}"#,
1210        )
1211        .unwrap();
1212
1213        let result = cmd_convert(&json_path, &parquet_path);
1214        assert!(result.is_ok());
1215        assert!(parquet_path.exists());
1216    }
1217
1218    #[test]
1219    fn test_save_dataset_jsonl() {
1220        let temp_dir = tempfile::tempdir()
1221            .ok()
1222            .unwrap_or_else(|| panic!("Should create temp dir"));
1223        let path = temp_dir.path().join("output.jsonl");
1224
1225        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1226        let batch = arrow::array::RecordBatch::try_new(
1227            schema,
1228            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1229        )
1230        .unwrap();
1231        let dataset = ArrowDataset::from_batch(batch).unwrap();
1232
1233        let result = save_dataset(&dataset, &path);
1234        assert!(result.is_ok());
1235    }
1236
1237    #[test]
1238    fn test_load_dataset_no_extension() {
1239        let path = PathBuf::from("file_without_extension");
1240        let result = load_dataset(&path);
1241        assert!(result.is_err());
1242    }
1243
1244    #[test]
1245    fn test_cmd_head_exact_rows() {
1246        let temp_dir = tempfile::tempdir()
1247            .ok()
1248            .unwrap_or_else(|| panic!("Should create temp dir"));
1249        let path = temp_dir.path().join("exact.parquet");
1250        create_test_parquet(&path, 10);
1251
1252        // Request exact number of rows
1253        let result = cmd_head(&path, 10);
1254        assert!(result.is_ok());
1255    }
1256
1257    #[test]
1258    fn test_cmd_convert_parquet_to_parquet() {
1259        let temp_dir = tempfile::tempdir()
1260            .ok()
1261            .unwrap_or_else(|| panic!("Should create temp dir"));
1262        let input = temp_dir.path().join("input.parquet");
1263        let output = temp_dir.path().join("output.parquet");
1264        create_test_parquet(&input, 20);
1265
1266        let result = cmd_convert(&input, &output);
1267        assert!(result.is_ok());
1268
1269        // Both should have same data
1270        let original = ArrowDataset::from_parquet(&input).unwrap();
1271        let converted = ArrowDataset::from_parquet(&output).unwrap();
1272        assert_eq!(original.len(), converted.len());
1273    }
1274
1275    #[test]
1276    fn test_get_format_all_types() {
1277        assert_eq!(get_format(Path::new("data.parquet")), "Parquet");
1278        assert_eq!(get_format(Path::new("data.arrow")), "Arrow IPC");
1279        assert_eq!(get_format(Path::new("data.ipc")), "Arrow IPC");
1280        assert_eq!(get_format(Path::new("data.csv")), "CSV");
1281        assert_eq!(get_format(Path::new("data.json")), "JSON");
1282        assert_eq!(get_format(Path::new("data.jsonl")), "JSON");
1283        assert_eq!(get_format(Path::new("data.txt")), "Unknown");
1284        assert_eq!(get_format(Path::new("data.yaml")), "Unknown");
1285        assert_eq!(get_format(Path::new("data")), "Unknown");
1286    }
1287}