libsql_wal/io/
file.rs

1use std::fs::File;
2use std::future::Future;
3use std::io::{self, ErrorKind, IoSlice, Result, Write};
4
5use libsql_sys::wal::either::Either;
6
7use super::buf::{IoBuf, IoBufMut};
8
9pub trait FileExt: Send + Sync + 'static {
10    fn len(&self) -> io::Result<u64>;
11    fn write_all_at(&self, buf: &[u8], offset: u64) -> Result<()> {
12        let mut written = 0;
13
14        while written != buf.len() {
15            written += self.write_at(&buf[written..], offset + written as u64)?;
16        }
17
18        Ok(())
19    }
20    fn write_at_vectored(&self, bufs: &[IoSlice], offset: u64) -> Result<usize>;
21    fn write_at(&self, buf: &[u8], offset: u64) -> Result<usize>;
22
23    fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<usize>;
24    fn read_exact_at(&self, buf: &mut [u8], offset: u64) -> Result<()> {
25        let mut read = 0;
26
27        while read != buf.len() {
28            let n = self.read_at(&mut buf[read..], offset + read as u64)?;
29            if n == 0 {
30                return Err(io::Error::new(
31                    ErrorKind::UnexpectedEof,
32                    "unexpected end-of-file",
33                ));
34            }
35            read += n;
36        }
37
38        Ok(())
39    }
40
41    fn sync_all(&self) -> Result<()>;
42
43    fn set_len(&self, len: u64) -> Result<()>;
44
45    fn cursor(&self, offset: u64) -> Cursor<Self>
46    where
47        Self: Sized,
48    {
49        Cursor {
50            file: self,
51            offset,
52            count: 0,
53        }
54    }
55
56    #[must_use]
57    fn read_exact_at_async<B: IoBufMut + Send + 'static>(
58        &self,
59        buf: B,
60        offset: u64,
61    ) -> impl Future<Output = (B, Result<()>)> + Send;
62
63    #[must_use]
64    fn read_at_async<B: IoBufMut + Send + 'static>(
65        &self,
66        buf: B,
67        offset: u64,
68    ) -> impl Future<Output = (B, Result<usize>)> + Send;
69
70    #[must_use]
71    fn write_all_at_async<B: IoBuf + Send + 'static>(
72        &self,
73        buf: B,
74        offset: u64,
75    ) -> impl Future<Output = (B, Result<()>)> + Send;
76}
77
78impl<U, V> FileExt for Either<U, V>
79where
80    V: FileExt,
81    U: FileExt,
82{
83    fn len(&self) -> io::Result<u64> {
84        match self {
85            Either::A(x) => x.len(),
86            Either::B(x) => x.len(),
87        }
88    }
89
90    fn write_at_vectored(&self, bufs: &[IoSlice], offset: u64) -> Result<usize> {
91        match self {
92            Either::A(x) => x.write_at_vectored(bufs, offset),
93            Either::B(x) => x.write_at_vectored(bufs, offset),
94        }
95    }
96
97    fn write_at(&self, buf: &[u8], offset: u64) -> Result<usize> {
98        match self {
99            Either::A(x) => x.write_at(buf, offset),
100            Either::B(x) => x.write_at(buf, offset),
101        }
102    }
103
104    fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<usize> {
105        match self {
106            Either::A(x) => x.read_at(buf, offset),
107            Either::B(x) => x.read_at(buf, offset),
108        }
109    }
110
111    fn sync_all(&self) -> Result<()> {
112        match self {
113            Either::A(x) => x.sync_all(),
114            Either::B(x) => x.sync_all(),
115        }
116    }
117
118    fn set_len(&self, len: u64) -> Result<()> {
119        match self {
120            Either::A(x) => x.set_len(len),
121            Either::B(x) => x.set_len(len),
122        }
123    }
124
125    fn read_exact_at_async<B: IoBufMut + Send + 'static>(
126        &self,
127        buf: B,
128        offset: u64,
129    ) -> impl Future<Output = (B, Result<()>)> + Send {
130        async move {
131            match self {
132                Either::A(x) => x.read_exact_at_async(buf, offset).await,
133                Either::B(x) => x.read_exact_at_async(buf, offset).await,
134            }
135        }
136    }
137
138    fn read_at_async<B: IoBufMut + Send + 'static>(
139        &self,
140        buf: B,
141        offset: u64,
142    ) -> impl Future<Output = (B, Result<usize>)> + Send {
143        async move {
144            match self {
145                Either::A(x) => x.read_at_async(buf, offset).await,
146                Either::B(x) => x.read_at_async(buf, offset).await,
147            }
148        }
149    }
150
151    fn write_all_at_async<B: IoBuf + Send + 'static>(
152        &self,
153        buf: B,
154        offset: u64,
155    ) -> impl Future<Output = (B, Result<()>)> + Send {
156        async move {
157            match self {
158                Either::A(x) => x.write_all_at_async(buf, offset).await,
159                Either::B(x) => x.write_all_at_async(buf, offset).await,
160            }
161        }
162    }
163}
164
165impl FileExt for File {
166    fn write_at_vectored(&self, bufs: &[IoSlice], offset: u64) -> Result<usize> {
167        Ok(nix::sys::uio::pwritev(self, bufs, offset as _)?)
168    }
169
170    fn write_at(&self, buf: &[u8], offset: u64) -> Result<usize> {
171        Ok(nix::sys::uio::pwrite(self, buf, offset as _)?)
172    }
173
174    fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<usize> {
175        let n = nix::sys::uio::pread(self, buf, offset as _)?;
176        Ok(n)
177    }
178
179    fn sync_all(&self) -> Result<()> {
180        std::fs::File::sync_all(self)
181    }
182
183    fn set_len(&self, len: u64) -> Result<()> {
184        std::fs::File::set_len(self, len)
185    }
186
187    async fn read_exact_at_async<B: IoBufMut + Send + 'static>(
188        &self,
189        mut buf: B,
190        offset: u64,
191    ) -> (B, Result<()>) {
192        let file = self.try_clone().unwrap();
193        let (buffer, ret) = tokio::task::spawn_blocking(move || {
194            // let mut read = 0;
195
196            let chunk = unsafe {
197                let len = buf.bytes_total();
198                let ptr = buf.stable_mut_ptr();
199                std::slice::from_raw_parts_mut(ptr, len)
200            };
201
202            let ret = file.read_exact_at(chunk, offset);
203            if ret.is_ok() {
204                unsafe {
205                    buf.set_init(buf.bytes_total());
206                }
207            }
208            (buf, ret)
209        })
210        .await
211        .unwrap();
212
213        (buffer, ret)
214    }
215
216    async fn read_at_async<B: IoBufMut + Send + 'static>(
217        &self,
218        mut buf: B,
219        offset: u64,
220    ) -> (B, Result<usize>) {
221        let file = self.try_clone().unwrap();
222        let (buffer, ret) = tokio::task::spawn_blocking(move || {
223            // let mut read = 0;
224
225            let chunk = unsafe {
226                let len = buf.bytes_total();
227                let ptr = buf.stable_mut_ptr();
228                std::slice::from_raw_parts_mut(ptr, len)
229            };
230
231            let ret = file.read_at(chunk, offset);
232            if let Ok(n) = ret {
233                unsafe {
234                    buf.set_init(n);
235                }
236            }
237            (buf, ret)
238        })
239        .await
240        .unwrap();
241
242        (buffer, ret)
243    }
244
245    async fn write_all_at_async<B: IoBuf + Send + 'static>(
246        &self,
247        buf: B,
248        offset: u64,
249    ) -> (B, Result<()>) {
250        let file = self.try_clone().unwrap();
251        let (buffer, ret) = tokio::task::spawn_blocking(move || {
252            let buffer = unsafe { std::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) };
253            let ret = file.write_all_at(buffer, offset);
254            (buf, ret)
255        })
256        .await
257        .unwrap();
258
259        (buffer, ret)
260    }
261
262    fn len(&self) -> io::Result<u64> {
263        Ok(self.metadata()?.len())
264    }
265}
266
267#[derive(Debug)]
268pub struct Cursor<'a, T> {
269    file: &'a T,
270    offset: u64,
271    count: u64,
272}
273
274impl<T> Cursor<'_, T> {
275    pub fn count(&self) -> u64 {
276        self.count
277    }
278}
279
280impl<T: FileExt> Write for Cursor<'_, T> {
281    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
282        let count = self.file.write_at(buf, self.offset + self.count)?;
283        self.count += count as u64;
284        Ok(count)
285    }
286
287    fn flush(&mut self) -> std::io::Result<()> {
288        Ok(())
289    }
290}
291
292pub struct BufCopy<W> {
293    w: W,
294    buf: Vec<u8>,
295}
296
297impl<W> BufCopy<W> {
298    pub fn new(w: W) -> Self {
299        Self { w, buf: Vec::new() }
300    }
301
302    pub fn into_parts(self) -> (W, Vec<u8>) {
303        let Self { w, buf } = self;
304        (w, buf)
305    }
306
307    pub fn get_ref(&self) -> &W {
308        &self.w
309    }
310}
311
312impl<W: Write> Write for BufCopy<W> {
313    fn write(&mut self, buf: &[u8]) -> Result<usize> {
314        let count = self.w.write(buf)?;
315        self.buf.extend_from_slice(&buf[..count]);
316        Ok(count)
317    }
318
319    fn flush(&mut self) -> Result<()> {
320        self.w.flush()
321    }
322}
323
324#[cfg(test)]
325mod test {
326    use std::io::Read;
327
328    use tempfile::tempfile;
329
330    use super::*;
331
332    #[tokio::test]
333    async fn test_write_async() {
334        let mut file = tempfile().unwrap();
335
336        let buf = vec![1u8; 12345];
337        let (buf, ret) = file.write_all_at_async(buf, 0).await;
338        ret.unwrap();
339        assert_eq!(buf.len(), 12345);
340        assert!(buf.iter().all(|x| *x == 1));
341
342        let buf = vec![2u8; 50];
343        let (buf, ret) = file.write_all_at_async(buf, 12345).await;
344        ret.unwrap();
345        assert_eq!(buf.len(), 50);
346        assert!(buf.iter().all(|x| *x == 2));
347
348        let mut out = Vec::new();
349        file.read_to_end(&mut out).unwrap();
350        assert!(out[0..12345].iter().all(|x| *x == 1));
351        assert!(out[12345..].iter().all(|x| *x == 2));
352    }
353
354    #[tokio::test]
355    async fn test_read() {
356        let mut file = tempfile().unwrap();
357
358        file.write_all(&[1; 12345]).unwrap();
359        file.write_all(&[2; 50]).unwrap();
360
361        let buf = vec![0u8; 12345];
362        let (buf, ret) = file.read_exact_at_async(buf, 0).await;
363        ret.unwrap();
364        assert_eq!(buf.len(), 12345);
365        assert!(buf.iter().all(|x| *x == 1));
366
367        let buf = vec![2u8; 50];
368        let (buf, ret) = file.read_exact_at_async(buf, 12345).await;
369        ret.unwrap();
370        assert_eq!(buf.len(), 50);
371        assert!(buf.iter().all(|x| *x == 2));
372    }
373}