use std::collections::HashMap;
pub fn encode(bytes: &Vec<u8>) -> String {
return encode_with_alphabet(&bytes, &String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="));
}
pub fn decode(base64_string: &String) -> Vec<u8> {
return decode_with_alphabet(&base64_string, &String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="));
}
pub fn encode_with_alphabet(bytes: &Vec<u8>, alphabet: &String) -> String {
if bytes.len() == 0 {
return String::from("");
}
if !validate_alphabet(&alphabet) {
panic!("Invalid alphabet!");
}
let pad_char: char = get_pad_char(&alphabet);
let fixed_alphabet: String = remove_pad_char(&alphabet);
let lookup_table = get_encode_lookup(&fixed_alphabet);
let pads: usize = get_number_of_pads(&bytes);
let mut buffer: Vec<u8> = vec![0; bytes.len() + pads];
let mut output: String = String::from("");
for i in 0..bytes.len() {
buffer[i] = bytes[i];
}
let mut i = 0;
while i < buffer.len() {
let mut number: usize = 0;
let mut segment: usize;
let mask: usize = 0x3f;
number += buffer[i] as usize;
number = number << 8;
number += buffer[i+1] as usize;
number = number << 8;
number += buffer[i+2] as usize;
if i == buffer.len() - 3 && pads != 0 {
if pads == 1 {
segment = (number >> 18) & mask;
output.push(lookup_table[segment]);
segment = (number >> 12) & mask;
output.push(lookup_table[segment]);
segment = (number >> 6) & mask;
output.push(lookup_table[segment]);
output.push(pad_char);
} else if pads == 2 {
segment = (number >> 18) & mask;
output.push(lookup_table[segment]);
segment = (number >> 12) & mask;
output.push(lookup_table[segment]);
output.push(pad_char);
output.push(pad_char);
} else {
panic!("pads had invalid value???");
}
} else {
segment = (number >> 18) & mask;
output.push(lookup_table[segment]);
segment = (number >> 12) & mask;
output.push(lookup_table[segment]);
segment = (number >> 6) & mask;
output.push(lookup_table[segment]);
segment = number & mask;
output.push(lookup_table[segment]);
}
i = i + 3;
}
return output;
}
pub fn decode_with_alphabet(base64_string: &String, alphabet: &String) -> Vec<u8> {
if base64_string == "" {
let empty_vec: Vec<u8> = vec![];
return empty_vec;
}
if !validate_alphabet(&alphabet) {
panic!("Invalid alphabet!");
}
let pad_char: char = get_pad_char(&alphabet);
let fixed_alphabet: String = remove_pad_char(&alphabet);
let decode_lookup: HashMap<char, usize> = get_decode_lookup(&fixed_alphabet, &pad_char);
let base64_vector: Vec<char> = base64_string.chars().collect();
let mut output_vector: Vec<u8> = vec![];
let pad_count: usize = get_pad_count(&base64_vector, &pad_char);
let mut i: usize = 0;
while i < base64_vector.len() {
let mut number: usize = 0;
number += match decode_lookup.get(&base64_vector[i]) {
Some(value) => *value,
None => panic!("Didn't find that key! {}", &base64_vector[i])
};
number = number << 6;
number += match decode_lookup.get(&base64_vector[i+1]) {
Some(value) => *value,
None => panic!("Didn't find that key! {}", &base64_vector[i+1])
};
number = number << 6;
number += match decode_lookup.get(&base64_vector[i+2]) {
Some(value) => *value,
None => panic!("Didn't find that key! {}", &base64_vector[i+2])
};
number = number << 6;
number += match decode_lookup.get(&base64_vector[i+3]) {
Some(value) => *value,
None => panic!("Didn't find that key! {}", &base64_vector[i+3])
};
if i != base64_vector.len() - 4 {
output_vector.push(((number & 0xff0000) >> 16) as u8);
output_vector.push(((number & 0x00ff00) >> 8) as u8);
output_vector.push((number & 0x0000ff) as u8);
} else {
if pad_count == 0 {
output_vector.push(((number & 0xff0000) >> 16) as u8);
output_vector.push(((number & 0x00ff00) >> 8) as u8);
output_vector.push((number & 0x0000ff) as u8);
} else if pad_count == 1 {
output_vector.push(((number & 0xff0000) >> 16) as u8);
output_vector.push(((number & 0x00ff00) >> 8) as u8);
} else if pad_count == 2 {
output_vector.push(((number & 0xff0000) >> 16) as u8);
} else {
panic!("Invalid pad_count");
}
}
i = i + 4;
}
return output_vector;
}
fn get_encode_lookup(alphabet: &String) -> Vec<char> {
let lookup: Vec<char> = alphabet.chars().collect();
return lookup;
}
fn get_decode_lookup(alphabet: &String, pad_char: &char) -> HashMap<char, usize> {
let alphabet_vector: Vec<char> = alphabet.chars().collect();
let mut reverse_lookup = HashMap::new();
let mut i: usize = 0;
for c in alphabet_vector {
reverse_lookup.insert(c, i);
i = i + 1;
}
reverse_lookup.insert(*pad_char, 0);
return reverse_lookup;
}
fn remove_pad_char(alphabet: &String) -> String {
let mut alphabet_vector: Vec<char> = alphabet.chars().collect();
let index = alphabet_vector.len() - 1;
let mut result: String = String::from("");
alphabet_vector.remove(index);
for c in alphabet_vector {
result.push(c);
}
return result;
}
fn get_number_of_pads(bytes: &Vec<u8>) -> usize {
if bytes.len() % 3 != 0 {
return 3 - (bytes.len() % 3);
}
return 0;
}
fn get_pad_char(alphabet: &String) -> char {
let alphabet_vector: Vec<char> = alphabet.chars().collect();
return alphabet_vector[alphabet_vector.len() - 1];
}
fn get_pad_count(bytes: &Vec<char>, pad_char: &char) -> usize {
let mut pad_count = 0;
let mut i: usize = 0;
while i < bytes.len() {
if bytes[i] == *pad_char {
pad_count += 1;
}
i = i + 1;
}
return pad_count;
}
fn validate_alphabet(alphabet: &String) -> bool {
let alphabet_vector: Vec<char> = alphabet.chars().collect();
if alphabet_vector.len() != 65 {
println!("Invalid alphabet length! {}", alphabet_vector.len());
return false;
}
let mut lookup_map: HashMap<char, bool> = HashMap::new();
for c in &alphabet_vector {
if lookup_map.contains_key(c) {
println!("Duplicate key! {}", &c);
return false;
} else {
lookup_map.insert(*c, true);
}
}
return true;
}
#[cfg(test)]
mod tests {
use super::*;
fn get_plain_strings() -> Vec<String> {
let plain_strings: Vec<String> = vec![
String::from("Foo Bar"),
String::from("This string\nhas newlines and\ttabs"),
String::from("mañana"),
String::from("Iñtërnâtiônàlizætiøn☃💩"),
String::from("🇺🇸🇺🇸")
];
return plain_strings;
}
fn get_base64_strings() -> Vec<String> {
let base64_strings: Vec<String> = vec![
String::from("Rm9vIEJhcg=="),
String::from("VGhpcyBzdHJpbmcKaGFzIG5ld2xpbmVzIGFuZAl0YWJz"),
String::from("bWHDsWFuYQ=="),
String::from("ScOxdMOrcm7DonRpw7Ruw6BsaXrDpnRpw7hu4piD8J+SqQ=="),
String::from("8J+HuvCfh7jwn4e68J+HuA==")
];
return base64_strings;
}
fn get_nonstandard_base64_strings() -> Vec<String> {
let nonstandard_base64_strings: Vec<String> = vec![
String::from("РёэоЗДИбЬа++"),
String::from("ХЁбиЬсБтЭЖИиЫёЬЙЪЁЕтЗЁщеЭцриЫёХтЗЁЕнЩАефШЦИт"),
String::from("ЫЦЖГлЦЕнШП++"),
String::from("СЬНрЭЛНкЬёыГзжРипыРнпъБлЪЧкГижРипыбншивГьИюСйП++"),
String::from("ьИюЖноВЯбыгпжшЮъьИюЖнА++")
];
return nonstandard_base64_strings;
}
fn get_alphabets() -> Vec<String> {
let alphabet_strings: Vec<String> = vec![
String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="),
String::from("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/="),
String::from("АБВГДЕЁЖЗИЙКЛМНОПРСТФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстфхцчшщъыьэюя+"),
String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZабвгдеёжзийклмнопрстyфхцчшщъыьэюя01234+"),
String::from("Foo"),
String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/#$%"),
String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123abcdefghijklmnopqrstuvwxyz0123456789+/")
];
return alphabet_strings;
}
mod encode_tests {
use super::*;
#[test]
fn should_encode_empty_vector_correctly() {
let actual = encode(&vec![]);
assert!(actual == String::from(""));
}
#[test]
fn should_encode_plain_strings_correctly() {
let plain_strings = get_plain_strings();
let base64_strings = get_base64_strings();
let mut i: usize = 0;
while i < plain_strings.len() {
let test_vector = &plain_strings[i].as_bytes().to_vec();
let actual = encode(&test_vector);
assert_eq!(actual, base64_strings[i]);
i = i + 1;
}
}
}
mod decode_tests {
use super::*;
#[test]
fn decode_empty_string_should_return_empty_vector() {
let actual: Vec<u8> = decode(&String::from(""));
assert_eq!(actual, []);
}
#[test]
fn decode_should_decode_base64_strings_correctly() {
let base64_strings = get_base64_strings();
let plain_strings = get_plain_strings();
let mut i: usize = 0;
while i < base64_strings.len() {
let test_string = &base64_strings[i];
let actual: String = match String::from_utf8(decode(&test_string)) {
Ok(s) => s,
Err(e) => panic!("Invalid utf-8 sequence: {:?}", e),
};
assert_eq!(actual, plain_strings[i]);
i = i + 1;
}
}
}
mod encode_with_alphabet_tests {
use super::*;
#[test]
fn encode_with_alphabet_should_encode_empty_vector_correctly() {
let alphabet_vector: Vec<String> = get_alphabets();
let actual = encode_with_alphabet(&vec![], &alphabet_vector[2]);
assert!(actual == String::from(""));
}
#[test]
fn encode_with_alphabet_should_encode_plain_strings_correctly() {
let alphabet_vector: Vec<String> = get_alphabets();
let plain_strings = get_plain_strings();
let nonstandard_base64_strings = get_nonstandard_base64_strings();
let mut i: usize = 0;
while i < plain_strings.len() {
let test_vector = &plain_strings[i].as_bytes().to_vec();
let actual = encode_with_alphabet(&test_vector, &alphabet_vector[2]);
assert_eq!(actual, nonstandard_base64_strings[i]);
i = i + 1;
}
}
}
mod decode_with_alphabet_tests {
use super::*;
#[test]
fn decode_with_alphabet_should_decode_empty_string_correctly() {
let alphabet_vector: Vec<String> = get_alphabets();
let actual: Vec<u8> = decode_with_alphabet(&String::from(""), &alphabet_vector[2]);
assert_eq!(actual, []);
}
#[test]
fn decode_with_alphabet_should_decode_base64_strings_correctly() {
let alphabet_vector: Vec<String> = get_alphabets();
let nonstandard_base64_strings = get_nonstandard_base64_strings();
let plain_strings = get_plain_strings();
let mut i: usize = 0;
while i < nonstandard_base64_strings.len() {
let test_string = &nonstandard_base64_strings[i];
let actual: String = match String::from_utf8(decode_with_alphabet(&test_string, &alphabet_vector[2])) {
Ok(s) => s,
Err(e) => panic!("Invalid utf-8 sequence: {:?}", e),
};
println!("actual = {}, plain_strings[{}] = {}", actual, i, plain_strings[i]);
assert_eq!(actual, plain_strings[i]);
i = i + 1;
}
}
}
mod get_encode_lookup_tests {
use super::*;
#[test]
fn should_give_expected_output() {
let test_string: String = String::from("Test");
let expected: Vec<char> = vec!['T', 'e', 's', 't'];
assert_eq!(expected, get_encode_lookup(&test_string));
}
}
mod get_decode_lookup_tests {
use super::*;
#[test]
fn should_build_decode_lookup_correctly() {
let mut test_data: HashMap<char, usize> = HashMap::new();
test_data.insert('A', 0);
test_data.insert('B', 1);
test_data.insert('C', 2);
test_data.insert('+', 0);
let test_alphabet: String = String::from("ABC");
let pad_char: char = '+';
assert_eq!(test_data, get_decode_lookup(&test_alphabet, &pad_char));
}
}
mod get_pad_char_tests {
use super::*;
#[test]
fn get_pad_char_should_return_correct_char() {
let alphabet = String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=");
assert_eq!(get_pad_char(&alphabet), '=');
}
#[test]
fn get_pad_char_should_return_correct_char_nonstandard_alphabet() {
let alphabet = String::from("АБВГДЕЁЖЗИЙКЛМНОПРСТФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстфхцчшщъыьэюя+");
assert_eq!(get_pad_char(&alphabet), '+');
}
}
mod remove_pad_char_tests {
use super::*;
#[test]
fn remove_pad_char_should_return_modified_alphabet_standard() {
let alphabet = String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=");
let expected = String::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
assert_eq!(remove_pad_char(&alphabet), expected);
}
#[test]
fn remove_pad_char_should_return_modified_alphabet_nonstandard() {
let alphabet = String::from("АБВГДЕЁЖЗИЙКЛМНОПРСТФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстфхцчшщъыьэюя+");
let expected = String::from("АБВГДЕЁЖЗИЙКЛМНОПРСТФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстфхцчшщъыьэюя");
assert_eq!(remove_pad_char(&alphabet), expected);
}
}
mod get_number_of_pads_tests {
use super::*;
#[test]
fn get_number_of_pads_should_return_zero() {
let test_vec = String::from("Foo").into_bytes();
assert!(get_number_of_pads(&test_vec) == 0);
}
#[test]
fn get_number_of_pads_should_return_one() {
let test_vec = String::from("Fooox").into_bytes();
assert!(get_number_of_pads(&test_vec) == 1);
}
#[test]
fn get_number_of_pads_should_return_two() {
let test_vec = String::from("Foo Bar").into_bytes();
assert!(get_number_of_pads(&test_vec) == 2);
}
}
mod get_pad_count_tests {
use super::*;
#[test]
fn get_pad_count_should_return_zero() {
let input: Vec<char> = String::from("FooBar").chars().collect();
let pad_char: char = '=';
let actual: usize = get_pad_count(&input, &pad_char);
assert_eq!(actual, 0);
}
#[test]
fn get_pad_count_should_return_one() {
let input: Vec<char> = String::from("FooBar=").chars().collect();
let pad_char: char = '=';
let actual: usize = get_pad_count(&input, &pad_char);
assert_eq!(actual, 1);
}
#[test]
fn get_pad_count_should_return_two() {
let input: Vec<char> = String::from("FooBar==").chars().collect();
let pad_char: char = '=';
let actual: usize = get_pad_count(&input, &pad_char);
assert_eq!(actual, 2);
}
}
mod validate_alphabet_tests {
use super::*;
#[test]
fn validate_alphabet_should_return_true_with_standard_alphabet() {
let alphabet_vector: Vec<String> = get_alphabets();
assert_eq!(validate_alphabet(&alphabet_vector[0]), true);
}
#[test]
fn validate_alphabet_should_return_true_with_nonstandard_alphabet() {
let alphabet_vector: Vec<String> = get_alphabets();
assert_eq!(validate_alphabet(&alphabet_vector[2]), true);
}
#[test]
fn validate_alphabet_should_return_false_with_short_alphabet() {
let alphabet_vector: Vec<String> = get_alphabets();
assert_eq!(validate_alphabet(&alphabet_vector[4]), false);
}
#[test]
fn validate_alphabet_should_return_false_with_long_alphabet() {
let alphabet_vector: Vec<String> = get_alphabets();
assert_eq!(validate_alphabet(&alphabet_vector[5]), false);
}
#[test]
fn validate_alphabet_should_return_false_with_nonunique_symbols() {
let alphabet_vector: Vec<String> = get_alphabets();
assert_eq!(validate_alphabet(&alphabet_vector[6]), false);
}
}
}