1use std::path::{Path, PathBuf};
4
5use arrow::util::pretty::print_batches;
6
7use crate::{ArrowDataset, Dataset};
8
9#[cfg(feature = "shuffle")]
12type MixInputs = (Vec<(ArrowDataset, f64, String)>, f64);
13
14pub(crate) fn load_dataset(path: &Path) -> crate::Result<ArrowDataset> {
16 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
17
18 match ext {
19 "parquet" => ArrowDataset::from_parquet(path),
20 "csv" => ArrowDataset::from_csv(path),
21 "json" | "jsonl" => ArrowDataset::from_json(path),
22 ext => Err(crate::Error::unsupported_format(ext)),
23 }
24}
25
26pub(crate) fn save_dataset(dataset: &ArrowDataset, path: &Path) -> crate::Result<()> {
28 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
29
30 match ext {
31 "parquet" => dataset.to_parquet(path),
32 "csv" => dataset.to_csv(path),
33 "json" | "jsonl" => dataset.to_json(path),
34 ext => Err(crate::Error::unsupported_format(ext)),
35 }
36}
37
38pub(crate) fn get_format(path: &Path) -> &'static str {
40 match path.extension().and_then(|e| e.to_str()) {
41 Some("parquet") => "Parquet",
42 Some("arrow" | "ipc") => "Arrow IPC",
43 Some("csv") => "CSV",
44 Some("json" | "jsonl") => "JSON",
45 _ => "Unknown",
46 }
47}
48
49pub(crate) fn cmd_convert(input: &Path, output: &Path) -> crate::Result<()> {
51 let dataset = load_dataset(input)?;
53
54 save_dataset(&dataset, output)?;
56
57 println!(
58 "Converted {} -> {} ({} rows)",
59 input.display(),
60 output.display(),
61 dataset.len()
62 );
63
64 Ok(())
65}
66
67pub(crate) fn cmd_info(path: &Path) -> crate::Result<()> {
72 let file_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
73 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
74
75 println!("File: {}", path.display());
76 println!("Format: {}", get_format(path));
77
78 if ext == "parquet" {
79 let file = std::fs::File::open(path).map_err(|e| crate::Error::io(e, path))?;
81 let builder = parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file)
82 .map_err(crate::Error::Parquet)?;
83 let metadata = builder.metadata();
84 let num_rows: i64 = metadata.row_groups().iter().map(|rg| rg.num_rows()).sum();
85 let num_batches = metadata.num_row_groups();
86 let num_columns = metadata
87 .row_groups()
88 .first()
89 .map_or(0, |rg| rg.num_columns());
90 println!("Rows: {num_rows}");
91 println!("Batches: {num_batches}");
92 println!("Columns: {num_columns}");
93 } else {
94 let dataset = load_dataset(path)?;
95 println!("Rows: {}", dataset.len());
96 println!("Batches: {}", dataset.num_batches());
97 println!("Columns: {}", dataset.schema().fields().len());
98 }
99
100 println!("Size: {file_size} bytes");
101
102 Ok(())
103}
104
105pub(crate) fn cmd_head(path: &Path, rows: usize) -> crate::Result<()> {
107 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
110 let dataset = if ext == "parquet" {
111 let file = std::fs::File::open(path).map_err(|e| crate::Error::io(e, path))?;
112 let builder = parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file)
113 .map_err(crate::Error::Parquet)?;
114 let reader = builder
115 .with_limit(rows)
116 .build()
117 .map_err(crate::Error::Parquet)?;
118 let batches: Vec<arrow::record_batch::RecordBatch> = reader
119 .collect::<std::result::Result<Vec<_>, _>>()
120 .map_err(crate::Error::Arrow)?;
121 if batches.is_empty() {
122 println!("Dataset is empty");
123 return Ok(());
124 }
125 ArrowDataset::new(batches)?
126 } else {
127 load_dataset(path)?
128 };
129
130 if dataset.is_empty() {
131 println!("Dataset is empty");
132 return Ok(());
133 }
134
135 let mut collected = Vec::new();
137 let mut count = 0;
138
139 for batch in dataset.iter() {
140 let take = (rows - count).min(batch.num_rows());
141 if take > 0 {
142 collected.push(batch.slice(0, take));
143 count += take;
144 }
145 if count >= rows {
146 break;
147 }
148 }
149
150 if collected.is_empty() {
151 println!("No data to display");
152 return Ok(());
153 }
154
155 print_batches(&collected).map_err(crate::Error::Arrow)?;
157
158 if count < dataset.len() {
159 println!("... showing {} of {} rows", count, dataset.len());
160 }
161
162 Ok(())
163}
164
165pub(crate) fn cmd_schema(path: &Path) -> crate::Result<()> {
167 let dataset = load_dataset(path)?;
168 let schema = dataset.schema();
169
170 println!("Schema for {}:", path.display());
171 println!();
172
173 for (i, field) in schema.fields().iter().enumerate() {
174 let nullable = if field.is_nullable() {
175 "nullable"
176 } else {
177 "not null"
178 };
179 println!(
180 " {}: {} ({}) [{}]",
181 i,
182 field.name(),
183 field.data_type(),
184 nullable
185 );
186 }
187
188 println!();
189 println!("Total columns: {}", schema.fields().len());
190
191 Ok(())
192}
193
194#[cfg(feature = "shuffle")]
196fn parse_input_spec(spec: &str) -> (PathBuf, f64) {
197 if let Some((path, weight_str)) = spec.rsplit_once(':') {
198 if let Ok(weight) = weight_str.parse::<f64>() {
200 return (PathBuf::from(path), weight);
201 }
202 }
203 (PathBuf::from(spec), 1.0)
204}
205
206#[cfg(feature = "shuffle")]
208fn load_mix_inputs(inputs: &[String]) -> crate::Result<MixInputs> {
209 let mut datasets = Vec::new();
210 let mut total_weight = 0.0;
211
212 for spec in inputs {
213 let (path, weight) = parse_input_spec(spec);
214 if !path.exists() {
215 return Err(crate::Error::io(
216 std::io::Error::new(std::io::ErrorKind::NotFound, "Input file not found"),
217 &path,
218 ));
219 }
220 let dataset = load_dataset(&path)?;
221 println!(
222 " Loaded {} ({} rows, weight={:.2})",
223 path.display(),
224 dataset.len(),
225 weight
226 );
227 total_weight += weight;
228 datasets.push((dataset, weight, path.display().to_string()));
229 }
230 Ok((datasets, total_weight))
231}
232
233#[cfg(feature = "shuffle")]
235#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
236fn sample_dataset(
237 dataset: &ArrowDataset,
238 rows_needed: usize,
239 rng: &mut rand::rngs::StdRng,
240) -> crate::Result<arrow::array::RecordBatch> {
241 use rand::seq::SliceRandom;
242
243 let available = dataset.len();
244 let mut indices: Vec<usize> = (0..available).collect();
245 indices.shuffle(rng);
246
247 if rows_needed > available {
248 let extra: Vec<usize> = (0..available)
249 .cycle()
250 .take(rows_needed - available)
251 .collect();
252 indices.extend(extra);
253 }
254 indices.truncate(rows_needed);
255
256 let schema = dataset.schema();
257 let flat_batches: Vec<_> = dataset.iter().collect();
258 let concatenated = arrow::compute::concat_batches(&schema, &flat_batches)
259 .map_err(|e| crate::Error::invalid_config(format!("Arrow concat error: {e}")))?;
260
261 let take_indices: Vec<u32> = indices.iter().map(|&i| i as u32).collect();
262 let index_array = arrow::array::UInt32Array::from(take_indices);
263
264 let columns: Vec<arrow::array::ArrayRef> = (0..concatenated.num_columns())
265 .map(|col_idx| {
266 arrow::compute::take(concatenated.column(col_idx), &index_array, None)
267 .map_err(|e| crate::Error::invalid_config(format!("Arrow take error: {e}")))
268 })
269 .collect::<crate::Result<Vec<_>>>()?;
270
271 arrow::array::RecordBatch::try_new(schema, columns)
272 .map_err(|e| crate::Error::invalid_config(format!("RecordBatch error: {e}")))
273}
274
275#[cfg(feature = "shuffle")]
277#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
278pub(crate) fn cmd_mix(
279 inputs: &[String],
280 output: &Path,
281 seed: u64,
282 max_rows: usize,
283) -> crate::Result<()> {
284 use rand::{rngs::StdRng, SeedableRng};
285
286 if inputs.is_empty() {
287 return Err(crate::Error::invalid_config("No input files provided"));
288 }
289
290 let (datasets, total_weight) = load_mix_inputs(inputs)?;
291 if total_weight == 0.0 {
292 return Err(crate::Error::invalid_config("All weights are zero"));
293 }
294
295 let total_available: usize = datasets.iter().map(|(d, _, _)| d.len()).sum();
296 let target_rows = if max_rows > 0 {
297 max_rows
298 } else {
299 total_available
300 };
301
302 println!(
303 "\nMixing {} datasets → {} target rows",
304 datasets.len(),
305 target_rows
306 );
307
308 let mut rng = StdRng::seed_from_u64(seed);
309 let mut all_batches = Vec::new();
310 let mut mixed_rows = 0;
311
312 for (dataset, weight, name) in &datasets {
313 let fraction = weight / total_weight;
314 let rows_for_dataset = (target_rows as f64 * fraction) as usize;
315
316 let batch = sample_dataset(dataset, rows_for_dataset, &mut rng)?;
317 let count = batch.num_rows();
318 all_batches.push(batch);
319 mixed_rows += count;
320
321 println!(" {} → {} rows ({:.1}%)", name, count, fraction * 100.0);
322 }
323
324 if all_batches.is_empty() {
325 return Err(crate::Error::invalid_config("No data to mix"));
326 }
327
328 let mixed = ArrowDataset::new(all_batches)?;
329 save_dataset(&mixed, output)?;
330
331 println!("\nMixed {} rows → {}", mixed_rows, output.display());
332 Ok(())
333}
334
335#[cfg(feature = "shuffle")]
336pub(crate) fn cmd_fim(
337 input: &Path,
338 output: &Path,
339 column: &str,
340 rate: f64,
341 format: &str,
342 seed: u64,
343) -> crate::Result<()> {
344 use crate::transform::{Fim, FimFormat, Transform};
345
346 let dataset = load_dataset(input)?;
347 let fim_format = match format {
348 "spm" => FimFormat::SPM,
349 _ => FimFormat::PSM,
350 };
351
352 let fim = Fim::new(column)
353 .with_rate(rate)
354 .with_format(fim_format)
355 .with_seed(seed);
356
357 let mut all_batches = Vec::new();
358 for batch in dataset.iter() {
359 all_batches.push(fim.apply(batch)?);
360 }
361
362 let transformed = ArrowDataset::new(all_batches)?;
363 save_dataset(&transformed, output)?;
364
365 println!(
366 "FIM transform ({} format, {:.0}% rate) applied to '{}' column",
367 format.to_uppercase(),
368 rate * 100.0,
369 column
370 );
371 println!("{} rows → {}", dataset.len(), output.display());
372 Ok(())
373}
374
375pub(crate) fn cmd_dedup(input: &Path, output: &Path, column: Option<&str>) -> crate::Result<()> {
380 use crate::transform::{Transform, Unique};
381
382 let dataset = load_dataset(input)?;
383 let original_rows = dataset.len();
384
385 let dedup = match column {
387 Some(col) => Unique::by(vec![col]),
388 None => detect_text_column_dedup(&dataset),
389 };
390
391 let mut all_batches = Vec::new();
392 for batch in dataset.iter() {
393 all_batches.push(dedup.apply(batch)?);
394 }
395
396 let deduped = ArrowDataset::new(all_batches)?;
397 let deduped_rows = deduped.len();
398 save_dataset(&deduped, output)?;
399
400 let removed = original_rows - deduped_rows;
401 println!(
402 "Dedup: {} → {} rows ({} duplicates removed, {:.1}% reduction)",
403 original_rows,
404 deduped_rows,
405 removed,
406 removed as f64 / original_rows.max(1) as f64 * 100.0
407 );
408 Ok(())
409}
410
411fn detect_text_column_dedup(dataset: &ArrowDataset) -> crate::transform::Unique {
413 use arrow::datatypes::DataType;
414
415 use crate::transform::Unique;
416
417 let schema = dataset.schema();
418 for name in &["text", "content", "code", "source"] {
419 if let Some((_, field)) = schema.column_with_name(name) {
420 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
421 return Unique::by(vec![*name]);
422 }
423 }
424 }
425 Unique::all()
427}
428
429pub(crate) fn cmd_filter_text(
434 input: &Path,
435 output: &Path,
436 column: Option<&str>,
437 min_score: f64,
438 min_length: usize,
439 max_length: usize,
440) -> crate::Result<()> {
441 use crate::transform::Transform;
442
443 let dataset = load_dataset(input)?;
444 let original_rows = dataset.len();
445
446 let col_name = column
447 .map(String::from)
448 .unwrap_or_else(|| find_text_column(&dataset));
449
450 let filter = TextQualityFilter::new(&col_name, min_score, min_length, max_length);
451
452 let mut all_batches = Vec::new();
453 for batch in dataset.iter() {
454 all_batches.push(filter.apply(batch)?);
455 }
456
457 let filtered = ArrowDataset::new(all_batches)?;
458 let kept = filtered.len();
459 save_dataset(&filtered, output)?;
460
461 let removed = original_rows - kept;
462 println!(
463 "Filter: {} → {} rows ({} removed, {:.1}% kept)",
464 original_rows,
465 kept,
466 removed,
467 kept as f64 / original_rows.max(1) as f64 * 100.0
468 );
469 println!(
470 " min_score={:.2} min_len={} max_len={} column='{}'",
471 min_score, min_length, max_length, col_name
472 );
473 Ok(())
474}
475
476fn find_text_column(dataset: &ArrowDataset) -> String {
478 use arrow::datatypes::DataType;
479 let schema = dataset.schema();
480 for name in &["text", "content", "code", "source"] {
481 if let Some((_, field)) = schema.column_with_name(name) {
482 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
483 return (*name).to_string();
484 }
485 }
486 }
487 for field in schema.fields() {
489 if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
490 return field.name().clone();
491 }
492 }
493 "text".to_string()
494}
495
496struct TextQualityFilter {
498 column: String,
499 min_score: f64,
500 min_length: usize,
501 max_length: usize,
502}
503
504impl TextQualityFilter {
505 fn new(column: &str, min_score: f64, min_length: usize, max_length: usize) -> Self {
506 Self {
507 column: column.to_string(),
508 min_score,
509 min_length,
510 max_length,
511 }
512 }
513}
514
515impl crate::transform::Transform for TextQualityFilter {
516 fn apply(&self, batch: arrow::array::RecordBatch) -> crate::Result<arrow::array::RecordBatch> {
517 use arrow::{
518 array::{Array, BooleanArray, StringArray},
519 compute::filter_record_batch,
520 };
521
522 let schema = batch.schema();
523 let col_idx = schema
524 .column_with_name(&self.column)
525 .map(|(i, _)| i)
526 .ok_or_else(|| crate::Error::column_not_found(&self.column))?;
527
528 let text_arr = batch
529 .column(col_idx)
530 .as_any()
531 .downcast_ref::<StringArray>()
532 .ok_or_else(|| crate::Error::column_not_found(&self.column))?;
533
534 let mask: BooleanArray = (0..text_arr.len())
535 .map(|i| {
536 if text_arr.is_null(i) {
537 Some(false)
538 } else {
539 let text = text_arr.value(i);
540 Some(passes_quality(
541 text,
542 self.min_score,
543 self.min_length,
544 self.max_length,
545 ))
546 }
547 })
548 .collect();
549
550 filter_record_batch(&batch, &mask).map_err(crate::Error::Arrow)
551 }
552}
553
554fn passes_quality(text: &str, min_score: f64, min_len: usize, max_len: usize) -> bool {
556 let len = text.len();
557 if len < min_len || len > max_len {
558 return false;
559 }
560 composite_score(text) >= min_score
561}
562
563fn composite_score(text: &str) -> f64 {
565 let s1 = score_alnum_ratio(text);
566 let s2 = score_line_length(text);
567 let s3 = score_dup_lines(text);
568 let s4 = score_entropy(text);
569 (s1 + s2 + s3 + s4) / 4.0
570}
571
572fn score_alnum_ratio(text: &str) -> f64 {
574 if text.is_empty() {
575 return 0.0;
576 }
577 let alnum = text.chars().filter(|c| c.is_alphanumeric()).count();
578 let ratio = alnum as f64 / text.len() as f64;
579 if ratio < 0.2 {
580 0.0
581 } else if ratio < 0.3 {
582 ratio
583 } else {
584 1.0
585 }
586}
587
588fn score_line_length(text: &str) -> f64 {
590 let lines: Vec<&str> = text.lines().collect();
591 if lines.is_empty() {
592 return 0.0;
593 }
594 let avg = text.len() as f64 / lines.len() as f64;
595 if avg < 10.0 {
596 0.2
597 } else if avg > 200.0 {
598 0.5
599 } else {
600 1.0
601 }
602}
603
604fn score_dup_lines(text: &str) -> f64 {
606 use std::collections::HashSet;
607 let lines: Vec<&str> = text.lines().collect();
608 if lines.len() <= 1 {
609 return 1.0;
610 }
611 let unique: HashSet<&str> = lines.iter().copied().collect();
612 let dup_ratio = 1.0 - (unique.len() as f64 / lines.len() as f64);
613 if dup_ratio > 0.5 {
614 0.2
615 } else {
616 1.0 - dup_ratio
617 }
618}
619
620fn score_entropy(text: &str) -> f64 {
622 if text.is_empty() {
623 return 0.0;
624 }
625 let mut counts = [0u32; 256];
626 for &b in text.as_bytes() {
627 counts[b as usize] += 1;
628 }
629 let len = text.len() as f64;
630 let entropy: f64 = counts
631 .iter()
632 .filter(|&&c| c > 0)
633 .map(|&c| {
634 let p = f64::from(c) / len;
635 -p * p.ln()
636 })
637 .sum();
638 let e = entropy / std::f64::consts::LN_2; if e < 2.0 {
640 0.2
641 } else if e > 6.5 {
642 0.3
643 } else {
644 1.0
645 }
646}
647
648#[cfg(test)]
649#[allow(
650 clippy::cast_possible_truncation,
651 clippy::cast_possible_wrap,
652 clippy::cast_precision_loss,
653 clippy::uninlined_format_args,
654 clippy::unwrap_used,
655 clippy::expect_used,
656 clippy::redundant_clone,
657 clippy::cast_lossless,
658 clippy::redundant_closure_for_method_calls,
659 clippy::too_many_lines,
660 clippy::float_cmp,
661 clippy::similar_names,
662 clippy::needless_late_init,
663 clippy::redundant_pattern_matching
664)]
665mod tests {
666 use std::sync::Arc;
667
668 use arrow::{
669 array::{Int32Array, StringArray},
670 datatypes::{DataType, Field, Schema},
671 };
672
673 use super::*;
674
675 fn create_test_parquet(path: &Path, rows: usize) {
676 let schema = Arc::new(Schema::new(vec![
677 Field::new("id", DataType::Int32, false),
678 Field::new("name", DataType::Utf8, false),
679 ]));
680
681 let ids: Vec<i32> = (0..rows as i32).collect();
682 let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
683
684 let batch = arrow::array::RecordBatch::try_new(
685 schema,
686 vec![
687 Arc::new(Int32Array::from(ids)),
688 Arc::new(StringArray::from(names)),
689 ],
690 )
691 .ok()
692 .unwrap_or_else(|| panic!("Should create batch"));
693
694 let dataset = ArrowDataset::from_batch(batch)
695 .ok()
696 .unwrap_or_else(|| panic!("Should create dataset"));
697
698 dataset
699 .to_parquet(path)
700 .ok()
701 .unwrap_or_else(|| panic!("Should write parquet"));
702 }
703
704 #[test]
705 fn test_cmd_info() {
706 let temp_dir = tempfile::tempdir()
707 .ok()
708 .unwrap_or_else(|| panic!("Should create temp dir"));
709 let path = temp_dir.path().join("test.parquet");
710 create_test_parquet(&path, 100);
711
712 let result = cmd_info(&path);
713 assert!(result.is_ok());
714 }
715
716 #[test]
717 fn test_cmd_head() {
718 let temp_dir = tempfile::tempdir()
719 .ok()
720 .unwrap_or_else(|| panic!("Should create temp dir"));
721 let path = temp_dir.path().join("test.parquet");
722 create_test_parquet(&path, 100);
723
724 let result = cmd_head(&path, 5);
725 assert!(result.is_ok());
726 }
727
728 #[test]
729 fn test_cmd_schema() {
730 let temp_dir = tempfile::tempdir()
731 .ok()
732 .unwrap_or_else(|| panic!("Should create temp dir"));
733 let path = temp_dir.path().join("test.parquet");
734 create_test_parquet(&path, 10);
735
736 let result = cmd_schema(&path);
737 assert!(result.is_ok());
738 }
739
740 #[test]
741 fn test_cmd_convert() {
742 let temp_dir = tempfile::tempdir()
743 .ok()
744 .unwrap_or_else(|| panic!("Should create temp dir"));
745 let input = temp_dir.path().join("input.parquet");
746 let output = temp_dir.path().join("output.parquet");
747 create_test_parquet(&input, 50);
748
749 let result = cmd_convert(&input, &output);
750 assert!(result.is_ok());
751
752 let original = ArrowDataset::from_parquet(&input)
754 .ok()
755 .unwrap_or_else(|| panic!("Should load original"));
756 let converted = ArrowDataset::from_parquet(&output)
757 .ok()
758 .unwrap_or_else(|| panic!("Should load converted"));
759
760 assert_eq!(original.len(), converted.len());
761 }
762
763 #[test]
764 fn test_load_dataset_unsupported() {
765 let path = PathBuf::from("test.xyz");
766 let result = load_dataset(&path);
767 assert!(result.is_err());
768 }
769
770 #[test]
771 fn test_get_format() {
772 assert_eq!(get_format(Path::new("test.parquet")), "Parquet");
773 assert_eq!(get_format(Path::new("test.arrow")), "Arrow IPC");
774 assert_eq!(get_format(Path::new("test.csv")), "CSV");
775 assert_eq!(get_format(Path::new("test.json")), "JSON");
776 assert_eq!(get_format(Path::new("test.unknown")), "Unknown");
777 }
778
779 #[test]
780 fn test_cmd_head_with_more_rows_than_dataset() {
781 let temp_dir = tempfile::tempdir()
782 .ok()
783 .unwrap_or_else(|| panic!("Should create temp dir"));
784 let path = temp_dir.path().join("test.parquet");
785 create_test_parquet(&path, 5);
786
787 let result = cmd_head(&path, 100);
789 assert!(result.is_ok());
790 }
791
792 #[test]
793 fn test_cmd_convert_parquet_to_csv() {
794 let temp_dir = tempfile::tempdir()
795 .ok()
796 .unwrap_or_else(|| panic!("Should create temp dir"));
797 let input = temp_dir.path().join("input.parquet");
798 let output = temp_dir.path().join("output.csv");
799 create_test_parquet(&input, 25);
800
801 let result = cmd_convert(&input, &output);
802 assert!(result.is_ok());
803 assert!(output.exists());
804 }
805
806 #[test]
807 fn test_cmd_convert_parquet_to_json() {
808 let temp_dir = tempfile::tempdir()
809 .ok()
810 .unwrap_or_else(|| panic!("Should create temp dir"));
811 let input = temp_dir.path().join("input.parquet");
812 let output = temp_dir.path().join("output.json");
813 create_test_parquet(&input, 15);
814
815 let result = cmd_convert(&input, &output);
816 assert!(result.is_ok());
817 assert!(output.exists());
818 }
819
820 #[test]
821 fn test_save_dataset_unsupported_format() {
822 let temp_dir = tempfile::tempdir()
823 .ok()
824 .unwrap_or_else(|| panic!("Should create temp dir"));
825 let input = temp_dir.path().join("data.parquet");
826 let output = temp_dir.path().join("output.xyz");
827 create_test_parquet(&input, 5);
828
829 let dataset = ArrowDataset::from_parquet(&input)
830 .ok()
831 .unwrap_or_else(|| panic!("Should load"));
832
833 let result = save_dataset(&dataset, &output);
834 assert!(result.is_err());
835 }
836
837 #[test]
838 fn test_get_format_ipc() {
839 assert_eq!(get_format(Path::new("test.ipc")), "Arrow IPC");
840 }
841
842 #[test]
843 fn test_get_format_jsonl() {
844 assert_eq!(get_format(Path::new("test.jsonl")), "JSON");
845 }
846
847 #[test]
848 fn test_get_format_no_extension() {
849 assert_eq!(get_format(Path::new("testfile")), "Unknown");
850 }
851
852 #[test]
853 fn test_cmd_convert_unsupported_output() {
854 let temp_dir = tempfile::tempdir()
855 .ok()
856 .unwrap_or_else(|| panic!("Should create temp dir"));
857 let input = temp_dir.path().join("input.parquet");
858 let output = temp_dir.path().join("output.xyz");
859 create_test_parquet(&input, 10);
860
861 let result = cmd_convert(&input, &output);
862 assert!(result.is_err());
863 }
864
865 #[test]
866 fn test_load_dataset_xyz_format() {
867 let temp_dir = tempfile::tempdir()
868 .ok()
869 .unwrap_or_else(|| panic!("Should create temp dir"));
870 let path = temp_dir.path().join("data.xyz");
871
872 std::fs::write(&path, "some data")
873 .ok()
874 .unwrap_or_else(|| panic!("Should write file"));
875
876 let result = load_dataset(&path);
877 assert!(result.is_err());
878 }
879
880 #[test]
881 fn test_get_format_arrow() {
882 assert_eq!(get_format(Path::new("test.arrow")), "Arrow IPC");
883 }
884
885 #[test]
886 fn test_get_format_unknown() {
887 assert_eq!(get_format(Path::new("test.feather")), "Unknown");
888 assert_eq!(get_format(Path::new("test.txt")), "Unknown");
889 }
890
891 #[test]
892 fn test_load_dataset_csv() {
893 let temp_dir = tempfile::tempdir()
894 .ok()
895 .unwrap_or_else(|| panic!("Should create temp dir"));
896 let parquet_path = temp_dir.path().join("data.parquet");
897 let csv_path = temp_dir.path().join("data.csv");
898
899 create_test_parquet(&parquet_path, 10);
900
901 let dataset = ArrowDataset::from_parquet(&parquet_path)
903 .ok()
904 .unwrap_or_else(|| panic!("Should load"));
905 dataset
906 .to_csv(&csv_path)
907 .ok()
908 .unwrap_or_else(|| panic!("Should write csv"));
909
910 let loaded = load_dataset(&csv_path);
912 assert!(loaded.is_ok());
913 }
914
915 #[test]
916 fn test_load_dataset_json() {
917 let temp_dir = tempfile::tempdir()
918 .ok()
919 .unwrap_or_else(|| panic!("Should create temp dir"));
920 let parquet_path = temp_dir.path().join("data.parquet");
921 let json_path = temp_dir.path().join("data.json");
922
923 create_test_parquet(&parquet_path, 10);
924
925 let dataset = ArrowDataset::from_parquet(&parquet_path)
927 .ok()
928 .unwrap_or_else(|| panic!("Should load"));
929 dataset
930 .to_json(&json_path)
931 .ok()
932 .unwrap_or_else(|| panic!("Should write json"));
933
934 let loaded = load_dataset(&json_path);
936 assert!(loaded.is_ok());
937 }
938
939 #[test]
940 fn test_load_dataset_jsonl() {
941 let temp_dir = tempfile::tempdir()
942 .ok()
943 .unwrap_or_else(|| panic!("Should create temp dir"));
944 let parquet_path = temp_dir.path().join("data.parquet");
945 let jsonl_path = temp_dir.path().join("data.jsonl");
946
947 create_test_parquet(&parquet_path, 10);
948
949 let dataset = ArrowDataset::from_parquet(&parquet_path)
951 .ok()
952 .unwrap_or_else(|| panic!("Should load"));
953 dataset
954 .to_json(&jsonl_path)
955 .ok()
956 .unwrap_or_else(|| panic!("Should write jsonl"));
957
958 let loaded = load_dataset(&jsonl_path);
960 assert!(loaded.is_ok());
961 }
962
963 #[test]
964 fn test_save_dataset_parquet() {
965 let temp_dir = tempfile::tempdir()
966 .ok()
967 .unwrap_or_else(|| panic!("Should create temp dir"));
968 let path = temp_dir.path().join("output.parquet");
969
970 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
971 let batch = arrow::array::RecordBatch::try_new(
972 schema,
973 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
974 )
975 .unwrap();
976 let dataset = ArrowDataset::from_batch(batch).unwrap();
977
978 let result = save_dataset(&dataset, &path);
979 assert!(result.is_ok());
980 }
981
982 #[test]
983 fn test_save_dataset_csv() {
984 let temp_dir = tempfile::tempdir()
985 .ok()
986 .unwrap_or_else(|| panic!("Should create temp dir"));
987 let path = temp_dir.path().join("output.csv");
988
989 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
990 let batch = arrow::array::RecordBatch::try_new(
991 schema,
992 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
993 )
994 .unwrap();
995 let dataset = ArrowDataset::from_batch(batch).unwrap();
996
997 let result = save_dataset(&dataset, &path);
998 assert!(result.is_ok());
999 }
1000
1001 #[test]
1002 fn test_save_dataset_json() {
1003 let temp_dir = tempfile::tempdir()
1004 .ok()
1005 .unwrap_or_else(|| panic!("Should create temp dir"));
1006 let path = temp_dir.path().join("output.json");
1007
1008 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1009 let batch = arrow::array::RecordBatch::try_new(
1010 schema,
1011 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1012 )
1013 .unwrap();
1014 let dataset = ArrowDataset::from_batch(batch).unwrap();
1015
1016 let result = save_dataset(&dataset, &path);
1017 assert!(result.is_ok());
1018 }
1019
1020 #[test]
1021 fn test_save_dataset_unknown_extension() {
1022 let temp_dir = tempfile::tempdir()
1023 .ok()
1024 .unwrap_or_else(|| panic!("Should create temp dir"));
1025 let path = temp_dir.path().join("output.xyz");
1026
1027 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1028 let batch = arrow::array::RecordBatch::try_new(
1029 schema,
1030 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1031 )
1032 .unwrap();
1033 let dataset = ArrowDataset::from_batch(batch).unwrap();
1034
1035 let result = save_dataset(&dataset, &path);
1036 assert!(result.is_err());
1037 }
1038
1039 #[test]
1040 fn test_cmd_convert_to_csv_format() {
1041 let temp_dir = tempfile::tempdir()
1042 .ok()
1043 .unwrap_or_else(|| panic!("Should create temp dir"));
1044 let input = temp_dir.path().join("input.parquet");
1045 let output = temp_dir.path().join("output.csv");
1046 create_test_parquet(&input, 20);
1047
1048 let result = cmd_convert(&input, &output);
1049 assert!(result.is_ok());
1050 assert!(output.exists());
1051 }
1052
1053 #[test]
1054 fn test_cmd_convert_to_json_format() {
1055 let temp_dir = tempfile::tempdir()
1056 .ok()
1057 .unwrap_or_else(|| panic!("Should create temp dir"));
1058 let input = temp_dir.path().join("input.parquet");
1059 let output = temp_dir.path().join("output.json");
1060 create_test_parquet(&input, 20);
1061
1062 let result = cmd_convert(&input, &output);
1063 assert!(result.is_ok());
1064 assert!(output.exists());
1065 }
1066
1067 #[test]
1068 fn test_cmd_head_more_than_available() {
1069 let temp_dir = tempfile::tempdir()
1070 .ok()
1071 .unwrap_or_else(|| panic!("Should create temp dir"));
1072 let path = temp_dir.path().join("small.parquet");
1073 create_test_parquet(&path, 5);
1074
1075 let result = cmd_head(&path, 100);
1077 assert!(result.is_ok());
1078 }
1079
1080 #[test]
1081 fn test_load_dataset_csv_file() {
1082 let temp_dir = tempfile::tempdir()
1083 .ok()
1084 .unwrap_or_else(|| panic!("Should create temp dir"));
1085 let csv_path = temp_dir.path().join("test.csv");
1086
1087 std::fs::write(&csv_path, "id,name\n1,foo\n2,bar\n").unwrap();
1089
1090 let result = load_dataset(&csv_path);
1091 assert!(result.is_ok());
1092 }
1093
1094 #[test]
1095 fn test_load_dataset_json_file() {
1096 let temp_dir = tempfile::tempdir()
1097 .ok()
1098 .unwrap_or_else(|| panic!("Should create temp dir"));
1099 let json_path = temp_dir.path().join("test.json");
1100
1101 std::fs::write(
1103 &json_path,
1104 r#"{"id":1,"name":"foo"}
1105{"id":2,"name":"bar"}"#,
1106 )
1107 .unwrap();
1108
1109 let result = load_dataset(&json_path);
1110 assert!(result.is_ok());
1111 }
1112
1113 #[test]
1116 fn test_cmd_head_zero_rows() {
1117 let temp_dir = tempfile::tempdir()
1118 .ok()
1119 .unwrap_or_else(|| panic!("Should create temp dir"));
1120 let path = temp_dir.path().join("test.parquet");
1121 create_test_parquet(&path, 50);
1122
1123 let result = cmd_head(&path, 0);
1125 assert!(result.is_ok());
1126 }
1127
1128 #[test]
1129 fn test_cmd_info_small_file() {
1130 let temp_dir = tempfile::tempdir()
1131 .ok()
1132 .unwrap_or_else(|| panic!("Should create temp dir"));
1133 let path = temp_dir.path().join("small.parquet");
1134 create_test_parquet(&path, 5);
1135
1136 let result = cmd_info(&path);
1137 assert!(result.is_ok());
1138 }
1139
1140 #[test]
1141 fn test_cmd_info_large_file() {
1142 let temp_dir = tempfile::tempdir()
1143 .ok()
1144 .unwrap_or_else(|| panic!("Should create temp dir"));
1145 let path = temp_dir.path().join("large.parquet");
1146 create_test_parquet(&path, 1000);
1147
1148 let result = cmd_info(&path);
1149 assert!(result.is_ok());
1150 }
1151
1152 #[test]
1153 fn test_cmd_schema_complex() {
1154 let temp_dir = tempfile::tempdir()
1155 .ok()
1156 .unwrap_or_else(|| panic!("Should create temp dir"));
1157 let path = temp_dir.path().join("complex.parquet");
1158
1159 let schema = Arc::new(Schema::new(vec![
1161 Field::new("id", DataType::Int32, false),
1162 Field::new("name", DataType::Utf8, true),
1163 Field::new("value", DataType::Float64, true),
1164 ]));
1165
1166 let batch = arrow::array::RecordBatch::try_new(
1167 schema,
1168 vec![
1169 Arc::new(Int32Array::from(vec![1, 2, 3])),
1170 Arc::new(StringArray::from(vec!["a", "b", "c"])),
1171 Arc::new(arrow::array::Float64Array::from(vec![1.0, 2.0, 3.0])),
1172 ],
1173 )
1174 .unwrap();
1175
1176 let dataset = ArrowDataset::from_batch(batch).unwrap();
1177 dataset.to_parquet(&path).unwrap();
1178
1179 let result = cmd_schema(&path);
1180 assert!(result.is_ok());
1181 }
1182
1183 #[test]
1184 fn test_cmd_convert_csv_to_parquet() {
1185 let temp_dir = tempfile::tempdir()
1186 .ok()
1187 .unwrap_or_else(|| panic!("Should create temp dir"));
1188 let csv_path = temp_dir.path().join("input.csv");
1189 let parquet_path = temp_dir.path().join("output.parquet");
1190
1191 std::fs::write(&csv_path, "id,name\n1,foo\n2,bar\n").unwrap();
1192
1193 let result = cmd_convert(&csv_path, &parquet_path);
1194 assert!(result.is_ok());
1195 assert!(parquet_path.exists());
1196 }
1197
1198 #[test]
1199 fn test_cmd_convert_json_to_parquet() {
1200 let temp_dir = tempfile::tempdir()
1201 .ok()
1202 .unwrap_or_else(|| panic!("Should create temp dir"));
1203 let json_path = temp_dir.path().join("input.json");
1204 let parquet_path = temp_dir.path().join("output.parquet");
1205
1206 std::fs::write(
1207 &json_path,
1208 r#"{"id":1,"name":"foo"}
1209{"id":2,"name":"bar"}"#,
1210 )
1211 .unwrap();
1212
1213 let result = cmd_convert(&json_path, &parquet_path);
1214 assert!(result.is_ok());
1215 assert!(parquet_path.exists());
1216 }
1217
1218 #[test]
1219 fn test_save_dataset_jsonl() {
1220 let temp_dir = tempfile::tempdir()
1221 .ok()
1222 .unwrap_or_else(|| panic!("Should create temp dir"));
1223 let path = temp_dir.path().join("output.jsonl");
1224
1225 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1226 let batch = arrow::array::RecordBatch::try_new(
1227 schema,
1228 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
1229 )
1230 .unwrap();
1231 let dataset = ArrowDataset::from_batch(batch).unwrap();
1232
1233 let result = save_dataset(&dataset, &path);
1234 assert!(result.is_ok());
1235 }
1236
1237 #[test]
1238 fn test_load_dataset_no_extension() {
1239 let path = PathBuf::from("file_without_extension");
1240 let result = load_dataset(&path);
1241 assert!(result.is_err());
1242 }
1243
1244 #[test]
1245 fn test_cmd_head_exact_rows() {
1246 let temp_dir = tempfile::tempdir()
1247 .ok()
1248 .unwrap_or_else(|| panic!("Should create temp dir"));
1249 let path = temp_dir.path().join("exact.parquet");
1250 create_test_parquet(&path, 10);
1251
1252 let result = cmd_head(&path, 10);
1254 assert!(result.is_ok());
1255 }
1256
1257 #[test]
1258 fn test_cmd_convert_parquet_to_parquet() {
1259 let temp_dir = tempfile::tempdir()
1260 .ok()
1261 .unwrap_or_else(|| panic!("Should create temp dir"));
1262 let input = temp_dir.path().join("input.parquet");
1263 let output = temp_dir.path().join("output.parquet");
1264 create_test_parquet(&input, 20);
1265
1266 let result = cmd_convert(&input, &output);
1267 assert!(result.is_ok());
1268
1269 let original = ArrowDataset::from_parquet(&input).unwrap();
1271 let converted = ArrowDataset::from_parquet(&output).unwrap();
1272 assert_eq!(original.len(), converted.len());
1273 }
1274
1275 #[test]
1276 fn test_get_format_all_types() {
1277 assert_eq!(get_format(Path::new("data.parquet")), "Parquet");
1278 assert_eq!(get_format(Path::new("data.arrow")), "Arrow IPC");
1279 assert_eq!(get_format(Path::new("data.ipc")), "Arrow IPC");
1280 assert_eq!(get_format(Path::new("data.csv")), "CSV");
1281 assert_eq!(get_format(Path::new("data.json")), "JSON");
1282 assert_eq!(get_format(Path::new("data.jsonl")), "JSON");
1283 assert_eq!(get_format(Path::new("data.txt")), "Unknown");
1284 assert_eq!(get_format(Path::new("data.yaml")), "Unknown");
1285 assert_eq!(get_format(Path::new("data")), "Unknown");
1286 }
1287}