use crate::s3::signer::{ChunkSigningContext, sign_chunk, sign_trailer};
use crate::s3::utils::{ChecksumAlgorithm, b64_encode, sha256_hash};
use bytes::Bytes;
use crc_fast::{CrcAlgorithm, Digest as CrcFastDigest};
use futures_util::Stream;
#[cfg(feature = "ring")]
use ring::digest::{Context, SHA256};
use sha1::{Digest as Sha1Digest, Sha1};
#[cfg(not(feature = "ring"))]
use sha2::Sha256;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context as TaskContext, Poll};
const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
enum StreamingHasher {
Crc32(CrcFastDigest),
Crc32c(CrcFastDigest),
Crc64nvme(CrcFastDigest),
Sha1(Sha1),
#[cfg(feature = "ring")]
Sha256(Context),
#[cfg(not(feature = "ring"))]
Sha256(Sha256),
}
impl StreamingHasher {
fn new(algorithm: ChecksumAlgorithm) -> Self {
match algorithm {
ChecksumAlgorithm::CRC32 => {
StreamingHasher::Crc32(CrcFastDigest::new(CrcAlgorithm::Crc32IsoHdlc))
}
ChecksumAlgorithm::CRC32C => {
StreamingHasher::Crc32c(CrcFastDigest::new(CrcAlgorithm::Crc32Iscsi))
}
ChecksumAlgorithm::CRC64NVME => {
StreamingHasher::Crc64nvme(CrcFastDigest::new(CrcAlgorithm::Crc64Nvme))
}
ChecksumAlgorithm::SHA1 => StreamingHasher::Sha1(Sha1::new()),
#[cfg(feature = "ring")]
ChecksumAlgorithm::SHA256 => StreamingHasher::Sha256(Context::new(&SHA256)),
#[cfg(not(feature = "ring"))]
ChecksumAlgorithm::SHA256 => StreamingHasher::Sha256(Sha256::new()),
}
}
fn update(&mut self, data: &[u8]) {
match self {
StreamingHasher::Crc32(d) => d.update(data),
StreamingHasher::Crc32c(d) => d.update(data),
StreamingHasher::Crc64nvme(d) => d.update(data),
StreamingHasher::Sha1(h) => h.update(data),
#[cfg(feature = "ring")]
StreamingHasher::Sha256(ctx) => ctx.update(data),
#[cfg(not(feature = "ring"))]
StreamingHasher::Sha256(h) => h.update(data),
}
}
fn finalize(self) -> String {
match self {
StreamingHasher::Crc32(d) => b64_encode((d.finalize() as u32).to_be_bytes()),
StreamingHasher::Crc32c(d) => b64_encode((d.finalize() as u32).to_be_bytes()),
StreamingHasher::Crc64nvme(d) => b64_encode(d.finalize().to_be_bytes()),
StreamingHasher::Sha1(h) => {
let result = h.finalize();
b64_encode(&result[..])
}
#[cfg(feature = "ring")]
StreamingHasher::Sha256(ctx) => b64_encode(ctx.finish().as_ref()),
#[cfg(not(feature = "ring"))]
StreamingHasher::Sha256(h) => {
let result = h.finalize();
b64_encode(&result[..])
}
}
}
}
#[derive(Clone, Copy)]
enum EncoderState {
Streaming,
FinalChunk,
Trailer,
Done,
}
pub struct AwsChunkedEncoder<S> {
inner: S,
algorithm: ChecksumAlgorithm,
hasher: Option<StreamingHasher>,
state: EncoderState,
}
impl<S> AwsChunkedEncoder<S> {
pub fn new(inner: S, algorithm: ChecksumAlgorithm) -> Self {
Self {
inner,
algorithm,
hasher: Some(StreamingHasher::new(algorithm)),
state: EncoderState::Streaming,
}
}
}
impl<S, E> Stream for AwsChunkedEncoder<S>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
{
type Item = Result<Bytes, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.state {
EncoderState::Streaming => {
let inner = Pin::new(&mut self.inner);
match inner.poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
if chunk.is_empty() {
continue;
}
if let Some(ref mut hasher) = self.hasher {
hasher.update(&chunk);
}
let chunk_len = chunk.len();
let chunk_header = format!("{chunk_len:x}\r\n");
let mut output =
Vec::with_capacity(chunk_header.len() + chunk.len() + 2);
output.extend_from_slice(chunk_header.as_bytes());
output.extend_from_slice(&chunk);
output.extend_from_slice(b"\r\n");
return Poll::Ready(Some(Ok(Bytes::from(output))));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
self.state = EncoderState::FinalChunk;
}
Poll::Pending => {
return Poll::Pending;
}
}
}
EncoderState::FinalChunk => {
self.state = EncoderState::Trailer;
return Poll::Ready(Some(Ok(Bytes::from_static(b"0\r\n"))));
}
EncoderState::Trailer => {
let hasher = self.hasher.take().expect("hasher should exist");
let checksum_value = hasher.finalize();
let trailer = format!(
"{}:{}\r\n\r\n",
self.algorithm.header_name(),
checksum_value
);
self.state = EncoderState::Done;
return Poll::Ready(Some(Ok(Bytes::from(trailer))));
}
EncoderState::Done => {
return Poll::Ready(None);
}
}
}
}
}
pub fn calculate_encoded_length(
content_length: u64,
chunk_size: usize,
algorithm: ChecksumAlgorithm,
) -> u64 {
let chunk_size = chunk_size as u64;
let full_chunks = content_length / chunk_size;
let last_chunk_size = content_length % chunk_size;
let has_partial = if last_chunk_size > 0 { 1 } else { 0 };
let hex_len_full = format!("{chunk_size:x}").len() as u64;
let hex_len_partial = if last_chunk_size > 0 {
format!("{last_chunk_size:x}").len() as u64
} else {
0
};
let full_chunk_overhead = full_chunks * (hex_len_full + 2 + chunk_size + 2);
let partial_chunk_overhead = if has_partial > 0 {
hex_len_partial + 2 + last_chunk_size + 2
} else {
0
};
let final_chunk = 3;
let trailer_header_len = algorithm.header_name().len() as u64;
let checksum_b64_len = match algorithm {
ChecksumAlgorithm::CRC32 | ChecksumAlgorithm::CRC32C => 8, ChecksumAlgorithm::CRC64NVME => 12, ChecksumAlgorithm::SHA1 => 28, ChecksumAlgorithm::SHA256 => 44, };
let trailer_len = trailer_header_len + 1 + checksum_b64_len + 4;
full_chunk_overhead + partial_chunk_overhead + final_chunk + trailer_len
}
pub fn default_chunk_size() -> usize {
DEFAULT_CHUNK_SIZE
}
#[derive(Clone, Copy)]
enum SignedEncoderState {
Streaming,
FinalChunk,
Trailer,
TrailerSignature,
Done,
}
pub struct SignedAwsChunkedEncoder<S> {
inner: S,
algorithm: ChecksumAlgorithm,
hasher: Option<StreamingHasher>,
state: SignedEncoderState,
signing_key: Arc<[u8]>,
date_time: String,
scope: String,
current_signature: String,
checksum_value: Option<String>,
}
impl<S> SignedAwsChunkedEncoder<S> {
pub fn new(inner: S, algorithm: ChecksumAlgorithm, context: ChunkSigningContext) -> Self {
Self {
inner,
algorithm,
hasher: Some(StreamingHasher::new(algorithm)),
state: SignedEncoderState::Streaming,
signing_key: context.signing_key,
date_time: context.date_time,
scope: context.scope,
current_signature: context.seed_signature,
checksum_value: None,
}
}
fn sign_chunk_data(&mut self, chunk_hash: &str) -> String {
let signature = sign_chunk(
&self.signing_key,
&self.date_time,
&self.scope,
&self.current_signature,
chunk_hash,
);
self.current_signature = signature.clone();
signature
}
}
impl<S, E> Stream for SignedAwsChunkedEncoder<S>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
{
type Item = Result<Bytes, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.state {
SignedEncoderState::Streaming => {
let inner = Pin::new(&mut self.inner);
match inner.poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
if chunk.is_empty() {
continue;
}
if let Some(ref mut hasher) = self.hasher {
hasher.update(&chunk);
}
let chunk_hash = sha256_hash(&chunk);
let signature = self.sign_chunk_data(&chunk_hash);
let chunk_len = chunk.len();
let chunk_header =
format!("{chunk_len:x};chunk-signature={signature}\r\n");
let mut output =
Vec::with_capacity(chunk_header.len() + chunk.len() + 2);
output.extend_from_slice(chunk_header.as_bytes());
output.extend_from_slice(&chunk);
output.extend_from_slice(b"\r\n");
return Poll::Ready(Some(Ok(Bytes::from(output))));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
self.state = SignedEncoderState::FinalChunk;
}
Poll::Pending => {
return Poll::Pending;
}
}
}
SignedEncoderState::FinalChunk => {
const EMPTY_SHA256: &str =
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
let signature = self.sign_chunk_data(EMPTY_SHA256);
let final_chunk = format!("0;chunk-signature={}\r\n", signature);
self.state = SignedEncoderState::Trailer;
return Poll::Ready(Some(Ok(Bytes::from(final_chunk))));
}
SignedEncoderState::Trailer => {
let hasher = self.hasher.take().expect("hasher should exist");
let checksum_value = hasher.finalize();
self.checksum_value = Some(checksum_value.clone());
let trailer = format!(
"{}:{}\r\n",
self.algorithm.header_name().to_lowercase(),
checksum_value
);
self.state = SignedEncoderState::TrailerSignature;
return Poll::Ready(Some(Ok(Bytes::from(trailer))));
}
SignedEncoderState::TrailerSignature => {
let checksum_value =
self.checksum_value.as_ref().expect("checksum should exist");
let canonical_trailers = format!(
"{}:{}\n", self.algorithm.header_name().to_lowercase(),
checksum_value
);
let trailers_hash = sha256_hash(canonical_trailers.as_bytes());
let trailer_signature = sign_trailer(
&self.signing_key,
&self.date_time,
&self.scope,
&self.current_signature,
&trailers_hash,
);
let trailer_sig_line =
format!("x-amz-trailer-signature:{}\r\n\r\n", trailer_signature);
self.state = SignedEncoderState::Done;
return Poll::Ready(Some(Ok(Bytes::from(trailer_sig_line))));
}
SignedEncoderState::Done => {
return Poll::Ready(None);
}
}
}
}
}
pub fn calculate_signed_encoded_length(
content_length: u64,
chunk_size: usize,
algorithm: ChecksumAlgorithm,
) -> u64 {
let chunk_size = chunk_size as u64;
let full_chunks = content_length / chunk_size;
let last_chunk_size = content_length % chunk_size;
let has_partial = if last_chunk_size > 0 { 1 } else { 0 };
let signature_overhead: u64 = 81;
let hex_len_full = format!("{:x}", chunk_size).len() as u64;
let hex_len_partial = if last_chunk_size > 0 {
format!("{:x}", last_chunk_size).len() as u64
} else {
0
};
let full_chunk_overhead =
full_chunks * (hex_len_full + signature_overhead + 2 + chunk_size + 2);
let partial_chunk_overhead = if has_partial > 0 {
hex_len_partial + signature_overhead + 2 + last_chunk_size + 2
} else {
0
};
let final_chunk = 84;
let trailer_header_len = algorithm.header_name().to_lowercase().len() as u64;
let checksum_b64_len = match algorithm {
ChecksumAlgorithm::CRC32 | ChecksumAlgorithm::CRC32C => 8,
ChecksumAlgorithm::CRC64NVME => 12,
ChecksumAlgorithm::SHA1 => 28,
ChecksumAlgorithm::SHA256 => 44,
};
let checksum_trailer = trailer_header_len + 1 + checksum_b64_len + 2;
let trailer_signature = 92;
full_chunk_overhead
+ partial_chunk_overhead
+ final_chunk
+ checksum_trailer
+ trailer_signature
}
pub struct RechunkingStream<S> {
inner: S,
chunk_size: usize,
buffer: Vec<u8>,
done: bool,
}
impl<S> RechunkingStream<S> {
pub fn new(inner: S, chunk_size: usize) -> Self {
Self {
inner,
chunk_size,
buffer: Vec::with_capacity(chunk_size),
done: false,
}
}
pub fn with_default_chunk_size(inner: S) -> Self {
Self::new(inner, DEFAULT_CHUNK_SIZE)
}
}
impl<S, E> Stream for RechunkingStream<S>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
{
type Item = Result<Bytes, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let chunk_size = self.chunk_size;
if self.done && self.buffer.is_empty() {
return Poll::Ready(None);
}
if self.buffer.len() >= chunk_size {
let chunk: Vec<u8> = self.buffer.drain(..chunk_size).collect();
return Poll::Ready(Some(Ok(Bytes::from(chunk))));
}
loop {
if self.done {
if self.buffer.is_empty() {
return Poll::Ready(None);
}
let remaining = std::mem::take(&mut self.buffer);
return Poll::Ready(Some(Ok(Bytes::from(remaining))));
}
let inner = Pin::new(&mut self.inner);
match inner.poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
if chunk.is_empty() {
continue;
}
self.buffer.extend_from_slice(&chunk);
if self.buffer.len() >= chunk_size {
let chunk: Vec<u8> = self.buffer.drain(..chunk_size).collect();
return Poll::Ready(Some(Ok(Bytes::from(chunk))));
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
self.done = true;
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
#[tokio::test]
async fn test_aws_chunked_encoder_simple() {
let data = Bytes::from("Hello, World!");
let stream = futures_util::stream::iter(vec![Ok::<_, std::io::Error>(data.clone())]);
let mut encoder = AwsChunkedEncoder::new(stream, ChecksumAlgorithm::CRC32);
let mut output = Vec::new();
while let Some(chunk) = encoder.next().await {
output.extend_from_slice(&chunk.unwrap());
}
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.starts_with("d\r\n"));
assert!(output_str.contains("Hello, World!"));
assert!(output_str.contains("X-Amz-Checksum-CRC32:"));
assert!(output_str.ends_with("\r\n\r\n"));
assert!(output_str.contains("\r\n0\r\n"));
}
#[tokio::test]
async fn test_aws_chunked_encoder_multiple_chunks() {
let chunks = vec![
Ok::<_, std::io::Error>(Bytes::from("Hello, ")),
Ok(Bytes::from("World!")),
];
let stream = futures_util::stream::iter(chunks);
let mut encoder = AwsChunkedEncoder::new(stream, ChecksumAlgorithm::CRC64NVME);
let mut output = Vec::new();
while let Some(chunk) = encoder.next().await {
output.extend_from_slice(&chunk.unwrap());
}
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.starts_with("7\r\n")); assert!(output_str.contains("6\r\n")); assert!(output_str.contains("X-Amz-Checksum-CRC64NVME:"));
}
#[test]
fn test_calculate_encoded_length() {
let len = calculate_encoded_length(100, 64 * 1024, ChecksumAlgorithm::CRC32);
assert!(len > 100); }
fn test_signing_context() -> ChunkSigningContext {
ChunkSigningContext {
signing_key: Arc::from(vec![
0x98, 0xf1, 0xd8, 0x89, 0xfe, 0xc4, 0xf4, 0x42, 0x1a, 0xdc, 0x52, 0x2b, 0xab, 0x0c,
0xe1, 0xf8, 0x2c, 0x6c, 0x4e, 0x4e, 0xc3, 0x9a, 0xe1, 0xf6, 0xcc, 0xf2, 0x0e, 0x8f,
0x40, 0x89, 0x45, 0x65,
]),
date_time: "20130524T000000Z".to_string(),
scope: "20130524/us-east-1/s3/aws4_request".to_string(),
seed_signature: "4f232c4386841ef735655705268965c44a0e4690baa4adea153f7db9fa80a0a9"
.to_string(),
}
}
#[tokio::test]
async fn test_signed_encoder_simple() {
let data = Bytes::from("Hello, World!");
let stream = futures_util::stream::iter(vec![Ok::<_, std::io::Error>(data)]);
let context = test_signing_context();
let mut encoder = SignedAwsChunkedEncoder::new(stream, ChecksumAlgorithm::CRC32, context);
let mut output = Vec::new();
while let Some(chunk) = encoder.next().await {
output.extend_from_slice(&chunk.unwrap());
}
let output_str = String::from_utf8(output).unwrap();
assert!(output_str.starts_with("d;chunk-signature="));
assert!(output_str.contains("Hello, World!"));
assert!(output_str.contains("0;chunk-signature="));
assert!(output_str.contains("x-amz-checksum-crc32:"));
assert!(output_str.contains("x-amz-trailer-signature:"));
assert!(output_str.ends_with("\r\n\r\n"));
}
#[tokio::test]
async fn test_signed_encoder_multiple_chunks() {
let chunks = vec![
Ok::<_, std::io::Error>(Bytes::from("Hello, ")),
Ok(Bytes::from("World!")),
];
let stream = futures_util::stream::iter(chunks);
let context = test_signing_context();
let mut encoder = SignedAwsChunkedEncoder::new(stream, ChecksumAlgorithm::CRC32C, context);
let mut output = Vec::new();
while let Some(chunk) = encoder.next().await {
output.extend_from_slice(&chunk.unwrap());
}
let output_str = String::from_utf8(output).unwrap();
let sig_count = output_str.matches(";chunk-signature=").count();
assert_eq!(sig_count, 3);
assert!(output_str.contains("x-amz-checksum-crc32c:"));
assert!(output_str.contains("x-amz-trailer-signature:"));
}
#[tokio::test]
async fn test_signed_encoder_signature_is_64_hex_chars() {
let data = Bytes::from("test");
let stream = futures_util::stream::iter(vec![Ok::<_, std::io::Error>(data)]);
let context = test_signing_context();
let mut encoder = SignedAwsChunkedEncoder::new(stream, ChecksumAlgorithm::CRC32, context);
let mut output = Vec::new();
while let Some(chunk) = encoder.next().await {
output.extend_from_slice(&chunk.unwrap());
}
let output_str = String::from_utf8(output).unwrap();
for sig_match in output_str.match_indices(";chunk-signature=") {
let start = sig_match.0 + sig_match.1.len();
let sig = &output_str[start..start + 64];
assert!(
sig.chars().all(|c| c.is_ascii_hexdigit()),
"Signature should be hex: {}",
sig
);
}
let trailer_sig_start = output_str.find("x-amz-trailer-signature:").unwrap() + 24;
let trailer_sig = &output_str[trailer_sig_start..trailer_sig_start + 64];
assert!(
trailer_sig.chars().all(|c| c.is_ascii_hexdigit()),
"Trailer signature should be hex: {}",
trailer_sig
);
}
#[test]
fn test_calculate_signed_encoded_length() {
let len = calculate_signed_encoded_length(100, 64 * 1024, ChecksumAlgorithm::CRC32);
let unsigned_len = calculate_encoded_length(100, 64 * 1024, ChecksumAlgorithm::CRC32);
assert!(
len > unsigned_len,
"Signed length {} should be > unsigned length {}",
len,
unsigned_len
);
}
#[test]
fn test_calculate_signed_encoded_length_multiple_chunks() {
let content_len = 200 * 1024;
let chunk_size = 64 * 1024;
let len =
calculate_signed_encoded_length(content_len, chunk_size, ChecksumAlgorithm::SHA256);
assert!(len > content_len);
}
#[tokio::test]
async fn test_rechunking_stream_combines_small_chunks() {
let chunk_size = 1024;
let num_chunks = 10;
let chunks: Vec<Result<Bytes, std::io::Error>> = (0..num_chunks)
.map(|i| Ok(Bytes::from(vec![i as u8; chunk_size])))
.collect();
let stream = futures_util::stream::iter(chunks);
let mut rechunker = RechunkingStream::new(stream, 4096);
let mut output_chunks = Vec::new();
while let Some(chunk) = rechunker.next().await {
output_chunks.push(chunk.unwrap());
}
assert_eq!(output_chunks.len(), 3);
assert_eq!(output_chunks[0].len(), 4096);
assert_eq!(output_chunks[1].len(), 4096);
assert_eq!(output_chunks[2].len(), 2048);
let total: usize = output_chunks.iter().map(|c| c.len()).sum();
assert_eq!(total, num_chunks * chunk_size);
}
#[tokio::test]
async fn test_rechunking_stream_passes_large_chunks() {
let data = Bytes::from(vec![42u8; 10000]);
let stream = futures_util::stream::iter(vec![Ok::<_, std::io::Error>(data)]);
let mut rechunker = RechunkingStream::new(stream, 4096);
let mut output_chunks = Vec::new();
while let Some(chunk) = rechunker.next().await {
output_chunks.push(chunk.unwrap());
}
assert_eq!(output_chunks.len(), 3);
assert_eq!(output_chunks[0].len(), 4096);
assert_eq!(output_chunks[1].len(), 4096);
assert_eq!(output_chunks[2].len(), 1808);
}
#[tokio::test]
async fn test_rechunking_stream_exact_multiple() {
let data = Bytes::from(vec![1u8; 8192]);
let stream = futures_util::stream::iter(vec![Ok::<_, std::io::Error>(data)]);
let mut rechunker = RechunkingStream::new(stream, 4096);
let mut output_chunks = Vec::new();
while let Some(chunk) = rechunker.next().await {
output_chunks.push(chunk.unwrap());
}
assert_eq!(output_chunks.len(), 2);
assert_eq!(output_chunks[0].len(), 4096);
assert_eq!(output_chunks[1].len(), 4096);
}
#[tokio::test]
async fn test_rechunking_stream_empty() {
let stream = futures_util::stream::iter(Vec::<Result<Bytes, std::io::Error>>::new());
let mut rechunker = RechunkingStream::new(stream, 4096);
let result = rechunker.next().await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_rechunking_stream_preserves_data() {
let original: Vec<u8> = (0..=255).cycle().take(15000).collect();
let chunks: Vec<Result<Bytes, std::io::Error>> = original
.chunks(100) .map(|c| Ok(Bytes::copy_from_slice(c)))
.collect();
let stream = futures_util::stream::iter(chunks);
let mut rechunker = RechunkingStream::new(stream, 4096);
let mut output = Vec::new();
while let Some(chunk) = rechunker.next().await {
output.extend_from_slice(&chunk.unwrap());
}
assert_eq!(output, original);
}
}