#![deny(unsafe_code)]
use crate::types::{MessagePriority, NetworkError, NetworkMessage, PeerId};
use blake3::Hash;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, warn};
const MAX_CHUNK_SIZE: usize = 65536;
const MAX_CHUNKS: usize = 10000;
const CHUNK_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkHeader {
pub message_id: String,
pub total_chunks: u32,
pub chunk_index: u32,
pub chunk_size: usize,
pub message_hash: [u8; 32],
pub original_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkedMessage {
pub header: ChunkHeader,
pub data: Vec<u8>,
}
#[derive(Debug)]
pub struct StreamingChunk {
pub data: Vec<u8>,
pub received_at: Instant,
}
#[derive(Debug)]
pub struct ReassemblyState {
pub chunks: HashMap<u32, StreamingChunk>,
pub total_chunks: u32,
pub original_size: usize,
pub message_hash: [u8; 32],
pub started_at: Instant,
pub last_activity: Instant,
}
pub struct MessageChunker {
reassembly_states: Arc<RwLock<HashMap<String, ReassemblyState>>>,
chunk_cache: Arc<Mutex<lru::LruCache<String, Vec<u8>>>>,
config: ChunkerConfig,
}
#[derive(Debug, Clone)]
pub struct ChunkerConfig {
pub max_chunk_size: usize,
pub chunk_timeout: Duration,
pub enable_compression: bool,
pub compression_threshold: usize,
pub cache_size: usize,
}
impl Default for ChunkerConfig {
fn default() -> Self {
Self {
max_chunk_size: MAX_CHUNK_SIZE,
chunk_timeout: CHUNK_TIMEOUT,
enable_compression: true,
compression_threshold: 1024,
cache_size: 1000,
}
}
}
impl MessageChunker {
pub fn new(config: ChunkerConfig) -> Self {
Self {
reassembly_states: Arc::new(RwLock::new(HashMap::new())),
chunk_cache: Arc::new(Mutex::new(lru::LruCache::new(
std::num::NonZeroUsize::new(config.cache_size).unwrap(),
))),
config,
}
}
pub async fn chunk_message(
&self,
message: &NetworkMessage,
) -> Result<Vec<ChunkedMessage>, NetworkError> {
let payload = &message.payload;
if payload.len() <= self.config.max_chunk_size {
return Ok(vec![]); }
let data = if self.config.enable_compression && payload.len() > self.config.compression_threshold {
self.compress_data(payload)?
} else {
payload.clone()
};
let message_hash = blake3::hash(&data);
let total_chunks = ((data.len() + self.config.max_chunk_size - 1) / self.config.max_chunk_size) as u32;
if total_chunks > MAX_CHUNKS as u32 {
return Err(NetworkError::ValidationError(
format!("Message too large: {} chunks exceeds maximum {}", total_chunks, MAX_CHUNKS)
));
}
let mut chunks = Vec::with_capacity(total_chunks as usize);
for (index, chunk_data) in data.chunks(self.config.max_chunk_size).enumerate() {
let header = ChunkHeader {
message_id: message.id.clone(),
total_chunks,
chunk_index: index as u32,
chunk_size: chunk_data.len(),
message_hash: *message_hash.as_bytes(),
original_size: payload.len(),
};
chunks.push(ChunkedMessage {
header,
data: chunk_data.to_vec(),
});
}
debug!(
"Chunked message {} into {} chunks (original: {} bytes, chunked: {} bytes)",
message.id,
chunks.len(),
payload.len(),
data.len()
);
Ok(chunks)
}
pub async fn process_chunk(
&self,
chunk: ChunkedMessage,
) -> Result<Option<Vec<u8>>, NetworkError> {
let message_id = chunk.header.message_id.clone();
self.validate_chunk(&chunk)?;
if let Some(cached) = self.chunk_cache.lock().await.get(&message_id) {
return Ok(Some(cached.clone()));
}
let mut states = self.reassembly_states.write().await;
let state = states.entry(message_id.clone()).or_insert_with(|| {
ReassemblyState {
chunks: HashMap::new(),
total_chunks: chunk.header.total_chunks,
original_size: chunk.header.original_size,
message_hash: chunk.header.message_hash,
started_at: Instant::now(),
last_activity: Instant::now(),
}
});
state.last_activity = Instant::now();
if state.total_chunks != chunk.header.total_chunks {
return Err(NetworkError::ValidationError(
"Inconsistent chunk count".into()
));
}
state.chunks.insert(
chunk.header.chunk_index,
StreamingChunk {
data: chunk.data,
received_at: Instant::now(),
},
);
if state.chunks.len() == state.total_chunks as usize {
let reassembled = self.reassemble_message(state)?;
self.chunk_cache.lock().await.put(message_id.clone(), reassembled.clone());
states.remove(&message_id);
Ok(Some(reassembled))
} else {
debug!(
"Received chunk {}/{} for message {}",
chunk.header.chunk_index + 1,
state.total_chunks,
message_id
);
Ok(None)
}
}
fn validate_chunk(&self, chunk: &ChunkedMessage) -> Result<(), NetworkError> {
if chunk.header.chunk_index >= chunk.header.total_chunks {
return Err(NetworkError::ValidationError(
"Invalid chunk index".into()
));
}
if chunk.data.len() != chunk.header.chunk_size {
return Err(NetworkError::ValidationError(
"Chunk size mismatch".into()
));
}
if chunk.header.chunk_size > self.config.max_chunk_size {
return Err(NetworkError::ValidationError(
"Chunk size exceeds maximum".into()
));
}
Ok(())
}
fn reassemble_message(&self, state: &ReassemblyState) -> Result<Vec<u8>, NetworkError> {
let mut data = Vec::with_capacity(state.original_size);
for i in 0..state.total_chunks {
let chunk = state.chunks.get(&i)
.ok_or_else(|| NetworkError::ValidationError(
format!("Missing chunk {}", i)
))?;
data.extend_from_slice(&chunk.data);
}
let computed_hash = blake3::hash(&data);
if *computed_hash.as_bytes() != state.message_hash {
return Err(NetworkError::ValidationError(
"Message hash verification failed".into()
));
}
if self.config.enable_compression && data.len() < state.original_size {
self.decompress_data(&data)
} else {
Ok(data)
}
}
pub async fn cleanup_expired(&self) {
let mut states = self.reassembly_states.write().await;
let now = Instant::now();
states.retain(|id, state| {
let expired = now.duration_since(state.last_activity) < self.config.chunk_timeout;
if !expired {
warn!(
"Cleaning up expired message reassembly for {} ({}/{} chunks received)",
id,
state.chunks.len(),
state.total_chunks
);
}
expired
});
}
fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>, NetworkError> {
zstd::encode_all(data, 3)
.map_err(|e| NetworkError::Internal(format!("Compression failed: {}", e)))
}
fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>, NetworkError> {
zstd::decode_all(data)
.map_err(|e| NetworkError::Internal(format!("Decompression failed: {}", e)))
}
pub async fn get_stats(&self) -> ChunkerStats {
let states = self.reassembly_states.read().await;
let cache = self.chunk_cache.lock().await;
ChunkerStats {
active_reassemblies: states.len(),
cache_size: cache.len(),
total_chunks_waiting: states.values()
.map(|s| s.chunks.len())
.sum(),
}
}
}
#[derive(Debug, Clone)]
pub struct ChunkerStats {
pub active_reassemblies: usize,
pub cache_size: usize,
pub total_chunks_waiting: usize,
}
pub trait ChunkableMessage {
fn needs_chunking(&self, max_size: usize) -> bool;
fn into_chunked(self) -> ChunkedNetworkMessage;
}
impl ChunkableMessage for NetworkMessage {
fn needs_chunking(&self, max_size: usize) -> bool {
self.payload.len() > max_size
}
fn into_chunked(self) -> ChunkedNetworkMessage {
ChunkedNetworkMessage {
base: self,
chunks: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ChunkedNetworkMessage {
pub base: NetworkMessage,
pub chunks: Option<Vec<ChunkedMessage>>,
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
#[tokio::test]
async fn test_message_chunking() {
let config = ChunkerConfig {
max_chunk_size: 1024,
..Default::default()
};
let chunker = MessageChunker::new(config);
let message = NetworkMessage {
id: Uuid::new_v4().to_string(),
source: vec![1],
destination: vec![2],
payload: vec![0u8; 3000], priority: MessagePriority::Normal,
ttl: Duration::from_secs(60),
};
let chunks = chunker.chunk_message(&message).await.unwrap();
assert_eq!(chunks.len(), 3);
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.header.chunk_index, i as u32);
assert_eq!(chunk.header.total_chunks, 3);
assert!(chunk.header.chunk_size <= 1024);
}
}
#[tokio::test]
async fn test_message_reassembly() {
let config = ChunkerConfig {
max_chunk_size: 1024,
enable_compression: false,
..Default::default()
};
let chunker = MessageChunker::new(config);
let original_data = vec![42u8; 2500];
let message = NetworkMessage {
id: Uuid::new_v4().to_string(),
source: vec![1],
destination: vec![2],
payload: original_data.clone(),
priority: MessagePriority::Normal,
ttl: Duration::from_secs(60),
};
let chunks = chunker.chunk_message(&message).await.unwrap();
let mut reassembled_data = None;
for chunk in chunks {
if let Some(data) = chunker.process_chunk(chunk).await.unwrap() {
reassembled_data = Some(data);
}
}
assert_eq!(reassembled_data.unwrap(), original_data);
}
#[tokio::test]
async fn test_out_of_order_reassembly() {
let config = ChunkerConfig {
max_chunk_size: 1024,
enable_compression: false,
..Default::default()
};
let chunker = MessageChunker::new(config);
let original_data = vec![99u8; 3072]; let message = NetworkMessage {
id: Uuid::new_v4().to_string(),
source: vec![1],
destination: vec![2],
payload: original_data.clone(),
priority: MessagePriority::Normal,
ttl: Duration::from_secs(60),
};
let chunks = chunker.chunk_message(&message).await.unwrap();
let mut reassembled_data = None;
chunker.process_chunk(chunks[2].clone()).await.unwrap();
chunker.process_chunk(chunks[0].clone()).await.unwrap();
if let Some(data) = chunker.process_chunk(chunks[1].clone()).await.unwrap() {
reassembled_data = Some(data);
}
assert_eq!(reassembled_data.unwrap(), original_data);
}
}