shared_files/
writer.rs

1//! File writing functionality, notably the [`SharedFileWriter`] type.
2
3use crate::errors::{CompleteWritingError, WriteError};
4use crate::{FilePath, Sentinel, SharedFileType, WriteState};
5use crossbeam::atomic::AtomicCell;
6use pin_project::{pin_project, pinned_drop};
7use std::io::{Error, ErrorKind, IoSlice};
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tokio::io;
13use tokio::io::AsyncWrite;
14
15/// A writer for the shared temporary file.
16///
17/// ## Dropping the writer
18///
19/// Note that while dropping the writer while implicitly change it to "completed",
20/// you must manually call [`SharedFileWriter::sync_all`] or [`SharedFileWriter::sync_data`]
21/// to ensure all content is flushed to the underlying buffer.
22#[pin_project(PinnedDrop)]
23pub struct SharedFileWriter<T> {
24    /// The file to write to.
25    #[pin]
26    file: T,
27    /// The sentinel value to keep the file alive.
28    sentinel: Arc<Sentinel<T>>,
29}
30
31impl<T> SharedFileWriter<T> {
32    pub(crate) fn new(file: T, sentinel: Arc<Sentinel<T>>) -> Self {
33        Self { file, sentinel }
34    }
35
36    /// Gets the file path.
37    pub fn file_path(&self) -> &PathBuf
38    where
39        T: FilePath,
40    {
41        self.file.file_path()
42    }
43
44    /// Synchronizes data and metadata with the disk buffer.
45    pub async fn sync_all(&self) -> Result<(), T::SyncError>
46    where
47        T: SharedFileType,
48    {
49        self.file.sync_all().await?;
50        Self::sync_committed_and_written(&self.sentinel);
51        self.sentinel.wake_readers();
52        Ok(())
53    }
54
55    /// Synchronizes data with the disk buffer.
56    pub async fn sync_data(&self) -> Result<(), T::SyncError>
57    where
58        T: SharedFileType,
59    {
60        self.file.sync_data().await?;
61        Self::sync_committed_and_written(&self.sentinel);
62        self.sentinel.wake_readers();
63        Ok(())
64    }
65
66    /// Completes the writing operation.
67    ///
68    /// Use [`complete_no_sync`](Self::complete_no_sync) if you do not wish
69    /// to sync the file to disk.
70    pub async fn complete(self) -> Result<(), CompleteWritingError>
71    where
72        T: SharedFileType,
73    {
74        if self.sync_all().await.is_err() {
75            return Err(CompleteWritingError::SyncError);
76        }
77        self.complete_no_sync()
78    }
79
80    /// Completes the writing operation.
81    ///
82    /// If you need to sync the file to disk, consider calling
83    /// [`complete`](Self::complete) instead.
84    pub fn complete_no_sync(self) -> Result<(), CompleteWritingError> {
85        self.finalize_state()
86    }
87
88    /// Synchronizes the number of written bytes with the number of committed bytes.
89    fn sync_committed_and_written(sentinel: &Arc<Sentinel<T>>) {
90        match sentinel.state.load() {
91            WriteState::Pending(_committed, written) => {
92                sentinel.state.store(WriteState::Pending(written, written));
93            }
94            WriteState::Completed(_) => {}
95            WriteState::Failed => {}
96        }
97    }
98
99    /// Sets the state to finalized.
100    ///
101    /// See also [`update_state`](Self::update_state) for increasing the byte count.
102    fn finalize_state(&self) -> Result<(), CompleteWritingError> {
103        let result = match self.sentinel.state.load() {
104            WriteState::Pending(_committed, written) => {
105                assert_eq!(_committed, written, "The number of committed bytes is less than the number of written bytes - call sync before dropping");
106                self.sentinel.state.store(WriteState::Completed(written));
107                Ok(())
108            }
109            WriteState::Completed(_) => Ok(()),
110            WriteState::Failed => Err(CompleteWritingError::FileWritingFailed),
111        };
112
113        self.sentinel.wake_readers();
114        result
115    }
116
117    /// Updates the internal byte count with the specified number of bytes written.
118    /// Will produce an error if the update failed.
119    ///
120    /// ## Returns
121    /// Returns the number of bytes written in total.
122    ///
123    /// See also [`finalize_state`](Self::finalize_state) for finalizing the write.
124    fn update_state(state: &AtomicCell<WriteState>, written: usize) -> Result<usize, Error> {
125        match state.load() {
126            WriteState::Pending(committed, previously_written) => {
127                let count = previously_written + written;
128                state.store(WriteState::Pending(committed, count));
129                Ok(count)
130            }
131            WriteState::Completed(count) => {
132                // Ensure we do not try to write more data after completing
133                // the file.
134                if written != 0 {
135                    return Err(Error::new(ErrorKind::BrokenPipe, WriteError::FileClosed));
136                }
137                Ok(count)
138            }
139            WriteState::Failed => Err(Error::from(ErrorKind::Other)),
140        }
141    }
142
143    /// Processes a [`Poll`] result from a write operation.
144    ///
145    /// This will update the internal byte count and produce an error
146    /// if the update failed.
147    fn handle_poll_write_result(
148        sentinel: &Sentinel<T>,
149        poll: Poll<Result<usize, Error>>,
150    ) -> Poll<Result<usize, Error>> {
151        match poll {
152            Poll::Ready(result) => match result {
153                Ok(written) => match Self::update_state(&sentinel.state, written) {
154                    Ok(_) => Poll::Ready(Ok(written)),
155                    Err(e) => Poll::Ready(Err(e)),
156                },
157                Err(e) => {
158                    sentinel.state.store(WriteState::Failed);
159                    sentinel.wake_readers();
160                    Poll::Ready(Err(e))
161                }
162            },
163            Poll::Pending => Poll::Pending,
164        }
165    }
166}
167
168#[pinned_drop]
169impl<T> PinnedDrop for SharedFileWriter<T> {
170    fn drop(mut self: Pin<&mut Self>) {
171        self.finalize_state().ok();
172    }
173}
174
175impl<T> AsyncWrite for SharedFileWriter<T>
176where
177    T: AsyncWrite,
178{
179    fn poll_write(
180        self: Pin<&mut Self>,
181        cx: &mut Context<'_>,
182        buf: &[u8],
183    ) -> Poll<io::Result<usize>> {
184        let this = self.project();
185        let poll = this.file.poll_write(cx, buf);
186        Self::handle_poll_write_result(this.sentinel, poll)
187    }
188
189    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190        let this = self.project();
191        match this.file.poll_flush(cx) {
192            Poll::Ready(result) => match result {
193                Ok(()) => {
194                    Self::sync_committed_and_written(this.sentinel);
195                    this.sentinel.wake_readers();
196                    Poll::Ready(Ok(()))
197                }
198                Err(e) => {
199                    this.sentinel.state.store(WriteState::Failed);
200                    this.sentinel.wake_readers();
201                    Poll::Ready(Err(e))
202                }
203            },
204            Poll::Pending => Poll::Pending,
205        }
206    }
207
208    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
209        let this = self.project();
210        match this.file.poll_shutdown(cx) {
211            Poll::Ready(result) => match result {
212                Ok(()) => {
213                    if let WriteState::Pending(_committed, written) = this.sentinel.state.load() {
214                        debug_assert_eq!(_committed, written);
215                        this.sentinel.state.store(WriteState::Completed(written));
216                    }
217
218                    Poll::Ready(Ok(()))
219                }
220                Err(e) => {
221                    this.sentinel.state.store(WriteState::Failed);
222                    Poll::Ready(Err(e))
223                }
224            },
225            Poll::Pending => Poll::Pending,
226        }
227    }
228
229    fn poll_write_vectored(
230        self: Pin<&mut Self>,
231        cx: &mut Context<'_>,
232        bufs: &[IoSlice<'_>],
233    ) -> Poll<Result<usize, Error>> {
234        let this = self.project();
235        let poll = this.file.poll_write_vectored(cx, bufs);
236        Self::handle_poll_write_result(this.sentinel, poll)
237    }
238
239    fn is_write_vectored(&self) -> bool {
240        self.file.is_write_vectored()
241    }
242}