use std::{
convert::TryFrom,
fmt::{self, Display, Formatter},
pin::Pin,
sync::{Arc, Weak},
};
use bytes::{Bytes, BytesMut};
use openssl::ssl::SslRef;
use pin_project::pin_project;
#[cfg(test)]
use rand::RngCore;
use static_assertions::const_assert;
use tokio_serde::{Deserializer, Serializer};
use tracing::{trace, warn};
#[cfg(test)]
use casper_types::testing::TestRng;
use casper_types::Digest;
use super::{tls::KeyFingerprint, Message, Metrics, Payload};
use crate::{types::NodeId, utils};
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct TraceId([u8; 8]);
impl Display for TraceId {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(&base16::encode_lower(&self.0))
}
}
#[pin_project]
#[derive(Debug)]
pub struct CountingFormat<F> {
#[pin]
inner: F,
connection_id: ConnectionId,
out_count: u64,
in_count: u64,
role: Role,
metrics: Weak<Metrics>,
}
impl<F> CountingFormat<F> {
#[inline]
pub(super) fn new(
metrics: Weak<Metrics>,
connection_id: ConnectionId,
role: Role,
inner: F,
) -> Self {
Self {
metrics,
connection_id,
out_count: 0,
in_count: 0,
role,
inner,
}
}
}
impl<F, P> Serializer<Arc<Message<P>>> for CountingFormat<F>
where
F: Serializer<Arc<Message<P>>>,
P: Payload,
{
type Error = F::Error;
#[inline]
fn serialize(self: Pin<&mut Self>, item: &Arc<Message<P>>) -> Result<Bytes, Self::Error> {
let this = self.project();
let projection: Pin<&mut F> = this.inner;
let serialized = F::serialize(projection, item)?;
let msg_size = serialized.len() as u64;
let msg_kind = item.classify();
Metrics::record_payload_out(this.metrics, msg_kind, msg_size);
let trace_id = this
.connection_id
.create_trace_id(this.role.out_flag(), *this.out_count);
*this.out_count += 1;
trace!(target: "net_out",
msg_id = %trace_id,
msg_size,
msg_kind = %msg_kind, "sending");
Ok(serialized)
}
}
impl<F, P> Deserializer<Message<P>> for CountingFormat<F>
where
F: Deserializer<Message<P>>,
P: Payload,
{
type Error = F::Error;
#[inline]
fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Message<P>, Self::Error> {
let this = self.project();
let projection: Pin<&mut F> = this.inner;
let msg_size = src.len() as u64;
let deserialized = F::deserialize(projection, src)?;
let msg_kind = deserialized.classify();
Metrics::record_payload_in(this.metrics, msg_kind, msg_size);
let trace_id = this
.connection_id
.create_trace_id(this.role.in_flag(), *this.in_count);
*this.in_count += 1;
trace!(target: "net_in",
msg_id = %trace_id,
msg_size,
msg_kind = %msg_kind, "received");
Ok(deserialized)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(super) struct ConnectionId([u8; Digest::LENGTH]);
const_assert!(KeyFingerprint::LENGTH >= Digest::LENGTH);
const_assert!(Digest::LENGTH >= 12);
#[derive(Copy, Clone, Debug)]
pub(super) struct TlsRandomData {
combined_random: [u8; 12],
}
const ZERO_RANDOMNESS: [u8; 12] = [0; 12];
impl TlsRandomData {
fn collect(ssl: &SslRef) -> Self {
let mut server_random = [0; 12];
let mut client_random = [0; 12];
ssl.server_random(&mut server_random);
if server_random == ZERO_RANDOMNESS {
warn!("TLS server random is all zeros");
}
ssl.client_random(&mut client_random);
if server_random == ZERO_RANDOMNESS {
warn!("TLS client random is all zeros");
}
utils::xor(&mut server_random, &client_random);
Self {
combined_random: server_random,
}
}
#[cfg(test)]
fn random(rng: &mut TestRng) -> Self {
let mut buffer = [0u8; 12];
rng.fill_bytes(&mut buffer);
Self {
combined_random: buffer,
}
}
}
impl ConnectionId {
fn create(random_data: TlsRandomData, our_id: NodeId, their_id: NodeId) -> ConnectionId {
let mut id = Digest::hash(random_data.combined_random).value();
utils::xor(&mut id, &our_id.hash_bytes()[0..Digest::LENGTH]);
utils::xor(&mut id, &their_id.hash_bytes()[0..Digest::LENGTH]);
ConnectionId(id)
}
fn create_trace_id(&self, flag: u8, count: u64) -> TraceId {
let mut buffer = self.0;
buffer[0] ^= flag;
utils::xor(&mut buffer[4..12], &count.to_ne_bytes());
let full_hash = Digest::hash(buffer);
let truncated = TryFrom::try_from(&full_hash.value()[0..8]).expect("buffer size mismatch");
TraceId(truncated)
}
#[inline]
pub(crate) fn as_bytes(&self) -> &[u8] {
&self.0
}
#[inline]
pub(crate) fn from_connection(ssl: &SslRef, our_id: NodeId, their_id: NodeId) -> Self {
Self::create(TlsRandomData::collect(ssl), our_id, their_id)
}
#[cfg(test)]
pub(super) fn random(rng: &mut TestRng) -> Self {
ConnectionId::create(
TlsRandomData::random(rng),
NodeId::random(rng),
NodeId::random(rng),
)
}
}
#[derive(Copy, Clone, Debug)]
#[repr(u8)]
pub(super) enum Role {
Dialer,
Listener,
}
impl Role {
#[inline]
fn in_flag(self) -> u8 {
!(self.out_flag())
}
#[inline]
fn out_flag(self) -> u8 {
const MAGIC_FLAG: u8 = 0b10101010;
match self {
Role::Dialer => MAGIC_FLAG,
Role::Listener => !MAGIC_FLAG,
}
}
}
#[cfg(test)]
mod tests {
use crate::types::NodeId;
use super::{ConnectionId, Role, TlsRandomData, TraceId};
#[test]
fn trace_id_has_16_character() {
let data = [0, 1, 2, 3, 4, 5, 6, 7];
let output = format!("{}", TraceId(data));
assert_eq!(output.len(), 16);
}
#[test]
fn can_create_deterministic_trace_id() {
let mut rng = crate::new_rng();
let node_a = NodeId::random(&mut rng);
let node_b = NodeId::random(&mut rng);
let a_to_b_random = TlsRandomData::random(&mut rng);
let a_to_b = ConnectionId::create(a_to_b_random, node_a, node_b);
let a_to_b_alt = ConnectionId::create(a_to_b_random, node_b, node_a);
assert_eq!(a_to_b, a_to_b_alt);
let b_to_a_random = TlsRandomData::random(&mut rng);
let b_to_a = ConnectionId::create(b_to_a_random, node_b, node_a);
let b_to_a_alt = ConnectionId::create(b_to_a_random, node_a, node_b);
assert_eq!(b_to_a, b_to_a_alt);
assert_ne!(a_to_b, b_to_a);
let msg_ab_0_on_a = a_to_b.create_trace_id(Role::Dialer.out_flag(), 0);
let msg_ab_0_on_b = a_to_b.create_trace_id(Role::Listener.in_flag(), 0);
assert_eq!(msg_ab_0_on_a, msg_ab_0_on_b);
let msg_ab_1_on_a = a_to_b.create_trace_id(Role::Dialer.out_flag(), 1);
let msg_ab_1_on_b = a_to_b.create_trace_id(Role::Listener.in_flag(), 1);
assert_eq!(msg_ab_1_on_a, msg_ab_1_on_b);
assert_ne!(msg_ab_0_on_a, msg_ab_1_on_a);
let msg_ba_0_on_b = a_to_b.create_trace_id(Role::Listener.out_flag(), 0);
let msg_ba_0_on_a = a_to_b.create_trace_id(Role::Dialer.in_flag(), 0);
assert_eq!(msg_ba_0_on_b, msg_ba_0_on_a);
assert_ne!(msg_ba_0_on_b, msg_ab_0_on_b);
}
}