git_internal/internal/zlib/stream/
inflate.rs1use std::{io, io::BufRead};
5
6use flate2::{Decompress, FlushDecompress, Status};
7
8use crate::{internal::object::types::ObjectType, utils::HashAlgorithm};
9
10pub struct ReadBoxed<R> {
16 pub inner: R,
18 pub decompressor: Box<Decompress>,
20 count_hash: bool,
22 pub hash: HashAlgorithm,
25}
26impl<R> ReadBoxed<R>
27where
28 R: BufRead,
29{
30 pub fn new(inner: R, obj_type: ObjectType, size: usize) -> Self {
33 let mut hash = HashAlgorithm::new();
35 hash.update(obj_type.to_bytes());
36 hash.update(b" ");
37 hash.update(size.to_string().as_bytes());
38 hash.update(b"\0");
39 ReadBoxed {
40 inner,
41 hash,
42 count_hash: true,
43 decompressor: Box::new(Decompress::new(true)),
44 }
45 }
46
47 pub fn new_for_delta(inner: R) -> Self {
50 ReadBoxed {
51 inner,
52 hash: HashAlgorithm::new(),
53 count_hash: false,
54 decompressor: Box::new(Decompress::new(true)),
55 }
56 }
57}
58impl<R> io::Read for ReadBoxed<R>
59where
60 R: BufRead,
61{
62 fn read(&mut self, into: &mut [u8]) -> io::Result<usize> {
63 let o = read(&mut self.inner, &mut self.decompressor, into)?;
64 if self.count_hash {
66 self.hash.update(&into[..o]);
67 }
68 Ok(o)
69 }
70}
71
72fn read(rd: &mut impl BufRead, state: &mut Decompress, mut dst: &mut [u8]) -> io::Result<usize> {
74 let mut total_written = 0;
75 loop {
76 let (written, consumed, ret, eof);
77 {
78 let input = rd.fill_buf()?;
79 eof = input.is_empty();
80 let before_out = state.total_out();
81 let before_in = state.total_in();
82 let flush = if eof {
83 FlushDecompress::Finish
84 } else {
85 FlushDecompress::None
86 };
87 ret = state.decompress(input, dst, flush);
88 written = (state.total_out() - before_out) as usize;
89 total_written += written;
90 dst = &mut dst[written..];
91 consumed = (state.total_in() - before_in) as usize;
92 }
93 rd.consume(consumed);
94
95 match ret {
96 Ok(Status::StreamEnd) => return Ok(total_written),
98 Ok(Status::Ok | Status::BufError) if eof || dst.is_empty() => return Ok(total_written),
100 Ok(Status::Ok | Status::BufError) if consumed != 0 || written != 0 => continue,
102 Ok(Status::Ok | Status::BufError) => unreachable!("Definitely a bug somewhere"),
104 Err(..) => {
105 return Err(io::Error::new(
106 io::ErrorKind::InvalidInput,
107 "corrupt deflate stream",
108 ));
109 }
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use std::io::{Read, Write};
117
118 use flate2::{Compression, write::ZlibEncoder};
119 use sha1::{Digest, Sha1};
120
121 use super::*;
122 use crate::hash::{HashKind, ObjectHash, set_hash_kind_for_test};
123
124 fn zlib_compress(data: &[u8]) -> Vec<u8> {
126 let mut enc = ZlibEncoder::new(Vec::new(), Compression::default());
127 enc.write_all(data).unwrap();
128 enc.finish().unwrap()
129 }
130
131 #[test]
133 fn inflate_object_counts_hash() {
134 let _guard = set_hash_kind_for_test(HashKind::Sha1);
135 let body = b"hello\n";
136 let compressed = zlib_compress(body);
137 let cursor = io::Cursor::new(compressed);
138
139 let mut reader = ReadBoxed::new(cursor, ObjectType::Blob, body.len());
140 let mut out = Vec::new();
141 reader.read_to_end(&mut out).unwrap();
142 assert_eq!(out, body);
143
144 let mut expected = Sha1::new();
146 expected.update(ObjectType::Blob.to_bytes());
147 expected.update(b" ");
148 expected.update(body.len().to_string());
149 expected.update(b"\0");
150 expected.update(body);
151 assert_eq!(reader.hash.finalize(), expected.finalize().to_vec());
152 }
153
154 #[test]
156 fn inflate_delta_skips_hash() {
157 let _guard = set_hash_kind_for_test(HashKind::Sha1);
158 let body = b"delta bytes";
159 let compressed = zlib_compress(body);
160 let cursor = io::Cursor::new(compressed);
161
162 let mut reader = ReadBoxed::new_for_delta(cursor);
163 let mut out = Vec::new();
164 reader.read_to_end(&mut out).unwrap();
165 assert_eq!(out, body);
166
167 let empty_hash = Sha1::new().finalize();
169 assert_eq!(reader.hash.finalize(), empty_hash.to_vec());
170 }
171
172 #[test]
174 fn corrupt_stream_returns_error() {
175 let _guard = set_hash_kind_for_test(HashKind::Sha1);
176 let data = b"not a valid zlib stream";
177 let mut reader = ReadBoxed::new(io::Cursor::new(data), ObjectType::Blob, data.len());
178 let mut out = [0u8; 16];
179 let err = reader.read(&mut out).unwrap_err();
180 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
181 }
182
183 #[test]
185 fn inflate_object_counts_hash_sha256() {
186 let _guard = set_hash_kind_for_test(HashKind::Sha256);
187 let body = b"content";
188 let compressed = zlib_compress(body);
189 let cursor = io::Cursor::new(compressed);
190
191 let mut reader = ReadBoxed::new(cursor, ObjectType::Blob, body.len());
192 let mut out = Vec::new();
193 reader.read_to_end(&mut out).unwrap();
194 assert_eq!(out, body);
195
196 let reader_hash = reader.hash.finalize();
198 let expected = ObjectHash::from_type_and_data(ObjectType::Blob, body);
199
200 assert_eq!(reader_hash.len(), 32);
201 assert_eq!(expected.as_ref().len(), 32);
202 assert_eq!(reader_hash.as_slice(), expected.as_ref());
203 }
204}