nydus_utils/
reader.rs

1// Copyright (C) 2022 Alibaba Cloud. All rights reserved.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5use std::fs::File;
6use std::io::{BufReader, Read, Seek, SeekFrom};
7use std::marker::PhantomData;
8use std::os::unix::io::{AsRawFd, RawFd};
9use std::sync::{Arc, Mutex};
10
11use sha2::Sha256;
12
13use crate::digest::DigestHasher;
14
15/// A wrapper reader to read a range of data from a file.
16pub struct FileRangeReader<'a> {
17    fd: RawFd,
18    offset: u64,
19    size: u64,
20    r: PhantomData<&'a u8>,
21}
22
23impl FileRangeReader<'_> {
24    /// Create a wrapper reader to read a range of data from the file.
25    pub fn new(f: &File, offset: u64, size: u64) -> Self {
26        Self {
27            fd: f.as_raw_fd(),
28            offset,
29            size,
30            r: PhantomData,
31        }
32    }
33}
34
35impl Read for FileRangeReader<'_> {
36    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
37        let size = std::cmp::min(self.size as usize, buf.len());
38        let nr_read = nix::sys::uio::pread(self.fd, &mut buf[0..size], self.offset as i64)
39            .map_err(|_| last_error!())?;
40        self.offset += nr_read as u64;
41        self.size -= nr_read as u64;
42        Ok(nr_read)
43    }
44}
45
46struct BufReaderState<R: Read> {
47    reader: BufReader<R>,
48    pos: u64,
49    hash: Sha256,
50}
51
52/// A wrapper over `BufReader` to track current position.
53pub struct BufReaderInfo<R: Read> {
54    calc_digest: bool,
55    state: Arc<Mutex<BufReaderState<R>>>,
56}
57
58impl<R: Read> BufReaderInfo<R> {
59    /// Create a new instance of `BufReaderPos` from a `BufReader`.
60    pub fn from_buf_reader(buf_reader: BufReader<R>) -> Self {
61        let state = BufReaderState {
62            reader: buf_reader,
63            pos: 0,
64            hash: Sha256::default(),
65        };
66        Self {
67            calc_digest: true,
68            state: Arc::new(Mutex::new(state)),
69        }
70    }
71
72    /// Get current position of the reader.
73    pub fn position(&self) -> u64 {
74        self.state.lock().unwrap().pos
75    }
76
77    /// Get the hash object.
78    pub fn get_hash_object(&self) -> Sha256 {
79        self.state.lock().unwrap().hash.clone()
80    }
81
82    /// Enable or disable blob digest calculation.
83    pub fn enable_digest_calculation(&mut self, enable: bool) {
84        self.calc_digest = enable;
85    }
86}
87
88impl<R: Read> Read for BufReaderInfo<R> {
89    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
90        let mut state = self.state.lock().unwrap();
91        state.reader.read(buf).inspect(|&v| {
92            state.pos += v as u64;
93            if v > 0 && self.calc_digest {
94                state.hash.digest_update(&buf[..v]);
95            }
96        })
97    }
98}
99
100impl<R: Read + Seek> Seek for BufReaderInfo<R> {
101    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
102        let mut state = self.state.lock().unwrap();
103        let pos = state.reader.seek(pos)?;
104        state.pos = pos;
105        Ok(pos)
106    }
107}
108
109impl<R: Read> Clone for BufReaderInfo<R> {
110    fn clone(&self) -> Self {
111        Self {
112            calc_digest: self.calc_digest,
113            state: self.state.clone(),
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use vmm_sys_util::tempfile::TempFile;
122
123    #[test]
124    fn test_file_range_reader() {
125        let file = TempFile::new().unwrap();
126        std::fs::write(file.as_path(), b"This is a test").unwrap();
127        let mut reader = FileRangeReader::new(file.as_file(), 4, 6);
128        let mut buf = vec![0u8; 128];
129        let res = reader.read(&mut buf).unwrap();
130        assert_eq!(res, 6);
131        assert_eq!(&buf[..6], b" is a ".as_slice());
132        let res = reader.read(&mut buf).unwrap();
133        assert_eq!(res, 0);
134    }
135}