extsort_iter/run/
file_run.rs1use std::{
2 io::{self, ErrorKind, Read},
3 mem::{self, MaybeUninit},
4 num::NonZeroUsize,
5};
6
7use crate::tape::Tape;
8
9use super::Run;
10
11pub trait RunBacking: Read {
14 fn finalize(&mut self);
17}
18impl RunBacking for Box<dyn Read + Send> {
19 fn finalize(&mut self) {
20 let dummy = Box::new(io::Cursor::new(&[]));
24 *self = dummy;
25 }
26}
27
28pub struct ExternalRun<T, TBacking>
33where
34 TBacking: RunBacking,
35{
36 source: TBacking,
38 buffer: Vec<MaybeUninit<T>>,
43 read_idx: usize,
46 remaining_entries: usize,
49}
50
51impl<T, B> Drop for ExternalRun<T, B>
52where
53 B: RunBacking,
54{
55 fn drop(&mut self) {
56 if mem::needs_drop::<T>() {
58 while self.next().is_some() {}
60 }
61 }
62}
63
64pub fn create_buffer_run<T>(source: Vec<T>) -> ExternalRun<T, Box<dyn Read + Send>> {
67 let buffer: Vec<MaybeUninit<T>> = unsafe {
68 core::mem::transmute(source)
71 };
72
73 let remaining_entries = buffer.len();
74
75 ExternalRun {
76 source: Box::new(io::Cursor::new(&[])),
77 buffer,
78 read_idx: 0,
79 remaining_entries,
80 }
81}
82
83impl<T, TBacking> ExternalRun<T, TBacking>
84where
85 TBacking: RunBacking,
86{
87 pub fn from_tape(tape: Tape<TBacking>, buffer_size: NonZeroUsize) -> Self {
88 let num_entries = tape.num_entries();
89 let source = tape.into_backing();
90
91 let mut buffer = Vec::with_capacity(buffer_size.into());
92 for _ in 0..buffer_size.into() {
93 buffer.push(MaybeUninit::uninit());
94 }
95 let mut res = Self {
96 buffer,
97 read_idx: 0,
98 remaining_entries: num_entries,
99 source,
100 };
101
102 res.refill_buffer();
103
104 res
105 }
106
107 fn refill_buffer(&mut self) {
112 fn read_with_retry(source: &mut impl Read, buffer: &mut [u8]) -> io::Result<usize> {
114 loop {
115 match source.read(buffer) {
116 Ok(size) => break Ok(size),
117 Err(e) if e.kind() == ErrorKind::Interrupted => {}
118 err => break err,
119 }
120 }
121 }
122
123 fn try_read_exact(source: &mut impl Read, mut buffer: &mut [u8]) -> usize {
127 let mut bytes_read = 0;
128 while !buffer.is_empty() {
129 let read = read_with_retry(source, buffer).expect("Unable to perform read on FileRun. This means that the file was modified from under us!");
130 if read == 0 {
131 break;
132 }
133 buffer = &mut buffer[read..];
134 bytes_read += read;
135 }
136
137 bytes_read
138 }
139
140 let item_size = std::mem::size_of::<T>();
141
142 if item_size == 0 {
145 self.read_idx = 0;
146 return;
147 }
148
149 let slice = unsafe {
150 let start = self.buffer.as_mut_ptr() as *mut u8;
151 std::slice::from_raw_parts_mut(start, self.buffer.len() * item_size)
152 };
153
154 let bytes_read = try_read_exact(&mut self.source, slice);
155 assert_eq!(
156 0,
157 bytes_read % item_size,
158 "The size of the file does not match anymore! was it modified?"
159 );
160 let remaining_size = bytes_read / item_size;
161 self.buffer.truncate(remaining_size);
162
163 self.read_idx = 0;
164 }
165}
166
167impl<T, TBacking> Run<T> for ExternalRun<T, TBacking>
168where
169 TBacking: RunBacking,
170{
171 fn peek(&self) -> Option<&T> {
173 if self.remaining_entries == 0 {
174 None
175 } else {
176 unsafe { Some(self.buffer[self.read_idx].assume_init_ref()) }
181 }
182 }
183
184 fn next(&mut self) -> Option<T> {
186 if self.remaining_entries == 0 {
187 self.source.finalize();
188 return None;
189 }
190
191 let result = unsafe { self.buffer[self.read_idx].assume_init_read() };
199
200 self.read_idx += 1;
204 self.remaining_entries -= 1;
205
206 if self.read_idx >= self.buffer.len() {
210 self.refill_buffer();
211 }
212
213 Some(result)
214 }
215
216 fn remaining_items(&self) -> usize {
217 self.remaining_entries
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use std::fmt::Debug;
224
225 impl RunBacking for std::io::Cursor<Vec<u8>> {
226 fn finalize(&mut self) {
227 let len = self.get_ref().len();
230 self.set_position(len as u64);
231 }
232 }
233
234 use crate::tape::vec_to_tape;
235
236 use super::*;
237
238 fn test_file_run<T>(data: Vec<T>, buffer_size: NonZeroUsize)
239 where
240 T: Clone + Eq + Debug,
241 {
242 let tape = vec_to_tape(data.clone());
243 let mut run = ExternalRun::from_tape(tape, buffer_size);
244
245 assert_eq!(data.len(), run.remaining_items());
246 let collected = std::iter::from_fn(|| run.next()).collect::<Vec<_>>();
247 assert_eq!(data, collected);
248 }
249
250 #[test]
251 fn test_drop() {
252 let vec: Vec<i32> = (1..5).collect();
253 let data: Vec<_> = core::iter::repeat(&vec).take(20).cloned().collect();
254 let tape = vec_to_tape(data);
255 let mut run: ExternalRun<Vec<i32>, _> =
256 ExternalRun::from_tape(tape, NonZeroUsize::new(4096).unwrap());
257 for _ in 0..10 {
258 run.next();
259 }
260 drop(run);
261 }
262
263 #[test]
264 fn works_with_vecs() {
265 let d = (1..100).collect::<Vec<_>>();
266 let data = vec![d; 10];
267
268 test_file_run(data, NonZeroUsize::new(2).unwrap());
269 }
270
271 #[test]
272 fn works_with_zst() {
273 let data = vec![(); 10];
274 test_file_run(data, NonZeroUsize::new(2).unwrap());
275 }
276
277 #[test]
278 fn works_with_larger_buffer() {
279 let size = NonZeroUsize::new(20).unwrap();
280 let data = vec![(); 10];
281 test_file_run(data, size);
282
283 let data = vec![1337; 10];
284 test_file_run(data, size);
285 }
286
287 #[test]
288 fn works_with_empty_data() {
289 let size = NonZeroUsize::new(10).unwrap();
290 let data = vec![(); 0];
291 test_file_run(data, size);
292
293 let data = vec![1; 0];
294 test_file_run(data, size);
295 }
296}