use std::io::Read;
use std::ptr::null_mut;
use log::info;
use zstd::zstd_safe::compress_bound;
use crate::lol::io::bytes;
use crate::lol::io::bytes::MMap;
use crate::lol::io::bytes::Impl;
use crate::{lol_trace_func, lol_throw_if};
#[derive(Debug, Clone)]
pub struct Buffer {
pub data_: *mut u8,
pub size_: usize,
pub impl_: Impl
}
impl Default for Buffer {
fn default() -> Self {
let impl_: Impl = Impl { ref_count: 0, vec: Vec::<u8>::new(), mmp: MMap { file: null_mut(), data: null_mut(), size: 0 } };
Self { data_: null_mut(), size_: 0, impl_ }
}
}
impl Buffer {
pub const KIB: usize = 1024;
pub const MIB: usize = 1024 * Self::KIB;
pub const GIB: usize = 1024 * Self::MIB;
pub fn data(&mut self) -> *mut u8 {
self.data_
}
pub fn size(&self) -> usize {
self.size_
}
pub fn write_decompress_zstd(&mut self, pos: usize, count: usize, src: &[u8], src_count: usize) -> Result<(), std::io::Error> {
let trace = lol_trace_func!(write_decompress_zstd, lol_trace_var!("{:#x}", size_),
lol_trace_var!("{:#x}", pos),
lol_trace_var!("{:#x}", count),
lol_trace_var!("{:p}", src),
lol_trace_var!("{:#x}", src_count));
let maxendpos = pos +count;
lol_throw_if!(maxendpos< pos, trace);
lol_throw_if!(self.impl_reserve(maxendpos), trace);
let mut de = zstd::Decoder::new(&src[..src_count])
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to create zstd decoder: {}", e)))?;
de.read_exact(&mut self.impl_.vec)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to decompress zstd data: {}", e)))?;
self.size_ = self.size_.max(pos +count);
Ok(())
}
pub fn impl_reserve(&mut self, size: usize) -> bool {
self.impl_.vec.reserve(size);
self.impl_.vec.resize(size, 0u8);
if self.impl_.vec.len() == size {
false
} else {
true
}
}
pub fn write_decompress_zstd_hack(&mut self, mut pos: usize, mut count: usize, src: &[u8], mut src_count: usize) -> Result<(), std::io::Error> {
let trace = lol_trace_func!(write_decompress_zstd_hack, lol_trace_var!("{:#x}", self.size_),
lol_trace_var!("{:#x}", pos),
lol_trace_var!("{:#x}", count),
lol_trace_var!("{:p}", src),
lol_trace_var!("{:#x}", src_count));
let maxendpos = pos + count;
lol_throw_if!(maxendpos < pos, trace);
lol_throw_if!(self.impl_reserve(maxendpos), trace);
let farst_frame_start = Self::find_zstd_magic(&src);
println!("ZstdMulti in write_decompress_zstd_hack frame_start {:x}", farst_frame_start.unwrap());
if farst_frame_start.is_some() {
let mut i = 0;
while let Some(next_farme_start) = Self::find_zstd_magic(&src[i..]) {
let mut pos = pos;
let mut count = count;
let mut src_count = src_count;
pos += next_farme_start+i;
count -= next_farme_start+i;
src_count -= next_farme_start+i;
let _= self.write_decompress_zstd(pos, count, &src[next_farme_start+i..], src_count).map_err(|e| {
info!("{}", e);
});
if count == 0 && src_count == 0 {
return Ok(());
}
i += pos +4;
if i >= src.len() {
break;
}
};
pos += farst_frame_start.unwrap();
count -= farst_frame_start.unwrap();
src_count -= farst_frame_start.unwrap();
}
self.write_decompress_zstd(pos, count, &src[farst_frame_start.unwrap()..], src_count).unwrap();
Ok(())
}
pub fn find_zstd_magic_offset_vec(src: &[u8]) -> Vec<usize>{
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
let mut vec = Vec::<usize>::new();
for (i, window) in src.windows(4).enumerate() {
if window == ZSTD_MAGIC {
if i == 0 {
vec.push(0xffffffff);
} else {
vec.push(i);
}
}
}
vec
}
pub fn find_zstd_magic(src: &[u8]) -> Option<usize> {
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
const ZSTD_SKIPPABLE_MIN: [u8; 4] = [0x50, 0x2A, 0x4D, 0x18]; const ZSTD_SKIPPABLE_MAX: [u8; 4] = [0x5F, 0x2A, 0x4D, 0x18];
if src.len() < 4 {
return None;
}
if let Some(pos) = src.windows(4).position(|w| w == ZSTD_MAGIC) {
return Some(pos);
}
for (_, window) in src.windows(4).enumerate() {
if window[1..4] == ZSTD_SKIPPABLE_MIN[1..4] &&
window[0] >= ZSTD_SKIPPABLE_MIN[0] &&
window[0] <= ZSTD_SKIPPABLE_MAX[0] {
return None;
}
}
None
}
pub fn write_compress_zstd(&mut self, pos: usize, src: &[u8], src_count: usize, level: i32) -> Result<(), std::io::Error> {
let trace = lol_trace_func!(write_compress_zstd, lol_trace_var!("{:#x}", self.size_),
lol_trace_var!("{:#x}", pos),
lol_trace_var!("{:p}", src),
lol_trace_var!("{:#x}", src_count));
if src_count > src.len() {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput,
format!("src_count ({}) exceeds src length ({})", src_count, src.len())));
}
let safe_level = level.clamp(-7, 22);
if safe_level != level {
info!("Warning: zstd compression level {} clamped to {}", level, safe_level);
}
let actual_src = &src[..src_count.min(src.len())];
let bound = compress_bound(actual_src.len());
let maxendpos = pos + bound;
lol_throw_if!(maxendpos < pos, trace);
lol_throw_if!(self.impl_reserve(maxendpos), trace);
let result = zstd::encode_all(actual_src, safe_level)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to compress zstd data: {}", e)))?;
let res_size = result.len();
self.size_ = self.size_.max(pos + res_size);
self.impl_.vec = result;
Ok(())
}
pub fn copy_compress_zstd(&mut self, src: &[u8]) -> bytes::Bytes {
self.copy_compress_zstd_compatible(src, 0) }
pub fn copy_compress_zstd_compatible(&mut self, src: &[u8], level: i32) -> bytes::Bytes {
let mut result = bytes::Bytes::bytes();
if let Err(e) = result.0.write_compress_zstd_compatible(0, src, src.len(), level) {
info!("Failed to compress zstd data: {}", e);
}
return result;
}
pub fn write_compress_zstd_compatible(&mut self, pos: usize, src: &[u8], src_count: usize, level: i32) -> Result<(), std::io::Error> {
if src_count > src.len() {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput,
format!("src_count ({}) exceeds src length ({})", src_count, src.len())));
}
let safe_level = level.clamp(0, 22); let actual_src = &src[..src_count.min(src.len())];
let mut encoder = zstd::stream::Encoder::new(Vec::new(), safe_level)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to create zstd encoder: {}", e)))?;
encoder.set_pledged_src_size(Some(actual_src.len() as u64))
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to set src size: {}", e)))?;
std::io::Write::write_all(&mut encoder, actual_src)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write data: {}", e)))?;
let result = encoder.finish()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to finish compression: {}", e)))?;
let bound = compress_bound(actual_src.len());
let maxendpos = pos + bound;
let trace_compat = lol_trace_func!(write_compress_zstd_compatible, lol_trace_var!("{:#x}", pos));
lol_throw_if!(maxendpos < pos, trace_compat);
lol_throw_if!(self.impl_reserve(maxendpos), trace_compat);
let res_size = result.len();
self.size_ = self.size_.max(pos + res_size);
self.impl_.vec = result;
Ok(())
}
pub fn write_compress_zstd_no_dict(&mut self, pos: usize, src: &[u8], src_count: usize, level: i32) -> Result<(), std::io::Error> {
let safe_level = level.clamp(0, 22);
let actual_src = &src[..src_count.min(src.len())];
let result = zstd::encode_all(actual_src, safe_level)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to compress: {}", e)))?;
let bound = compress_bound(actual_src.len());
let maxendpos = pos + bound;
let trace_no_dict = lol_trace_func!(write_compress_zstd_no_dict, lol_trace_var!("{:#x}", pos));
lol_throw_if!(maxendpos < pos, trace_no_dict);
lol_throw_if!(self.impl_reserve(maxendpos), trace_no_dict);
let res_size = result.len();
self.impl_.vec = result;
self.size_ = self.size_.max(pos + res_size);
Ok(())
}
pub fn copy_compress_zstd_simple(&mut self, src: &[u8]) -> bytes::Bytes {
let mut result = bytes::Bytes::bytes();
if let Err(e) = result.0.write_compress_zstd_no_dict(0, src, src.len(), 0) {
info!("Failed to compress zstd data (simple): {}", e);
}
return result;
}
}