use clap::{Args, Parser, Subcommand};
use std::fs;
use std::io::{self, Write};
use crate::bounded_input::{read_file_bounded, read_line_bounded};
use crate::caesar_cipher::{decrypt, decrypt_safe, encrypt, encrypt_safe};
use crate::config::{
DEFAULT_SHIFT, MAX_BRUTE_FORCE_SHIFT, MAX_INPUT_SIZE, MAX_SHIFT, MAX_SHIFT_LINE_SIZE, MIN_SHIFT,
};
#[derive(Parser)]
#[command(name = "caesar_cipher")]
#[command(about = "A Caesar cipher encryption/decryption tool")]
#[command(version)]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
}
#[derive(Args)]
pub struct CipherArgs {
#[arg(short, long, conflicts_with = "file")]
pub text: Option<String>,
#[arg(short = 'f', long, conflicts_with = "text")]
pub file: Option<String>,
#[arg(short, long, default_value_t = DEFAULT_SHIFT)]
pub shift: i16,
#[arg(short, long)]
pub output: Option<String>,
#[arg(long)]
pub safe: bool,
}
#[derive(Subcommand)]
pub enum Commands {
Encrypt(CipherArgs),
Decrypt(CipherArgs),
Interactive,
BruteForce {
#[arg(short, long, conflicts_with = "file")]
text: Option<String>,
#[arg(short = 'f', long, conflicts_with = "text")]
file: Option<String>,
},
}
pub fn run_cli() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
match cli.command {
Commands::Encrypt(args) => {
let input_text = get_input_text(args.text, args.file)?;
let result = if args.safe {
encrypt_safe(&input_text, args.shift)?
} else {
encrypt(&input_text, args.shift)
};
output_result(&result, args.output)?;
}
Commands::Decrypt(args) => {
let input_text = get_input_text(args.text, args.file)?;
let result = if args.safe {
decrypt_safe(&input_text, args.shift)?
} else {
decrypt(&input_text, args.shift)
};
output_result(&result, args.output)?;
}
Commands::Interactive => {
run_interactive_mode()?;
}
Commands::BruteForce { text, file } => {
let input_text = get_input_text(text, file)?;
run_brute_force(&input_text);
}
}
Ok(())
}
fn get_input_text(
text: Option<String>,
file: Option<String>,
) -> Result<String, Box<dyn std::error::Error>> {
if let Some(t) = text {
if t.len() > MAX_INPUT_SIZE {
return Err(format!(
"Input text exceeds maximum size of {} bytes",
MAX_INPUT_SIZE
)
.into());
}
return Ok(t);
}
if let Some(f) = file {
let input = read_file_bounded(&f, MAX_INPUT_SIZE).map_err(|e| e.to_string())?;
return Ok(trim_trailing_newline(&input).to_string());
}
print!("Enter text: ");
io::stdout().flush()?;
let mut stdin = io::stdin().lock();
let input = read_line_bounded(&mut stdin, MAX_INPUT_SIZE).map_err(|e| e.to_string())?;
Ok(trim_trailing_newline(&input).to_string())
}
fn trim_trailing_newline(input: &str) -> &str {
input.trim_end_matches(['\n', '\r'])
}
fn output_result(
result: &str,
output_file: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
match output_file {
Some(file) => {
fs::write(&file, result)?;
println!("Result written to file: {}", file);
}
None => {
println!("{}", result);
}
}
Ok(())
}
fn prompt_for_text(prompt: &str) -> io::Result<String> {
print!("{}", prompt);
io::stdout().flush()?;
let mut stdin = io::stdin().lock();
let input = read_line_bounded(&mut stdin, MAX_INPUT_SIZE)?;
Ok(trim_trailing_newline(&input).to_string())
}
pub(crate) fn validate_shift_input(input: &str) -> (i16, Option<String>) {
let trimmed = input.trim();
if trimmed.is_empty() {
return (DEFAULT_SHIFT, None);
}
match trimmed.parse::<i16>() {
Ok(shift) => {
if !(MIN_SHIFT..=MAX_SHIFT).contains(&shift) {
let warning = format!(
"Warning: shift {} is outside the typical range ({} to {}). Value will be normalized.",
shift, MIN_SHIFT, MAX_SHIFT
);
(shift, Some(warning))
} else {
(shift, None)
}
}
Err(_) => {
let warning = format!("Invalid shift value, using default ({})", DEFAULT_SHIFT);
(DEFAULT_SHIFT, Some(warning))
}
}
}
fn prompt_for_shift() -> io::Result<i16> {
print!("Enter shift value (default: {}): ", DEFAULT_SHIFT);
io::stdout().flush()?;
let mut stdin = io::stdin().lock();
let shift_str = read_line_bounded(&mut stdin, MAX_SHIFT_LINE_SIZE)?;
let (shift, warning) = validate_shift_input(&shift_str);
if let Some(msg) = warning {
println!("{}", msg);
}
Ok(shift)
}
fn run_interactive_mode() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Caesar Cipher Interactive Mode ===");
println!("Type 'quit' to exit");
loop {
let choice =
prompt_for_text("\nChoose operation (e)ncrypt, (d)ecrypt, (b)rute force, or (q)uit: ")?;
let choice = choice.to_lowercase();
match choice.as_str() {
"e" | "encrypt" => {
let text = prompt_for_text("Enter text to encrypt: ")?;
let shift = prompt_for_shift()?;
let result = encrypt(&text, shift);
println!("Encrypted: {}", result);
}
"d" | "decrypt" => {
let text = prompt_for_text("Enter text to decrypt: ")?;
let shift = prompt_for_shift()?;
let result = decrypt(&text, shift);
println!("Decrypted: {}", result);
}
"b" | "brute" | "bruteforce" => {
let text = prompt_for_text("Enter text to brute force decrypt: ")?;
run_brute_force(&text);
}
"q" | "quit" => {
println!("Goodbye!");
break;
}
_ => {
println!("Invalid option. Please choose e, d, b, or q.");
}
}
}
Ok(())
}
fn run_brute_force(text: &str) {
println!("\n=== Brute Force Decryption ===");
println!("Original: {}", text);
println!("Trying all possible shifts:");
for shift in 0..=MAX_BRUTE_FORCE_SHIFT {
let decrypted = decrypt(text, shift);
println!("Shift {:2}: {}", shift, decrypted);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_get_input_text_from_string() {
let result = get_input_text(Some("Hello".to_string()), None).unwrap();
assert_eq!(result, "Hello");
}
#[test]
fn test_get_input_text_from_file() {
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "Test content").unwrap();
let result =
get_input_text(None, Some(temp_file.path().to_string_lossy().to_string())).unwrap();
assert_eq!(result.trim(), "Test content");
}
#[test]
fn test_get_input_text_prefers_text_when_both_provided() {
let result = get_input_text(Some("Hello".to_string()), Some("file.txt".to_string()));
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Hello");
}
#[test]
fn test_trim_trailing_newline_preserves_leading_and_trailing_spaces() {
let input = " keep spaces \n";
let result = trim_trailing_newline(input);
assert_eq!(result, " keep spaces ");
}
#[test]
fn test_output_result_to_file() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_string_lossy().to_string();
output_result("Test output", Some(file_path.clone())).unwrap();
let content = fs::read_to_string(file_path).unwrap();
assert_eq!(content, "Test output");
}
#[test]
fn test_read_file_bounded_rejects_directory() {
let dir = tempfile::tempdir().unwrap();
let result = crate::bounded_input::read_file_bounded(dir.path(), MAX_INPUT_SIZE);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not a regular file"));
}
#[test]
fn test_get_input_text_oversized_file_error() {
use std::io::Write as _;
let mut temp_file = NamedTempFile::new().unwrap();
let oversized_data = vec![b'A'; MAX_INPUT_SIZE + 1];
temp_file.write_all(&oversized_data).unwrap();
temp_file.flush().unwrap();
let result = get_input_text(None, Some(temp_file.path().to_string_lossy().to_string()));
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("exceeds maximum size"));
}
#[test]
fn test_get_input_text_file_at_max_size_succeeds() {
use std::io::Write as _;
let mut temp_file = NamedTempFile::new().unwrap();
let data = vec![b'A'; MAX_INPUT_SIZE];
temp_file.write_all(&data).unwrap();
temp_file.flush().unwrap();
let result = get_input_text(None, Some(temp_file.path().to_string_lossy().to_string()));
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), MAX_INPUT_SIZE);
}
#[test]
fn test_validate_shift_input_valid_value() {
let (shift, warning) = validate_shift_input("3");
assert_eq!(shift, 3);
assert!(warning.is_none());
}
#[test]
fn test_validate_shift_input_zero() {
let (shift, warning) = validate_shift_input("0");
assert_eq!(shift, 0);
assert!(warning.is_none());
}
#[test]
fn test_validate_shift_input_max_boundary() {
let (shift, warning) = validate_shift_input("25");
assert_eq!(shift, 25);
assert!(warning.is_none());
}
#[test]
fn test_validate_shift_input_min_boundary() {
let (shift, warning) = validate_shift_input("-25");
assert_eq!(shift, -25);
assert!(warning.is_none());
}
#[test]
fn test_validate_shift_input_out_of_range_positive() {
let (shift, warning) = validate_shift_input("26");
assert_eq!(shift, 26);
assert!(warning.is_some());
assert!(warning.unwrap().contains("Warning"));
}
#[test]
fn test_validate_shift_input_out_of_range_negative() {
let (shift, warning) = validate_shift_input("-26");
assert_eq!(shift, -26);
assert!(warning.is_some());
assert!(warning.unwrap().contains("Warning"));
}
#[test]
fn test_validate_shift_input_far_out_of_range() {
let (shift, warning) = validate_shift_input("9999");
assert_eq!(shift, 9999);
assert!(warning.is_some());
assert!(warning.unwrap().contains("9999"));
}
#[test]
fn test_validate_shift_input_invalid_string() {
let (shift, warning) = validate_shift_input("abc");
assert_eq!(shift, DEFAULT_SHIFT);
assert!(warning.is_some());
assert!(warning.unwrap().contains("Invalid"));
}
#[test]
fn test_validate_shift_input_empty_string() {
let (shift, warning) = validate_shift_input("");
assert_eq!(shift, DEFAULT_SHIFT);
assert!(warning.is_none());
}
#[test]
fn test_validate_shift_input_whitespace_only() {
let (shift, warning) = validate_shift_input(" ");
assert_eq!(shift, DEFAULT_SHIFT);
assert!(warning.is_none());
}
#[test]
fn test_validate_shift_input_with_surrounding_whitespace() {
let (shift, warning) = validate_shift_input(" 5 ");
assert_eq!(shift, 5);
assert!(warning.is_none());
}
}