use std::io::{Cursor, Read as _, Write as _, copy as copy_std};
use mlua::prelude::*;
use blocking::unblock;
use futures_lite::io::{BufReader, copy};
use lz4::{Decoder, EncoderBuilder};
use async_compression::{
Level::Best as CompressionQuality,
Level::Precise as PreciseCompressionQuality,
futures::bufread::{
BrotliDecoder, BrotliEncoder, GzipDecoder, GzipEncoder, ZlibDecoder, ZlibEncoder,
ZstdDecoder, ZstdEncoder,
},
};
#[derive(Debug, Clone, Copy)]
pub enum CompressDecompressFormat {
Brotli,
GZip,
LZ4,
ZLib,
Zstd,
}
#[allow(dead_code)]
impl CompressDecompressFormat {
#[allow(clippy::missing_panics_doc)]
pub fn detect_from_bytes(bytes: impl AsRef<[u8]>) -> Option<Self> {
match bytes.as_ref() {
b if b.len() >= 4
&& matches!(u32::from_le_bytes(b[0..4].try_into().unwrap()), 0xFD2FB528) =>
{
Some(Self::Zstd)
}
b if b.len() >= 4
&& matches!(
u32::from_le_bytes(b[0..4].try_into().unwrap()),
0x184D2204 | 0x184C2102
) =>
{
Some(Self::LZ4)
}
b if b.len() >= 4
&& matches!(
b[0..3],
[0xE1, 0x97, 0x81] | [0xE1, 0x97, 0x82] | [0xE1, 0x97, 0x80]
) =>
{
Some(Self::Brotli)
}
b if b.len() >= 3 && matches!(b[0..3], [0x1F, 0x8B, 0x08]) => Some(Self::GZip),
b if b.len() >= 2
&& matches!(
b[0..2],
[0x78, 0x01] | [0x78, 0x5E] | [0x78, 0x9C] | [0x78, 0xDA]
) =>
{
Some(Self::ZLib)
}
_ => None,
}
}
pub fn detect_from_header_str(header: impl AsRef<str>) -> Option<Self> {
match header.as_ref().to_ascii_lowercase().trim() {
"br" | "brotli" => Some(Self::Brotli),
"deflate" => Some(Self::ZLib),
"gz" | "gzip" => Some(Self::GZip),
"zst" | "zstd" => Some(Self::Zstd),
_ => None,
}
}
}
impl FromLua for CompressDecompressFormat {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::String(s) = &value {
match s.to_string_lossy().to_ascii_lowercase().trim() {
"brotli" => Ok(Self::Brotli),
"gzip" => Ok(Self::GZip),
"lz4" => Ok(Self::LZ4),
"zlib" => Ok(Self::ZLib),
"zstd" => Ok(Self::Zstd),
kind => Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "CompressDecompressFormat".to_string(),
message: Some(format!(
"Invalid format '{kind}', valid formats are: brotli, gzip, lz4, zlib, zstd"
)),
}),
}
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "CompressDecompressFormat".to_string(),
message: None,
})
}
}
}
pub async fn compress(
source: impl AsRef<[u8]>,
format: CompressDecompressFormat,
level: Option<i32>,
) -> LuaResult<Vec<u8>> {
if let CompressDecompressFormat::LZ4 = format {
let source = source.as_ref().to_vec();
return unblock(move || compress_lz4(source)).await.into_lua_err();
}
let mut bytes = Vec::new();
let reader = BufReader::new(source.as_ref());
let compression_quality = match level {
Some(l) => PreciseCompressionQuality(l),
None => CompressionQuality,
};
match format {
CompressDecompressFormat::Brotli => {
let mut encoder = BrotliEncoder::with_quality(reader, compression_quality);
copy(&mut encoder, &mut bytes).await?;
}
CompressDecompressFormat::GZip => {
let mut encoder = GzipEncoder::with_quality(reader, compression_quality);
copy(&mut encoder, &mut bytes).await?;
}
CompressDecompressFormat::ZLib => {
let mut encoder = ZlibEncoder::with_quality(reader, compression_quality);
copy(&mut encoder, &mut bytes).await?;
}
CompressDecompressFormat::Zstd => {
let mut encoder = ZstdEncoder::with_quality(reader, compression_quality);
copy(&mut encoder, &mut bytes).await?;
}
CompressDecompressFormat::LZ4 => unreachable!(),
}
Ok(bytes)
}
pub async fn decompress(
source: impl AsRef<[u8]>,
format: CompressDecompressFormat,
) -> LuaResult<Vec<u8>> {
if let CompressDecompressFormat::LZ4 = format {
let source = source.as_ref().to_vec();
return unblock(move || decompress_lz4(source)).await.into_lua_err();
}
let mut bytes = Vec::new();
let reader = BufReader::new(source.as_ref());
match format {
CompressDecompressFormat::Brotli => {
let mut decoder = BrotliDecoder::new(reader);
copy(&mut decoder, &mut bytes).await?;
}
CompressDecompressFormat::GZip => {
let mut decoder = GzipDecoder::new(reader);
copy(&mut decoder, &mut bytes).await?;
}
CompressDecompressFormat::ZLib => {
let mut decoder = ZlibDecoder::new(reader);
copy(&mut decoder, &mut bytes).await?;
}
CompressDecompressFormat::Zstd => {
let mut decoder = ZstdDecoder::new(reader);
copy(&mut decoder, &mut bytes).await?;
}
CompressDecompressFormat::LZ4 => unreachable!(),
}
Ok(bytes)
}
fn compress_lz4(input: Vec<u8>) -> LuaResult<Vec<u8>> {
let mut input = Cursor::new(input);
let mut output = Cursor::new(Vec::new());
let len = input.get_ref().len() as u32;
output.write_all(len.to_le_bytes().as_ref())?;
let mut encoder = EncoderBuilder::new()
.level(16)
.checksum(lz4::ContentChecksum::ChecksumEnabled)
.block_mode(lz4::BlockMode::Independent)
.build(output)?;
copy_std(&mut input, &mut encoder)?;
let (output, result) = encoder.finish();
result?;
Ok(output.into_inner())
}
fn decompress_lz4(input: Vec<u8>) -> LuaResult<Vec<u8>> {
let mut input = Cursor::new(input);
let mut size = [0; 4];
input.read_exact(&mut size)?;
let capacity = u32::from_le_bytes(size) as usize;
let mut output = Cursor::new(Vec::with_capacity(capacity));
let mut decoder = Decoder::new(input)?;
copy_std(&mut decoder, &mut output)?;
Ok(output.into_inner())
}