async_file_lock/
lib.rs

1#![deny(unused_must_use)]
2// #![cfg_attr(test, feature(test))]
3
4use std::task::{Context, Poll};
5use std::pin::Pin;
6use std::fmt::Formatter;
7use std::fmt::Debug;
8use std::future::Future;
9use std::io::{Error, Result, SeekFrom, Seek};
10use tokio::fs::{File, OpenOptions};
11use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite};
12use tokio::task::{spawn_blocking, JoinHandle};
13use futures_lite::{ready, FutureExt};
14use fs3::FileExt;
15use std::path::Path;
16use std::mem::MaybeUninit;
17
18/// Locks a file asynchronously.
19/// Auto locks a file if any read or write methods are called. If [Self::lock_exclusive]
20/// or [Self::lock_shared] has been called then the file will stay locked.
21/// Can auto seek to specified location before doing any read/write operation.
22///
23/// Note 1: Do not attempt to have multiple file handles for the same file. Because locking is done
24/// per process basis.
25/// Note 2: Remember to open a file with specified read and/or write mode as write calls will just
26/// be ignored if the file is opened in read mode.
27pub struct FileLock {
28    mode: SeekFrom,
29    state: State,
30    is_manually_locked: bool,
31    unlocked_file: Option<std::fs::File>,
32    locked_file: Option<File>,
33    result: Option<Result<u64>>,
34    locking_fut: Option<JoinHandle<std::result::Result<File, (std::fs::File, Error)>>>,
35    unlocking_fut: Option<Pin<Box<dyn Future<Output = std::fs::File> + Send>>>,
36    seek_fut: Option<JoinHandle<(Result<u64>, std::fs::File)>>,
37}
38
39impl FileLock {
40    /// Opens a file in read and write mode that is unlocked.
41    // This function will create a file if it does not exist, and will truncate it if it does.
42    pub async fn create(path: impl AsRef<Path>) -> Result<FileLock> {
43        let file = OpenOptions::new().write(true).read(true).create(true).truncate(true).open(path).await?;
44        Ok(FileLock::new_tokio(file).await)
45    }
46
47    /// Attempts to open a file in read and write mode that is unlocked.
48    pub async fn open(path: impl AsRef<Path>) -> Result<FileLock> {
49        let file = OpenOptions::new().write(true).read(true).open(path).await?;
50        Ok(FileLock::new_tokio(file).await)
51    }
52
53    /// Creates a new 'FileLock' from [`tokio::fs::File`].
54    pub async fn new_tokio(tokio_file: File) -> FileLock {
55        FileLock {
56            mode: SeekFrom::Current(0),
57            state: State::Unlocked,
58            is_manually_locked: false,
59            unlocked_file: Some(tokio_file.into_std().await),
60            locked_file: None,
61            result: None,
62            locking_fut: None,
63            unlocking_fut: None,
64            seek_fut: None
65        }
66    }
67
68    /// Creates a new 'FileLock' from [`std::fs::File`].
69    pub fn new_std(std_file: std::fs::File) -> FileLock {
70        FileLock {
71            mode: SeekFrom::Current(0),
72            state: State::Unlocked,
73            is_manually_locked: false,
74            unlocked_file: Some(std_file),
75            locked_file: None,
76            result: None,
77            locking_fut: None,
78            unlocking_fut: None,
79            seek_fut: None
80        }
81    }
82
83    /// Locks the file for reading and writing until [`Self::unlock`] is called.
84    pub fn lock_exclusive(&mut self) -> LockFuture {
85        if self.locked_file.is_some() {
86            panic!("File already locked.");
87        }
88        self.is_manually_locked = true;
89        LockFuture::new_exclusive(self)
90    }
91
92    /// Locks the file for reading and writing until [`Self::unlock`] is called. Returns an error if
93    /// the file is currently locked.
94    pub fn try_lock_exclusive(&mut self) -> Result<()> {
95        if self.locked_file.is_some() {
96            panic!("File already locked.");
97        }
98        self.is_manually_locked = true;
99        self.unlocked_file.as_mut().unwrap().try_lock_exclusive().map(|_| {
100            self.locked_file = Some(File::from_std(self.unlocked_file.take().unwrap()));
101            self.state = State::Locked;
102
103        })
104    }
105
106    /// Locks the file for reading until [`Self::unlock`] is called.
107    pub fn lock_shared(&mut self) -> LockFuture {
108        if self.locked_file.is_some() {
109            panic!("File already locked.");
110        }
111        self.is_manually_locked = true;
112        LockFuture::new_shared(self)
113    }
114
115    /// Locks the file for reading until [`Self::unlock`] is called. Returns an error if the file
116    /// is currently locked.
117    pub fn try_lock_shared(&mut self) -> Result<()> {
118        if self.locked_file.is_some() {
119            panic!("File already locked.");
120        }
121        self.is_manually_locked = true;
122        self.unlocked_file.as_mut().unwrap().try_lock_shared().map(|_| {
123            self.locked_file = Some(File::from_std(self.unlocked_file.take().unwrap()));
124            self.state = State::Locked;
125
126        })
127    }
128
129    /// Unlocks the file.
130    pub fn unlock(&mut self) -> UnlockFuture {
131        if self.unlocked_file.is_some() {
132            panic!("File already unlocked.");
133        }
134        self.is_manually_locked = false;
135        UnlockFuture::new(self)
136    }
137
138    /// Sets auto seeking mode. File will always seek to specified location before doing any
139    /// read/write operation.
140    pub fn set_seeking_mode(&mut self, mode: SeekFrom) {
141        self.mode = mode;
142    }
143
144    pub fn seeking_mode(&self) -> SeekFrom {
145        self.mode
146    }
147
148    /// Attempts to sync all OS-internal metadata to disk.
149    ///
150    /// This function will attempt to ensure that all in-core data reaches the
151    /// filesystem before returning.
152    ///
153    /// # Examples
154    ///
155    /// ```no_run
156    /// use tokio::fs::File;
157    /// use tokio::prelude::*;
158    ///
159    /// # async fn dox() -> std::io::Result<()> {
160    /// let mut file = File::create("foo.txt").await?;
161    /// file.write_all(b"hello, world!").await?;
162    /// file.sync_all().await?;
163    /// # Ok(())
164    /// # }
165    /// ```
166    ///
167    /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait.
168    ///
169    /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
170    /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
171    pub async fn sync_all(&mut self) -> Result<()> {
172        if let Some(file) = &mut self.locked_file {
173            return file.sync_all().await;
174        }
175        let file = self.unlocked_file.take().unwrap();
176        let (result, file) = spawn_blocking(|| {
177            (file.sync_all(), file)
178        }).await.unwrap();
179        self.unlocked_file = Some(file);
180        result
181    }
182
183    /// This function is similar to `sync_all`, except that it may not
184    /// synchronize file metadata to the filesystem.
185    ///
186    /// This is intended for use cases that must synchronize content, but don't
187    /// need the metadata on disk. The goal of this method is to reduce disk
188    /// operations.
189    ///
190    /// Note that some platforms may simply implement this in terms of `sync_all`.
191    ///
192    /// # Examples
193    ///
194    /// ```no_run
195    /// use tokio::fs::File;
196    /// use tokio::prelude::*;
197    ///
198    /// # async fn dox() -> std::io::Result<()> {
199    /// let mut file = File::create("foo.txt").await?;
200    /// file.write_all(b"hello, world!").await?;
201    /// file.sync_data().await?;
202    /// # Ok(())
203    /// # }
204    /// ```
205    ///
206    /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait.
207    ///
208    /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
209    /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
210    pub async fn sync_data(&mut self) -> Result<()> {
211        if let Some(file) = &mut self.locked_file {
212            return file.sync_data().await;
213        }
214        let file = self.unlocked_file.take().unwrap();
215        let (result, file) = spawn_blocking(|| {
216            (file.sync_data(), file)
217        }).await.unwrap();
218        self.unlocked_file = Some(file);
219        result
220    }
221
222    /// Gets a reference to the file.
223    ///
224    /// If the file is locked it will be in the second element of a tuple as [`tokio::fs::File`]
225    /// otherwise it will be in the first element as [`std::fs::File`].
226    /// It is inadvisable to directly read/write from/to the file.
227    pub fn get_ref(&self) -> (Option<&std::fs::File>, Option<&File>) {
228        (self.unlocked_file.as_ref(), self.locked_file.as_ref())
229    }
230
231    /// Gets a mutable reference to the file.
232    ///
233    /// If the file is locked it will be in the second element of a tuple as [`tokio::fs::File`]
234    /// otherwise it will be in the first element as [`std::fs::File`].
235    /// It is inadvisable to directly read/write from/to the file.
236    pub fn get_mut(&mut self) -> (Option<&mut std::fs::File>, Option<&mut File>) {
237        (self.unlocked_file.as_mut(), self.locked_file.as_mut())
238    }
239
240    fn poll_exclusive_lock(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
241        loop {
242            match &mut self.locking_fut {
243                None => {
244                    LockFuture::new_exclusive(self);
245                }
246                Some(_) => return self.poll_locking_fut(cx),
247            }
248        }
249    }
250
251    fn poll_shared_lock(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
252        loop {
253            match &mut self.locking_fut {
254                None => {
255                    LockFuture::new_shared(self);
256                }
257                Some(_) => return self.poll_locking_fut(cx),
258            }
259        }
260    }
261
262    fn poll_unlock(&mut self, cx: &mut Context<'_>) -> Poll<()> {
263        loop {
264            match &mut self.unlocking_fut {
265                None => {
266                    UnlockFuture::new(self);
267                }
268                Some(fut) => {
269                    // println!("unlocking");
270                    let file = ready!(fut.poll(cx));
271                    let result = file.unlock();
272                    self.unlocked_file = Some(file);
273                    if let Err(e) = result {
274                        self.result = Some(Err(e));
275                    }
276                    self.state = State::Unlocked;
277                    self.unlocking_fut.take();
278                    // println!("unlocked");
279                    return Poll::Ready(());
280                }
281            }
282        }
283    }
284
285    fn poll_locking_fut(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
286        let result = ready!(self.locking_fut.as_mut().unwrap().poll(cx)).unwrap();
287        self.locking_fut.take();
288        return match result {
289            Ok(file) => {
290                self.locked_file = Some(file);
291                self.state = State::Locked;
292                Poll::Ready(Ok(()))
293            }
294            Err((file, e)) => {
295                self.unlocked_file = Some(file);
296                self.state = State::Unlocked;
297                Poll::Ready(Err(e))
298            }
299        };
300    }
301}
302
303macro_rules! poll_loop {
304    ($self:ident, $cx:ident, $unlocked_map:expr, $lock:ident, State::Working => $working:block) => {
305        loop {
306            match $self.state {
307                State::Unlocked => {
308                    if let Some(result) = $self.result.take() {
309                        return Poll::Ready(result.map($unlocked_map));
310                    }
311                    $self.state = State::Locking;
312                }
313                State::Unlocking => ready!($self.poll_unlock($cx)),
314                #[allow(unused_must_use)]
315                State::Locked => match $self.mode {
316                    SeekFrom::Current(0) => $self.state = State::Working,
317                    _ => {
318                        let mode = $self.mode;
319                        $self.as_mut().start_seek($cx, mode);
320                        $self.state = State::Seeking;
321                    }
322                },
323                State::Working => {
324                    // println!("working");
325                    $working
326                    // println!("worked");
327                },
328                State::Locking => {
329                    if let Err(e) = ready!($self.$lock($cx)) {
330                        return Poll::Ready(Err(e));
331                    }
332                }
333                State::Seeking => match ready!($self.as_mut().poll_complete($cx)) {
334                    Ok(_) => $self.state = State::Working,
335                    Err(e) => return Poll::Ready(Err(e)),
336                },
337            }
338        }
339    };
340}
341
342impl AsyncWrite for FileLock {
343    fn poll_write(
344        mut self: Pin<&mut Self>,
345        cx: &mut Context<'_>,
346        buf: &[u8],
347    ) -> Poll<Result<usize>> {
348        poll_loop! {self, cx, |x| x as usize, poll_exclusive_lock,
349            State::Working => {
350                let result = ready!(Pin::new(self.locked_file.as_mut().unwrap())
351                        .as_mut()
352                        .poll_write(cx, buf));
353                // println!("written {:?}", &buf[..*result.as_ref().unwrap()]);
354                if self.is_manually_locked {
355                    self.state = State::Locked;
356                    return Poll::Ready(result);
357                } else {
358                    self.state = State::Unlocking;
359                    self.result = Some(result.map(|x| x as u64));
360                }
361            }
362            // State::Flushing => {
363            //     if let Err(e) = ready!(self.as_mut().poll_flush(cx)) {
364            //         self.result = Some(Err(e));
365            //     }
366            //     self.state = State::Unlocking;
367            // }
368        };
369    }
370
371    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
372        // println!("flushing");
373        poll_loop! {self, cx, |_| (), poll_exclusive_lock,
374            State::Working => {
375                let result = ready!(Pin::new(self.locked_file.as_mut().unwrap())
376                        .as_mut()
377                        .poll_flush(cx));
378                // println!("flushed");
379                if self.is_manually_locked {
380                    self.state = State::Locked;
381                    return Poll::Ready(result);
382                } else {
383                    self.state = State::Unlocking;
384                    self.result = Some(result.map(|_| 0));
385                }
386            }
387            // State::Flushing => {
388            //     let result = ready!(Pin::new(self.locked_file.as_mut().unwrap())
389            //             .as_mut()
390            //             .poll_flush(cx));
391            //     // println!("flushed");
392            //     return Poll::Ready(result);
393            // }
394        };
395    }
396
397    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
398        // println!("shutting down");
399        // We don't have to do anything as files are unlocked when underlying tokio file reports
400        // some progress. Looking at implementation of shutdown for `tokio::fs::File` says that it
401        // does nothing.
402        Poll::Ready(Ok(()))
403    }
404}
405
406impl AsyncRead for FileLock {
407    unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
408        false
409    }
410
411    fn poll_read(
412        mut self: Pin<&mut Self>,
413        cx: &mut Context<'_>,
414        buf: &mut [u8],
415    ) -> Poll<Result<usize>> {
416        poll_loop! {self, cx, |x| x as usize, poll_shared_lock,
417            State::Working => {
418                let result = ready!(Pin::new(self.locked_file.as_mut().unwrap())
419                        .as_mut()
420                        .poll_read(cx, buf));
421                if self.is_manually_locked {
422                    self.state = State::Locked;
423                    return Poll::Ready(result);
424                } else {
425                    self.state = State::Unlocking;
426                    self.result = Some(result.map(|x| x as u64));
427                }
428            }
429        };
430    }
431}
432
433impl AsyncSeek for FileLock {
434    fn start_seek(
435        mut self: Pin<&mut Self>,
436        cx: &mut Context<'_>,
437        position: SeekFrom,
438    ) -> Poll<Result<()>> {
439        if let Some(ref mut locked_file) = self.locked_file {
440            return Pin::new(locked_file)
441                .as_mut()
442                .start_seek(cx, position);
443        }
444        let mut file = self.unlocked_file.take().expect("Cannot seek while in the process of locking/unlocking/seeking");
445        self.seek_fut = Some(spawn_blocking(move || {
446            (file.seek(position), file)
447        }));
448        return Poll::Ready(Ok(()));
449    }
450
451    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> {
452        if let Some(ref mut locked_file) = self.locked_file {
453            return Pin::new(locked_file)
454                    .as_mut()
455                    .poll_complete(cx)
456        }
457        let (result, file) = ready!(Pin::new(self.seek_fut.as_mut().unwrap()).poll(cx)).unwrap();
458        self.seek_fut = None;
459        self.unlocked_file = Some(file);
460        return Poll::Ready(result);
461    }
462}
463
464impl Debug for FileLock {
465    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
466        let mut debug = f.debug_struct("FileLock");
467        match self.state {
468            State::Unlocked => {
469                debug.field("file", self.unlocked_file.as_ref().unwrap());
470            }
471            State::Locked => {
472                debug.field("file", self.locked_file.as_ref().unwrap());
473            }
474            _ => panic!("Invalid state"),
475        }
476        debug.field("mode", &self.mode).finish()
477    }
478}
479
480enum State {
481    Unlocked,
482    Unlocking,
483    Locked,
484    Locking,
485    Seeking,
486    Working,
487}
488
489pub struct LockFuture<'a> {
490    file_lock: &'a mut FileLock,
491}
492
493impl<'a> LockFuture<'a> {
494    fn new_exclusive(file_lock: &'a mut FileLock) -> LockFuture<'a> {
495        // println!("locking exclusive");
496        let unlocked_file = file_lock.unlocked_file.take().unwrap();
497        file_lock.locking_fut = Some(spawn_blocking(move || {
498            let result = match unlocked_file.lock_exclusive() {
499                Ok(_) => Ok(File::from_std(unlocked_file)),
500                Err(e) => Err((unlocked_file, e)),
501            };
502            // println!("locked exclusive");
503            result
504        }));
505        LockFuture { file_lock }
506    }
507
508    fn new_shared(file_lock: &'a mut FileLock) -> LockFuture<'a> {
509        // println!("locking shared");
510        let unlocked_file = file_lock.unlocked_file.take().unwrap();
511        file_lock.locking_fut = Some(spawn_blocking(move || {
512            let result = match unlocked_file.lock_shared() {
513                Ok(_) => Ok(File::from_std(unlocked_file)),
514                Err(e) => Err((unlocked_file, e)),
515            };
516            // println!("locked shared");
517            result
518        }));
519        LockFuture { file_lock }
520    }
521}
522
523impl<'a> Future for LockFuture<'a> {
524    type Output = Result<()>;
525
526    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
527        self.file_lock.poll_locking_fut(cx)
528    }
529}
530
531pub struct UnlockFuture<'a> {
532    file_lock: &'a mut FileLock,
533}
534
535impl<'a> UnlockFuture<'a> {
536    fn new(file_lock: &'a mut FileLock) -> UnlockFuture<'a> {
537        file_lock.unlocking_fut = Some(file_lock.locked_file.take().unwrap().into_std().boxed());
538        file_lock.state = State::Unlocking;
539        UnlockFuture { file_lock }
540    }
541}
542
543impl<'a> Future for UnlockFuture<'a> {
544    type Output = ();
545
546    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
547        self.file_lock.poll_unlock(cx)
548    }
549}
550