shared_files/
reader.rs

1//! File reading functionality, notably the [`SharedFileReader`] type.
2
3use crate::errors::ReadError;
4use crate::{Sentinel, SharedFileType, WriteState};
5use pin_project::{pin_project, pinned_drop};
6use std::io::{ErrorKind, SeekFrom};
7use std::pin::Pin;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tokio::io;
12use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
13use uuid::Uuid;
14
15/// A reader for the shared temporary file.
16#[pin_project(PinnedDrop)]
17pub struct SharedFileReader<T> {
18    /// The ID of the reader.
19    id: Uuid,
20    /// The file to read from.
21    #[pin]
22    file: T,
23    /// The sentinel value to keep the file alive.
24    sentinel: Arc<Sentinel<T>>,
25    /// The number of bytes read. Used to keep track
26    /// of how many bytes need to be read from the underlying buffer.
27    read: AtomicUsize,
28}
29
30/// These IDs never leave the current system, so the node ID is arbitrary.
31static NODE_ID: &[u8; 6] = &[2, 3, 0, 6, 1, 2];
32
33impl<T> SharedFileReader<T>
34where
35    T: SharedFileType<Type = T>,
36{
37    pub(crate) fn new(file: T, sentinel: Arc<Sentinel<T>>) -> Self {
38        Self {
39            id: Uuid::now_v1(NODE_ID),
40            file,
41            sentinel,
42            read: AtomicUsize::new(0),
43        }
44    }
45
46    /// Creates a new, independent reader.
47    pub async fn fork(&self) -> Result<Self, T::OpenError> {
48        Ok(Self {
49            id: Uuid::now_v1(NODE_ID),
50            file: self.sentinel.original.open_ro().await?,
51            sentinel: self.sentinel.clone(),
52            read: AtomicUsize::new(0),
53        })
54    }
55}
56
57impl<T> SharedFileReader<T> {
58    /// Gets the (expected) size of the file to read.
59    pub fn file_size(&self) -> FileSize {
60        match self.sentinel.state.load() {
61            WriteState::Pending(commited, _written) => FileSize::AtLeast(commited),
62            WriteState::Completed(size) => FileSize::Exactly(size),
63            WriteState::Failed => FileSize::Error,
64        }
65    }
66}
67
68/// The file size of the file to read.
69#[derive(Debug, Copy, Clone)]
70pub enum FileSize {
71    /// The file is not entirely written yet. The specified amount is the minimum
72    /// number known to exist.
73    AtLeast(usize),
74    /// The file is completely written and has exactly the specified amount of bytes.
75    Exactly(usize),
76    /// An error occurred while writing the file; reading may not complete.
77    Error,
78}
79
80impl FileSize {
81    /// Returns the minimum or exact file size if it is known, or [`None`] otherwise.
82    pub fn minimum_size(&self) -> Option<usize> {
83        if let Self::AtLeast(len) = self {
84            Some(*len)
85        } else {
86            self.exact_size()
87        }
88    }
89
90    /// Returns the exact file size if it is known, or [`None`] otherwise.
91    pub fn exact_size(&self) -> Option<usize> {
92        if let Self::Exactly(len) = self {
93            Some(*len)
94        } else {
95            None
96        }
97    }
98}
99
100#[pinned_drop]
101impl<T> PinnedDrop for SharedFileReader<T> {
102    fn drop(mut self: Pin<&mut Self>) {
103        self.sentinel.remove_reader_waker(&self.id)
104    }
105}
106
107impl<T> AsyncRead for SharedFileReader<T>
108where
109    T: AsyncRead,
110{
111    fn poll_read(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &mut ReadBuf<'_>,
115    ) -> Poll<io::Result<()>> {
116        let read_so_far = self.read.load(Ordering::Acquire);
117
118        let current_total = match self.sentinel.state.load() {
119            WriteState::Pending(committed, _written) => {
120                // If the number of committed bytes is the same as the number
121                // of bytes we have already read, try again later.
122                if read_so_far == committed {
123                    self.sentinel.register_reader_waker(self.id, cx.waker());
124                    return Poll::Pending;
125                }
126                committed
127            }
128            WriteState::Completed(count) => {
129                // If we have read all there is, we're done.
130                if read_so_far == count {
131                    return Poll::Ready(Ok(()));
132                }
133                count
134            }
135            WriteState::Failed => {
136                return Poll::Ready(Err(io::Error::new(
137                    ErrorKind::BrokenPipe,
138                    ReadError::FileClosed,
139                )))
140            }
141        };
142
143        // Ensure to not read more bytes than were actually written
144        // by constraining the actual buffer to a smaller one if needed.
145        let read_at_most = (current_total - read_so_far).min(buf.remaining());
146        let mut smaller_buf = buf.take(read_at_most);
147        let read_offset = smaller_buf.filled().len();
148
149        let this = self.project();
150
151        if let Poll::Ready(result) = this.file.poll_read(cx, &mut smaller_buf) {
152            this.sentinel.remove_reader_waker(this.id);
153            if let Err(e) = result {
154                return Poll::Ready(Err(e));
155            }
156
157            // If the buffer was advanced, return the result.
158            let read_now = smaller_buf.filled().len();
159            if read_now != read_offset {
160                // Advance the parent buffer.
161                unsafe {
162                    buf.assume_init(read_now);
163                }
164                buf.set_filled(read_now);
165
166                let read = read_so_far + (read_now - read_offset);
167                this.read.store(read, Ordering::Release);
168                return Poll::Ready(result);
169            }
170
171            // If the buffer was not advanced and source file is completed (or in fail state),
172            // return as-is. Otherwise, reset and wait.
173            match this.sentinel.state.load() {
174                WriteState::Pending(_, _) => {}
175                WriteState::Completed(_) => return Poll::Ready(Ok(())),
176                WriteState::Failed => {
177                    return Poll::Ready(Err(io::Error::new(
178                        ErrorKind::BrokenPipe,
179                        ReadError::FileClosed,
180                    )))
181                }
182            }
183        }
184
185        // "Advance" the parent buffer.
186        buf.advance(0);
187
188        // Re-register waker and try again.
189        this.sentinel.register_reader_waker(*this.id, cx.waker());
190        Poll::Pending
191    }
192}
193
194impl<T> AsyncSeek for SharedFileReader<T>
195where
196    T: AsyncSeek,
197{
198    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
199        let this = self.project();
200        this.file.start_seek(position)
201    }
202
203    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
204        let this = self.project();
205        this.file.poll_complete(cx)
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_exact_size() {
215        assert_eq!(FileSize::Exactly(42).exact_size(), Some(42));
216        assert_eq!(FileSize::AtLeast(41).exact_size(), None);
217        assert_eq!(FileSize::Error.exact_size(), None);
218    }
219
220    #[test]
221    fn test_minimum_size() {
222        assert_eq!(FileSize::Exactly(42).minimum_size(), Some(42));
223        assert_eq!(FileSize::AtLeast(41).minimum_size(), Some(41));
224        assert_eq!(FileSize::Error.minimum_size(), None);
225    }
226}