oci_client/
blob.rs

1//! Helpers for interacting with blobs and their verification
2use std::task::Poll;
3
4use futures_util::stream::{BoxStream, Stream};
5use futures_util::TryStreamExt;
6
7use crate::digest::Digester;
8use crate::errors::DigestError;
9
10/// Stream response of a blob with optional content length if available
11pub struct SizedStream {
12    /// The length of the stream if the upstream registry sent a `Content-Length` header
13    pub content_length: Option<u64>,
14    /// The digest header value if the upstream registry sent a `Digest` header. This should be used
15    /// (in addition to the layer digest) for validation when using partial requests as the library
16    /// can't validate against the full response.
17    pub digest_header_value: Option<String>,
18    /// The stream of bytes
19    pub stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
20}
21
22impl Stream for SizedStream {
23    type Item = Result<bytes::Bytes, std::io::Error>;
24
25    fn poll_next(
26        mut self: std::pin::Pin<&mut Self>,
27        cx: &mut std::task::Context<'_>,
28    ) -> Poll<Option<Self::Item>> {
29        self.stream.try_poll_next_unpin(cx)
30    }
31}
32
33/// The response of a partial blob request
34pub enum BlobResponse {
35    /// The response is a full blob (for example when partial requests aren't supported)
36    Full(SizedStream),
37    /// The response is a partial blob as requested
38    Partial(SizedStream),
39}
40
41pub(crate) struct VerifyingStream {
42    stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
43    layer_digester: Digester,
44    expected_layer_digest: String,
45    header_digester: Option<(Digester, String)>,
46}
47
48impl VerifyingStream {
49    pub fn new(
50        stream: BoxStream<'static, Result<bytes::Bytes, std::io::Error>>,
51        layer_digester: Digester,
52        expected_layer_digest: String,
53        header_digester_and_digest: Option<(Digester, String)>,
54    ) -> Self {
55        Self {
56            stream,
57            layer_digester,
58            expected_layer_digest,
59            header_digester: header_digester_and_digest,
60        }
61    }
62}
63
64impl Stream for VerifyingStream {
65    type Item = Result<bytes::Bytes, std::io::Error>;
66
67    fn poll_next(
68        self: std::pin::Pin<&mut Self>,
69        cx: &mut std::task::Context<'_>,
70    ) -> Poll<Option<Self::Item>> {
71        let this = self.get_mut();
72        match futures_util::ready!(this.stream.as_mut().poll_next(cx)) {
73            Some(Ok(bytes)) => {
74                this.layer_digester.update(&bytes);
75                if let Some((digester, _)) = this.header_digester.as_mut() {
76                    digester.update(&bytes);
77                }
78                Poll::Ready(Some(Ok(bytes)))
79            }
80            Some(Err(e)) => Poll::Ready(Some(Err(e))),
81            None => {
82                // Now that we've reached the end of the stream, verify the digest(s)
83                match this.header_digester.as_mut() {
84                    Some((digester, expected)) => {
85                        // Check the header digester and then the layer digester before returning
86                        let digest = digester.finalize();
87                        if digest != *expected {
88                            return Poll::Ready(Some(Err(std::io::Error::new(
89                                std::io::ErrorKind::Other,
90                                DigestError::VerificationError {
91                                    expected: expected.clone(),
92                                    actual: digest,
93                                },
94                            ))));
95                        }
96                        let digest = this.layer_digester.finalize();
97                        if digest == this.expected_layer_digest {
98                            Poll::Ready(None)
99                        } else {
100                            Poll::Ready(Some(Err(std::io::Error::new(
101                                std::io::ErrorKind::Other,
102                                DigestError::VerificationError {
103                                    expected: expected.clone(),
104                                    actual: digest,
105                                },
106                            ))))
107                        }
108                    }
109                    None => {
110                        let digest = this.layer_digester.finalize();
111                        if digest == this.expected_layer_digest {
112                            Poll::Ready(None)
113                        } else {
114                            Poll::Ready(Some(Err(std::io::Error::new(
115                                std::io::ErrorKind::Other,
116                                DigestError::VerificationError {
117                                    expected: this.expected_layer_digest.clone(),
118                                    actual: digest,
119                                },
120                            ))))
121                        }
122                    }
123                }
124            }
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    use bytes::Bytes;
134    use futures_util::TryStreamExt;
135    use sha2::Digest as _;
136
137    #[tokio::test]
138    async fn test_verifying_stream() {
139        // Test with correct SHA
140        let data = b"Hello, world!";
141        let correct_sha = format!("sha256:{:x}", sha2::Sha256::digest(data));
142        let stream = VerifyingStream::new(
143            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
144                data,
145            ))])),
146            Digester::new(&correct_sha).unwrap(),
147            correct_sha.clone(),
148            None,
149        );
150        stream
151            .try_collect::<Vec<_>>()
152            .await
153            .expect("Should not error with valid data");
154
155        // Test with incorrect SHA
156        let incorrect_sha = "sha256:incorrect_hash";
157        let stream = VerifyingStream::new(
158            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
159                data,
160            ))])),
161            Digester::new(incorrect_sha).unwrap(),
162            incorrect_sha.to_string(),
163            None,
164        );
165
166        let err = stream
167            .try_collect::<Vec<_>>()
168            .await
169            .expect_err("Should error with invalid sha");
170
171        let err = err
172            .into_inner()
173            .expect("Should have inner error")
174            .downcast::<DigestError>()
175            .expect("Should be a DigestError");
176        assert!(
177            matches!(*err, DigestError::VerificationError { .. }),
178            "Error should be a verification error"
179        );
180
181        // Test with correct SHA and header
182        let correct_header_sha = format!("sha512:{:x}", sha2::Sha512::digest(data));
183        let stream = VerifyingStream::new(
184            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
185                data,
186            ))])),
187            Digester::new(&correct_sha).unwrap(),
188            correct_sha.clone(),
189            Some((
190                Digester::new(&correct_header_sha).unwrap(),
191                correct_header_sha.clone(),
192            )),
193        );
194        stream
195            .try_collect::<Vec<_>>()
196            .await
197            .expect("Should not error with valid data");
198
199        // Test with correct layer sha and wrong header sha
200        let incorrect_header_sha = "sha512:incorrect_hash";
201        let stream = VerifyingStream::new(
202            Box::pin(futures_util::stream::iter(vec![Ok(Bytes::from_static(
203                data,
204            ))])),
205            Digester::new(&correct_sha).unwrap(),
206            correct_sha.clone(),
207            Some((
208                Digester::new(incorrect_header_sha).unwrap(),
209                incorrect_header_sha.to_string(),
210            )),
211        );
212
213        let err = stream
214            .try_collect::<Vec<_>>()
215            .await
216            .expect_err("Should error with invalid sha");
217
218        let err = err
219            .into_inner()
220            .expect("Should have inner error")
221            .downcast::<DigestError>()
222            .expect("Should be a DigestError");
223        assert!(
224            matches!(*err, DigestError::VerificationError { .. }),
225            "Error should be a verification error"
226        );
227    }
228}