use std::{collections::HashMap, fmt::Display};
use aes::Aes256;
use cosmian_fpe::ff1::{FF1h, FlexibleNumeralString};
use itertools::Itertools;
use super::AnoError;
use crate::{ano_ensure, core::KEY_LENGTH};
pub const RECOMMENDED_THRESHOLD: usize = 1_000_000;
pub fn min_plaintext_length(alphabet_len: usize) -> usize {
((RECOMMENDED_THRESHOLD as f32).log(alphabet_len as f32)).ceil() as usize
}
#[derive(Debug, Clone)]
pub struct Alphabet {
pub(crate) chars: Vec<char>,
pub(crate) min_text_length: usize,
}
impl TryFrom<&str> for Alphabet {
type Error = AnoError;
fn try_from(alphabet: &str) -> Result<Self, Self::Error> {
let chars = alphabet.chars().sorted().unique().collect_vec();
if chars.len() < 2 || chars.len() >= 1 << 16 {
return Err(AnoError::FPE(format!(
"Alphabet must contain between 2 and 2^16 characters. This alphabet contains {} \
characters",
chars.len()
)));
}
Ok(Self {
min_text_length: min_plaintext_length(chars.len()),
chars,
})
}
}
impl TryFrom<&String> for Alphabet {
type Error = AnoError;
fn try_from(value: &String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl Alphabet {
pub fn instantiate(alphabet: &str) -> Result<Self, AnoError> {
Self::try_from(alphabet)
}
fn extend_(&mut self, additional_characters: Vec<char>) {
self.chars.extend(additional_characters);
self.chars = self
.chars
.iter()
.sorted()
.unique()
.copied()
.collect::<Vec<_>>();
self.min_text_length = min_plaintext_length(self.chars.len());
}
#[must_use]
pub fn minimum_plaintext_length(&self) -> usize {
self.min_text_length
}
pub fn extend_with(&mut self, additional_characters: &str) {
self.extend_(additional_characters.chars().collect::<Vec<_>>());
}
#[must_use]
pub fn alphabet_len(&self) -> usize {
self.chars.len()
}
pub(crate) fn char_to_position(&self, c: char) -> Option<u16> {
match self.chars.binary_search(&c) {
Ok(pos) => Some(pos as u16),
Err(_) => None,
}
}
pub(crate) fn char_from_position(&self, position: u16) -> Option<char> {
let pos = position as usize;
if pos >= self.chars.len() {
return None;
}
Some(self.chars[pos])
}
fn rebase(&self, input: &str) -> (Vec<u16>, HashMap<usize, char>) {
let mut stripped_input: Vec<u16> = vec![];
let mut non_alphabet_chars = HashMap::<usize, char>::new();
for (idx, c) in input.chars().enumerate() {
if let Some(pos) = self.char_to_position(c) {
stripped_input.push(pos);
} else {
non_alphabet_chars.insert(idx, c);
};
}
(stripped_input, non_alphabet_chars)
}
fn debase(
&self,
mut stripped_input: Vec<u16>,
non_alphabet_chars: &HashMap<usize, char>,
) -> Result<String, AnoError> {
let mut result = vec![];
for i in 0..stripped_input.len() + non_alphabet_chars.len() {
result.push(if let Some(c) = non_alphabet_chars.get(&i) {
*c
} else {
let position = stripped_input.remove(0);
self.char_from_position(position).ok_or_else(|| {
AnoError::FPE(format!(
"index {} out of bounds for alphabet of size {}",
position,
self.alphabet_len()
))
})?
});
}
Ok(result.into_iter().collect::<String>())
}
pub fn encrypt(&self, key: &[u8], tweak: &[u8], plaintext: &str) -> Result<String, AnoError> {
let (stripped_input, non_alphabet_chars) = self.rebase(plaintext);
ano_ensure!(
stripped_input.len() >= self.minimum_plaintext_length(),
"The stripped input length of {} is too short. It should be at least {} given the \
alphabet length of {}.",
stripped_input.len(),
self.minimum_plaintext_length(),
self.alphabet_len()
);
if key.len() != KEY_LENGTH {
return Err(AnoError::KeySize(key.len(), KEY_LENGTH));
}
let fpe_ff = FF1h::<Aes256>::new(key, self.alphabet_len() as u32)
.map_err(|e| AnoError::FPE(format!("failed instantiating FF1: {e}")))?;
let ciphertext_ns = fpe_ff
.encrypt(tweak, &FlexibleNumeralString::from(stripped_input))
.map_err(|e| AnoError::FPE(format!("FF1 encryption failed: {e}")))?;
let ciphertext = Vec::<u16>::from(ciphertext_ns);
self.debase(ciphertext, &non_alphabet_chars)
}
pub fn decrypt(&self, key: &[u8], tweak: &[u8], ciphertext: &str) -> Result<String, AnoError> {
let (stripped_input, non_alphabet_chars) = self.rebase(ciphertext);
let fpe_ff = FF1h::<Aes256>::new(key, self.alphabet_len() as u32)
.map_err(|e| AnoError::FPE(format!("failed instantiating FF1: {e}")))?;
let plaintext_ns = fpe_ff
.decrypt(tweak, &FlexibleNumeralString::from(stripped_input))
.map_err(|e| AnoError::FPE(format!("FF1 decryption failed: {e}")))?;
let plaintext = Vec::<u16>::from(plaintext_ns);
self.debase(plaintext, &non_alphabet_chars)
}
}
impl Display for Alphabet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.chars.iter().collect::<String>()))
}
}
macro_rules! define_alphabet_constructors {
($($name:ident => $alphabet:expr),+) => {
$(
impl Alphabet {
#[doc = "Creates an Alphabet with the given alphabet string: `"]
#[doc = $alphabet]
#[doc = "`."]
#[must_use] pub fn $name() -> Alphabet {
Alphabet::try_from($alphabet).unwrap()
}
}
)+
}
}
define_alphabet_constructors! {
numeric => "0123456789",
hexa_decimal => "0123456789abcdef",
alpha_lower => "abcdefghijklmnopqrstuvwxyz",
alpha_upper => "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
alpha => "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
alpha_numeric => "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
}
impl Alphabet {
pub fn utf() -> Self {
let chars = (0..=1 << 16_u32)
.filter_map(char::from_u32)
.collect::<Vec<char>>();
Self {
min_text_length: min_plaintext_length(chars.len()),
chars,
}
}
pub fn chinese() -> Self {
let chars = (0x4E00..=0x9FFF_u32)
.filter_map(char::from_u32)
.collect::<Vec<char>>();
Self {
min_text_length: min_plaintext_length(chars.len()),
chars,
}
}
pub fn latin1sup() -> Self {
let chars = (0x0021..=0x007E_u32)
.chain(0x00C0..=0x00FF)
.filter_map(char::from_u32)
.collect::<Vec<char>>();
Self {
min_text_length: min_plaintext_length(chars.len()),
chars,
}
}
pub fn latin1sup_alphanum() -> Self {
let chars = (0x0030..=0x0039_u32)
.chain(0x0041..=0x005A)
.chain(0x0061..=0x007A)
.chain(0x00C0..=0x00FF)
.filter_map(char::from_u32)
.collect::<Vec<char>>();
Self {
min_text_length: min_plaintext_length(chars.len()),
chars,
}
}
}