1mod bmap;
2pub use crate::bmap::*;
3mod discarder;
4pub use crate::discarder::*;
5use async_trait::async_trait;
6use futures::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt};
7use futures::TryFutureExt;
8use sha2::{Digest, Sha256};
9use thiserror::Error;
10
11use std::io::Result as IOResult;
12use std::io::{Read, Seek, SeekFrom, Write};
13
14pub trait SeekForward {
16 fn seek_forward(&mut self, offset: u64) -> IOResult<()>;
17}
18
19impl<T: Seek> SeekForward for T {
20 fn seek_forward(&mut self, forward: u64) -> IOResult<()> {
21 self.seek(SeekFrom::Current(forward as i64))?;
22 Ok(())
23 }
24}
25
26#[async_trait]
27pub trait AsyncSeekForward {
28 async fn async_seek_forward(&mut self, offset: u64) -> IOResult<()>;
29}
30
31#[async_trait]
32impl<T: AsyncSeek + Unpin + Send> AsyncSeekForward for T {
33 async fn async_seek_forward(&mut self, forward: u64) -> IOResult<()> {
34 self.seek(SeekFrom::Current(forward as i64)).await?;
35 Ok(())
36 }
37}
38
39#[derive(Debug, Error)]
40pub enum CopyError {
41 #[error("Failed to Read: {0}")]
42 ReadError(std::io::Error),
43 #[error("Failed to Write: {0}")]
44 WriteError(std::io::Error),
45 #[error("Checksum error")]
46 ChecksumError,
47 #[error("Unexpected EOF on input")]
48 UnexpectedEof,
49}
50
51pub fn copy<I, O>(input: &mut I, output: &mut O, map: &Bmap) -> Result<(), CopyError>
52where
53 I: Read + SeekForward,
54 O: Write + SeekForward,
55{
56 let mut hasher = match map.checksum_type() {
57 HashType::Sha256 => Sha256::new(),
58 };
59
60 let mut v = vec![0; 8 * 1024 * 1024];
62
63 let buf = v.as_mut_slice();
64 let mut position = 0;
65 for range in map.block_map() {
66 let forward = range.offset() - position;
67 input.seek_forward(forward).map_err(CopyError::ReadError)?;
68 output
69 .seek_forward(forward)
70 .map_err(CopyError::WriteError)?;
71
72 let mut left = range.length() as usize;
73 while left > 0 {
74 let toread = left.min(buf.len());
75 let r = input
76 .read(&mut buf[0..toread])
77 .map_err(CopyError::ReadError)?;
78 if r == 0 {
79 return Err(CopyError::UnexpectedEof);
80 }
81 hasher.update(&buf[0..r]);
82 output
83 .write_all(&buf[0..r])
84 .map_err(CopyError::WriteError)?;
85 left -= r;
86 }
87 let digest = hasher.finalize_reset();
88 if range.checksum().as_slice() != digest.as_slice() {
89 return Err(CopyError::ChecksumError);
90 }
91
92 position = range.offset() + range.length();
93 }
94
95 Ok(())
96}
97
98pub async fn copy_async<I, O>(input: &mut I, output: &mut O, map: &Bmap) -> Result<(), CopyError>
99where
100 I: AsyncRead + AsyncSeekForward + Unpin,
101 O: AsyncWrite + AsyncSeekForward + Unpin,
102{
103 let mut hasher = match map.checksum_type() {
104 HashType::Sha256 => Sha256::new(),
105 };
106
107 let mut v = vec![0; 8 * 1024 * 1024];
109
110 let buf = v.as_mut_slice();
111 let mut position = 0;
112 for range in map.block_map() {
113 let forward = range.offset() - position;
114 input
115 .async_seek_forward(forward)
116 .map_err(CopyError::ReadError)
117 .await?;
118 output.flush().map_err(CopyError::WriteError).await?;
119 output
120 .async_seek_forward(forward)
121 .map_err(CopyError::WriteError)
122 .await?;
123
124 let mut left = range.length() as usize;
125 while left > 0 {
126 let toread = left.min(buf.len());
127 let r = input
128 .read(&mut buf[0..toread])
129 .map_err(CopyError::ReadError)
130 .await?;
131 if r == 0 {
132 return Err(CopyError::UnexpectedEof);
133 }
134 hasher.update(&buf[0..r]);
135 output
136 .write_all(&buf[0..r])
137 .await
138 .map_err(CopyError::WriteError)?;
139 left -= r;
140 }
141 let digest = hasher.finalize_reset();
142 if range.checksum().as_slice() != digest.as_slice() {
143 return Err(CopyError::ChecksumError);
144 }
145
146 position = range.offset() + range.length();
147 }
148 Ok(())
149}
150
151pub fn copy_nobmap<I, O>(input: &mut I, output: &mut O) -> Result<(), CopyError>
152where
153 I: Read,
154 O: Write,
155{
156 std::io::copy(input, output).map_err(CopyError::WriteError)?;
157 Ok(())
158}
159
160pub async fn copy_async_nobmap<I, O>(input: &mut I, output: &mut O) -> Result<(), CopyError>
161where
162 I: AsyncRead + AsyncSeekForward + Unpin,
163 O: AsyncWrite + AsyncSeekForward + Unpin,
164{
165 futures::io::copy(input, output)
166 .map_err(CopyError::WriteError)
167 .await?;
168 Ok(())
169}