1use crate::errors::{CompleteWritingError, WriteError};
4use crate::{FilePath, Sentinel, SharedFileType, WriteState};
5use crossbeam::atomic::AtomicCell;
6use pin_project::{pin_project, pinned_drop};
7use std::io::{Error, ErrorKind, IoSlice};
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tokio::io;
13use tokio::io::AsyncWrite;
14
15#[pin_project(PinnedDrop)]
23pub struct SharedFileWriter<T> {
24 #[pin]
26 file: T,
27 sentinel: Arc<Sentinel<T>>,
29}
30
31impl<T> SharedFileWriter<T> {
32 pub(crate) fn new(file: T, sentinel: Arc<Sentinel<T>>) -> Self {
33 Self { file, sentinel }
34 }
35
36 pub fn file_path(&self) -> &PathBuf
38 where
39 T: FilePath,
40 {
41 self.file.file_path()
42 }
43
44 pub async fn sync_all(&self) -> Result<(), T::SyncError>
46 where
47 T: SharedFileType,
48 {
49 self.file.sync_all().await?;
50 Self::sync_committed_and_written(&self.sentinel);
51 self.sentinel.wake_readers();
52 Ok(())
53 }
54
55 pub async fn sync_data(&self) -> Result<(), T::SyncError>
57 where
58 T: SharedFileType,
59 {
60 self.file.sync_data().await?;
61 Self::sync_committed_and_written(&self.sentinel);
62 self.sentinel.wake_readers();
63 Ok(())
64 }
65
66 pub async fn complete(self) -> Result<(), CompleteWritingError>
71 where
72 T: SharedFileType,
73 {
74 if self.sync_all().await.is_err() {
75 return Err(CompleteWritingError::SyncError);
76 }
77 self.complete_no_sync()
78 }
79
80 pub fn complete_no_sync(self) -> Result<(), CompleteWritingError> {
85 self.finalize_state()
86 }
87
88 fn sync_committed_and_written(sentinel: &Arc<Sentinel<T>>) {
90 match sentinel.state.load() {
91 WriteState::Pending(_committed, written) => {
92 sentinel.state.store(WriteState::Pending(written, written));
93 }
94 WriteState::Completed(_) => {}
95 WriteState::Failed => {}
96 }
97 }
98
99 fn finalize_state(&self) -> Result<(), CompleteWritingError> {
103 let result = match self.sentinel.state.load() {
104 WriteState::Pending(_committed, written) => {
105 assert_eq!(_committed, written, "The number of committed bytes is less than the number of written bytes - call sync before dropping");
106 self.sentinel.state.store(WriteState::Completed(written));
107 Ok(())
108 }
109 WriteState::Completed(_) => Ok(()),
110 WriteState::Failed => Err(CompleteWritingError::FileWritingFailed),
111 };
112
113 self.sentinel.wake_readers();
114 result
115 }
116
117 fn update_state(state: &AtomicCell<WriteState>, written: usize) -> Result<usize, Error> {
125 match state.load() {
126 WriteState::Pending(committed, previously_written) => {
127 let count = previously_written + written;
128 state.store(WriteState::Pending(committed, count));
129 Ok(count)
130 }
131 WriteState::Completed(count) => {
132 if written != 0 {
135 return Err(Error::new(ErrorKind::BrokenPipe, WriteError::FileClosed));
136 }
137 Ok(count)
138 }
139 WriteState::Failed => Err(Error::from(ErrorKind::Other)),
140 }
141 }
142
143 fn handle_poll_write_result(
148 sentinel: &Sentinel<T>,
149 poll: Poll<Result<usize, Error>>,
150 ) -> Poll<Result<usize, Error>> {
151 match poll {
152 Poll::Ready(result) => match result {
153 Ok(written) => match Self::update_state(&sentinel.state, written) {
154 Ok(_) => Poll::Ready(Ok(written)),
155 Err(e) => Poll::Ready(Err(e)),
156 },
157 Err(e) => {
158 sentinel.state.store(WriteState::Failed);
159 sentinel.wake_readers();
160 Poll::Ready(Err(e))
161 }
162 },
163 Poll::Pending => Poll::Pending,
164 }
165 }
166}
167
168#[pinned_drop]
169impl<T> PinnedDrop for SharedFileWriter<T> {
170 fn drop(mut self: Pin<&mut Self>) {
171 self.finalize_state().ok();
172 }
173}
174
175impl<T> AsyncWrite for SharedFileWriter<T>
176where
177 T: AsyncWrite,
178{
179 fn poll_write(
180 self: Pin<&mut Self>,
181 cx: &mut Context<'_>,
182 buf: &[u8],
183 ) -> Poll<io::Result<usize>> {
184 let this = self.project();
185 let poll = this.file.poll_write(cx, buf);
186 Self::handle_poll_write_result(this.sentinel, poll)
187 }
188
189 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190 let this = self.project();
191 match this.file.poll_flush(cx) {
192 Poll::Ready(result) => match result {
193 Ok(()) => {
194 Self::sync_committed_and_written(this.sentinel);
195 this.sentinel.wake_readers();
196 Poll::Ready(Ok(()))
197 }
198 Err(e) => {
199 this.sentinel.state.store(WriteState::Failed);
200 this.sentinel.wake_readers();
201 Poll::Ready(Err(e))
202 }
203 },
204 Poll::Pending => Poll::Pending,
205 }
206 }
207
208 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
209 let this = self.project();
210 match this.file.poll_shutdown(cx) {
211 Poll::Ready(result) => match result {
212 Ok(()) => {
213 if let WriteState::Pending(_committed, written) = this.sentinel.state.load() {
214 debug_assert_eq!(_committed, written);
215 this.sentinel.state.store(WriteState::Completed(written));
216 }
217
218 Poll::Ready(Ok(()))
219 }
220 Err(e) => {
221 this.sentinel.state.store(WriteState::Failed);
222 Poll::Ready(Err(e))
223 }
224 },
225 Poll::Pending => Poll::Pending,
226 }
227 }
228
229 fn poll_write_vectored(
230 self: Pin<&mut Self>,
231 cx: &mut Context<'_>,
232 bufs: &[IoSlice<'_>],
233 ) -> Poll<Result<usize, Error>> {
234 let this = self.project();
235 let poll = this.file.poll_write_vectored(cx, bufs);
236 Self::handle_poll_write_result(this.sentinel, poll)
237 }
238
239 fn is_write_vectored(&self) -> bool {
240 self.file.is_write_vectored()
241 }
242}