use std::collections::HashMap;
use std::path::{Path, PathBuf};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use log::{debug, info};
use rand::prelude::*;
use rand_chacha::ChaChaRng;
use try_from::TryInto;
use crate::error::{Error, Result};
use crate::io::{open_data, Compression};
use crate::split::{
single::{ProportionSplit, RowSplit, Split, SplitEnum},
splits::{SplitSelection, Splits},
writer::SplitWriter,
};
pub struct SplitterBuilder {
input: PathBuf,
splits: Splits,
seed: Option<u64>,
output_prefix: Option<PathBuf>,
chunk_size: Option<u64>,
total_rows: Option<u64>,
input_compression: Compression,
output_compression: Compression,
csv: bool,
has_header: bool,
}
impl SplitterBuilder {
pub fn new<P: AsRef<Path>>(
input: &P,
row_splits: Vec<RowSplit>,
prop_splits: Vec<ProportionSplit>,
) -> Result<Self> {
let splits = if row_splits.is_empty() {
Splits::Proportions(prop_splits.try_into()?)
} else {
Splits::Rows(row_splits.into())
};
Ok(SplitterBuilder {
input: input.as_ref().to_path_buf(),
splits,
seed: None,
output_prefix: None,
chunk_size: None,
total_rows: None,
input_compression: Compression::Uncompressed,
output_compression: Compression::Uncompressed,
csv: false,
has_header: true,
})
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn output_prefix(mut self, output_prefix: PathBuf) -> Self {
self.output_prefix = Some(output_prefix);
self
}
pub fn chunk_size(mut self, chunk_size: u64) -> Self {
self.chunk_size = Some(chunk_size);
self
}
pub fn total_rows(mut self, total_rows: u64) -> Self {
self.total_rows = Some(total_rows);
self
}
pub fn input_compression(mut self, input_compression: Compression) -> Self {
self.input_compression = input_compression;
self
}
pub fn output_compression(mut self, output_compression: Compression) -> Self {
self.output_compression = output_compression;
self
}
pub fn csv(mut self, csv: bool) -> Self {
self.csv = csv;
self
}
pub fn has_header(mut self, has_header: bool) -> Self {
self.has_header = has_header;
self
}
pub fn build(self) -> Result<Splitter> {
let rng = match self.seed {
Some(s) => ChaChaRng::seed_from_u64(s),
None => ChaChaRng::from_entropy(),
};
Ok(Splitter {
input: self.input,
rng,
splits: self.splits,
output_prefix: self.output_prefix,
chunk_size: self.chunk_size,
total_rows: self.total_rows,
input_compression: self.input_compression,
output_compression: self.output_compression,
csv: self.csv,
has_header: self.has_header,
})
}
}
pub struct Splitter {
input: PathBuf,
splits: Splits,
rng: ChaChaRng,
output_prefix: Option<PathBuf>,
chunk_size: Option<u64>,
total_rows: Option<u64>,
input_compression: Compression,
output_compression: Compression,
csv: bool,
has_header: bool,
}
impl Splitter {
pub fn run(mut self) -> Result<()> {
let multi = MultiProgress::new();
let progress: HashMap<String, ProgressBar> = match (&self.splits, self.total_rows) {
(Splits::Proportions(p), Some(t)) => p
.splits
.iter()
.map(|p| {
let style = ProgressStyle::default_bar()
.template("{msg:<10}: [{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/~{len:7} (ETA: {eta_precise})")
.progress_chars("█▉▊▋▌▍▎▏ ");
let split_total = p.proportion * t as f64;
let pb = multi.add(ProgressBar::new(split_total as u64));
pb.set_draw_delta(10); pb.set_message(&p.name());
pb.set_style(style);
(p.name().to_string(), pb)
})
.collect(),
(Splits::Proportions(p), None) => p
.splits
.iter()
.map(|p| {
let style = ProgressStyle::default_bar()
.template("{msg:<10}: [{elapsed_precise}] {spinner:.green} {pos:>7}");
let pb = multi.add(ProgressBar::new_spinner());
pb.set_draw_delta(10); pb.set_style(style);
pb.set_message(&p.name());
(p.name().to_string(), pb)
})
.collect(),
(Splits::Rows(r), _) => r
.splits
.iter()
.map(|r| {
let style = ProgressStyle::default_bar()
.template("{msg:<10}: [{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} (ETA: {eta_precise})")
.progress_chars("█▉▊▋▌▍▎▏ ");
let pb = multi.add(ProgressBar::new(r.total as u64));
pb.set_draw_delta(10); pb.set_message(&r.name());
pb.set_style(style);
(r.name().to_string().clone(), pb)
})
.collect()
};
let mut senders = HashMap::new();
let mut chunk_writers = Vec::new();
let output_path = match self.output_prefix {
Some(ref f) => f.clone(),
None => self.input.clone(),
};
match &self.splits {
Splits::Proportions(p) => {
for split in p.iter() {
let split = SplitEnum::Proportion((*split).clone());
let (split_sender, mut split_chunk_writers) = SplitWriter::new(
&output_path,
&split,
self.chunk_size,
self.total_rows,
self.output_compression,
)?;
senders.insert(split.name().to_string(), split_sender);
chunk_writers.append(&mut split_chunk_writers);
}
}
Splits::Rows(r) => {
for split in r.iter() {
let split = SplitEnum::Rows((*split).clone());
let (split_sender, mut split_chunk_writers) = SplitWriter::new(
&output_path,
&split,
self.chunk_size,
self.total_rows,
self.output_compression,
)?;
senders.insert(split.name().to_string(), split_sender);
chunk_writers.append(&mut split_chunk_writers);
}
}
};
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(chunk_writers.len() + 2)
.thread_name(|num| format!("thread-{}", num))
.start_handler(|num| debug!("thread {} starting", num))
.exit_handler(|num| debug!("thread {} finishing", num))
.build()
.unwrap();
pool.scope(move |scope| {
info!("Reading data from {}", self.input.to_str().unwrap());
let reader_builder = if self.csv {
let mut reader_builder = csv::ReaderBuilder::new();
reader_builder.has_headers(false);
Some(reader_builder)
} else {
None
};
let mut reader = open_data(&self.input, self.input_compression, reader_builder)?;
if self.has_header {
info!("Writing header to files");
let header = match reader.read_line() {
Some(h) => h?,
None => return Err(Error::EmptyFile),
};
for sender in senders.values_mut() {
sender.send_all(&header)?;
}
}
scope.spawn(move |_| multi.join().unwrap());
let has_header = self.has_header;
{
for writer in chunk_writers {
scope.spawn(move |_| {
let mut chunk_id = writer.chunk_id;
let mut rows_sent_to_chunk = 0;
let mut file = writer.output(chunk_id).expect("Could not open file");
let mut header: Header<String> = if has_header {
Header::None
} else {
Header::Disabled
};
for row in writer.receiver.iter() {
if header == Header::None {
header = Header::Some(row.clone());
}
if let Some(chunk_size) = writer.chunk_size {
if rows_sent_to_chunk > (chunk_size) {
chunk_id = chunk_id.map(|c| c + 2);
file = writer.output(chunk_id).expect("Could not open file");
if let Header::Some(h) = header.as_ref() {
writer
.handle_row(&mut file, h)
.expect("Could not write row to file");
}
rows_sent_to_chunk = 1
}
}
writer
.handle_row(&mut file, &row)
.expect("Could not write row to file");
rows_sent_to_chunk += 1;
}
})
}
}
info!("Reading lines");
while let Some(record) = reader.read_line() {
let split = self.splits.get_split(&mut self.rng);
match split {
SplitSelection::Some(split) => {
match senders.get_mut(split).unwrap().send(record.unwrap()) {
Ok(_) => progress[split].inc(1),
Err(e) => return Err(e),
}
}
SplitSelection::None => continue,
SplitSelection::Done => break,
}
}
progress.values().for_each(|f| f.finish_at_current_pos());
info!("Finished writing to files");
for (_, sender) in senders {
sender.finish();
}
Ok(())
})?;
Ok(())
}
}
#[derive(Debug, PartialEq)]
enum Header<T> {
None,
Some(T),
Disabled,
}
impl Header<String> {
fn as_ref(&self) -> Header<&str> {
match self {
Header::None => Header::None,
Header::Disabled => Header::Disabled,
Header::Some(s) => Header::Some(s.as_str()),
}
}
}