use common::alphabet::Alphabet;
use common::cipher::Cipher;
use common::{alphabet, substitute};
pub struct Caesar {
shift: usize,
}
impl Cipher for Caesar {
type Key = usize;
type Algorithm = Caesar;
fn new(shift: usize) -> Result<Caesar, &'static str> {
if shift >= 1 && shift <= 26 {
return Ok(Caesar { shift: shift });
}
Err("Invalid shift factor. Must be in the range 1-26")
}
fn encrypt(&self, message: &str) -> Result<String, &'static str> {
substitute::shift_substitution(message, |idx| {
alphabet::STANDARD.modulo((idx + self.shift) as isize)
})
}
fn decrypt(&self, ciphertext: &str) -> Result<String, &'static str> {
substitute::shift_substitution(ciphertext, |idx| {
alphabet::STANDARD.modulo(idx as isize - self.shift as isize)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_message() {
let c = Caesar::new(2).unwrap();
assert_eq!("Cvvcem cv fcyp!", c.encrypt("Attack at dawn!").unwrap());
}
#[test]
fn decrypt_message() {
let c = Caesar::new(2).unwrap();
assert_eq!("Attack at dawn!", c.decrypt("Cvvcem cv fcyp!").unwrap());
}
#[test]
fn with_utf8() {
let c = Caesar::new(3).unwrap();
let message = "Peace, Freedom and Liberty! 🗡️";
let encrypted = c.encrypt(message).unwrap();
let decrypted = c.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, message);
}
#[test]
fn exhaustive_encrypt() {
let message = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
for i in 1..27 {
let c = Caesar::new(i).unwrap();
let encrypted = c.encrypt(message).unwrap();
let decrypted = c.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, message);
}
}
#[test]
fn key_to_small() {
assert!(Caesar::new(0).is_err());
}
#[test]
fn key_to_big() {
assert!(Caesar::new(27).is_err());
}
}