#![allow(unused_unsafe)]
use super::super::common;
use crate::core::dictionary::Dictionary;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
pub fn encode(data: &[u8], dictionary: &Dictionary) -> Option<String> {
let mut lut = ['\0'; 256];
for (i, lut_entry) in lut.iter_mut().enumerate() {
*lut_entry = dictionary.encode_digit(i)?;
}
let output_len = data.len();
let mut result = String::with_capacity(output_len);
unsafe {
encode_neon_impl(data, &lut, &mut result);
}
Some(result)
}
#[cfg(target_arch = "aarch64")]
pub fn decode(encoded: &str, dictionary: &Dictionary) -> Option<Vec<u8>> {
use std::collections::HashMap;
let mut reverse_map: HashMap<char, u8> = HashMap::with_capacity(256);
for i in 0..256 {
if let Some(ch) = dictionary.encode_digit(i) {
reverse_map.insert(ch, i as u8);
}
}
let chars: Vec<char> = encoded.chars().collect();
let mut result = Vec::with_capacity(chars.len());
if !unsafe { decode_neon_impl(&chars, &reverse_map, &mut result) } {
return None;
}
Some(result)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn encode_neon_impl(data: &[u8], lut: &[char; 256], result: &mut String) {
const BLOCK_SIZE: usize = 16;
if data.len() < BLOCK_SIZE {
encode_scalar_remainder(data, lut, result);
return;
}
let (num_rounds, simd_bytes) = common::calculate_blocks(data.len(), BLOCK_SIZE);
let mut offset = 0;
for _ in 0..num_rounds {
let input_vec = unsafe { vld1q_u8(data.as_ptr().add(offset)) };
let mut input_buf = [0u8; 16];
unsafe {
vst1q_u8(input_buf.as_mut_ptr(), input_vec);
}
for &byte in &input_buf {
result.push(lut[byte as usize]);
}
offset += BLOCK_SIZE;
}
if simd_bytes < data.len() {
encode_scalar_remainder(&data[simd_bytes..], lut, result);
}
}
fn encode_scalar_remainder(data: &[u8], lut: &[char; 256], result: &mut String) {
for &byte in data {
result.push(lut[byte as usize]);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn decode_neon_impl(
chars: &[char],
reverse_map: &std::collections::HashMap<char, u8>,
result: &mut Vec<u8>,
) -> bool {
const BLOCK_SIZE: usize = 16;
let (num_rounds, simd_bytes) = common::calculate_blocks(chars.len(), BLOCK_SIZE);
for round in 0..num_rounds {
let offset = round * BLOCK_SIZE;
for ch in chars.iter().skip(offset).take(BLOCK_SIZE) {
match reverse_map.get(ch) {
Some(&byte_val) => result.push(byte_val),
None => return false, }
}
}
for &ch in &chars[simd_bytes..] {
match reverse_map.get(&ch) {
Some(&byte_val) => result.push(byte_val),
None => return false,
}
}
true
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::core::config::{DictionaryRegistry, EncodingMode};
fn make_base256_dict() -> crate::core::dictionary::Dictionary {
let config = DictionaryRegistry::load_default().unwrap();
let dict_config = config.get_dictionary("base256_matrix").unwrap();
let chars: Vec<char> = dict_config.effective_chars().unwrap().chars().collect();
crate::core::dictionary::Dictionary::new_with_mode(chars, EncodingMode::Chunked, None)
.unwrap()
}
#[test]
fn test_encode_simple() {
let dictionary = make_base256_dict();
let test_data = b"Hello";
if let Some(encoded) = encode(test_data, &dictionary) {
assert_eq!(encoded.chars().count(), test_data.len());
if let Some(decoded) = decode(&encoded, &dictionary) {
assert_eq!(decoded, test_data);
} else {
panic!("Decode failed");
}
} else {
panic!("Encode failed");
}
}
#[test]
fn test_encode_all_bytes() {
let dictionary = make_base256_dict();
let test_data: Vec<u8> = (0..=255).collect();
if let Some(encoded) = encode(&test_data, &dictionary) {
assert_eq!(encoded.chars().count(), 256);
if let Some(decoded) = decode(&encoded, &dictionary) {
assert_eq!(decoded, test_data);
} else {
panic!("Decode failed");
}
} else {
panic!("Encode failed");
}
}
#[test]
fn test_decode_round_trip() {
let dictionary = make_base256_dict();
for len in 0..100 {
let original: Vec<u8> = (0..len).map(|i| (i * 7) as u8).collect();
if let Some(encoded) = encode(&original, &dictionary) {
if let Some(decoded) = decode(&encoded, &dictionary) {
assert_eq!(decoded, original, "Round-trip failed at length {}", len);
} else {
panic!("Decode failed at length {}", len);
}
} else {
panic!("Encode failed at length {}", len);
}
}
}
#[test]
fn test_encode_empty() {
let dictionary = make_base256_dict();
if let Some(encoded) = encode(&[], &dictionary) {
assert_eq!(encoded, "");
} else {
panic!("Encode failed for empty input");
}
}
#[test]
fn test_encode_single_byte() {
let dictionary = make_base256_dict();
if let Some(encoded) = encode(&[0xFF], &dictionary) {
assert_eq!(encoded.chars().count(), 1);
if let Some(decoded) = decode(&encoded, &dictionary) {
assert_eq!(decoded, vec![0xFF]);
} else {
panic!("Decode failed for single byte");
}
} else {
panic!("Encode failed for single byte");
}
}
#[test]
fn test_encode_large_input() {
let dictionary = make_base256_dict();
let test_data: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
if let Some(encoded) = encode(&test_data, &dictionary) {
assert_eq!(encoded.chars().count(), test_data.len());
if let Some(decoded) = decode(&encoded, &dictionary) {
assert_eq!(decoded, test_data);
} else {
panic!("Decode failed for large input");
}
} else {
panic!("Encode failed for large input");
}
}
#[test]
fn test_decode_invalid_char() {
let dictionary = make_base256_dict();
let valid_data = vec![65u8];
if let Some(mut encoded) = encode(&valid_data, &dictionary) {
encoded.push('🦀');
assert_eq!(decode(&encoded, &dictionary), None);
}
}
#[test]
fn test_simd_boundary() {
let dictionary = make_base256_dict();
let test_data: Vec<u8> = (0..16).collect();
if let Some(encoded) = encode(&test_data, &dictionary) {
assert_eq!(encoded.chars().count(), 16);
if let Some(decoded) = decode(&encoded, &dictionary) {
assert_eq!(decoded, test_data);
}
}
}
#[test]
fn test_simd_boundary_plus_one() {
let dictionary = make_base256_dict();
let test_data: Vec<u8> = (0..17).collect();
if let Some(encoded) = encode(&test_data, &dictionary) {
assert_eq!(encoded.chars().count(), 17);
if let Some(decoded) = decode(&encoded, &dictionary) {
assert_eq!(decoded, test_data);
}
}
}
}