git_internal/internal/zlib/stream/
inflate.rs1use std::{io, io::BufRead};
5
6use flate2::{Decompress, FlushDecompress, Status};
7
8use crate::{internal::object::types::ObjectType, utils::HashAlgorithm};
9
10pub struct ReadBoxed<R> {
16 pub inner: R,
18 pub decompressor: Box<Decompress>,
20 count_hash: bool,
22 pub hash: HashAlgorithm,
25}
26impl<R> ReadBoxed<R>
27where
28 R: BufRead,
29{
30 pub fn new(inner: R, obj_type: ObjectType, size: usize) -> Self {
33 let mut hash = HashAlgorithm::new();
35 hash.update(
36 obj_type
37 .to_bytes()
38 .expect("ReadBoxed::new called with a delta type"),
39 );
40 hash.update(b" ");
41 hash.update(size.to_string().as_bytes());
42 hash.update(b"\0");
43 ReadBoxed {
44 inner,
45 hash,
46 count_hash: true,
47 decompressor: Box::new(Decompress::new(true)),
48 }
49 }
50
51 pub fn new_for_delta(inner: R) -> Self {
54 ReadBoxed {
55 inner,
56 hash: HashAlgorithm::new(),
57 count_hash: false,
58 decompressor: Box::new(Decompress::new(true)),
59 }
60 }
61}
62impl<R> io::Read for ReadBoxed<R>
63where
64 R: BufRead,
65{
66 fn read(&mut self, into: &mut [u8]) -> io::Result<usize> {
67 let o = read(&mut self.inner, &mut self.decompressor, into)?;
68 if self.count_hash {
70 self.hash.update(&into[..o]);
71 }
72 Ok(o)
73 }
74}
75
76fn read(rd: &mut impl BufRead, state: &mut Decompress, mut dst: &mut [u8]) -> io::Result<usize> {
78 let mut total_written = 0;
79 loop {
80 let (written, consumed, ret, eof);
81 {
82 let input = rd.fill_buf()?;
83 eof = input.is_empty();
84 let before_out = state.total_out();
85 let before_in = state.total_in();
86 let flush = if eof {
87 FlushDecompress::Finish
88 } else {
89 FlushDecompress::None
90 };
91 ret = state.decompress(input, dst, flush);
92 written = (state.total_out() - before_out) as usize;
93 total_written += written;
94 dst = &mut dst[written..];
95 consumed = (state.total_in() - before_in) as usize;
96 }
97 rd.consume(consumed);
98
99 match ret {
100 Ok(Status::StreamEnd) => return Ok(total_written),
102 Ok(Status::Ok | Status::BufError) if eof || dst.is_empty() => return Ok(total_written),
104 Ok(Status::Ok | Status::BufError) if consumed != 0 || written != 0 => continue,
106 Ok(Status::Ok | Status::BufError) => unreachable!("Definitely a bug somewhere"),
108 Err(..) => {
109 return Err(io::Error::new(
110 io::ErrorKind::InvalidInput,
111 "corrupt deflate stream",
112 ));
113 }
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use std::io::{Read, Write};
121
122 use flate2::{Compression, write::ZlibEncoder};
123 use sha1::{Digest, Sha1};
124
125 use super::*;
126 use crate::hash::{HashKind, ObjectHash, set_hash_kind_for_test};
127
128 fn zlib_compress(data: &[u8]) -> Vec<u8> {
130 let mut enc = ZlibEncoder::new(Vec::new(), Compression::default());
131 enc.write_all(data).unwrap();
132 enc.finish().unwrap()
133 }
134
135 #[test]
137 fn inflate_object_counts_hash() {
138 let _guard = set_hash_kind_for_test(HashKind::Sha1);
139 let body = b"hello\n";
140 let compressed = zlib_compress(body);
141 let cursor = io::Cursor::new(compressed);
142
143 let mut reader = ReadBoxed::new(cursor, ObjectType::Blob, body.len());
144 let mut out = Vec::new();
145 reader.read_to_end(&mut out).unwrap();
146 assert_eq!(out, body);
147
148 let mut expected = Sha1::new();
150 expected.update(ObjectType::Blob.to_bytes().unwrap());
151 expected.update(b" ");
152 expected.update(body.len().to_string());
153 expected.update(b"\0");
154 expected.update(body);
155 assert_eq!(reader.hash.finalize(), expected.finalize().to_vec());
156 }
157
158 #[test]
160 fn inflate_delta_skips_hash() {
161 let _guard = set_hash_kind_for_test(HashKind::Sha1);
162 let body = b"delta bytes";
163 let compressed = zlib_compress(body);
164 let cursor = io::Cursor::new(compressed);
165
166 let mut reader = ReadBoxed::new_for_delta(cursor);
167 let mut out = Vec::new();
168 reader.read_to_end(&mut out).unwrap();
169 assert_eq!(out, body);
170
171 let empty_hash = Sha1::new().finalize();
173 assert_eq!(reader.hash.finalize(), empty_hash.to_vec());
174 }
175
176 #[test]
178 fn corrupt_stream_returns_error() {
179 let _guard = set_hash_kind_for_test(HashKind::Sha1);
180 let data = b"not a valid zlib stream";
181 let mut reader = ReadBoxed::new(io::Cursor::new(data), ObjectType::Blob, data.len());
182 let mut out = [0u8; 16];
183 let err = reader.read(&mut out).unwrap_err();
184 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
185 }
186
187 #[test]
189 fn inflate_object_counts_hash_sha256() {
190 let _guard = set_hash_kind_for_test(HashKind::Sha256);
191 let body = b"content";
192 let compressed = zlib_compress(body);
193 let cursor = io::Cursor::new(compressed);
194
195 let mut reader = ReadBoxed::new(cursor, ObjectType::Blob, body.len());
196 let mut out = Vec::new();
197 reader.read_to_end(&mut out).unwrap();
198 assert_eq!(out, body);
199
200 let reader_hash = reader.hash.finalize();
202 let expected = ObjectHash::from_type_and_data(ObjectType::Blob, body);
203
204 assert_eq!(reader_hash.len(), 32);
205 assert_eq!(expected.as_ref().len(), 32);
206 assert_eq!(reader_hash.as_slice(), expected.as_ref());
207 }
208}