use flate2::read::ZlibDecoder;
use std::io::Read;
use crate::innodb::vendor::VendorInfo;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionAlgorithm {
None,
Zlib,
Lz4,
Lzo,
Lzma,
Bzip2,
Snappy,
Zstd,
}
pub fn detect_compression(
fsp_flags: u32,
vendor_info: Option<&VendorInfo>,
) -> CompressionAlgorithm {
use crate::innodb::constants::*;
if let Some(vi) = vendor_info {
if vi.is_full_crc32() {
let algo = (fsp_flags & MARIADB_FSP_FLAGS_FCRC32_COMPRESSED_ALGO_MASK) >> 5;
return mariadb_algo_from_id(algo as u8);
}
if vi.vendor == crate::innodb::vendor::InnoDbVendor::MariaDB {
if fsp_flags & MARIADB_FSP_FLAGS_PAGE_COMPRESSION != 0 {
return CompressionAlgorithm::Zlib;
}
return CompressionAlgorithm::None;
}
}
let comp_bits = (fsp_flags >> 11) & 0x03;
match comp_bits {
1 => CompressionAlgorithm::Zlib,
2 => CompressionAlgorithm::Lz4,
3 => CompressionAlgorithm::Zstd,
_ => CompressionAlgorithm::None,
}
}
pub fn detect_mariadb_page_compression(page_data: &[u8]) -> Option<CompressionAlgorithm> {
if page_data.len() < 27 {
return None;
}
let algo_id = page_data[26];
Some(mariadb_algo_from_id(algo_id))
}
fn mariadb_algo_from_id(id: u8) -> CompressionAlgorithm {
match id {
1 => CompressionAlgorithm::Zlib,
2 => CompressionAlgorithm::Lz4,
3 => CompressionAlgorithm::Lzo,
4 => CompressionAlgorithm::Lzma,
5 => CompressionAlgorithm::Bzip2,
6 => CompressionAlgorithm::Snappy,
_ => CompressionAlgorithm::None,
}
}
pub fn decompress_zlib(compressed: &[u8]) -> Option<Vec<u8>> {
let mut decoder = ZlibDecoder::new(compressed);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).ok()?;
Some(decompressed)
}
pub fn decompress_lz4(compressed: &[u8], uncompressed_len: usize) -> Option<Vec<u8>> {
lz4_flex::decompress(compressed, uncompressed_len).ok()
}
pub fn decompress_zstd(compressed: &[u8]) -> Option<Vec<u8>> {
let mut decoder = ruzstd::decoding::StreamingDecoder::new(compressed).ok()?;
let mut decompressed = Vec::new();
std::io::Read::read_to_end(&mut decoder, &mut decompressed).ok()?;
Some(decompressed)
}
pub fn is_hole_punched(page_data: &[u8], page_size: u32) -> bool {
if page_data.len() < page_size as usize {
return false;
}
let check_start = (page_size as usize * 3) / 4;
page_data[check_start..page_size as usize]
.iter()
.all(|&b| b == 0)
}
impl std::fmt::Display for CompressionAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CompressionAlgorithm::None => write!(f, "None"),
CompressionAlgorithm::Zlib => write!(f, "Zlib"),
CompressionAlgorithm::Lz4 => write!(f, "LZ4"),
CompressionAlgorithm::Lzo => write!(f, "LZO"),
CompressionAlgorithm::Lzma => write!(f, "LZMA"),
CompressionAlgorithm::Bzip2 => write!(f, "bzip2"),
CompressionAlgorithm::Snappy => write!(f, "Snappy"),
CompressionAlgorithm::Zstd => write!(f, "ZSTD"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::innodb::vendor::MariaDbFormat;
#[test]
fn test_detect_compression_mysql() {
assert_eq!(detect_compression(0, None), CompressionAlgorithm::None);
assert_eq!(
detect_compression(1 << 11, None),
CompressionAlgorithm::Zlib
);
assert_eq!(detect_compression(2 << 11, None), CompressionAlgorithm::Lz4);
assert_eq!(
detect_compression(3 << 11, None),
CompressionAlgorithm::Zstd
);
assert_eq!(
detect_compression(0xFF | (1 << 11), None),
CompressionAlgorithm::Zlib
);
}
#[test]
fn test_detect_compression_mariadb_full_crc32() {
let vendor = VendorInfo::mariadb(MariaDbFormat::FullCrc32);
let flags = 0x10 | (1 << 5);
assert_eq!(
detect_compression(flags, Some(&vendor)),
CompressionAlgorithm::Zlib
);
let flags = 0x10 | (2 << 5);
assert_eq!(
detect_compression(flags, Some(&vendor)),
CompressionAlgorithm::Lz4
);
let flags = 0x10 | (3 << 5);
assert_eq!(
detect_compression(flags, Some(&vendor)),
CompressionAlgorithm::Lzo
);
}
#[test]
fn test_detect_mariadb_page_compression() {
let mut page = vec![0u8; 38];
page[26] = 2; assert_eq!(
detect_mariadb_page_compression(&page),
Some(CompressionAlgorithm::Lz4)
);
page[26] = 6; assert_eq!(
detect_mariadb_page_compression(&page),
Some(CompressionAlgorithm::Snappy)
);
}
#[test]
fn test_decompress_zlib() {
use flate2::write::ZlibEncoder;
use flate2::Compression;
use std::io::Write;
let original = b"Hello, InnoDB compression test data!";
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let result = decompress_zlib(&compressed).unwrap();
assert_eq!(result, original);
}
#[test]
fn test_decompress_lz4() {
let original = b"Hello, LZ4 compression test data for InnoDB!";
let compressed = lz4_flex::compress_prepend_size(original);
let result = lz4_flex::decompress(&compressed[4..], original.len());
assert!(result.is_ok());
assert_eq!(result.unwrap(), original);
}
#[test]
fn test_detect_compression_mysql_zstd() {
assert_eq!(
detect_compression(3 << 11, None),
CompressionAlgorithm::Zstd
);
}
#[test]
fn test_decompress_zstd() {
let original = b"Hello, ZSTD compression test data for InnoDB!";
let compressed = ruzstd::encoding::compress_to_vec(
&original[..],
ruzstd::encoding::CompressionLevel::Fastest,
);
let result = decompress_zstd(&compressed).unwrap();
assert_eq!(result, original);
}
#[test]
fn test_zstd_display() {
assert_eq!(format!("{}", CompressionAlgorithm::Zstd), "ZSTD");
}
#[test]
fn test_decompress_zstd_invalid() {
assert!(decompress_zstd(&[0xFF, 0xFE]).is_none());
}
#[test]
fn test_is_hole_punched() {
let page_size = 16384u32;
let mut page = vec![0u8; page_size as usize];
assert!(is_hole_punched(&page, page_size));
page[0] = 0xFF;
page[100] = 0xAB;
assert!(is_hole_punched(&page, page_size));
page[page_size as usize - 10] = 0x01;
assert!(!is_hole_punched(&page, page_size));
}
}