use crate::digest::Digest as DigestRef;
use crate::error;
use crate::models::MediaType;
use crate::models::Platform;
use crate::progress::{NoopHandle, ProgressHandle, ProgressReporter};
use crate::uri::{Reference, Uri};
use bon::Builder;
use bytes::Bytes;
use futures::FutureExt;
use futures::future::BoxFuture;
use reqwest::Response;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use snafu::{ResultExt, ensure};
use std::cmp::min;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_util::io::StreamReader;
use url::Url;
const MIN_CHUNK_SIZE: usize = 5 * 1024 * 1024;
const MAX_CHUNK_SIZE: usize = 100 * 1024 * 1024;
#[derive(Debug, Serialize, Deserialize, Clone, Builder)]
#[serde(rename_all = "camelCase")]
pub struct Layer {
#[builder(into)]
media_type: MediaType,
#[builder(into)]
size: usize,
#[builder(into)]
digest: String,
#[builder(into)]
#[serde(skip_serializing_if = "Option::is_none")]
platform: Option<Platform>,
}
impl Layer {
pub async fn copy<'a, R, W>(
reader: &'a mut R,
writer: &'a mut W,
size: usize,
) -> crate::Result<()>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
let mut index = 0;
let chunk_size = (size / 40).clamp(MIN_CHUNK_SIZE, MAX_CHUNK_SIZE);
while index < size {
let read_size = min(chunk_size, size - index);
let mut buffer = vec![0; read_size];
reader
.read_exact(&mut buffer)
.await
.context(error::LayerReadSnafu)?;
writer
.write_all(buffer.as_slice())
.await
.context(error::LayerWriteSnafu)?;
index += read_size;
}
Ok(())
}
pub async fn create(
uri: &Uri,
media_type: &MediaType,
size: usize,
digest: Option<String>,
progress: Option<&dyn ProgressReporter>,
) -> crate::Result<Option<Writer>> {
if let Some(digest) = digest.as_ref() {
trace!(target: "layer", "checking if a blob already exists with the digest: {digest}");
if uri
.registry()
.check_blob(uri.repository(), digest.as_str())
.await?
{
debug!(target: "layer", "blob already exists with the digest: {digest}");
return Ok(None);
}
}
let handle: Box<dyn ProgressHandle> = match progress {
Some(reporter) => {
let label = digest
.as_ref()
.and_then(|d| DigestRef::parse(d).ok())
.map(|d| format!("blob {} ->", d.short(9)))
.unwrap_or_else(|| "blob ->".to_string());
reporter.start(size as u64, &label)
}
None => Box::new(NoopHandle),
};
Ok(Some(Writer {
uri: uri.clone(),
size,
media_type: media_type.clone(),
state: WriterState::Initial,
digest: Sha256::new(),
written: 0,
buffer: Vec::new(),
progress: handle,
}))
}
pub async fn open(
&self,
uri: &Uri,
progress: Option<&dyn ProgressReporter>,
) -> crate::Result<Reader> {
let (stream, content_length) = uri
.registry()
.fetch_blob(uri.repository(), self.digest.as_str())
.await?;
let expected = self.size as u64;
if content_length != 0 && content_length != expected {
return error::LayerSizeMismatchSnafu {
expected: expected as usize,
actual: content_length as usize,
}
.fail();
}
let handle: Box<dyn ProgressHandle> = match progress {
Some(reporter) => {
let label = DigestRef::parse(&self.digest)
.map(|d| format!("blob {} <-", d.short(9)))
.unwrap_or_else(|_| "blob <-".to_string());
reporter.start(expected, &label)
}
None => Box::new(NoopHandle),
};
let stream_reader = StreamReader::new(stream);
Ok(Reader::new(
stream_reader,
handle,
Some(self.digest.clone()),
Some(self.size),
))
}
pub async fn open_uri(uri: &Uri) -> crate::Result<Reader> {
ensure!(
matches!(uri.reference(), Reference::Digest { .. }),
error::DirectLoadBlobSnafu { uri: uri.clone() }
);
let digest = uri.reference().to_string();
let (stream, _size) = uri
.registry()
.fetch_blob(uri.repository(), digest.as_str())
.await?;
Ok(Reader::new(
StreamReader::new(stream),
Box::new(NoopHandle),
None,
None,
))
}
pub fn media_type(&self) -> &MediaType {
&self.media_type
}
pub fn digest(&self) -> &str {
&self.digest
}
pub fn size(&self) -> usize {
self.size
}
pub fn platform(&self) -> Option<Platform> {
self.platform.clone()
}
pub async fn delete(&self, uri: &Uri) -> crate::Result<()> {
uri.registry()
.delete_blob(uri.repository(), self.digest.as_str())
.await
}
}
pub struct Reader {
inner: Pin<Box<dyn AsyncRead + Send + Sync>>,
progress: Box<dyn ProgressHandle>,
expected_digest: Option<String>,
expected_size: Option<usize>,
hasher: Sha256,
bytes_read: usize,
finalized: bool,
}
impl Reader {
pub fn new(
inner: impl AsyncRead + Send + Sync + 'static,
progress: Box<dyn ProgressHandle>,
expected_digest: Option<String>,
expected_size: Option<usize>,
) -> Self {
Self {
inner: Box::pin(inner),
progress,
expected_digest,
expected_size,
hasher: Sha256::new(),
bytes_read: 0,
finalized: false,
}
}
}
impl Drop for Reader {
fn drop(&mut self) {
self.progress.finish();
}
}
impl AsyncRead for Reader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
let before = buf.filled().len();
match this.inner.as_mut().poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let after = buf.filled().len();
let delta = after - before;
if delta > 0 {
let new_chunk = &buf.filled()[before..after];
this.hasher.update(new_chunk);
this.bytes_read += delta;
if let Some(expected) = this.expected_size
&& this.bytes_read > expected
{
return Poll::Ready(Err(std::io::Error::other(format!(
"registry returned more bytes than declared (expected {expected}, got {})",
this.bytes_read
))));
}
this.progress.inc(delta as u64);
} else if !this.finalized {
this.finalized = true;
if let Some(expected) = this.expected_size
&& this.bytes_read != expected
{
return Poll::Ready(Err(std::io::Error::other(format!(
"short layer read (expected {expected}, got {})",
this.bytes_read
))));
}
if let Some(expected) = this.expected_digest.as_ref() {
let computed = format!(
"sha256:{}",
base16::encode_lower(this.hasher.clone().finalize().as_slice())
);
if computed != *expected {
return Poll::Ready(Err(std::io::Error::other(format!(
"layer digest mismatch: expected {expected}, computed {computed}"
))));
}
}
}
Poll::Ready(Ok(()))
}
other => other,
}
}
}
pub struct Writer {
uri: Uri,
media_type: MediaType,
size: usize,
written: usize,
digest: Sha256,
state: WriterState,
buffer: Vec<u8>,
progress: Box<dyn ProgressHandle>,
}
enum WriterState {
Initial,
Starting(BoxFuture<'static, crate::Result<Response>>),
Idle { upload_url: Url },
Uploading {
fut: BoxFuture<'static, crate::Result<Response>>,
chunk_len: usize,
},
Finishing {
fut: BoxFuture<'static, crate::Result<Response>>,
pending_advance: usize,
},
Done,
Failed(String),
}
impl Writer {
pub async fn layer(&mut self) -> crate::Result<Layer> {
if !matches!(self.state, WriterState::Done) {
return Err(error::Error::LayerWrite {
source: std::io::Error::other("writer.layer() called before shutdown"),
});
}
let digest_bytes = self.digest.clone().finalize();
let digest = format!("sha256:{}", base16::encode_lower(&digest_bytes));
self.progress.finish();
Ok(Layer {
media_type: self.media_type.clone(),
digest,
size: self.written,
platform: None,
})
}
fn fail<E: std::fmt::Display>(&mut self, e: E) -> std::io::Error {
let s = e.to_string();
self.state = WriterState::Failed(s.clone());
std::io::Error::other(s)
}
fn buffer_ready_to_flush(&self) -> bool {
self.buffer.len() >= MIN_CHUNK_SIZE
}
fn poll_uploading(&mut self, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let WriterState::Uploading { fut, chunk_len } = &mut self.state else {
unreachable!("poll_uploading called outside Uploading state");
};
match fut.poll_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(response)) => {
let chunk_len = *chunk_len;
if !response.status().is_success() {
let status = response.status();
let err = self.fail(format!("registry rejected chunk: {status}"));
return Poll::Ready(Err(err));
}
let next_url = match crate::client::extract_location(&response, &self.current_url())
{
Ok(u) => u,
Err(e) => {
let err = self.fail(e);
return Poll::Ready(Err(err));
}
};
self.written += chunk_len;
self.progress.inc(chunk_len as u64);
self.state = WriterState::Idle {
upload_url: next_url,
};
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
let err = self.fail(e);
Poll::Ready(Err(err))
}
}
}
fn poll_finishing(&mut self, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let WriterState::Finishing {
fut,
pending_advance,
} = &mut self.state
else {
unreachable!("poll_finishing called outside Finishing state");
};
match fut.poll_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(response)) => {
if !response.status().is_success() {
let status = response.status();
let err = self.fail(format!("registry rejected blob finalize: {status}"));
return Poll::Ready(Err(err));
}
let advance = *pending_advance;
self.written += advance;
self.progress.inc(advance as u64);
self.state = WriterState::Done;
self.progress.finish();
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
let err = self.fail(e);
Poll::Ready(Err(err))
}
}
}
fn poll_starting(&mut self, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let WriterState::Starting(fut) = &mut self.state else {
unreachable!("poll_starting called outside Starting state");
};
match fut.poll_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(response)) => {
if !response.status().is_success() {
let status = response.status();
let err = self.fail(format!("registry rejected upload start: {status}"));
return Poll::Ready(Err(err));
}
let url = match crate::client::extract_location(&response, &self.current_url()) {
Ok(u) => u,
Err(e) => {
let err = self.fail(e);
return Poll::Ready(Err(err));
}
};
self.state = WriterState::Idle { upload_url: url };
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
let err = self.fail(e);
Poll::Ready(Err(err))
}
}
}
fn current_url(&self) -> Url {
self.uri
.registry()
.url()
.unwrap_or_else(|_| Url::parse("http://invalid.local/").expect("invariant"))
}
fn launch_patch(&mut self) {
let WriterState::Idle { upload_url } =
std::mem::replace(&mut self.state, WriterState::Done)
else {
unreachable!("launch_patch called outside Idle state");
};
let chunk = std::mem::take(&mut self.buffer);
let chunk_len = chunk.len();
self.digest.update(&chunk);
let bytes = Bytes::from(chunk);
let start = self.written;
let end = start + chunk_len;
let client = self.uri.registry().client.clone();
let upload_url_for_fut = upload_url.clone();
let fut = async move {
client
.upload_part(upload_url_for_fut, bytes, start, end)
.await
}
.boxed();
self.state = WriterState::Uploading { fut, chunk_len };
}
fn launch_finish(&mut self) {
let WriterState::Idle { upload_url } =
std::mem::replace(&mut self.state, WriterState::Done)
else {
unreachable!("launch_finish called outside Idle state");
};
let chunk = std::mem::take(&mut self.buffer);
let chunk_len = chunk.len();
if chunk_len > 0 {
self.digest.update(&chunk);
}
let digest_bytes = self.digest.clone().finalize();
let digest = format!("sha256:{}", base16::encode_lower(&digest_bytes));
let bytes = Bytes::from(chunk);
let start = self.written;
let end = start + chunk_len;
let client = self.uri.registry().client.clone();
let fut = async move {
client
.finish_blob_upload(upload_url, bytes, digest, start, end)
.await
}
.boxed();
self.state = WriterState::Finishing {
fut,
pending_advance: chunk_len,
};
}
fn launch_monolithic_post(&mut self) {
let chunk = std::mem::take(&mut self.buffer);
let chunk_len = chunk.len();
self.digest.update(&chunk);
let digest_bytes = self.digest.clone().finalize();
let digest = format!("sha256:{}", base16::encode_lower(&digest_bytes));
let bytes = Bytes::from(chunk);
let registry_url = match self.uri.registry().url() {
Ok(u) => u,
Err(e) => {
self.state = WriterState::Failed(e.to_string());
return;
}
};
let repository = self.uri.repository().clone();
let client = self.uri.registry().client.clone();
let fut = async move {
client
.post_blob(registry_url, repository, bytes, digest)
.await
}
.boxed();
self.state = WriterState::Finishing {
fut,
pending_advance: chunk_len,
};
}
fn launch_start(&mut self) {
let registry_url = match self.uri.registry().url() {
Ok(u) => u,
Err(e) => {
self.state = WriterState::Failed(e.to_string());
return;
}
};
let repository = self.uri.repository().clone();
let client = self.uri.registry().client.clone();
let fut = async move { client.start_upload(registry_url, repository).await }.boxed();
self.state = WriterState::Starting(fut);
}
}
impl AsyncWrite for Writer {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
loop {
match &this.state {
WriterState::Failed(reason) => {
return Poll::Ready(Err(std::io::Error::other(reason.clone())));
}
WriterState::Done => {
return Poll::Ready(Err(std::io::Error::other(
"writer is closed; further writes are rejected",
)));
}
WriterState::Starting(_) => match this.poll_starting(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Uploading { .. } => match this.poll_uploading(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Finishing { .. } => match this.poll_finishing(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Initial | WriterState::Idle { .. } => break,
}
}
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let to_take = buf.len().min(MAX_CHUNK_SIZE);
this.buffer.extend_from_slice(&buf[..to_take]);
match this.state {
WriterState::Initial => {
if this.buffer.len() >= this.size && this.size > 0 {
this.launch_monolithic_post();
} else if this.buffer_ready_to_flush() {
this.launch_start();
}
}
WriterState::Idle { .. } => {
if this.buffer_ready_to_flush() {
this.launch_patch();
}
}
_ => unreachable!("only Initial/Idle reachable post-drive"),
}
Poll::Ready(Ok(to_take))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
loop {
match &this.state {
WriterState::Failed(reason) => {
return Poll::Ready(Err(std::io::Error::other(reason.clone())));
}
WriterState::Done => return Poll::Ready(Ok(())),
WriterState::Starting(_) => match this.poll_starting(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Uploading { .. } => match this.poll_uploading(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Finishing { .. } => match this.poll_finishing(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Initial | WriterState::Idle { .. } => return Poll::Ready(Ok(())),
}
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
loop {
match &this.state {
WriterState::Failed(reason) => {
return Poll::Ready(Err(std::io::Error::other(reason.clone())));
}
WriterState::Done => return Poll::Ready(Ok(())),
WriterState::Starting(_) => match this.poll_starting(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Uploading { .. } => match this.poll_uploading(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Finishing { .. } => match this.poll_finishing(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriterState::Initial => {
if !this.buffer.is_empty() {
this.launch_monolithic_post();
continue;
}
this.state = WriterState::Done;
return Poll::Ready(Ok(()));
}
WriterState::Idle { .. } => {
this.launch_finish();
continue;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncReadExt;
struct ShortReader {
data: Vec<u8>,
pos: usize,
}
impl AsyncRead for ShortReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
let remaining = this.data.len() - this.pos;
let take = remaining.min(7).min(buf.remaining());
buf.put_slice(&this.data[this.pos..this.pos + take]);
this.pos += take;
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn copy_handles_short_reads() {
let data = vec![0xABu8; 31];
let mut src = ShortReader {
data: data.clone(),
pos: 0,
};
let mut dst = Vec::new();
Layer::copy(&mut src, &mut dst, data.len()).await.unwrap();
assert_eq!(dst, data);
}
#[tokio::test]
async fn reader_finalizes_digest_on_eof_when_present() {
let empty_digest =
"sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
let mut r = Reader::new(
tokio::io::empty(),
Box::new(NoopHandle),
Some(empty_digest.to_string()),
Some(0),
);
let mut out = Vec::new();
r.read_to_end(&mut out).await.unwrap();
assert!(out.is_empty());
}
#[tokio::test]
async fn reader_detects_digest_mismatch() {
let bad_digest = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
let mut r = Reader::new(
tokio::io::empty(),
Box::new(NoopHandle),
Some(bad_digest.to_string()),
Some(0),
);
let mut out = Vec::new();
let err = r.read_to_end(&mut out).await.unwrap_err();
assert!(err.to_string().contains("digest mismatch"));
}
}