Skip to main content

zsync_rs/
control.rs

1use std::io::{BufRead, Read};
2
3use crate::rsum::Rsum;
4
5#[derive(Debug, thiserror::Error)]
6pub enum ParseError {
7    #[error("IO error: {0}")]
8    Io(#[from] std::io::Error),
9    #[error("Invalid header: {0}")]
10    InvalidHeader(String),
11    #[error("Missing required field: {0}")]
12    MissingField(String),
13    #[error("Invalid blocksize: {0}")]
14    InvalidBlocksize(String),
15    #[error("Invalid hash lengths: {0}")]
16    InvalidHashLengths(String),
17    #[error("Invalid length: {0}")]
18    InvalidLength(String),
19    #[error("Unexpected end of file")]
20    UnexpectedEof,
21}
22
23#[derive(Debug, Clone, Copy)]
24pub struct BlockChecksum {
25    pub rsum: Rsum,
26    pub checksum: [u8; 16],
27}
28
29#[derive(Debug, Clone)]
30pub struct ControlFile {
31    pub version: String,
32    pub filename: Option<String>,
33    pub mtime: Option<String>,
34    pub blocksize: usize,
35    pub length: u64,
36    pub hash_lengths: HashLengths,
37    pub urls: Vec<String>,
38    pub sha1: Option<String>,
39    pub block_checksums: Vec<BlockChecksum>,
40}
41
42#[derive(Debug, Clone, Copy)]
43pub struct HashLengths {
44    pub seq_matches: u8,
45    pub rsum_bytes: u8,
46    pub checksum_bytes: u8,
47}
48
49impl Default for HashLengths {
50    fn default() -> Self {
51        Self {
52            seq_matches: 1,
53            rsum_bytes: 4,
54            checksum_bytes: 16,
55        }
56    }
57}
58
59impl ControlFile {
60    pub fn parse<R: Read>(reader: R) -> Result<Self, ParseError> {
61        let mut reader = std::io::BufReader::new(reader);
62        let mut line = String::new();
63
64        let mut version = String::new();
65        let mut filename = None;
66        let mut mtime = None;
67        let mut blocksize = None;
68        let mut length = None;
69        let mut hash_lengths = HashLengths::default();
70        let mut urls = Vec::new();
71        let mut sha1 = None;
72
73        loop {
74            line.clear();
75            let bytes_read = reader.read_line(&mut line)?;
76            if bytes_read == 0 {
77                return Err(ParseError::UnexpectedEof);
78            }
79
80            let trimmed = line.trim_end_matches(['\n', '\r', ' ']);
81            if trimmed.is_empty() {
82                break;
83            }
84
85            let Some((key, value)) = trimmed.split_once(':') else {
86                return Err(ParseError::InvalidHeader(trimmed.to_string()));
87            };
88
89            let value = value.trim_start_matches(' ');
90
91            match key {
92                "zsync" => {
93                    version = value.to_string();
94                }
95                "Filename" => {
96                    filename = Some(value.to_string());
97                }
98                "MTime" => {
99                    mtime = Some(value.to_string());
100                }
101                "Blocksize" => {
102                    let bs: usize = value
103                        .parse()
104                        .map_err(|_| ParseError::InvalidBlocksize(value.to_string()))?;
105                    if bs == 0 || (bs & (bs - 1)) != 0 {
106                        return Err(ParseError::InvalidBlocksize(value.to_string()));
107                    }
108                    blocksize = Some(bs);
109                }
110                "Length" => {
111                    length = Some(
112                        value
113                            .parse()
114                            .map_err(|_| ParseError::InvalidLength(value.to_string()))?,
115                    );
116                }
117                "URL" => {
118                    urls.push(value.to_string());
119                }
120                "Hash-Lengths" => {
121                    let parts: Vec<&str> = value.split(',').collect();
122                    if parts.len() != 3 {
123                        return Err(ParseError::InvalidHashLengths(value.to_string()));
124                    }
125                    let seq_matches: u8 = parts[0]
126                        .parse()
127                        .map_err(|_| ParseError::InvalidHashLengths(value.to_string()))?;
128                    let rsum_bytes: u8 = parts[1]
129                        .parse()
130                        .map_err(|_| ParseError::InvalidHashLengths(value.to_string()))?;
131                    let checksum_bytes: u8 = parts[2]
132                        .parse()
133                        .map_err(|_| ParseError::InvalidHashLengths(value.to_string()))?;
134
135                    if !(1..=2).contains(&seq_matches)
136                        || !(1..=4).contains(&rsum_bytes)
137                        || !(3..=16).contains(&checksum_bytes)
138                    {
139                        return Err(ParseError::InvalidHashLengths(value.to_string()));
140                    }
141
142                    hash_lengths = HashLengths {
143                        seq_matches,
144                        rsum_bytes,
145                        checksum_bytes,
146                    };
147                }
148                "SHA-1" => {
149                    if value.len() != 40 {
150                        return Err(ParseError::InvalidHeader(
151                            "SHA-1 digest wrong length".to_string(),
152                        ));
153                    }
154                    sha1 = Some(value.to_string());
155                }
156                _ => {}
157            }
158        }
159
160        let blocksize =
161            blocksize.ok_or_else(|| ParseError::MissingField("Blocksize".to_string()))?;
162        let length: u64 = length.ok_or_else(|| ParseError::MissingField("Length".to_string()))?;
163
164        let num_blocks = length.div_ceil(blocksize as u64) as usize;
165
166        // Sanity check: avoid massive allocations from malformed input.
167        // Each block needs (rsum_bytes + checksum_bytes) of data following the header.
168        const MAX_BLOCKS: usize = 64 * 1024 * 1024;
169        if num_blocks > MAX_BLOCKS {
170            return Err(ParseError::InvalidLength(format!(
171                "too many blocks: {num_blocks}"
172            )));
173        }
174
175        let block_checksums = Self::read_block_checksums(&mut reader, num_blocks, hash_lengths)?;
176
177        Ok(Self {
178            version,
179            filename,
180            mtime,
181            blocksize,
182            length,
183            hash_lengths,
184            urls,
185            sha1,
186            block_checksums,
187        })
188    }
189
190    fn read_block_checksums<R: BufRead>(
191        reader: &mut R,
192        num_blocks: usize,
193        hash_lengths: HashLengths,
194    ) -> Result<Vec<BlockChecksum>, ParseError> {
195        let mut checksums = Vec::with_capacity(num_blocks);
196        let entry_size = (hash_lengths.rsum_bytes + hash_lengths.checksum_bytes) as usize;
197        let mut buf = vec![0u8; entry_size];
198
199        for _ in 0..num_blocks {
200            reader.read_exact(&mut buf)?;
201
202            let rsum_bytes = hash_lengths.rsum_bytes as usize;
203            let (rsum_a, rsum_b) = match rsum_bytes {
204                1 => (0u16, u16::from(buf[0])),
205                2 => (0u16, u16::from_be_bytes([buf[0], buf[1]])),
206                3 => (u16::from(buf[0]), u16::from_be_bytes([buf[1], buf[2]])),
207                4 => (
208                    u16::from_be_bytes([buf[0], buf[1]]),
209                    u16::from_be_bytes([buf[2], buf[3]]),
210                ),
211                _ => (0, 0),
212            };
213
214            let mut checksum = [0u8; 16];
215            checksum[..hash_lengths.checksum_bytes as usize]
216                .copy_from_slice(&buf[rsum_bytes..entry_size]);
217
218            checksums.push(BlockChecksum {
219                rsum: Rsum {
220                    a: rsum_a,
221                    b: rsum_b,
222                },
223                checksum,
224            });
225        }
226
227        Ok(checksums)
228    }
229
230    pub fn num_blocks(&self) -> usize {
231        self.block_checksums.len()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_parse_minimal() {
241        let mut data = Vec::new();
242        data.extend_from_slice(
243            b"zsync: 0.6.2\nBlocksize: 2048\nLength: 2048\nHash-Lengths: 1,4,16\n\n",
244        );
245        data.extend_from_slice(&[0u8; 20]);
246        let result = ControlFile::parse(&data[..]);
247        assert!(result.is_ok());
248        let cf = result.unwrap();
249        assert_eq!(cf.blocksize, 2048);
250        assert_eq!(cf.length, 2048);
251    }
252
253    #[test]
254    fn test_parse_missing_blocksize() {
255        let data = b"zsync: 0.6.2\nLength: 4096\n\n";
256        let result = ControlFile::parse(&data[..]);
257        assert!(result.is_err());
258    }
259
260    #[test]
261    fn test_parse_invalid_blocksize() {
262        let data = b"zsync: 0.6.2\nBlocksize: 1000\nLength: 4096\n\n";
263        let result = ControlFile::parse(&data[..]);
264        assert!(result.is_err());
265    }
266}