use crate::elliptic_curve::{Curve, ECOp, ECPoint, ECScalar};
use crate::hashing::hash_input;
use crate::internals::{decrypt, encrypt, ByteVector, PREError};
use crate::internals::{Nonce, NONCE_SIZE};
use sha2::{Sha256, Sha512};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ReEncryptionKey {
r1: ByteVector,
r2: ByteVector,
r3: ByteVector,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EncryptedMessage {
tag: ByteVector,
encrypted_key: ByteVector,
message_check_sum: ByteVector,
overall_check_sum: ByteVector,
data: ByteVector,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ReEncryptedMessage {
d1: ByteVector,
d2: ByteVector,
d3: ByteVector,
d4: ByteVector,
d5: ByteVector,
}
pub struct PREState {
pub curve: Curve,
private_key: ECScalar,
pub public_key: ECPoint,
}
impl PREState {
#[allow(dead_code)]
pub fn new(curve: Curve) -> Self {
let private_key = Curve::get_random_scalar();
let public_key = curve.base_point.multiply(&private_key);
PREState {
curve,
private_key,
public_key,
}
}
#[allow(dead_code)]
pub fn generate_re_encryption_key(
&self,
pub_k: &ByteVector,
tag: ByteVector,
) -> Result<ReEncryptionKey, PREError> {
let public_key = self.curve.get_point_from_bytes(pub_k)?;
let private_key = self.private_key.to_bytes();
let random_scalar = Curve::get_random_scalar();
let hash_output = hash_input::<Sha256, 32>(vec![tag.clone(), private_key.clone()])
.expect("Re-Encryption Key Generation Error: unable to hash (tag, private key) pair");
let h = self.curve.get_scalar_from_bytes(&hash_output.to_vec())?;
let factor = random_scalar.eval(Some(h), ECOp::Subtract).unwrap();
Ok(ReEncryptionKey {
r1: self.curve.base_point.multiply(&factor).to_bytes(),
r2: public_key.multiply(&random_scalar).to_bytes(),
r3: self
.curve
.get_scalar_from_hash(vec![tag.clone(), private_key])
.unwrap()
.to_bytes(),
})
}
#[allow(dead_code)]
async fn encrypt_symmetric(&self, data: &ByteVector, key_hash: &ByteVector) -> ByteVector {
let key: &[u8] = &key_hash[0..32];
let nonce: &[u8] = &key_hash[32..32 + NONCE_SIZE];
let cipher_text = encrypt(data, key, Some(Nonce::from_slice(nonce)), false).await;
match cipher_text {
Ok(encrypted_msg) => encrypted_msg,
Err(error) => {
panic!("Encrypt Symmetric Error : {} \n", error);
}
}
}
#[allow(dead_code)]
async fn decrypt_symmetric(
&self,
ciphertext: &ByteVector,
key_hash: &ByteVector,
) -> ByteVector {
let key: &[u8] = &key_hash[0..32];
let nonce: &[u8] = &key_hash[32..32 + NONCE_SIZE];
let original_plain_text =
decrypt(ciphertext, key, Some(Nonce::from_slice(nonce)), false).await;
match original_plain_text {
Ok(message) => message,
Err(error) => {
panic!("Decrypt Symmetric Error : {} \n", error);
}
}
}
#[allow(dead_code)]
pub async fn self_encrypt(
&self,
message: ByteVector,
tag: ByteVector,
) -> Result<EncryptedMessage, PREError> {
let random_scalar: ECScalar = Curve::get_random_scalar();
let public_key: ECPoint = self.curve.base_point.multiply(&random_scalar);
let private_key_vector: Vec<u8> = self.private_key.to_bytes();
let tag_private_key_hash: [u8; 32] =
hash_input::<Sha256, 32>(vec![tag.clone(), private_key_vector.clone()])?;
let scalar_tag_private_key: ECScalar = self
.curve
.get_scalar_from_bytes(&(tag_private_key_hash.to_vec()))?;
let hg: ECPoint = self.curve.base_point.multiply(&scalar_tag_private_key);
let encrypted_key: Vec<u8> = public_key.eval(&hg, ECOp::Add).unwrap().to_bytes();
let symmetric_encryption_key: [u8; 64] =
hash_input::<Sha512, 64>(vec![public_key.to_bytes()])?;
let data: Vec<u8> = self
.encrypt_symmetric(&message, &symmetric_encryption_key.to_vec())
.await;
let message_check_sum: [u8; 64] =
hash_input::<Sha512, 64>(vec![message, public_key.to_bytes()]).unwrap();
let alp: Vec<u8> = self
.curve
.get_scalar_from_hash(vec![tag.clone(), self.private_key.to_bytes()])
.unwrap()
.to_bytes();
let overall_check_sum: [u8; 64] = hash_input::<Sha512, 64>(vec![
encrypted_key.clone(),
data.clone(),
message_check_sum.to_vec(),
alp,
])?;
Ok(EncryptedMessage {
tag,
encrypted_key,
message_check_sum: message_check_sum.to_vec(),
overall_check_sum: overall_check_sum.to_vec(),
data,
})
}
#[allow(dead_code)]
pub async fn self_decrypt(
&self,
encrypted_message: EncryptedMessage,
) -> Result<ByteVector, PREError> {
let private_key: Vec<u8> = self.private_key.to_bytes();
let alp: Vec<u8> = self
.curve
.get_scalar_from_hash(vec![encrypted_message.tag.clone(), private_key.clone()])?
.to_bytes();
let first_check: [u8; 64] = hash_input::<Sha512, 64>(vec![
encrypted_message.encrypted_key.clone(),
encrypted_message.data.clone(),
encrypted_message.message_check_sum.clone(),
alp.clone(),
])?;
let overall_checksum_length: usize = encrypted_message.overall_check_sum.len();
let mut matching_values: usize = encrypted_message
.overall_check_sum
.iter()
.zip(&first_check.to_vec())
.filter(|&(x, y)| x == y)
.count();
if matching_values != overall_checksum_length {
return Err(PREError::OverallCheckSumFailure(String::from(
"self-decrypt: overall checksum failure.\n",
)));
}
let tag_private_key_hash: [u8; 32] =
hash_input::<Sha256, 32>(vec![encrypted_message.tag, private_key])?;
let scalar_tag_private_key: ECScalar = self
.curve
.get_scalar_from_bytes(&tag_private_key_hash.to_vec())?;
let hg: ECPoint = self.curve.base_point.multiply(&scalar_tag_private_key);
let encrypted_key: ECPoint = self
.curve
.get_point_from_bytes(&encrypted_message.encrypted_key)?;
let recovered_public_key: ByteVector =
encrypted_key.eval(&hg, ECOp::Subtract).unwrap().to_bytes();
let key: [u8; 64] = hash_input::<Sha512, 64>(vec![recovered_public_key.clone()]).unwrap();
let data: Vec<u8> = self
.decrypt_symmetric(&encrypted_message.data, &key.to_vec())
.await;
let check2: [u8; 64] = hash_input::<Sha512, 64>(vec![data.clone(), recovered_public_key])?;
matching_values = encrypted_message
.message_check_sum
.iter()
.zip(&check2.to_vec())
.filter(|&(x, y)| x == y)
.count();
if matching_values != encrypted_message.message_check_sum.len() {
return Err(PREError::MessageCheckSumFailure(String::from(
"self-decrypt: message checksum failure.",
)));
}
Ok(data)
}
#[allow(dead_code)]
pub fn re_encrypt(
&self,
public_key: &ByteVector,
message: EncryptedMessage,
re_encryption_key: ReEncryptionKey,
_curve: &Curve,
) -> Result<ReEncryptedMessage, PREError> {
let check1 = hash_input::<Sha512, 64>(vec![
message.encrypted_key.clone(),
message.data.clone(),
message.message_check_sum.clone(),
re_encryption_key.r3,
])?;
let matching_values = message
.overall_check_sum
.iter()
.zip(&check1.to_vec())
.filter(|&(x, y)| x == y)
.count();
if matching_values != message.overall_check_sum.len() {
return Err(PREError::OverallCheckSumFailure(String::from("")));
}
let p = self.curve.get_point_from_bytes(public_key)?;
let t = Curve::get_random_scalar();
let txg = p.multiply(&t);
let bet = self
.curve
.get_scalar_from_hash(vec![
txg.to_bytes(),
message.data.clone(),
message.message_check_sum.clone(),
re_encryption_key.r2.clone(),
self.curve.base_point.multiply(&t).to_bytes(),
])
.unwrap();
let r1 = self.curve.get_point_from_bytes(&re_encryption_key.r1)?;
let encrypted_key = self
.curve
.get_point_from_bytes(&message.encrypted_key)
.unwrap()
.eval(&r1, ECOp::Add)?;
Ok(ReEncryptedMessage {
d1: encrypted_key.multiply(&bet).to_bytes(),
d2: message.data,
d3: message.message_check_sum,
d4: re_encryption_key.r2,
d5: self.curve.base_point.multiply(&t).to_bytes(),
})
}
#[allow(dead_code)]
pub async fn re_decrypt(
&self,
re_encrypted_message: ReEncryptedMessage,
) -> Result<ByteVector, PREError> {
let d1 = self.curve.get_point_from_bytes(&re_encrypted_message.d1)?;
let d4 = self.curve.get_point_from_bytes(&re_encrypted_message.d4)?;
let d5 = self.curve.get_point_from_bytes(&re_encrypted_message.d5)?;
let txg = d5.multiply(&self.private_key);
let b_inv = self
.curve
.get_scalar_from_hash(vec![
txg.to_bytes(),
re_encrypted_message.d2.clone(),
re_encrypted_message.d3.clone(),
re_encrypted_message.d4,
re_encrypted_message.d5,
])
.unwrap()
.eval(None, ECOp::Invert)?;
let private_key_inv = self.private_key.eval(None, ECOp::Invert)?;
let t1 = d1.multiply(&b_inv);
let t2 = d4.multiply(&private_key_inv);
let t_buf = t1.eval(&t2, ECOp::Subtract)?.to_bytes();
let key = hash_input::<Sha512, 64>(vec![t_buf.clone()])?;
let data = self
.decrypt_symmetric(&re_encrypted_message.d2, &key.to_vec())
.await;
let check2 = hash_input::<Sha512, 64>(vec![data.clone(), t_buf])?;
let matching_values = re_encrypted_message
.d3
.iter()
.zip(&check2.to_vec())
.filter(|&(x, y)| x == y)
.count();
if matching_values != re_encrypted_message.d3.len() {
return Err(PREError::DefaultError(String::from("181?")));
}
Ok(data)
}
}
#[cfg(test)]
mod self_encryption_tests {
use super::*;
use crate::test_utils::{bytes_to_str_utf8, generate_test_files};
use futures::executor::block_on;
use std::fs;
use std::path::Path;
const TEST_DIR_PATH: &str = "test-files";
#[cfg(unix)]
fn _path_to_bytes<P: AsRef<Path>>(path: P) -> Vec<u8> {
use std::os::unix::ffi::OsStrExt;
path.as_ref().as_os_str().as_bytes().to_vec()
}
fn generate_plaintext_messages() -> Vec<(&'static str, &'static str)> {
let messages = vec![
("first message to encrypt", "tag1"),
("second message to encrypt", "tag2"),
("third message to encrypt", "tag3"),
];
messages
}
#[test]
fn test_simple_self_encryption() {
let plaintext_messages = generate_plaintext_messages();
let curve: Curve = Curve::new();
let pre_state = PREState::new(curve);
for (message, tag) in plaintext_messages {
let message_as_bytes = message.as_bytes();
let encrypted_message: EncryptedMessage =
block_on(pre_state.self_encrypt(message_as_bytes.into(), tag.as_bytes().into()))
.unwrap();
let decrypted_ciphertext: ByteVector =
block_on(pre_state.self_decrypt(encrypted_message.clone())).unwrap();
assert_eq!(decrypted_ciphertext.as_slice(), message_as_bytes);
assert_eq!(tag.as_bytes(), encrypted_message.tag.as_slice());
}
}
#[test]
fn test_file_self_encryption() {
generate_test_files();
let pre_state: PREState = PREState::new(Curve::new());
for dir_entry in fs::read_dir(TEST_DIR_PATH).unwrap() {
let path_buf = dir_entry.unwrap().path();
let contents = std::fs::read_to_string(path_buf.as_path())
.expect("Unable to read test file contents");
let encrypted_file: EncryptedMessage = block_on(pre_state.self_encrypt(
contents.clone().as_bytes().to_vec(),
String::from("dummy tag").into_bytes(),
))
.expect("failed to encrypt file");
let decrypted_file: ByteVector =
block_on(pre_state.self_decrypt(encrypted_file)).expect("failed to decrypt file");
dbg!(bytes_to_str_utf8(decrypted_file.as_slice()));
assert_eq!(decrypted_file.as_slice(), contents.as_bytes());
}
}
}
#[cfg(test)]
mod test_re_encryption {
use super::*;
use futures::executor::block_on;
use std::fs;
const BYTE_LENGTH: usize = 32;
use crate::test_utils::{bytes_to_str_utf8, generate_test_files, remove_test_files};
const TEST_DIR_PATH: &str = "test-files";
#[test]
fn test_generate_re_encryption_key() {
let pre_state: PREState = PREState::new(Curve::new());
let public_key = pre_state.public_key.to_bytes();
let tag = "dummy tag".as_bytes().to_vec();
let re_encryption_key = pre_state
.generate_re_encryption_key(&public_key, tag.clone())
.unwrap();
assert_eq!(re_encryption_key.r1.len(), BYTE_LENGTH);
assert_eq!(re_encryption_key.r2.len(), BYTE_LENGTH);
assert_eq!(re_encryption_key.r3.len(), BYTE_LENGTH);
}
#[test]
fn test_file_re_encrypt() {
generate_test_files();
let curve = Curve::new();
let pre_state: PREState = PREState::new(curve.clone());
for dir_entry in fs::read_dir(TEST_DIR_PATH).unwrap() {
let path_buf = dir_entry.unwrap().path();
dbg!(path_buf.clone());
let file_contents = std::fs::read_to_string(path_buf.as_path())
.expect("Unable to read test file contents.");
let encrypted_file: EncryptedMessage = block_on(pre_state.self_encrypt(
file_contents.clone().as_bytes().to_vec(),
String::from("dummy tag").into_bytes(),
))
.expect("failed to self-encrypt test file");
let public_key: ByteVector = pre_state.public_key.to_bytes();
let re_encryption_key: ReEncryptionKey = pre_state
.generate_re_encryption_key(&public_key, encrypted_file.tag.clone())
.unwrap();
let re_encrypted_file: ReEncryptedMessage = pre_state
.re_encrypt(&public_key, encrypted_file, re_encryption_key, &curve)
.expect("failed to re-encrypt test file");
let decrypted_file_under_rekey: ByteVector =
block_on(pre_state.re_decrypt(re_encrypted_file))
.expect("failed to decrypt re-encrypted test file");
dbg!(bytes_to_str_utf8(decrypted_file_under_rekey.as_slice()));
assert_eq!(
decrypted_file_under_rekey.as_slice(),
file_contents.as_bytes()
);
}
remove_test_files();
}
}