1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use std::{fs::File, io::{BufReader, BufRead}};

use itertools::Itertools;

use crate::pipeline::Node;

/// A loader with a key generating function
pub struct KeyedLoader {
    files: Vec<String>,
    file_sizes: Vec<usize>,
    delimeter: String,
}

impl KeyedLoader {
    pub fn new(files: &[&str], delimeter: &str) -> Self {
        // Get file sizes
        let file_sizes: Vec<usize> = files.iter().map(|f| {
            let file = File::open(f).unwrap();
            let reader = BufReader::new(file);
            let mut delimeter_count = 0;
            if delimeter == "\n" {
                delimeter_count = reader.lines().count();
            } else {
                for line in reader.lines().flatten() {
                    delimeter_count += line.matches(delimeter).count();
                }
                delimeter_count += 1; // Since delimeters divide the examples, there should be 1 more example than delimeter
            }
            delimeter_count
        }).collect();

        KeyedLoader {
            files: files.iter().map(|s| s.to_string()).collect(),
            file_sizes,
            delimeter: delimeter.to_string()
        }
    }
}

impl Node for KeyedLoader {
    type Input = usize;
    type Output = String;

    fn process(&mut self, input: Vec<Self::Input>) -> Vec<Self::Output> {
        // Get bounds to load from
        let (min, max) = input.iter().minmax().into_option().unwrap().to_owned();
        let (mut min, mut max) = (*min, *max);
        let (mut min_file, mut max_file) = (0, 0);
        let mut counter = 0;
        for (index, file_size) in self.file_sizes.iter().enumerate() {
            counter += file_size;
            if counter > min {
                min_file = index;
                min -= counter + file_size;
            }
            if counter + file_size > max {
                max_file = index;
                max -= counter + file_size;
            }
        }
        // Sort inputs and keep track of order (orig order, sorted indexes)
        let mut sorted_inputs: Vec<(usize, usize)> = input.into_iter().enumerate().collect();
        sorted_inputs.sort_by(|a, b| a.1.cmp(&b.1));

        // Load all segments from min to max
        let mut buffer = Vec::with_capacity(sorted_inputs.len());
        for file_index in min_file..max_file + 1 {
            let file = File::open(&self.files[file_index]).unwrap();
            let reader = BufReader::new(file);
            
            let mut index_counter = 0;
            let mut segment_counter = if file_index == min_file {min} else {0};
            let segments_to_take = if file_index == max_file {max} else {self.file_sizes[file_index]};
            if self.delimeter == "\n" {
                for line in reader.lines().flatten() {
                    if segment_counter == sorted_inputs[index_counter].1 {
                        buffer.push(line);
                        index_counter += 1;
                        if index_counter == sorted_inputs.len() {return buffer;}
                    }
                    segment_counter += 1;
                }
            } else {
                let mut intermediate_segment = "".to_string();
                for line in reader.lines().flatten() {
                    let line_segments: Vec<&str> = line.split(&self.delimeter).collect();

                    if segment_counter == sorted_inputs[index_counter].1 {
                        buffer.push(format!("{}{}", intermediate_segment, line_segments[0]));
                        index_counter += 1;
                        if index_counter == sorted_inputs.len() {return buffer;}
                    }
                    for line_segment in line_segments.iter().take((segments_to_take - counter).min(line_segments.len() - 1)) {
                        if segment_counter == sorted_inputs[index_counter].1 {
                            buffer.push(line_segment.to_string());
                            index_counter += 1;
                            if index_counter == sorted_inputs.len() {return buffer;}
                        }
                    }
                    intermediate_segment = line_segments.last().unwrap().to_string();
                        
                    segment_counter += line_segments.len() - 1;
                    if segment_counter >= segments_to_take {
                        break;
                    }
                }
            }
        }

        buffer
    }

    fn reset(&mut self) {
        // Recalculate file sizes
        let file_sizes = self.files.iter().map(|f| {
            let file = File::open(f).unwrap();
            let reader = BufReader::new(file);
            let mut delimeter_count = 0;
            if self.delimeter == "\n" {
                delimeter_count = reader.lines().count();
            } else {
                for line in reader.lines().flatten() {
                    delimeter_count += line.matches(&self.delimeter).count();
                }
                delimeter_count += 1; // Since delimeters divide the examples, there should be 1 more example than delimeter
            }
            delimeter_count
        }).collect();
        self.file_sizes = file_sizes;
    }

    fn data_remaining(&self) -> usize {
        usize::MAX
    }
}