1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use crate::common::*;
use crate::decode::DecodeError;
use mimalloc::MiMalloc;
mod common;
mod decode;
mod encode;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
pub fn encode(input: &[u8]) -> String {
let mut buffer = vec![0; ((input.len() / 3) + 1) * 4];
let total_chunks = input.len() / (ENC_CHUNK_SIZE * 3);
encode::encode_u128_chunks(input, &mut buffer);
let bytes_rem = encode::encode_u128_remainder(
&input[ENC_CHUNK_SIZE * total_chunks * 3..],
&mut buffer[ENC_CHUNK_SIZE * total_chunks * 4..],
);
buffer.truncate(ENC_CHUNK_SIZE * total_chunks * 4 + bytes_rem);
unsafe { String::from_utf8_unchecked(buffer) }
}
pub fn decode(encoded: &str) -> Result<Vec<u8>, DecodeError> {
let input = encoded.as_bytes();
let mut buffer = vec![0; ((input.len() + 3) / 4) * 3];
let total_chunks = input.len().saturating_sub(2) / (DEC_CHUNK_SIZE * 4);
let in_limit = total_chunks * DEC_CHUNK_SIZE * 4;
let out_limit = total_chunks * DEC_CHUNK_SIZE * 3;
decode::decode_u64_chunks(&input[..in_limit], &mut buffer)?;
let bytes_rem = decode::decode_u64_remainder(&input[in_limit..], &mut buffer[out_limit..])?;
buffer.truncate(out_limit + bytes_rem);
Ok(buffer)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
#[test]
fn should_encode_following_base64_spec() {
assert_eq!(encode(b"Hello!"), "SGVsbG8h");
assert_eq!(encode(b"0123456789"), "MDEyMzQ1Njc4OQ==");
assert_eq!(
encode(b"https://foo.bar/q?a=2&b=3#fr"),
"aHR0cHM6Ly9mb28uYmFyL3E/YT0yJmI9MyNmcg=="
);
assert_eq!(encode(b" "), "ICA=");
assert_eq!(encode(b""), "");
assert_eq!(encode(&0u32.to_ne_bytes()), "AAAAAA==");
}
#[test]
fn should_decode_following_base64_spec() {
assert_eq!(decode("SGVsbG8h").unwrap(), b"Hello!");
assert_eq!(decode("MDEyMzQ1Njc4OQ==").unwrap(), b"0123456789");
assert_eq!(
decode("aHR0cHM6Ly9mb28uYmFyL3E/YT0yJmI9MyNmcg==").unwrap(),
b"https://foo.bar/q?a=2&b=3#fr"
);
assert_eq!(decode("ICA=").unwrap(), b" ");
assert_eq!(decode("").unwrap(), b"");
assert_eq!(decode("AAAAAA==").unwrap(), 0u32.to_ne_bytes())
}
#[test]
fn should_preserve_original_input() {
for size in 0..512 {
let bytes = random_bytes(size);
assert_eq!(decode(&encode(&bytes)).unwrap(), bytes);
}
let large = random_bytes(3 * 1024 * 1024);
assert_eq!(decode(&encode(&large)).unwrap(), large);
}
#[test]
fn should_error_when_decode_non_base64_input() {
assert_eq!(
decode("AAA^AAA==").unwrap_err(),
DecodeError::InvalidByte(b'^')
);
assert_eq!(decode("!").unwrap_err(), DecodeError::InvalidByte(b'!'));
assert_eq!(
decode("\nNjc4OQ==").unwrap_err(),
DecodeError::InvalidByte(b'\n')
);
assert_eq!(
"Invalid non-base64 byte '#'",
format!("{}", DecodeError::InvalidByte(b'#'))
);
}
fn random_bytes(size: usize) -> Vec<u8> {
let mut bytes = Vec::with_capacity(size);
let mut r = SmallRng::from_entropy();
while bytes.len() < size {
bytes.push(r.gen::<u8>());
}
bytes
}
}