use std::{arch::x86_64, mem};
use crate::DidInner;
#[repr(transparent)]
pub struct OptionDidPlc([u8; 16]);
impl OptionDidPlc {
pub const INVALID: OptionDidPlc = {
let mut val = OptionDidPlc([0; 16]);
val.0[0] = 1;
val
};
}
impl TryFrom<OptionDidPlc> for DidInner {
type Error = ();
fn try_from(val: OptionDidPlc) -> Result<Self, Self::Error> {
const {
let plc_val = DidInner::Plc([0xff; 15]);
let bytes = unsafe { mem::transmute::<DidInner, [u8; 16]>(plc_val) };
assert!(bytes[0] == 0, "The discriminant of `DidInner::Plc` should be 0");
}
if val.0[0] == 0 {
unsafe { Ok(mem::transmute::<OptionDidPlc, DidInner>(val)) }
} else {
Err(())
}
}
}
#[inline]
pub fn decode_plc(plc_str: &[u8; 32]) -> OptionDidPlc {
if is_x86_feature_detected!("avx2") {
unsafe { decode_plc_avx2(plc_str) }
} else {
decode_plc_non_avx(plc_str)
}
}
#[target_feature(enable = "avx2")]
#[inline]
fn decode_plc_avx2(plc_str: &[u8; 32]) -> OptionDidPlc {
let data = unsafe { x86_64::_mm256_loadu_si256(plc_str.as_ptr() as _) };
let alpha_mask = {
x86_64::_mm256_andnot_si256(
x86_64::_mm256_cmpgt_epi8(
data,
x86_64::_mm256_set_epi64x(
0x7a7a7a7a7a7a7a7a, 0x7a7a7a7a7a7a7a7a,
0x7a7a7a7a7a7a7a7a,
0x3a636c703a646964, ),
),
x86_64::_mm256_cmpgt_epi8(
data,
x86_64::_mm256_set_epi64x(
0x6060606060606060, 0x6060606060606060,
0x6060606060606060,
0x39626b6f39636863, ),
),
)
};
let num_mask = {
x86_64::_mm256_andnot_si256(
x86_64::_mm256_cmpgt_epi8(
data,
x86_64::_mm256_set_epi64x(
0x3737373737373737, 0x3737373737373737,
0x3737373737373737,
0x3a636c703a646964, ),
),
x86_64::_mm256_cmpgt_epi8(
data,
x86_64::_mm256_set_epi64x(
0x3131313131313131, 0x3131313131313131,
0x3131313131313131,
0x39626b6f39636863, ),
),
)
};
let char_to_val = x86_64::_mm256_blendv_epi8(
x86_64::_mm256_set1_epi8((b'2' - 26) as i8),
x86_64::_mm256_set1_epi8(b'a' as i8),
alpha_mask,
);
let values = x86_64::_mm256_sub_epi8(data, char_to_val);
let is_valid = {
let alpha = x86_64::_mm256_movemask_epi8(alpha_mask) as u32;
let num = x86_64::_mm256_movemask_epi8(num_mask) as u32;
let base32 = alpha | num;
base32 == !0 };
let reg1 = x86_64::_mm256_permute4x64_epi64::<0b11100110>(values);
#[rustfmt::skip]
let reg1 = x86_64::_mm256_shuffle_epi8(
reg1,
x86_64::_mm256_set_epi8(
15, 14, 7, 6,
13, 12, 5, 4,
11, 10, 3, 2,
9, 8, 1, 0,
15, 14, 7, 6,
13, 12, 5, 4,
11, 10, 3, 2,
9, 8, 1, 0,
),
);
let reg2 = x86_64::_mm256_and_si256(reg1, x86_64::_mm256_set1_epi16(0xff00u16 as i16));
let reg3 = x86_64::_mm256_and_si256(reg1, x86_64::_mm256_set1_epi16(0x00ffu16 as i16));
let reg2 = x86_64::_mm256_srlv_epi32(
reg2,
x86_64::_mm256_set_epi32(
8, 6, 4, 2, 8, 6, 4, 2, ),
);
let reg3 = x86_64::_mm256_sllv_epi32(
reg3,
x86_64::_mm256_set_epi32(
5, 7, 1, 3, 5, 7, 1, 3, ),
);
#[rustfmt::skip]
let reg2 = x86_64::_mm256_shuffle_epi8(
reg2,
x86_64::_mm256_set_epi8(
14, 10, 6, 2, 3, 12, 8, 4, -1, -1, -1, 7, -1, -1, -1, -1, 5, -1, -1, -1, -1, 7, -1, -1, 0, 1, 14, 10, 6, 2, 3, -1, ),
);
#[rustfmt::skip]
let reg3 = x86_64::_mm256_shuffle_epi8(
reg3,
x86_64::_mm256_set_epi8(
14, 15, 11, 6, 2, 12, 13, 9, -1, 10, -1, -1, -1, -1, 8, -1, -1, -1, -1, 10, -1, -1, -1, -1, 4, 0, 14, 15, 11, 6, 2, -1, ),
);
let reduce1 = x86_64::_mm256_or_si256(reg2, reg3);
let reduce_hi =
x86_64::_mm256_castsi256_si128(x86_64::_mm256_permute4x64_epi64::<0b1101>(reduce1));
let reduce_lo =
x86_64::_mm256_castsi256_si128(x86_64::_mm256_permute4x64_epi64::<0b1000>(reduce1));
let reduce2 = x86_64::_mm_or_si128(reduce_hi, reduce_lo);
let mut out = OptionDidPlc([0; 16]);
unsafe { x86_64::_mm_storeu_si128(out.0.as_mut_ptr() as _, reduce2) };
out.0[0] = if is_valid { 0 } else { 1 };
out
}
#[inline]
fn decode_plc_non_avx(plc_str: &[u8; 32]) -> OptionDidPlc {
let Some(ident) = plc_str.strip_prefix(b"did:plc:") else {
return OptionDidPlc::INVALID;
};
if !ident.iter().all(|&b| matches!(b, b'a'..=b'z' | b'2'..=b'7')) {
return OptionDidPlc::INVALID;
}
let mut out = OptionDidPlc([0u8; 16]);
#[inline]
fn pack_bytes(ident_bytes: &[u8]) -> u64 {
debug_assert_eq!(ident_bytes.len(), 8);
let bytes = u64::from_le_bytes([
ident_bytes[7],
ident_bytes[6],
ident_bytes[5],
ident_bytes[4],
ident_bytes[3],
ident_bytes[2],
ident_bytes[1],
ident_bytes[0],
]);
let alpha_mask = 0x4040404040404040_u64;
let alpha_flags = (bytes & alpha_mask) >> 6;
let values = bytes - alpha_flags * (b'z' - b'2' + 1) as u64;
let values = values - 0x1818181818181818_u64;
if is_x86_feature_detected!("bmi2") {
unsafe { x86_64::_pext_u64(values, 0x1f1f1f1f1f1f1f1f) }
} else {
let [h, g, f, e, d, c, b, a] = values.to_le_bytes();
((a as u64) << 35)
| ((b as u64) << 30)
| ((c as u64) << 25)
| ((d as u64) << 20)
| ((e as u64) << 15)
| ((f as u64) << 10)
| ((g as u64) << 5)
| (h as u64)
}
}
debug_assert_eq!(ident.len(), 24);
for i in 0..3 {
let from = i * 8;
let ident_slice = unsafe { ident.get_unchecked(from..from + 8) };
let bytes = pack_bytes(ident_slice).to_le_bytes();
out.0[i * 5 + 1] = bytes[4];
out.0[i * 5 + 2] = bytes[3];
out.0[i * 5 + 3] = bytes[2];
out.0[i * 5 + 4] = bytes[1];
out.0[i * 5 + 5] = bytes[0];
}
out
}
#[allow(dead_code)] pub fn encode_plc(val: DidInner, out: &mut [u8; 32]) {
debug_assert!(matches!(val, DidInner::Plc(_)), "Input should be `DidInner::Plc`");
let bytes: [u8; 16] = unsafe { mem::transmute::<DidInner, [u8; 16]>(val) };
if is_x86_feature_detected!("avx2") {
unsafe {
encode_plc_avx2(bytes, out);
}
} else {
encode_plc_non_avx(bytes, out);
}
}
#[target_feature(enable = "avx2")]
#[inline]
fn encode_plc_avx2(bytes_with_discr: [u8; 16], out: &mut [u8; 32]) {
let data = unsafe { x86_64::_mm_loadu_si128(bytes_with_discr.as_ptr() as _) };
let data_x2 = x86_64::_mm256_broadcastsi128_si256(data);
#[rustfmt::skip]
let half1 = x86_64::_mm256_shuffle_epi8(
data_x2,
x86_64::_mm256_set_epi8(
-1, 15, -1, 10, 12, 13, 7, 8,
-1, 14, -1, 9, 11, 12, 6, 7,
-1, 5, -1, -1, 2, 3, -1, -1,
-1, 4, -1, -1, 1, 2, -1, -1,
)
);
#[rustfmt::skip]
let half1 = x86_64::_mm256_sllv_epi32(half1, x86_64::_mm256_set_epi32(
8, 4,
6, 2,
8, 4,
6, 2,
));
#[rustfmt::skip]
let half1 = x86_64::_mm256_shuffle_epi8(
half1,
x86_64::_mm256_set_epi8(
15, -1, 7, -1, 11, -1, 3, -1,
13, -1, 5, -1, 9, -1, 1, -1,
15, -1, 7, -1, 11, -1, 3, -1,
-1, -1, -1, -1, -1, -1, -1, -1,
)
);
#[rustfmt::skip]
let half2 = x86_64::_mm256_shuffle_epi8(
data_x2,
x86_64::_mm256_set_epi8(
14, 15, 9, 10, 12, -1, 7, -1,
13, 14, 8, 9, 11, -1, 6, -1,
4, 5, -1, -1, 2, -1, -1, -1,
3, 4, -1, -1, 1, -1, -1, -1,
),
);
#[rustfmt::skip]
let half2 = x86_64::_mm256_srlv_epi32(half2, x86_64::_mm256_set_epi32(
5, 9,
7, 11,
5, 9,
7, 11,
));
#[rustfmt::skip]
let half2 = x86_64::_mm256_shuffle_epi8(
half2,
x86_64::_mm256_set_epi8(
-1, 14, -1, 6, -1, 10, -1, 2,
-1, 12, -1, 4, -1, 8, -1, 0,
-1, 14, -1, 6, -1, 10, -1, 2,
-1, -1, -1, -1, -1, -1, -1, -1,
),
);
let combined = x86_64::_mm256_or_si256(half1, half2);
let combined = x86_64::_mm256_and_si256(combined, x86_64::_mm256_set1_epi16(0x1f1f));
let alpha_or_num_mask =
x86_64::_mm256_cmpgt_epi8(combined, x86_64::_mm256_set1_epi8((b'z' - b'a') as i8));
let add_vec = x86_64::_mm256_blendv_epi8(
x86_64::_mm256_set1_epi8(b'a' as i8),
x86_64::_mm256_set1_epi8((b'2' - (b'z' - b'a') - 1) as i8),
alpha_or_num_mask,
);
let chars = x86_64::_mm256_add_epi8(combined, add_vec);
unsafe { x86_64::_mm256_storeu_si256(out.as_mut_ptr() as _, chars); }
out[..8].copy_from_slice(b"did:plc:");
}
#[inline]
fn encode_plc_non_avx(bytes_with_discr: [u8; 16], out: &mut [u8; 32]) {
let bytes = &bytes_with_discr[1..];
fn byte_to_base32(val: u8) -> u8 {
match val {
0..26 => val + b'a',
26..32 => val - 26 + b'2',
_ => unreachable!(),
}
}
out[..8].copy_from_slice(b"did:plc:");
for i in 0..3 {
let bytes_pos = i * 5;
let packed = usize::from_le_bytes([
bytes[bytes_pos + 4],
bytes[bytes_pos + 3],
bytes[bytes_pos + 2],
bytes[bytes_pos + 1],
bytes[bytes_pos],
0,
0,
0,
]);
let a = byte_to_base32((packed >> 35) as u8 & 0x1f);
let b = byte_to_base32((packed >> 30) as u8 & 0x1f);
let c = byte_to_base32((packed >> 25) as u8 & 0x1f);
let d = byte_to_base32((packed >> 20) as u8 & 0x1f);
let e = byte_to_base32((packed >> 15) as u8 & 0x1f);
let f = byte_to_base32((packed >> 10) as u8 & 0x1f);
let g = byte_to_base32((packed >> 5) as u8 & 0x1f);
let h = byte_to_base32(packed as u8 & 0x1f);
let start = 8 + i * 8;
let end = start + 8;
out[start..end].copy_from_slice(&[a, b, c, d, e, f, g, h]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg_attr(miri, ignore)] fn individual_bytes_decode_ok_avx2() {
if !is_x86_feature_detected!("avx2") {
panic!("AVX2 feature not detected");
}
test_individual_bytes_decode(|x| unsafe { decode_plc_avx2(x) });
}
#[test]
fn individual_bytes_decode_ok_non_avx() {
test_individual_bytes_decode(decode_plc_non_avx);
}
fn test_individual_bytes_decode<F: Fn(&[u8; 32]) -> OptionDidPlc>(decoder: F) {
let mut did = "did:plc:aaaaaaaaaaaaaaaaaaaaaaaa".to_string();
let mut bad_results = vec![];
for i in 8..32 {
let base32_alphabet = b"abcdefghijklmnopqrstuvwxyz234567";
for c in &base32_alphabet[1..] {
unsafe { did.as_bytes_mut()[i] = *c };
let result: DidInner = decoder(did.as_bytes().as_array().unwrap())
.try_into()
.unwrap_or_else(|_| panic!("Decoder failed on {did}"));
let mut expected_bytes =
base32::decode(base32::Alphabet::Rfc4648Lower { padding: false }, &did[8..])
.unwrap();
expected_bytes.insert(0, 0);
let result_bytes = unsafe { mem::transmute::<DidInner, [u8; 16]>(result) };
if result_bytes != expected_bytes[..] {
bad_results.push((did.to_owned(), result_bytes, expected_bytes));
}
}
unsafe { did.as_bytes_mut()[i] = b'a' }; }
if !bad_results.is_empty() {
let mut out = format!("{} error(s):\n", bad_results.len());
out.push_str(" ");
let byte_indices =
(00..16).map(|i| format!("{i:02x}")).collect::<Vec<_>>().as_slice().join(", ");
let ref_did = "did:plc:abcdefghijklmnopqrstuvwx";
out.push('\n');
for (did, result, expected) in bad_results {
out.push_str(&format!("Ref DID: {ref_did}\n"));
out.push_str(&format!("DID: {did}\n"));
out.push_str(&format!("Indices: {byte_indices}\n"));
out.push_str(&format!("Result: {result:02x?}\n"));
out.push_str(&format!("Expected: {expected:02x?}\n\n"));
}
panic!("{out}");
}
}
}