use std::{
convert::TryFrom,
fmt,
io::{self, Cursor, Read, Write},
};
use crate::{
__impl_index, __impl_u8_array,
address::Address,
content, crypto,
io::{LenBm, ReadFrom, SizedReadFrom, WriteTo},
message,
object::{DecryptError, Header, ObjectKind, ObjectType, ObjectVersion, TryFromObjectTypeError},
priv_util::ToHexString,
};
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct BroadcastV4 {
encrypted: Vec<u8>,
}
impl BroadcastV4 {
pub fn encrypted(&self) -> &[u8] {
&self.encrypted
}
}
impl WriteTo for BroadcastV4 {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.encrypted.write_to(w)?;
Ok(())
}
}
impl SizedReadFrom for BroadcastV4 {
fn sized_read_from(r: &mut dyn Read, len: usize) -> io::Result<Self>
where
Self: Sized,
{
let encrypted = Vec::<u8>::sized_read_from(r, len)?;
Ok(Self { encrypted })
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Tag([u8; 32]);
impl Tag {
pub fn new(value: [u8; 32]) -> Self {
Self(value)
}
}
__impl_u8_array!(Tag);
__impl_index!(Tag, 0, u8);
impl WriteTo for Tag {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.0.write_to(w)
}
}
impl ReadFrom for Tag {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self(<[u8; 32]>::read_from(r)?))
}
}
impl LenBm for Tag {
fn len_bm(&self) -> usize {
self.0.len()
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct BroadcastV5 {
tag: Tag,
encrypted: Vec<u8>,
}
impl BroadcastV5 {
pub fn new(tag: Tag, encrypted: Vec<u8>) -> Self {
Self { tag, encrypted }
}
pub fn tag(&self) -> &Tag {
&self.tag
}
pub fn encrypted(&self) -> &[u8] {
&self.encrypted
}
}
impl WriteTo for BroadcastV5 {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.tag.write_to(w)?;
self.encrypted.write_to(w)?;
Ok(())
}
}
impl SizedReadFrom for BroadcastV5 {
fn sized_read_from(r: &mut dyn Read, len: usize) -> io::Result<Self>
where
Self: Sized,
{
let tag = Tag::read_from(r)?;
let encrypted = Vec::<u8>::sized_read_from(r, len - tag.len_bm())?;
Ok(Self { tag, encrypted })
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Broadcast {
V4(BroadcastV4),
V5(BroadcastV5),
}
impl Broadcast {
fn encrypted(&self) -> &[u8] {
match self {
Self::V4(v4) => v4.encrypted(),
Self::V5(v5) => v5.encrypted(),
}
}
fn signed_header(&self, header: &Header) -> Result<Vec<u8>, io::Error> {
let mut bytes = Vec::new();
header.write_to(&mut bytes)?;
match self {
Self::V4(_) => (),
Self::V5(v5) => v5.tag.write_to(&mut bytes)?,
}
Ok(bytes)
}
pub fn decrypt(
&self,
header: &Header,
address: &Address,
) -> Result<content::Broadcast, DecryptError> {
let mut bytes = Cursor::new(self.encrypted());
let encrypted = crypto::Encrypted::sized_read_from(&mut bytes, self.encrypted().len())?;
let private_key = address.broadcast_private_encryption_key()?;
let decrypted = encrypted.decrypt(&private_key)?;
let mut bytes = Cursor::new(decrypted);
let content = content::Broadcast::read_from(&mut bytes)?;
if content.stream_number() != header.stream_number() {
return Err(DecryptError::StreamsNotMatch {
headers: header.stream_number(),
contents: content.stream_number(),
});
}
let a = content.address()?;
match self {
Self::V4(_) => {
if a.version().as_u64() < 2 || a.version().as_u64() > 3 {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
if a.hash() != address.hash() {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
}
Self::V5(_) => {
if a.version().as_u64() < 4 {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
if a.broadcast_private_encryption_key()? != private_key {
return Err(DecryptError::InvalidAddress {
expected: address.clone(),
actual: a,
});
}
}
}
content.verify(self.signed_header(header)?)?;
Ok(content)
}
}
#[derive(Debug)]
pub enum TryIntoBroadcastError {
InvalidType(ObjectType),
IoError(io::Error),
UnsupportedVersion(ObjectVersion),
}
impl fmt::Display for TryIntoBroadcastError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidType(object_type) => write!(f, "invalid object type: {}", object_type),
Self::IoError(err) => err.fmt(f),
Self::UnsupportedVersion(version) => {
write!(f, "unsupported broadcast version: {}", version)
}
}
}
}
impl std::error::Error for TryIntoBroadcastError {}
impl From<io::Error> for TryIntoBroadcastError {
fn from(err: io::Error) -> Self {
Self::IoError(err)
}
}
impl TryFrom<message::Object> for Broadcast {
type Error = TryIntoBroadcastError;
fn try_from(
object: message::Object,
) -> Result<Self, <Self as TryFrom<message::Object>>::Error> {
let kind = ObjectKind::try_from(object.header().object_type());
if let Err(TryFromObjectTypeError(object_type)) = kind {
return Err(TryIntoBroadcastError::InvalidType(object_type));
}
if kind.unwrap() != ObjectKind::Broadcast {
return Err(TryIntoBroadcastError::InvalidType(
object.header().object_type(),
));
}
match object.header().version().as_u64() {
4 => {
let mut bytes = Cursor::new(object.object_payload());
let broadcast =
BroadcastV4::sized_read_from(&mut bytes, object.object_payload().len())?;
Ok(Self::V4(broadcast))
}
5 => {
let mut bytes = Cursor::new(object.object_payload());
let broadcast =
BroadcastV5::sized_read_from(&mut bytes, object.object_payload().len())?;
Ok(Self::V5(broadcast))
}
_ => Err(TryIntoBroadcastError::UnsupportedVersion(
object.header().version(),
)),
}
}
}