use bincode::Options;
use serde::de::DeserializeOwned;
use std::io::Read;
const MAX_EXPANSION_FACTOR: u64 = 4;
const MIN_SIZE_LIMIT: u64 = 1024 * 1024;
const MAX_SIZE_LIMIT: u64 = 1024 * 1024 * 1024;
#[inline]
pub fn deserialize_with_limit<T: DeserializeOwned>(bytes: &[u8]) -> Result<T, bincode::Error> {
let limit = calculate_limit(bytes.len() as u64);
bincode::options().with_limit(limit).deserialize(bytes)
}
#[inline]
pub fn deserialize_from_with_limit<T: DeserializeOwned, R: Read>(
reader: R,
expected_size: u64,
) -> Result<T, bincode::Error> {
let limit = calculate_limit(expected_size);
bincode::options()
.with_limit(limit)
.deserialize_from(reader)
}
#[inline]
fn calculate_limit(input_size: u64) -> u64 {
let expanded = input_size.saturating_mul(MAX_EXPANSION_FACTOR);
expanded.clamp(MIN_SIZE_LIMIT, MAX_SIZE_LIMIT)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct TestStruct {
name: String,
values: Vec<u32>,
}
fn serialize_with_options<T: Serialize>(value: &T) -> Result<Vec<u8>, bincode::Error> {
bincode::options().serialize(value)
}
#[test]
fn test_deserialize_valid() {
let original = TestStruct { name: "test".to_owned(), values: vec![1, 2, 3, 4, 5] };
let bytes = serialize_with_options(&original).unwrap();
let restored: TestStruct = deserialize_with_limit(&bytes).unwrap();
assert_eq!(original, restored);
}
#[test]
fn test_deserialize_empty() {
let original: Vec<u8> = Vec::new();
let bytes = serialize_with_options(&original).unwrap();
let restored: Vec<u8> = deserialize_with_limit(&bytes).unwrap();
assert_eq!(original, restored);
}
#[test]
fn test_calculate_limit() {
assert_eq!(calculate_limit(100), MIN_SIZE_LIMIT);
let medium = 10 * 1024 * 1024; assert_eq!(calculate_limit(medium), medium * MAX_EXPANSION_FACTOR);
let huge = 10 * 1024 * 1024 * 1024; assert_eq!(calculate_limit(huge), MAX_SIZE_LIMIT);
}
#[test]
fn test_deserialize_from_reader() {
let original = TestStruct { name: "reader_test".to_owned(), values: vec![10, 20, 30] };
let bytes = serialize_with_options(&original).unwrap();
let cursor = std::io::Cursor::new(&bytes);
let restored: TestStruct = deserialize_from_with_limit(cursor, bytes.len() as u64).unwrap();
assert_eq!(original, restored);
}
}