use std::{
convert::{TryFrom, TryInto},
fmt,
io::{self, Cursor, Read, Write},
};
use crate::{
__impl_index, __impl_u8_array,
address::{Address, FromPublicKeysError},
content, crypto,
io::{LenBm, ReadFrom, SizedReadFrom, WriteTo},
message,
net::{Addr, AddrExt, OnionV3Addr, ParseOnionV3AddrError, SocketAddrExt},
priv_util::ToHexString,
time::Time,
var_type::VarInt,
};
pub use crate::stream::StreamNumber;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct ObjectType(u32);
impl ObjectType {
pub const fn new(value: u32) -> Self {
Self(value)
}
pub fn as_u32(self) -> u32 {
self.0
}
}
impl fmt::Display for ObjectType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<u32> for ObjectType {
fn from(value: u32) -> Self {
Self(value)
}
}
impl WriteTo for ObjectType {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.0.write_to(w)
}
}
impl ReadFrom for ObjectType {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self(u32::read_from(r)?))
}
}
impl LenBm for ObjectType {
fn len_bm(&self) -> usize {
self.0.len_bm()
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct ObjectVersion(VarInt);
impl fmt::Display for ObjectVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl ObjectVersion {
pub fn new(value: u64) -> Self {
Self(value.into())
}
pub fn as_u64(self) -> u64 {
self.0.as_u64()
}
}
impl From<u64> for ObjectVersion {
fn from(value: u64) -> Self {
Self(value.into())
}
}
impl WriteTo for ObjectVersion {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.0.write_to(w)
}
}
impl ReadFrom for ObjectVersion {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self(VarInt::read_from(r)?))
}
}
impl LenBm for ObjectVersion {
fn len_bm(&self) -> usize {
self.0.len_bm()
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Header {
expires_time: Time,
object_type: ObjectType,
version: ObjectVersion,
stream_number: StreamNumber,
}
impl Header {
pub fn new(
expires_time: Time,
object_type: ObjectType,
version: ObjectVersion,
stream_number: StreamNumber,
) -> Self {
Self {
expires_time,
object_type,
version,
stream_number,
}
}
pub fn expires_time(&self) -> Time {
self.expires_time
}
pub fn object_type(&self) -> ObjectType {
self.object_type
}
pub fn version(&self) -> ObjectVersion {
self.version
}
pub fn stream_number(&self) -> StreamNumber {
self.stream_number
}
}
impl WriteTo for Header {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.expires_time.write_to(w)?;
self.object_type.write_to(w)?;
self.version.write_to(w)?;
self.stream_number.write_to(w)?;
Ok(())
}
}
impl ReadFrom for Header {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self {
expires_time: Time::read_from(r)?,
object_type: ObjectType::read_from(r)?,
version: ObjectVersion::read_from(r)?,
stream_number: StreamNumber::read_from(r)?,
})
}
}
impl LenBm for Header {
fn len_bm(&self) -> usize {
self.expires_time.len_bm()
+ self.object_type.len_bm()
+ self.version.len_bm()
+ self.stream_number.len_bm()
}
}
#[test]
fn test_header_write_to() {
let test = Header::new(
0x0123_4567_89ab_cdef.into(),
2.into(),
3u64.into(),
1u32.into(),
);
let mut bytes = Vec::new();
test.write_to(&mut bytes).unwrap();
let expected = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00, 0x00, 0x02, 3, 1,
];
assert_eq!(bytes, expected);
}
#[test]
fn test_header_read_from() {
use std::io::Cursor;
let mut bytes = Cursor::new([
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00, 0x00, 0x02, 3, 1,
]);
let test = Header::read_from(&mut bytes).unwrap();
let expected = Header::new(
0x0123_4567_89ab_cdef.into(),
2.into(),
3u64.into(),
1u32.into(),
);
assert_eq!(test, expected);
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum ObjectKind {
Getpubkey,
Pubkey,
Msg,
Broadcast,
Onionpeer,
}
const OBJECT_GETPUBKEY: u32 = 0;
const OBJECT_PUBKEY: u32 = 1;
const OBJECT_MSG: u32 = 2;
const OBJECT_BROADCAST: u32 = 3;
const OBJECT_ONIONPEER: u32 = 0x74_6f72;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TryFromObjectTypeError(ObjectType);
impl fmt::Display for TryFromObjectTypeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unknown object type {}", self.0)
}
}
impl std::error::Error for TryFromObjectTypeError {}
impl TryFrom<ObjectType> for ObjectKind {
type Error = TryFromObjectTypeError;
fn try_from(t: ObjectType) -> Result<Self, <Self as TryFrom<ObjectType>>::Error> {
match t.as_u32() {
OBJECT_GETPUBKEY => Ok(Self::Getpubkey),
OBJECT_PUBKEY => Ok(Self::Pubkey),
OBJECT_MSG => Ok(Self::Msg),
OBJECT_BROADCAST => Ok(Self::Broadcast),
OBJECT_ONIONPEER => Ok(Self::Onionpeer),
_ => Err(TryFromObjectTypeError(t)),
}
}
}
impl From<ObjectKind> for ObjectType {
fn from(kind: ObjectKind) -> Self {
match kind {
ObjectKind::Getpubkey => OBJECT_GETPUBKEY.into(),
ObjectKind::Pubkey => OBJECT_PUBKEY.into(),
ObjectKind::Msg => OBJECT_MSG.into(),
ObjectKind::Broadcast => OBJECT_BROADCAST.into(),
ObjectKind::Onionpeer => OBJECT_ONIONPEER.into(),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct BroadcastV4 {
encrypted: Vec<u8>,
}
impl BroadcastV4 {
pub fn encrypted(&self) -> &[u8] {
&self.encrypted
}
}
impl WriteTo for BroadcastV4 {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.encrypted.write_to(w)?;
Ok(())
}
}
impl SizedReadFrom for BroadcastV4 {
fn sized_read_from(r: &mut dyn Read, len: usize) -> io::Result<Self>
where
Self: Sized,
{
let encrypted = Vec::<u8>::sized_read_from(r, len)?;
Ok(Self { encrypted })
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Tag([u8; 32]);
impl Tag {
pub fn new(value: [u8; 32]) -> Self {
Self(value)
}
}
__impl_u8_array!(Tag);
__impl_index!(Tag, 0, u8);
impl WriteTo for Tag {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.0.write_to(w)
}
}
impl ReadFrom for Tag {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self(<[u8; 32]>::read_from(r)?))
}
}
impl LenBm for Tag {
fn len_bm(&self) -> usize {
self.0.len()
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct BroadcastV5 {
tag: Tag,
encrypted: Vec<u8>,
}
impl BroadcastV5 {
pub fn tag(&self) -> &Tag {
&self.tag
}
pub fn encrypted(&self) -> &[u8] {
&self.encrypted
}
}
impl WriteTo for BroadcastV5 {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.tag.write_to(w)?;
self.encrypted.write_to(w)?;
Ok(())
}
}
impl SizedReadFrom for BroadcastV5 {
fn sized_read_from(r: &mut dyn Read, len: usize) -> io::Result<Self>
where
Self: Sized,
{
let tag = Tag::read_from(r)?;
let encrypted = Vec::<u8>::sized_read_from(r, len - tag.len_bm())?;
Ok(Self { tag, encrypted })
}
}
#[derive(Debug)]
pub enum DecryptError {
IoError(io::Error),
DecryptError(crypto::DecryptError),
StreamsNotMatch {
headers: StreamNumber,
contents: StreamNumber,
},
FromPublicKeysError(FromPublicKeysError),
InvalidAddress {
expected: Address,
actual: Address,
},
VerifyError(crypto::VerifyError),
}
impl fmt::Display for DecryptError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::IoError(err) => err.fmt(f),
Self::DecryptError(err) => err.fmt(f),
Self::StreamsNotMatch { headers, contents } => write!(
f,
"streams not match: header's is {}, but content's is {}",
headers, contents
),
Self::FromPublicKeysError(err) => err.fmt(f),
Self::InvalidAddress { expected, actual } => write!(
f,
"address is expected to be {}, but actual is {}",
expected, actual
),
Self::VerifyError(err) => err.fmt(f),
}
}
}
impl std::error::Error for DecryptError {}
impl From<io::Error> for DecryptError {
fn from(err: io::Error) -> Self {
Self::IoError(err)
}
}
impl From<crypto::DecryptError> for DecryptError {
fn from(err: crypto::DecryptError) -> Self {
Self::DecryptError(err)
}
}
impl From<FromPublicKeysError> for DecryptError {
fn from(err: FromPublicKeysError) -> Self {
Self::FromPublicKeysError(err)
}
}
impl From<crypto::VerifyError> for DecryptError {
fn from(err: crypto::VerifyError) -> Self {
Self::VerifyError(err)
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Broadcast {
V4(BroadcastV4),
V5(BroadcastV5),
}
impl Broadcast {
fn encrypted(&self) -> &[u8] {
match self {
Self::V4(v4) => v4.encrypted(),
Self::V5(v5) => v5.encrypted(),
}
}
fn signed_header(&self, header: &Header) -> Result<Vec<u8>, io::Error> {
let mut bytes = Vec::new();
header.write_to(&mut bytes)?;
match self {
Self::V4(_) => (),
Self::V5(v5) => v5.tag.write_to(&mut bytes)?,
}
Ok(bytes)
}
pub fn decrypt(
&self,
header: &Header,
address: &Address,
) -> Result<content::Broadcast, DecryptError> {
let mut bytes = Cursor::new(self.encrypted());
let encrypted = crypto::Encrypted::sized_read_from(&mut bytes, self.encrypted().len())?;
let private_key = address.broadcast_key();
let decrypted = encrypted.decrypt(&private_key)?;
let mut bytes = Cursor::new(decrypted);
let content = content::Broadcast::read_from(&mut bytes)?;
if content.stream_number() != header.stream_number() {
return Err(DecryptError::StreamsNotMatch {
headers: header.stream_number(),
contents: content.stream_number(),
});
}
let a = content.address()?;
match self {
Self::V4(_) => {
if a.version().as_u64() < 2 || a.version().as_u64() > 3 {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
if a.hash() != address.hash() {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
}
Self::V5(_) => {
if a.version().as_u64() < 4 {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
if a.broadcast_key() != address.broadcast_key() {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
}
}
content.verify(self.signed_header(header)?)?;
Ok(content)
}
}
#[derive(Debug)]
pub enum TryIntoBroadcastError {
InvalidType(ObjectType),
IoError(io::Error),
UnsupportedVersion(ObjectVersion),
}
impl fmt::Display for TryIntoBroadcastError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidType(object_type) => write!(f, "invalid object type: {}", object_type),
Self::IoError(err) => err.fmt(f),
Self::UnsupportedVersion(version) => {
write!(f, "unsupported broadcast version: {}", version)
}
}
}
}
impl std::error::Error for TryIntoBroadcastError {}
impl From<io::Error> for TryIntoBroadcastError {
fn from(err: io::Error) -> Self {
Self::IoError(err)
}
}
impl TryFrom<message::Object> for Broadcast {
type Error = TryIntoBroadcastError;
fn try_from(
object: message::Object,
) -> Result<Self, <Self as TryFrom<message::Object>>::Error> {
let kind = ObjectKind::try_from(object.header().object_type());
if let Err(TryFromObjectTypeError(object_type)) = kind {
return Err(TryIntoBroadcastError::InvalidType(object_type));
}
if kind.unwrap() != ObjectKind::Broadcast {
return Err(TryIntoBroadcastError::InvalidType(
object.header().object_type(),
));
}
match object.header().version().as_u64() {
4 => {
let mut bytes = Cursor::new(object.object_payload());
let broadcast =
BroadcastV4::sized_read_from(&mut bytes, object.object_payload().len())?;
Ok(Self::V4(broadcast))
}
5 => {
let mut bytes = Cursor::new(object.object_payload());
let broadcast =
BroadcastV5::sized_read_from(&mut bytes, object.object_payload().len())?;
Ok(Self::V5(broadcast))
}
_ => Err(TryIntoBroadcastError::UnsupportedVersion(
object.header().version(),
)),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Onionpeer {
port: VarInt,
addr: Vec<u8>,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum TryFromOnionpeerError {
InvalidPort(u64),
InvalidLength(usize),
InvalidPrefix([u8; 6]),
ParseOnionV3AddrError(ParseOnionV3AddrError),
}
impl fmt::Display for TryFromOnionpeerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidPort(port) => write!(f, "invalid port: {}", port),
Self::InvalidLength(len) => write!(f, "invalid length: {}", len),
Self::InvalidPrefix(prefix) => {
write!(f, "invalid prefix: {}", prefix.as_ref().to_hex_string())
}
Self::ParseOnionV3AddrError(err) => err.fmt(f),
}
}
}
impl std::error::Error for TryFromOnionpeerError {}
impl From<ParseOnionV3AddrError> for TryFromOnionpeerError {
fn from(err: ParseOnionV3AddrError) -> Self {
Self::ParseOnionV3AddrError(err)
}
}
const ONION_PREFIX: [u8; 6] = [0xfd, 0x87, 0xd8, 0x7e, 0xeb, 0x43];
impl TryFrom<Onionpeer> for SocketAddrExt {
type Error = TryFromOnionpeerError;
fn try_from(op: Onionpeer) -> Result<Self, <Self as TryFrom<Onionpeer>>::Error> {
if op.port.as_u64() > u16::MAX as u64 {
return Err(TryFromOnionpeerError::InvalidPort(op.port.as_u64()));
}
let port = op.port.as_u64() as u16;
let len = op.addr.len();
if len == 16 {
let bytes: [u8; 16] = op.addr[..].try_into().unwrap();
let addr: Addr = bytes.into();
let addr: AddrExt = addr.into();
Ok(Self::new(addr, port))
} else if len == 6 + 35 {
if op.addr[0..6] != *ONION_PREFIX.as_ref() {
let prefix = op.addr[0..6].try_into().unwrap();
return Err(TryFromOnionpeerError::InvalidPrefix(prefix));
}
let mut bytes = [0; 35];
bytes.copy_from_slice(&op.addr[6..]);
let addr = OnionV3Addr::new(bytes)?;
let addr = AddrExt::OnionV3(addr);
Ok(Self::new(addr, port))
} else {
Err(TryFromOnionpeerError::InvalidLength(len))
}
}
}
const IPV4_PREFIX: [u8; 12] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff];
impl From<SocketAddrExt> for Onionpeer {
fn from(addr: SocketAddrExt) -> Self {
match addr {
SocketAddrExt::Ipv4(sa) => {
let port: VarInt = (sa.port() as u64).into();
let mut addr = IPV4_PREFIX.clone().to_vec();
addr.extend_from_slice(&sa.ip().octets());
Self { port, addr }
}
SocketAddrExt::Ipv6(sa) => {
let port: VarInt = (sa.port() as u64).into();
let addr = sa.ip().octets().to_vec();
Self { port, addr }
}
SocketAddrExt::OnionV2(sa) => {
let port: VarInt = (sa.port() as u64).into();
let mut addr = ONION_PREFIX.clone().to_vec();
addr.extend_from_slice(sa.onion_addr().as_ref());
Self { port, addr }
}
SocketAddrExt::OnionV3(sa) => {
let port: VarInt = (sa.port() as u64).into();
let mut addr = ONION_PREFIX.clone().to_vec();
addr.extend_from_slice(sa.onion_addr().as_ref());
Self { port, addr }
}
}
}
}
impl WriteTo for Onionpeer {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.port.write_to(w)?;
self.addr.write_to(w)?;
Ok(())
}
}
impl SizedReadFrom for Onionpeer {
fn sized_read_from(r: &mut dyn Read, len: usize) -> io::Result<Self>
where
Self: Sized,
{
let port = VarInt::read_from(r)?;
let addr = Vec::<u8>::sized_read_from(r, len - port.len_bm())?;
Ok(Self { port, addr })
}
}