Skip to main content

reshard_tokenized/
lib.rs

1use std::ffi::OsString;
2use std::fs::{self, File};
3use std::io::{self, BufReader, BufWriter, Read, Write};
4use std::num::ParseIntError;
5use std::path::{Path, PathBuf};
6
7use flate2::Compression;
8use flate2::read::MultiGzDecoder;
9use flate2::write::GzEncoder;
10use indicatif::{ProgressBar, ProgressStyle};
11use rayon::prelude::*;
12use thiserror::Error;
13use tracing::{debug, info};
14use walkdir::WalkDir;
15
16const IO_BUFFER_SIZE: usize = 1024 * 1024;
17
18#[derive(Debug, Clone)]
19pub struct MergeConfig {
20    pub input_path: PathBuf,
21    pub num_files: usize,
22    pub output_path: PathBuf,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MergeReport {
27    pub npy_inputs: usize,
28    pub csv_gz_inputs: usize,
29    pub npy_outputs: Vec<PathBuf>,
30    pub csv_gz_outputs: Vec<PathBuf>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct DiscoveredFiles {
35    pub npy_files: Vec<PathBuf>,
36    pub csv_gz_files: Vec<PathBuf>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct OutputPlan {
41    pub npy_outputs: Vec<PathBuf>,
42    pub csv_gz_outputs: Vec<PathBuf>,
43}
44
45#[derive(Debug, Error)]
46pub enum MergeError {
47    #[error("`--num-files` must be at least 1, got {0}")]
48    InvalidNumFiles(usize),
49    #[error("failed to read metadata for path {path}: {source}")]
50    ReadPathMetadata {
51        path: PathBuf,
52        #[source]
53        source: io::Error,
54    },
55    #[error("input path is not a directory: {0}")]
56    InputPathNotDirectory(PathBuf),
57    #[error("failed to walk input directory {path}: {source}")]
58    WalkInputDirectory {
59        path: PathBuf,
60        #[source]
61        source: walkdir::Error,
62    },
63    #[error("output path exists and is not a directory: {0}")]
64    OutputPathNotDirectory(PathBuf),
65    #[error("failed to create directory {path}: {source}")]
66    CreateDirectory {
67        path: PathBuf,
68        #[source]
69        source: io::Error,
70    },
71    #[error(
72        "shard count mismatch for {file_type}: {input_shards} input shard sets but {output_paths} output paths"
73    )]
74    ShardConfigurationMismatch {
75        file_type: &'static str,
76        input_shards: usize,
77        output_paths: usize,
78    },
79    #[error("failed to open source file {path}: {source}")]
80    OpenSourceFile {
81        path: PathBuf,
82        #[source]
83        source: io::Error,
84    },
85    #[error("failed to create destination file {path}: {source}")]
86    CreateDestinationFile {
87        path: PathBuf,
88        #[source]
89        source: io::Error,
90    },
91    #[error("failed to copy data from {source_path} to {destination_path}: {source}")]
92    CopyFileData {
93        source_path: PathBuf,
94        destination_path: PathBuf,
95        #[source]
96        source: io::Error,
97    },
98    #[error("failed to flush destination file {path}: {source}")]
99    FlushDestinationFile {
100        path: PathBuf,
101        #[source]
102        source: io::Error,
103    },
104    #[error("failed to read CSV record at {path}:{row}: {source}")]
105    ReadCsvRecord {
106        path: PathBuf,
107        row: usize,
108        #[source]
109        source: csv::Error,
110    },
111    #[error("invalid CSV record at {path}:{row}: expected at least 5 columns, found {columns}")]
112    InvalidCsvRecord {
113        path: PathBuf,
114        row: usize,
115        columns: usize,
116    },
117    #[error("failed to parse `{column}` at {path}:{row}: `{value}` ({source})")]
118    ParseCsvField {
119        path: PathBuf,
120        row: usize,
121        column: &'static str,
122        value: String,
123        #[source]
124        source: ParseIntError,
125    },
126    #[error("invalid metadata span at {path}:{row}: end ({end}) is less than start ({start})")]
127    InvalidCsvSpan {
128        path: PathBuf,
129        row: usize,
130        start: u64,
131        end: u64,
132    },
133    #[error(
134        "overflow while remapping metadata span at {path}:{row}: start={start}, length={length}"
135    )]
136    CsvSpanOverflow {
137        path: PathBuf,
138        row: usize,
139        start: u64,
140        length: u64,
141    },
142    #[error("failed to write CSV record to {path}: {source}")]
143    WriteCsvRecord {
144        path: PathBuf,
145        #[source]
146        source: csv::Error,
147    },
148}
149
150pub fn merge_files(config: &MergeConfig) -> Result<MergeReport, MergeError> {
151    validate_config(config)?;
152    let discovered = discover_files(&config.input_path)?;
153    let plan = build_output_plan(&config.output_path, config.num_files)?;
154    create_output_directories(config, &plan)?;
155
156    info!(
157        input_path = %config.input_path.display(),
158        output_path = %config.output_path.display(),
159        num_files = config.num_files,
160        "starting merge",
161    );
162    info!(
163        npy_files = discovered.npy_files.len(),
164        csv_gz_files = discovered.csv_gz_files.len(),
165        "discovered source files",
166    );
167    debug!(?plan, "resolved output files");
168
169    let npy_shards = shard_paths(&discovered.npy_files, config.num_files);
170    let csv_gz_shards = shard_paths(&discovered.csv_gz_files, config.num_files);
171    let progress =
172        build_progress_bar((discovered.npy_files.len() + discovered.csv_gz_files.len()) as u64);
173
174    let (npy_result, csv_result) = rayon::join(
175        || merge_npy_shards(&npy_shards, &plan.npy_outputs, &progress),
176        || merge_csv_gz_shards(&csv_gz_shards, &plan.csv_gz_outputs, &progress),
177    );
178    npy_result?;
179    csv_result?;
180
181    progress.finish_with_message("merge complete");
182    info!("merge complete");
183
184    Ok(MergeReport {
185        npy_inputs: discovered.npy_files.len(),
186        csv_gz_inputs: discovered.csv_gz_files.len(),
187        npy_outputs: plan.npy_outputs,
188        csv_gz_outputs: plan.csv_gz_outputs,
189    })
190}
191
192fn validate_config(config: &MergeConfig) -> Result<(), MergeError> {
193    if config.num_files == 0 {
194        return Err(MergeError::InvalidNumFiles(config.num_files));
195    }
196
197    let metadata =
198        fs::metadata(&config.input_path).map_err(|source| MergeError::ReadPathMetadata {
199            path: config.input_path.clone(),
200            source,
201        })?;
202    if !metadata.is_dir() {
203        return Err(MergeError::InputPathNotDirectory(config.input_path.clone()));
204    }
205
206    Ok(())
207}
208
209pub fn discover_files(input_path: &Path) -> Result<DiscoveredFiles, MergeError> {
210    let mut npy_files = Vec::new();
211    let mut csv_gz_files = Vec::new();
212
213    for entry in WalkDir::new(input_path) {
214        let entry = entry.map_err(|source| MergeError::WalkInputDirectory {
215            path: input_path.to_path_buf(),
216            source,
217        })?;
218        if !entry.file_type().is_file() {
219            continue;
220        }
221
222        let path = entry.into_path();
223        if is_npy_file(&path) {
224            npy_files.push(path);
225        } else if is_csv_gz_file(&path) {
226            csv_gz_files.push(path);
227        }
228    }
229
230    npy_files.sort();
231    csv_gz_files.sort();
232
233    Ok(DiscoveredFiles {
234        npy_files,
235        csv_gz_files,
236    })
237}
238
239pub fn build_output_plan(output_path: &Path, num_files: usize) -> Result<OutputPlan, MergeError> {
240    if num_files == 0 {
241        return Err(MergeError::InvalidNumFiles(num_files));
242    }
243
244    if num_files == 1 {
245        return Ok(OutputPlan {
246            npy_outputs: vec![append_extension(output_path, "npy")],
247            csv_gz_outputs: vec![append_extension(output_path, "csv.gz")],
248        });
249    }
250
251    let npy_outputs = (0..num_files)
252        .map(|index| output_path.join(format!("{index:08}.npy")))
253        .collect();
254    let csv_gz_outputs = (0..num_files)
255        .map(|index| output_path.join(format!("{index:08}.csv.gz")))
256        .collect();
257
258    Ok(OutputPlan {
259        npy_outputs,
260        csv_gz_outputs,
261    })
262}
263
264fn create_output_directories(config: &MergeConfig, plan: &OutputPlan) -> Result<(), MergeError> {
265    if config.num_files == 1 {
266        for output in plan.npy_outputs.iter().chain(plan.csv_gz_outputs.iter()) {
267            create_parent_dir(output)?;
268        }
269        return Ok(());
270    }
271
272    if config.output_path.exists() {
273        let metadata =
274            fs::metadata(&config.output_path).map_err(|source| MergeError::ReadPathMetadata {
275                path: config.output_path.clone(),
276                source,
277            })?;
278        if !metadata.is_dir() {
279            return Err(MergeError::OutputPathNotDirectory(
280                config.output_path.clone(),
281            ));
282        }
283        return Ok(());
284    }
285
286    fs::create_dir_all(&config.output_path).map_err(|source| MergeError::CreateDirectory {
287        path: config.output_path.clone(),
288        source,
289    })?;
290    Ok(())
291}
292
293fn create_parent_dir(path: &Path) -> Result<(), MergeError> {
294    let Some(parent) = path.parent() else {
295        return Ok(());
296    };
297    if parent.as_os_str().is_empty() {
298        return Ok(());
299    }
300    fs::create_dir_all(parent).map_err(|source| MergeError::CreateDirectory {
301        path: parent.to_path_buf(),
302        source,
303    })?;
304    Ok(())
305}
306
307pub fn shard_paths(paths: &[PathBuf], num_shards: usize) -> Vec<Vec<PathBuf>> {
308    let mut shards = vec![Vec::new(); num_shards];
309    for (index, path) in paths.iter().enumerate() {
310        shards[index % num_shards].push(path.clone());
311    }
312    shards
313}
314
315pub fn merge_npy_shards(
316    input_shards: &[Vec<PathBuf>],
317    output_paths: &[PathBuf],
318    progress: &ProgressBar,
319) -> Result<(), MergeError> {
320    if input_shards.len() != output_paths.len() {
321        return Err(MergeError::ShardConfigurationMismatch {
322            file_type: "npy",
323            input_shards: input_shards.len(),
324            output_paths: output_paths.len(),
325        });
326    }
327
328    (0..output_paths.len())
329        .into_par_iter()
330        .try_for_each(|index| {
331            let shard_inputs = &input_shards[index];
332            let shard_output = &output_paths[index];
333            let shard_progress = progress.clone();
334            merge_single_npy_shard(shard_inputs, shard_output, &shard_progress)
335        })
336}
337
338pub fn merge_csv_gz_shards(
339    input_shards: &[Vec<PathBuf>],
340    output_paths: &[PathBuf],
341    progress: &ProgressBar,
342) -> Result<(), MergeError> {
343    if input_shards.len() != output_paths.len() {
344        return Err(MergeError::ShardConfigurationMismatch {
345            file_type: "csv.gz",
346            input_shards: input_shards.len(),
347            output_paths: output_paths.len(),
348        });
349    }
350
351    (0..output_paths.len())
352        .into_par_iter()
353        .try_for_each(|index| {
354            let shard_inputs = &input_shards[index];
355            let shard_output = &output_paths[index];
356            let shard_progress = progress.clone();
357            merge_single_csv_gz_shard(shard_inputs, shard_output, &shard_progress)
358        })
359}
360
361fn merge_single_npy_shard(
362    input_paths: &[PathBuf],
363    output_path: &Path,
364    progress: &ProgressBar,
365) -> Result<(), MergeError> {
366    let output_file =
367        File::create(output_path).map_err(|source| MergeError::CreateDestinationFile {
368            path: output_path.to_path_buf(),
369            source,
370        })?;
371    let mut writer = BufWriter::with_capacity(IO_BUFFER_SIZE, output_file);
372    let mut buffer = vec![0_u8; IO_BUFFER_SIZE];
373
374    for input_path in input_paths {
375        let input_file = File::open(input_path).map_err(|source| MergeError::OpenSourceFile {
376            path: input_path.clone(),
377            source,
378        })?;
379        let mut reader = BufReader::with_capacity(IO_BUFFER_SIZE, input_file);
380        copy_reader_to_writer(&mut reader, &mut writer, &mut buffer).map_err(|source| {
381            MergeError::CopyFileData {
382                source_path: input_path.clone(),
383                destination_path: output_path.to_path_buf(),
384                source,
385            }
386        })?;
387        progress.inc(1);
388    }
389
390    writer
391        .flush()
392        .map_err(|source| MergeError::FlushDestinationFile {
393            path: output_path.to_path_buf(),
394            source,
395        })?;
396    Ok(())
397}
398
399fn merge_single_csv_gz_shard(
400    input_paths: &[PathBuf],
401    output_path: &Path,
402    progress: &ProgressBar,
403) -> Result<(), MergeError> {
404    let output_file =
405        File::create(output_path).map_err(|source| MergeError::CreateDestinationFile {
406            path: output_path.to_path_buf(),
407            source,
408        })?;
409    let writer = BufWriter::with_capacity(IO_BUFFER_SIZE, output_file);
410    let encoder = GzEncoder::new(writer, Compression::default());
411    let mut csv_writer = csv::WriterBuilder::new()
412        .has_headers(false)
413        .from_writer(encoder);
414    let mut next_start = 0_u64;
415    let mut wrote_header = false;
416
417    for input_path in input_paths {
418        let input_file = File::open(input_path).map_err(|source| MergeError::OpenSourceFile {
419            path: input_path.clone(),
420            source,
421        })?;
422        let reader = BufReader::with_capacity(IO_BUFFER_SIZE, input_file);
423        let decoder = MultiGzDecoder::new(reader);
424        let mut csv_reader = csv::ReaderBuilder::new()
425            .has_headers(false)
426            .flexible(true)
427            .from_reader(decoder);
428
429        for (row_index, maybe_record) in csv_reader.records().enumerate() {
430            let row = row_index + 1;
431            let record = maybe_record.map_err(|source| MergeError::ReadCsvRecord {
432                path: input_path.clone(),
433                row,
434                source,
435            })?;
436            if record.is_empty() {
437                continue;
438            }
439
440            if is_metadata_header(&record) {
441                if !wrote_header {
442                    csv_writer.write_record(&record).map_err(|source| {
443                        MergeError::WriteCsvRecord {
444                            path: output_path.to_path_buf(),
445                            source,
446                        }
447                    })?;
448                    wrote_header = true;
449                }
450                continue;
451            }
452
453            if record.len() < 5 {
454                return Err(MergeError::InvalidCsvRecord {
455                    path: input_path.clone(),
456                    row,
457                    columns: record.len(),
458                });
459            }
460
461            let start = parse_csv_u64(&record, 0, input_path, row, "start")?;
462            let end = parse_csv_u64(&record, 1, input_path, row, "end")?;
463            if end < start {
464                return Err(MergeError::InvalidCsvSpan {
465                    path: input_path.clone(),
466                    row,
467                    start,
468                    end,
469                });
470            }
471
472            let length = end - start;
473            let new_start = next_start;
474            let new_end =
475                new_start
476                    .checked_add(length)
477                    .ok_or_else(|| MergeError::CsvSpanOverflow {
478                        path: input_path.clone(),
479                        row,
480                        start: new_start,
481                        length,
482                    })?;
483
484            let mut output_record = record
485                .iter()
486                .map(std::string::ToString::to_string)
487                .collect::<Vec<_>>();
488            output_record[0] = new_start.to_string();
489            output_record[1] = new_end.to_string();
490
491            csv_writer.write_record(&output_record).map_err(|source| {
492                MergeError::WriteCsvRecord {
493                    path: output_path.to_path_buf(),
494                    source,
495                }
496            })?;
497            next_start = new_end;
498        }
499
500        progress.inc(1);
501    }
502
503    csv_writer
504        .flush()
505        .map_err(|source| MergeError::WriteCsvRecord {
506            path: output_path.to_path_buf(),
507            source: source.into(),
508        })?;
509    let encoder = csv_writer
510        .into_inner()
511        .map_err(|source| MergeError::WriteCsvRecord {
512            path: output_path.to_path_buf(),
513            source: source.into_error().into(),
514        })?;
515    let mut writer = encoder
516        .finish()
517        .map_err(|source| MergeError::FlushDestinationFile {
518            path: output_path.to_path_buf(),
519            source,
520        })?;
521    writer
522        .flush()
523        .map_err(|source| MergeError::FlushDestinationFile {
524            path: output_path.to_path_buf(),
525            source,
526        })?;
527    Ok(())
528}
529
530fn parse_csv_u64(
531    record: &csv::StringRecord,
532    index: usize,
533    path: &Path,
534    row: usize,
535    column: &'static str,
536) -> Result<u64, MergeError> {
537    let value = record
538        .get(index)
539        .ok_or_else(|| MergeError::InvalidCsvRecord {
540            path: path.to_path_buf(),
541            row,
542            columns: record.len(),
543        })?;
544    value
545        .parse::<u64>()
546        .map_err(|source| MergeError::ParseCsvField {
547            path: path.to_path_buf(),
548            row,
549            column,
550            value: value.to_string(),
551            source,
552        })
553}
554
555fn is_metadata_header(record: &csv::StringRecord) -> bool {
556    let Some(first) = record.get(0) else {
557        return false;
558    };
559    let Some(second) = record.get(1) else {
560        return false;
561    };
562    first.eq_ignore_ascii_case("start") && second.eq_ignore_ascii_case("end")
563}
564
565fn copy_reader_to_writer<R: Read, W: Write>(
566    reader: &mut R,
567    writer: &mut W,
568    buffer: &mut [u8],
569) -> io::Result<u64> {
570    let mut total_written = 0_u64;
571    loop {
572        let read_bytes = reader.read(buffer)?;
573        if read_bytes == 0 {
574            break;
575        }
576        writer.write_all(&buffer[..read_bytes])?;
577        total_written += read_bytes as u64;
578    }
579    Ok(total_written)
580}
581
582fn is_npy_file(path: &Path) -> bool {
583    path.extension().is_some_and(|extension| extension == "npy")
584}
585
586fn is_csv_gz_file(path: &Path) -> bool {
587    path.file_name()
588        .and_then(|file_name| file_name.to_str())
589        .is_some_and(|file_name| file_name.ends_with(".csv.gz"))
590}
591
592fn append_extension(path: &Path, extension: &str) -> PathBuf {
593    let mut normalized = PathBuf::new();
594    for component in path.components() {
595        normalized.push(component.as_os_str());
596    }
597    if normalized.as_os_str().is_empty() {
598        normalized = path.to_path_buf();
599    }
600
601    let mut file_name = OsString::from(normalized.as_os_str());
602    file_name.push(".");
603    file_name.push(extension);
604    PathBuf::from(file_name)
605}
606
607fn build_progress_bar(total_files: u64) -> ProgressBar {
608    if cfg!(test) {
609        return ProgressBar::hidden();
610    }
611
612    let progress = ProgressBar::new(total_files);
613    let style = ProgressStyle::with_template(
614        "[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} files ({eta}) {msg}",
615    )
616    .unwrap_or_else(|_| ProgressStyle::default_bar())
617    .progress_chars("=>-");
618    progress.set_style(style);
619    progress.set_message("merging");
620    progress
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use std::io::Read;
627    use tempfile::tempdir;
628
629    fn write_gzip(path: &Path, content: &str) {
630        let file = File::create(path).expect("create gzip input");
631        let mut encoder = GzEncoder::new(file, Compression::default());
632        encoder
633            .write_all(content.as_bytes())
634            .expect("write gzip content");
635        encoder.finish().expect("finish gzip input");
636    }
637
638    fn read_gzip(path: &Path) -> String {
639        let file = File::open(path).expect("open gzip output");
640        let reader = BufReader::new(file);
641        let mut decoder = MultiGzDecoder::new(reader);
642        let mut content = String::new();
643        decoder
644            .read_to_string(&mut content)
645            .expect("read gzip output");
646        content
647    }
648
649    #[test]
650    fn discovers_files_recursively_and_ignores_others() {
651        let temp = tempdir().expect("create tempdir");
652        let nested = temp.path().join("nested").join("inner");
653        fs::create_dir_all(&nested).expect("create nested dirs");
654
655        fs::write(temp.path().join("b.npy"), [1_u8, 2_u8]).expect("write npy");
656        fs::write(nested.join("a.npy"), [3_u8]).expect("write nested npy");
657        fs::write(temp.path().join("skip.txt"), "ignore").expect("write skip file");
658        write_gzip(&nested.join("z.csv.gz"), "zeta\n");
659        write_gzip(&temp.path().join("m.csv.gz"), "mu\n");
660
661        let discovered = discover_files(temp.path()).expect("discover files");
662
663        assert_eq!(discovered.npy_files.len(), 2);
664        assert_eq!(discovered.csv_gz_files.len(), 2);
665        assert!(discovered.npy_files[0] < discovered.npy_files[1]);
666        assert!(discovered.csv_gz_files[0] < discovered.csv_gz_files[1]);
667    }
668
669    #[test]
670    fn builds_single_output_plan() {
671        let output = PathBuf::from("/tmp/output/base");
672        let plan = build_output_plan(&output, 1).expect("build plan");
673        assert_eq!(
674            plan.npy_outputs,
675            vec![PathBuf::from("/tmp/output/base.npy")]
676        );
677        assert_eq!(
678            plan.csv_gz_outputs,
679            vec![PathBuf::from("/tmp/output/base.csv.gz")]
680        );
681    }
682
683    #[test]
684    fn builds_sharded_output_plan() {
685        let output = PathBuf::from("/tmp/output/shards");
686        let plan = build_output_plan(&output, 3).expect("build plan");
687        assert_eq!(
688            plan.npy_outputs,
689            vec![
690                PathBuf::from("/tmp/output/shards/00000000.npy"),
691                PathBuf::from("/tmp/output/shards/00000001.npy"),
692                PathBuf::from("/tmp/output/shards/00000002.npy"),
693            ]
694        );
695        assert_eq!(
696            plan.csv_gz_outputs,
697            vec![
698                PathBuf::from("/tmp/output/shards/00000000.csv.gz"),
699                PathBuf::from("/tmp/output/shards/00000001.csv.gz"),
700                PathBuf::from("/tmp/output/shards/00000002.csv.gz"),
701            ]
702        );
703    }
704
705    #[test]
706    fn builds_single_output_plan_with_trailing_separator() {
707        let output = PathBuf::from("/tmp/output/base/");
708        let plan = build_output_plan(&output, 1).expect("build plan");
709        assert_eq!(
710            plan.npy_outputs,
711            vec![PathBuf::from("/tmp/output/base.npy")]
712        );
713        assert_eq!(
714            plan.csv_gz_outputs,
715            vec![PathBuf::from("/tmp/output/base.csv.gz")]
716        );
717    }
718
719    #[test]
720    fn shards_paths_round_robin() {
721        let paths = vec![
722            PathBuf::from("a.npy"),
723            PathBuf::from("b.npy"),
724            PathBuf::from("c.npy"),
725            PathBuf::from("d.npy"),
726            PathBuf::from("e.npy"),
727        ];
728
729        let shards = shard_paths(&paths, 2);
730        assert_eq!(
731            shards[0],
732            vec![
733                PathBuf::from("a.npy"),
734                PathBuf::from("c.npy"),
735                PathBuf::from("e.npy")
736            ]
737        );
738        assert_eq!(
739            shards[1],
740            vec![PathBuf::from("b.npy"), PathBuf::from("d.npy")]
741        );
742    }
743
744    #[test]
745    fn merges_npy_shards_by_byte_concatenation() {
746        let temp = tempdir().expect("create tempdir");
747        let inputs = temp.path().join("inputs");
748        let outputs = temp.path().join("outputs");
749        fs::create_dir_all(&inputs).expect("create input dir");
750        fs::create_dir_all(&outputs).expect("create output dir");
751
752        let a = inputs.join("a.npy");
753        let b = inputs.join("b.npy");
754        let c = inputs.join("c.npy");
755        fs::write(&a, [1_u8, 2_u8]).expect("write a");
756        fs::write(&b, [3_u8]).expect("write b");
757        fs::write(&c, [4_u8, 5_u8]).expect("write c");
758
759        let shards = vec![vec![a.clone(), c.clone()], vec![b.clone()]];
760        let out0 = outputs.join("00000000.npy");
761        let out1 = outputs.join("00000001.npy");
762        let progress = ProgressBar::hidden();
763
764        merge_npy_shards(&shards, &[out0.clone(), out1.clone()], &progress)
765            .expect("merge npy shards");
766
767        assert_eq!(
768            fs::read(out0).expect("read out0"),
769            vec![1_u8, 2_u8, 4_u8, 5_u8]
770        );
771        assert_eq!(fs::read(out1).expect("read out1"), vec![3_u8]);
772    }
773
774    #[test]
775    fn merges_csv_gz_shards_by_decompress_and_recompress() {
776        let temp = tempdir().expect("create tempdir");
777        let inputs = temp.path().join("inputs");
778        let outputs = temp.path().join("outputs");
779        fs::create_dir_all(&inputs).expect("create input dir");
780        fs::create_dir_all(&outputs).expect("create output dir");
781
782        let first = inputs.join("a.csv.gz");
783        let second = inputs.join("b.csv.gz");
784        let third = inputs.join("c.csv.gz");
785        write_gzip(
786            &first,
787            "start,end,id,src,loc\n0,2,id-a,src-a,1\n2,5,id-b,src-a,2\n",
788        );
789        write_gzip(&second, "start,end,id,src,loc\n0,1,id-c,src-b,1\n");
790        write_gzip(&third, "start,end,id,src,loc\n0,4,id-d,src-c,9\n");
791
792        let shards = vec![vec![first.clone(), third.clone()], vec![second.clone()]];
793        let out0 = outputs.join("00000000.csv.gz");
794        let out1 = outputs.join("00000001.csv.gz");
795        let progress = ProgressBar::hidden();
796
797        merge_csv_gz_shards(&shards, &[out0.clone(), out1.clone()], &progress)
798            .expect("merge csv.gz shards");
799
800        assert_eq!(
801            read_gzip(&out0),
802            "start,end,id,src,loc\n0,2,id-a,src-a,1\n2,5,id-b,src-a,2\n5,9,id-d,src-c,9\n"
803        );
804        assert_eq!(read_gzip(&out1), "start,end,id,src,loc\n0,1,id-c,src-b,1\n");
805    }
806
807    #[test]
808    fn runs_end_to_end_with_sharded_outputs() {
809        let temp = tempdir().expect("create tempdir");
810        let input_root = temp.path().join("input");
811        let nested = input_root.join("nested");
812        fs::create_dir_all(&nested).expect("create nested input dir");
813
814        fs::write(input_root.join("a.npy"), [1_u8]).expect("write a.npy");
815        fs::write(nested.join("b.npy"), [2_u8]).expect("write b.npy");
816        write_gzip(&input_root.join("a.csv.gz"), "0,1,id-a,src-a,1\n");
817        write_gzip(&nested.join("b.csv.gz"), "0,1,id-b,src-b,1\n");
818        fs::write(input_root.join("ignore.bin"), [9_u8]).expect("write ignored file");
819
820        let output_path = temp.path().join("sharded");
821        let config = MergeConfig {
822            input_path: input_root.clone(),
823            num_files: 2,
824            output_path: output_path.clone(),
825        };
826
827        let report = merge_files(&config).expect("run merge");
828
829        assert_eq!(report.npy_inputs, 2);
830        assert_eq!(report.csv_gz_inputs, 2);
831        assert!(output_path.join("00000000.npy").exists());
832        assert!(output_path.join("00000001.npy").exists());
833        assert!(output_path.join("00000000.csv.gz").exists());
834        assert!(output_path.join("00000001.csv.gz").exists());
835    }
836
837    #[test]
838    fn rejects_zero_num_files() {
839        let output = PathBuf::from("anything");
840        let err = build_output_plan(&output, 0).expect_err("expected invalid num-files error");
841        match err {
842            MergeError::InvalidNumFiles(value) => assert_eq!(value, 0),
843            other => panic!("unexpected error: {other}"),
844        }
845    }
846}