use std::{
fmt::{Display, Formatter},
time::{Duration, SystemTime},
};
use crate::crypto::prelude::*;
use crate::internal::errors::CoreTypesError;
use crate::internal::prelude::TicketBuilder;
use crate::primitive::prelude::*;
#[derive(
Copy, Clone, Debug, smart_default::SmartDefault, strum::Display, strum::EnumDiscriminants,
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[strum_discriminants(vis(pub))]
#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(i8))]
#[cfg_attr(
feature = "serde",
strum_discriminants(derive(serde::Serialize, serde::Deserialize))
)]
#[strum(serialize_all = "PascalCase")]
pub enum ChannelStatus {
#[default]
Closed,
Open,
#[strum(serialize = "PendingToClose")]
PendingToClose(SystemTime),
}
impl ChannelStatus {
pub fn closure_time_elapsed(&self, current_time: &SystemTime) -> bool {
match self {
ChannelStatus::Closed => true,
ChannelStatus::Open => false,
ChannelStatus::PendingToClose(closure_time) => closure_time <= current_time,
}
}
}
impl From<ChannelStatus> for i8 {
fn from(value: ChannelStatus) -> Self {
match value {
ChannelStatus::Closed => 0,
ChannelStatus::Open => 1,
ChannelStatus::PendingToClose(_) => 2,
}
}
}
impl PartialEq for ChannelStatus {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Open, Self::Open) => true,
(Self::Closed, Self::Closed) => true,
(Self::PendingToClose(ct_1), Self::PendingToClose(ct_2)) => {
let diff = ct_1.max(ct_2).saturating_sub(*ct_1.min(ct_2));
diff.as_secs() == 0
}
_ => false,
}
}
}
impl Eq for ChannelStatus {}
impl std::hash::Hash for ChannelStatus {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
i8::from(*self).hash(state);
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::Display, strum::EnumString)]
#[strum(serialize_all = "lowercase")]
pub enum ChannelDirection {
Incoming = 0,
Outgoing = 1,
}
pub type ChannelId = Hash;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ChannelParties(Address, Address);
impl Display for ChannelParties {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{} -> {}", self.0, self.1)
}
}
impl ChannelParties {
pub fn new(source: Address, destination: Address) -> Self {
Self(source, destination)
}
pub fn source(&self) -> &Address {
&self.0
}
pub fn destination(&self) -> &Address {
&self.1
}
}
impl<'a> From<&'a ChannelParties> for ChannelId {
fn from(value: &'a ChannelParties) -> Self {
generate_channel_id(&value.0, &value.1)
}
}
impl<'a> From<&'a ChannelEntry> for ChannelParties {
fn from(value: &'a ChannelEntry) -> Self {
Self(value.source, value.destination)
}
}
#[derive(Debug, Copy, Clone, smart_default::SmartDefault)]
pub struct ChannelBuilder {
source: Option<Address>,
destination: Option<Address>,
balance: Option<HoprBalance>,
#[default(0)]
ticket_index: u64,
#[default(ChannelStatus::Open)]
status: ChannelStatus,
#[default(1)]
channel_epoch: u32,
}
impl ChannelBuilder {
pub const MAX_FUNDING_AMOUNT: u128 = 10_u128.pow(25);
pub const MAX_CHANNEL_STAKE: u128 = (1 << 96) - 1;
#[must_use]
pub fn source<A: Into<Address>>(mut self, source: A) -> Self {
self.source = Some(source.into());
self
}
#[must_use]
pub fn destination<A: Into<Address>>(mut self, destination: A) -> Self {
self.destination = Some(destination.into());
self
}
#[must_use]
pub fn between<A: Into<Address>, B: Into<Address>>(
mut self,
source: A,
destination: B,
) -> Self {
self.source = Some(source.into());
self.destination = Some(destination.into());
self
}
#[must_use]
pub fn amount<A: Into<U256>>(mut self, amount: A) -> Self {
self.balance = Some(HoprBalance::from(amount));
self
}
#[must_use]
pub fn balance(mut self, balance: HoprBalance) -> Self {
self.balance = Some(balance);
self
}
#[must_use]
pub fn ticket_index(mut self, ticket_index: u64) -> Self {
self.ticket_index = ticket_index;
self
}
#[must_use]
pub fn status(mut self, status: ChannelStatus) -> Self {
self.status = status;
self
}
#[must_use]
pub fn epoch(mut self, channel_epoch: u32) -> Self {
self.channel_epoch = channel_epoch;
self
}
pub fn build(self) -> crate::internal::errors::Result<ChannelEntry> {
let source = self
.source
.ok_or(CoreTypesError::InvalidInputData("missing source".into()))?;
let destination = self.destination.ok_or(CoreTypesError::InvalidInputData(
"missing destination".into(),
))?;
let balance = self
.balance
.ok_or(CoreTypesError::InvalidInputData("missing balance".into()))?;
if source == destination {
return Err(CoreTypesError::InvalidInputData(
"source and destination cannot be the same".into(),
));
}
Ok(ChannelEntry {
source,
destination,
balance: (balance <= Self::MAX_CHANNEL_STAKE.into())
.then_some(balance)
.ok_or(CoreTypesError::InvalidInputData("balance too high".into()))?,
ticket_index: (self.ticket_index <= TicketBuilder::MAX_TICKET_INDEX)
.then_some(self.ticket_index)
.ok_or(CoreTypesError::InvalidInputData(
"ticket index too high".into(),
))?,
status: self.status,
channel_epoch: (self.channel_epoch <= TicketBuilder::MAX_CHANNEL_EPOCH)
.then_some(self.channel_epoch)
.ok_or(CoreTypesError::InvalidInputData(
"channel epoch too high".into(),
))?,
id: generate_channel_id(&source, &destination),
})
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ChannelEntry {
pub source: Address,
pub destination: Address,
pub balance: HoprBalance,
pub ticket_index: u64,
pub status: ChannelStatus,
pub channel_epoch: u32,
id: ChannelId,
}
impl ChannelEntry {
#[must_use]
pub fn builder() -> ChannelBuilder {
ChannelBuilder::default()
}
#[deprecated(since = "1.3.0", note = "use ChannelBuilder instead")]
pub fn new(
source: Address,
destination: Address,
balance: HoprBalance,
ticket_index: u64,
status: ChannelStatus,
channel_epoch: u32,
) -> Self {
ChannelEntry {
source,
destination,
balance,
ticket_index,
status,
channel_epoch,
id: generate_channel_id(&source, &destination),
}
}
pub fn get_id(&self) -> &ChannelId {
&self.id
}
pub fn closure_time_passed(&self, current_time: SystemTime) -> bool {
self.status.closure_time_elapsed(¤t_time)
}
pub fn remaining_closure_time(&self, current_time: SystemTime) -> Option<Duration> {
match self.status {
ChannelStatus::Open => None,
ChannelStatus::PendingToClose(closure_time) => {
Some(closure_time.saturating_sub(current_time))
}
ChannelStatus::Closed => Some(Duration::ZERO),
}
}
pub fn closure_time_at(&self) -> Option<SystemTime> {
match self.status {
ChannelStatus::PendingToClose(ct) => Some(ct),
_ => None,
}
}
pub fn direction(&self, me: &Address) -> Option<ChannelDirection> {
if self.source.eq(me) {
Some(ChannelDirection::Outgoing)
} else if self.destination.eq(me) {
Some(ChannelDirection::Incoming)
} else {
None
}
}
pub fn orientation(&self, me: &Address) -> Option<(ChannelDirection, Address)> {
if self.source.eq(me) {
Some((ChannelDirection::Outgoing, self.destination))
} else if self.destination.eq(me) {
Some((ChannelDirection::Incoming, self.source))
} else {
None
}
}
pub fn diff(&self, other: &Self) -> Vec<ChannelChange> {
ChannelChange::diff_channels(self, other)
}
}
impl std::hash::Hash for ChannelEntry {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
std::hash::Hash::hash(&self.id, state);
self.channel_epoch.hash(state);
}
}
impl Display for ChannelEntry {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{} channel {}", self.status, self.get_id(),)
}
}
pub fn generate_channel_id(source: &Address, destination: &Address) -> Hash {
Hash::create(&[source.as_ref(), destination.as_ref()])
}
#[derive(Clone, Copy, Debug, strum::Display)]
pub enum ChannelChange {
#[strum(to_string = "status change: {left} -> {right}")]
Status {
left: ChannelStatus,
right: ChannelStatus,
},
#[strum(to_string = "balance change: {left} -> {right}")]
Balance {
left: HoprBalance,
right: HoprBalance,
},
#[strum(to_string = "epoch change: {left} -> {right}")]
Epoch { left: u32, right: u32 },
#[strum(to_string = "ticket index change: {left} -> {right}")]
TicketIndex { left: u64, right: u64 },
}
impl ChannelChange {
pub fn diff_channels(left: &ChannelEntry, right: &ChannelEntry) -> Vec<Self> {
assert_eq!(left.id, right.id, "must have equal ids");
if left == right {
return Vec::with_capacity(0);
}
let mut ret = Vec::with_capacity(4);
if left.status != right.status {
ret.push(ChannelChange::Status {
left: left.status,
right: right.status,
});
}
if left.balance != right.balance {
ret.push(ChannelChange::Balance {
left: left.balance,
right: right.balance,
});
}
if left.channel_epoch != right.channel_epoch {
ret.push(ChannelChange::Epoch {
left: left.channel_epoch,
right: right.channel_epoch,
});
}
if left.ticket_index != right.ticket_index {
ret.push(ChannelChange::TicketIndex {
left: left.ticket_index,
right: right.ticket_index,
})
}
ret
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CorruptedChannelEntry(ChannelId);
impl From<ChannelId> for CorruptedChannelEntry {
fn from(value: ChannelId) -> Self {
CorruptedChannelEntry(value)
}
}
impl CorruptedChannelEntry {
pub fn channel_id(&self) -> &ChannelId {
&self.0
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct SrcDstPair(Address, Address);
impl From<ChannelEntry> for SrcDstPair {
fn from(channel: ChannelEntry) -> Self {
SrcDstPair(channel.source, channel.destination)
}
}
#[cfg(test)]
mod tests {
use std::{
ops::Add,
str::FromStr,
time::{Duration, SystemTime},
};
use hex_literal::hex;
use super::*;
lazy_static::lazy_static! {
static ref ALICE: ChainKeypair = ChainKeypair::from_secret(&hex!("492057cf93e99b31d2a85bc5e98a9c3aa0021feec52c227cc8170e8f7d047775")).expect("lazy static keypair should be constructible");
static ref BOB: ChainKeypair = ChainKeypair::from_secret(&hex!("48680484c6fc31bc881a0083e6e32b6dc789f9eaba0f8b981429fd346c697f8c")).expect("lazy static keypair should be constructible");
static ref ADDRESS_1: Address = "3829b806aea42200c623c4d6b9311670577480ed".parse().expect("lazy static address should be constructible");
static ref ADDRESS_2: Address = "1a34729c69e95d6e11c3a9b9be3ea0c62c6dc5b1".parse().expect("lazy static address should be constructible");
}
#[test]
pub fn test_generate_id() -> anyhow::Result<()> {
let from = Address::from_str("0xa460f2e47c641b64535f5f4beeb9ac6f36f9d27c")?;
let to = Address::from_str("0xb8b75fef7efdf4530cf1688c933d94e4e519ccd1")?;
let id = generate_channel_id(&from, &to).to_string();
assert_eq!(
"0x1a410210ce7265f3070bf0e8885705dce452efcfbd90a5467525d136fcefc64a",
id
);
Ok(())
}
#[test]
fn channel_status_names() {
assert_eq!("Open", ChannelStatus::Open.to_string());
assert_eq!("Closed", ChannelStatus::Closed.to_string());
assert_eq!(
"PendingToClose",
ChannelStatus::PendingToClose(SystemTime::now()).to_string()
);
}
#[test]
fn channel_status_repr_compat() {
assert_eq!(
ChannelStatusDiscriminants::Open as i8,
i8::from(ChannelStatus::Open)
);
assert_eq!(
ChannelStatusDiscriminants::Closed as i8,
i8::from(ChannelStatus::Closed)
);
assert_eq!(
ChannelStatusDiscriminants::PendingToClose as i8,
i8::from(ChannelStatus::PendingToClose(SystemTime::now()))
);
}
#[test]
fn channel_builder_should_reject_invalid_values() -> anyhow::Result<()> {
let builder = ChannelBuilder::default()
.source(Address::from_str(
"0x1234567890123456789012345678901234567890",
)?)
.destination(Address::from_str(
"0xb8b75fef7efdf4530cf1688c933d94e4e519ccd1",
)?)
.amount(ChannelBuilder::MAX_CHANNEL_STAKE)
.ticket_index(TicketBuilder::MAX_TICKET_INDEX)
.status(ChannelStatus::Open)
.epoch(TicketBuilder::MAX_CHANNEL_EPOCH);
assert!(builder.build().is_ok());
let builder = builder.destination(Address::from_str(
"0x1234567890123456789012345678901234567890",
)?);
assert!(builder.build().is_err());
let builder = builder
.destination(Address::from_str(
"0xb8b75fef7efdf4530cf1688c933d94e4e519ccd1",
)?)
.ticket_index(TicketBuilder::MAX_TICKET_INDEX + 1);
assert!(builder.build().is_err());
let builder = builder
.ticket_index(TicketBuilder::MAX_TICKET_INDEX)
.amount(ChannelBuilder::MAX_CHANNEL_STAKE + 1);
assert!(builder.build().is_err());
let builder = builder
.amount(TicketBuilder::MAX_TICKET_AMOUNT)
.epoch(TicketBuilder::MAX_CHANNEL_EPOCH + 1);
assert!(builder.build().is_err());
Ok(())
}
#[test]
fn channel_entry_closure_time() -> anyhow::Result<()> {
let mut ce = ChannelBuilder::default()
.source(*ADDRESS_1)
.destination(*ADDRESS_2)
.amount(10)
.ticket_index(23)
.status(ChannelStatus::Open)
.epoch(3)
.build()?;
assert!(
!ce.closure_time_passed(SystemTime::now()),
"opened channel cannot pass closure time"
);
assert!(
ce.remaining_closure_time(SystemTime::now()).is_none(),
"opened channel cannot have remaining closure time"
);
let current_time = SystemTime::now();
ce.status = ChannelStatus::PendingToClose(current_time.add(Duration::from_secs(60)));
assert!(
!ce.closure_time_passed(current_time),
"must not have passed closure time"
);
assert_eq!(
60,
ce.remaining_closure_time(current_time)
.expect("must have closure time")
.as_secs()
);
let current_time = current_time.add(Duration::from_secs(120));
assert!(
ce.closure_time_passed(current_time),
"must have passed closure time"
);
assert_eq!(
Duration::ZERO,
ce.remaining_closure_time(current_time)
.expect("must have closure time")
);
Ok(())
}
}