1use std::ffi::OsString;
2use std::fs::{self, File};
3use std::io::{self, BufReader, BufWriter, Read, Write};
4use std::num::ParseIntError;
5use std::path::{Path, PathBuf};
6
7use flate2::Compression;
8use flate2::read::MultiGzDecoder;
9use flate2::write::GzEncoder;
10use indicatif::{ProgressBar, ProgressStyle};
11use rayon::prelude::*;
12use thiserror::Error;
13use tracing::{debug, info};
14use walkdir::WalkDir;
15
16const IO_BUFFER_SIZE: usize = 1024 * 1024;
17
18#[derive(Debug, Clone)]
19pub struct MergeConfig {
20 pub input_path: PathBuf,
21 pub num_files: usize,
22 pub output_path: PathBuf,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MergeReport {
27 pub npy_inputs: usize,
28 pub csv_gz_inputs: usize,
29 pub npy_outputs: Vec<PathBuf>,
30 pub csv_gz_outputs: Vec<PathBuf>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct DiscoveredFiles {
35 pub npy_files: Vec<PathBuf>,
36 pub csv_gz_files: Vec<PathBuf>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct OutputPlan {
41 pub npy_outputs: Vec<PathBuf>,
42 pub csv_gz_outputs: Vec<PathBuf>,
43}
44
45#[derive(Debug, Error)]
46pub enum MergeError {
47 #[error("`--num-files` must be at least 1, got {0}")]
48 InvalidNumFiles(usize),
49 #[error("failed to read metadata for path {path}: {source}")]
50 ReadPathMetadata {
51 path: PathBuf,
52 #[source]
53 source: io::Error,
54 },
55 #[error("input path is not a directory: {0}")]
56 InputPathNotDirectory(PathBuf),
57 #[error("failed to walk input directory {path}: {source}")]
58 WalkInputDirectory {
59 path: PathBuf,
60 #[source]
61 source: walkdir::Error,
62 },
63 #[error("output path exists and is not a directory: {0}")]
64 OutputPathNotDirectory(PathBuf),
65 #[error("failed to create directory {path}: {source}")]
66 CreateDirectory {
67 path: PathBuf,
68 #[source]
69 source: io::Error,
70 },
71 #[error(
72 "shard count mismatch for {file_type}: {input_shards} input shard sets but {output_paths} output paths"
73 )]
74 ShardConfigurationMismatch {
75 file_type: &'static str,
76 input_shards: usize,
77 output_paths: usize,
78 },
79 #[error("failed to open source file {path}: {source}")]
80 OpenSourceFile {
81 path: PathBuf,
82 #[source]
83 source: io::Error,
84 },
85 #[error("failed to create destination file {path}: {source}")]
86 CreateDestinationFile {
87 path: PathBuf,
88 #[source]
89 source: io::Error,
90 },
91 #[error("failed to copy data from {source_path} to {destination_path}: {source}")]
92 CopyFileData {
93 source_path: PathBuf,
94 destination_path: PathBuf,
95 #[source]
96 source: io::Error,
97 },
98 #[error("failed to flush destination file {path}: {source}")]
99 FlushDestinationFile {
100 path: PathBuf,
101 #[source]
102 source: io::Error,
103 },
104 #[error("failed to read CSV record at {path}:{row}: {source}")]
105 ReadCsvRecord {
106 path: PathBuf,
107 row: usize,
108 #[source]
109 source: csv::Error,
110 },
111 #[error("invalid CSV record at {path}:{row}: expected at least 5 columns, found {columns}")]
112 InvalidCsvRecord {
113 path: PathBuf,
114 row: usize,
115 columns: usize,
116 },
117 #[error("failed to parse `{column}` at {path}:{row}: `{value}` ({source})")]
118 ParseCsvField {
119 path: PathBuf,
120 row: usize,
121 column: &'static str,
122 value: String,
123 #[source]
124 source: ParseIntError,
125 },
126 #[error("invalid metadata span at {path}:{row}: end ({end}) is less than start ({start})")]
127 InvalidCsvSpan {
128 path: PathBuf,
129 row: usize,
130 start: u64,
131 end: u64,
132 },
133 #[error(
134 "overflow while remapping metadata span at {path}:{row}: start={start}, length={length}"
135 )]
136 CsvSpanOverflow {
137 path: PathBuf,
138 row: usize,
139 start: u64,
140 length: u64,
141 },
142 #[error("failed to write CSV record to {path}: {source}")]
143 WriteCsvRecord {
144 path: PathBuf,
145 #[source]
146 source: csv::Error,
147 },
148}
149
150pub fn merge_files(config: &MergeConfig) -> Result<MergeReport, MergeError> {
151 validate_config(config)?;
152 let discovered = discover_files(&config.input_path)?;
153 let plan = build_output_plan(&config.output_path, config.num_files)?;
154 create_output_directories(config, &plan)?;
155
156 info!(
157 input_path = %config.input_path.display(),
158 output_path = %config.output_path.display(),
159 num_files = config.num_files,
160 "starting merge",
161 );
162 info!(
163 npy_files = discovered.npy_files.len(),
164 csv_gz_files = discovered.csv_gz_files.len(),
165 "discovered source files",
166 );
167 debug!(?plan, "resolved output files");
168
169 let npy_shards = shard_paths(&discovered.npy_files, config.num_files);
170 let csv_gz_shards = shard_paths(&discovered.csv_gz_files, config.num_files);
171 let progress =
172 build_progress_bar((discovered.npy_files.len() + discovered.csv_gz_files.len()) as u64);
173
174 let (npy_result, csv_result) = rayon::join(
175 || merge_npy_shards(&npy_shards, &plan.npy_outputs, &progress),
176 || merge_csv_gz_shards(&csv_gz_shards, &plan.csv_gz_outputs, &progress),
177 );
178 npy_result?;
179 csv_result?;
180
181 progress.finish_with_message("merge complete");
182 info!("merge complete");
183
184 Ok(MergeReport {
185 npy_inputs: discovered.npy_files.len(),
186 csv_gz_inputs: discovered.csv_gz_files.len(),
187 npy_outputs: plan.npy_outputs,
188 csv_gz_outputs: plan.csv_gz_outputs,
189 })
190}
191
192fn validate_config(config: &MergeConfig) -> Result<(), MergeError> {
193 if config.num_files == 0 {
194 return Err(MergeError::InvalidNumFiles(config.num_files));
195 }
196
197 let metadata =
198 fs::metadata(&config.input_path).map_err(|source| MergeError::ReadPathMetadata {
199 path: config.input_path.clone(),
200 source,
201 })?;
202 if !metadata.is_dir() {
203 return Err(MergeError::InputPathNotDirectory(config.input_path.clone()));
204 }
205
206 Ok(())
207}
208
209pub fn discover_files(input_path: &Path) -> Result<DiscoveredFiles, MergeError> {
210 let mut npy_files = Vec::new();
211 let mut csv_gz_files = Vec::new();
212
213 for entry in WalkDir::new(input_path) {
214 let entry = entry.map_err(|source| MergeError::WalkInputDirectory {
215 path: input_path.to_path_buf(),
216 source,
217 })?;
218 if !entry.file_type().is_file() {
219 continue;
220 }
221
222 let path = entry.into_path();
223 if is_npy_file(&path) {
224 npy_files.push(path);
225 } else if is_csv_gz_file(&path) {
226 csv_gz_files.push(path);
227 }
228 }
229
230 npy_files.sort();
231 csv_gz_files.sort();
232
233 Ok(DiscoveredFiles {
234 npy_files,
235 csv_gz_files,
236 })
237}
238
239pub fn build_output_plan(output_path: &Path, num_files: usize) -> Result<OutputPlan, MergeError> {
240 if num_files == 0 {
241 return Err(MergeError::InvalidNumFiles(num_files));
242 }
243
244 if num_files == 1 {
245 return Ok(OutputPlan {
246 npy_outputs: vec![append_extension(output_path, "npy")],
247 csv_gz_outputs: vec![append_extension(output_path, "csv.gz")],
248 });
249 }
250
251 let npy_outputs = (0..num_files)
252 .map(|index| output_path.join(format!("{index:08}.npy")))
253 .collect();
254 let csv_gz_outputs = (0..num_files)
255 .map(|index| output_path.join(format!("{index:08}.csv.gz")))
256 .collect();
257
258 Ok(OutputPlan {
259 npy_outputs,
260 csv_gz_outputs,
261 })
262}
263
264fn create_output_directories(config: &MergeConfig, plan: &OutputPlan) -> Result<(), MergeError> {
265 if config.num_files == 1 {
266 for output in plan.npy_outputs.iter().chain(plan.csv_gz_outputs.iter()) {
267 create_parent_dir(output)?;
268 }
269 return Ok(());
270 }
271
272 if config.output_path.exists() {
273 let metadata =
274 fs::metadata(&config.output_path).map_err(|source| MergeError::ReadPathMetadata {
275 path: config.output_path.clone(),
276 source,
277 })?;
278 if !metadata.is_dir() {
279 return Err(MergeError::OutputPathNotDirectory(
280 config.output_path.clone(),
281 ));
282 }
283 return Ok(());
284 }
285
286 fs::create_dir_all(&config.output_path).map_err(|source| MergeError::CreateDirectory {
287 path: config.output_path.clone(),
288 source,
289 })?;
290 Ok(())
291}
292
293fn create_parent_dir(path: &Path) -> Result<(), MergeError> {
294 let Some(parent) = path.parent() else {
295 return Ok(());
296 };
297 if parent.as_os_str().is_empty() {
298 return Ok(());
299 }
300 fs::create_dir_all(parent).map_err(|source| MergeError::CreateDirectory {
301 path: parent.to_path_buf(),
302 source,
303 })?;
304 Ok(())
305}
306
307pub fn shard_paths(paths: &[PathBuf], num_shards: usize) -> Vec<Vec<PathBuf>> {
308 let mut shards = vec![Vec::new(); num_shards];
309 for (index, path) in paths.iter().enumerate() {
310 shards[index % num_shards].push(path.clone());
311 }
312 shards
313}
314
315pub fn merge_npy_shards(
316 input_shards: &[Vec<PathBuf>],
317 output_paths: &[PathBuf],
318 progress: &ProgressBar,
319) -> Result<(), MergeError> {
320 if input_shards.len() != output_paths.len() {
321 return Err(MergeError::ShardConfigurationMismatch {
322 file_type: "npy",
323 input_shards: input_shards.len(),
324 output_paths: output_paths.len(),
325 });
326 }
327
328 (0..output_paths.len())
329 .into_par_iter()
330 .try_for_each(|index| {
331 let shard_inputs = &input_shards[index];
332 let shard_output = &output_paths[index];
333 let shard_progress = progress.clone();
334 merge_single_npy_shard(shard_inputs, shard_output, &shard_progress)
335 })
336}
337
338pub fn merge_csv_gz_shards(
339 input_shards: &[Vec<PathBuf>],
340 output_paths: &[PathBuf],
341 progress: &ProgressBar,
342) -> Result<(), MergeError> {
343 if input_shards.len() != output_paths.len() {
344 return Err(MergeError::ShardConfigurationMismatch {
345 file_type: "csv.gz",
346 input_shards: input_shards.len(),
347 output_paths: output_paths.len(),
348 });
349 }
350
351 (0..output_paths.len())
352 .into_par_iter()
353 .try_for_each(|index| {
354 let shard_inputs = &input_shards[index];
355 let shard_output = &output_paths[index];
356 let shard_progress = progress.clone();
357 merge_single_csv_gz_shard(shard_inputs, shard_output, &shard_progress)
358 })
359}
360
361fn merge_single_npy_shard(
362 input_paths: &[PathBuf],
363 output_path: &Path,
364 progress: &ProgressBar,
365) -> Result<(), MergeError> {
366 let output_file =
367 File::create(output_path).map_err(|source| MergeError::CreateDestinationFile {
368 path: output_path.to_path_buf(),
369 source,
370 })?;
371 let mut writer = BufWriter::with_capacity(IO_BUFFER_SIZE, output_file);
372 let mut buffer = vec![0_u8; IO_BUFFER_SIZE];
373
374 for input_path in input_paths {
375 let input_file = File::open(input_path).map_err(|source| MergeError::OpenSourceFile {
376 path: input_path.clone(),
377 source,
378 })?;
379 let mut reader = BufReader::with_capacity(IO_BUFFER_SIZE, input_file);
380 copy_reader_to_writer(&mut reader, &mut writer, &mut buffer).map_err(|source| {
381 MergeError::CopyFileData {
382 source_path: input_path.clone(),
383 destination_path: output_path.to_path_buf(),
384 source,
385 }
386 })?;
387 progress.inc(1);
388 }
389
390 writer
391 .flush()
392 .map_err(|source| MergeError::FlushDestinationFile {
393 path: output_path.to_path_buf(),
394 source,
395 })?;
396 Ok(())
397}
398
399fn merge_single_csv_gz_shard(
400 input_paths: &[PathBuf],
401 output_path: &Path,
402 progress: &ProgressBar,
403) -> Result<(), MergeError> {
404 let output_file =
405 File::create(output_path).map_err(|source| MergeError::CreateDestinationFile {
406 path: output_path.to_path_buf(),
407 source,
408 })?;
409 let writer = BufWriter::with_capacity(IO_BUFFER_SIZE, output_file);
410 let encoder = GzEncoder::new(writer, Compression::default());
411 let mut csv_writer = csv::WriterBuilder::new()
412 .has_headers(false)
413 .from_writer(encoder);
414 let mut next_start = 0_u64;
415 let mut wrote_header = false;
416
417 for input_path in input_paths {
418 let input_file = File::open(input_path).map_err(|source| MergeError::OpenSourceFile {
419 path: input_path.clone(),
420 source,
421 })?;
422 let reader = BufReader::with_capacity(IO_BUFFER_SIZE, input_file);
423 let decoder = MultiGzDecoder::new(reader);
424 let mut csv_reader = csv::ReaderBuilder::new()
425 .has_headers(false)
426 .flexible(true)
427 .from_reader(decoder);
428
429 for (row_index, maybe_record) in csv_reader.records().enumerate() {
430 let row = row_index + 1;
431 let record = maybe_record.map_err(|source| MergeError::ReadCsvRecord {
432 path: input_path.clone(),
433 row,
434 source,
435 })?;
436 if record.is_empty() {
437 continue;
438 }
439
440 if is_metadata_header(&record) {
441 if !wrote_header {
442 csv_writer.write_record(&record).map_err(|source| {
443 MergeError::WriteCsvRecord {
444 path: output_path.to_path_buf(),
445 source,
446 }
447 })?;
448 wrote_header = true;
449 }
450 continue;
451 }
452
453 if record.len() < 5 {
454 return Err(MergeError::InvalidCsvRecord {
455 path: input_path.clone(),
456 row,
457 columns: record.len(),
458 });
459 }
460
461 let start = parse_csv_u64(&record, 0, input_path, row, "start")?;
462 let end = parse_csv_u64(&record, 1, input_path, row, "end")?;
463 if end < start {
464 return Err(MergeError::InvalidCsvSpan {
465 path: input_path.clone(),
466 row,
467 start,
468 end,
469 });
470 }
471
472 let length = end - start;
473 let new_start = next_start;
474 let new_end =
475 new_start
476 .checked_add(length)
477 .ok_or_else(|| MergeError::CsvSpanOverflow {
478 path: input_path.clone(),
479 row,
480 start: new_start,
481 length,
482 })?;
483
484 let mut output_record = record
485 .iter()
486 .map(std::string::ToString::to_string)
487 .collect::<Vec<_>>();
488 output_record[0] = new_start.to_string();
489 output_record[1] = new_end.to_string();
490
491 csv_writer.write_record(&output_record).map_err(|source| {
492 MergeError::WriteCsvRecord {
493 path: output_path.to_path_buf(),
494 source,
495 }
496 })?;
497 next_start = new_end;
498 }
499
500 progress.inc(1);
501 }
502
503 csv_writer
504 .flush()
505 .map_err(|source| MergeError::WriteCsvRecord {
506 path: output_path.to_path_buf(),
507 source: source.into(),
508 })?;
509 let encoder = csv_writer
510 .into_inner()
511 .map_err(|source| MergeError::WriteCsvRecord {
512 path: output_path.to_path_buf(),
513 source: source.into_error().into(),
514 })?;
515 let mut writer = encoder
516 .finish()
517 .map_err(|source| MergeError::FlushDestinationFile {
518 path: output_path.to_path_buf(),
519 source,
520 })?;
521 writer
522 .flush()
523 .map_err(|source| MergeError::FlushDestinationFile {
524 path: output_path.to_path_buf(),
525 source,
526 })?;
527 Ok(())
528}
529
530fn parse_csv_u64(
531 record: &csv::StringRecord,
532 index: usize,
533 path: &Path,
534 row: usize,
535 column: &'static str,
536) -> Result<u64, MergeError> {
537 let value = record
538 .get(index)
539 .ok_or_else(|| MergeError::InvalidCsvRecord {
540 path: path.to_path_buf(),
541 row,
542 columns: record.len(),
543 })?;
544 value
545 .parse::<u64>()
546 .map_err(|source| MergeError::ParseCsvField {
547 path: path.to_path_buf(),
548 row,
549 column,
550 value: value.to_string(),
551 source,
552 })
553}
554
555fn is_metadata_header(record: &csv::StringRecord) -> bool {
556 let Some(first) = record.get(0) else {
557 return false;
558 };
559 let Some(second) = record.get(1) else {
560 return false;
561 };
562 first.eq_ignore_ascii_case("start") && second.eq_ignore_ascii_case("end")
563}
564
565fn copy_reader_to_writer<R: Read, W: Write>(
566 reader: &mut R,
567 writer: &mut W,
568 buffer: &mut [u8],
569) -> io::Result<u64> {
570 let mut total_written = 0_u64;
571 loop {
572 let read_bytes = reader.read(buffer)?;
573 if read_bytes == 0 {
574 break;
575 }
576 writer.write_all(&buffer[..read_bytes])?;
577 total_written += read_bytes as u64;
578 }
579 Ok(total_written)
580}
581
582fn is_npy_file(path: &Path) -> bool {
583 path.extension().is_some_and(|extension| extension == "npy")
584}
585
586fn is_csv_gz_file(path: &Path) -> bool {
587 path.file_name()
588 .and_then(|file_name| file_name.to_str())
589 .is_some_and(|file_name| file_name.ends_with(".csv.gz"))
590}
591
592fn append_extension(path: &Path, extension: &str) -> PathBuf {
593 let mut normalized = PathBuf::new();
594 for component in path.components() {
595 normalized.push(component.as_os_str());
596 }
597 if normalized.as_os_str().is_empty() {
598 normalized = path.to_path_buf();
599 }
600
601 let mut file_name = OsString::from(normalized.as_os_str());
602 file_name.push(".");
603 file_name.push(extension);
604 PathBuf::from(file_name)
605}
606
607fn build_progress_bar(total_files: u64) -> ProgressBar {
608 if cfg!(test) {
609 return ProgressBar::hidden();
610 }
611
612 let progress = ProgressBar::new(total_files);
613 let style = ProgressStyle::with_template(
614 "[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} files ({eta}) {msg}",
615 )
616 .unwrap_or_else(|_| ProgressStyle::default_bar())
617 .progress_chars("=>-");
618 progress.set_style(style);
619 progress.set_message("merging");
620 progress
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use std::io::Read;
627 use tempfile::tempdir;
628
629 fn write_gzip(path: &Path, content: &str) {
630 let file = File::create(path).expect("create gzip input");
631 let mut encoder = GzEncoder::new(file, Compression::default());
632 encoder
633 .write_all(content.as_bytes())
634 .expect("write gzip content");
635 encoder.finish().expect("finish gzip input");
636 }
637
638 fn read_gzip(path: &Path) -> String {
639 let file = File::open(path).expect("open gzip output");
640 let reader = BufReader::new(file);
641 let mut decoder = MultiGzDecoder::new(reader);
642 let mut content = String::new();
643 decoder
644 .read_to_string(&mut content)
645 .expect("read gzip output");
646 content
647 }
648
649 #[test]
650 fn discovers_files_recursively_and_ignores_others() {
651 let temp = tempdir().expect("create tempdir");
652 let nested = temp.path().join("nested").join("inner");
653 fs::create_dir_all(&nested).expect("create nested dirs");
654
655 fs::write(temp.path().join("b.npy"), [1_u8, 2_u8]).expect("write npy");
656 fs::write(nested.join("a.npy"), [3_u8]).expect("write nested npy");
657 fs::write(temp.path().join("skip.txt"), "ignore").expect("write skip file");
658 write_gzip(&nested.join("z.csv.gz"), "zeta\n");
659 write_gzip(&temp.path().join("m.csv.gz"), "mu\n");
660
661 let discovered = discover_files(temp.path()).expect("discover files");
662
663 assert_eq!(discovered.npy_files.len(), 2);
664 assert_eq!(discovered.csv_gz_files.len(), 2);
665 assert!(discovered.npy_files[0] < discovered.npy_files[1]);
666 assert!(discovered.csv_gz_files[0] < discovered.csv_gz_files[1]);
667 }
668
669 #[test]
670 fn builds_single_output_plan() {
671 let output = PathBuf::from("/tmp/output/base");
672 let plan = build_output_plan(&output, 1).expect("build plan");
673 assert_eq!(
674 plan.npy_outputs,
675 vec![PathBuf::from("/tmp/output/base.npy")]
676 );
677 assert_eq!(
678 plan.csv_gz_outputs,
679 vec![PathBuf::from("/tmp/output/base.csv.gz")]
680 );
681 }
682
683 #[test]
684 fn builds_sharded_output_plan() {
685 let output = PathBuf::from("/tmp/output/shards");
686 let plan = build_output_plan(&output, 3).expect("build plan");
687 assert_eq!(
688 plan.npy_outputs,
689 vec![
690 PathBuf::from("/tmp/output/shards/00000000.npy"),
691 PathBuf::from("/tmp/output/shards/00000001.npy"),
692 PathBuf::from("/tmp/output/shards/00000002.npy"),
693 ]
694 );
695 assert_eq!(
696 plan.csv_gz_outputs,
697 vec![
698 PathBuf::from("/tmp/output/shards/00000000.csv.gz"),
699 PathBuf::from("/tmp/output/shards/00000001.csv.gz"),
700 PathBuf::from("/tmp/output/shards/00000002.csv.gz"),
701 ]
702 );
703 }
704
705 #[test]
706 fn builds_single_output_plan_with_trailing_separator() {
707 let output = PathBuf::from("/tmp/output/base/");
708 let plan = build_output_plan(&output, 1).expect("build plan");
709 assert_eq!(
710 plan.npy_outputs,
711 vec![PathBuf::from("/tmp/output/base.npy")]
712 );
713 assert_eq!(
714 plan.csv_gz_outputs,
715 vec![PathBuf::from("/tmp/output/base.csv.gz")]
716 );
717 }
718
719 #[test]
720 fn shards_paths_round_robin() {
721 let paths = vec![
722 PathBuf::from("a.npy"),
723 PathBuf::from("b.npy"),
724 PathBuf::from("c.npy"),
725 PathBuf::from("d.npy"),
726 PathBuf::from("e.npy"),
727 ];
728
729 let shards = shard_paths(&paths, 2);
730 assert_eq!(
731 shards[0],
732 vec![
733 PathBuf::from("a.npy"),
734 PathBuf::from("c.npy"),
735 PathBuf::from("e.npy")
736 ]
737 );
738 assert_eq!(
739 shards[1],
740 vec![PathBuf::from("b.npy"), PathBuf::from("d.npy")]
741 );
742 }
743
744 #[test]
745 fn merges_npy_shards_by_byte_concatenation() {
746 let temp = tempdir().expect("create tempdir");
747 let inputs = temp.path().join("inputs");
748 let outputs = temp.path().join("outputs");
749 fs::create_dir_all(&inputs).expect("create input dir");
750 fs::create_dir_all(&outputs).expect("create output dir");
751
752 let a = inputs.join("a.npy");
753 let b = inputs.join("b.npy");
754 let c = inputs.join("c.npy");
755 fs::write(&a, [1_u8, 2_u8]).expect("write a");
756 fs::write(&b, [3_u8]).expect("write b");
757 fs::write(&c, [4_u8, 5_u8]).expect("write c");
758
759 let shards = vec![vec![a.clone(), c.clone()], vec![b.clone()]];
760 let out0 = outputs.join("00000000.npy");
761 let out1 = outputs.join("00000001.npy");
762 let progress = ProgressBar::hidden();
763
764 merge_npy_shards(&shards, &[out0.clone(), out1.clone()], &progress)
765 .expect("merge npy shards");
766
767 assert_eq!(
768 fs::read(out0).expect("read out0"),
769 vec![1_u8, 2_u8, 4_u8, 5_u8]
770 );
771 assert_eq!(fs::read(out1).expect("read out1"), vec![3_u8]);
772 }
773
774 #[test]
775 fn merges_csv_gz_shards_by_decompress_and_recompress() {
776 let temp = tempdir().expect("create tempdir");
777 let inputs = temp.path().join("inputs");
778 let outputs = temp.path().join("outputs");
779 fs::create_dir_all(&inputs).expect("create input dir");
780 fs::create_dir_all(&outputs).expect("create output dir");
781
782 let first = inputs.join("a.csv.gz");
783 let second = inputs.join("b.csv.gz");
784 let third = inputs.join("c.csv.gz");
785 write_gzip(
786 &first,
787 "start,end,id,src,loc\n0,2,id-a,src-a,1\n2,5,id-b,src-a,2\n",
788 );
789 write_gzip(&second, "start,end,id,src,loc\n0,1,id-c,src-b,1\n");
790 write_gzip(&third, "start,end,id,src,loc\n0,4,id-d,src-c,9\n");
791
792 let shards = vec![vec![first.clone(), third.clone()], vec![second.clone()]];
793 let out0 = outputs.join("00000000.csv.gz");
794 let out1 = outputs.join("00000001.csv.gz");
795 let progress = ProgressBar::hidden();
796
797 merge_csv_gz_shards(&shards, &[out0.clone(), out1.clone()], &progress)
798 .expect("merge csv.gz shards");
799
800 assert_eq!(
801 read_gzip(&out0),
802 "start,end,id,src,loc\n0,2,id-a,src-a,1\n2,5,id-b,src-a,2\n5,9,id-d,src-c,9\n"
803 );
804 assert_eq!(read_gzip(&out1), "start,end,id,src,loc\n0,1,id-c,src-b,1\n");
805 }
806
807 #[test]
808 fn runs_end_to_end_with_sharded_outputs() {
809 let temp = tempdir().expect("create tempdir");
810 let input_root = temp.path().join("input");
811 let nested = input_root.join("nested");
812 fs::create_dir_all(&nested).expect("create nested input dir");
813
814 fs::write(input_root.join("a.npy"), [1_u8]).expect("write a.npy");
815 fs::write(nested.join("b.npy"), [2_u8]).expect("write b.npy");
816 write_gzip(&input_root.join("a.csv.gz"), "0,1,id-a,src-a,1\n");
817 write_gzip(&nested.join("b.csv.gz"), "0,1,id-b,src-b,1\n");
818 fs::write(input_root.join("ignore.bin"), [9_u8]).expect("write ignored file");
819
820 let output_path = temp.path().join("sharded");
821 let config = MergeConfig {
822 input_path: input_root.clone(),
823 num_files: 2,
824 output_path: output_path.clone(),
825 };
826
827 let report = merge_files(&config).expect("run merge");
828
829 assert_eq!(report.npy_inputs, 2);
830 assert_eq!(report.csv_gz_inputs, 2);
831 assert!(output_path.join("00000000.npy").exists());
832 assert!(output_path.join("00000001.npy").exists());
833 assert!(output_path.join("00000000.csv.gz").exists());
834 assert!(output_path.join("00000001.csv.gz").exists());
835 }
836
837 #[test]
838 fn rejects_zero_num_files() {
839 let output = PathBuf::from("anything");
840 let err = build_output_plan(&output, 0).expect_err("expected invalid num-files error");
841 match err {
842 MergeError::InvalidNumFiles(value) => assert_eq!(value, 0),
843 other => panic!("unexpected error: {other}"),
844 }
845 }
846}