1use std::task::Poll;
3
4use futures_util::stream::{BoxStream, Stream};
5use futures_util::TryStreamExt;
6
7use crate::digest::Digester;
8use crate::errors::DigestError;
9
10pub struct SizedStream {
12 pub content_length: Option<u64>,
14 pub digest_header_value: Option<String>,
18 pub stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
20}
21
22impl Stream for SizedStream {
23 type Item = Result<bytes::Bytes, std::io::Error>;
24
25 fn poll_next(
26 mut self: std::pin::Pin<&mut Self>,
27 cx: &mut std::task::Context<'_>,
28 ) -> Poll<Option<Self::Item>> {
29 self.stream.try_poll_next_unpin(cx)
30 }
31}
32
33pub enum BlobResponse {
35 Full(SizedStream),
37 Partial(SizedStream),
39}
40
41pub(crate) struct VerifyingStream {
42 stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
43 layer_digester: Digester,
44 expected_layer_digest: String,
45 header_digester: Option<(Digester, String)>,
46}
47
48impl VerifyingStream {
49 pub fn new(
50 stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
51 layer_digester: Digester,
52 expected_layer_digest: String,
53 header_digester_and_digest: Option<(Digester, String)>,
54 ) -> Self {
55 Self {
56 stream,
57 layer_digester,
58 expected_layer_digest,
59 header_digester: header_digester_and_digest,
60 }
61 }
62}
63
64impl Stream for VerifyingStream {
65 type Item = Result<bytes::Bytes, std::io::Error>;
66
67 fn poll_next(
68 self: std::pin::Pin<&mut Self>,
69 cx: &mut std::task::Context<'_>,
70 ) -> Poll<Option<Self::Item>> {
71 let this = self.get_mut();
72 match futures_util::ready!(this.stream.as_mut().poll_next(cx)) {
73 Some(Ok(bytes)) => {
74 this.layer_digester.update(&bytes);
75 if let Some((digester, _)) = this.header_digester.as_mut() {
76 digester.update(&bytes);
77 }
78 Poll::Ready(Some(Ok(bytes)))
79 }
80 Some(Err(e)) => Poll::Ready(Some(Err(e))),
81 None => {
82 match this.header_digester.as_mut() {
84 Some((digester, expected)) => {
85 let digest = digester.finalize();
87 if digest != *expected {
88 return Poll::Ready(Some(Err(std::io::Error::other(
89 DigestError::VerificationError {
90 expected: expected.clone(),
91 actual: digest,
92 },
93 ))));
94 }
95 let digest = this.layer_digester.finalize();
96 if digest == this.expected_layer_digest {
97 Poll::Ready(None)
98 } else {
99 Poll::Ready(Some(Err(std::io::Error::other(
100 DigestError::VerificationError {
101 expected: expected.clone(),
102 actual: digest,
103 },
104 ))))
105 }
106 }
107 None => {
108 let digest = this.layer_digester.finalize();
109 if digest == this.expected_layer_digest {
110 Poll::Ready(None)
111 } else {
112 Poll::Ready(Some(Err(std::io::Error::other(
113 DigestError::VerificationError {
114 expected: this.expected_layer_digest.clone(),
115 actual: digest,
116 },
117 ))))
118 }
119 }
120 }
121 }
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 use bytes::Bytes;
131 use futures_util::TryStreamExt;
132 use sha2::Digest as _;
133
134 #[tokio::test]
135 async fn test_verifying_stream() {
136 let data = b"Hello, world!";
138 let correct_sha = format!("sha256:{:x}", sha2::Sha256::digest(data));
139 let stream = VerifyingStream::new(
140 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
141 data,
142 ))])),
143 Digester::new(&correct_sha).unwrap(),
144 correct_sha.clone(),
145 None,
146 );
147 stream
148 .try_collect::<Vec<_>>()
149 .await
150 .expect("Should not error with valid data");
151
152 let incorrect_sha = "sha256:incorrect_hash";
154 let stream = VerifyingStream::new(
155 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
156 data,
157 ))])),
158 Digester::new(incorrect_sha).unwrap(),
159 incorrect_sha.to_string(),
160 None,
161 );
162
163 let err = stream
164 .try_collect::<Vec<_>>()
165 .await
166 .expect_err("Should error with invalid sha");
167
168 let err = err
169 .into_inner()
170 .expect("Should have inner error")
171 .downcast::<DigestError>()
172 .expect("Should be a DigestError");
173 assert!(
174 matches!(*err, DigestError::VerificationError { .. }),
175 "Error should be a verification error"
176 );
177
178 let correct_header_sha = format!("sha512:{:x}", sha2::Sha512::digest(data));
180 let stream = VerifyingStream::new(
181 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
182 data,
183 ))])),
184 Digester::new(&correct_sha).unwrap(),
185 correct_sha.clone(),
186 Some((
187 Digester::new(&correct_header_sha).unwrap(),
188 correct_header_sha.clone(),
189 )),
190 );
191 stream
192 .try_collect::<Vec<_>>()
193 .await
194 .expect("Should not error with valid data");
195
196 let incorrect_header_sha = "sha512:incorrect_hash";
198 let stream = VerifyingStream::new(
199 Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
200 data,
201 ))])),
202 Digester::new(&correct_sha).unwrap(),
203 correct_sha.clone(),
204 Some((
205 Digester::new(incorrect_header_sha).unwrap(),
206 incorrect_header_sha.to_string(),
207 )),
208 );
209
210 let err = stream
211 .try_collect::<Vec<_>>()
212 .await
213 .expect_err("Should error with invalid sha");
214
215 let err = err
216 .into_inner()
217 .expect("Should have inner error")
218 .downcast::<DigestError>()
219 .expect("Should be a DigestError");
220 assert!(
221 matches!(*err, DigestError::VerificationError { .. }),
222 "Error should be a verification error"
223 );
224 }
225}