use std::io::{self, Read};
pub const LEB128_MAGIC: &[u8] = b"COMPRESSED_LEB128";
pub fn read_signed_leb128<R: Read>(reader: &mut R) -> io::Result<i64> {
let mut result: i64 = 0;
let mut shift = 0;
let mut byte = [0u8; 1];
loop {
reader.read_exact(&mut byte)?;
let b = byte[0];
result |= ((b & 0x7f) as i64) << shift;
shift += 7;
if b & 0x80 == 0 {
if shift < 64 && (b & 0x40) != 0 {
result |= !0i64 << shift;
}
break;
}
if shift >= 64 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"LEB128 overflow: value too large",
));
}
}
Ok(result)
}
fn decode_single_leb128(data: &[u8]) -> io::Result<(i64, usize)> {
let mut result: i64 = 0;
let mut shift = 0;
let mut pos = 0;
loop {
if pos >= data.len() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Unexpected end of LEB128 data",
));
}
let b = data[pos];
pos += 1;
result |= ((b & 0x7f) as i64) << shift;
shift += 7;
if b & 0x80 == 0 {
if shift < 64 && (b & 0x40) != 0 {
result |= !0i64 << shift;
}
break;
}
if shift >= 64 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"LEB128 overflow: value too large",
));
}
}
Ok((result, pos))
}
pub fn read_compressed_tensor_i16<R: Read>(reader: &mut R, count: usize) -> io::Result<Vec<i16>> {
let mut magic_buf = [0u8; 17];
reader.read_exact(&mut magic_buf)?;
if magic_buf == LEB128_MAGIC {
let mut size_buf = [0u8; 4];
reader.read_exact(&mut size_buf)?;
let compressed_size = u32::from_le_bytes(size_buf) as usize;
const MAX_COMPRESSED_SIZE: usize = 100 * 1024 * 1024; if compressed_size == 0 || compressed_size > MAX_COMPRESSED_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid compressed size: {compressed_size} (max: {MAX_COMPRESSED_SIZE})"),
));
}
let mut compressed_data = vec![0u8; compressed_size];
reader.read_exact(&mut compressed_data)?;
decode_leb128_array_i16(&compressed_data, count)
} else {
let mut result = Vec::with_capacity(count);
let mut idx = 0;
while idx + 1 < magic_buf.len() && result.len() < count {
let val = i16::from_le_bytes([magic_buf[idx], magic_buf[idx + 1]]);
result.push(val);
idx += 2;
}
let leftover: Option<u8> = if idx < magic_buf.len() {
Some(magic_buf[idx])
} else {
None
};
let remaining = count - result.len();
if remaining > 0 {
let mut buf = [0u8; 2];
if let Some(first_byte) = leftover {
reader.read_exact(&mut buf[..1])?;
let val = i16::from_le_bytes([first_byte, buf[0]]);
result.push(val);
}
let still_remaining = count - result.len();
for _ in 0..still_remaining {
reader.read_exact(&mut buf)?;
result.push(i16::from_le_bytes(buf));
}
}
Ok(result)
}
}
fn decode_leb128_array_i16(data: &[u8], count: usize) -> io::Result<Vec<i16>> {
let mut result = Vec::with_capacity(count);
let mut pos = 0;
for _ in 0..count {
let (val, consumed) = decode_single_leb128(&data[pos..])?;
result.push(val as i16);
pos += consumed;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_decode_single_leb128_positive() {
let (val, consumed) = decode_single_leb128(&[0x00]).unwrap();
assert_eq!(val, 0);
assert_eq!(consumed, 1);
let (val, consumed) = decode_single_leb128(&[0x01]).unwrap();
assert_eq!(val, 1);
assert_eq!(consumed, 1);
let (val, consumed) = decode_single_leb128(&[0x3F]).unwrap();
assert_eq!(val, 63);
assert_eq!(consumed, 1);
let (val, consumed) = decode_single_leb128(&[0xC0, 0x00]).unwrap();
assert_eq!(val, 64);
assert_eq!(consumed, 2);
let (val, consumed) = decode_single_leb128(&[0xFF, 0x00]).unwrap();
assert_eq!(val, 127);
assert_eq!(consumed, 2);
let (val, consumed) = decode_single_leb128(&[0x80, 0x01]).unwrap();
assert_eq!(val, 128);
assert_eq!(consumed, 2);
}
#[test]
fn test_decode_single_leb128_negative() {
let (val, _) = decode_single_leb128(&[0x7F]).unwrap();
assert_eq!(val, -1);
let (val, _) = decode_single_leb128(&[0x40]).unwrap();
assert_eq!(val, -64);
let (val, _) = decode_single_leb128(&[0xBF, 0x7F]).unwrap();
assert_eq!(val, -65);
let (val, _) = decode_single_leb128(&[0x80, 0x7F]).unwrap();
assert_eq!(val, -128);
}
#[test]
fn test_read_compressed_tensor_i16_uncompressed() {
let data: Vec<u8> = vec![
0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
let mut cursor = Cursor::new(data);
let result = read_compressed_tensor_i16(&mut cursor, 3).unwrap();
assert_eq!(result, vec![1, 2, 3]);
}
#[test]
fn test_decode_single_leb128_early_eof() {
let result = decode_single_leb128(&[0x80]); assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unexpected end"));
let result = decode_single_leb128(&[]);
assert!(result.is_err());
}
#[test]
fn test_decode_single_leb128_large_values() {
let (val, consumed) = decode_single_leb128(&[0xAC, 0x02]).unwrap();
assert_eq!(val, 300);
assert_eq!(consumed, 2);
let (val, consumed) = decode_single_leb128(&[0x80, 0x80, 0x01]).unwrap();
assert_eq!(val, 16384);
assert_eq!(consumed, 3);
}
#[test]
fn test_decode_leb128_array_count_mismatch() {
let data = [0x00, 0x01]; let result = decode_leb128_array_i16(&data, 10); assert!(result.is_err());
}
#[test]
fn test_read_signed_leb128_stream() {
let data = vec![0x00, 0x7F, 0x80, 0x01]; let mut cursor = Cursor::new(data);
let val = read_signed_leb128(&mut cursor).unwrap();
assert_eq!(val, 0);
let val = read_signed_leb128(&mut cursor).unwrap();
assert_eq!(val, -1);
let val = read_signed_leb128(&mut cursor).unwrap();
assert_eq!(val, 128);
}
#[test]
fn test_read_signed_leb128_i16_range() {
let (val, _) = decode_single_leb128(&[0xFF, 0xFF, 0x01]).unwrap();
assert_eq!(val, 32767);
assert_eq!(val as i16, i16::MAX);
let (val, _) = decode_single_leb128(&[0x80, 0x80, 0x7E]).unwrap();
assert_eq!(val, -32768);
assert_eq!(val as i16, i16::MIN);
}
}