rustlol 0.1.1

A wad files lib
Documentation
use std::io::Read;
use std::ptr::null_mut;
use log::info;
use zstd::zstd_safe::compress_bound;

use crate::lol::io::bytes;
use crate::lol::io::bytes::MMap;
use crate::lol::io::bytes::Impl;
use crate::{lol_trace_func, lol_throw_if};

#[derive(Debug, Clone)]
pub struct Buffer {
        pub data_: *mut u8,
        pub size_: usize,
        pub impl_: Impl
}


impl Default for Buffer  {
    fn default() -> Self {
        let impl_: Impl = Impl { ref_count: 0, vec: Vec::<u8>::new(), mmp: MMap { file: null_mut(), data: null_mut(), size: 0 } }; 
        Self { data_: null_mut(), size_: 0, impl_ }
    }
}





impl Buffer {
    
    pub const KIB: usize = 1024; 
    pub const MIB: usize = 1024 * Self::KIB;
    pub const GIB: usize = 1024 * Self::MIB;

    pub fn data(&mut self) -> *mut u8 {
        self.data_
    }

    pub fn size(&self) -> usize {
        self.size_
    }


    pub fn write_decompress_zstd(&mut self, pos: usize, count: usize, src: &[u8],  src_count: usize) -> Result<(), std::io::Error> {
        let trace = lol_trace_func!(write_decompress_zstd, lol_trace_var!("{:#x}", size_),
                   lol_trace_var!("{:#x}", pos),
                   lol_trace_var!("{:#x}", count),
                   lol_trace_var!("{:p}", src),
                   lol_trace_var!("{:#x}", src_count));
        let maxendpos = pos +count;
        lol_throw_if!(maxendpos< pos, trace);
        lol_throw_if!(self.impl_reserve(maxendpos), trace);
        
        let mut de = zstd::Decoder::new(&src[..src_count])
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to create zstd decoder: {}", e)))?;
        de.read_exact(&mut self.impl_.vec)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to decompress zstd data: {}", e)))?;
        

        self.size_ = self.size_.max(pos +count);
        Ok(())
    }

    pub fn impl_reserve(&mut self, size: usize) -> bool {
        self.impl_.vec.reserve(size);
        self.impl_.vec.resize(size, 0u8);
        if self.impl_.vec.len() ==  size {
            //println!( "size: {} len {}", size,  self.impl_.vec.len());
            false
        } else {
            true
        }
    }

    pub fn write_decompress_zstd_hack(&mut self, mut pos: usize, mut count: usize, src: &[u8],  mut src_count: usize) -> Result<(), std::io::Error> {
        let trace = lol_trace_func!(write_decompress_zstd_hack, lol_trace_var!("{:#x}", self.size_),
                   lol_trace_var!("{:#x}", pos),
                   lol_trace_var!("{:#x}", count),
                   lol_trace_var!("{:p}", src),
                   lol_trace_var!("{:#x}", src_count));
        let maxendpos = pos + count;
        lol_throw_if!(maxendpos < pos, trace);
        lol_throw_if!(self.impl_reserve(maxendpos), trace);        

        let farst_frame_start = Self::find_zstd_magic(&src);
        println!("ZstdMulti in  write_decompress_zstd_hack frame_start {:x}", farst_frame_start.unwrap());
        
        if farst_frame_start.is_some() {

            let mut i = 0;
             while let Some(next_farme_start) = Self::find_zstd_magic(&src[i..]) {
       
                let mut pos = pos;
                let mut count = count;
                let mut src_count = src_count;

                
                pos += next_farme_start+i;
                count -= next_farme_start+i;
                src_count -= next_farme_start+i;


                let _= self.write_decompress_zstd(pos, count, &src[next_farme_start+i..], src_count).map_err(|e| {
                    info!("{}", e);
                });
                    
                if count == 0 && src_count == 0 {
                    return Ok(());
                }
          

                i += pos +4;
                if i >= src.len() {
                
                   break;
                } 
            };

            pos += farst_frame_start.unwrap();
            count -= farst_frame_start.unwrap();
            src_count -= farst_frame_start.unwrap();
        }
        

 
        self.write_decompress_zstd(pos, count, &src[farst_frame_start.unwrap()..], src_count).unwrap();
        


        Ok(())
    }


    pub fn find_zstd_magic_offset_vec(src: &[u8]) -> Vec<usize>{
        const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
        let mut vec = Vec::<usize>::new();
        for (i, window) in src.windows(4).enumerate() {
            if window == ZSTD_MAGIC {
                if i == 0 {
                    vec.push(0xffffffff);
                } else {
                    vec.push(i);
                }
                
            }
        }

        vec
    }


    pub fn find_zstd_magic(src: &[u8]) -> Option<usize> {
        const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
        const ZSTD_SKIPPABLE_MIN: [u8; 4] = [0x50, 0x2A, 0x4D, 0x18]; // 可跳过帧的最小魔数
        const ZSTD_SKIPPABLE_MAX: [u8; 4] = [0x5F, 0x2A, 0x4D, 0x18]; // 可跳过帧的最大魔数

        if src.len() < 4 {
            return None;
        }

        // 查找标准zstd魔数
        if let Some(pos) = src.windows(4).position(|w| w == ZSTD_MAGIC) {

            return Some(pos);
        } 
  

       //查找可跳过的帧魔数范围
        for (_, window) in src.windows(4).enumerate() {
            if window[1..4] == ZSTD_SKIPPABLE_MIN[1..4] && 
               window[0] >= ZSTD_SKIPPABLE_MIN[0] && 
               window[0] <= ZSTD_SKIPPABLE_MAX[0] {
                return None;
            }
        }

        None
       
    }

