use rsa::pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey};
use std::fmt;
use subtle::{Choice, ConstantTimeEq};
#[cfg(feature = "memquota-memcost")]
use {derive_deftly::Deftly, tor_memquota_cost::derive_deftly_template_HasMemoryCost};
use crate::util::{ct::CtByteArray, rng::RngCompat};
pub use rsa::Error;
pub const RSA_ID_LEN: usize = 20;
#[derive(Clone, Copy, Hash, Ord, PartialOrd, Eq, PartialEq)]
#[cfg_attr(
feature = "memquota-memcost",
derive(Deftly),
derive_deftly(HasMemoryCost)
)]
pub struct RsaIdentity {
id: CtByteArray<RSA_ID_LEN>,
}
impl ConstantTimeEq for RsaIdentity {
fn ct_eq(&self, other: &Self) -> Choice {
self.id.ct_eq(&other.id)
}
}
impl fmt::Display for RsaIdentity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "${}", hex::encode(&self.id.as_ref()[..]))
}
}
impl fmt::Debug for RsaIdentity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RsaIdentity {{ {} }}", self)
}
}
impl safelog::Redactable for RsaIdentity {
fn display_redacted(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "${}…", hex::encode(&self.id.as_ref()[..1]))
}
fn debug_redacted(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RsaIdentity {{ {} }}", self.redacted())
}
}
impl serde::Serialize for RsaIdentity {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&hex::encode(&self.id.as_ref()[..]))
} else {
serializer.serialize_bytes(&self.id.as_ref()[..])
}
}
}
impl<'de> serde::Deserialize<'de> for RsaIdentity {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
if deserializer.is_human_readable() {
struct RsaIdentityVisitor;
impl<'de> serde::de::Visitor<'de> for RsaIdentityVisitor {
type Value = RsaIdentity;
fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("hex-encoded RSA identity")
}
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
RsaIdentity::from_hex(s)
.ok_or_else(|| E::custom("wrong encoding for RSA identity"))
}
}
deserializer.deserialize_str(RsaIdentityVisitor)
} else {
struct RsaIdentityVisitor;
impl<'de> serde::de::Visitor<'de> for RsaIdentityVisitor {
type Value = RsaIdentity;
fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("RSA identity")
}
fn visit_bytes<E>(self, bytes: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
RsaIdentity::from_bytes(bytes)
.ok_or_else(|| E::custom("wrong length for RSA identity"))
}
}
deserializer.deserialize_bytes(RsaIdentityVisitor)
}
}
}
impl RsaIdentity {
pub fn as_bytes(&self) -> &[u8] {
&self.id.as_ref()[..]
}
pub fn as_hex_upper(&self) -> String {
hex::encode_upper(self.as_bytes())
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
Some(RsaIdentity {
id: CtByteArray::from(<[u8; RSA_ID_LEN]>::try_from(bytes).ok()?),
})
}
pub fn from_hex(s: &str) -> Option<Self> {
let mut array = [0_u8; 20];
match hex::decode_to_slice(s, &mut array) {
Err(_) => None,
Ok(()) => Some(RsaIdentity::from(array)),
}
}
pub fn is_zero(&self) -> bool {
self.id.ct_eq(&[0; RSA_ID_LEN].into()).into()
}
}
impl From<[u8; 20]> for RsaIdentity {
fn from(id: [u8; 20]) -> RsaIdentity {
RsaIdentity { id: id.into() }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PublicKey(rsa::RsaPublicKey);
pub struct KeyPair(rsa::RsaPrivateKey);
impl KeyPair {
pub fn generate<R: rand_core::RngCore + rand_core::CryptoRng>(
csprng: &mut R,
) -> Result<Self, tor_error::Bug> {
Ok(Self(
rsa::RsaPrivateKey::new(&mut RngCompat::new(csprng), 1024).map_err(|_| {
tor_error::internal!("Generating RSA key failed, despite fixed exponent and size")
})?,
))
}
pub fn to_public_key(&self) -> PublicKey {
PublicKey(self.0.to_public_key())
}
pub fn from_der(der: &[u8]) -> Option<Self> {
Some(KeyPair(rsa::RsaPrivateKey::from_pkcs1_der(der).ok()?))
}
pub fn as_key(&self) -> &rsa::RsaPrivateKey {
&self.0
}
pub fn sign(&self, message: &[u8]) -> Result<Vec<u8>, rsa::Error> {
self.0.sign(rsa::Pkcs1v15Sign::new_unprefixed(), message)
}
}
impl PublicKey {
pub fn exponent_is(&self, e: u32) -> bool {
use rsa::traits::PublicKeyParts;
*self.0.e() == rsa::BigUint::new(vec![e])
}
pub fn bits(&self) -> usize {
use rsa::traits::PublicKeyParts;
self.0.n().bits()
}
pub fn verify(&self, hashed: &[u8], sig: &[u8]) -> Result<(), signature::Error> {
let padding = rsa::pkcs1v15::Pkcs1v15Sign::new_unprefixed();
self.0
.verify(padding, hashed, sig)
.map_err(|_| signature::Error::new())
}
pub fn from_der(der: &[u8]) -> Option<Self> {
Some(PublicKey(rsa::RsaPublicKey::from_pkcs1_der(der).ok()?))
}
pub fn to_der(&self) -> Vec<u8> {
use der_parser::ber::BerObject;
use rsa::traits::PublicKeyParts;
let mut n = self.0.n().to_bytes_be();
if n[0] & 0b10000000 != 0 {
n.insert(0, 0_u8);
}
let n = BerObject::from_int_slice(&n);
let mut e = self.0.e().to_bytes_be();
if e[0] & 0b10000000 != 0 {
e.insert(0, 0_u8);
}
let e = BerObject::from_int_slice(&e);
let asn1 = BerObject::from_seq(vec![n, e]);
asn1.to_vec().expect("RSA key not encodable as DER")
}
pub fn to_rsa_identity(&self) -> RsaIdentity {
use crate::d::Sha1;
use digest::Digest;
let id: [u8; RSA_ID_LEN] = Sha1::digest(self.to_der()).into();
RsaIdentity { id: id.into() }
}
pub fn as_key(&self) -> &rsa::RsaPublicKey {
&self.0
}
}
impl<'a> From<&'a KeyPair> for PublicKey {
fn from(value: &'a KeyPair) -> Self {
PublicKey(value.to_public_key().0)
}
}
impl From<rsa::RsaPrivateKey> for KeyPair {
fn from(value: rsa::RsaPrivateKey) -> Self {
Self(value)
}
}
impl From<rsa::RsaPublicKey> for PublicKey {
fn from(value: rsa::RsaPublicKey) -> Self {
Self(value)
}
}
pub struct ValidatableRsaSignature {
key: PublicKey,
sig: Vec<u8>,
expected_hash: Vec<u8>,
}
impl ValidatableRsaSignature {
pub fn new(key: &PublicKey, sig: &[u8], expected_hash: &[u8]) -> Self {
ValidatableRsaSignature {
key: key.clone(),
sig: sig.into(),
expected_hash: expected_hash.into(),
}
}
}
impl super::ValidatableSignature for ValidatableRsaSignature {
fn is_valid(&self) -> bool {
self.key
.verify(&self.expected_hash[..], &self.sig[..])
.is_ok()
}
}