use crate::core::dictionary::Dictionary;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DictionaryVariant {
Base64Standard,
Base64Url,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Base32Variant {
Rfc4648,
Rfc4648Hex,
}
pub fn identify_base64_variant(dict: &Dictionary) -> Option<DictionaryVariant> {
if dict.base() != 64 {
return None;
}
let char_62 = dict.encode_digit(62)?;
let char_63 = dict.encode_digit(63)?;
let candidate = match (char_62, char_63) {
('+', '/') => DictionaryVariant::Base64Standard,
('-', '_') => DictionaryVariant::Base64Url,
_ => return None,
};
if verify_base64_dictionary(dict, candidate) {
Some(candidate)
} else {
None
}
}
pub fn verify_base64_dictionary(dict: &Dictionary, variant: DictionaryVariant) -> bool {
if dict.base() != 64 {
return false;
}
let expected = match variant {
DictionaryVariant::Base64Standard => {
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
}
DictionaryVariant::Base64Url => {
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
}
};
for (i, expected_char) in expected.chars().enumerate() {
if dict.encode_digit(i) != Some(expected_char) {
return false;
}
}
true
}
pub fn identify_base32_variant(dict: &Dictionary) -> Option<Base32Variant> {
if dict.base() != 32 {
return None;
}
let chars: Vec<char> = (0..32).filter_map(|i| dict.encode_digit(i)).collect();
if chars.len() != 32 {
return None;
}
if chars[0] == 'A' && chars[25] == 'Z' && chars[26] == '2' && chars[31] == '7' {
let candidate = Base32Variant::Rfc4648;
if verify_base32_dictionary(dict, candidate) {
return Some(candidate);
}
}
if chars[0] == '0' && chars[9] == '9' && chars[10] == 'A' && chars[31] == 'V' {
let candidate = Base32Variant::Rfc4648Hex;
if verify_base32_dictionary(dict, candidate) {
return Some(candidate);
}
}
None
}
pub fn verify_base32_dictionary(dict: &Dictionary, variant: Base32Variant) -> bool {
if dict.base() != 32 {
return false;
}
let expected = match variant {
Base32Variant::Rfc4648 => "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567",
Base32Variant::Rfc4648Hex => "0123456789ABCDEFGHIJKLMNOPQRSTUV",
};
for (i, expected_char) in expected.chars().enumerate() {
if dict.encode_digit(i) != Some(expected_char) {
return false;
}
}
true
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CharRange {
pub index_start: u8,
pub index_end: u8,
pub char_start: char,
pub char_end: char,
pub offset: i8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TranslationStrategy {
Sequential { start_codepoint: u32 },
Ranged { ranges: &'static [CharRange] },
Arbitrary { dictionary_size: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LutStrategy {
NotApplicable,
SmallDirect,
LargePlatformDependent,
ScalarOnly,
}
static BASE64_STANDARD_RANGES: &[CharRange] = &[
CharRange {
index_start: 0,
index_end: 26,
char_start: 'A',
char_end: 'Z',
offset: 65,
},
CharRange {
index_start: 26,
index_end: 52,
char_start: 'a',
char_end: 'z',
offset: 71,
},
CharRange {
index_start: 52,
index_end: 62,
char_start: '0',
char_end: '9',
offset: -4,
},
];
static BASE64_URL_RANGES: &[CharRange] = &[
CharRange {
index_start: 0,
index_end: 26,
char_start: 'A',
char_end: 'Z',
offset: 65,
},
CharRange {
index_start: 26,
index_end: 52,
char_start: 'a',
char_end: 'z',
offset: 71,
},
CharRange {
index_start: 52,
index_end: 62,
char_start: '0',
char_end: '9',
offset: -4,
},
];
static HEX_UPPER_RANGES: &[CharRange] = &[
CharRange {
index_start: 0,
index_end: 10,
char_start: '0',
char_end: '9',
offset: 48,
},
CharRange {
index_start: 10,
index_end: 16,
char_start: 'A',
char_end: 'F',
offset: 55,
},
];
static HEX_LOWER_RANGES: &[CharRange] = &[
CharRange {
index_start: 0,
index_end: 10,
char_start: '0',
char_end: '9',
offset: 48,
},
CharRange {
index_start: 10,
index_end: 16,
char_start: 'a',
char_end: 'f',
offset: 87,
},
];
#[derive(Debug, Clone)]
pub struct DictionaryMetadata {
pub base: usize,
pub bits_per_symbol: u8,
pub strategy: TranslationStrategy,
pub simd_compatible: bool,
}
impl DictionaryMetadata {
pub fn simd_available(&self) -> bool {
self.simd_compatible
}
pub fn lut_strategy(&self) -> LutStrategy {
match self.strategy {
TranslationStrategy::Arbitrary { dictionary_size } => {
if dictionary_size <= 16 {
LutStrategy::SmallDirect
} else if dictionary_size <= 64 {
LutStrategy::LargePlatformDependent
} else {
LutStrategy::ScalarOnly
}
}
_ => LutStrategy::NotApplicable,
}
}
pub fn from_dictionary(dict: &Dictionary) -> Self {
let base = dict.base();
if !base.is_power_of_two() {
return Self {
base,
bits_per_symbol: 0,
strategy: TranslationStrategy::Arbitrary {
dictionary_size: base,
},
simd_compatible: false,
};
}
let bits_per_symbol = (base as f64).log2() as u8;
let strategy = Self::detect_strategy(dict);
let simd_compatible = matches!(bits_per_symbol, 4 | 5 | 6 | 8)
&& !matches!(strategy, TranslationStrategy::Arbitrary { .. });
Self {
base,
bits_per_symbol,
strategy,
simd_compatible,
}
}
fn detect_strategy(dict: &Dictionary) -> TranslationStrategy {
let base = dict.base();
let chars: Vec<char> = (0..base).filter_map(|i| dict.encode_digit(i)).collect();
if chars.len() != base {
return TranslationStrategy::Arbitrary {
dictionary_size: chars.len(),
};
}
let first_codepoint = chars[0] as u32;
let is_sequential = chars
.iter()
.enumerate()
.all(|(i, &c)| (c as u32) == first_codepoint + (i as u32));
if is_sequential {
return TranslationStrategy::Sequential {
start_codepoint: first_codepoint,
};
}
if let Some(ranges) = Self::detect_ranges(&chars) {
return TranslationStrategy::Ranged { ranges };
}
TranslationStrategy::Arbitrary {
dictionary_size: base,
}
}
fn detect_ranges(chars: &[char]) -> Option<&'static [CharRange]> {
if chars.len() == 64 && Self::matches_base64_standard(chars) {
return Some(BASE64_STANDARD_RANGES);
}
if chars.len() == 64 && Self::matches_base64_url(chars) {
return Some(BASE64_URL_RANGES);
}
if chars.len() == 16 && Self::matches_hex_upper(chars) {
return Some(HEX_UPPER_RANGES);
}
if chars.len() == 16 && Self::matches_hex_lower(chars) {
return Some(HEX_LOWER_RANGES);
}
None
}
fn matches_base64_standard(chars: &[char]) -> bool {
let expected = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
chars.iter().zip(expected.chars()).all(|(a, b)| *a == b)
}
fn matches_base64_url(chars: &[char]) -> bool {
let expected = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
chars.iter().zip(expected.chars()).all(|(a, b)| *a == b)
}
fn matches_hex_upper(chars: &[char]) -> bool {
let expected = "0123456789ABCDEF";
chars.iter().zip(expected.chars()).all(|(a, b)| *a == b)
}
fn matches_hex_lower(chars: &[char]) -> bool {
let expected = "0123456789abcdef";
chars.iter().zip(expected.chars()).all(|(a, b)| *a == b)
}
}
#[cfg(all(test, target_arch = "x86_64"))]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::core::config::EncodingMode;
fn make_base64_standard_dict() -> Dictionary {
let chars: Vec<char> = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
.chars()
.collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, Some('=')).unwrap()
}
fn make_base64_url_dict() -> Dictionary {
let chars: Vec<char> = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
.chars()
.collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, Some('=')).unwrap()
}
#[test]
fn test_identify_standard_base64() {
let dict = make_base64_standard_dict();
assert_eq!(
identify_base64_variant(&dict),
Some(DictionaryVariant::Base64Standard)
);
}
#[test]
fn test_identify_base64_url() {
let dict = make_base64_url_dict();
assert_eq!(
identify_base64_variant(&dict),
Some(DictionaryVariant::Base64Url)
);
}
#[test]
fn test_identify_non_base64() {
let chars: Vec<char> = "0123456789ABCDEF".chars().collect();
let dict = Dictionary::new(chars).unwrap();
assert_eq!(identify_base64_variant(&dict), None);
}
#[test]
fn test_identify_unknown_variant() {
let chars: Vec<char> = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789@$"
.chars()
.collect();
let dict = Dictionary::new_with_mode(chars, EncodingMode::Chunked, None).unwrap();
assert_eq!(identify_base64_variant(&dict), None);
}
#[test]
fn test_verify_standard_dictionary() {
let dict = make_base64_standard_dict();
assert!(verify_base64_dictionary(
&dict,
DictionaryVariant::Base64Standard
));
assert!(!verify_base64_dictionary(
&dict,
DictionaryVariant::Base64Url
));
}
#[test]
fn test_verify_url_dictionary() {
let dict = make_base64_url_dict();
assert!(verify_base64_dictionary(
&dict,
DictionaryVariant::Base64Url
));
assert!(!verify_base64_dictionary(
&dict,
DictionaryVariant::Base64Standard
));
}
#[test]
fn test_sequential_dictionary_detection() {
let chars: Vec<char> = (0x100..0x140)
.map(|cp| char::from_u32(cp).unwrap())
.collect();
let dict = Dictionary::new_with_mode(chars, EncodingMode::Chunked, None).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 64);
assert_eq!(metadata.bits_per_symbol, 6);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Sequential {
start_codepoint: 0x100
}
));
assert!(metadata.simd_compatible);
}
#[test]
fn test_ranged_base64_standard_detection() {
let dict = make_base64_standard_dict();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 64);
assert_eq!(metadata.bits_per_symbol, 6);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Ranged { .. }
));
assert!(metadata.simd_compatible);
}
#[test]
fn test_ranged_base64_url_detection() {
let dict = make_base64_url_dict();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 64);
assert_eq!(metadata.bits_per_symbol, 6);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Ranged { .. }
));
assert!(metadata.simd_compatible);
}
#[test]
fn test_ranged_hex_upper_detection() {
let chars: Vec<char> = "0123456789ABCDEF".chars().collect();
let dict = Dictionary::new(chars).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 16);
assert_eq!(metadata.bits_per_symbol, 4);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Ranged { .. }
));
assert!(metadata.simd_compatible);
}
#[test]
fn test_ranged_hex_lower_detection() {
let chars: Vec<char> = "0123456789abcdef".chars().collect();
let dict = Dictionary::new(chars).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 16);
assert_eq!(metadata.bits_per_symbol, 4);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Ranged { .. }
));
assert!(metadata.simd_compatible);
}
#[test]
fn test_arbitrary_dictionary_detection() {
let chars: Vec<char> = "ZYXWVUTSRQPONMLKJIHGFEDCBAzyxwvutsrqponmlkjihgfedcba9876543210+/"
.chars()
.collect();
let dict = Dictionary::new_with_mode(chars, EncodingMode::Chunked, None).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 64);
assert_eq!(metadata.bits_per_symbol, 6);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Arbitrary {
dictionary_size: 64
}
));
assert!(!metadata.simd_compatible);
}
#[test]
fn test_non_power_of_two_detection() {
let chars: Vec<char> = "0123456789".chars().collect();
let dict = Dictionary::new(chars).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 10);
assert_eq!(metadata.bits_per_symbol, 0);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Arbitrary {
dictionary_size: 10
}
));
assert!(!metadata.simd_compatible);
}
#[test]
fn test_base32_sequential() {
let chars: Vec<char> = (0x41..0x61).map(|cp| char::from_u32(cp).unwrap()).collect();
let dict = Dictionary::new(chars).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 32);
assert_eq!(metadata.bits_per_symbol, 5);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Sequential {
start_codepoint: 0x41
}
));
assert!(metadata.simd_compatible);
}
#[test]
fn test_base256_sequential() {
let chars: Vec<char> = (0x100..0x200)
.map(|cp| char::from_u32(cp).unwrap())
.collect();
let dict = Dictionary::new(chars).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 256);
assert_eq!(metadata.bits_per_symbol, 8);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Sequential {
start_codepoint: 0x100
}
));
assert!(metadata.simd_compatible);
}
#[test]
fn test_sequential_starting_at_printable() {
let chars: Vec<char> = (0x21..0x31).map(|cp| char::from_u32(cp).unwrap()).collect();
let dict = Dictionary::new(chars).unwrap();
let metadata = DictionaryMetadata::from_dictionary(&dict);
assert_eq!(metadata.base, 16);
assert_eq!(metadata.bits_per_symbol, 4);
assert!(matches!(
metadata.strategy,
TranslationStrategy::Sequential {
start_codepoint: 0x21
}
));
assert!(metadata.simd_compatible);
}
}