1use crate::errors::ReadError;
4use crate::{Sentinel, SharedFileType, WriteState};
5use pin_project::{pin_project, pinned_drop};
6use std::io::{ErrorKind, SeekFrom};
7use std::pin::Pin;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tokio::io;
12use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
13use uuid::Uuid;
14
15#[pin_project(PinnedDrop)]
17pub struct SharedFileReader<T> {
18 id: Uuid,
20 #[pin]
22 file: T,
23 sentinel: Arc<Sentinel<T>>,
25 read: AtomicUsize,
28}
29
30static NODE_ID: &[u8; 6] = &[2, 3, 0, 6, 1, 2];
32
33impl<T> SharedFileReader<T>
34where
35 T: SharedFileType<Type = T>,
36{
37 pub(crate) fn new(file: T, sentinel: Arc<Sentinel<T>>) -> Self {
38 Self {
39 id: Uuid::now_v1(NODE_ID),
40 file,
41 sentinel,
42 read: AtomicUsize::new(0),
43 }
44 }
45
46 pub async fn fork(&self) -> Result<Self, T::OpenError> {
48 Ok(Self {
49 id: Uuid::now_v1(NODE_ID),
50 file: self.sentinel.original.open_ro().await?,
51 sentinel: self.sentinel.clone(),
52 read: AtomicUsize::new(0),
53 })
54 }
55}
56
57impl<T> SharedFileReader<T> {
58 pub fn file_size(&self) -> FileSize {
60 match self.sentinel.state.load() {
61 WriteState::Pending(commited, _written) => FileSize::AtLeast(commited),
62 WriteState::Completed(size) => FileSize::Exactly(size),
63 WriteState::Failed => FileSize::Error,
64 }
65 }
66}
67
68#[derive(Debug, Copy, Clone)]
70pub enum FileSize {
71 AtLeast(usize),
74 Exactly(usize),
76 Error,
78}
79
80impl FileSize {
81 pub fn minimum_size(&self) -> Option<usize> {
83 if let Self::AtLeast(len) = self {
84 Some(*len)
85 } else {
86 self.exact_size()
87 }
88 }
89
90 pub fn exact_size(&self) -> Option<usize> {
92 if let Self::Exactly(len) = self {
93 Some(*len)
94 } else {
95 None
96 }
97 }
98}
99
100#[pinned_drop]
101impl<T> PinnedDrop for SharedFileReader<T> {
102 fn drop(mut self: Pin<&mut Self>) {
103 self.sentinel.remove_reader_waker(&self.id)
104 }
105}
106
107impl<T> AsyncRead for SharedFileReader<T>
108where
109 T: AsyncRead,
110{
111 fn poll_read(
112 self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 buf: &mut ReadBuf<'_>,
115 ) -> Poll<io::Result<()>> {
116 let read_so_far = self.read.load(Ordering::Acquire);
117
118 let current_total = match self.sentinel.state.load() {
119 WriteState::Pending(committed, _written) => {
120 if read_so_far == committed {
123 self.sentinel.register_reader_waker(self.id, cx.waker());
124 return Poll::Pending;
125 }
126 committed
127 }
128 WriteState::Completed(count) => {
129 if read_so_far == count {
131 return Poll::Ready(Ok(()));
132 }
133 count
134 }
135 WriteState::Failed => {
136 return Poll::Ready(Err(io::Error::new(
137 ErrorKind::BrokenPipe,
138 ReadError::FileClosed,
139 )))
140 }
141 };
142
143 let read_at_most = (current_total - read_so_far).min(buf.remaining());
146 let mut smaller_buf = buf.take(read_at_most);
147 let read_offset = smaller_buf.filled().len();
148
149 let this = self.project();
150
151 if let Poll::Ready(result) = this.file.poll_read(cx, &mut smaller_buf) {
152 this.sentinel.remove_reader_waker(this.id);
153 if let Err(e) = result {
154 return Poll::Ready(Err(e));
155 }
156
157 let read_now = smaller_buf.filled().len();
159 if read_now != read_offset {
160 unsafe {
162 buf.assume_init(read_now);
163 }
164 buf.set_filled(read_now);
165
166 let read = read_so_far + (read_now - read_offset);
167 this.read.store(read, Ordering::Release);
168 return Poll::Ready(result);
169 }
170
171 match this.sentinel.state.load() {
174 WriteState::Pending(_, _) => {}
175 WriteState::Completed(_) => return Poll::Ready(Ok(())),
176 WriteState::Failed => {
177 return Poll::Ready(Err(io::Error::new(
178 ErrorKind::BrokenPipe,
179 ReadError::FileClosed,
180 )))
181 }
182 }
183 }
184
185 buf.advance(0);
187
188 this.sentinel.register_reader_waker(*this.id, cx.waker());
190 Poll::Pending
191 }
192}
193
194impl<T> AsyncSeek for SharedFileReader<T>
195where
196 T: AsyncSeek,
197{
198 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
199 let this = self.project();
200 this.file.start_seek(position)
201 }
202
203 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
204 let this = self.project();
205 this.file.poll_complete(cx)
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_exact_size() {
215 assert_eq!(FileSize::Exactly(42).exact_size(), Some(42));
216 assert_eq!(FileSize::AtLeast(41).exact_size(), None);
217 assert_eq!(FileSize::Error.exact_size(), None);
218 }
219
220 #[test]
221 fn test_minimum_size() {
222 assert_eq!(FileSize::Exactly(42).minimum_size(), Some(42));
223 assert_eq!(FileSize::AtLeast(41).minimum_size(), Some(41));
224 assert_eq!(FileSize::Error.minimum_size(), None);
225 }
226}