use std::{
cmp::min,
time::{Duration, Instant},
};
use ed25519_dalek::{Signature, SigningKey, PUBLIC_KEY_LENGTH, SIGNATURE_LENGTH};
use rand_core::OsRng;
use sha2::Digest;
use x25519_dalek::StaticSecret;
use crate::{
buffer::OutputBuffer,
error::RnsError,
hash::{AddressHash, Hash, ADDRESS_HASH_SIZE},
identity::{DecryptIdentity, DerivedKey, EncryptIdentity, Identity, PrivateIdentity},
packet::{
DestinationType, Header, Packet, PacketContext, PacketDataBuffer, PacketType, PACKET_MDU,
},
};
use super::DestinationDesc;
const LINK_MTU_SIZE: usize = 3;
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum LinkStatus {
Pending = 0x00,
Handshake = 0x01,
Active = 0x02,
Stale = 0x03,
Closed = 0x04,
}
impl LinkStatus {
pub fn not_yet_active(&self) -> bool {
*self == LinkStatus::Pending || *self == LinkStatus::Handshake
}
}
pub type LinkId = AddressHash;
#[derive(Clone)]
pub struct LinkPayload {
buffer: [u8; PACKET_MDU],
len: usize,
context: PacketContext,
request_id: Option<[u8; ADDRESS_HASH_SIZE]>,
}
impl LinkPayload {
pub fn new() -> Self {
Self {
buffer: [0u8; PACKET_MDU],
len: 0,
context: PacketContext::None,
request_id: None,
}
}
pub fn new_from_slice(data: &[u8]) -> Self {
Self::new_from_slice_with_context(data, PacketContext::None)
}
pub fn new_from_slice_with_context(data: &[u8], context: PacketContext) -> Self {
let mut buffer = [0u8; PACKET_MDU];
let len = min(data.len(), buffer.len());
buffer[..len].copy_from_slice(&data[..len]);
Self {
buffer,
len,
context,
request_id: None,
}
}
pub fn new_from_slice_with_context_and_request_id(
data: &[u8],
context: PacketContext,
request_id: Option<[u8; ADDRESS_HASH_SIZE]>,
) -> Self {
let mut payload = Self::new_from_slice_with_context(data, context);
payload.request_id = request_id;
payload
}
pub fn new_from_vec(data: &[u8]) -> Self {
let mut buffer = [0u8; PACKET_MDU];
let copy_len = min(buffer.len(), data.len());
buffer[..copy_len].copy_from_slice(&data[..copy_len]);
Self {
buffer,
len: data.len(),
context: PacketContext::None,
request_id: None,
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn context(&self) -> PacketContext {
self.context
}
pub fn request_id(&self) -> Option<[u8; ADDRESS_HASH_SIZE]> {
self.request_id
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_slice(&self) -> &[u8] {
&self.buffer[..self.len]
}
}
impl Default for LinkPayload {
fn default() -> Self {
Self::new()
}
}
impl From<&Packet> for LinkId {
fn from(packet: &Packet) -> Self {
let data = packet.data.as_slice();
let data_diff = if data.len() > PUBLIC_KEY_LENGTH * 2 {
data.len() - PUBLIC_KEY_LENGTH * 2
} else {
0
};
let hashable_data = &data[..data.len() - data_diff];
AddressHash::new_from_hash(&Hash::new(
Hash::generator()
.chain_update([packet.header.to_meta() & 0b00001111])
.chain_update(packet.destination.as_slice())
.chain_update([packet.context as u8])
.chain_update(hashable_data)
.finalize()
.into(),
))
}
}
#[allow(clippy::large_enum_variant)]
pub enum LinkHandleResult {
None,
Activated,
Proof(Packet),
KeepAlive,
}
#[derive(Clone)]
pub enum LinkEvent {
Activated,
Data(Box<LinkPayload>),
Closed,
}
#[derive(Clone)]
pub struct LinkEventData {
pub id: LinkId,
pub address_hash: AddressHash,
pub event: LinkEvent,
}
pub struct Link {
id: LinkId,
destination: DestinationDesc,
priv_identity: PrivateIdentity,
peer_identity: Identity,
derived_key: DerivedKey,
signalling: Option<[u8; LINK_MTU_SIZE]>,
status: LinkStatus,
request_time: Instant,
rtt: Duration,
event_tx: tokio::sync::broadcast::Sender<LinkEventData>,
}
impl Link {
pub fn new(
destination: DestinationDesc,
event_tx: tokio::sync::broadcast::Sender<LinkEventData>,
) -> Self {
Self {
id: AddressHash::new_empty(),
destination,
priv_identity: PrivateIdentity::new_from_rand(OsRng),
peer_identity: Identity::default(),
derived_key: DerivedKey::new_empty(),
signalling: None,
status: LinkStatus::Pending,
request_time: Instant::now(),
rtt: Duration::from_secs(0),
event_tx,
}
}
pub fn new_from_request(
packet: &Packet,
signing_key: SigningKey,
destination: DestinationDesc,
event_tx: tokio::sync::broadcast::Sender<LinkEventData>,
) -> Result<Self, RnsError> {
if packet.data.len() < PUBLIC_KEY_LENGTH * 2 {
return Err(RnsError::InvalidArgument);
}
let data = packet.data.as_slice();
let peer_identity = Identity::new_from_slices(
&data[..PUBLIC_KEY_LENGTH],
&data[PUBLIC_KEY_LENGTH..PUBLIC_KEY_LENGTH * 2],
);
let signalling = if data.len() >= PUBLIC_KEY_LENGTH * 2 + LINK_MTU_SIZE {
let mut bytes = [0u8; LINK_MTU_SIZE];
bytes.copy_from_slice(
&data[PUBLIC_KEY_LENGTH * 2..PUBLIC_KEY_LENGTH * 2 + LINK_MTU_SIZE],
);
Some(bytes)
} else {
None
};
let link_id = LinkId::from(packet);
log::debug!("link: create from request {}", link_id);
let mut link = Self {
id: link_id,
destination,
priv_identity: PrivateIdentity::new(StaticSecret::random_from_rng(OsRng), signing_key),
peer_identity,
derived_key: DerivedKey::new_empty(),
signalling,
status: LinkStatus::Pending,
request_time: Instant::now(),
rtt: Duration::from_secs(0),
event_tx,
};
link.handshake(peer_identity);
Ok(link)
}
pub fn request(&mut self) -> Packet {
let mut packet_data = PacketDataBuffer::new();
packet_data.safe_write(self.priv_identity.as_identity().public_key.as_bytes());
packet_data.safe_write(self.priv_identity.as_identity().verifying_key.as_bytes());
let packet = Packet {
header: Header {
packet_type: PacketType::LinkRequest,
..Default::default()
},
ifac: None,
destination: self.destination.address_hash,
transport: None,
context: PacketContext::None,
data: packet_data,
};
self.status = LinkStatus::Pending;
self.id = LinkId::from(&packet);
self.request_time = Instant::now();
packet
}
pub fn prove(&mut self) -> Packet {
log::debug!("link({}): prove", self.id);
if self.status != LinkStatus::Active {
self.status = LinkStatus::Active;
self.post_event(LinkEvent::Activated);
}
let mut packet_data = PacketDataBuffer::new();
packet_data.safe_write(self.id.as_slice());
packet_data.safe_write(self.priv_identity.as_identity().public_key.as_bytes());
packet_data.safe_write(self.priv_identity.as_identity().verifying_key.as_bytes());
if let Some(signalling) = self.signalling {
packet_data.safe_write(&signalling);
}
let signature = self.priv_identity.sign(packet_data.as_slice());
packet_data.reset();
packet_data.safe_write(&signature.to_bytes()[..]);
packet_data.safe_write(self.priv_identity.as_identity().public_key.as_bytes());
if let Some(signalling) = self.signalling {
packet_data.safe_write(&signalling);
}
Packet {
header: Header {
packet_type: PacketType::Proof,
destination_type: DestinationType::Link,
..Default::default()
},
ifac: None,
destination: self.id,
transport: None,
context: PacketContext::LinkRequestProof,
data: packet_data,
}
}
pub fn prove_packet(&self, packet: &Packet) -> Packet {
let hash = packet.hash().to_bytes();
let signature = self.priv_identity.sign(&hash).to_bytes();
let mut packet_data = PacketDataBuffer::new();
packet_data.safe_write(&hash);
packet_data.safe_write(&signature);
Packet {
header: Header {
packet_type: PacketType::Proof,
destination_type: DestinationType::Link,
..Default::default()
},
ifac: None,
destination: self.id,
transport: None,
context: PacketContext::LinkProof,
data: packet_data,
}
}
fn handle_data_packet(&mut self, packet: &Packet) -> LinkHandleResult {
if self.status != LinkStatus::Active {
log::warn!("link({}): handling data packet in inactive state", self.id);
}
match packet.context {
PacketContext::None
| PacketContext::Request
| PacketContext::Response
| PacketContext::LinkIdentify => {
let mut buffer = [0u8; PACKET_MDU];
if let Ok(plain_text) = self.decrypt(packet.data.as_slice(), &mut buffer[..]) {
let preview_len = plain_text.len().min(32);
eprintln!(
"[link] data_plain len={} preview={}",
plain_text.len(),
bytes_to_hex(&plain_text[..preview_len])
);
log::trace!("link({}): data {}B", self.id, plain_text.len());
self.request_time = Instant::now();
let request_id = if packet.context == PacketContext::Request {
let hash = packet.hash().to_bytes();
let mut id = [0u8; ADDRESS_HASH_SIZE];
id.copy_from_slice(&hash[..ADDRESS_HASH_SIZE]);
Some(id)
} else {
None
};
self.post_event(LinkEvent::Data(Box::new(
LinkPayload::new_from_slice_with_context_and_request_id(
plain_text,
packet.context,
request_id,
),
)));
return LinkHandleResult::Proof(self.prove_packet(packet));
} else {
log::error!("link({}): can't decrypt packet", self.id);
}
}
PacketContext::KeepAlive => {
if !packet.data.is_empty() && packet.data.as_slice()[0] == 0xFF {
self.request_time = Instant::now();
log::trace!("link({}): keep-alive request", self.id);
return LinkHandleResult::KeepAlive;
}
if !packet.data.is_empty() && packet.data.as_slice()[0] == 0xFE {
log::trace!("link({}): keep-alive response", self.id);
self.request_time = Instant::now();
return LinkHandleResult::None;
}
}
_ => {}
}
LinkHandleResult::None
}
pub fn handle_packet(&mut self, packet: &Packet) -> LinkHandleResult {
if packet.destination != self.id {
return LinkHandleResult::None;
}
match packet.header.packet_type {
PacketType::Data => return self.handle_data_packet(packet),
PacketType::Proof => {
if self.status == LinkStatus::Pending
&& packet.context == PacketContext::LinkRequestProof
{
if let Ok(identity) = validate_proof_packet(&self.destination, &self.id, packet)
{
log::debug!("link({}): has been proved", self.id);
self.handshake(identity);
self.status = LinkStatus::Active;
self.rtt = self.request_time.elapsed();
log::debug!("link({}): activated", self.id);
self.post_event(LinkEvent::Activated);
return LinkHandleResult::Activated;
} else {
log::warn!("link({}): proof is not valid", self.id);
}
}
}
_ => {}
}
LinkHandleResult::None
}
pub fn data_packet(&self, data: &[u8]) -> Result<Packet, RnsError> {
if self.status != LinkStatus::Active {
log::warn!("link: can't create data packet for closed link");
}
let mut packet_data = PacketDataBuffer::new();
let cipher_text_len = {
let cipher_text = self.encrypt(data, packet_data.accuire_buf_max())?;
cipher_text.len()
};
packet_data.resize(cipher_text_len);
Ok(Packet {
header: Header {
destination_type: DestinationType::Link,
packet_type: PacketType::Data,
..Default::default()
},
ifac: None,
destination: self.id,
transport: None,
context: PacketContext::None,
data: packet_data,
})
}
pub fn keep_alive_packet(&self, data: u8) -> Packet {
log::trace!("link({}): create keep alive {}", self.id, data);
let mut packet_data = PacketDataBuffer::new();
packet_data.safe_write(&[data]);
Packet {
header: Header {
destination_type: DestinationType::Link,
packet_type: PacketType::Data,
..Default::default()
},
ifac: None,
destination: self.id,
transport: None,
context: PacketContext::KeepAlive,
data: packet_data,
}
}
pub fn encrypt<'a>(&self, text: &[u8], out_buf: &'a mut [u8]) -> Result<&'a [u8], RnsError> {
self.priv_identity
.encrypt(OsRng, text, &self.derived_key, out_buf)
}
pub fn decrypt<'a>(&self, text: &[u8], out_buf: &'a mut [u8]) -> Result<&'a [u8], RnsError> {
self.priv_identity
.decrypt(OsRng, text, &self.derived_key, out_buf)
}
pub fn destination(&self) -> &DestinationDesc {
&self.destination
}
pub fn peer_identity(&self) -> &Identity {
&self.peer_identity
}
pub fn create_rtt(&self) -> Packet {
let rtt = self.rtt.as_secs_f32();
let mut buf = Vec::new();
{
buf.reserve(4);
rmp::encode::write_f32(&mut buf, rtt).unwrap();
}
let mut packet_data = PacketDataBuffer::new();
let token_len = {
let token = self
.encrypt(buf.as_slice(), packet_data.accuire_buf_max())
.expect("encrypted data");
token.len()
};
packet_data.resize(token_len);
log::trace!("link: {} create rtt packet = {} sec", self.id, rtt);
Packet {
header: Header {
destination_type: DestinationType::Link,
..Default::default()
},
ifac: None,
destination: self.id,
transport: None,
context: PacketContext::LinkRTT,
data: packet_data,
}
}
fn handshake(&mut self, peer_identity: Identity) {
log::debug!("link({}): handshake", self.id);
self.status = LinkStatus::Handshake;
self.peer_identity = peer_identity;
self.derived_key = self
.priv_identity
.derive_key(&self.peer_identity.public_key, Some(self.id.as_slice()));
}
fn post_event(&self, event: LinkEvent) {
let _ = self.event_tx.send(LinkEventData {
id: self.id,
address_hash: self.destination.address_hash,
event,
});
}
pub fn close(&mut self) {
self.status = LinkStatus::Closed;
self.post_event(LinkEvent::Closed);
log::warn!("link: close {}", self.id);
}
pub fn restart(&mut self) {
log::warn!(
"link({}): restart after {}s",
self.id,
self.request_time.elapsed().as_secs()
);
self.status = LinkStatus::Pending;
}
pub fn elapsed(&self) -> Duration {
self.request_time.elapsed()
}
pub fn status(&self) -> LinkStatus {
self.status
}
pub fn id(&self) -> &LinkId {
&self.id
}
}
fn bytes_to_hex(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for byte in bytes {
use std::fmt::Write;
let _ = write!(&mut out, "{:02x}", byte);
}
out
}
fn validate_proof_packet(
destination: &DestinationDesc,
id: &LinkId,
packet: &Packet,
) -> Result<Identity, RnsError> {
const MIN_PROOF_LEN: usize = SIGNATURE_LENGTH + PUBLIC_KEY_LENGTH;
const MTU_PROOF_LEN: usize = SIGNATURE_LENGTH + PUBLIC_KEY_LENGTH + LINK_MTU_SIZE;
const SIGN_DATA_LEN: usize = ADDRESS_HASH_SIZE + PUBLIC_KEY_LENGTH * 2 + LINK_MTU_SIZE;
if packet.data.len() < MIN_PROOF_LEN {
return Err(RnsError::PacketError);
}
let mut proof_data = [0u8; SIGN_DATA_LEN];
let verifying_key = destination.identity.verifying_key.as_bytes();
let sign_data_len = {
let mut output = OutputBuffer::new(&mut proof_data[..]);
output.write(id.as_slice())?;
output.write(
&packet.data.as_slice()[SIGNATURE_LENGTH..SIGNATURE_LENGTH + PUBLIC_KEY_LENGTH],
)?;
output.write(verifying_key)?;
if packet.data.len() >= MTU_PROOF_LEN {
let mtu_bytes = &packet.data.as_slice()[SIGNATURE_LENGTH + PUBLIC_KEY_LENGTH..];
output.write(mtu_bytes)?;
}
output.offset()
};
let identity = Identity::new_from_slices(
&proof_data[ADDRESS_HASH_SIZE..ADDRESS_HASH_SIZE + PUBLIC_KEY_LENGTH],
verifying_key,
);
let signature = Signature::from_slice(&packet.data.as_slice()[..SIGNATURE_LENGTH])
.map_err(|_| RnsError::CryptoError)?;
identity
.verify(&proof_data[..sign_data_len], &signature)
.map_err(|_| RnsError::IncorrectSignature)?;
Ok(identity)
}