1use std::path::{Path, PathBuf};
4
5use arrow::util::pretty::print_batches;
6
7use crate::{ArrowDataset, Dataset};
8
9#[cfg(feature = "shuffle")]
12type MixInputs = (Vec<(ArrowDataset, f64, String)>, f64);
13
14pub(crate) fn load_dataset(path: &Path) -> crate::Result<ArrowDataset> {
16 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
17
18 match ext {
19 "parquet" => ArrowDataset::from_parquet(path),
20 "csv" => ArrowDataset::from_csv(path),
21 "json" | "jsonl" => ArrowDataset::from_json(path),
22 ext => Err(crate::Error::unsupported_format(ext)),
23 }
24}
25
26pub(crate) fn save_dataset(dataset: &ArrowDataset, path: &Path) -> crate::Result<()> {
28 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
29
30 match ext {
31 "parquet" => dataset.to_parquet(path),
32 "csv" => dataset.to_csv(path),
33 "json" | "jsonl" => dataset.to_json(path),
34 ext => Err(crate::Error::unsupported_format(ext)),
35 }
36}
37
38pub(crate) fn get_format(path: &Path) -> &'static str {
40 match path.extension().and_then(|e| e.to_str()) {
41 Some("parquet") => "Parquet",
42 Some("arrow" | "ipc") => "Arrow IPC",
43 Some("csv") => "CSV",
44 Some("json" | "jsonl") => "JSON",
45 _ => "Unknown",
46 }
47}
48
49pub(crate) fn cmd_convert(input: &Path, output: &Path) -> crate::Result<()> {
51 let dataset = load_dataset(input)?;
53
54 save_dataset(&dataset, output)?;
56
57 println!(
58 "Converted {} -> {} ({} rows)",
59 input.display(),
60 output.display(),
61 dataset.len()
62 );
63
64 Ok(())
65}
66
67pub(crate) fn cmd_info(path: &Path) -> crate::Result<()> {
72 let file_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
73 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
74
75 println!("File: {}", path.display());
76 println!("Format: {}", get_format(path));
77
78 if ext == "parquet" {
79 let file = std::fs::File::open(path).map_err(|e| crate::Error::io(e, path))?;
81 let builder = parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file)
82 .map_err(crate::Error::Parquet)?;
83 let metadata = builder.metadata();
84 let num_rows: i64 = metadata.row_groups().iter().map(|rg| rg.num_rows()).sum();
85 let num_batches = metadata.num_row_groups();
86 let num_columns = metadata
87 .row_groups()
88 .first()
89 .map_or(0, |rg| rg.num_columns());
90 println!("Rows: {num_rows}");
91 println!("Batches: {num_batches}");
92 println!("Columns: {num_columns}");
93 } else {
94 let dataset = load_dataset(path)?;
95 println!("Rows: {}", dataset.len());
96 println!("Batches: {}", dataset.num_batches());
97 println!("Columns: {}", dataset.schema().fields().len());
98 }
99
100 println!("Size: {file_size} bytes");
101
102 Ok(())
103}
104
105pub(crate) fn cmd_head(path: &Path, rows: usize) -> crate::Result<()> {
107 let Some(dataset) = load_head_dataset(path, rows)? else {
108 return Ok(());
109 };
110
111 if dataset.is_empty() {
112 println!("Dataset is empty");
113 return Ok(());
114 }
115
116 let (collected, count) = collect_head_batches(&dataset, rows);
117 if collected.is_empty() {
118 println!("No data to display");
119 return Ok(());
120 }
121
122 print_batches(&collected).map_err(crate::Error::Arrow)?;
123 if count < dataset.len() {
124 println!("... showing {} of {} rows", count, dataset.len());
125 }
126 Ok(())
127}
128
129fn load_head_dataset(path: &Path, rows: usize) -> crate::Result<Option<ArrowDataset>> {
134 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
137 if ext != "parquet" {
138 return Ok(Some(load_dataset(path)?));
139 }
140 let file = std::fs::File::open(path).map_err(|e| crate::Error::io(e, path))?;
141 let builder = parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file)
142 .map_err(crate::Error::Parquet)?;
143 let reader = builder
144 .with_limit(rows)
145 .build()
146 .map_err(crate::Error::Parquet)?;
147 let batches: Vec<arrow::record_batch::RecordBatch> = reader
148 .collect::<std::result::Result<Vec<_>, _>>()
149 .map_err(crate::Error::Arrow)?;
150 if batches.is_empty() {
151 println!("Dataset is empty");
152 return Ok(None);
153 }
154 Ok(Some(ArrowDataset::new(batches)?))
155}
156
157fn collect_head_batches(
158 dataset: &ArrowDataset,
159 rows: usize,
160) -> (Vec<arrow::record_batch::RecordBatch>, usize) {
161 let mut collected = Vec::new();
162 let mut count = 0;
163 for batch in dataset.iter() {
164 let take = (rows - count).min(batch.num_rows());
165 if take > 0 {
166 collected.push(batch.slice(0, take));
167 count += take;
168 }
169 if count >= rows {
170 break;
171 }
172 }
173 (collected, count)
174}
175
176pub(crate) fn cmd_schema(path: &Path) -> crate::Result<()> {
178 let dataset = load_dataset(path)?;
179 let schema = dataset.schema();
180
181 println!("Schema for {}:", path.display());
182 println!();
183
184 for (i, field) in schema.fields().iter().enumerate() {
185 let nullable = if field.is_nullable() {
186 "nullable"
187 } else {
188 "not null"
189 };
190 println!(
191 " {}: {} ({}) [{}]",
192 i,
193 field.name(),
194 field.data_type(),
195 nullable
196 );
197 }
198
199 println!();
200 println!("Total columns: {}", schema.fields().len());
201
202 Ok(())
203}
204
205#[cfg(feature = "shuffle")]
207fn parse_input_spec(spec: &str) -> (PathBuf, f64) {
208 if let Some((path, weight_str)) = spec.rsplit_once(':') {
209 if let Ok(weight) = weight_str.parse::<f64>() {
211 return (PathBuf::from(path), weight);
212 }
213 }
214 (PathBuf::from(spec), 1.0)
215}
216
217#[cfg(feature = "shuffle")]
219fn load_mix_inputs(inputs: &[String]) -> crate::Result<MixInputs> {
220 let mut datasets = Vec::new();
221 let mut total_weight = 0.0;
222
223 for spec in inputs {
224 let (path, weight) = parse_input_spec(spec);
225 if !path.exists() {
226 return Err(crate::Error::io(
227 std::io::Error::new(std::io::ErrorKind::NotFound, "Input file not found"),
228 &path,
229 ));
230 }
231 let dataset = load_dataset(&path)?;
232 println!(
233 " Loaded {} ({} rows, weight={:.2})",
234 path.display(),
235 dataset.len(),
236 weight
237 );
238 total_weight += weight;
239 datasets.push((dataset, weight, path.display().to_string()));
240 }
241 Ok((datasets, total_weight))
242}
243
244#[cfg(feature = "shuffle")]
246#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
247fn sample_dataset(
248 dataset: &ArrowDataset,
249 rows_needed: usize,
250 rng: &mut rand::rngs::StdRng,
251) -> crate::Result<arrow::array::RecordBatch> {
252 use rand::seq::SliceRandom;
253
254 let available = dataset.len();
255 let mut indices: Vec<usize> = (0..available).collect();
256 indices.shuffle(rng);
257
258 if rows_needed > available {
259 let extra: Vec<usize> = (0..available)
260 .cycle()
261 .take(rows_needed - available)
262 .collect();
263 indices.extend(extra);
264 }
265 indices.truncate(rows_needed);
266
267 let schema = dataset.schema();
268 let flat_batches: Vec<_> = dataset.iter().collect();
269 let concatenated = arrow::compute::concat_batches(&schema, &flat_batches)
270 .map_err(|e| crate::Error::invalid_config(format!("Arrow concat error: {e}")))?;
271
272 let take_indices: Vec<u32> = indices.iter().map(|&i| i as u32).collect();
273 let index_array = arrow::array::UInt32Array::from(take_indices);
274
275 let columns: Vec<arrow::array::ArrayRef> = (0..concatenated.num_columns())
276 .map(|col_idx| {
277 arrow::compute::take(concatenated.column(col_idx), &index_array, None)
278 .map_err(|e| crate::Error::invalid_config(format!("Arrow take error: {e}")))
279 })
280 .collect::<crate::Result<Vec<_>>>()?;
281
282 arrow::array::RecordBatch::try_new(schema, columns)
283 .map_err(|e| crate::Error::invalid_config(format!("RecordBatch error: {e}")))
284}
285
286#[cfg(feature = "shuffle")]
288#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
289pub(crate) fn cmd_mix(
290 inputs: &[String],
291 output: &Path,
292 seed: u64,
293 max_rows: usize,
294) -> crate::Result<()> {
295 use rand::{rngs::StdRng, SeedableRng};
296
297 if inputs.is_empty() {
298 return Err(crate::Error::invalid_config("No input files provided"));
299 }
300
301 let (datasets, total_weight) = load_mix_inputs(inputs)?;
302 if total_weight == 0.0 {
303 return Err(crate::Error::invalid_config("All weights are zero"));
304 }
305
306 let total_available: usize = datasets.iter().map(|(d, _, _)| d.len()).sum();
307 let target_rows = if max_rows > 0 {
308 max_rows
309 } else {
310 total_available
311 };
312
313 println!(
314 "\nMixing {} datasets → {} target rows",
315 datasets.len(),
316 target_rows
317 );
318
319 let mut rng = StdRng::seed_from_u64(seed);
320 let mut all_batches = Vec::new();
321 let mut mixed_rows = 0;
322
323 for (dataset, weight, name) in &datasets {
324 let fraction = weight / total_weight;
325 let rows_for_dataset = (target_rows as f64 * fraction) as usize;
326
327 let batch = sample_dataset(dataset, rows_for_dataset, &mut rng)?;
328 let count = batch.num_rows();
329 all_batches.push(batch);
330 mixed_rows += count;
331
332 println!(" {} → {} rows ({:.1}%)", name, count, fraction * 100.0);
333 }
334
335 if all_batches.is_empty() {
336 return Err(crate::Error::invalid_config("No data to mix"));
337 }
338
339 let mixed = ArrowDataset::new(all_batches)?;
340 save_dataset(&mixed, output)?;
341
342 println!("\nMixed {} rows → {}", mixed_rows, output.display());
343 Ok(())
344}
345
346#[cfg(feature = "shuffle")]
347pub(crate) fn cmd_fim(
348 input: &Path,
349 output: &Path,
350 column: &str,
351 rate: f64,
352 format: &str,
353 seed: u64,
354) -> crate::Result<()> {
355 use crate::transform::{Fim, FimFormat, Transform};
356
357 let dataset = load_dataset(input)?;
358 let fim_format = match format {
359 "spm" => FimFormat::SPM,
360 _ => FimFormat::PSM,
361 };
362
363 let fim = Fim::new(column)
364 .with_rate(rate)
365 .with_format(fim_format)
366 .with_seed(seed);
367
368 let mut all_batches = Vec::new();
369 for batch in dataset.iter() {
370 all_batches.push(fim.apply(batch)?);
371 }
372
373 let transformed = ArrowDataset::new(all_batches)?;
374 save_dataset(&transformed, output)?;
375
376 println!(
377 "FIM transform ({} format, {:.0}% rate) applied to '{}' column",
378 format.to_uppercase(),
379 rate * 100.0,
380 column
381 );
382 println!("{} rows → {}", dataset.len(), output.display());
383 Ok(())
384}
385
386pub(crate) fn cmd_dedup(input: &Path, output: &Path, column: Option<&str>) -> crate::Result<()> {
391 use crate::transform::{Transform, Unique};
392
393 let dataset = load_dataset(input)?;
394 let original_rows = dataset.len();
395
396 let dedup = match column {
398 Some(col) => Unique::by(vec![col]),
399 None => detect_text_column_dedup(&dataset),
400 };
401
402 let mut all_batches = Vec::new();
403 for batch in dataset.iter() {
404 all_batches.push(dedup.apply(batch)?);
405 }
406
407 let deduped = ArrowDataset::new(all_batches)?;
408 let deduped_rows = deduped.len();
409 save_dataset(&deduped, output)?;
410
411 let removed = original_rows - deduped_rows;
412 println!(
413 "Dedup: {} → {} rows ({} duplicates removed, {:.1}% reduction)",
414 original_rows,
415 deduped_rows,
416 removed,
417 removed as f64 / original_rows.max(1) as f64 * 100.0
418 );
419 Ok(())
420}
421
422fn detect_text_column_dedup(dataset: &ArrowDataset) -> crate::transform::Unique {
424 use arrow::datatypes::DataType;
425
426 use crate::transform::Unique;
427
428 let schema = dataset.schema();
429 for name in &["text", "content", "code", "source"] {
430 if let Some((_, field)) = schema.column_with_name(name) {
431 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
432 return Unique::by(vec![*name]);
433 }
434 }
435 }
436 Unique::all()
438}
439
440pub(crate) fn cmd_filter_text(
445 input: &Path,
446 output: &Path,
447 column: Option<&str>,
448 min_score: f64,
449 min_length: usize,
450 max_length: usize,
451) -> crate::Result<()> {
452 use crate::transform::Transform;
453
454 let dataset = load_dataset(input)?;
455 let original_rows = dataset.len();
456
457 let col_name = column
458 .map(String::from)
459 .unwrap_or_else(|| find_text_column(&dataset));
460
461 let filter = TextQualityFilter::new(&col_name, min_score, min_length, max_length);
462
463 let mut all_batches = Vec::new();
464 for batch in dataset.iter() {
465 all_batches.push(filter.apply(batch)?);
466 }
467
468 let filtered = ArrowDataset::new(all_batches)?;
469 let kept = filtered.len();
470 save_dataset(&filtered, output)?;
471
472 let removed = original_rows - kept;
473 println!(
474 "Filter: {} → {} rows ({} removed, {:.1}% kept)",
475 original_rows,
476 kept,
477 removed,
478 kept as f64 / original_rows.max(1) as f64 * 100.0
479 );
480 println!(
481 " min_score={:.2} min_len={} max_len={} column='{}'",
482 min_score, min_length, max_length, col_name
483 );
484 Ok(())
485}
486
487fn find_text_column(dataset: &ArrowDataset) -> String {
489 use arrow::datatypes::DataType;
490 let schema = dataset.schema();
491 for name in &["text", "content", "code", "source"] {
492 if let Some((_, field)) = schema.column_with_name(name) {
493 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
494 return (*name).to_string();
495 }
496 }
497 }
498 for field in schema.fields() {
500 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
501 return field.name().clone();
502 }
503 }
504 "text".to_string()
505}
506
507struct TextQualityFilter {
509 column: String,
510 min_score: f64,
511 min_length: usize,
512 max_length: usize,
513}
514
515impl TextQualityFilter {
516 fn new(column: &str, min_score: f64, min_length: usize, max_length: usize) -> Self {
517 Self {
518 column: column.to_string(),
519 min_score,
520 min_length,
521 max_length,
522 }
523 }
524}
525
526impl crate::transform::Transform for TextQualityFilter {
527 fn apply(&self, batch: arrow::array::RecordBatch) -> crate::Result<arrow::array::RecordBatch> {
528 use arrow::{
529 array::{Array, BooleanArray, StringArray},
530 compute::filter_record_batch,
531 };
532
533 let schema = batch.schema();
534 let col_idx = schema
535 .column_with_name(&self.column)
536 .map(|(i, _)| i)
537 .ok_or_else(|| crate::Error::column_not_found(&self.column))?;
538
539 let text_arr = batch
540 .column(col_idx)
541 .as_any()
542 .downcast_ref::<StringArray>()
543 .ok_or_else(|| crate::Error::column_not_found(&self.column))?;
544
545 let mask: BooleanArray = (0..text_arr.len())
546 .map(|i| {
547 if text_arr.is_null(i) {
548 Some(false)
549 } else {
550 let text = text_arr.value(i);
551 Some(passes_quality(
552 text,
553 self.min_score,
554 self.min_length,
555 self.max_length,
556 ))
557 }
558 })
559 .collect();
560
561 filter_record_batch(&batch, &mask).map_err(crate::Error::Arrow)
562 }
563}
564
565fn passes_quality(text: &str, min_score: f64, min_len: usize, max_len: usize) -> bool {
567 let len = text.len();
568 if len < min_len || len > max_len {
569 return false;
570 }
571 composite_score(text) >= min_score
572}
573
574fn composite_score(text: &str) -> f64 {
576 let s1 = score_alnum_ratio(text);
577 let s2 = score_line_length(text);
578 let s3 = score_dup_lines(text);
579 let s4 = score_entropy(text);
580 (s1 + s2 + s3 + s4) / 4.0
581}
582
583fn score_alnum_ratio(text: &str) -> f64 {
585 if text.is_empty() {
586 return 0.0;
587 }
588 let alnum = text.chars().filter(|c| c.is_alphanumeric()).count();
589 let ratio = alnum as f64 / text.len() as f64;
590 if ratio < 0.2 {
591 0.0
592 } else if ratio < 0.3 {
593 ratio
594 } else {
595 1.0
596 }
597}
598
599fn score_line_length(text: &str) -> f64 {
601 let lines: Vec<&str> = text.lines().collect();
602 if lines.is_empty() {
603 return 0.0;
604 }
605 let avg = text.len() as f64 / lines.len() as f64;
606 if avg < 10.0 {
607 0.2
608 } else if avg > 200.0 {
609 0.5
610 } else {
611 1.0
612 }
613}
614
615fn score_dup_lines(text: &str) -> f64 {
617 use std::collections::HashSet;
618 let lines: Vec<&str> = text.lines().collect();
619 if lines.len() <= 1 {
620 return 1.0;
621 }
622 let unique: HashSet<&str> = lines.iter().copied().collect();
623 let dup_ratio = 1.0 - (unique.len() as f64 / lines.len() as f64);
624 if dup_ratio > 0.5 {
625 0.2
626 } else {
627 1.0 - dup_ratio
628 }
629}
630
631fn score_entropy(text: &str) -> f64 {
633 if text.is_empty() {
634 return 0.0;
635 }
636 let mut counts = [0u32; 256];
637 for &b in text.as_bytes() {
638 counts[b as usize] += 1;
639 }
640 let len = text.len() as f64;
641 let entropy: f64 = counts
642 .iter()
643 .filter(|&&c| c > 0)
644 .map(|&c| {
645 let p = f64::from(c) / len;
646 -p * p.ln()
647 })
648 .sum();
649 let e = entropy / std::f64::consts::LN_2; if e < 2.0 {
651 0.2
652 } else if e > 6.5 {
653 0.3
654 } else {
655 1.0
656 }
657}
658
659#[cfg(test)]
660#[allow(
661 clippy::cast_possible_truncation,
662 clippy::cast_possible_wrap,
663 clippy::cast_precision_loss,
664 clippy::uninlined_format_args,
665 clippy::unwrap_used,
666 clippy::expect_used,
667 clippy::redundant_clone,
668 clippy::cast_lossless,
669 clippy::redundant_closure_for_method_calls,
670 clippy::too_many_lines,
671 clippy::float_cmp,
672 clippy::similar_names,
673 clippy::needless_late_init,
674 clippy::redundant_pattern_matching
675)]
676mod tests {
677 use std::sync::Arc;
678
679 use arrow::{
680 array::{Int32Array, StringArray},
681 datatypes::{DataType, Field, Schema},
682 };
683
684 use super::*;
685
686 fn create_test_parquet(path: &Path, rows: usize) {
687 let schema = Arc::new(Schema::new(vec![
688 Field::new("id", DataType::Int32, false),
689 Field::new("name", DataType::Utf8, false),
690 ]));
691
692 let ids: Vec<i32> = (0..rows as i32).collect();
693 let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
694
695 let batch = arrow::array::RecordBatch::try_new(
696 schema,
697 vec![
698 Arc::new(Int32Array::from(ids)),
699 Arc::new(StringArray::from(names)),
700 ],
701 )
702 .ok()
703 .unwrap_or_else(|| panic!("Should create batch"));
704
705 let dataset = ArrowDataset::from_batch(batch)
706 .ok()
707 .unwrap_or_else(|| panic!("Should create dataset"));
708
709 dataset
710 .to_parquet(path)
711 .ok()
712 .unwrap_or_else(|| panic!("Should write parquet"));
713 }
714
715 #[test]
716 fn test_cmd_info() {
717 let temp_dir = tempfile::tempdir()
718 .ok()
719 .unwrap_or_else(|| panic!("Should create temp dir"));
720 let path = temp_dir.path().join("test.parquet");
721 create_test_parquet(&path, 100);
722
723 let result = cmd_info(&path);
724 assert!(result.is_ok());
725 }
726
727 #[test]
728 fn test_cmd_head() {
729 let temp_dir = tempfile::tempdir()
730 .ok()
731 .unwrap_or_else(|| panic!("Should create temp dir"));
732 let path = temp_dir.path().join("test.parquet");
733 create_test_parquet(&path, 100);
734
735 let result = cmd_head(&path, 5);
736 assert!(result.is_ok());
737 }
738
739 #[test]
740 fn test_cmd_schema() {
741 let temp_dir = tempfile::tempdir()
742 .ok()
743 .unwrap_or_else(|| panic!("Should create temp dir"));
744 let path = temp_dir.path().join("test.parquet");
745 create_test_parquet(&path, 10);
746
747 let result = cmd_schema(&path);
748 assert!(result.is_ok());
749 }
750
751 #[test]
752 fn test_cmd_convert() {
753 let temp_dir = tempfile::tempdir()
754 .ok()
755 .unwrap_or_else(|| panic!("Should create temp dir"));
756 let input = temp_dir.path().join("input.parquet");
757 let output = temp_dir.path().join("output.parquet");
758 create_test_parquet(&input, 50);
759
760 let result = cmd_convert(&input, &output);
761 assert!(result.is_ok());
762
763 let original = ArrowDataset::from_parquet(&input)
765 .ok()
766 .unwrap_or_else(|| panic!("Should load original"));
767 let converted = ArrowDataset::from_parquet(&output)
768 .ok()
769 .unwrap_or_else(|| panic!("Should load converted"));
770
771 assert_eq!(original.len(), converted.len());
772 }
773
774 #[test]
775 fn test_load_dataset_unsupported() {
776 let path = PathBuf::from("test.xyz");
777 let result = load_dataset(&path);
778 assert!(result.is_err());
779 }
780
781 #[test]
782 fn test_get_format() {
783 assert_eq!(get_format(Path::new("test.parquet")), "Parquet");
784 assert_eq!(get_format(Path::new("test.arrow")), "Arrow IPC");
785 assert_eq!(get_format(Path::new("test.csv")), "CSV");
786 assert_eq!(get_format(Path::new("test.json")), "JSON");
787 assert_eq!(get_format(Path::new("test.unknown")), "Unknown");
788 }
789
790 #[test]
791 fn test_cmd_head_with_more_rows_than_dataset() {
792 let temp_dir = tempfile::tempdir()
793 .ok()
794 .unwrap_or_else(|| panic!("Should create temp dir"));
795 let path = temp_dir.path().join("test.parquet");
796 create_test_parquet(&path, 5);
797
798 let result = cmd_head(&path, 100);
800 assert!(result.is_ok());
801 }
802
803 #[test]
804 fn test_cmd_convert_parquet_to_csv() {
805 let temp_dir = tempfile::tempdir()
806 .ok()
807 .unwrap_or_else(|| panic!("Should create temp dir"));
808 let input = temp_dir.path().join("input.parquet");
809 let output = temp_dir.path().join("output.csv");
810 create_test_parquet(&input, 25);
811
812 let result = cmd_convert(&input, &output);
813 assert!(result.is_ok());
814 assert!(output.exists());
815 }
816
817 #[test]
818 fn test_cmd_convert_parquet_to_json() {
819 let temp_dir = tempfile::tempdir()
820 .ok()
821 .unwrap_or_else(|| panic!("Should create temp dir"));
822 let input = temp_dir.path().join("input.parquet");
823 let output = temp_dir.path().join("output.json");
824 create_test_parquet(&input, 15);
825
826 let result = cmd_convert(&input, &output);
827 assert!(result.is_ok());
828 assert!(output.exists());
829 }
830
831 #[test]
832 fn test_save_dataset_unsupported_format() {
833 let temp_dir = tempfile::tempdir()
834 .ok()
835 .unwrap_or_else(|| panic!("Should create temp dir"));
836 let input = temp_dir.path().join("data.parquet");
837 let output = temp_dir.path().join("output.xyz");
838 create_test_parquet(&input, 5);
839
840 let dataset = ArrowDataset::from_parquet(&input)
841 .ok()
842 .unwrap_or_else(|| panic!("Should load"));
843
844 let result = save_dataset(&dataset, &output);
845 assert!(result.is_err());
846 }
847
848 #[test]
849 fn test_get_format_ipc() {
850 assert_eq!(get_format(Path::new("test.ipc")), "Arrow IPC");
851 }
852
853 #[test]
854 fn test_get_format_jsonl() {
855 assert_eq!(get_format(Path::new("test.jsonl")), "JSON");
856 }
857
858 #[test]
859 fn test_get_format_no_extension() {
860 assert_eq!(get_format(Path::new("testfile")), "Unknown");
861 }
862
863 #[test]
864 fn test_cmd_convert_unsupported_output() {
865 let temp_dir = tempfile::tempdir()
866 .ok()
867 .unwrap_or_else(|| panic!("Should create temp dir"));
868 let input = temp_dir.path().join("input.parquet");
869 let output = temp_dir.path().join("output.xyz");
870 create_test_parquet(&input, 10);
871
872 let result = cmd_convert(&input, &output);
873 assert!(result.is_err());
874 }
875
876 #[test]
877 fn test_load_dataset_xyz_format() {
878 let temp_dir = tempfile::tempdir()
879 .ok()
880 .unwrap_or_else(|| panic!("Should create temp dir"));
881 let path = temp_dir.path().join("data.xyz");
882
883 std::fs::write(&path, "some data")
884 .ok()
885 .unwrap_or_else(|| panic!("Should write file"));
886
887 let result = load_dataset(&path);
888 assert!(result.is_err());
889 }
890
891 #[test]
892 fn test_get_format_arrow() {
893 assert_eq!(get_format(Path::new("test.arrow")), "Arrow IPC");
894 }
895
896 #[test]
897 fn test_get_format_unknown() {
898 assert_eq!(get_format(Path::new("test.feather")), "Unknown");
899 assert_eq!(get_format(Path::new("test.txt")), "Unknown");
900 }
901
902 #[test]
903 fn test_load_dataset_csv() {
904 let temp_dir = tempfile::tempdir()
905 .ok()
906 .unwrap_or_else(|| panic!("Should create temp dir"));
907 let parquet_path = temp_dir.path().join("data.parquet");
908 let csv_path = temp_dir.path().join("data.csv");
909
910 create_test_parquet(&parquet_path, 10);
911
912 let dataset = ArrowDataset::from_parquet(&parquet_path)
914 .ok()
915 .unwrap_or_else(|| panic!("Should load"));
916 dataset
917 .to_csv(&csv_path)
918 .ok()
919 .unwrap_or_else(|| panic!("Should write csv"));
920
921 let loaded = load_dataset(&csv_path);
923 assert!(loaded.is_ok());
924 }
925
926 #[test]
927 fn test_load_dataset_json() {
928 let temp_dir = tempfile::tempdir()
929 .ok()
930 .unwrap_or_else(|| panic!("Should create temp dir"));
931 let parquet_path = temp_dir.path().join("data.parquet");
932 let json_path = temp_dir.path().join("data.json");
933
934 create_test_parquet(&parquet_path, 10);
935
936 let dataset = ArrowDataset::from_parquet(&parquet_path)
938 .ok()
939 .unwrap_or_else(|| panic!("Should load"));
940 dataset
941 .to_json(&json_path)
942 .ok()
943 .unwrap_or_else(|| panic!("Should write json"));
944
945 let loaded = load_dataset(&json_path);
947 assert!(loaded.is_ok());
948 }
949
950 #[test]
951 fn test_load_dataset_jsonl() {
952 let temp_dir = tempfile::tempdir()
953 .ok()
954 .unwrap_or_else(|| panic!("Should create temp dir"));
955 let parquet_path = temp_dir.path().join("data.parquet");
956 let jsonl_path = temp_dir.path().join("data.jsonl");
957
958 create_test_parquet(&parquet_path, 10);
959
960 let dataset = ArrowDataset::from_parquet(&parquet_path)
962 .ok()
963 .unwrap_or_else(|| panic!("Should load"));
964 dataset
965 .to_json(&jsonl_path)
966 .ok()
967 .unwrap_or_else(|| panic!("Should write jsonl"));
968
969 let loaded = load_dataset(&jsonl_path);
971 assert!(loaded.is_ok());
972 }
973
974 #[test]
975 fn test_save_dataset_parquet() {
976 let temp_dir = tempfile::tempdir()
977 .ok()
978 .unwrap_or_else(|| panic!("Should create temp dir"));
979 let path = temp_dir.path().join("output.parquet");
980
981 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
982 let batch = arrow::array::RecordBatch::try_new(
983 schema,
984 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
985 )
986 .unwrap();
987 let dataset = ArrowDataset::from_batch(batch).unwrap();
988
989 let result = save_dataset(&dataset, &path);
990 assert!(result.is_ok());
991 }
992
993 #[test]
994 fn test_save_dataset_csv() {
995 let temp_dir = tempfile::tempdir()
996 .ok()
997 .unwrap_or_else(|| panic!("Should create temp dir"));
998 let path = temp_dir.path().join("output.csv");
999
1000 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1001 let batch = arrow::array::RecordBatch::try_new(
1002 schema,
1003 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1004 )
1005 .unwrap();
1006 let dataset = ArrowDataset::from_batch(batch).unwrap();
1007
1008 let result = save_dataset(&dataset, &path);
1009 assert!(result.is_ok());
1010 }
1011
1012 #[test]
1013 fn test_save_dataset_json() {
1014 let temp_dir = tempfile::tempdir()
1015 .ok()
1016 .unwrap_or_else(|| panic!("Should create temp dir"));
1017 let path = temp_dir.path().join("output.json");
1018
1019 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1020 let batch = arrow::array::RecordBatch::try_new(
1021 schema,
1022 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1023 )
1024 .unwrap();
1025 let dataset = ArrowDataset::from_batch(batch).unwrap();
1026
1027 let result = save_dataset(&dataset, &path);
1028 assert!(result.is_ok());
1029 }
1030
1031 #[test]
1032 fn test_save_dataset_unknown_extension() {
1033 let temp_dir = tempfile::tempdir()
1034 .ok()
1035 .unwrap_or_else(|| panic!("Should create temp dir"));
1036 let path = temp_dir.path().join("output.xyz");
1037
1038 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1039 let batch = arrow::array::RecordBatch::try_new(
1040 schema,
1041 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1042 )
1043 .unwrap();
1044 let dataset = ArrowDataset::from_batch(batch).unwrap();
1045
1046 let result = save_dataset(&dataset, &path);
1047 assert!(result.is_err());
1048 }
1049
1050 #[test]
1051 fn test_cmd_convert_to_csv_format() {
1052 let temp_dir = tempfile::tempdir()
1053 .ok()
1054 .unwrap_or_else(|| panic!("Should create temp dir"));
1055 let input = temp_dir.path().join("input.parquet");
1056 let output = temp_dir.path().join("output.csv");
1057 create_test_parquet(&input, 20);
1058
1059 let result = cmd_convert(&input, &output);
1060 assert!(result.is_ok());
1061 assert!(output.exists());
1062 }
1063
1064 #[test]
1065 fn test_cmd_convert_to_json_format() {
1066 let temp_dir = tempfile::tempdir()
1067 .ok()
1068 .unwrap_or_else(|| panic!("Should create temp dir"));
1069 let input = temp_dir.path().join("input.parquet");
1070 let output = temp_dir.path().join("output.json");
1071 create_test_parquet(&input, 20);
1072
1073 let result = cmd_convert(&input, &output);
1074 assert!(result.is_ok());
1075 assert!(output.exists());
1076 }
1077
1078 #[test]
1079 fn test_cmd_head_more_than_available() {
1080 let temp_dir = tempfile::tempdir()
1081 .ok()
1082 .unwrap_or_else(|| panic!("Should create temp dir"));
1083 let path = temp_dir.path().join("small.parquet");
1084 create_test_parquet(&path, 5);
1085
1086 let result = cmd_head(&path, 100);
1088 assert!(result.is_ok());
1089 }
1090
1091 #[test]
1092 fn test_load_dataset_csv_file() {
1093 let temp_dir = tempfile::tempdir()
1094 .ok()
1095 .unwrap_or_else(|| panic!("Should create temp dir"));
1096 let csv_path = temp_dir.path().join("test.csv");
1097
1098 std::fs::write(&csv_path, "id,name\n1,foo\n2,bar\n").unwrap();
1100
1101 let result = load_dataset(&csv_path);
1102 assert!(result.is_ok());
1103 }
1104
1105 #[test]
1106 fn test_load_dataset_json_file() {
1107 let temp_dir = tempfile::tempdir()
1108 .ok()
1109 .unwrap_or_else(|| panic!("Should create temp dir"));
1110 let json_path = temp_dir.path().join("test.json");
1111
1112 std::fs::write(
1114 &json_path,
1115 r#"{"id":1,"name":"foo"}
1116{"id":2,"name":"bar"}"#,
1117 )
1118 .unwrap();
1119
1120 let result = load_dataset(&json_path);
1121 assert!(result.is_ok());
1122 }
1123
1124 #[test]
1127 fn test_cmd_head_zero_rows() {
1128 let temp_dir = tempfile::tempdir()
1129 .ok()
1130 .unwrap_or_else(|| panic!("Should create temp dir"));
1131 let path = temp_dir.path().join("test.parquet");
1132 create_test_parquet(&path, 50);
1133
1134 let result = cmd_head(&path, 0);
1136 assert!(result.is_ok());
1137 }
1138
1139 #[test]
1140 fn test_cmd_info_small_file() {
1141 let temp_dir = tempfile::tempdir()
1142 .ok()
1143 .unwrap_or_else(|| panic!("Should create temp dir"));
1144 let path = temp_dir.path().join("small.parquet");
1145 create_test_parquet(&path, 5);
1146
1147 let result = cmd_info(&path);
1148 assert!(result.is_ok());
1149 }
1150
1151 #[test]
1152 fn test_cmd_info_large_file() {
1153 let temp_dir = tempfile::tempdir()
1154 .ok()
1155 .unwrap_or_else(|| panic!("Should create temp dir"));
1156 let path = temp_dir.path().join("large.parquet");
1157 create_test_parquet(&path, 1000);
1158
1159 let result = cmd_info(&path);
1160 assert!(result.is_ok());
1161 }
1162
1163 #[test]
1164 fn test_cmd_schema_complex() {
1165 let temp_dir = tempfile::tempdir()
1166 .ok()
1167 .unwrap_or_else(|| panic!("Should create temp dir"));
1168 let path = temp_dir.path().join("complex.parquet");
1169
1170 let schema = Arc::new(Schema::new(vec![
1172 Field::new("id", DataType::Int32, false),
1173 Field::new("name", DataType::Utf8, true),
1174 Field::new("value", DataType::Float64, true),
1175 ]));
1176
1177 let batch = arrow::array::RecordBatch::try_new(
1178 schema,
1179 vec![
1180 Arc::new(Int32Array::from(vec![1, 2, 3])),
1181 Arc::new(StringArray::from(vec!["a", "b", "c"])),
1182 Arc::new(arrow::array::Float64Array::from(vec![1.0, 2.0, 3.0])),
1183 ],
1184 )
1185 .unwrap();
1186
1187 let dataset = ArrowDataset::from_batch(batch).unwrap();
1188 dataset.to_parquet(&path).unwrap();
1189
1190 let result = cmd_schema(&path);
1191 assert!(result.is_ok());
1192 }
1193
1194 #[test]
1195 fn test_cmd_convert_csv_to_parquet() {
1196 let temp_dir = tempfile::tempdir()
1197 .ok()
1198 .unwrap_or_else(|| panic!("Should create temp dir"));
1199 let csv_path = temp_dir.path().join("input.csv");
1200 let parquet_path = temp_dir.path().join("output.parquet");
1201
1202 std::fs::write(&csv_path, "id,name\n1,foo\n2,bar\n").unwrap();
1203
1204 let result = cmd_convert(&csv_path, &parquet_path);
1205 assert!(result.is_ok());
1206 assert!(parquet_path.exists());
1207 }
1208
1209 #[test]
1210 fn test_cmd_convert_json_to_parquet() {
1211 let temp_dir = tempfile::tempdir()
1212 .ok()
1213 .unwrap_or_else(|| panic!("Should create temp dir"));
1214 let json_path = temp_dir.path().join("input.json");
1215 let parquet_path = temp_dir.path().join("output.parquet");
1216
1217 std::fs::write(
1218 &json_path,
1219 r#"{"id":1,"name":"foo"}
1220{"id":2,"name":"bar"}"#,
1221 )
1222 .unwrap();
1223
1224 let result = cmd_convert(&json_path, &parquet_path);
1225 assert!(result.is_ok());
1226 assert!(parquet_path.exists());
1227 }
1228
1229 #[test]
1230 fn test_save_dataset_jsonl() {
1231 let temp_dir = tempfile::tempdir()
1232 .ok()
1233 .unwrap_or_else(|| panic!("Should create temp dir"));
1234 let path = temp_dir.path().join("output.jsonl");
1235
1236 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1237 let batch = arrow::array::RecordBatch::try_new(
1238 schema,
1239 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1240 )
1241 .unwrap();
1242 let dataset = ArrowDataset::from_batch(batch).unwrap();
1243
1244 let result = save_dataset(&dataset, &path);
1245 assert!(result.is_ok());
1246 }
1247
1248 #[test]
1249 fn test_load_dataset_no_extension() {
1250 let path = PathBuf::from("file_without_extension");
1251 let result = load_dataset(&path);
1252 assert!(result.is_err());
1253 }
1254
1255 #[test]
1256 fn test_cmd_head_exact_rows() {
1257 let temp_dir = tempfile::tempdir()
1258 .ok()
1259 .unwrap_or_else(|| panic!("Should create temp dir"));
1260 let path = temp_dir.path().join("exact.parquet");
1261 create_test_parquet(&path, 10);
1262
1263 let result = cmd_head(&path, 10);
1265 assert!(result.is_ok());
1266 }
1267
1268 #[test]
1269 fn test_cmd_convert_parquet_to_parquet() {
1270 let temp_dir = tempfile::tempdir()
1271 .ok()
1272 .unwrap_or_else(|| panic!("Should create temp dir"));
1273 let input = temp_dir.path().join("input.parquet");
1274 let output = temp_dir.path().join("output.parquet");
1275 create_test_parquet(&input, 20);
1276
1277 let result = cmd_convert(&input, &output);
1278 assert!(result.is_ok());
1279
1280 let original = ArrowDataset::from_parquet(&input).unwrap();
1282 let converted = ArrowDataset::from_parquet(&output).unwrap();
1283 assert_eq!(original.len(), converted.len());
1284 }
1285
1286 #[test]
1287 fn test_get_format_all_types() {
1288 assert_eq!(get_format(Path::new("data.parquet")), "Parquet");
1289 assert_eq!(get_format(Path::new("data.arrow")), "Arrow IPC");
1290 assert_eq!(get_format(Path::new("data.ipc")), "Arrow IPC");
1291 assert_eq!(get_format(Path::new("data.csv")), "CSV");
1292 assert_eq!(get_format(Path::new("data.json")), "JSON");
1293 assert_eq!(get_format(Path::new("data.jsonl")), "JSON");
1294 assert_eq!(get_format(Path::new("data.txt")), "Unknown");
1295 assert_eq!(get_format(Path::new("data.yaml")), "Unknown");
1296 assert_eq!(get_format(Path::new("data")), "Unknown");
1297 }
1298}