1use std::task::Poll;
3
4use futures_util::stream::{BoxStream, Stream};
5use futures_util::TryStreamExt;
6
7use crate::digest::Digester;
8use crate::errors::DigestError;
9
10pub struct SizedStream {
12 pub content_length: Option<u64>,
14 pub digest_header_value: Option<String>,
18 pub stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
20}
21
22impl Stream for SizedStream {
23 type Item = Result<bytes::Bytes, std::io::Error>;
24
25 fn poll_next(
26 mut self: std::pin::Pin<&mut Self>,
27 cx: &mut std::task::Context<'_>,
28 ) -> Poll<Option<Self::Item>> {
29 self.stream.try_poll_next_unpin(cx)
30 }
31}
32
33pub enum BlobResponse {
35 Full(SizedStream),
37 Partial(SizedStream),
39}
40
41pub(crate) struct VerifyingStream {
42 stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
43 layer_digester: Digester,
44 expected_layer_digest: String,
45 header_digester: Option<(Digester, String)>,
46}
47
48impl VerifyingStream {
49 pub fn new(
50 stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
51 layer_digester: Digester,
52 expected_layer_digest: String,
53 header_digester_and_digest: Option<(Digester, String)>,
54 ) -> Self {
55 Self {
56 stream,
57 layer_digester,
58 expected_layer_digest,
59 header_digester: header_digester_and_digest,
60 }
61 }
62}
63
64impl Stream for VerifyingStream {
65 type Item = Result<bytes::Bytes, std::io::Error>;
66
67 fn poll_next(
68 self: std::pin::Pin<&mut Self>,
69 cx: &mut std::task::Context<'_>,
70 ) -> Poll<Option<Self::Item>> {
71 let this = self.get_mut();
72 match futures_util::ready!(this.stream.as_mut().poll_next(cx)) {
73 Some(Ok(bytes)) => {
74 this.layer_digester.update(&bytes);
75 if let Some((digester, _)) = this.header_digester.as_mut() {
76 digester.update(&bytes);
77 }
78 Poll::Ready(Some(Ok(bytes)))
79 }
80 Some(Err(e)) => Poll::Ready(Some(Err(e))),
81 None => {
82 match this.header_digester.as_mut() {
84 Some((digester, expected)) => {
85 let digest = digester.finalize();
87 if digest != *expected {
88 return Poll::Ready(Some(Err(std::io::Error::new(
89 std::io::ErrorKind::Other,
90 DigestError::VerificationError {
91 expected: expected.clone(),
92 actual: digest,
93 },
94 ))));
95 }
96 let digest = this.layer_digester.finalize();
97 if digest == this.expected_layer_digest {
98 Poll::Ready(None)
99 } else {
100 Poll::Ready(Some(Err(std::io::Error::new(
101 std::io::ErrorKind::Other,
102 DigestError::VerificationError {
103 expected: expected.clone(),
104 actual: digest,
105 },
106 ))))
107 }
108 }
109 None => {
110 let digest = this.layer_digester.finalize();
111 if digest == this.expected_layer_digest {
112 Poll::Ready(None)
113 } else {
114 Poll::Ready(Some(Err(std::io::Error::new(
115 std::io::ErrorKind::Other,
116 DigestError::VerificationError {
117 expected: this.expected_layer_digest.clone(),
118 actual: digest,
119 },
120 ))))
121 }
122 }
123 }
124 }
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 use bytes::Bytes;
134 use futures_util::TryStreamExt;
135 use sha2::Digest as _;
136
137 #[tokio::test]
138 async fn test_verifying_stream() {
139 let data = b"Hello, world!";
141 let correct_sha = format!("sha256:{:x}", sha2::Sha256::digest(data));
142 let stream = VerifyingStream::new(
143 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
144 data,
145 ))])),
146 Digester::new(&correct_sha).unwrap(),
147 correct_sha.clone(),
148 None,
149 );
150 stream
151 .try_collect::<Vec<_>>()
152 .await
153 .expect("Should not error with valid data");
154
155 let incorrect_sha = "sha256:incorrect_hash";
157 let stream = VerifyingStream::new(
158 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
159 data,
160 ))])),
161 Digester::new(incorrect_sha).unwrap(),
162 incorrect_sha.to_string(),
163 None,
164 );
165
166 let err = stream
167 .try_collect::<Vec<_>>()
168 .await
169 .expect_err("Should error with invalid sha");
170
171 let err = err
172 .into_inner()
173 .expect("Should have inner error")
174 .downcast::<DigestError>()
175 .expect("Should be a DigestError");
176 assert!(
177 matches!(*err, DigestError::VerificationError { .. }),
178 "Error should be a verification error"
179 );
180
181 let correct_header_sha = format!("sha512:{:x}", sha2::Sha512::digest(data));
183 let stream = VerifyingStream::new(
184 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
185 data,
186 ))])),
187 Digester::new(&correct_sha).unwrap(),
188 correct_sha.clone(),
189 Some((
190 Digester::new(&correct_header_sha).unwrap(),
191 correct_header_sha.clone(),
192 )),
193 );
194 stream
195 .try_collect::<Vec<_>>()
196 .await
197 .expect("Should not error with valid data");
198
199 let incorrect_header_sha = "sha512:incorrect_hash";
201 let stream = VerifyingStream::new(
202 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
203 data,
204 ))])),
205 Digester::new(&correct_sha).unwrap(),
206 correct_sha.clone(),
207 Some((
208 Digester::new(incorrect_header_sha).unwrap(),
209 incorrect_header_sha.to_string(),
210 )),
211 );
212
213 let err = stream
214 .try_collect::<Vec<_>>()
215 .await
216 .expect_err("Should error with invalid sha");
217
218 let err = err
219 .into_inner()
220 .expect("Should have inner error")
221 .downcast::<DigestError>()
222 .expect("Should be a DigestError");
223 assert!(
224 matches!(*err, DigestError::VerificationError { .. }),
225 "Error should be a verification error"
226 );
227 }
228}