    pub fn write_compress_zstd(&mut self, pos: usize,  src: &[u8],  src_count: usize, level: i32) -> Result<(), std::io::Error> {
        let trace = lol_trace_func!(write_compress_zstd, lol_trace_var!("{:#x}", self.size_),
                    lol_trace_var!("{:#x}", pos),
                    lol_trace_var!("{:p}", src),
                    lol_trace_var!("{:#x}", src_count));
        
        // 验证输入参数
        if src_count > src.len() {
            return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, 
                format!("src_count ({}) exceeds src length ({})", src_count, src.len())));
        }
        
        // 验证压缩级别范围 (zstd支持-7到22)
        let safe_level = level.clamp(-7, 22);
        if safe_level != level {
            info!("Warning: zstd compression level {} clamped to {}", level, safe_level);
        }
        
        // 只压缩实际需要的数据长度
        let actual_src = &src[..src_count.min(src.len())];
        let bound = compress_bound(actual_src.len());                
        let maxendpos = pos + bound;
        lol_throw_if!(maxendpos < pos, trace);
        lol_throw_if!(self.impl_reserve(maxendpos), trace);
        
        let result = zstd::encode_all(actual_src, safe_level)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to compress zstd data: {}", e)))?;
        
        let res_size = result.len();
        self.size_ = self.size_.max(pos + res_size);
        self.impl_.vec = result;
        Ok(())
    }
    pub fn copy_compress_zstd(&mut self, src: &[u8]) -> bytes::Bytes {
        // 使用pyzstd兼容的压缩参数
        self.copy_compress_zstd_compatible(src, 0) // pyzstd默认级别3
    }

    pub fn copy_compress_zstd_compatible(&mut self, src: &[u8], level: i32) -> bytes::Bytes {
        let mut result = bytes::Bytes::bytes();
        if let Err(e) = result.0.write_compress_zstd_compatible(0, src, src.len(), level) {
            info!("Failed to compress zstd data: {}", e);
        }
        return result;
    }

    pub fn write_compress_zstd_compatible(&mut self, pos: usize, src: &[u8], src_count: usize, level: i32) -> Result<(), std::io::Error> {
        // 验证输入参数
        if src_count > src.len() {
            return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, 
                format!("src_count ({}) exceeds src length ({})", src_count, src.len())));
        }

        // 使用与pyzstd兼容的压缩参数
        let safe_level = level.clamp(0, 22); // pyzstd通常使用1-22范围
        let actual_src = &src[..src_count.min(src.len())];
        
        // 使用zstd::stream模块进行更精确的控制
        let mut encoder = zstd::stream::Encoder::new(Vec::new(), safe_level)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to create zstd encoder: {}", e)))?;
        
        // 设置与pyzstd兼容的参数
        encoder.set_pledged_src_size(Some(actual_src.len() as u64))
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to set src size: {}", e)))?;
        
        // 写入数据并完成压缩
        std::io::Write::write_all(&mut encoder, actual_src)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write data: {}", e)))?;
        
        let result = encoder.finish()
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to finish compression: {}", e)))?;
        
        let bound = compress_bound(actual_src.len());
        let maxendpos = pos + bound;
        let trace_compat = lol_trace_func!(write_compress_zstd_compatible, lol_trace_var!("{:#x}", pos));
        lol_throw_if!(maxendpos < pos, trace_compat);
        lol_throw_if!(self.impl_reserve(maxendpos), trace_compat);
        
        let res_size = result.len();
        self.size_ = self.size_.max(pos + res_size);
        self.impl_.vec = result;
        Ok(())
    }
    /// 专门为pyzstd兼容性设计的压缩函数
    pub fn write_compress_zstd_no_dict(&mut self, pos: usize, src: &[u8], src_count: usize, level: i32) -> Result<(), std::io::Error> {
        // 确保不使用字典压缩,与pyzstd完全兼容
        let safe_level = level.clamp(0, 22);
        let actual_src = &src[..src_count.min(src.len())];
        
        // 使用最基本的zstd压缩,不添加任何特殊参数
        let result = zstd::encode_all(actual_src, safe_level)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to compress: {}", e)))?;
        
        let bound = compress_bound(actual_src.len());
        let maxendpos = pos + bound;
        let trace_no_dict = lol_trace_func!(write_compress_zstd_no_dict, lol_trace_var!("{:#x}", pos));
        lol_throw_if!(maxendpos < pos, trace_no_dict);
        lol_throw_if!(self.impl_reserve(maxendpos), trace_no_dict);
        
        let res_size = result.len();
        self.impl_.vec = result;
        self.size_ = self.size_.max(pos + res_size);
        Ok(())
    }

    /// 简单的pyzstd兼容压缩
    pub fn copy_compress_zstd_simple(&mut self, src: &[u8]) -> bytes::Bytes {
        let mut result = bytes::Bytes::bytes();
        if let Err(e) = result.0.write_compress_zstd_no_dict(0, src, src.len(), 0) {
            info!("Failed to compress zstd data (simple): {}", e);
        }
        return result;
    }
}