use rand::rngs::OsRng;
use rsa::errors::Error;
use rsa::padding::PaddingScheme;
use rsa::pkcs8::ToPublicKey;
use rsa::{PublicKey, RsaPrivateKey, RsaPublicKey};
use sha2::Digest;
use crate::deon::error::DeonError;
const PUT_REQUEST_TAG: u8 = 0x00;
const POP_REQUEST_TAG: u8 = 0x01;
#[allow(clippy::large_enum_variant)]
pub enum ClientRequest {
Put {
public_key: Option<RsaPublicKey>,
payload: Option<Vec<u8>>,
surb: Option<Vec<u8>>,
},
Pop {
public_key: RsaPublicKey,
secret_key: Option<RsaPrivateKey>,
wingman: Vec<u8>,
},
}
impl ClientRequest {
fn encrypt(pk: &RsaPublicKey, cleartext: Vec<u8>) -> Result<Vec<u8>, Error> {
let mut rng = OsRng;
let padding = PaddingScheme::new_pkcs1v15_encrypt();
pk.encrypt(&mut rng, padding, &cleartext)
}
fn decrypt(sk: &RsaPrivateKey, ciphertext: &[u8]) -> Result<Vec<u8>, Error> {
let padding = PaddingScheme::new_pkcs1v15_encrypt();
sk.decrypt(padding, ciphertext)
}
fn sign(sk: &RsaPrivateKey, digest_in: &[u8]) -> Result<Vec<u8>, Error> {
let hash = rsa::Hash::SHA2_256;
let padding = PaddingScheme::new_pkcs1v15_sign(Some(hash));
let digest = sha2::Sha256::digest(digest_in).to_vec();
sk.sign(padding, &digest)
}
fn verify(pk: &RsaPublicKey, digest_in: &[u8], sig: &[u8]) -> Result<(), Error> {
let hash = rsa::Hash::SHA2_256;
let padding = PaddingScheme::new_pkcs1v15_sign(Some(hash));
let digest = sha2::Sha256::digest(digest_in).to_vec();
pk.verify(padding, &digest, sig)
}
fn serialize_put(
public_key: Option<RsaPublicKey>,
payload: Option<Vec<u8>>,
surb: Option<Vec<u8>>,
) -> Result<Vec<u8>, DeonError> {
let mut payload = match payload {
None => Vec::<u8>::new(),
Some(payload) => payload,
};
let mut surb = match surb {
None => Vec::<u8>::new(),
Some(surb) => surb,
};
let pem = match public_key {
None => Vec::<u8>::new(),
Some(public_key) => match public_key.to_public_key_pem() {
Err(e) => {
return Err(DeonError::RsaPkcs8(format!(
"Failed to create PEM from public key: {}",
e
)))
}
Ok(pem) => {
if !payload.is_empty() {
payload = match Self::encrypt(&public_key, payload) {
Err(e) => {
return Err(DeonError::RsaCrypto(format!(
"Failed to encrypt payload: {}",
e
)))
}
Ok(cyphertext) => cyphertext,
};
}
if !surb.is_empty() {
surb = match Self::encrypt(&public_key, surb) {
Err(e) => {
return Err(DeonError::RsaCrypto(format!(
"Failed to encrypt SURB: {}",
e
)))
}
Ok(cyphertext) => cyphertext,
};
}
pem.as_bytes().to_vec()
}
},
};
let surb_len_bytes = (surb.len() as u64).to_be_bytes();
let payload_len_bytes = (payload.len() as u64).to_be_bytes();
let pem_len_bytes = (pem.len() as u64).to_be_bytes();
Ok(std::iter::once(PUT_REQUEST_TAG)
.chain(pem_len_bytes)
.chain(pem)
.chain(payload_len_bytes)
.chain(payload)
.chain(surb_len_bytes)
.chain(surb)
.collect())
}
fn serialize_pop(
public_key: RsaPublicKey,
secret_key: Option<RsaPrivateKey>,
wingman: Vec<u8>,
) -> Result<Vec<u8>, DeonError> {
let secret_key = match secret_key {
None => {
return Err(DeonError::NoSecretKey(
"Cannot sign pop request without secret key".to_string(),
))
}
Some(secret_key) => secret_key,
};
let pem = match public_key.to_public_key_pem() {
Err(e) => {
return Err(DeonError::RsaPkcs8(format!(
"Failed to create PEM from public key: {}",
e
)))
}
Ok(pem) => pem.as_bytes().to_vec(),
};
let pem_len_bytes = (pem.len() as u64).to_be_bytes();
let mut vector: Vec<u8> = std::iter::once(POP_REQUEST_TAG)
.chain(wingman)
.chain(pem_len_bytes)
.chain(pem)
.collect();
let mut sig = match Self::sign(&secret_key, &vector) {
Err(e) => {
return Err(DeonError::RsaCrypto(format!(
"Failed to sign vector: {}",
e
)))
}
Ok(sig) => sig,
};
let sig_len_bytes = (sig.len() as u64).to_be_bytes();
vector.append(&mut sig_len_bytes.to_vec());
vector.append(&mut sig);
Ok(vector)
}
pub fn serialize(self) -> Result<Vec<u8>, DeonError> {
match self {
ClientRequest::Put {
public_key,
payload,
surb,
} => Ok(Self::serialize_put(public_key, payload, surb)?),
ClientRequest::Pop {
public_key,
secret_key,
wingman,
} => Ok(Self::serialize_pop(public_key, secret_key, wingman)?),
}
}
fn deserialize_put(b: &[u8], secret_key: Option<&RsaPrivateKey>) -> Result<Self, DeonError> {
if b.len() < 1 + 3 * std::mem::size_of::<u64>() {
return Err(DeonError::Malformed(format!(
"Expected at least {} bytes, got {}",
1 + 3 * std::mem::size_of::<u64>(),
b.len()
)));
}
let pem_len_bytes = u64::from_be_bytes(
b[1..1 + std::mem::size_of::<u64>()]
.as_ref()
.try_into()
.unwrap(),
);
if pem_len_bytes > (b.len() - 1 - 3 * std::mem::size_of::<u64>()) as u64 {
return Err(DeonError::Malformed(format!(
"Expected at most {} bytes for PEM, got {}",
b.len() - 1 - 3 * std::mem::size_of::<u64>(),
pem_len_bytes
)));
}
let pem_bound = 1 + std::mem::size_of::<u64>() + pem_len_bytes as usize;
let pem = &b[1 + std::mem::size_of::<u64>()..pem_bound].to_vec();
let payload_len_bytes = u64::from_be_bytes(
b[pem_bound..pem_bound + std::mem::size_of::<u64>()]
.as_ref()
.try_into()
.unwrap(),
);
if payload_len_bytes > (b.len() - pem_bound - 2 * std::mem::size_of::<u64>()) as u64 {
return Err(DeonError::Malformed(format!(
"Expected at most {} bytes for payload, got {}",
b.len() - pem_bound - 2 * std::mem::size_of::<u64>(),
payload_len_bytes
)));
}
let payload_bound = pem_bound + std::mem::size_of::<u64>() + payload_len_bytes as usize;
let payload = &b[pem_bound + std::mem::size_of::<u64>()..payload_bound].to_vec();
let surb_len_bytes = u64::from_be_bytes(
b[payload_bound..payload_bound + std::mem::size_of::<u64>()]
.as_ref()
.try_into()
.unwrap(),
);
if surb_len_bytes > (b.len() - payload_bound - std::mem::size_of::<u64>()) as u64 {
return Err(DeonError::Malformed(format!(
"Expected at most {} bytes for SURB, got {}",
b.len() - payload_bound - std::mem::size_of::<u64>(),
surb_len_bytes
)));
}
let surb_bound = payload_bound + std::mem::size_of::<u64>() + surb_len_bytes as usize;
let surb = &b[payload_bound + std::mem::size_of::<u64>()..surb_bound].to_vec();
let mut payload = match payload.is_empty() {
true => None,
false => Some(payload.to_owned()),
};
let mut surb = match surb.is_empty() {
true => None,
false => Some(surb.to_owned()),
};
let public_key = match pem.is_empty() {
true => None,
false => {
match secret_key {
None => {
payload = None;
surb = None;
}
Some(sk) => {
match payload {
None => {}
Some(given_payload) => {
payload = match Self::decrypt(sk, &given_payload) {
Err(e) => {
return Err(DeonError::RsaCrypto(format!(
"Failed to decrypt payload: {}",
e
)))
}
Ok(cleartext) => Some(cleartext),
};
}
}
match surb {
None => {}
Some(given_surb) => {
surb = match Self::decrypt(sk, &given_surb) {
Err(e) => {
return Err(DeonError::RsaCrypto(format!(
"Failed to decrypt SURB: {}",
e
)))
}
Ok(cleartext) => Some(cleartext),
};
}
}
}
}
match String::from_utf8(pem.to_owned()) {
Err(e) => {
return Err(DeonError::Malformed(format!(
"Failed to create PEM string from bytes: {}",
e
)))
}
Ok(string) => {
match rsa::pkcs8::FromPublicKey::from_public_key_pem(string.as_str()) {
Err(e) => {
return Err(DeonError::RsaPkcs8(format!(
"Failed to create public key from PEM: {}",
e
)))
}
Ok(public_key) => Some(public_key),
}
}
}
}
};
Ok(ClientRequest::Put {
public_key,
payload,
surb,
})
}
fn deserialize_pop(b: &[u8]) -> Result<Self, DeonError> {
if b.len() < 97 + 2 * std::mem::size_of::<u64>() {
return Err(DeonError::Malformed(format!(
"Expected at least {} bytes, got {}",
97 + 2 * std::mem::size_of::<u64>(),
b.len()
)));
}
let wingman = &b[1..97];
let pem_len_bytes = u64::from_be_bytes(
b[97..97 + std::mem::size_of::<u64>()]
.as_ref()
.try_into()
.unwrap(),
);
if pem_len_bytes > (b.len() - 97 - 2 * std::mem::size_of::<u64>()) as u64 {
return Err(DeonError::Malformed(format!(
"Expected at most {} bytes for PEM, got {}",
b.len() - 97 - 2 * std::mem::size_of::<u64>(),
b.len()
)));
}
let pem_bound = 97 + std::mem::size_of::<u64>() + pem_len_bytes as usize;
let pem = &b[97 + std::mem::size_of::<u64>()..pem_bound];
let sig_len_bytes = u64::from_be_bytes(
b[pem_bound..pem_bound + std::mem::size_of::<u64>()]
.as_ref()
.try_into()
.unwrap(),
);
if sig_len_bytes > (b.len() - pem_bound - std::mem::size_of::<u64>()) as u64 {
return Err(DeonError::Malformed(format!(
"Expected at most {} bytes for signature, got {}",
b.len() - pem_bound - std::mem::size_of::<u64>(),
sig_len_bytes
)));
}
let sig_bound = pem_bound + std::mem::size_of::<u64>() + sig_len_bytes as usize;
let sig = &b[pem_bound + std::mem::size_of::<u64>()..sig_bound];
let public_key = match String::from_utf8(pem.to_owned()) {
Err(e) => {
return Err(DeonError::Malformed(format!(
"Failed to create PEM string from bytes: {}",
e
)))
}
Ok(string) => match rsa::pkcs8::FromPublicKey::from_public_key_pem(string.as_str()) {
Err(e) => {
return Err(DeonError::RsaPkcs8(format!(
"Failed to create public key from PEM: {}",
e
)))
}
Ok(public_key) => public_key,
},
};
let digest_in: Vec<u8> = std::iter::once(POP_REQUEST_TAG)
.chain(wingman.to_vec())
.chain(pem_len_bytes.to_be_bytes())
.chain(pem.to_vec())
.collect();
match Self::verify(&public_key, &digest_in, sig) {
Err(e) => Err(DeonError::RsaCrypto(format!(
"Failed to verify pop request: {}",
e
))),
Ok(_) => Ok(ClientRequest::Pop {
public_key,
secret_key: None,
wingman: wingman.to_vec(),
}),
}
}
pub fn deserialize(b: &[u8], sk: Option<&RsaPrivateKey>) -> Result<Self, DeonError> {
if b.len() < std::mem::size_of::<u8>() {
return Err(DeonError::Malformed(format!(
"Expected at least {} bytes, got {}",
std::mem::size_of::<u8>(),
b.len()
)));
}
match b[0] {
PUT_REQUEST_TAG => Self::deserialize_put(b, sk),
POP_REQUEST_TAG => Self::deserialize_pop(b),
_ => Err(DeonError::Malformed(format!(
"Expected tag {} or {}, got {}",
PUT_REQUEST_TAG, POP_REQUEST_TAG, b[0]
))),
}
}
}