use crate::error::{Error, Result};
const STD_ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
const PAD: u8 = b'=';
#[must_use]
pub fn base64_encode(input: &[u8]) -> String {
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
for chunk in input.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = *chunk.get(1).unwrap_or(&0) as u32;
let b2 = *chunk.get(2).unwrap_or(&0) as u32;
let n = (b0 << 16) | (b1 << 8) | b2;
out.push(STD_ALPHABET[(n >> 18) as usize & 0x3f] as char);
out.push(STD_ALPHABET[(n >> 12) as usize & 0x3f] as char);
out.push(if chunk.len() > 1 {
STD_ALPHABET[(n >> 6) as usize & 0x3f] as char
} else {
PAD as char
});
out.push(if chunk.len() > 2 {
STD_ALPHABET[n as usize & 0x3f] as char
} else {
PAD as char
});
}
out
}
#[inline]
fn b64_value(c: u8) -> Option<u8> {
match c {
b'A'..=b'Z' => Some(c - b'A'),
b'a'..=b'z' => Some(c - b'a' + 26),
b'0'..=b'9' => Some(c - b'0' + 52),
b'+' => Some(62),
b'/' => Some(63),
_ => None,
}
}
pub fn base64_decode(input: &str) -> Result<Vec<u8>> {
let bytes = input.as_bytes();
if bytes.len() % 4 != 0 {
return Err(Error::MalformedNote(
"base64 length is not a multiple of 4".into(),
));
}
if bytes.is_empty() {
return Ok(Vec::new());
}
let pad = bytes.iter().rev().take_while(|&&c| c == PAD).count();
if pad > 2 {
return Err(Error::MalformedNote("too much base64 padding".into()));
}
let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
for group in bytes.chunks(4) {
let mut acc = 0u32;
let mut real = 0usize;
for (i, &c) in group.iter().enumerate() {
if c == PAD {
acc <<= 6;
} else {
let v = b64_value(c)
.ok_or_else(|| Error::MalformedNote("invalid base64 character".into()))?;
acc = (acc << 6) | u32::from(v);
real = i + 1;
}
}
match real {
4 => {
out.push((acc >> 16) as u8);
out.push((acc >> 8) as u8);
out.push(acc as u8);
}
3 => {
if acc & 0xff != 0 {
return Err(Error::MalformedNote("non-canonical base64".into()));
}
out.push((acc >> 16) as u8);
out.push((acc >> 8) as u8);
}
2 => {
if acc & 0xffff != 0 {
return Err(Error::MalformedNote("non-canonical base64".into()));
}
out.push((acc >> 16) as u8);
}
_ => return Err(Error::MalformedNote("invalid base64 group".into())),
}
}
Ok(out)
}
#[must_use]
pub fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for &b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0x0f) as usize] as char);
}
out
}
pub fn hex_decode(s: &str) -> Result<Vec<u8>> {
let bytes = s.as_bytes();
if bytes.len() % 2 != 0 {
return Err(Error::MalformedNote("odd-length hex".into()));
}
let val = |c: u8| -> Result<u8> {
match c {
b'0'..=b'9' => Ok(c - b'0'),
b'a'..=b'f' => Ok(c - b'a' + 10),
b'A'..=b'F' => Ok(c - b'A' + 10),
_ => Err(Error::MalformedNote("invalid hex character".into())),
}
};
let mut out = Vec::with_capacity(bytes.len() / 2);
for pair in bytes.chunks(2) {
out.push((val(pair[0])? << 4) | val(pair[1])?);
}
Ok(out)
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn base64_known_vectors() {
assert_eq!(base64_encode(b""), "");
assert_eq!(base64_encode(b"f"), "Zg==");
assert_eq!(base64_encode(b"fo"), "Zm8=");
assert_eq!(base64_encode(b"foo"), "Zm9v");
assert_eq!(base64_encode(b"foob"), "Zm9vYg==");
assert_eq!(base64_encode(b"fooba"), "Zm9vYmE=");
assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
assert_eq!(base64_decode("Zg==").unwrap(), b"f");
assert_eq!(base64_decode("Zm8=").unwrap(), b"fo");
assert_eq!(base64_decode("Zm9vYmFy").unwrap(), b"foobar");
}
#[test]
fn base64_rejects_bad_input() {
assert!(base64_decode("Zg=").is_err()); assert!(base64_decode("Z===").is_err()); assert!(base64_decode("Zm9*").is_err()); assert!(base64_decode("Zh==").is_err()); }
#[test]
fn hex_roundtrip_and_width() {
assert_eq!(hex_encode(&[0x00, 0x0f, 0xa3, 0xff]), "000fa3ff");
assert_eq!(
hex_decode("000fa3ff").unwrap(),
vec![0x00, 0x0f, 0xa3, 0xff]
);
assert!(hex_decode("abc").is_err());
assert!(hex_decode("zz").is_err());
}
proptest! {
#[test]
fn base64_roundtrip(data: Vec<u8>) {
let encoded = base64_encode(&data);
prop_assert_eq!(base64_decode(&encoded).unwrap(), data);
}
#[test]
fn hex_roundtrip(data: Vec<u8>) {
let encoded = hex_encode(&data);
prop_assert_eq!(hex_decode(&encoded).unwrap(), data);
}
}
}