dataflow/pipeline/loader/
keyed.rs

1use std::{
2    fs::File,
3    io::{BufRead, BufReader},
4};
5
6use itertools::Itertools;
7
8use crate::pipeline::*;
9
10/// A loader with a key generating function
11#[derive(Clone)]
12pub struct KeyedLoader {
13    files: Vec<String>,
14    file_sizes: Vec<usize>,
15    delimeter: String,
16}
17
18impl KeyedLoader {
19    pub fn new(files: &[&str], delimeter: &str) -> Self {
20        // Get file sizes
21        let file_sizes: Vec<usize> = files
22            .iter()
23            .map(|f| {
24                let file = File::open(f).unwrap();
25                let reader = BufReader::new(file);
26                let mut delimeter_count = 0;
27                if delimeter == "\n" {
28                    delimeter_count = reader.lines().count();
29                } else {
30                    for line in reader.lines().flatten() {
31                        delimeter_count += line.matches(delimeter).count();
32                    }
33                    delimeter_count += 1; // Since delimeters divide the examples, there should be 1 more example than delimeter
34                }
35                delimeter_count
36            })
37            .collect();
38
39        KeyedLoader {
40            files: files.iter().map(|s| s.to_string()).collect(),
41            file_sizes,
42            delimeter: delimeter.to_string(),
43        }
44    }
45}
46
47impl Node<Vec<usize>> for KeyedLoader {
48    type Output = Vec<String>;
49
50    fn process(&mut self, input: Vec<usize>) -> Self::Output {
51        // Get bounds to load from
52        let (min, max) = input.iter().minmax().into_option().unwrap().to_owned();
53        let (mut min, mut max) = (*min, *max);
54        let (mut min_file, mut max_file) = (0, 0);
55        let mut counter = 0;
56        for (index, file_size) in self.file_sizes.iter().enumerate() {
57            counter += file_size;
58            if counter > min {
59                min_file = index;
60                min -= counter + file_size;
61            }
62            if counter + file_size > max {
63                max_file = index;
64                max -= counter + file_size;
65            }
66        }
67        // Sort inputs and keep track of order (orig order, sorted indexes)
68        let mut sorted_inputs: Vec<(usize, usize)> = input.into_iter().enumerate().collect();
69        sorted_inputs.sort_by(|a, b| a.1.cmp(&b.1));
70
71        // Load all segments from min to max
72        let mut buffer = Vec::with_capacity(sorted_inputs.len());
73        for file_index in min_file..max_file + 1 {
74            let file = File::open(&self.files[file_index]).unwrap();
75            let reader = BufReader::new(file);
76
77            let mut index_counter = 0;
78            let mut segment_counter = if file_index == min_file { min } else { 0 };
79            let segments_to_take = if file_index == max_file {
80                max
81            } else {
82                self.file_sizes[file_index]
83            };
84            if self.delimeter == "\n" {
85                for line in reader.lines().flatten() {
86                    if segment_counter == sorted_inputs[index_counter].1 {
87                        buffer.push(line);
88                        index_counter += 1;
89                        if index_counter == sorted_inputs.len() {
90                            return buffer;
91                        }
92                    }
93                    segment_counter += 1;
94                }
95            } else {
96                let mut intermediate_segment = "".to_string();
97                for line in reader.lines().flatten() {
98                    let line_segments: Vec<&str> = line.split(&self.delimeter).collect();
99
100                    if segment_counter == sorted_inputs[index_counter].1 {
101                        buffer.push(format!("{}{}", intermediate_segment, line_segments[0]));
102                        index_counter += 1;
103                        if index_counter == sorted_inputs.len() {
104                            return buffer;
105                        }
106                    }
107                    for line_segment in line_segments
108                        .iter()
109                        .take((segments_to_take - counter).min(line_segments.len() - 1))
110                    {
111                        if segment_counter == sorted_inputs[index_counter].1 {
112                            buffer.push(line_segment.to_string());
113                            index_counter += 1;
114                            if index_counter == sorted_inputs.len() {
115                                return buffer;
116                            }
117                        }
118                    }
119                    intermediate_segment = line_segments.last().unwrap().to_string();
120
121                    segment_counter += line_segments.len() - 1;
122                    if segment_counter >= segments_to_take {
123                        break;
124                    }
125                }
126            }
127        }
128
129        buffer
130    }
131
132    fn reset(&mut self) {
133        // Recalculate file sizes
134        let file_sizes = self
135            .files
136            .iter()
137            .map(|f| {
138                let file = File::open(f).unwrap();
139                let reader = BufReader::new(file);
140                let mut delimeter_count = 0;
141                if self.delimeter == "\n" {
142                    delimeter_count = reader.lines().count();
143                } else {
144                    for line in reader.lines().flatten() {
145                        delimeter_count += line.matches(&self.delimeter).count();
146                    }
147                    delimeter_count += 1; // Since delimeters divide the examples, there should be 1 more example than delimeter
148                }
149                delimeter_count
150            })
151            .collect();
152        self.file_sizes = file_sizes;
153    }
154
155    fn data_remaining(&self, before: usize) -> usize {
156        before
157    }
158}