use common::alphabet;
use common::alphabet::Alphabet;
use common::cipher::Cipher;
use num::integer::gcd;
use rulinalg::matrix::{BaseMatrix, BaseMatrixMut, Matrix};
pub struct Hill {
key: Matrix<isize>,
}
impl Cipher for Hill {
type Key = Matrix<isize>;
type Algorithm = Hill;
fn new(key: Matrix<isize>) -> Result<Hill, &'static str> {
if key.cols() != key.rows() {
return Err("Key must be a square matrix.");
}
let m: Matrix<f64> = key.clone()
.try_into()
.expect("Could not convert Matrix of type `isize` to `f64`.");
if m.clone().inverse().is_err() || Hill::calc_inverse_key(m.clone()).is_err() {
return Err("The inverse of this matrix cannot be calculated for decryption.");
}
if gcd(m.clone().det() as isize, 26) != 1 {
return Err("The inverse determinant of the key cannot be calculated.");
}
Ok(Hill { key: key })
}
fn encrypt(&self, message: &str) -> Result<String, &'static str> {
Hill::transform_message(&self.key.clone().try_into().unwrap(), message)
}
fn decrypt(&self, ciphertext: &str) -> Result<String, &'static str> {
let inverse_key = Hill::calc_inverse_key(self.key.clone().try_into().unwrap())?;
Hill::transform_message(&inverse_key, ciphertext)
}
}
impl Hill {
pub fn from_phrase(phrase: &str, chunk_size: usize) -> Result<Hill, &'static str> {
if chunk_size < 2 {
return Err("The chunk size must be greater than 1.");
}
if chunk_size * chunk_size != phrase.len() {
return Err("The square of the chunk size must equal the length of the phrase.");
}
let mut matrix: Vec<isize> = Vec::new();
for c in phrase.chars() {
match alphabet::STANDARD.find_position(c) {
Some(pos) => matrix.push(pos as isize),
None => return Err("Phrase cannot contain non-alphabetic symbols."),
}
}
let key = Matrix::new(chunk_size, chunk_size, matrix);
Hill::new(key)
}
fn transform_message(key: &Matrix<f64>, message: &str) -> Result<String, &'static str> {
for c in message.chars() {
if alphabet::STANDARD.find_position(c).is_none() {
return Err(
"Invalid message. Please strip any whitespace or non-alphabetic symbols.",
);
}
}
let mut transformed_message = String::new();
let mut buffer = message.to_string();
let chunk_size = key.rows();
if buffer.len() % chunk_size > 0 {
let padding = chunk_size - (buffer.len() % chunk_size);
for _ in 0..padding {
buffer.push('a');
}
}
let mut i = 0;
while i < buffer.len() {
match Hill::transform_chunk(key, &buffer[i..(i + chunk_size)]) {
Ok(s) => transformed_message.push_str(&s),
Err(e) => return Err(e),
}
i += chunk_size;
}
Ok(transformed_message)
}
fn transform_chunk(key: &Matrix<f64>, chunk: &str) -> Result<String, &'static str> {
let mut transformed = String::new();
if key.rows() != chunk.len() {
return Err("Cannot perform transformation on unequal vector lengths");
}
let mut index_representation: Vec<f64> = Vec::new();
for c in chunk.chars() {
index_representation.push(alphabet::STANDARD
.find_position(c)
.expect("Attempted transformation of non-alphabetic symbol")
as f64);
}
let mut product = key * Matrix::new(index_representation.len(), 1, index_representation);
product = product.apply(&|x| (x % 26.0).round());
for (i, pos) in product.iter().enumerate() {
let orig = chunk
.chars()
.nth(i)
.expect("Expected to find char at index.");
transformed.push(
alphabet::STANDARD
.get_letter(*pos as usize, orig.is_uppercase())
.expect("Calculate index is invalid."),
);
}
Ok(transformed)
}
fn calc_inverse_key(key: Matrix<f64>) -> Result<Matrix<f64>, &'static str> {
let det = key.clone().det();
let det_inv = alphabet::STANDARD
.multiplicative_inverse(det as isize)
.expect("Inverse for determinant could not be found.");
Ok(key.inverse().unwrap().apply(&|x| {
let y = (x * det as f64).round() as isize;
(alphabet::STANDARD.modulo(y) as f64 * det_inv as f64) % 26.0
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn keygen_from_phrase() {
assert!(Hill::from_phrase("CEFJCBDRH", 3).is_ok());
}
#[test]
fn invalid_phrase() {
assert!(Hill::from_phrase("killer", 2).is_err());
}
#[test]
fn encrypt_no_padding_req() {
let h = Hill::new(Matrix::new(3, 3, vec![2, 4, 5, 9, 2, 1, 3, 17, 7])).unwrap();
let m = "ATTACKatDAWN";
assert_eq!(m, h.decrypt(&h.encrypt(m).unwrap()).unwrap());
}
#[test]
fn encrypt_with_symbols() {
let h = Hill::from_phrase("CEFJCBDRH", 3).unwrap();
assert!(h.encrypt("This won!t w@rk").is_err());
}
#[test]
fn decrypt_with_symbols() {
let h = Hill::from_phrase("CEFJCBDRH", 3).unwrap();
assert!(h.decrypt("This won!t w@rk").is_err());
}
#[test]
fn encrypt_padding_req() {
let h = Hill::new(Matrix::new(3, 3, vec![2, 4, 5, 9, 2, 1, 3, 17, 7])).unwrap();
let m = "ATTACKATDAWNz";
let e = h.encrypt(m).unwrap();
assert_eq!("PFOGOANPGXFXyrx", e);
let d = h.decrypt(&e).unwrap();
assert_eq!("ATTACKATDAWNzaa", d);
}
#[test]
fn valid_key() {
assert!(Hill::new(Matrix::new(3, 3, vec![2, 4, 5, 9, 2, 1, 3, 17, 7])).is_ok());
}
#[test]
fn non_square_matrix() {
assert!(Hill::new(Matrix::new(3, 2, vec![2, 4, 9, 2, 3, 17])).is_err());
}
#[test]
fn non_invertable_matrix() {
assert!(Hill::new(Matrix::new(3, 3, vec![2, 2, 3, 6, 6, 9, 1, 4, 8])).is_err());
}
}