use crate::byte_storage::{ByteStorage, ByteStorageError};
use crate::ffi::error::CachekitError;
use std::panic::catch_unwind;
use std::slice;
#[no_mangle]
pub unsafe extern "C" fn cachekit_compress(
input: *const u8,
input_len: usize,
output: *mut u8,
output_len: *mut usize,
) -> CachekitError {
let result = catch_unwind(|| {
if input.is_null() {
return CachekitError::NullPointer;
}
if output.is_null() {
return CachekitError::NullPointer;
}
if output_len.is_null() {
return CachekitError::NullPointer;
}
let available_size = unsafe { *output_len };
let input_slice = unsafe { slice::from_raw_parts(input, input_len) };
let storage = ByteStorage::new(None);
let compressed = match storage.store(input_slice, None) {
Ok(data) => data,
Err(e) => {
return match e {
ByteStorageError::InputTooLarge => CachekitError::InputTooLarge,
ByteStorageError::SerializationFailed(_) => CachekitError::InvalidInput,
_ => CachekitError::InvalidInput,
};
}
};
if compressed.len() > available_size {
unsafe {
*output_len = compressed.len();
}
return CachekitError::BufferTooSmall;
}
unsafe {
std::ptr::copy_nonoverlapping(compressed.as_ptr(), output, compressed.len());
*output_len = compressed.len();
}
CachekitError::Ok
});
result.unwrap_or(CachekitError::InvalidInput)
}
#[no_mangle]
pub unsafe extern "C" fn cachekit_decompress(
input: *const u8,
input_len: usize,
output: *mut u8,
output_len: *mut usize,
) -> CachekitError {
let result = catch_unwind(|| {
if input.is_null() {
return CachekitError::NullPointer;
}
if output.is_null() {
return CachekitError::NullPointer;
}
if output_len.is_null() {
return CachekitError::NullPointer;
}
let available_size = unsafe { *output_len };
let input_slice = unsafe { slice::from_raw_parts(input, input_len) };
let storage = ByteStorage::new(None);
let (decompressed, _format) = match storage.retrieve(input_slice) {
Ok(data) => data,
Err(e) => {
return match e {
ByteStorageError::ChecksumMismatch => CachekitError::ChecksumMismatch,
ByteStorageError::DecompressionFailed => CachekitError::DecompressionFailed,
ByteStorageError::DecompressionBomb => CachekitError::DecompressionBomb,
ByteStorageError::InputTooLarge => CachekitError::InputTooLarge,
ByteStorageError::SizeValidationFailed => CachekitError::SizeValidationFailed,
ByteStorageError::DeserializationFailed(_) => CachekitError::InvalidInput,
_ => CachekitError::InvalidInput,
};
}
};
if decompressed.len() > available_size {
unsafe {
*output_len = decompressed.len();
}
return CachekitError::BufferTooSmall;
}
unsafe {
std::ptr::copy_nonoverlapping(decompressed.as_ptr(), output, decompressed.len());
*output_len = decompressed.len();
}
CachekitError::Ok
});
result.unwrap_or(CachekitError::DecompressionFailed)
}
#[no_mangle]
pub extern "C" fn cachekit_compressed_bound(input_len: usize) -> usize {
let lz4_bound = input_len.saturating_add(input_len / 255).saturating_add(16);
let msgpack_overhead = 120;
lz4_bound.saturating_add(msgpack_overhead)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compress_decompress_roundtrip() {
let input = b"Hello, World! This is a test of the FFI compression interface.";
let input_len = input.len();
let mut compressed = vec![0u8; cachekit_compressed_bound(input_len)];
let mut compressed_len = compressed.len();
let result = unsafe {
cachekit_compress(
input.as_ptr(),
input_len,
compressed.as_mut_ptr(),
&mut compressed_len,
)
};
assert_eq!(result, CachekitError::Ok);
assert!(compressed_len > 0);
assert!(compressed_len < compressed.len());
let mut decompressed = vec![0u8; input_len + 1000]; let mut decompressed_len = decompressed.len();
let result = unsafe {
cachekit_decompress(
compressed.as_ptr(),
compressed_len,
decompressed.as_mut_ptr(),
&mut decompressed_len,
)
};
assert_eq!(result, CachekitError::Ok);
assert_eq!(decompressed_len, input_len);
assert_eq!(&decompressed[..decompressed_len], input);
}
#[test]
fn test_null_pointer_checks() {
let input = b"test";
let mut output = vec![0u8; 1024];
let mut output_len = output.len();
assert_eq!(
unsafe { cachekit_compress(std::ptr::null(), 4, output.as_mut_ptr(), &mut output_len) },
CachekitError::NullPointer
);
assert_eq!(
unsafe { cachekit_compress(input.as_ptr(), 4, std::ptr::null_mut(), &mut output_len) },
CachekitError::NullPointer
);
assert_eq!(
unsafe {
cachekit_compress(input.as_ptr(), 4, output.as_mut_ptr(), std::ptr::null_mut())
},
CachekitError::NullPointer
);
}
#[test]
fn test_buffer_too_small() {
let input = b"Hello, World!";
let input_len = input.len();
let mut compressed = vec![0u8; 10]; let mut compressed_len = compressed.len();
let result = unsafe {
cachekit_compress(
input.as_ptr(),
input_len,
compressed.as_mut_ptr(),
&mut compressed_len,
)
};
assert_eq!(result, CachekitError::BufferTooSmall);
assert!(compressed_len > 10);
}
#[test]
fn test_checksum_mismatch() {
let input = b"Hello, World!";
let input_len = input.len();
let mut compressed = vec![0u8; cachekit_compressed_bound(input_len)];
let mut compressed_len = compressed.len();
let result = unsafe {
cachekit_compress(
input.as_ptr(),
input_len,
compressed.as_mut_ptr(),
&mut compressed_len,
)
};
assert_eq!(result, CachekitError::Ok);
if compressed_len > 10 {
compressed[compressed_len / 2] ^= 0xFF;
}
let mut decompressed = vec![0u8; input_len + 1000];
let mut decompressed_len = decompressed.len();
let result = unsafe {
cachekit_decompress(
compressed.as_ptr(),
compressed_len,
decompressed.as_mut_ptr(),
&mut decompressed_len,
)
};
assert!(result != CachekitError::Ok);
}
#[test]
fn test_compressed_bound_adequate() {
for size in [0, 1, 10, 100, 1000, 10000] {
let bound = cachekit_compressed_bound(size);
assert!(bound > size); assert!(bound < size * 2 + 200); }
}
#[test]
fn test_empty_data_roundtrip() {
let input: &[u8] = b"";
let input_len = input.len();
assert_eq!(input_len, 0);
let mut compressed = vec![0u8; cachekit_compressed_bound(input_len)];
let mut compressed_len = compressed.len();
let result = unsafe {
cachekit_compress(
input.as_ptr(),
input_len,
compressed.as_mut_ptr(),
&mut compressed_len,
)
};
assert_eq!(result, CachekitError::Ok);
assert!(compressed_len > 0);
let mut decompressed = vec![0u8; 100]; let mut decompressed_len = decompressed.len();
let result = unsafe {
cachekit_decompress(
compressed.as_ptr(),
compressed_len,
decompressed.as_mut_ptr(),
&mut decompressed_len,
)
};
assert_eq!(result, CachekitError::Ok);
assert_eq!(decompressed_len, 0); }
}