use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use hex::FromHex;
use keetanetwork_account::{AccountPublicKey, GenericAccount, KeyPairType};
use keetanetwork_asn1::vote as transport;
use keetanetwork_block::{AccountRef, BlockHash};
use num_bigint::BigInt;
use num_traits::Num;
use crate::error::VoteError;
use crate::fee::Fees;
use crate::validity::Validity;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SignatureAlgo {
Ed25519,
EcdsaWithSha3_256,
}
impl SignatureAlgo {
pub(crate) fn from_issuer(account: &GenericAccount) -> Result<Self, VoteError> {
match account.to_keypair_type() {
KeyPairType::ED25519 => Ok(Self::Ed25519),
KeyPairType::ECDSASECP256K1 | KeyPairType::ECDSASECP256R1 => Ok(Self::EcdsaWithSha3_256),
_ => Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchIssuer),
}
}
fn matches_issuer(self, account: &GenericAccount) -> bool {
matches!(
(self, account.to_keypair_type()),
(Self::Ed25519, KeyPairType::ED25519)
| (Self::EcdsaWithSha3_256, KeyPairType::ECDSASECP256K1 | KeyPairType::ECDSASECP256R1)
)
}
fn to_transport(self) -> transport::VoteSignatureAlgo {
match self {
Self::Ed25519 => transport::VoteSignatureAlgo::Ed25519,
Self::EcdsaWithSha3_256 => transport::VoteSignatureAlgo::EcdsaWithSha3_256,
}
}
fn from_transport(value: transport::VoteSignatureAlgo) -> Self {
match value {
transport::VoteSignatureAlgo::Ed25519 => Self::Ed25519,
transport::VoteSignatureAlgo::EcdsaWithSha3_256 => Self::EcdsaWithSha3_256,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct DecodedVote {
pub(crate) serial: BigInt,
pub(crate) signature_algo: SignatureAlgo,
pub(crate) issuer: AccountRef,
pub(crate) validity: Validity,
pub(crate) blocks: Vec<BlockHash>,
pub(crate) fees: Option<Fees>,
pub(crate) signature: Vec<u8>,
pub(crate) tbs_bytes: Vec<u8>,
}
pub(crate) fn build_tbs(
serial: &BigInt,
signature_algo: SignatureAlgo,
issuer: &AccountRef,
validity: Validity,
blocks: &[BlockHash],
fees: Option<&Fees>,
) -> Result<transport::TbsCertificate, VoteError> {
if signature_algo != SignatureAlgo::from_issuer(issuer)? {
return Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchIssuer);
}
build_tbs_inner(serial, signature_algo, issuer, validity, blocks, fees)
}
pub(crate) fn encode_tbs(tbs: &transport::TbsCertificate) -> Result<Vec<u8>, VoteError> {
Ok(transport::encode_tbs(tbs)?)
}
pub(crate) fn encode_vote(
tbs: transport::TbsCertificate,
signature_algo: SignatureAlgo,
signature: Vec<u8>,
) -> Result<Vec<u8>, VoteError> {
let value = transport::VoteCertificate { tbs, signature_algo: signature_algo.to_transport(), signature };
Ok(transport::encode_vote(&value)?)
}
pub(crate) fn decode_wrapper(bytes: &[u8]) -> Result<DecodedVote, VoteError> {
let decoded = transport::decode_vote(bytes).map_err(decode_error_to_vote)?;
if decoded.tbs.signature_algo != decoded.signature_algo {
return Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchWrapper);
}
let signature_algo = SignatureAlgo::from_transport(decoded.signature_algo);
let serial = decoded.tbs.serial_number.clone();
let issuer_string = take_dn_value(&decoded.tbs.issuer, &transport::oids::COMMON_NAME)
.ok_or(VoteError::MalformedVoteIssuerInformation)?;
let issuer: AccountRef = Arc::new(
issuer_string
.parse::<GenericAccount>()
.map_err(|_| VoteError::MalformedVoteIssuerInformation)?,
);
if !signature_algo.matches_issuer(&issuer) {
return Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchIssuer);
}
let subject_serial_hex =
take_dn_value(&decoded.tbs.subject, &transport::oids::SERIAL_NUMBER).ok_or(VoteError::MalformedVoteSerial)?;
let subject_serial = parse_lower_hex_bigint(&subject_serial_hex).map_err(|_| VoteError::MalformedVoteSerial)?;
if subject_serial != serial {
return Err(VoteError::SerialMismatch);
}
let subject_public_key = decode_subject_public_key(&decoded.tbs.subject_public_key)?;
if !accounts_match(&subject_public_key, &issuer) {
return Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchIssuer);
}
let validity = Validity::try_new(decoded.tbs.validity.not_before.into(), decoded.tbs.validity.not_after.into())?;
let (blocks, fees) = collect_extensions(&decoded.tbs.extensions)?;
Ok(DecodedVote {
serial,
signature_algo,
issuer,
validity,
blocks,
fees,
signature: decoded.signature,
tbs_bytes: decoded.tbs_bytes,
})
}
fn build_tbs_inner(
serial: &BigInt,
signature_algo: SignatureAlgo,
issuer: &AccountRef,
validity: Validity,
blocks: &[BlockHash],
fees: Option<&Fees>,
) -> Result<transport::TbsCertificate, VoteError> {
Ok(transport::TbsCertificate {
serial_number: serial.clone(),
signature_algo: signature_algo.to_transport(),
issuer: dn_with_attribute(transport::oids::COMMON_NAME, issuer.to_string()),
validity: transport::Validity { not_before: validity.from.into(), not_after: validity.to.into() },
subject: dn_with_attribute(transport::oids::SERIAL_NUMBER, bigint_to_lower_hex(serial)),
subject_public_key: subject_public_key_for_issuer(issuer)?,
extensions: build_extension_list(blocks, fees)?,
})
}
fn dn_with_attribute(oid: transport::VoteOid, value: String) -> transport::DistinguishedName {
transport::DistinguishedName { rdns: vec![vec![transport::AttributeTypeAndValue { oid, value }]] }
}
fn take_dn_value(dn: &transport::DistinguishedName, oid: &transport::VoteOid) -> Option<String> {
let mut found: Option<String> = None;
for rdn in &dn.rdns {
for attribute in rdn {
if &attribute.oid == oid {
found = Some(attribute.value.clone());
}
}
}
found
}
fn subject_public_key_for_issuer(issuer: &AccountRef) -> Result<transport::VoteSubjectPublicKey, VoteError> {
let bytes = issuer.to_public_key_with_type();
let raw = bytes
.get(1..)
.ok_or(VoteError::MalformedVoteSubjectPublicKeyInformation)?
.to_vec();
match issuer.to_keypair_type() {
KeyPairType::ED25519 => Ok(transport::VoteSubjectPublicKey::Ed25519 { key: raw }),
KeyPairType::ECDSASECP256K1 => {
Ok(transport::VoteSubjectPublicKey::Ecdsa { curve: transport::EcdsaCurve::Secp256k1, key: raw })
}
KeyPairType::ECDSASECP256R1 => {
Ok(transport::VoteSubjectPublicKey::Ecdsa { curve: transport::EcdsaCurve::Secp256r1, key: raw })
}
_ => Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchIssuer),
}
}
fn decode_subject_public_key(value: &transport::VoteSubjectPublicKey) -> Result<AccountRef, VoteError> {
let (key_type, raw) = match value {
transport::VoteSubjectPublicKey::Ed25519 { key } => (KeyPairType::ED25519, key),
transport::VoteSubjectPublicKey::Ecdsa { curve, key } => match curve {
transport::EcdsaCurve::Secp256k1 => (KeyPairType::ECDSASECP256K1, key),
transport::EcdsaCurve::Secp256r1 => (KeyPairType::ECDSASECP256R1, key),
},
};
let mut bytes = Vec::with_capacity(1 + raw.len());
bytes.push(key_type as u8);
bytes.extend_from_slice(raw);
let account = GenericAccount::from_hex(hex::encode(bytes))
.map_err(|_| VoteError::MalformedVoteSubjectPublicKeyInformation)?;
Ok(Arc::new(account))
}
fn build_extension_list(blocks: &[BlockHash], fees: Option<&Fees>) -> Result<Vec<transport::Extension>, VoteError> {
let mut extensions = Vec::new();
let hash_data = transport::HashData {
algorithm: transport::oids::SHA3_256,
hashes: blocks.iter().map(|hash| hash.as_bytes().to_vec()).collect(),
};
let hash_data_value = transport::encode_hash_data(&hash_data)?;
extensions.push(transport::Extension { oid: transport::oids::HASH_DATA, critical: true, value: hash_data_value });
if let Some(fees) = fees {
let fees_transport = fees.to_transport()?;
let fees_value = transport::encode_fees(&fees_transport)?;
extensions.push(transport::Extension { oid: transport::oids::FEES, critical: true, value: fees_value });
}
Ok(extensions)
}
fn collect_extensions(extensions: &[transport::Extension]) -> Result<(Vec<BlockHash>, Option<Fees>), VoteError> {
let mut blocks: Option<Vec<BlockHash>> = None;
let mut fees: Option<Fees> = None;
for extension in extensions {
if extension.oid == transport::oids::HASH_DATA {
let hash_data = transport::decode_hash_data(&extension.value).map_err(hash_data_error)?;
if hash_data.algorithm != transport::oids::SHA3_256 {
return Err(VoteError::MalformedHashesFromVoteDataUnsupportedHashFunc);
}
let mut converted = Vec::with_capacity(hash_data.hashes.len());
for raw in hash_data.hashes {
let hash = BlockHash::try_from(raw.as_slice())
.map_err(|_| VoteError::MalformedHashesFromVoteDataUnsupportedHashType)?;
converted.push(hash);
}
blocks = Some(converted);
} else if extension.oid == transport::oids::FEES {
fees = Some(Fees::from_transport(transport::decode_fees(&extension.value).map_err(fees_error)?)?);
} else if extension.critical {
return Err(VoteError::MalformedVoteExtensionsValueCriticalType);
}
}
let blocks = blocks.ok_or(VoteError::MalformedVoteNoBlocksFound)?;
Ok((blocks, fees))
}
fn decode_error_to_vote(error: keetanetwork_asn1::Asn1Error) -> VoteError {
use keetanetwork_asn1::vote::VoteDecodeSlot;
use keetanetwork_asn1::Asn1Error;
match error {
Asn1Error::InvalidVoteVersion => VoteError::InvalidVersion,
Asn1Error::VoteDecode { slot } => match slot {
VoteDecodeSlot::Wrapper => VoteError::MalformedWrapper,
VoteDecodeSlot::WrapperExtraData => VoteError::MalformedWrapper,
VoteDecodeSlot::TbsContent => VoteError::MalformedVoteWrapper,
VoteDecodeSlot::Version => VoteError::MalformedVoteContent,
VoteDecodeSlot::VersionValue => VoteError::MalformedVoteVersion,
VoteDecodeSlot::Serial => VoteError::MalformedVoteSerial,
VoteDecodeSlot::SignatureAlgorithm => VoteError::MalformedVoteSignatureInformation,
VoteDecodeSlot::Issuer => VoteError::MalformedVoteIssuerInformation,
VoteDecodeSlot::Validity => VoteError::MalformedVoteValidityInformation,
VoteDecodeSlot::Subject => VoteError::MalformedVoteSubjectInformation,
VoteDecodeSlot::SubjectPublicKey => VoteError::MalformedVoteSubjectPublicKeyInformation,
VoteDecodeSlot::Extensions => VoteError::MalformedVoteExtensions,
VoteDecodeSlot::TbsExtraData => VoteError::MalformedVoteContentExtraData,
VoteDecodeSlot::WrapperSignatureAlgorithm => VoteError::MalformedVoteSignatureInformation,
VoteDecodeSlot::SignatureValue => VoteError::MalformedVoteSignatureValue,
},
other => other.into(),
}
}
fn hash_data_error(_error: keetanetwork_asn1::Asn1Error) -> VoteError {
VoteError::MalformedHashesFromVoteInvalidInput
}
fn fees_error(_error: keetanetwork_asn1::Asn1Error) -> VoteError {
VoteError::MalformedFeesFromVoteInvalidInput
}
fn accounts_match(left: &AccountRef, right: &AccountRef) -> bool {
left.to_public_key_with_type() == right.to_public_key_with_type()
}
fn bigint_to_lower_hex(value: &BigInt) -> String {
value.to_str_radix(16)
}
fn parse_lower_hex_bigint(value: &str) -> Result<BigInt, num_bigint::ParseBigIntError> {
BigInt::from_str_radix(value, 16)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::{
ed25519_issuer, secp256k1_issuer, secp256r1_issuer, sign_simple_vote, single_fees, token_account,
validity_millis,
};
#[test]
fn test_signature_algo_from_issuer() {
assert!(matches!(SignatureAlgo::from_issuer(ed25519_issuer(b"a").as_ref()), Ok(SignatureAlgo::Ed25519)));
assert!(matches!(
SignatureAlgo::from_issuer(secp256k1_issuer(b"a").as_ref()),
Ok(SignatureAlgo::EcdsaWithSha3_256)
));
assert!(matches!(
SignatureAlgo::from_issuer(secp256r1_issuer(b"a").as_ref()),
Ok(SignatureAlgo::EcdsaWithSha3_256)
));
assert!(matches!(
SignatureAlgo::from_issuer(token_account(b"a").as_ref()),
Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchIssuer)
));
}
#[test]
fn test_build_tbs_rejects_algo_issuer_mismatch() {
let issuer = secp256k1_issuer(b"a");
let result = build_tbs(
&BigInt::from(1u8),
SignatureAlgo::Ed25519,
&issuer,
validity_millis(0, 60_000),
&[BlockHash::from([1u8; 32])],
None,
);
assert!(matches!(result, Err(VoteError::MalformedVoteSignatureSchemeDoesNotMatchIssuer)));
}
#[test]
fn test_bigint_hex_round_trip() {
assert_eq!(bigint_to_lower_hex(&BigInt::from(255u16)), "ff");
assert!(matches!(parse_lower_hex_bigint("ff"), Ok(value) if value == BigInt::from(255u16)));
assert!(parse_lower_hex_bigint("zz").is_err());
}
#[test]
fn test_take_dn_value() {
let dn = dn_with_attribute(transport::oids::COMMON_NAME, "hello".to_string());
assert_eq!(take_dn_value(&dn, &transport::oids::COMMON_NAME), Some("hello".to_string()));
assert_eq!(take_dn_value(&dn, &transport::oids::SERIAL_NUMBER), None);
}
#[test]
fn test_subject_public_key_round_trip() -> Result<(), VoteError> {
for issuer in [ed25519_issuer(b"k"), secp256k1_issuer(b"k"), secp256r1_issuer(b"k")] {
let encoded = subject_public_key_for_issuer(&issuer)?;
let decoded = decode_subject_public_key(&encoded)?;
assert!(accounts_match(&decoded, &issuer), "decoded subject public key must match issuer");
}
Ok(())
}
#[test]
fn test_extension_list_round_trip() -> Result<(), VoteError> {
let blocks = vec![BlockHash::from([1u8; 32]), BlockHash::from([2u8; 32])];
let fees = single_fees(7);
let extensions = build_extension_list(&blocks, Some(&fees))?;
let (decoded_blocks, decoded_fees) = collect_extensions(&extensions)?;
assert_eq!(decoded_blocks, blocks, "round-tripped block hashes must match");
assert!(decoded_fees.is_some(), "round-tripped fees must be present");
Ok(())
}
#[test]
fn test_collect_extensions_requires_blocks() {
let result = collect_extensions(&[]);
assert!(matches!(result, Err(VoteError::MalformedVoteNoBlocksFound)));
}
#[test]
fn test_decode_wrapper_round_trip() -> Result<(), VoteError> {
let issuer = ed25519_issuer(b"alice");
let blocks = vec![BlockHash::from([3u8; 32])];
let vote = sign_simple_vote(&issuer, 9, validity_millis(0, 60_000), blocks.clone(), None);
let decoded = decode_wrapper(vote.as_bytes())?;
assert_eq!(decoded.serial, BigInt::from(9u8), "decoded serial must match");
assert_eq!(decoded.blocks, blocks, "decoded blocks must match");
assert!(accounts_match(&decoded.issuer, &issuer), "decoded issuer must match");
assert!(decoded.fees.is_none(), "fee-free vote must decode without fees");
Ok(())
}
#[test]
fn test_decode_error_to_vote_maps_version() {
assert!(matches!(
decode_error_to_vote(keetanetwork_asn1::Asn1Error::InvalidVoteVersion),
VoteError::InvalidVersion
));
}
}