use std::{
cmp::Ordering,
convert::TryInto,
fmt::{self, Debug, Display, Formatter},
hash::Hash,
marker::PhantomData,
path::Path,
str,
time::{SystemTime, UNIX_EPOCH},
};
use anyhow::Context;
use datasize::DataSize;
use hex_fmt::HexFmt;
use nid::Nid;
use openssl::{
asn1::{Asn1Integer, Asn1IntegerRef, Asn1Time},
bn::{BigNum, BigNumContext},
ec,
error::ErrorStack,
hash::{DigestBytes, MessageDigest},
nid,
pkey::{PKey, PKeyRef, Private},
sha,
ssl::{SslAcceptor, SslConnector, SslContextBuilder, SslMethod, SslVerifyMode, SslVersion},
x509::{X509Builder, X509Name, X509NameBuilder, X509NameRef, X509Ref, X509},
};
#[cfg(test)]
use rand::{
distributions::{Distribution, Standard},
Rng,
};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
use crate::utils::{read_file, write_file};
mod big_array {
use serde_big_array::big_array;
big_array! { BigArray; }
}
const SIGNATURE_ALGORITHM: Nid = Nid::ECDSA_WITH_SHA512;
const SIGNATURE_CURVE: Nid = Nid::SECP521R1;
const SIGNATURE_DIGEST: Nid = Nid::SHA512;
type SslResult<T> = Result<T, ErrorStack>;
#[derive(Copy, Clone, DataSize, Deserialize, Serialize)]
struct Sha512(#[serde(with = "big_array::BigArray")] [u8; Sha512::SIZE]);
impl Sha512 {
const SIZE: usize = 64;
const NID: Nid = Nid::SHA512;
fn new<B: AsRef<[u8]>>(data: B) -> Self {
let mut openssl_sha = sha::Sha512::new();
openssl_sha.update(data.as_ref());
Sha512(openssl_sha.finish())
}
fn bytes(&self) -> &[u8] {
let bs = &self.0[..];
debug_assert_eq!(bs.len(), Self::SIZE);
bs
}
fn from_openssl_digest(digest: &DigestBytes) -> Self {
let digest_bytes = digest.as_ref();
debug_assert_eq!(
digest_bytes.len(),
Self::SIZE,
"digest is not the right size - check constants in `tls.rs`"
);
let mut buf = [0; Self::SIZE];
buf.copy_from_slice(&digest_bytes[0..Self::SIZE]);
Sha512(buf)
}
fn create_message_digest() -> MessageDigest {
MessageDigest::from_nid(Self::NID).expect("Sha512::NID is invalid")
}
}
#[derive(Copy, Clone, DataSize, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
pub(crate) struct CertFingerprint(Sha512);
impl Debug for CertFingerprint {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "CertFingerprint({:10})", HexFmt(self.0.bytes()))
}
}
#[derive(Copy, Clone, DataSize, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
pub struct KeyFingerprint(Sha512);
impl Debug for KeyFingerprint {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "KeyFingerprint({:10})", HexFmt(self.0.bytes()))
}
}
#[cfg(test)]
impl From<[u8; Sha512::SIZE]> for KeyFingerprint {
fn from(raw_bytes: [u8; Sha512::SIZE]) -> Self {
KeyFingerprint(Sha512(raw_bytes))
}
}
#[cfg(test)]
impl Distribution<KeyFingerprint> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> KeyFingerprint {
let mut bytes = [0u8; Sha512::SIZE];
rng.fill(&mut bytes[..]);
bytes.into()
}
}
#[derive(Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
struct Signature(Vec<u8>);
impl Debug for Signature {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Signature({:10})", HexFmt(&self.0))
}
}
#[derive(Clone, DataSize)]
pub struct TlsCert {
#[data_size(skip)]
x509: X509,
cert_fingerprint: CertFingerprint,
key_fingerprint: KeyFingerprint,
}
impl TlsCert {
pub(crate) fn fingerprint(&self) -> CertFingerprint {
self.cert_fingerprint
}
pub(crate) fn public_key_fingerprint(&self) -> KeyFingerprint {
self.key_fingerprint
}
pub(crate) fn as_x509(&self) -> &X509 {
&self.x509
}
}
impl Debug for TlsCert {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "TlsCert({:?})", self.fingerprint())
}
}
impl Hash for TlsCert {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.fingerprint().hash(state);
}
}
impl PartialEq for TlsCert {
fn eq(&self, other: &Self) -> bool {
self.fingerprint() == other.fingerprint()
}
}
impl Eq for TlsCert {}
impl<'de> Deserialize<'de> for TlsCert {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
validate_cert(x509_serde::deserialize(deserializer)?).map_err(serde::de::Error::custom)
}
}
impl Serialize for TlsCert {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
x509_serde::serialize(&self.x509, serializer)
}
}
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct Signed<V> {
data: Vec<u8>,
signature: Signature,
_phantom: PhantomData<V>,
}
pub fn generate_node_cert() -> SslResult<(X509, PKey<Private>)> {
let private_key = generate_private_key()?;
let cert = generate_cert(&private_key, "casper-node")?;
Ok((cert, private_key))
}
pub(crate) fn create_tls_acceptor(
cert: &X509Ref,
private_key: &PKeyRef<Private>,
) -> SslResult<SslAcceptor> {
let mut builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls_server())?;
set_context_options(&mut builder, cert, private_key)?;
Ok(builder.build())
}
pub(crate) fn create_tls_connector(
cert: &X509Ref,
private_key: &PKeyRef<Private>,
) -> SslResult<SslConnector> {
let mut builder = SslConnector::builder(SslMethod::tls_client())?;
set_context_options(&mut builder, cert, private_key)?;
Ok(builder.build())
}
fn set_context_options(
ctx: &mut SslContextBuilder,
cert: &X509Ref,
private_key: &PKeyRef<Private>,
) -> SslResult<()> {
ctx.set_min_proto_version(Some(SslVersion::TLS1_3))?;
ctx.set_certificate(cert)?;
ctx.set_private_key(private_key)?;
ctx.check_private_key()?;
ctx.set_verify_callback(SslVerifyMode::PEER, |_, _| true);
Ok(())
}
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("error reading public key from certificate: {0:?}")]
CannotReadPublicKey(#[source] ErrorStack),
#[error("error reading subject or issuer name: {0:?}")]
CorruptSubjectOrIssuer(#[source] ErrorStack),
#[error("wrong signature scheme")]
WrongSignatureAlgorithm,
#[error("there was an issue reading or converting times: {0:?}")]
TimeIssue(#[source] ErrorStack),
#[error("the certificate is not yet valid")]
NotYetValid,
#[error("the certificate expired")]
Expired,
#[error("the serial number could not be compared to the reference: {0:?}")]
InvalidSerialNumber(#[source] ErrorStack),
#[error("wrong serial number")]
WrongSerialNumber,
#[error("no valid elliptic curve key could be extracted from certificate: {0:?}")]
CouldNotExtractEcKey(#[source] ErrorStack),
#[error("the given public key fails basic sanity checks: {0:?}")]
KeyFailsCheck(#[source] ErrorStack),
#[error("underlying elliptic curve is wrong")]
WrongCurve,
#[error("certificate is not self-signed")]
NotSelfSigned,
#[error("the signature could not be validated")]
FailedToValidateSignature(#[source] ErrorStack),
#[error("the signature is invalid")]
InvalidSignature,
#[error("failed to read fingerprint")]
InvalidFingerprint(#[source] ErrorStack),
#[error("could not create a big num context")]
BigNumContextNotAvailable(#[source] ErrorStack),
#[error("could not encode public key as bytes")]
PublicKeyEncodingFailed(#[source] ErrorStack),
}
pub(crate) fn validate_cert(cert: X509) -> Result<TlsCert, ValidationError> {
if cert.signature_algorithm().object().nid() != SIGNATURE_ALGORITHM {
return Err(ValidationError::WrongSignatureAlgorithm);
}
let subject =
name_to_string(cert.subject_name()).map_err(ValidationError::CorruptSubjectOrIssuer)?;
let issuer =
name_to_string(cert.issuer_name()).map_err(ValidationError::CorruptSubjectOrIssuer)?;
if subject != issuer {
return Err(ValidationError::NotSelfSigned);
}
if !num_eq(cert.serial_number(), 1).map_err(ValidationError::InvalidSerialNumber)? {
return Err(ValidationError::WrongSerialNumber);
}
let asn1_now = Asn1Time::from_unix(now()).map_err(ValidationError::TimeIssue)?;
if asn1_now
.compare(cert.not_before())
.map_err(ValidationError::TimeIssue)?
!= Ordering::Greater
{
return Err(ValidationError::NotYetValid);
}
if asn1_now
.compare(cert.not_after())
.map_err(ValidationError::TimeIssue)?
!= Ordering::Less
{
return Err(ValidationError::Expired);
}
let public_key = cert
.public_key()
.map_err(ValidationError::CannotReadPublicKey)?;
let ec_key = public_key
.ec_key()
.map_err(ValidationError::CouldNotExtractEcKey)?;
ec_key.check_key().map_err(ValidationError::KeyFailsCheck)?;
if ec_key.group().curve_name() != Some(SIGNATURE_CURVE) {
return Err(ValidationError::WrongCurve);
}
if !cert
.verify(&public_key)
.map_err(ValidationError::FailedToValidateSignature)?
{
return Err(ValidationError::InvalidSignature);
}
assert_eq!(Sha512::NID, SIGNATURE_DIGEST);
let digest = &cert
.digest(Sha512::create_message_digest())
.map_err(ValidationError::InvalidFingerprint)?;
let cert_fingerprint = CertFingerprint(Sha512::from_openssl_digest(digest));
let mut big_num_context =
BigNumContext::new().map_err(ValidationError::BigNumContextNotAvailable)?;
let buf = ec_key
.public_key()
.to_bytes(
ec::EcGroup::from_curve_name(SIGNATURE_CURVE)
.expect("broken constant SIGNATURE_CURVE")
.as_ref(),
ec::PointConversionForm::COMPRESSED,
&mut big_num_context,
)
.map_err(ValidationError::PublicKeyEncodingFailed)?;
let key_fingerprint = KeyFingerprint(Sha512::new(&buf));
Ok(TlsCert {
x509: cert,
cert_fingerprint,
key_fingerprint,
})
}
pub(crate) fn load_cert<P: AsRef<Path>>(src: P) -> anyhow::Result<X509> {
let pem = read_file(src.as_ref()).with_context(|| "failed to load certificate")?;
Ok(X509::from_pem(&pem).context("parsing certificate")?)
}
pub(crate) fn load_private_key<P: AsRef<Path>>(src: P) -> anyhow::Result<PKey<Private>> {
let pem = read_file(src.as_ref()).with_context(|| "failed to load private key")?;
Ok(PKey::private_key_from_pem(&pem).context("parsing private key")?)
}
pub fn save_cert<P: AsRef<Path>>(cert: &X509Ref, dest: P) -> anyhow::Result<()> {
let pem = cert.to_pem().context("converting certificate to PEM")?;
write_file(dest, pem).with_context(|| "failed to write certificate")?;
Ok(())
}
pub fn save_private_key<P: AsRef<Path>>(key: &PKeyRef<Private>, dest: P) -> anyhow::Result<()> {
let pem = key
.private_key_to_pem_pkcs8()
.context("converting private key to PEM")?;
write_file(dest, pem).with_context(|| "failed to write private key")?;
Ok(())
}
fn now() -> i64 {
let now = SystemTime::now();
let ts: i64 = now
.duration_since(UNIX_EPOCH)
.expect("Great Scott! Your clock is horribly broken, Marty.")
.as_secs()
.try_into()
.expect("32-bit systems and far future are not supported");
ts
}
fn mknum(n: u32) -> Result<Asn1Integer, ErrorStack> {
let bn = BigNum::from_u32(n)?;
bn.to_asn1_integer()
}
fn mkname(c: &str, o: &str, cn: &str) -> Result<X509Name, ErrorStack> {
let mut builder = X509NameBuilder::new()?;
if !c.is_empty() {
builder.append_entry_by_text("C", c)?;
}
if !o.is_empty() {
builder.append_entry_by_text("O", o)?;
}
builder.append_entry_by_text("CN", cn)?;
Ok(builder.build())
}
fn name_to_string(name: &X509NameRef) -> SslResult<String> {
let mut output = String::new();
for entry in name.entries() {
output.push_str(entry.object().nid().long_name()?);
output.push_str("=");
output.push_str(entry.data().as_utf8()?.as_ref());
output.push_str(" ");
}
Ok(output)
}
fn num_eq(num: &Asn1IntegerRef, other: u32) -> SslResult<bool> {
let l = num.to_bn()?;
let r = BigNum::from_u32(other)?;
Ok(l.is_negative() == r.is_negative() && l.ucmp(&r.as_ref()) == Ordering::Equal)
}
fn generate_private_key() -> SslResult<PKey<Private>> {
let ec_group = ec::EcGroup::from_curve_name(SIGNATURE_CURVE)?;
let ec_key = ec::EcKey::generate(ec_group.as_ref())?;
PKey::from_ec_key(ec_key)
}
fn generate_cert(private_key: &PKey<Private>, cn: &str) -> SslResult<X509> {
let mut builder = X509Builder::new()?;
builder.set_version(2)?;
builder.set_serial_number(mknum(1)?.as_ref())?;
let issuer = mkname("US", "Casper Blockchain", cn)?;
builder.set_issuer_name(issuer.as_ref())?;
builder.set_subject_name(issuer.as_ref())?;
let ts = now();
builder.set_not_before(Asn1Time::from_unix(ts - 60)?.as_ref())?;
builder.set_not_after(Asn1Time::from_unix(ts + 10 * 365 * 24 * 60 * 60)?.as_ref())?;
builder.set_pubkey(private_key.as_ref())?;
assert_eq!(Sha512::NID, SIGNATURE_DIGEST);
builder.sign(private_key.as_ref(), Sha512::create_message_digest())?;
let cert = builder.build();
assert!(
validate_cert(cert.clone()).is_ok(),
"newly generated cert does not pass our own validity check"
);
Ok(cert)
}
mod x509_serde {
use std::str;
use openssl::x509::X509;
use serde::{Deserialize, Deserializer, Serializer};
use super::validate_cert;
pub(super) fn serialize<S>(value: &X509, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = value.to_pem().map_err(serde::ser::Error::custom)?;
serializer.serialize_str(str::from_utf8(&encoded).map_err(serde::ser::Error::custom)?)
}
pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<X509, D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
let x509 = X509::from_pem(s.as_bytes()).map_err(serde::de::Error::custom)?;
validate_cert(x509)
.map_err(serde::de::Error::custom)
.map(|tc| tc.x509)
}
}
impl PartialEq for Sha512 {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.bytes() == other.bytes()
}
}
impl Eq for Sha512 {}
impl Ord for Sha512 {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
Ord::cmp(self.bytes(), other.bytes())
}
}
impl PartialOrd for Sha512 {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(Ord::cmp(self, other))
}
}
impl Debug for Sha512 {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", HexFmt(&self.0[..]))
}
}
impl Display for Sha512 {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:10}", HexFmt(&self.0[..]))
}
}
impl Display for CertFingerprint {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.0, f)
}
}
impl Display for KeyFingerprint {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:10}", HexFmt(self.0.bytes()))
}
}
impl Display for Signature {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:10}", HexFmt(&self.0[..]))
}
}
impl<T> Display for Signed<T>
where
T: Display + for<'de> Deserialize<'de>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match bincode::deserialize::<T>(self.data.as_slice()) {
Ok(item) => write!(f, "signed[{}]<{} bytes>", self.signature, item),
Err(_err) => write!(f, "signed[{}]<CORRUPT>", self.signature),
}
}
}
impl Hash for Sha512 {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
let mut chunk = [0u8; 8];
chunk.copy_from_slice(&self.bytes()[0..8]);
state.write_u64(u64::from_le_bytes(chunk))
}
}
#[cfg(test)]
mod test {
use super::{generate_node_cert, mkname, name_to_string, validate_cert, TlsCert};
#[test]
fn simple_name_to_string() {
let name = mkname("sc", "some_org", "some_cn").expect("could not create name");
assert_eq!(
name_to_string(name.as_ref()).expect("name to string failed"),
"countryName=sc organizationName=some_org commonName=some_cn "
);
}
#[test]
fn test_tls_cert_serde_roundtrip() {
let (cert, _private_key) = generate_node_cert().expect("failed to generate key, cert pair");
let tls_cert = validate_cert(cert).expect("generated cert is not valid");
let serialized = bincode::serialize(&tls_cert).expect("could not serialize");
let deserialized: TlsCert =
bincode::deserialize(serialized.as_slice()).expect("could not deserialize");
let serialized_again = bincode::serialize(&deserialized).expect("could not serialize");
assert_eq!(serialized, serialized_again);
}
}