use crate::ALPHABET;
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeError {
InvalidChar(usize),
Checksum,
InvalidLength,
}
#[inline]
pub fn decode(input: &str) -> Result<Vec<u8>, DecodeError> {
decode_full(input, false)
}
#[inline]
pub fn decode_full(input: &str, validate_checksum: bool) -> Result<Vec<u8>, DecodeError> {
if input.is_empty() {
return Ok(vec![]);
}
let bytes = input.as_bytes();
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
let cap = ((bytes.len() as f64 * 0.733).ceil() as usize).max(1);
let mut output = Vec::with_capacity(cap);
let zeros = bytes.iter().take_while(|&&b| b == b'1').count();
let digits = &bytes[zeros..];
if digits.is_empty() {
output.extend(std::iter::repeat_n(0u8, zeros));
return finish_decode(output, validate_checksum);
}
for (i, &ch) in digits.iter().enumerate() {
if DIGIT_TO_VAL[ch as usize] == 255 {
return Err(DecodeError::InvalidChar(zeros + i));
}
}
#[cfg(feature = "simd")]
{
#[cfg(target_arch = "x86_64")]
{
if digits.len() >= 32 && std::arch::is_x86_feature_detected!("avx2") {
unsafe {
decode_simd_x86(&mut output, digits, zeros);
}
} else {
decode_scalar(&mut output, digits, zeros);
}
}
#[cfg(target_arch = "aarch64")]
{
if digits.len() >= 16 && std::arch::is_aarch64_feature_detected!("neon") {
unsafe {
decode_simd_arm(&mut output, digits, zeros);
}
} else {
decode_scalar(&mut output, digits, zeros);
}
}
}
#[cfg(not(feature = "simd"))]
{
decode_scalar(&mut output, digits, zeros);
}
finish_decode(output, validate_checksum)
}
#[allow(clippy::cast_possible_truncation)]
#[inline]
fn decode_scalar(output: &mut Vec<u8>, digits: &[u8], zeros: usize) {
let mut num: Vec<u8> = Vec::new();
for &ch in digits {
let val = DIGIT_TO_VAL[ch as usize];
let mut carry = 0u64;
for b in num.iter_mut().rev() {
let temp = u64::from(*b) * 58 + carry;
*b = (temp % 256) as u8;
carry = temp / 256;
}
while carry > 0 {
num.insert(0, (carry % 256) as u8);
carry /= 256;
}
let mut c = u64::from(val);
let mut pos = num.len();
while c > 0 {
if pos == 0 {
while c > 0 {
num.insert(0, (c % 256) as u8);
c /= 256;
}
break;
}
pos -= 1;
let temp = u64::from(num[pos]) + c;
num[pos] = (temp % 256) as u8;
c = temp / 256;
}
}
output.append(&mut num);
output.splice(0..0, std::iter::repeat_n(0u8, zeros));
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
#[target_feature(enable = "avx2")]
#[allow(
unsafe_op_in_unsafe_fn,
clippy::cast_ptr_alignment,
clippy::ptr_as_ptr,
clippy::cast_possible_truncation
)]
unsafe fn decode_simd_x86(output: &mut Vec<u8>, digits: &[u8], zeros: usize) {
decode_scalar(output, digits, zeros);
}
#[cfg(all(target_arch = "aarch64", feature = "simd"))]
#[target_feature(enable = "neon")]
#[allow(
unsafe_op_in_unsafe_fn,
clippy::cast_ptr_alignment,
clippy::ptr_as_ptr,
clippy::cast_possible_truncation
)]
unsafe fn decode_simd_arm(output: &mut Vec<u8>, digits: &[u8], zeros: usize) {
decode_scalar(output, digits, zeros);
}
fn finish_decode(mut output: Vec<u8>, validate_checksum: bool) -> Result<Vec<u8>, DecodeError> {
if validate_checksum {
if output.len() < 4 {
return Err(DecodeError::InvalidLength);
}
let payload = &output[..output.len() - 4];
let hash1 = Sha256::digest(payload);
let hash2 = Sha256::digest(hash1);
let expected_checksum = &hash2[0..4];
let actual_checksum = &output[output.len() - 4..];
if expected_checksum != actual_checksum {
return Err(DecodeError::Checksum);
}
output.truncate(output.len() - 4);
}
Ok(output)
}
const DIGIT_TO_VAL: [u8; 128] = {
let mut table = [255u8; 128];
let alphabet = &ALPHABET;
let mut idx = 0u8;
let mut i = 0usize;
while i < 58 {
let ch = alphabet[i];
if (ch as usize) < 128 {
table[ch as usize] = idx;
}
idx += 1;
i += 1;
}
table
};
#[cfg(test)]
mod tests {
use super::*;
use hex_literal::hex;
#[test]
fn decode_known_no_checksum() {
assert_eq!(decode(""), Ok(vec![]));
assert_eq!(decode("1"), Ok(vec![0u8]));
assert_eq!(decode("Cn8eVZg"), Ok(b"hello".to_vec()));
let encoded = "111114VYJtj3yEDffZem7N3PkK563wkLZZ8RjKzcfY";
let genesis = hex!("000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f");
assert_eq!(decode(encoded), Ok(genesis.to_vec()));
assert!(matches!(
decode("invalid!"),
Err(DecodeError::InvalidChar(4))
));
}
#[test]
fn decode_with_checksum() {
let addr = "1BitcoinEaterAddressDontSendf59kuE";
let expected_payload = hex!("00759d6677091e973b9e9d99f19c68fbf43e3f05f9");
assert_eq!(decode_full(addr, true).unwrap(), expected_payload.to_vec());
let invalid_addr = "1BitcoinEaterAddressDontSendf59kuF";
assert!(matches!(
decode_full(invalid_addr, true),
Err(DecodeError::Checksum)
));
}
#[test]
fn decode_length_error() {
assert!(matches!(
decode_full("12", true),
Err(DecodeError::InvalidLength)
));
}
#[test]
fn simd_dispatch() {
let _ = decode("Cn8eVZg");
}
#[test]
fn simd_correctness() {
let long = b"hello world bsv58 test payload for simd".repeat(10);
let enc = crate::encode(&long);
let dec = decode(&enc).unwrap();
assert_eq!(dec, long.to_vec());
}
#[test]
fn chunked_correctness() {
let input = "111114VYJtj3yEDffZem7N3PkK563wkLZZ8RjKzcfY";
let expected = hex!("000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f");
let dec = decode(input).unwrap();
assert_eq!(dec, expected.to_vec());
}
#[test]
fn large_decode() {
let long = vec![42u8; 1024];
let enc = crate::encode(&long);
let dec = decode(&enc).unwrap();
assert_eq!(dec, long);
}
}