Skip to main content

oci_client/
blob.rs

1//! Helpers for interacting with blobs and their verification
2use std::task::Poll;
3
4use futures_util::stream::{BoxStream, Stream};
5use futures_util::TryStreamExt;
6
7use crate::digest::Digester;
8use crate::errors::DigestError;
9
10/// Stream response of a blob with optional content length if available
11pub struct SizedStream {
12    /// The length of the stream if the upstream registry sent a `Content-Length` header
13    pub content_length: Option<u64>,
14    /// The digest header value if the upstream registry sent a `Digest` header. This should be used
15    /// (in addition to the layer digest) for validation when using partial requests as the library
16    /// can't validate against the full response.
17    pub digest_header_value: Option<String>,
18    /// The stream of bytes
19    pub stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
20}
21
22impl Stream for SizedStream {
23    type Item = Result<bytes::Bytes, std::io::Error>;
24
25    fn poll_next(
26        mut self: std::pin::Pin<&mut Self>,
27        cx: &mut std::task::Context<'_>,
28    ) -> Poll<Option<Self::Item>> {
29        self.stream.try_poll_next_unpin(cx)
30    }
31}
32
33/// The response of a partial blob request
34pub enum BlobResponse {
35    /// The response is a full blob (for example when partial requests aren't supported)
36    Full(SizedStream),
37    /// The response is a partial blob as requested
38    Partial(SizedStream),
39}
40
41pub(crate) struct VerifyingStream {
42    stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
43    layer_digester: Digester,
44    expected_layer_digest: String,
45    header_digester: Option<(Digester, String)>,
46}
47
48impl VerifyingStream {
49    pub fn new(
50        stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
51        layer_digester: Digester,
52        expected_layer_digest: String,
53        header_digester_and_digest: Option<(Digester, String)>,
54    ) -> Self {
55        Self {
56            stream,
57            layer_digester,
58            expected_layer_digest,
59            header_digester: header_digester_and_digest,
60        }
61    }
62}
63
64impl Stream for VerifyingStream {
65    type Item = Result<bytes::Bytes, std::io::Error>;
66
67    fn poll_next(
68        self: std::pin::Pin<&mut Self>,
69        cx: &mut std::task::Context<'_>,
70    ) -> Poll<Option<Self::Item>> {
71        let this = self.get_mut();
72        match futures_util::ready!(this.stream.as_mut().poll_next(cx)) {
73            Some(Ok(bytes)) => {
74                this.layer_digester.update(&bytes);
75                if let Some((digester, _)) = this.header_digester.as_mut() {
76                    digester.update(&bytes);
77                }
78                Poll::Ready(Some(Ok(bytes)))
79            }
80            Some(Err(e)) => Poll::Ready(Some(Err(e))),
81            None => {
82                // Now that we've reached the end of the stream, verify the digest(s)
83                match this.header_digester.as_mut() {
84                    Some((digester, expected)) => {
85                        // Check the header digester and then the layer digester before returning
86                        let digest = digester.finalize();
87                        if digest != *expected {
88                            return Poll::Ready(Some(Err(std::io::Error::other(
89                                DigestError::VerificationError {
90                                    expected: expected.clone(),
91                                    actual: digest,
92                                },
93                            ))));
94                        }
95                        let digest = this.layer_digester.finalize();
96                        if digest == this.expected_layer_digest {
97                            Poll::Ready(None)
98                        } else {
99                            Poll::Ready(Some(Err(std::io::Error::other(
100                                DigestError::VerificationError {
101                                    expected: expected.clone(),
102                                    actual: digest,
103                                },
104                            ))))
105                        }
106                    }
107                    None => {
108                        let digest = this.layer_digester.finalize();
109                        if digest == this.expected_layer_digest {
110                            Poll::Ready(None)
111                        } else {
112                            Poll::Ready(Some(Err(std::io::Error::other(
113                                DigestError::VerificationError {
114                                    expected: this.expected_layer_digest.clone(),
115                                    actual: digest,
116                                },
117                            ))))
118                        }
119                    }
120                }
121            }
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    use bytes::Bytes;
131    use futures_util::TryStreamExt;
132    use sha2::Digest as _;
133
134    #[tokio::test]
135    async fn test_verifying_stream() {
136        // Test with correct SHA
137        let data = b"Hello, world!";
138        let correct_sha = format!("sha256:{:x}", sha2::Sha256::digest(data));
139        let stream = VerifyingStream::new(
140            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
141                data,
142            ))])),
143            Digester::new(&correct_sha).unwrap(),
144            correct_sha.clone(),
145            None,
146        );
147        stream
148            .try_collect::<Vec<_>>()
149            .await
150            .expect("Should not error with valid data");
151
152        // Test with incorrect SHA
153        let incorrect_sha = "sha256:incorrect_hash";
154        let stream = VerifyingStream::new(
155            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
156                data,
157            ))])),
158            Digester::new(incorrect_sha).unwrap(),
159            incorrect_sha.to_string(),
160            None,
161        );
162
163        let err = stream
164            .try_collect::<Vec<_>>()
165            .await
166            .expect_err("Should error with invalid sha");
167
168        let err = err
169            .into_inner()
170            .expect("Should have inner error")
171            .downcast::<DigestError>()
172            .expect("Should be a DigestError");
173        assert!(
174            matches!(*err, DigestError::VerificationError { .. }),
175            "Error should be a verification error"
176        );
177
178        // Test with correct SHA and header
179        let correct_header_sha = format!("sha512:{:x}", sha2::Sha512::digest(data));
180        let stream = VerifyingStream::new(
181            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
182                data,
183            ))])),
184            Digester::new(&correct_sha).unwrap(),
185            correct_sha.clone(),
186            Some((
187                Digester::new(&correct_header_sha).unwrap(),
188                correct_header_sha.clone(),
189            )),
190        );
191        stream
192            .try_collect::<Vec<_>>()
193            .await
194            .expect("Should not error with valid data");
195
196        // Test with correct layer sha and wrong header sha
197        let incorrect_header_sha = "sha512:incorrect_hash";
198        let stream = VerifyingStream::new(
199            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
200                data,
201            ))])),
202            Digester::new(&correct_sha).unwrap(),
203            correct_sha.clone(),
204            Some((
205                Digester::new(incorrect_header_sha).unwrap(),
206                incorrect_header_sha.to_string(),
207            )),
208        );
209
210        let err = stream
211            .try_collect::<Vec<_>>()
212            .await
213            .expect_err("Should error with invalid sha");
214
215        let err = err
216            .into_inner()
217            .expect("Should have inner error")
218            .downcast::<DigestError>()
219            .expect("Should be a DigestError");
220        assert!(
221            matches!(*err, DigestError::VerificationError { .. }),
222            "Error should be a verification error"
223        );
224    }
225}