use rand::{prelude::SliceRandom, thread_rng};
use std::{
fs::File,
io::{BufRead, BufReader},
};
use crate::pipeline::*;
pub struct RandomLoader {
files: Vec<String>, delimeter: String, load_order: Vec<usize>, currently_loaded_index: usize, max_index: usize, min_index: usize, }
impl RandomLoader {
pub fn new(files: &[String]) -> Self {
RandomLoader {
files: files.to_vec(),
delimeter: "\n".to_string(),
load_order: vec![],
currently_loaded_index: 0,
min_index: 0,
max_index: usize::MAX,
}
}
pub fn from_directory(path: &str) -> Self {
let files = std::fs::read_dir(path)
.unwrap()
.into_iter()
.map(|r| r.unwrap().path().to_str().unwrap().to_string())
.collect();
RandomLoader {
files,
delimeter: "\n".to_string(),
load_order: vec![],
currently_loaded_index: 0,
min_index: 0,
max_index: usize::MAX,
}
}
pub fn with_delimeter(self, delimeter: String) -> Self {
RandomLoader { delimeter, ..self }
}
pub fn max_index(self, max_index: usize) -> Self {
RandomLoader { max_index, ..self }
}
pub fn min_index(self, min_index: usize) -> Self {
RandomLoader { min_index, ..self }
}
}
impl Node for RandomLoader {
type Input = Vec<()>;
type Output = Vec<String>;
fn process(&mut self, input: Self::Input) -> Self::Output {
let mut examples_to_load = self.load_order[self.currently_loaded_index
..self
.load_order
.len()
.min(self.currently_loaded_index + input.len())]
.to_vec();
examples_to_load.sort_unstable();
let mut current_index = 0;
let mut current_example = 0;
let mut loaded = vec![];
for file in &self.files {
let file = File::open(file).unwrap();
let reader = BufReader::new(file);
if self.delimeter == "\n" {
for line in reader.lines().flatten() {
if current_index == examples_to_load[current_example] {
loaded.push(line);
current_example += 1;
if current_example == examples_to_load.len() {
break;
}
}
current_index += 1;
}
} else {
let mut intermediate = String::new();
for line in reader.lines().flatten() {
if line.contains(&self.delimeter) {
let split: Vec<&str> = line.split(&self.delimeter).collect();
if line.starts_with(&self.delimeter) {
if current_index == examples_to_load[current_example] {
loaded.push(intermediate.clone());
current_example += 1;
}
current_index += 1;
intermediate = String::new();
}
if intermediate.is_empty() {
if current_index == examples_to_load[current_example] {
loaded.push(split[0].to_string());
current_example += 1;
}
} else if current_index == examples_to_load[current_example] {
intermediate.push_str(split[0]);
loaded.push(intermediate.clone());
current_example += 1;
}
current_index += 1;
if split.len() > 1 {
for s in split[1..split.len() - 1].iter() {
if current_index == examples_to_load[current_example] {
loaded.push(s.to_string());
current_example += 1;
}
current_index += 1;
}
}
if line.ends_with(&self.delimeter) {
if current_index == examples_to_load[current_example] {
loaded.push(split.last().unwrap().to_string());
current_example += 1;
}
current_index += 1;
} else {
intermediate = split.last().unwrap().to_string();
}
if current_index >= examples_to_load.len() {
break;
}
} else {
intermediate.push_str(&line);
}
}
}
if current_example == examples_to_load.len() {
break;
}
}
self.currently_loaded_index += loaded.len();
loaded.shuffle(&mut thread_rng());
loaded
}
fn reset(&mut self) {
let mut total_examples = 0;
for file in &self.files {
let file = File::open(file).unwrap();
let reader = BufReader::new(file);
let mut delimeter_count = 0;
if self.delimeter == "\n" {
delimeter_count = reader.lines().count();
} else {
for line in reader.lines().flatten() {
delimeter_count += line.matches(&self.delimeter).count();
}
delimeter_count += 1; }
total_examples += delimeter_count;
if total_examples >= self.max_index {
break;
}
}
let mut rng = thread_rng();
let mut block_indexes: Vec<usize> = (usize::max(0, self.min_index)
..usize::min(total_examples, self.max_index))
.step_by(100_000)
.collect();
block_indexes.shuffle(&mut rng);
self.load_order = block_indexes
.iter()
.map(|i| {
let mut indexes: Vec<usize> =
(*i..usize::min(i + 100_000, self.max_index)).collect();
indexes.shuffle(&mut rng);
indexes
})
.fold(
Vec::with_capacity(
usize::min(total_examples, self.max_index) - usize::max(0, self.min_index),
),
|mut acc, i| {
acc.extend(i.into_iter());
acc
},
);
self.currently_loaded_index = 0;
}
fn data_remaining(&self, _before: usize) -> usize {
self.load_order.len() - self.currently_loaded_index
}
}
impl ExplicitNode<Vec<()>, Vec<String>> for RandomLoader {
fn process(&mut self, input: Vec<()>) -> Vec<String> {
<Self as Node>::process(self, input)
}
fn data_remaining(&self, before: usize) -> usize {
<Self as Node>::data_remaining(self, before)
}
fn reset(&mut self) {
<Self as Node>::reset(self);
}
}