use crate::base::scan::{ConvertSymbols, EntrySymbol, ScannerError};
use core::fmt;
use octseq::builder::{
EmptyBuilder, FreezeBuilder, FromBuilder, OctetsBuilder, ShortBuf,
};
#[cfg(feature = "std")]
use std::string::String;
pub fn decode<Octets>(s: &str) -> Result<Octets, DecodeError>
where
Octets: FromBuilder,
<Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
{
let mut decoder = Decoder::<<Octets as FromBuilder>::Builder>::new();
for ch in s.chars() {
decoder.push(ch)?;
}
decoder.finalize()
}
pub fn display<B, W>(bytes: &B, f: &mut W) -> fmt::Result
where
B: AsRef<[u8]> + ?Sized,
W: fmt::Write,
{
fn ch(i: u8) -> char {
ENCODE_ALPHABET[i as usize]
}
for chunk in bytes.as_ref().chunks(3) {
match chunk.len() {
1 => {
f.write_char(ch(chunk[0] >> 2))?;
f.write_char(ch((chunk[0] & 0x03) << 4))?;
f.write_char('=')?;
f.write_char('=')?;
}
2 => {
f.write_char(ch(chunk[0] >> 2))?;
f.write_char(ch(((chunk[0] & 0x03) << 4) | (chunk[1] >> 4)))?;
f.write_char(ch((chunk[1] & 0x0F) << 2))?;
f.write_char('=')?;
}
3 => {
f.write_char(ch(chunk[0] >> 2))?;
f.write_char(ch(((chunk[0] & 0x03) << 4) | (chunk[1] >> 4)))?;
f.write_char(ch(((chunk[1] & 0x0F) << 2) | (chunk[2] >> 6)))?;
f.write_char(ch(chunk[2] & 0x3F))?;
}
_ => unreachable!(),
}
}
Ok(())
}
#[cfg(feature = "std")]
pub fn encode_string<B: AsRef<[u8]> + ?Sized>(bytes: &B) -> String {
let mut res = String::with_capacity((bytes.as_ref().len() / 3 + 1) * 4);
display(bytes, &mut res).unwrap();
res
}
pub fn encode_display<Octets: AsRef<[u8]>>(
octets: &Octets,
) -> impl fmt::Display + '_ {
struct Display<'a>(&'a [u8]);
impl fmt::Display for Display<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
display(self.0, f)
}
}
Display(octets.as_ref())
}
#[cfg(feature = "serde")]
pub mod serde {
use super::encode_display;
use core::fmt;
use octseq::builder::{EmptyBuilder, FromBuilder, OctetsBuilder};
use octseq::serde::{DeserializeOctets, SerializeOctets};
pub fn serialize<Octets, S>(
octets: &Octets,
serializer: S,
) -> Result<S::Ok, S::Error>
where
Octets: AsRef<[u8]> + SerializeOctets,
S: serde::Serializer,
{
if serializer.is_human_readable() {
serializer.collect_str(&encode_display(octets))
} else {
octets.serialize_octets(serializer)
}
}
pub fn deserialize<'de, Octets, D: serde::Deserializer<'de>>(
deserializer: D,
) -> Result<Octets, D::Error>
where
Octets: FromBuilder + DeserializeOctets<'de>,
<Octets as FromBuilder>::Builder: EmptyBuilder,
{
struct Visitor<'de, Octets: DeserializeOctets<'de>>(Octets::Visitor);
impl<'de, Octets> serde::de::Visitor<'de> for Visitor<'de, Octets>
where
Octets: FromBuilder + DeserializeOctets<'de>,
<Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
{
type Value = Octets;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("an Base64-encoded string")
}
fn visit_str<E: serde::de::Error>(
self,
v: &str,
) -> Result<Self::Value, E> {
super::decode(v).map_err(E::custom)
}
fn visit_borrowed_bytes<E: serde::de::Error>(
self,
value: &'de [u8],
) -> Result<Self::Value, E> {
self.0.visit_borrowed_bytes(value)
}
#[cfg(feature = "std")]
fn visit_byte_buf<E: serde::de::Error>(
self,
value: std::vec::Vec<u8>,
) -> Result<Self::Value, E> {
self.0.visit_byte_buf(value)
}
}
if deserializer.is_human_readable() {
deserializer.deserialize_str(Visitor(Octets::visitor()))
} else {
Octets::deserialize_with_visitor(
deserializer,
Visitor(Octets::visitor()),
)
}
}
}
pub struct Decoder<Builder> {
buf: [u8; 4],
next: usize,
target: Result<Builder, DecodeError>,
}
impl<Builder: EmptyBuilder> Decoder<Builder> {
#[must_use]
pub fn new() -> Self {
Decoder {
buf: [0; 4],
next: 0,
target: Ok(Builder::empty()),
}
}
}
impl<Builder: OctetsBuilder> Decoder<Builder> {
pub fn finalize(self) -> Result<Builder::Octets, DecodeError>
where
Builder: FreezeBuilder,
{
let (target, next) = (self.target, self.next);
target.and_then(|bytes| {
if next & 0x0F != 0 {
Err(DecodeError::ShortInput)
} else {
Ok(bytes.freeze())
}
})
}
pub fn push(&mut self, ch: char) -> Result<(), DecodeError> {
if self.next == 0xF0 {
self.target = Err(DecodeError::TrailingInput);
return Err(DecodeError::TrailingInput);
}
let val = if ch == PAD {
if self.next < 2 {
return Err(DecodeError::IllegalChar(ch));
}
0x80 } else {
if ch > (127 as char) {
return Err(DecodeError::IllegalChar(ch));
}
let val = DECODE_ALPHABET[ch as usize];
if val == 0xFF {
return Err(DecodeError::IllegalChar(ch));
}
val
};
self.buf[self.next] = val;
self.next += 1;
if self.next == 4 {
let target = self.target.as_mut().unwrap(); target
.append_slice(&[(self.buf[0] << 2) | (self.buf[1] >> 4)])
.map_err(Into::into)?;
if self.buf[2] != 0x80 {
target
.append_slice(&[(self.buf[1] << 4) | (self.buf[2] >> 2)])
.map_err(Into::into)?;
}
if self.buf[3] != 0x80 {
if self.buf[2] == 0x80 {
return Err(DecodeError::TrailingInput);
}
target
.append_slice(&[(self.buf[2] << 6) | self.buf[3]])
.map_err(Into::into)?;
self.next = 0
} else {
self.next = 0xF0
}
}
Ok(())
}
}
impl<Builder: EmptyBuilder> Default for Decoder<Builder> {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, Default)]
pub struct SymbolConverter {
input: [u8; 4],
next: usize,
output: [u8; 3],
}
impl SymbolConverter {
#[must_use]
pub fn new() -> Self {
Default::default()
}
fn process_char<Error: ScannerError>(
&mut self,
ch: char,
) -> Result<Option<&[u8]>, Error> {
if self.next == EOF_MARKER {
return Err(Error::custom("trailing Base 64 data"));
}
let val = if ch == PAD {
if self.next < 2 {
return Err(Error::custom("illegal Base 64 data"));
}
PAD_MARKER } else {
if ch > (127 as char) {
return Err(Error::custom("illegal Base 64 data"));
}
let val = DECODE_ALPHABET[ch as usize];
if val == 0xFF {
return Err(Error::custom("illegal Base 64 data"));
}
val
};
self.input[self.next] = val;
self.next += 1;
if self.next == 4 {
self.output[0] = (self.input[0] << 2) | (self.input[1] >> 4);
if self.input[2] == PAD_MARKER {
if self.input[3] == PAD_MARKER {
self.next = EOF_MARKER;
Ok(Some(&self.output[..1]))
} else {
Err(Error::custom("illegal Base 64 data"))
}
} else {
self.output[1] = (self.input[1] << 4) | (self.input[2] >> 2);
if self.input[3] == PAD_MARKER {
self.next = EOF_MARKER;
Ok(Some(&self.output[..2]))
} else {
self.output[2] = (self.input[2] << 6) | self.input[3];
self.next = 0;
Ok(Some(&self.output))
}
}
} else {
Ok(None)
}
}
}
impl<Sym, Error> ConvertSymbols<Sym, Error> for SymbolConverter
where
Sym: Into<EntrySymbol>,
Error: ScannerError,
{
fn process_symbol(
&mut self,
symbol: Sym,
) -> Result<Option<&[u8]>, Error> {
match symbol.into() {
EntrySymbol::Symbol(symbol) => self.process_char(
symbol
.into_char()
.map_err(|_| Error::custom("illegal Base 64 data"))?,
),
EntrySymbol::EndOfToken => Ok(None),
}
}
fn process_tail(&mut self) -> Result<Option<&[u8]>, Error> {
if self.next & 0x0F != 0 {
Err(Error::custom("incomplete Base 64 data"))
} else {
Ok(None)
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum DecodeError {
IllegalChar(char),
TrailingInput,
ShortInput,
ShortBuf,
}
impl From<ShortBuf> for DecodeError {
fn from(_: ShortBuf) -> Self {
DecodeError::ShortBuf
}
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
DecodeError::TrailingInput => f.write_str("trailing input"),
DecodeError::IllegalChar(ch) => {
write!(f, "illegal character '{}'", ch)
}
DecodeError::ShortInput => f.write_str("incomplete input"),
DecodeError::ShortBuf => ShortBuf.fmt(f),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for DecodeError {}
const DECODE_ALPHABET: [u8; 128] = [
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, ];
const ENCODE_ALPHABET: [char; 64] = [
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/', ];
const PAD: char = '=';
const PAD_MARKER: u8 = 0x80;
const EOF_MARKER: usize = 0xF0;
#[cfg(all(test, feature = "std"))]
mod test {
use super::*;
const HAPPY_CASES: &[(&[u8], &str)] = &[
(b"", ""),
(b"f", "Zg=="),
(b"fo", "Zm8="),
(b"foo", "Zm9v"),
(b"foob", "Zm9vYg=="),
(b"fooba", "Zm9vYmE="),
(b"foobar", "Zm9vYmFy"),
];
#[test]
fn decode_str() {
fn decode(s: &str) -> Result<std::vec::Vec<u8>, DecodeError> {
super::decode(s)
}
for (bin, text) in HAPPY_CASES {
assert_eq!(&decode(text).unwrap(), bin, "decode {}", text)
}
assert_eq!(decode("FPucA").unwrap_err(), DecodeError::ShortInput);
assert_eq!(
decode("FPucA=").unwrap_err(),
DecodeError::IllegalChar('=')
);
assert_eq!(decode("FPucAw=").unwrap_err(), DecodeError::ShortInput);
assert_eq!(
decode("FPucAw=a").unwrap_err(),
DecodeError::TrailingInput
);
assert_eq!(
decode("FPucAw==a").unwrap_err(),
DecodeError::TrailingInput
);
}
#[test]
fn symbol_converter() {
use crate::base::scan::Symbols;
use std::vec::Vec;
fn decode(s: &str) -> Result<Vec<u8>, std::io::Error> {
let mut convert = SymbolConverter::new();
let convert: &mut dyn ConvertSymbols<_, std::io::Error> =
&mut convert;
let mut res = Vec::new();
for sym in Symbols::new(s.chars()) {
if let Some(octs) = convert.process_symbol(sym)? {
res.extend_from_slice(octs);
}
}
if let Some(octs) = convert.process_tail()? {
res.extend_from_slice(octs);
}
Ok(res)
}
for (bin, text) in HAPPY_CASES {
assert_eq!(&decode(text).unwrap(), bin, "convert {}", text)
}
}
#[test]
fn display_bytes() {
use super::*;
fn fmt(s: &[u8]) -> String {
let mut out = String::new();
display(s, &mut out).unwrap();
out
}
for (bin, text) in HAPPY_CASES {
assert_eq!(&fmt(bin), text, "fmt {}", text);
}
}
}