use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, Ordering};
const DEFAULT_THRESHOLD_BYTES: usize = 524_288;
const DEFAULT_CHUNK_SIZE_BYTES: usize = 262_144;
const CHUNKED_SENTINEL: &[u8] = b"CELERS_CHUNKED:";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkingConfig {
pub enabled: bool,
pub threshold_bytes: usize,
pub chunk_size_bytes: usize,
pub checksum_enabled: bool,
}
impl Default for ChunkingConfig {
fn default() -> Self {
Self {
enabled: true,
threshold_bytes: DEFAULT_THRESHOLD_BYTES,
chunk_size_bytes: DEFAULT_CHUNK_SIZE_BYTES,
checksum_enabled: true,
}
}
}
impl ChunkingConfig {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Self::default()
}
}
pub fn with_threshold(mut self, bytes: usize) -> Self {
self.threshold_bytes = bytes;
self
}
pub fn with_chunk_size(mut self, bytes: usize) -> Self {
self.chunk_size_bytes = bytes;
self
}
pub fn with_checksum(mut self, enabled: bool) -> Self {
self.checksum_enabled = enabled;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkMetadata {
pub total_chunks: usize,
pub chunk_size: usize,
pub total_size: usize,
pub checksum: Option<u32>,
pub created_at: u64,
}
#[derive(Debug)]
pub struct ChunkingStats {
chunks_stored: AtomicU64,
chunks_loaded: AtomicU64,
bytes_chunked: AtomicU64,
checksum_failures: AtomicU64,
}
impl Default for ChunkingStats {
fn default() -> Self {
Self::new()
}
}
impl ChunkingStats {
pub fn new() -> Self {
Self {
chunks_stored: AtomicU64::new(0),
chunks_loaded: AtomicU64::new(0),
bytes_chunked: AtomicU64::new(0),
checksum_failures: AtomicU64::new(0),
}
}
pub fn record_store(&self, num_chunks: u64, total_bytes: u64) {
self.chunks_stored.fetch_add(num_chunks, Ordering::Relaxed);
self.bytes_chunked.fetch_add(total_bytes, Ordering::Relaxed);
}
pub fn record_load(&self, num_chunks: u64) {
self.chunks_loaded.fetch_add(num_chunks, Ordering::Relaxed);
}
pub fn record_checksum_failure(&self) {
self.checksum_failures.fetch_add(1, Ordering::Relaxed);
}
pub fn chunks_stored(&self) -> u64 {
self.chunks_stored.load(Ordering::Relaxed)
}
pub fn chunks_loaded(&self) -> u64 {
self.chunks_loaded.load(Ordering::Relaxed)
}
pub fn bytes_chunked(&self) -> u64 {
self.bytes_chunked.load(Ordering::Relaxed)
}
pub fn checksum_failures(&self) -> u64 {
self.checksum_failures.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.chunks_stored.store(0, Ordering::Relaxed);
self.chunks_loaded.store(0, Ordering::Relaxed);
self.bytes_chunked.store(0, Ordering::Relaxed);
self.checksum_failures.store(0, Ordering::Relaxed);
}
}
impl Clone for ResultChunker {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
stats: ChunkingStats::new(),
}
}
}
pub struct ResultChunker {
config: ChunkingConfig,
stats: ChunkingStats,
}
impl ResultChunker {
pub fn new(config: ChunkingConfig) -> Self {
Self {
config,
stats: ChunkingStats::new(),
}
}
pub fn needs_chunking(&self, data: &[u8]) -> bool {
self.config.enabled && data.len() > self.config.threshold_bytes
}
pub fn create_sentinel(&self, metadata: &ChunkMetadata) -> Vec<u8> {
let mut sentinel = CHUNKED_SENTINEL.to_vec();
if let Ok(json) = serde_json::to_vec(metadata) {
sentinel.extend_from_slice(&json);
}
sentinel
}
pub fn is_chunked(data: &[u8]) -> bool {
data.len() >= CHUNKED_SENTINEL.len() && data[..CHUNKED_SENTINEL.len()] == *CHUNKED_SENTINEL
}
pub fn parse_sentinel(data: &[u8]) -> Result<ChunkMetadata, std::io::Error> {
if !Self::is_chunked(data) {
return Err(std::io::Error::other("not a chunked sentinel"));
}
let json = &data[CHUNKED_SENTINEL.len()..];
serde_json::from_slice(json).map_err(|e| std::io::Error::other(e.to_string()))
}
pub fn split_chunks(&self, data: &[u8]) -> (ChunkMetadata, Vec<Vec<u8>>) {
let chunk_size = if self.config.chunk_size_bytes == 0 {
DEFAULT_CHUNK_SIZE_BYTES
} else {
self.config.chunk_size_bytes
};
let total_chunks = if data.is_empty() {
1
} else {
data.len().div_ceil(chunk_size)
};
let checksum = if self.config.checksum_enabled {
Some(crc32fast::hash(data))
} else {
None
};
let created_at = chrono::Utc::now().timestamp() as u64;
let metadata = ChunkMetadata {
total_chunks,
chunk_size,
total_size: data.len(),
checksum,
created_at,
};
let chunks: Vec<Vec<u8>> = if data.is_empty() {
vec![Vec::new()]
} else {
data.chunks(chunk_size).map(|c| c.to_vec()).collect()
};
self.stats
.record_store(total_chunks as u64, data.len() as u64);
(metadata, chunks)
}
pub fn reassemble_chunks(
&self,
metadata: &ChunkMetadata,
chunks: &[Vec<u8>],
) -> Result<Vec<u8>, std::io::Error> {
if chunks.len() != metadata.total_chunks {
return Err(std::io::Error::other(format!(
"chunk count mismatch: expected {}, got {}",
metadata.total_chunks,
chunks.len()
)));
}
let mut result = Vec::with_capacity(metadata.total_size);
for chunk in chunks {
result.extend_from_slice(chunk);
}
if result.len() != metadata.total_size {
return Err(std::io::Error::other(format!(
"reassembled size mismatch: expected {}, got {}",
metadata.total_size,
result.len()
)));
}
if let Some(expected) = metadata.checksum {
let actual = crc32fast::hash(&result);
if actual != expected {
self.stats.record_checksum_failure();
return Err(std::io::Error::other(format!(
"checksum mismatch: expected {expected}, got {actual}"
)));
}
}
self.stats.record_load(metadata.total_chunks as u64);
Ok(result)
}
pub fn chunk_keys(base_key: &str, total_chunks: usize) -> Vec<String> {
(0..total_chunks)
.map(|i| format!("{base_key}:chunk:{i}"))
.collect()
}
pub fn metadata_key(base_key: &str) -> String {
format!("{base_key}:chunks")
}
pub fn config(&self) -> &ChunkingConfig {
&self.config
}
pub fn stats(&self) -> &ChunkingStats {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunking_config_defaults() {
let config = ChunkingConfig::default();
assert!(config.enabled);
assert_eq!(config.threshold_bytes, 524_288);
assert_eq!(config.chunk_size_bytes, 262_144);
assert!(config.checksum_enabled);
}
#[test]
fn test_needs_chunking() {
let chunker = ResultChunker::new(
ChunkingConfig::new()
.with_threshold(100)
.with_chunk_size(50),
);
let small = vec![0u8; 50];
assert!(!chunker.needs_chunking(&small));
let exact = vec![0u8; 100];
assert!(!chunker.needs_chunking(&exact));
let large = vec![0u8; 101];
assert!(chunker.needs_chunking(&large));
let disabled_chunker = ResultChunker::new(ChunkingConfig::disabled());
let huge = vec![0u8; 1_000_000];
assert!(!disabled_chunker.needs_chunking(&huge));
}
#[test]
fn test_split_and_reassemble() {
let chunker =
ResultChunker::new(ChunkingConfig::new().with_threshold(10).with_chunk_size(50));
for size in [0, 1, 49, 50, 51, 100, 150, 255, 1000] {
let data: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
let (metadata, chunks) = chunker.split_chunks(&data);
let reassembled = chunker
.reassemble_chunks(&metadata, &chunks)
.expect("reassemble failed");
assert_eq!(reassembled, data, "roundtrip failed for size {size}");
}
}
#[test]
fn test_split_exact_boundary() {
let chunk_size = 64;
let chunker = ResultChunker::new(
ChunkingConfig::new()
.with_threshold(0)
.with_chunk_size(chunk_size),
);
let data = vec![0xABu8; chunk_size * 3];
let (metadata, chunks) = chunker.split_chunks(&data);
assert_eq!(metadata.total_chunks, 3);
assert_eq!(chunks.len(), 3);
for chunk in &chunks {
assert_eq!(chunk.len(), chunk_size);
}
let reassembled = chunker
.reassemble_chunks(&metadata, &chunks)
.expect("reassemble failed");
assert_eq!(reassembled, data);
}
#[test]
fn test_split_one_byte_over() {
let chunk_size = 64;
let chunker = ResultChunker::new(
ChunkingConfig::new()
.with_threshold(0)
.with_chunk_size(chunk_size),
);
let data = vec![0xCDu8; chunk_size * 2 + 1];
let (metadata, chunks) = chunker.split_chunks(&data);
assert_eq!(metadata.total_chunks, 3);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].len(), chunk_size);
assert_eq!(chunks[1].len(), chunk_size);
assert_eq!(chunks[2].len(), 1);
let reassembled = chunker
.reassemble_chunks(&metadata, &chunks)
.expect("reassemble failed");
assert_eq!(reassembled, data);
}
#[test]
fn test_checksum_verification() {
let chunker = ResultChunker::new(
ChunkingConfig::new()
.with_threshold(0)
.with_chunk_size(50)
.with_checksum(true),
);
let data = vec![0xFFu8; 100];
let (metadata, mut chunks) = chunker.split_chunks(&data);
if let Some(byte) = chunks[1].first_mut() {
*byte = byte.wrapping_add(1);
}
let result = chunker.reassemble_chunks(&metadata, &chunks);
assert!(result.is_err());
assert!(result
.as_ref()
.err()
.is_some_and(|e| e.to_string().contains("checksum mismatch")));
assert_eq!(chunker.stats().checksum_failures(), 1);
}
#[test]
fn test_checksum_disabled() {
let chunker = ResultChunker::new(
ChunkingConfig::new()
.with_threshold(0)
.with_chunk_size(50)
.with_checksum(false),
);
let data = vec![0xFFu8; 100];
let (mut metadata, mut chunks) = chunker.split_chunks(&data);
if let Some(byte) = chunks[0].first_mut() {
*byte = byte.wrapping_add(1);
}
assert!(metadata.checksum.is_none());
metadata.total_size = chunks.iter().map(|c| c.len()).sum();
let result = chunker.reassemble_chunks(&metadata, &chunks);
assert!(result.is_ok());
}
#[test]
fn test_sentinel_roundtrip() {
let chunker = ResultChunker::new(ChunkingConfig::new());
let metadata = ChunkMetadata {
total_chunks: 5,
chunk_size: 256,
total_size: 1234,
checksum: Some(0xDEADBEEF),
created_at: 1700000000,
};
let sentinel = chunker.create_sentinel(&metadata);
assert!(ResultChunker::is_chunked(&sentinel));
let parsed = ResultChunker::parse_sentinel(&sentinel).expect("parse sentinel failed");
assert_eq!(parsed.total_chunks, 5);
assert_eq!(parsed.chunk_size, 256);
assert_eq!(parsed.total_size, 1234);
assert_eq!(parsed.checksum, Some(0xDEADBEEF));
assert_eq!(parsed.created_at, 1700000000);
}
#[test]
fn test_sentinel_not_chunked() {
let data = b"just regular data";
assert!(!ResultChunker::is_chunked(data));
let result = ResultChunker::parse_sentinel(data);
assert!(result.is_err());
}
#[test]
fn test_chunk_keys() {
let keys = ResultChunker::chunk_keys("celery-task-meta-abc123", 3);
assert_eq!(keys.len(), 3);
assert_eq!(keys[0], "celery-task-meta-abc123:chunk:0");
assert_eq!(keys[1], "celery-task-meta-abc123:chunk:1");
assert_eq!(keys[2], "celery-task-meta-abc123:chunk:2");
}
#[test]
fn test_metadata_key() {
let key = ResultChunker::metadata_key("celery-task-meta-abc123");
assert_eq!(key, "celery-task-meta-abc123:chunks");
}
#[test]
fn test_chunk_count_mismatch() {
let chunker =
ResultChunker::new(ChunkingConfig::new().with_threshold(0).with_chunk_size(50));
let data = vec![0u8; 100];
let (metadata, chunks) = chunker.split_chunks(&data);
assert_eq!(chunks.len(), 2);
let result = chunker.reassemble_chunks(&metadata, &chunks[..1]);
assert!(result.is_err());
assert!(result
.as_ref()
.err()
.is_some_and(|e| e.to_string().contains("chunk count mismatch")));
}
#[test]
fn test_empty_data() {
let chunker =
ResultChunker::new(ChunkingConfig::new().with_threshold(0).with_chunk_size(50));
let data: Vec<u8> = Vec::new();
let (metadata, chunks) = chunker.split_chunks(&data);
assert_eq!(metadata.total_chunks, 1);
assert_eq!(metadata.total_size, 0);
assert_eq!(chunks.len(), 1);
assert!(chunks[0].is_empty());
let reassembled = chunker
.reassemble_chunks(&metadata, &chunks)
.expect("reassemble failed");
assert!(reassembled.is_empty());
}
#[test]
fn test_stats_tracking() {
let chunker = ResultChunker::new(
ChunkingConfig::new()
.with_threshold(0)
.with_chunk_size(50)
.with_checksum(true),
);
let data = vec![0u8; 120];
let (metadata, chunks) = chunker.split_chunks(&data);
assert_eq!(chunker.stats().chunks_stored(), 3);
assert_eq!(chunker.stats().bytes_chunked(), 120);
let _reassembled = chunker
.reassemble_chunks(&metadata, &chunks)
.expect("reassemble failed");
assert_eq!(chunker.stats().chunks_loaded(), 3);
chunker.stats().reset();
assert_eq!(chunker.stats().chunks_stored(), 0);
assert_eq!(chunker.stats().chunks_loaded(), 0);
assert_eq!(chunker.stats().bytes_chunked(), 0);
assert_eq!(chunker.stats().checksum_failures(), 0);
}
}