async_fd_lock/
write_guard.rs

1use std::{
2    io::{self, BufRead, Read, Seek, Write},
3    pin::Pin,
4};
5
6use cfg_if::cfg_if;
7use pin_project::{pin_project, pinned_drop};
8
9use crate::sys::{AsOpenFile, AsOpenFileExt, RwLockGuard};
10
11/// An exclusive lock on a file.
12///
13/// # Panics
14///
15/// Dropping this type may panic if the lock fails to unlock.
16#[must_use = "if unused the RwLock will immediately unlock"]
17#[derive(Debug)]
18#[pin_project(PinnedDrop)]
19pub struct RwLockWriteGuard<T: AsOpenFile> {
20    #[pin]
21    file: Option<T>,
22}
23
24impl<T: AsOpenFile> RwLockWriteGuard<T> {
25    pub(crate) fn new<F: AsOpenFile>(file: T, guard: RwLockGuard<F>) -> Self {
26        guard.defuse();
27        Self { file: Some(file) }
28    }
29
30    pub fn inner(&self) -> &T {
31        self.file
32            .as_ref()
33            .expect("file only removed during release")
34    }
35
36    pub fn inner_mut(&mut self) -> &mut T {
37        self.file
38            .as_mut()
39            .expect("file only removed during release")
40    }
41
42    pub fn inner_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
43        self.project()
44            .file
45            .as_pin_mut()
46            .expect("file only removed during release")
47    }
48
49    /// Releases the lock, returning the inner file.
50    pub fn release(mut self) -> io::Result<T> {
51        let file = self.file.take().expect("file only removed during release");
52        file.release_lock_blocking()?;
53        Ok(file)
54    }
55}
56
57/// Delegate [`Read`] to the inner file.
58impl<T: AsOpenFile + Read> Read for RwLockWriteGuard<T> {
59    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
60        self.inner_mut().read(buf)
61    }
62
63    fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
64        self.inner_mut().read_vectored(bufs)
65    }
66
67    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
68        self.inner_mut().read_to_end(buf)
69    }
70
71    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
72        self.inner_mut().read_to_string(buf)
73    }
74
75    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
76        self.inner_mut().read_exact(buf)
77    }
78
79    fn by_ref(&mut self) -> &mut Self
80    where
81        Self: Sized,
82    {
83        self
84    }
85}
86
87impl<T: AsOpenFile + BufRead> BufRead for RwLockWriteGuard<T> {
88    fn fill_buf(&mut self) -> io::Result<&[u8]> {
89        self.inner_mut().fill_buf()
90    }
91
92    fn consume(&mut self, amt: usize) {
93        self.inner_mut().consume(amt)
94    }
95
96    fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> io::Result<usize> {
97        self.inner_mut().read_until(byte, buf)
98    }
99
100    fn read_line(&mut self, buf: &mut String) -> io::Result<usize> {
101        self.inner_mut().read_line(buf)
102    }
103}
104
105/// Delegate [`Write`] to the inner file.
106impl<T: AsOpenFile + Write> Write for RwLockWriteGuard<T> {
107    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
108        self.inner_mut().write(buf)
109    }
110
111    fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
112        self.inner_mut().write_vectored(bufs)
113    }
114
115    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
116        self.inner_mut().write_all(buf)
117    }
118
119    fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> io::Result<()> {
120        self.inner_mut().write_fmt(fmt)
121    }
122
123    fn by_ref(&mut self) -> &mut Self
124    where
125        Self: Sized,
126    {
127        self
128    }
129
130    fn flush(&mut self) -> io::Result<()> {
131        self.inner_mut().flush()
132    }
133}
134
135impl<T: AsOpenFile + Seek> Seek for RwLockWriteGuard<T> {
136    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
137        self.inner_mut().seek(pos)
138    }
139
140    fn rewind(&mut self) -> io::Result<()> {
141        self.inner_mut().rewind()
142    }
143
144    fn stream_position(&mut self) -> io::Result<u64> {
145        self.inner_mut().stream_position()
146    }
147
148    fn seek_relative(&mut self, offset: i64) -> io::Result<()> {
149        self.inner_mut().seek_relative(offset)
150    }
151}
152
153cfg_if! {
154    if #[cfg(feature = "async")] {
155        use std::task::{Context, Poll};
156        use tokio::io::{AsyncRead, AsyncBufRead, AsyncWrite, AsyncSeek};
157
158        /// Delegate [`AsyncRead`] to the inner file.
159        impl<T: AsOpenFile + AsyncRead> AsyncRead for RwLockWriteGuard<T> {
160            fn poll_read(
161                self: Pin<&mut Self>,
162                cx: &mut std::task::Context<'_>,
163                buf: &mut tokio::io::ReadBuf<'_>,
164            ) -> std::task::Poll<io::Result<()>> {
165                self.inner_pin_mut().poll_read(cx, buf)
166            }
167        }
168
169        impl<T: AsOpenFile + AsyncBufRead> AsyncBufRead for RwLockWriteGuard<T> {
170            fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
171                self.inner_pin_mut().poll_fill_buf(cx)
172            }
173
174            fn consume(self: Pin<&mut Self>, amt: usize) {
175                self.inner_pin_mut().consume(amt)
176            }
177        }
178
179        /// Delegate [`AsyncWrite`] to the inner file.
180        impl<T: AsOpenFile + AsyncWrite> AsyncWrite for RwLockWriteGuard<T> {
181            fn poll_write(
182                self: Pin<&mut Self>,
183                cx: &mut std::task::Context<'_>,
184                buf: &[u8],
185            ) -> std::task::Poll<Result<usize, io::Error>> {
186                self.inner_pin_mut().poll_write(cx, buf)
187            }
188
189            fn poll_write_vectored(
190                self: Pin<&mut Self>,
191                cx: &mut std::task::Context<'_>,
192                bufs: &[io::IoSlice<'_>],
193            ) -> std::task::Poll<Result<usize, io::Error>> {
194                self.inner_pin_mut().poll_write_vectored(cx, bufs)
195            }
196
197            fn is_write_vectored(&self) -> bool {
198                self.inner().is_write_vectored()
199            }
200
201            fn poll_flush(
202                self: Pin<&mut Self>,
203                cx: &mut std::task::Context<'_>,
204            ) -> std::task::Poll<Result<(), io::Error>> {
205                self.inner_pin_mut().poll_flush(cx)
206            }
207
208            fn poll_shutdown(
209                self: Pin<&mut Self>,
210                cx: &mut std::task::Context<'_>,
211            ) -> std::task::Poll<Result<(), io::Error>> {
212                self.inner_pin_mut().poll_shutdown(cx)
213            }
214        }
215
216        impl<T: AsOpenFile + AsyncSeek> AsyncSeek for RwLockWriteGuard<T> {
217            fn start_seek(self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
218                self.inner_pin_mut().start_seek(position)
219            }
220
221            fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
222                self.inner_pin_mut().poll_complete(cx)
223            }
224        }
225    }
226}
227
228/// Release the lock if it was not already released, as indicated by a `None`.
229#[pinned_drop]
230impl<T: AsOpenFile> PinnedDrop for RwLockWriteGuard<T> {
231    #[inline]
232    fn drop(self: Pin<&mut Self>) {
233        if let Some(file) = self.project().file.as_pin_mut() {
234            let _ = file.release_lock_blocking();
235        }
236    }
237}