use super::depth_limited::{would_exceed_depth_limit, MAX_DESERIALIZE_DEPTH};
use super::utils::{check_data_size, compress_data, decompress_data};
use super::Serializer;
use crate::error::{CacheError, Result};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
pub struct JsonSerializer {
compress: bool,
}
const MAX_JSON_SIZE: usize = 5 * 1024 * 1024;
mod byte_array {
use base64::prelude::*;
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(bytes: &[u8], serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = BASE64_STANDARD.encode(bytes);
serializer.serialize_str(&encoded)
}
pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let encoded = String::deserialize(deserializer)?;
BASE64_STANDARD.decode(&encoded).map_err(serde::de::Error::custom)
}
}
#[derive(Serialize, Deserialize)]
struct ByteArrayWrapper(#[serde(with = "byte_array")] Vec<u8>);
impl JsonSerializer {
pub fn new() -> Self {
Self { compress: false }
}
pub fn with_compression() -> Self {
Self { compress: true }
}
}
impl Default for JsonSerializer {
fn default() -> Self {
Self::new()
}
}
impl Serializer for JsonSerializer {
fn serialize(&self, _type_name: &str, data: &[u8]) -> Result<Vec<u8>> {
let wrapper = ByteArrayWrapper(data.to_vec());
let json_bytes = serde_json::to_vec(&wrapper).map_err(|e| CacheError::Serialization(e.to_string()))?;
if self.compress {
compress_data(&json_bytes)
} else {
Ok(json_bytes)
}
}
fn deserialize(&self, _type_name: &str, data: &[u8]) -> Result<Vec<u8>> {
check_data_size(data, MAX_JSON_SIZE, "JSON")?;
let json_bytes = if self.compress {
decompress_data(data)?
} else {
data.to_vec()
};
if would_exceed_depth_limit(&json_bytes, MAX_DESERIALIZE_DEPTH)
.map_err(|e| CacheError::Serialization(e.to_string()))?
{
return Err(CacheError::InvalidInput(format!(
"JSON 嵌套深度超过最大限制 {}",
MAX_DESERIALIZE_DEPTH
)));
}
let wrapper: ByteArrayWrapper =
serde_json::from_slice(&json_bytes).map_err(|e| CacheError::Serialization(e.to_string()))?;
Ok(wrapper.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base64_serialization() {
let serializer = JsonSerializer::new();
let data = vec![0, 1, 2, 255, 254, 253];
let serialized = serializer.serialize("test", &data).unwrap();
let json_str = String::from_utf8(serialized.clone()).unwrap();
assert!(json_str.contains("\""));
let deserialized = serializer.deserialize("test", &serialized).unwrap();
assert_eq!(data, deserialized);
}
#[test]
fn test_base64_vs_array_size() {
let serializer = JsonSerializer::new();
let data: Vec<u8> = (0..=255).collect();
let serialized = serializer.serialize("test", &data).unwrap();
let old_size = data.len() * 4;
let new_size = serialized.len();
assert!(new_size < old_size, "base64 编码应该比 JSON 数组更小");
}
#[test]
fn test_compression() {
let serializer = JsonSerializer::with_compression();
let data = vec![0u8; 1000];
let serialized = serializer.serialize("test", &data).unwrap();
let deserialized = serializer.deserialize("test", &serialized).unwrap();
assert_eq!(data, deserialized);
}
#[test]
fn test_empty_bytes() {
let serializer = JsonSerializer::new();
let data: Vec<u8> = vec![];
let serialized = serializer.serialize("test", &data).unwrap();
let deserialized = serializer.deserialize("test", &serialized).unwrap();
assert_eq!(data, deserialized);
}
#[test]
fn test_max_size_limit() {
let serializer = JsonSerializer::new();
let large_data = vec![0u8; MAX_JSON_SIZE + 1];
let result = serializer.deserialize("test", &large_data);
assert!(result.is_err());
}
}