#[allow(dead_code)]
mod support;
use std::alloc::{GlobalAlloc, Layout, System};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use structured_zstd::decoding::{DictionaryHandle, FrameDecoder};
use structured_zstd::encoding::FrameCompressor;
use support::{
LevelConfig, Scenario, ScenarioClass, benchmark_scenarios, build_training_samples,
dictionary_size_for, kernel_report_line, ldm_parameters, supported_levels_filtered,
};
struct TrackingAllocator;
static ALLOC_CURRENT: AtomicUsize = AtomicUsize::new(0);
static ALLOC_PEAK: AtomicUsize = AtomicUsize::new(0);
static ALLOC_BASELINE: AtomicUsize = AtomicUsize::new(0);
static TRACKING_ENABLED: AtomicBool = AtomicBool::new(false);
const HEADER_BYTES: usize = 16;
const FLAG_UNCOUNTED: u8 = 0;
const FLAG_COUNTED: u8 = 1;
#[inline]
fn tracker_header(layout: Layout) -> usize {
layout.align().max(HEADER_BYTES)
}
#[inline]
fn augmented_layout(layout: Layout) -> Option<Layout> {
let header = tracker_header(layout);
let total = layout.size().checked_add(header)?;
Layout::from_size_align(total, header).ok()
}
unsafe impl GlobalAlloc for TrackingAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let Some(augmented) = augmented_layout(layout) else {
return core::ptr::null_mut();
};
let header = tracker_header(layout);
let raw = unsafe { System.alloc(augmented) };
if raw.is_null() {
return raw;
}
let counted = TRACKING_ENABLED.load(Ordering::Relaxed);
unsafe {
*raw = if counted {
FLAG_COUNTED
} else {
FLAG_UNCOUNTED
};
}
if counted {
let prev = ALLOC_CURRENT.fetch_add(layout.size(), Ordering::Relaxed);
ALLOC_PEAK.fetch_max(prev + layout.size(), Ordering::Relaxed);
}
unsafe { raw.add(header) }
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let header = tracker_header(layout);
let raw = unsafe { ptr.sub(header) };
let counted = unsafe { *raw } == FLAG_COUNTED;
if counted {
ALLOC_CURRENT.fetch_sub(layout.size(), Ordering::Relaxed);
}
let augmented = Layout::from_size_align(layout.size() + header, header)
.expect("layout round-trips on dealloc");
unsafe { System.dealloc(raw, augmented) };
}
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
let header = tracker_header(layout);
let raw = unsafe { ptr.sub(header) };
let counted = unsafe { *raw } == FLAG_COUNTED;
let Some(new_total) = new_size.checked_add(header) else {
return core::ptr::null_mut();
};
let old_augmented = Layout::from_size_align(layout.size() + header, header)
.expect("layout round-trips on realloc");
let new_raw = unsafe { System.realloc(raw, old_augmented, new_total) };
if new_raw.is_null() {
return core::ptr::null_mut();
}
if counted {
if new_size >= layout.size() {
let delta = new_size - layout.size();
let prev = ALLOC_CURRENT.fetch_add(delta, Ordering::Relaxed);
ALLOC_PEAK.fetch_max(prev + delta, Ordering::Relaxed);
} else {
let delta = layout.size() - new_size;
ALLOC_CURRENT.fetch_sub(delta, Ordering::Relaxed);
}
}
unsafe { new_raw.add(header) }
}
}
#[global_allocator]
static GLOBAL: TrackingAllocator = TrackingAllocator;
fn measure_peak<R>(f: impl FnOnce() -> R) -> (R, usize) {
struct Guard;
impl Drop for Guard {
fn drop(&mut self) {
TRACKING_ENABLED.store(false, Ordering::Relaxed);
}
}
let baseline = ALLOC_CURRENT.load(Ordering::Relaxed);
ALLOC_BASELINE.store(baseline, Ordering::Relaxed);
ALLOC_PEAK.store(baseline, Ordering::Relaxed);
TRACKING_ENABLED.store(true, Ordering::Relaxed);
let _g = Guard;
let result = f();
let peak = ALLOC_PEAK.load(Ordering::Relaxed);
(result, peak.saturating_sub(baseline))
}
unsafe extern "C" fn ffi_alloc(
_opaque: *mut core::ffi::c_void,
size: usize,
) -> *mut core::ffi::c_void {
const FFI_HEADER: usize = 16;
const FFI_ALIGN: usize = 16;
let Some(total) = size.checked_add(FFI_HEADER) else {
return core::ptr::null_mut();
};
let Ok(layout) = Layout::from_size_align(total, FFI_ALIGN) else {
return core::ptr::null_mut();
};
let raw = unsafe { System.alloc(layout) };
if raw.is_null() {
return core::ptr::null_mut();
}
unsafe {
core::ptr::write(raw as *mut usize, size);
}
if TRACKING_ENABLED.load(Ordering::Relaxed) {
let prev = ALLOC_CURRENT.fetch_add(size, Ordering::Relaxed);
ALLOC_PEAK.fetch_max(prev + size, Ordering::Relaxed);
}
unsafe { raw.add(FFI_HEADER) as *mut core::ffi::c_void }
}
unsafe extern "C" fn ffi_free(_opaque: *mut core::ffi::c_void, address: *mut core::ffi::c_void) {
const FFI_HEADER: usize = 16;
const FFI_ALIGN: usize = 16;
if address.is_null() {
return;
}
let header_ptr = unsafe { (address as *mut u8).sub(FFI_HEADER) };
let size = unsafe { core::ptr::read(header_ptr as *const usize) };
let layout = Layout::from_size_align(size + FFI_HEADER, FFI_ALIGN)
.expect("layout round-trips from ffi_alloc");
if TRACKING_ENABLED.load(Ordering::Relaxed) {
ALLOC_CURRENT.fetch_sub(size, Ordering::Relaxed);
}
unsafe { System.dealloc(header_ptr, layout) };
}
fn ffi_custom_mem() -> zstd::zstd_safe::zstd_sys::ZSTD_customMem {
zstd::zstd_safe::zstd_sys::ZSTD_customMem {
customAlloc: Some(ffi_alloc),
customFree: Some(ffi_free),
opaque: core::ptr::null_mut(),
}
}
fn ffi_encode(input: &[u8], level: i32, ldm: bool, dict: Option<&[u8]>) -> Vec<u8> {
use zstd::zstd_safe::zstd_sys;
let cctx = unsafe { zstd_sys::ZSTD_createCCtx_advanced(ffi_custom_mem()) };
assert!(!cctx.is_null(), "ZSTD_createCCtx_advanced returned null");
unsafe {
let rc = zstd_sys::ZSTD_CCtx_setParameter(
cctx,
zstd_sys::ZSTD_cParameter::ZSTD_c_compressionLevel,
level,
);
assert!(zstd_sys::ZSTD_isError(rc) == 0);
let rc = zstd_sys::ZSTD_CCtx_setParameter(
cctx,
zstd_sys::ZSTD_cParameter::ZSTD_c_checksumFlag,
if cfg!(feature = "hash") { 1 } else { 0 },
);
assert!(zstd_sys::ZSTD_isError(rc) == 0);
if ldm {
let rc = zstd_sys::ZSTD_CCtx_setParameter(
cctx,
zstd_sys::ZSTD_cParameter::ZSTD_c_enableLongDistanceMatching,
1,
);
assert!(zstd_sys::ZSTD_isError(rc) == 0);
}
let rc = zstd_sys::ZSTD_CCtx_setParameter(
cctx,
zstd_sys::ZSTD_cParameter::ZSTD_c_contentSizeFlag,
1,
);
assert!(zstd_sys::ZSTD_isError(rc) == 0);
if input.len() <= (1 << 14) {
let rc = zstd_sys::ZSTD_CCtx_setParameter(
cctx,
zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog,
14,
);
assert!(zstd_sys::ZSTD_isError(rc) == 0);
}
if let Some(dict) = dict {
let rc = zstd_sys::ZSTD_CCtx_loadDictionary(
cctx,
dict.as_ptr() as *const core::ffi::c_void,
dict.len(),
);
assert!(
zstd_sys::ZSTD_isError(rc) == 0,
"CCtx_loadDictionary failed"
);
}
let rc = zstd_sys::ZSTD_CCtx_setPledgedSrcSize(cctx, input.len() as u64);
assert!(zstd_sys::ZSTD_isError(rc) == 0);
let recommended_in = zstd_sys::ZSTD_CStreamInSize();
let recommended_out = zstd_sys::ZSTD_CStreamOutSize();
let mut output: Vec<u8> = Vec::new();
let mut chunk = vec![0u8; recommended_out];
let mut in_pos: usize = 0;
loop {
let chunk_end = (in_pos + recommended_in).min(input.len());
let mut zin = zstd_sys::ZSTD_inBuffer {
src: input.as_ptr() as *const core::ffi::c_void,
size: chunk_end,
pos: in_pos,
};
let mode = if chunk_end == input.len() {
zstd_sys::ZSTD_EndDirective::ZSTD_e_end
} else {
zstd_sys::ZSTD_EndDirective::ZSTD_e_continue
};
loop {
let mut zout = zstd_sys::ZSTD_outBuffer {
dst: chunk.as_mut_ptr() as *mut core::ffi::c_void,
size: chunk.len(),
pos: 0,
};
let remaining = zstd_sys::ZSTD_compressStream2(cctx, &mut zout, &mut zin, mode);
assert!(zstd_sys::ZSTD_isError(remaining) == 0);
output.extend_from_slice(&chunk[..zout.pos]);
let frame_done =
matches!(mode, zstd_sys::ZSTD_EndDirective::ZSTD_e_end) && remaining == 0;
let chunk_done = matches!(mode, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue)
&& zin.pos == zin.size;
if frame_done || chunk_done {
break;
}
}
in_pos = zin.pos;
if in_pos == input.len() && matches!(mode, zstd_sys::ZSTD_EndDirective::ZSTD_e_end) {
break;
}
}
zstd_sys::ZSTD_freeCCtx(cctx);
output
}
}
fn ffi_decode(compressed: &[u8], expected_len: usize, dict: Option<&[u8]>) -> Vec<u8> {
use zstd::zstd_safe::zstd_sys;
let dctx = unsafe { zstd_sys::ZSTD_createDCtx_advanced(ffi_custom_mem()) };
assert!(!dctx.is_null(), "ZSTD_createDCtx_advanced returned null");
unsafe {
if let Some(dict) = dict {
let rc = zstd_sys::ZSTD_DCtx_loadDictionary(
dctx,
dict.as_ptr() as *const core::ffi::c_void,
dict.len(),
);
assert!(
zstd_sys::ZSTD_isError(rc) == 0,
"DCtx_loadDictionary failed"
);
}
let mut output = vec![0u8; expected_len];
let written = zstd_sys::ZSTD_decompressDCtx(
dctx,
output.as_mut_ptr() as *mut core::ffi::c_void,
output.len(),
compressed.as_ptr() as *const core::ffi::c_void,
compressed.len(),
);
assert!(zstd_sys::ZSTD_isError(written) == 0);
assert_eq!(
written, expected_len,
"ffi_decode wrote {written} bytes, expected {expected_len}",
);
output.truncate(written);
zstd_sys::ZSTD_freeDCtx(dctx);
output
}
}
fn escape_report_label(label: &str) -> String {
label.replace('\\', "\\\\").replace('\"', "\\\"")
}
fn emit_report(
scenario: &Scenario,
level: LevelConfig,
stage: &str,
rust_peak: usize,
ffi_peak: usize,
) {
let escaped = escape_report_label(&scenario.label);
println!(
"REPORT_MEM scenario={} label=\"{}\" level={} stage={} rust_peak_alloc_bytes={} ffi_peak_alloc_bytes={}",
scenario.id, escaped, level.name, stage, rust_peak, ffi_peak
);
}
fn train_ffi_dictionary(source: &[u8]) -> Option<Vec<u8>> {
let samples = build_training_samples(source);
let max_dict_size = source.len().saturating_sub(64);
let dict_size = dictionary_size_for(source.len())
.max(256)
.min(max_dict_size);
zstd::dict::from_samples(&samples, dict_size).ok()
}
fn main() {
println!("{}", kernel_report_line());
let scenarios = benchmark_scenarios();
for scenario in &scenarios {
for level in supported_levels_filtered() {
let ldm_params = ldm_parameters(&level);
let expected_len = scenario.len();
if level.dict {
if !matches!(scenario.class, ScenarioClass::Small | ScenarioClass::Corpus) {
continue;
}
let Some(dict) = train_ffi_dictionary(&scenario.bytes) else {
eprintln!(
"BENCH_WARN skipping dict memory variant for {} (FastCOVER training failed)",
scenario.id
);
continue;
};
let (rust_compressed, rust_peak) = measure_peak(|| {
let mut compressor: FrameCompressor = FrameCompressor::new(level.rust_level);
if let Some(params) = &ldm_params {
compressor.set_parameters(params);
}
compressor.set_content_checksum(cfg!(feature = "hash"));
compressor
.set_dictionary_from_bytes(&dict)
.expect("dictionary should attach");
let mut out = Vec::new();
compressor.compress_independent_frame_into(&scenario.bytes[..], &mut out);
out
});
let (ffi_compressed, ffi_peak) = measure_peak(|| {
ffi_encode(&scenario.bytes[..], level.ffi_level, level.ldm, Some(&dict))
});
emit_report(scenario, level, "compress", rust_peak, ffi_peak);
for (source, compressed) in [
("rust_stream", &rust_compressed),
("c_stream", &ffi_compressed),
] {
let (_, rust_decode_peak) = measure_peak(|| {
let handle = DictionaryHandle::decode_dict(&dict)
.expect("dictionary handle parse should succeed");
let mut target = vec![0u8; expected_len];
let mut decoder = FrameDecoder::new();
let written = decoder
.decode_all_with_dict_handle(
compressed.as_slice(),
&mut target,
&handle,
)
.expect("rust decode-with-dict should succeed");
assert_eq!(written, expected_len);
target
});
let (_, ffi_decode_peak) = measure_peak(|| {
ffi_decode(compressed.as_slice(), expected_len, Some(&dict))
});
emit_report(
scenario,
level,
&format!("decompress-{source}"),
rust_decode_peak,
ffi_decode_peak,
);
}
continue;
}
let (rust_compressed, rust_peak) = measure_peak(|| {
let mut compressor: FrameCompressor = FrameCompressor::new(level.rust_level);
if let Some(params) = &ldm_params {
compressor.set_parameters(params);
}
compressor.set_content_checksum(cfg!(feature = "hash"));
compressor.compress_independent_frame(&scenario.bytes[..])
});
let (ffi_compressed, ffi_peak) =
measure_peak(|| ffi_encode(&scenario.bytes[..], level.ffi_level, level.ldm, None));
emit_report(scenario, level, "compress", rust_peak, ffi_peak);
for (source, compressed) in [
("rust_stream", &rust_compressed),
("c_stream", &ffi_compressed),
] {
let (_, rust_decode_peak) = measure_peak(|| {
let mut target = vec![0u8; expected_len];
let mut decoder = FrameDecoder::new();
let written = decoder
.decode_all(compressed.as_slice(), &mut target)
.unwrap();
assert_eq!(written, expected_len);
target
});
let (_, ffi_decode_peak) =
measure_peak(|| ffi_decode(compressed.as_slice(), expected_len, None));
emit_report(
scenario,
level,
&format!("decompress-{source}"),
rust_decode_peak,
ffi_decode_peak,
);
}
}
}
}