extsort_iter/run/
file_run.rs

1use std::{
2    io::{self, ErrorKind, Read},
3    mem::{self, MaybeUninit},
4    num::NonZeroUsize,
5};
6
7use crate::tape::Tape;
8
9use super::Run;
10
11/// A backing for a run. Basically, we extend the Read trait
12/// with an option for premature resource release
13pub trait RunBacking: Read {
14    /// allow the backing to optionally free resources.
15    /// after this, no more data should be returned
16    fn finalize(&mut self);
17}
18impl RunBacking for Box<dyn Read + Send> {
19    fn finalize(&mut self) {
20        // for boxed backings, we can release the underlying resources
21        // by replacing it with a dummy implementation.
22        // this will drop the old value.
23        let dummy = Box::new(io::Cursor::new(&[]));
24        *self = dummy;
25    }
26}
27
28/// A run backed by a file on disk.
29/// The file is deleted when the run is dropped.
30/// Our TBacking type will be File for the real
31/// usage in the library and Cursor for testing
32pub struct ExternalRun<T, TBacking>
33where
34    TBacking: RunBacking,
35{
36    /// the source file that is backing our run.
37    source: TBacking,
38    /// the data maintained in our sort buffer.
39    /// we maintain the invariant that
40    /// all the data entries from the read_idx
41    /// to the end are actually initialized.
42    buffer: Vec<MaybeUninit<T>>,
43    /// a pointer marking the point from which on the entries are initialized.
44    /// the entries in our vec are initialized.
45    read_idx: usize,
46    /// the remaining entries for this run.
47    /// used for the size_hint and to be able to deal with zero sized types
48    remaining_entries: usize,
49}
50
51impl<T, B> Drop for ExternalRun<T, B>
52where
53    B: RunBacking,
54{
55    fn drop(&mut self) {
56        // if the
57        if mem::needs_drop::<T>() {
58            // drop all elements by reading from the source until all items are exhausted
59            while self.next().is_some() {}
60        }
61    }
62}
63
64/// Creates a new FileRun Object that uses the provided source as its
65/// buffer, but is not actually backed by anything on disk
66pub fn create_buffer_run<T>(source: Vec<T>) -> ExternalRun<T, Box<dyn Read + Send>> {
67    let buffer: Vec<MaybeUninit<T>> = unsafe {
68        // we are only transmuting our Vec<T> to Vec<MaybeUninit<T>>.
69        // this is guaranteed to have the same binary representation.
70        core::mem::transmute(source)
71    };
72
73    let remaining_entries = buffer.len();
74
75    ExternalRun {
76        source: Box::new(io::Cursor::new(&[])),
77        buffer,
78        read_idx: 0,
79        remaining_entries,
80    }
81}
82
83impl<T, TBacking> ExternalRun<T, TBacking>
84where
85    TBacking: RunBacking,
86{
87    pub fn from_tape(tape: Tape<TBacking>, buffer_size: NonZeroUsize) -> Self {
88        let num_entries = tape.num_entries();
89        let source = tape.into_backing();
90
91        let mut buffer = Vec::with_capacity(buffer_size.into());
92        for _ in 0..buffer_size.into() {
93            buffer.push(MaybeUninit::uninit());
94        }
95        let mut res = Self {
96            buffer,
97            read_idx: 0,
98            remaining_entries: num_entries,
99            source,
100        };
101
102        res.refill_buffer();
103
104        res
105    }
106
107    /// refills the read buffer.
108    /// this should only be called if the read_idx is at the end of the buffer
109    ///
110    /// This function may panic on IO errors
111    fn refill_buffer(&mut self) {
112        /// keep retrying the read if it returns with an interrupted error.
113        fn read_with_retry(source: &mut impl Read, buffer: &mut [u8]) -> io::Result<usize> {
114            loop {
115                match source.read(buffer) {
116                    Ok(size) => break Ok(size),
117                    Err(e) if e.kind() == ErrorKind::Interrupted => {}
118                    err => break err,
119                }
120            }
121        }
122
123        /// try to read exactly the requested number of bytes
124        /// This function may only return less than the number of requested bytes
125        /// when the end of the run is reached.
126        fn try_read_exact(source: &mut impl Read, mut buffer: &mut [u8]) -> usize {
127            let mut bytes_read = 0;
128            while !buffer.is_empty() {
129                let read = read_with_retry(source, buffer).expect("Unable to perform read on FileRun. This means that the file was modified from under us!");
130                if read == 0 {
131                    break;
132                }
133                buffer = &mut buffer[read..];
134                bytes_read += read;
135            }
136
137            bytes_read
138        }
139
140        let item_size = std::mem::size_of::<T>();
141
142        // for ZSTs it really does not make sense to try to read them back from our
143        // io backing, so we just reset the read index
144        if item_size == 0 {
145            self.read_idx = 0;
146            return;
147        }
148
149        let slice = unsafe {
150            let start = self.buffer.as_mut_ptr() as *mut u8;
151            std::slice::from_raw_parts_mut(start, self.buffer.len() * item_size)
152        };
153
154        let bytes_read = try_read_exact(&mut self.source, slice);
155        assert_eq!(
156            0,
157            bytes_read % item_size,
158            "The size of the file does not match anymore! was it modified?"
159        );
160        let remaining_size = bytes_read / item_size;
161        self.buffer.truncate(remaining_size);
162
163        self.read_idx = 0;
164    }
165}
166
167impl<T, TBacking> Run<T> for ExternalRun<T, TBacking>
168where
169    TBacking: RunBacking,
170{
171    /// Peek at the next entry in the run
172    fn peek(&self) -> Option<&T> {
173        if self.remaining_entries == 0 {
174            None
175        } else {
176            // SAFETY:
177            // we always ensure that everything from the read_idx to the
178            // end of the buffer is properly initialized from the backing file.
179            // so while the read_idx is inside the buffer bounds, it must be valid.
180            unsafe { Some(self.buffer[self.read_idx].assume_init_ref()) }
181        }
182    }
183
184    /// Get the next item from the run and advance its position
185    fn next(&mut self) -> Option<T> {
186        if self.remaining_entries == 0 {
187            self.source.finalize();
188            return None;
189        }
190
191        // when we have reached this point, we can be certain that we are inside the
192        // buffer bounds.
193
194        // SAFETY:
195        // we always ensure that everything from the read_idx to the
196        // end of the buffer is properly initialized from the backing file.
197        // so while the read_idx is inside the buffer bounds, it must be valid.
198        let result = unsafe { self.buffer[self.read_idx].assume_init_read() };
199
200        // we consumed the value at the read_index so we need to make sure that we increment it
201        // to maintain the buffer invariant
202        // as well as decrement the remaining entries in our run.
203        self.read_idx += 1;
204        self.remaining_entries -= 1;
205
206        // we check if we need to refill the buffer in case we have reached the end
207        // we do this here to make sure that the peek is always inside
208        // the buffer as long as there are still items
209        if self.read_idx >= self.buffer.len() {
210            self.refill_buffer();
211        }
212
213        Some(result)
214    }
215
216    fn remaining_items(&self) -> usize {
217        self.remaining_entries
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use std::fmt::Debug;
224
225    impl RunBacking for std::io::Cursor<Vec<u8>> {
226        fn finalize(&mut self) {
227            // for a cursor, we just set it to the end of the buffer.
228            // like this, it will also not yield any bytes anymore.
229            let len = self.get_ref().len();
230            self.set_position(len as u64);
231        }
232    }
233
234    use crate::tape::vec_to_tape;
235
236    use super::*;
237
238    fn test_file_run<T>(data: Vec<T>, buffer_size: NonZeroUsize)
239    where
240        T: Clone + Eq + Debug,
241    {
242        let tape = vec_to_tape(data.clone());
243        let mut run = ExternalRun::from_tape(tape, buffer_size);
244
245        assert_eq!(data.len(), run.remaining_items());
246        let collected = std::iter::from_fn(|| run.next()).collect::<Vec<_>>();
247        assert_eq!(data, collected);
248    }
249
250    #[test]
251    fn test_drop() {
252        let vec: Vec<i32> = (1..5).collect();
253        let data: Vec<_> = core::iter::repeat(&vec).take(20).cloned().collect();
254        let tape = vec_to_tape(data);
255        let mut run: ExternalRun<Vec<i32>, _> =
256            ExternalRun::from_tape(tape, NonZeroUsize::new(4096).unwrap());
257        for _ in 0..10 {
258            run.next();
259        }
260        drop(run);
261    }
262
263    #[test]
264    fn works_with_vecs() {
265        let d = (1..100).collect::<Vec<_>>();
266        let data = vec![d; 10];
267
268        test_file_run(data, NonZeroUsize::new(2).unwrap());
269    }
270
271    #[test]
272    fn works_with_zst() {
273        let data = vec![(); 10];
274        test_file_run(data, NonZeroUsize::new(2).unwrap());
275    }
276
277    #[test]
278    fn works_with_larger_buffer() {
279        let size = NonZeroUsize::new(20).unwrap();
280        let data = vec![(); 10];
281        test_file_run(data, size);
282
283        let data = vec![1337; 10];
284        test_file_run(data, size);
285    }
286
287    #[test]
288    fn works_with_empty_data() {
289        let size = NonZeroUsize::new(10).unwrap();
290        let data = vec![(); 0];
291        test_file_run(data, size);
292
293        let data = vec![1; 0];
294        test_file_run(data, size);
295    }
296}