use std::fmt::{Display, Formatter};
use openssl::error::ErrorStack;
use openssl::pkey::Private;
use openssl::rsa::{Padding, Rsa};
use openssl::symm::{decrypt, encrypt, Cipher};
use sha2::{Digest, Sha256 as sha2_256, Sha512 as sha2_512};
use crate::error::{InvalidArgumentError, LibError, MissingArgumentError};
#[derive(PartialEq, Debug)]
pub struct CryptoError {
message: String,
}
impl Default for CryptoError {
fn default() -> Self {
CryptoError {
message: "암호화 처리중 오류가 발생하였습니다.".to_owned(),
}
}
}
impl Display for CryptoError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Encrypt/Decrypt error.")
}
}
impl From<&str> for CryptoError {
fn from(value: &str) -> Self {
CryptoError {
message: value.to_owned(),
}
}
}
impl LibError for CryptoError {
fn get_message(&self) -> &str {
self.message.as_str()
}
fn get_type_name_from_instance(&self) -> &str {
std::any::type_name::<CryptoError>()
}
}
#[derive(PartialEq)]
#[allow(non_camel_case_types)]
pub enum SHA_TYPE {
SHA_256,
SHA_512,
}
#[derive(PartialEq)]
#[allow(non_camel_case_types)]
pub enum AES_TYPE {
AES_128,
AES_256,
}
pub fn make_sha_hash(
hash_type: SHA_TYPE,
target: &[u8],
salt: Option<&str>,
) -> Result<Box<[u8]>, MissingArgumentError> {
if target.is_empty() {
return Err(MissingArgumentError::from("Hash 대상이 빈 문자열 입니다."));
}
return match hash_type {
SHA_TYPE::SHA_256 => _hash_::<sha2_256>(target, salt),
SHA_TYPE::SHA_512 => _hash_::<sha2_512>(target, salt),
};
fn _hash_<D: Digest>(
target: &[u8],
salt: Option<&str>,
) -> Result<Box<[u8]>, MissingArgumentError> {
let mut _hash = D::new();
_hash.update(target);
if !salt.is_none() && !salt.unwrap().is_empty() {
_hash.update(salt.unwrap().as_bytes());
}
let result: Vec<u8> = _hash.finalize().to_vec();
return Ok(Box::from(result.as_slice()));
}
}
pub fn make_sha_hash_string(
hash_type: SHA_TYPE,
target: &[u8],
salt: Option<&str>,
) -> Result<String, MissingArgumentError> {
let result = make_sha_hash(hash_type, target, salt);
match result {
Ok(r) => {
let v: Vec<String> = r.iter().map(|b| format!("{:02x}", b)).collect();
Ok(v.join(""))
}
Err(e) => Err(e),
}
}
#[derive(Debug)]
pub struct AESResult {
salt: Option<Vec<u8>>,
result: Vec<u8>,
result_str: Option<String>,
iv: Vec<u8>,
}
impl AESResult {
fn new(salt: Option<&[u8]>, result: &[u8], iv: &[u8]) -> Self {
AESResult {
salt: match salt {
None => None,
Some(v) => Some(Vec::from(v)),
},
result: Vec::from(result),
result_str: {
let v = Vec::from(result);
let v: Vec<String> = v.iter().map(|b| format!("{:02x}", b)).collect();
Some(v.join(""))
},
iv: Vec::from(iv),
}
}
#[inline]
pub fn salt(&self) -> Option<&[u8]> {
return match &self.salt {
None => None,
Some(v) => {
return Some(v.as_ref());
}
};
}
#[inline]
pub fn result(&self) -> &[u8] {
self.result.as_ref()
}
#[inline]
pub fn result_str(&self) -> Option<&str> {
match &self.result_str {
None => None,
Some(v) => Some(v.as_str()),
}
}
#[inline]
pub fn iv(&self) -> &[u8] {
self.iv.as_ref()
}
#[deprecated(note = "salt(&self)로 대체. 삭제 예정.")]
pub fn get_salt(&self) -> Option<&[u8]> {
return match &self.salt {
None => None,
Some(v) => {
return Some(v.as_ref());
}
};
}
#[deprecated(note = "result(&self)로 대체. 삭제 예정.")]
pub fn get_result(&self) -> &[u8] {
return self.result.as_ref();
}
#[deprecated(note = "iv(&self)로 대체. 삭제 예정.")]
pub fn get_iv(&self) -> &[u8] {
return self.iv.as_ref();
}
}
impl Display for AESResult {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"salt : {:#?}\n, result : {:#?}\n, iv : {:#?}",
self.salt, self.result, self.iv
)
}
}
pub fn validate_salt(salt: Option<&[u8]>) -> Result<(), InvalidArgumentError> {
return match salt {
None => Ok(()),
Some(v) => {
return if v.len() != 8 {
Err(InvalidArgumentError::from(
"Salt length is invalid(must 8 bytes)",
))
} else {
Ok(())
};
}
};
}
pub fn aes_encrypt(
enc_type: AES_TYPE,
target: &[u8],
secret: &[u8],
salt: Option<&[u8]>,
repeat_count: usize,
) -> Result<AESResult, Box<dyn LibError>> {
if target.is_empty() {
return Err(Box::from(InvalidArgumentError::from(
"암호화 대상이 빈 문자열 입니다",
)));
}
let validate_salt = validate_salt(salt);
if validate_salt.is_err() {
return Err(Box::from(validate_salt.err().unwrap()));
}
let cipher = if AES_TYPE::AES_128 == enc_type {
Cipher::aes_128_cbc()
} else {
Cipher::aes_256_cbc()
};
let key_spec = openssl::pkcs5::bytes_to_key(
cipher,
openssl::hash::MessageDigest::md5(),
secret,
salt,
repeat_count as i32,
);
if key_spec.is_err() {
eprintln!("AES error : {:#?}", key_spec.err());
return Err(Box::from(CryptoError::from(
"AES 암호화 처리 중 오류가 발생하였습니다.",
)));
}
let unwrapped_spec = key_spec.unwrap();
let key = unwrapped_spec.key;
let iv = unwrapped_spec.iv.unwrap();
let result: Result<Vec<u8>, ErrorStack> =
encrypt(cipher, key.as_slice(), Some(iv.as_slice()), target);
match result {
Ok(vv) => Ok(AESResult::new(salt, vv.as_slice(), iv.as_slice())),
Err(e) => {
eprintln!("AES encrypt error : {:#?}", e);
Err(Box::from(InvalidArgumentError::from("암호화 처리 오류")))
}
}
}
pub fn aes_decrypt(
enc_type: AES_TYPE,
target: Option<&[u8]>,
secret: &[u8],
iv: &[u8],
salt: Option<&[u8]>,
repeat_count: usize,
) -> Result<Box<[u8]>, Box<dyn LibError>> {
match target {
None => Err(Box::from(MissingArgumentError::from(
"복호화 대상이 지정되지 않았습니다.",
))),
Some(v) => {
if v.len() == 0 {
return Err(Box::from(InvalidArgumentError::from(
"복호화 대상의 길이가 0 입니다.",
)));
}
let validate_salt = validate_salt(salt);
if validate_salt.is_err() {
return Err(Box::from(validate_salt.err().unwrap()));
}
let cipher = if AES_TYPE::AES_128 == enc_type {
Cipher::aes_128_cbc()
} else {
Cipher::aes_256_cbc()
};
let key_spec = openssl::pkcs5::bytes_to_key(
cipher,
openssl::hash::MessageDigest::md5(),
secret,
salt,
repeat_count as i32,
);
if key_spec.is_err() {
eprintln!("AES error: {:#?}", key_spec.err());
return Err(Box::from(CryptoError::from(
"AES 복호화 처리 중 오류가 발생하였습니다.",
)));
}
let unwrapped_spec = key_spec.unwrap();
let key = unwrapped_spec.key;
let result = decrypt(cipher, key.as_slice(), Some(iv), v);
match result {
Ok(vv) => Ok(Box::from(vv.as_slice())),
Err(e) => {
eprintln!("AES decrypt error: {:#?}", e);
Err(Box::from(InvalidArgumentError::from("복호화 처리 오류")))
}
}
}
}
}
#[allow(non_camel_case_types)]
pub enum RSA_BIT {
B_1024,
B_2048,
B_4096,
B_8192,
}
impl RSA_BIT {
pub fn bit(&self) -> usize {
match self {
RSA_BIT::B_1024 => 1024usize,
RSA_BIT::B_2048 => 2048usize,
RSA_BIT::B_4096 => 4096usize,
RSA_BIT::B_8192 => 8192usize,
}
}
pub fn bytes(&self) -> u16 {
match self {
RSA_BIT::B_1024 => 128,
RSA_BIT::B_2048 => 256,
RSA_BIT::B_4096 => 512,
RSA_BIT::B_8192 => 1024,
}
}
}
pub struct RSAResult {
public_key: Vec<u8>,
public_modulus: Vec<u8>,
public_exponent: Vec<u8>,
private_key: Vec<u8>,
private_modulus: Vec<u8>,
private_exponent: Vec<u8>,
result: Vec<u8>,
result_str: Option<String>,
}
impl RSAResult {
pub fn new(
pub_key: &[u8],
pub_mod: &[u8],
pub_exp: &[u8],
prv_key: &[u8],
prv_mod: &[u8],
prv_exp: &[u8],
result: &[u8],
) -> Self {
RSAResult {
public_key: Vec::from(pub_key),
public_modulus: Vec::from(pub_mod),
public_exponent: Vec::from(pub_exp),
private_key: Vec::from(prv_key),
private_modulus: Vec::from(prv_mod),
private_exponent: Vec::from(prv_exp),
result: Vec::from(result),
result_str: {
let v = Vec::from(result);
let v: Vec<String> = v.iter().map(|b| format!("{:02x}", b)).collect();
Some(v.join(""))
},
}
}
#[inline]
pub fn public_key(&self) -> &[u8] {
self.public_key.as_ref()
}
#[inline]
pub fn public_modulus(&self) -> &[u8] {
self.public_modulus.as_ref()
}
#[inline]
pub fn public_exponent(&self) -> &[u8] {
self.public_exponent.as_ref()
}
#[inline]
pub fn private_key(&self) -> &[u8] {
self.private_key.as_ref()
}
#[inline]
pub fn private_modulus(&self) -> &[u8] {
self.private_modulus.as_ref()
}
#[inline]
pub fn private_exponent(&self) -> &[u8] {
self.private_exponent.as_ref()
}
#[inline]
pub fn result(&self) -> &[u8] {
self.result.as_ref()
}
#[inline]
pub fn result_str(&self) -> Option<&str> {
match &self.result_str {
None => None,
Some(v) => Some(v.as_str()),
}
}
#[deprecated(note = "public_key(&self)로 대체. 삭제 예정.")]
pub fn get_public_key(&self) -> &[u8] {
self.public_key.as_ref()
}
#[deprecated(note = "public_modulus(&self)로 대체. 삭제 예정.")]
pub fn get_public_modulus(&self) -> &[u8] {
self.public_modulus.as_ref()
}
#[deprecated(note = "public_exponent(&self)로 대체. 삭제 예정.")]
pub fn get_public_exponent(&self) -> &[u8] {
self.public_exponent.as_ref()
}
#[deprecated(note = "private_key(&self)로 대체. 삭제 예정.")]
pub fn get_private_key(&self) -> &[u8] {
self.private_key.as_ref()
}
#[deprecated(note = "private_modulus(&self)로 대체. 삭제 예정.")]
pub fn get_private_modulus(&self) -> &[u8] {
self.private_modulus.as_ref()
}
#[deprecated(note = "private_exponent(&self)로 대체. 삭제 예정.")]
pub fn get_private_exponent(&self) -> &[u8] {
self.private_exponent.as_ref()
}
#[deprecated(note = "result(&self)로 대체. 삭제 예정.")]
pub fn get_result(&self) -> &[u8] {
self.result.as_ref()
}
}
pub fn generate_rsa_keypair(bit_size: RSA_BIT) -> Result<Rsa<Private>, CryptoError> {
let rsa: Result<Rsa<Private>, ErrorStack> = Rsa::generate(bit_size.bit() as u32);
if rsa.is_err() {
eprintln!("Generate RSA key pair fail : {:#?}", rsa.err());
return Err(CryptoError::from(
"RSA key pair 생성 중 오류가 발생하였습니다.",
));
}
return Ok(rsa.unwrap());
}
pub fn rsa_encrypt_without_key(
target: &[u8],
bit_size: RSA_BIT,
) -> Result<Box<RSAResult>, CryptoError> {
let key_pair: Rsa<Private> = generate_rsa_keypair(bit_size)?;
let public_key = key_pair.public_key_to_pem();
let private_key = key_pair.private_key_to_pem();
if public_key.is_err() {
eprintln!("public key error: {:#?}", public_key.err());
return Err(CryptoError::from("Public key에서 오류가 발생하였습니다."));
}
if private_key.is_err() {
eprintln!("private key error: {:#?}", private_key.err());
return Err(CryptoError::from("Private key에서 오류가 발생하였습니다."));
}
let unwrapped_pub_key = public_key.unwrap();
let unwrapped_prv_key = private_key.unwrap();
let result = rsa_encrypt(target, unwrapped_pub_key.as_slice())?;
let rsa_result = RSAResult::new(
unwrapped_pub_key.as_slice(),
key_pair.n().to_vec().as_slice(),
key_pair.e().to_vec().as_slice(),
unwrapped_prv_key.as_slice(),
key_pair.n().to_vec().as_slice(),
key_pair.d().to_vec().as_slice(),
result.as_ref(),
);
return Ok(Box::from(rsa_result));
}
pub fn rsa_decrypt(target: &[u8], prv_key: &[u8]) -> Result<Vec<u8>, CryptoError> {
let private_key = Rsa::private_key_from_pem(prv_key);
if private_key.is_err() {
eprintln!("개인키 생성 오류: {:#?}", private_key.err());
return Err(CryptoError::from("개인키 오류가 발생하였습니다."));
}
let rsa = private_key.unwrap();
let mut buffer: Vec<u8> = vec![0; rsa.size() as usize];
let result = rsa.private_decrypt(target, &mut buffer, Padding::PKCS1);
if result.is_err() {
eprintln!("RSA decrypt error : {:#?}", result.err());
return Err(CryptoError::from(
"RSA 복호화 처리 중 오류가 발생하였습니다.",
));
}
let real_size = result.unwrap();
let final_result = &buffer[0..real_size];
return Ok(Vec::from(final_result)); }
fn rsa_encrypt(target: &[u8], pub_key: &[u8]) -> Result<Box<[u8]>, CryptoError> {
let public_key = Rsa::public_key_from_pem(pub_key).unwrap();
let rsa = Rsa::from(public_key);
let mut buffer = vec![0; rsa.size() as usize];
let result = rsa.public_encrypt(target, &mut buffer, Padding::PKCS1);
if result.is_err() {
eprintln!("RSA encrypt error : {:#?}", result.err());
return Err(CryptoError::from(
"RSA 암호화 처리 중 오류가 발생하였습니다.",
));
}
return Ok(Box::from(buffer.as_slice()));
}
#[cfg(test)]
mod tests {
use base64::prelude::*;
use super::*;
const PLAIN_TEXT: &str = "This 이것, That 저것";
#[test]
pub fn make_sha_hash_test() {
let mut result: Result<Box<[u8]>, MissingArgumentError> =
make_sha_hash(SHA_TYPE::SHA_256, "test".as_bytes(), Some("salt"));
assert!(!result.is_err());
let v: Vec<String> = result
.unwrap()
.iter()
.map(|b| format!("{:02x}", b))
.collect();
println!("SHA-256 result : {}", v.join(""));
result = make_sha_hash(SHA_TYPE::SHA_512, "test".as_bytes(), Some("salt"));
assert!(!result.is_err());
let v: Vec<String> = result
.unwrap()
.iter()
.map(|b| format!("{:02x}", b))
.collect();
let v = v.join("");
println!("SHA-512 result : {}", v);
let vv = make_sha_hash_string(SHA_TYPE::SHA_512, "test".as_bytes(), Some("salt"));
assert!(vv.is_ok(), "make_sha_hash_string error => {:#?}", vv.err());
assert_eq!(v, vv.unwrap(), "hash string 불일치")
}
#[test]
pub fn aes_encrypt_test() {
let repeat_count = 10usize;
let result: Result<AESResult, Box<dyn LibError>> = aes_encrypt(
AES_TYPE::AES_128,
PLAIN_TEXT.as_bytes(),
"abc".as_bytes(),
Some("salt".as_bytes()),
10,
);
assert!(result.is_err());
let err = result.err().unwrap();
let err_name = err.get_type_name_from_instance();
assert_eq!(err_name, std::any::type_name::<InvalidArgumentError>());
println!("err_name : {}", err_name);
let encrypt_result = aes_encrypt(
AES_TYPE::AES_128,
PLAIN_TEXT.as_bytes(),
"abcdefgh".as_bytes(),
Some("saltsalt".as_bytes()), repeat_count,
);
assert!(!encrypt_result.is_err(), "aes 암호화 오류 발생");
let result_value = encrypt_result.unwrap();
println!("unwrapped value : {:#?}", result_value);
println!("unwrapped result value : {:#?}", result_value.result);
assert!(result_value.result_str().is_some());
let raw_result: Vec<String> = result_value
.result()
.iter()
.map(|b| format!("{:02x}", b))
.collect();
let raw_result: String = raw_result.join("");
assert_eq!(raw_result, result_value.result_str().unwrap());
println!("aes result str ===> {}", result_value.result_str().unwrap());
let encoded_value = BASE64_STANDARD.encode(result_value.result.clone());
println!("aes base64 encoded value : {:#?}", encoded_value);
let mut salt: Option<&[u8]> = None;
let unwrapped_salt: Vec<u8>;
if result_value.salt.is_some() {
unwrapped_salt = result_value.salt.unwrap();
salt = Some(unwrapped_salt.as_slice());
}
println!("final sal : {:#?}", salt);
let decrypt_result = aes_decrypt(
AES_TYPE::AES_128,
Some(result_value.result.as_ref()),
b"abcdefgh",
result_value.iv.as_ref(),
salt,
repeat_count,
);
assert!(!decrypt_result.is_err(), "aes 복호화 오류 발생");
let decrypted_raw_value = decrypt_result.unwrap();
let decrypted_value = decrypted_raw_value.as_ref();
assert_eq!(
PLAIN_TEXT,
String::from_utf8_lossy(decrypted_value),
"복호화 값 불일치"
);
println!(
"decrypted text: {:?}",
String::from_utf8_lossy(decrypted_value)
);
}
#[test]
pub fn rsa_encrypt_test() {
let key_pair = generate_rsa_keypair(RSA_BIT::B_4096);
let result1 = rsa_encrypt(
PLAIN_TEXT.as_bytes(),
key_pair.unwrap().public_key_to_pem().unwrap().as_slice(),
);
assert!(!result1.is_err(), "RSA 2048 암호화 실패");
let result_raw = result1.unwrap();
assert_eq!(
result_raw.len(),
RSA_BIT::B_4096.bytes() as usize,
"암호화 결과 길이 불일치"
);
println!(
"rsa result(4096) : {:?}\nlength : {}",
result_raw,
result_raw.len()
);
let encoded_value = BASE64_STANDARD.encode(result_raw);
println!("rsa base 64 encoded value : {:?}", encoded_value);
let key_pair = generate_rsa_keypair(RSA_BIT::B_8192);
let result1 = rsa_encrypt(
PLAIN_TEXT.as_bytes(),
key_pair.unwrap().public_key_to_pem().unwrap().as_slice(),
);
assert!(!result1.is_err(), "RSA 8192 암호화 실패");
let result_raw = result1.unwrap();
assert_eq!(
result_raw.len(),
RSA_BIT::B_8192.bytes() as usize,
"암호화 결과 길이 불일치"
);
println!(
"rsa result(8192) : {:?}\nlength : {}",
result_raw,
result_raw.len()
);
let result2 = rsa_encrypt_without_key(PLAIN_TEXT.as_bytes(), RSA_BIT::B_2048);
assert!(result2.is_ok());
let result2_raw = result2.unwrap();
assert!(result2_raw.private_key().len() > 0, "개인키 반환 실패");
assert!(
result2_raw.private_exponent().len() > 0,
"개인키 지수 반환 실패"
);
assert!(
result2_raw.private_modulus().len() > 0,
"개인키 계수 반환 실패"
);
assert!(result2_raw.public_key().len() > 0, "공개키 반환 실패");
assert!(
result2_raw.public_exponent().len() > 0,
"공개키 지수 반환 실패"
);
assert!(
result2_raw.public_modulus().len() > 0,
"공개키 계수 반환 실패"
);
assert!(result2_raw.result().len() > 0, "암호화 결과 반환 실패");
assert_eq!(
result2_raw.result().len(),
RSA_BIT::B_2048.bytes() as usize,
"암호화 결과 길이 불일치"
);
assert!(result2_raw.result_str().is_some());
let raw_result: Vec<String> = result2_raw
.result()
.iter()
.map(|b| format!("{:02x}", b))
.collect();
let raw_result = raw_result.join("");
assert_eq!(raw_result, result2_raw.result_str().unwrap());
println!("rsa result str ===> {}", result2_raw.result_str().unwrap());
let decrypt2 = rsa_decrypt(result2_raw.result(), result2_raw.private_key());
assert!(!decrypt2.is_err());
let decrypt2_raw = decrypt2.unwrap();
let decrypt2_result = String::from_utf8(decrypt2_raw.to_vec()).unwrap();
assert_eq!(decrypt2_result, PLAIN_TEXT, "복호화 실패");
println!("원문: {:?}\n복호화 결과: {:?}", PLAIN_TEXT, decrypt2_result);
}
}