1use std::fs::File;
2use std::future::Future;
3use std::io::{self, ErrorKind, IoSlice, Result, Write};
4
5use libsql_sys::wal::either::Either;
6
7use super::buf::{IoBuf, IoBufMut};
8
9pub trait FileExt: Send + Sync + 'static {
10 fn len(&self) -> io::Result<u64>;
11 fn write_all_at(&self, buf: &[u8], offset: u64) -> Result<()> {
12 let mut written = 0;
13
14 while written != buf.len() {
15 written += self.write_at(&buf[written..], offset + written as u64)?;
16 }
17
18 Ok(())
19 }
20 fn write_at_vectored(&self, bufs: &[IoSlice], offset: u64) -> Result<usize>;
21 fn write_at(&self, buf: &[u8], offset: u64) -> Result<usize>;
22
23 fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<usize>;
24 fn read_exact_at(&self, buf: &mut [u8], offset: u64) -> Result<()> {
25 let mut read = 0;
26
27 while read != buf.len() {
28 let n = self.read_at(&mut buf[read..], offset + read as u64)?;
29 if n == 0 {
30 return Err(io::Error::new(
31 ErrorKind::UnexpectedEof,
32 "unexpected end-of-file",
33 ));
34 }
35 read += n;
36 }
37
38 Ok(())
39 }
40
41 fn sync_all(&self) -> Result<()>;
42
43 fn set_len(&self, len: u64) -> Result<()>;
44
45 fn cursor(&self, offset: u64) -> Cursor<Self>
46 where
47 Self: Sized,
48 {
49 Cursor {
50 file: self,
51 offset,
52 count: 0,
53 }
54 }
55
56 #[must_use]
57 fn read_exact_at_async<B: IoBufMut + Send + 'static>(
58 &self,
59 buf: B,
60 offset: u64,
61 ) -> impl Future<Output = (B, Result<()>)> + Send;
62
63 #[must_use]
64 fn read_at_async<B: IoBufMut + Send + 'static>(
65 &self,
66 buf: B,
67 offset: u64,
68 ) -> impl Future<Output = (B, Result<usize>)> + Send;
69
70 #[must_use]
71 fn write_all_at_async<B: IoBuf + Send + 'static>(
72 &self,
73 buf: B,
74 offset: u64,
75 ) -> impl Future<Output = (B, Result<()>)> + Send;
76}
77
78impl<U, V> FileExt for Either<U, V>
79where
80 V: FileExt,
81 U: FileExt,
82{
83 fn len(&self) -> io::Result<u64> {
84 match self {
85 Either::A(x) => x.len(),
86 Either::B(x) => x.len(),
87 }
88 }
89
90 fn write_at_vectored(&self, bufs: &[IoSlice], offset: u64) -> Result<usize> {
91 match self {
92 Either::A(x) => x.write_at_vectored(bufs, offset),
93 Either::B(x) => x.write_at_vectored(bufs, offset),
94 }
95 }
96
97 fn write_at(&self, buf: &[u8], offset: u64) -> Result<usize> {
98 match self {
99 Either::A(x) => x.write_at(buf, offset),
100 Either::B(x) => x.write_at(buf, offset),
101 }
102 }
103
104 fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<usize> {
105 match self {
106 Either::A(x) => x.read_at(buf, offset),
107 Either::B(x) => x.read_at(buf, offset),
108 }
109 }
110
111 fn sync_all(&self) -> Result<()> {
112 match self {
113 Either::A(x) => x.sync_all(),
114 Either::B(x) => x.sync_all(),
115 }
116 }
117
118 fn set_len(&self, len: u64) -> Result<()> {
119 match self {
120 Either::A(x) => x.set_len(len),
121 Either::B(x) => x.set_len(len),
122 }
123 }
124
125 fn read_exact_at_async<B: IoBufMut + Send + 'static>(
126 &self,
127 buf: B,
128 offset: u64,
129 ) -> impl Future<Output = (B, Result<()>)> + Send {
130 async move {
131 match self {
132 Either::A(x) => x.read_exact_at_async(buf, offset).await,
133 Either::B(x) => x.read_exact_at_async(buf, offset).await,
134 }
135 }
136 }
137
138 fn read_at_async<B: IoBufMut + Send + 'static>(
139 &self,
140 buf: B,
141 offset: u64,
142 ) -> impl Future<Output = (B, Result<usize>)> + Send {
143 async move {
144 match self {
145 Either::A(x) => x.read_at_async(buf, offset).await,
146 Either::B(x) => x.read_at_async(buf, offset).await,
147 }
148 }
149 }
150
151 fn write_all_at_async<B: IoBuf + Send + 'static>(
152 &self,
153 buf: B,
154 offset: u64,
155 ) -> impl Future<Output = (B, Result<()>)> + Send {
156 async move {
157 match self {
158 Either::A(x) => x.write_all_at_async(buf, offset).await,
159 Either::B(x) => x.write_all_at_async(buf, offset).await,
160 }
161 }
162 }
163}
164
165impl FileExt for File {
166 fn write_at_vectored(&self, bufs: &[IoSlice], offset: u64) -> Result<usize> {
167 Ok(nix::sys::uio::pwritev(self, bufs, offset as _)?)
168 }
169
170 fn write_at(&self, buf: &[u8], offset: u64) -> Result<usize> {
171 Ok(nix::sys::uio::pwrite(self, buf, offset as _)?)
172 }
173
174 fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<usize> {
175 let n = nix::sys::uio::pread(self, buf, offset as _)?;
176 Ok(n)
177 }
178
179 fn sync_all(&self) -> Result<()> {
180 std::fs::File::sync_all(self)
181 }
182
183 fn set_len(&self, len: u64) -> Result<()> {
184 std::fs::File::set_len(self, len)
185 }
186
187 async fn read_exact_at_async<B: IoBufMut + Send + 'static>(
188 &self,
189 mut buf: B,
190 offset: u64,
191 ) -> (B, Result<()>) {
192 let file = self.try_clone().unwrap();
193 let (buffer, ret) = tokio::task::spawn_blocking(move || {
194 let chunk = unsafe {
197 let len = buf.bytes_total();
198 let ptr = buf.stable_mut_ptr();
199 std::slice::from_raw_parts_mut(ptr, len)
200 };
201
202 let ret = file.read_exact_at(chunk, offset);
203 if ret.is_ok() {
204 unsafe {
205 buf.set_init(buf.bytes_total());
206 }
207 }
208 (buf, ret)
209 })
210 .await
211 .unwrap();
212
213 (buffer, ret)
214 }
215
216 async fn read_at_async<B: IoBufMut + Send + 'static>(
217 &self,
218 mut buf: B,
219 offset: u64,
220 ) -> (B, Result<usize>) {
221 let file = self.try_clone().unwrap();
222 let (buffer, ret) = tokio::task::spawn_blocking(move || {
223 let chunk = unsafe {
226 let len = buf.bytes_total();
227 let ptr = buf.stable_mut_ptr();
228 std::slice::from_raw_parts_mut(ptr, len)
229 };
230
231 let ret = file.read_at(chunk, offset);
232 if let Ok(n) = ret {
233 unsafe {
234 buf.set_init(n);
235 }
236 }
237 (buf, ret)
238 })
239 .await
240 .unwrap();
241
242 (buffer, ret)
243 }
244
245 async fn write_all_at_async<B: IoBuf + Send + 'static>(
246 &self,
247 buf: B,
248 offset: u64,
249 ) -> (B, Result<()>) {
250 let file = self.try_clone().unwrap();
251 let (buffer, ret) = tokio::task::spawn_blocking(move || {
252 let buffer = unsafe { std::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) };
253 let ret = file.write_all_at(buffer, offset);
254 (buf, ret)
255 })
256 .await
257 .unwrap();
258
259 (buffer, ret)
260 }
261
262 fn len(&self) -> io::Result<u64> {
263 Ok(self.metadata()?.len())
264 }
265}
266
267#[derive(Debug)]
268pub struct Cursor<'a, T> {
269 file: &'a T,
270 offset: u64,
271 count: u64,
272}
273
274impl<T> Cursor<'_, T> {
275 pub fn count(&self) -> u64 {
276 self.count
277 }
278}
279
280impl<T: FileExt> Write for Cursor<'_, T> {
281 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
282 let count = self.file.write_at(buf, self.offset + self.count)?;
283 self.count += count as u64;
284 Ok(count)
285 }
286
287 fn flush(&mut self) -> std::io::Result<()> {
288 Ok(())
289 }
290}
291
292pub struct BufCopy<W> {
293 w: W,
294 buf: Vec<u8>,
295}
296
297impl<W> BufCopy<W> {
298 pub fn new(w: W) -> Self {
299 Self { w, buf: Vec::new() }
300 }
301
302 pub fn into_parts(self) -> (W, Vec<u8>) {
303 let Self { w, buf } = self;
304 (w, buf)
305 }
306
307 pub fn get_ref(&self) -> &W {
308 &self.w
309 }
310}
311
312impl<W: Write> Write for BufCopy<W> {
313 fn write(&mut self, buf: &[u8]) -> Result<usize> {
314 let count = self.w.write(buf)?;
315 self.buf.extend_from_slice(&buf[..count]);
316 Ok(count)
317 }
318
319 fn flush(&mut self) -> Result<()> {
320 self.w.flush()
321 }
322}
323
324#[cfg(test)]
325mod test {
326 use std::io::Read;
327
328 use tempfile::tempfile;
329
330 use super::*;
331
332 #[tokio::test]
333 async fn test_write_async() {
334 let mut file = tempfile().unwrap();
335
336 let buf = vec![1u8; 12345];
337 let (buf, ret) = file.write_all_at_async(buf, 0).await;
338 ret.unwrap();
339 assert_eq!(buf.len(), 12345);
340 assert!(buf.iter().all(|x| *x == 1));
341
342 let buf = vec![2u8; 50];
343 let (buf, ret) = file.write_all_at_async(buf, 12345).await;
344 ret.unwrap();
345 assert_eq!(buf.len(), 50);
346 assert!(buf.iter().all(|x| *x == 2));
347
348 let mut out = Vec::new();
349 file.read_to_end(&mut out).unwrap();
350 assert!(out[0..12345].iter().all(|x| *x == 1));
351 assert!(out[12345..].iter().all(|x| *x == 2));
352 }
353
354 #[tokio::test]
355 async fn test_read() {
356 let mut file = tempfile().unwrap();
357
358 file.write_all(&[1; 12345]).unwrap();
359 file.write_all(&[2; 50]).unwrap();
360
361 let buf = vec![0u8; 12345];
362 let (buf, ret) = file.read_exact_at_async(buf, 0).await;
363 ret.unwrap();
364 assert_eq!(buf.len(), 12345);
365 assert!(buf.iter().all(|x| *x == 1));
366
367 let buf = vec![2u8; 50];
368 let (buf, ret) = file.read_exact_at_async(buf, 12345).await;
369 ret.unwrap();
370 assert_eq!(buf.len(), 50);
371 assert!(buf.iter().all(|x| *x == 2));
372 }
373}