git_internal/internal/zlib/stream/
inflate.rs1use std::{io, io::BufRead};
5
6use flate2::{Decompress, FlushDecompress, Status};
7
8use crate::internal::object::types::ObjectType;
9use crate::utils::HashAlgorithm;
10
11pub struct ReadBoxed<R> {
17 pub inner: R,
19 pub decompressor: Box<Decompress>,
21 count_hash: bool,
23 pub hash: HashAlgorithm,
26}
27impl<R> ReadBoxed<R>
28where
29 R: BufRead,
30{
31 pub fn new(inner: R, obj_type: ObjectType, size: usize) -> Self {
34 let mut hash = HashAlgorithm::new();
36 hash.update(obj_type.to_bytes());
37 hash.update(b" ");
38 hash.update(size.to_string().as_bytes());
39 hash.update(b"\0");
40 ReadBoxed {
41 inner,
42 hash,
43 count_hash: true,
44 decompressor: Box::new(Decompress::new(true)),
45 }
46 }
47
48 pub fn new_for_delta(inner: R) -> Self {
51 ReadBoxed {
52 inner,
53 hash: HashAlgorithm::new(),
54 count_hash: false,
55 decompressor: Box::new(Decompress::new(true)),
56 }
57 }
58}
59impl<R> io::Read for ReadBoxed<R>
60where
61 R: BufRead,
62{
63 fn read(&mut self, into: &mut [u8]) -> io::Result<usize> {
64 let o = read(&mut self.inner, &mut self.decompressor, into)?;
65 if self.count_hash {
67 self.hash.update(&into[..o]);
68 }
69 Ok(o)
70 }
71}
72
73fn read(rd: &mut impl BufRead, state: &mut Decompress, mut dst: &mut [u8]) -> io::Result<usize> {
75 let mut total_written = 0;
76 loop {
77 let (written, consumed, ret, eof);
78 {
79 let input = rd.fill_buf()?;
80 eof = input.is_empty();
81 let before_out = state.total_out();
82 let before_in = state.total_in();
83 let flush = if eof {
84 FlushDecompress::Finish
85 } else {
86 FlushDecompress::None
87 };
88 ret = state.decompress(input, dst, flush);
89 written = (state.total_out() - before_out) as usize;
90 total_written += written;
91 dst = &mut dst[written..];
92 consumed = (state.total_in() - before_in) as usize;
93 }
94 rd.consume(consumed);
95
96 match ret {
97 Ok(Status::StreamEnd) => return Ok(total_written),
99 Ok(Status::Ok | Status::BufError) if eof || dst.is_empty() => return Ok(total_written),
101 Ok(Status::Ok | Status::BufError) if consumed != 0 || written != 0 => continue,
103 Ok(Status::Ok | Status::BufError) => unreachable!("Definitely a bug somewhere"),
105 Err(..) => {
106 return Err(io::Error::new(
107 io::ErrorKind::InvalidInput,
108 "corrupt deflate stream",
109 ));
110 }
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::hash::{HashKind, ObjectHash, set_hash_kind_for_test};
119 use flate2::{Compression, write::ZlibEncoder};
120 use sha1::{Digest, Sha1};
121 use std::io::Read;
122 use std::io::Write;
123
124 fn zlib_compress(data: &[u8]) -> Vec<u8> {
126 let mut enc = ZlibEncoder::new(Vec::new(), Compression::default());
127 enc.write_all(data).unwrap();
128 enc.finish().unwrap()
129 }
130
131 #[test]
133 fn inflate_object_counts_hash() {
134 let _guard = set_hash_kind_for_test(HashKind::Sha1);
135 let body = b"hello\n";
136 let compressed = zlib_compress(body);
137 let cursor = io::Cursor::new(compressed);
138
139 let mut reader = ReadBoxed::new(cursor, ObjectType::Blob, body.len());
140 let mut out = Vec::new();
141 reader.read_to_end(&mut out).unwrap();
142 assert_eq!(out, body);
143
144 let mut expected = Sha1::new();
146 expected.update(ObjectType::Blob.to_bytes());
147 expected.update(b" ");
148 expected.update(body.len().to_string());
149 expected.update(b"\0");
150 expected.update(body);
151 assert_eq!(reader.hash.finalize(), expected.finalize().to_vec());
152 }
153
154 #[test]
156 fn inflate_delta_skips_hash() {
157 let _guard = set_hash_kind_for_test(HashKind::Sha1);
158 let body = b"delta bytes";
159 let compressed = zlib_compress(body);
160 let cursor = io::Cursor::new(compressed);
161
162 let mut reader = ReadBoxed::new_for_delta(cursor);
163 let mut out = Vec::new();
164 reader.read_to_end(&mut out).unwrap();
165 assert_eq!(out, body);
166
167 let empty_hash = Sha1::new().finalize();
169 assert_eq!(reader.hash.finalize(), empty_hash.to_vec());
170 }
171
172 #[test]
174 fn corrupt_stream_returns_error() {
175 let _guard = set_hash_kind_for_test(HashKind::Sha1);
176 let data = b"not a valid zlib stream";
177 let mut reader = ReadBoxed::new(io::Cursor::new(data), ObjectType::Blob, data.len());
178 let mut out = [0u8; 16];
179 let err = reader.read(&mut out).unwrap_err();
180 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
181 }
182
183 #[test]
185 fn inflate_object_counts_hash_sha256() {
186 let _guard = set_hash_kind_for_test(HashKind::Sha256);
187 let body = b"content";
188 let compressed = zlib_compress(body);
189 let cursor = io::Cursor::new(compressed);
190
191 let mut reader = ReadBoxed::new(cursor, ObjectType::Blob, body.len());
192 let mut out = Vec::new();
193 reader.read_to_end(&mut out).unwrap();
194 assert_eq!(out, body);
195
196 let reader_hash = reader.hash.finalize();
198 let expected = ObjectHash::from_type_and_data(ObjectType::Blob, body);
199
200 assert_eq!(reader_hash.len(), 32);
201 assert_eq!(expected.as_ref().len(), 32);
202 assert_eq!(reader_hash.as_slice(), expected.as_ref());
203 }
204}