use std::{
cmp::Ordering,
convert::{TryFrom, TryInto},
fmt,
hash::{Hash, Hasher},
marker::PhantomData,
ops::BitXor,
};
use blake2::{
Blake2bVar,
digest::{Update, VariableOutput},
};
use serde::{Deserialize, Deserializer, Serialize, de};
use tari_utilities::{
ByteArray,
ByteArrayError,
hex::{Hex, to_hex},
};
use thiserror::Error;
use crate::types::CommsPublicKey;
pub(super) type NodeIdArray = [u8; NodeId::byte_size()];
#[derive(Debug, Error, Clone)]
pub enum NodeIdError {
#[error("Incorrect byte count (expected {} bytes)", NodeId::byte_size())]
IncorrectByteCount,
#[error("Invalid digest output size")]
InvalidDigestOutputSize,
}
#[derive(Clone, Eq, Deserialize, Serialize, Default)]
pub struct NodeId(NodeIdArray);
impl NodeId {
pub fn new() -> Self {
Default::default()
}
pub const fn byte_size() -> usize {
13
}
pub fn from_key<K: ByteArray>(key: &K) -> Self {
let bytes = key.as_bytes();
let mut buf = [0u8; NodeId::byte_size()];
Blake2bVar::new(NodeId::byte_size())
.expect("NodeId::byte_size() is invalid")
.chain(bytes)
.finalize_variable(&mut buf).unwrap();
NodeId(buf)
}
pub fn from_public_key(key: &CommsPublicKey) -> Self {
Self::from_key(key)
}
pub fn into_inner(self) -> NodeIdArray {
self.0
}
pub fn short_str(&self) -> String {
to_hex(self.0.get(..8).expect("Index should exist"))
}
}
impl ByteArray for NodeId {
fn from_canonical_bytes(bytes: &[u8]) -> Result<Self, ByteArrayError> {
bytes.try_into().map_err(|err| ByteArrayError::ConversionError {
reason: format!("{err:?}"),
})
}
fn as_bytes(&self) -> &[u8] {
self.0.as_ref()
}
}
impl ByteArray for Box<NodeId> {
fn from_canonical_bytes(bytes: &[u8]) -> Result<Self, ByteArrayError> {
let node_id = NodeId::try_from(bytes).map_err(|err| ByteArrayError::ConversionError {
reason: format!("{err:?}"),
})?;
Ok(Box::new(node_id))
}
fn as_bytes(&self) -> &[u8] {
&self.as_ref().0
}
}
impl PartialEq for NodeId {
fn eq(&self, nid: &NodeId) -> bool {
self.0 == nid.0
}
}
impl PartialOrd<NodeId> for NodeId {
fn partial_cmp(&self, other: &NodeId) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for NodeId {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
}
}
impl BitXor for &NodeId {
type Output = NodeIdArray;
fn bitxor(self, rhs: Self) -> Self::Output {
let mut xor = [0u8; NodeId::byte_size()];
#[allow(clippy::needless_range_loop)]
for i in 0..NodeId::byte_size() {
*xor.get_mut(i).expect("Index should exist") =
self.0.get(i).expect("Index should exist") ^ rhs.0.get(i).expect("Index should exist");
}
xor
}
}
impl TryFrom<&[u8]> for NodeId {
type Error = NodeIdError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != NodeId::byte_size() {
return Err(NodeIdError::IncorrectByteCount);
}
let mut buf = [0; NodeId::byte_size()];
buf.copy_from_slice(bytes);
Ok(NodeId(buf))
}
}
impl From<CommsPublicKey> for NodeId {
fn from(pk: CommsPublicKey) -> Self {
NodeId::from_public_key(&pk)
}
}
impl Hash for NodeId {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl AsRef<[u8]> for NodeId {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl fmt::Display for NodeId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", to_hex(&self.0))
}
}
impl fmt::Debug for NodeId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "NodeId({})", to_hex(&self.0))
}
}
pub fn deserialize_node_id_from_hex<'de, D>(des: D) -> Result<NodeId, D::Error>
where D: Deserializer<'de> {
struct KeyStringVisitor<K> {
marker: PhantomData<K>,
}
impl de::Visitor<'_> for KeyStringVisitor<NodeId> {
type Value = NodeId;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a node id in hex format")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where E: de::Error {
NodeId::from_hex(v).map_err(E::custom)
}
}
des.deserialize_str(KeyStringVisitor { marker: PhantomData })
}
#[cfg(test)]
mod test {
#![allow(clippy::indexing_slicing)]
use tari_crypto::keys::SecretKey;
use super::*;
use crate::types::CommsSecretKey;
#[test]
fn display() {
let node_id = NodeId::try_from(&[144u8, 28, 106, 112, 220, 197, 216, 119, 9, 217, 42, 77, 159][..]).unwrap();
let result = format!("{node_id}");
assert_eq!("901c6a70dcc5d87709d92a4d9f", result);
}
#[test]
fn test_from_public_key() {
let mut rng = rand::rngs::OsRng;
let sk = CommsSecretKey::random(&mut rng);
let pk = CommsPublicKey::from_secret_key(&sk);
let node_id = NodeId::from_key(&pk);
assert_ne!(node_id.0.to_vec(), NodeId::new().0.to_vec());
let mut pk_array: [u8; 32] = [0; 32];
pk_array.copy_from_slice(pk.as_bytes());
assert_ne!(node_id.0.to_vec(), pk_array.to_vec());
}
#[test]
fn partial_eq() {
let bytes = &[173, 218, 34, 188, 211, 173, 235, 82, 18, 159, 55, 47, 242][..];
let nid1 = NodeId::try_from(bytes).unwrap();
let nid2 = NodeId::try_from(bytes).unwrap();
assert_eq!(nid1, nid2);
}
}