use crate::{
LightningError, LightningNode, NodeInfo, PaymentOutcome, PaymentResult, SimulationError,
};
use async_trait::async_trait;
use bitcoin::constants::ChainHash;
use bitcoin::hashes::{sha256::Hash as Sha256, Hash};
use bitcoin::secp256k1::PublicKey;
use bitcoin::{Network, ScriptBuf, TxOut};
use lightning::ln::chan_utils::make_funding_redeemscript;
use std::collections::{hash_map::Entry, HashMap};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use lightning::ln::features::{ChannelFeatures, NodeFeatures};
use lightning::ln::msgs::{
LightningError as LdkError, UnsignedChannelAnnouncement, UnsignedChannelUpdate,
};
use lightning::ln::{PaymentHash, PaymentPreimage};
use lightning::routing::gossip::{NetworkGraph, NodeId};
use lightning::routing::router::{find_route, Path, PaymentParameters, Route, RouteParameters};
use lightning::routing::scoring::ProbabilisticScorer;
use lightning::routing::utxo::{UtxoLookup, UtxoResult};
use lightning::util::logger::{Level, Logger, Record};
use thiserror::Error;
use tokio::select;
use tokio::sync::oneshot::{channel, Receiver, Sender};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use triggered::{Listener, Trigger};
use crate::ShortChannelID;
#[derive(Debug, Error)]
pub enum ForwardingError {
#[error("ZeroAmountHtlc")]
ZeroAmountHtlc,
#[error("ChannelNotFound: {0}")]
ChannelNotFound(ShortChannelID),
#[error("NodeNotFound: {0:?}")]
NodeNotFound(PublicKey),
#[error("PaymentHashExists: {0:?}")]
PaymentHashExists(PaymentHash),
#[error("PaymentHashNotFound: {0:?}")]
PaymentHashNotFound(PaymentHash),
#[error("InsufficientBalance: amount: {0} > balance: {1}")]
InsufficientBalance(u64, u64),
#[error("LessThanMinimum: amount: {0} < minimum: {1}")]
LessThanMinimum(u64, u64),
#[error("MoreThanMaximum: amount: {0} > maximum: {1}")]
MoreThanMaximum(u64, u64),
#[error("ExceedsInFlightCount: total in flight: {0} > maximum count: {1}")]
ExceedsInFlightCount(u64, u64),
#[error("ExceedsInFlightTotal: total in flight amount: {0} > maximum amount: {0}")]
ExceedsInFlightTotal(u64, u64),
#[error("ExpiryInSeconds: cltv expressed in seconds: {0}")]
ExpiryInSeconds(u32, u32),
#[error("InsufficientCltvDelta: cltv delta: {0} < required: {1}")]
InsufficientCltvDelta(u32, u32),
#[error("InsufficientFee: offered fee: {0} (base: {1}, prop: {2}) < expected: {3}")]
InsufficientFee(u64, u64, u64, u64),
#[error("FeeOverflow: htlc amount: {0} (base: {1}, prop: {2})")]
FeeOverflow(u64, u64, u64),
#[error("SanityCheckFailed: node balance: {0} != capacity: {1}")]
SanityCheckFailed(u64, u64),
}
impl ForwardingError {
fn is_critical(&self) -> bool {
matches!(
self,
ForwardingError::ZeroAmountHtlc
| ForwardingError::ChannelNotFound(_)
| ForwardingError::NodeNotFound(_)
| ForwardingError::PaymentHashExists(_)
| ForwardingError::PaymentHashNotFound(_)
| ForwardingError::SanityCheckFailed(_, _)
| ForwardingError::FeeOverflow(_, _, _)
)
}
}
#[derive(Copy, Clone)]
struct Htlc {
amount_msat: u64,
cltv_expiry: u32,
}
#[derive(Clone)]
pub struct ChannelPolicy {
pub pubkey: PublicKey,
pub max_htlc_count: u64,
pub max_in_flight_msat: u64,
pub min_htlc_size_msat: u64,
pub max_htlc_size_msat: u64,
pub cltv_expiry_delta: u32,
pub base_fee: u64,
pub fee_rate_prop: u64,
}
impl ChannelPolicy {
fn validate(&self, capacity_msat: u64) -> Result<(), SimulationError> {
if self.max_in_flight_msat > capacity_msat {
return Err(SimulationError::SimulatedNetworkError(format!(
"max_in_flight_msat {} > capacity {}",
self.max_in_flight_msat, capacity_msat
)));
}
if self.max_htlc_size_msat > capacity_msat {
return Err(SimulationError::SimulatedNetworkError(format!(
"max_htlc_size_msat {} > capacity {}",
self.max_htlc_size_msat, capacity_msat
)));
}
Ok(())
}
}
macro_rules! fail_forwarding_inequality {
($value_1:expr, $op:tt, $value_2:expr, $error_variant:ident $(, $opt:expr)*) => {
if $value_1 $op $value_2 {
return Err(ForwardingError::$error_variant(
$value_1,
$value_2
$(
, $opt
)*
));
}
};
}
#[derive(Clone)]
struct ChannelState {
local_balance_msat: u64,
in_flight: HashMap<PaymentHash, Htlc>,
policy: ChannelPolicy,
}
impl ChannelState {
fn new(policy: ChannelPolicy, local_balance_msat: u64) -> Self {
ChannelState {
local_balance_msat,
in_flight: HashMap::new(),
policy,
}
}
fn in_flight_total(&self) -> u64 {
self.in_flight.values().map(|h| h.amount_msat).sum()
}
fn check_htlc_forward(
&self,
cltv_delta: u32,
amt: u64,
fee: u64,
) -> Result<(), ForwardingError> {
fail_forwarding_inequality!(cltv_delta, <, self.policy.cltv_expiry_delta, InsufficientCltvDelta);
let expected_fee = amt
.checked_mul(self.policy.fee_rate_prop)
.and_then(|prop_fee| (prop_fee / 1000000).checked_add(self.policy.base_fee))
.ok_or(ForwardingError::FeeOverflow(
amt,
self.policy.base_fee,
self.policy.fee_rate_prop,
))?;
fail_forwarding_inequality!(
fee, <, expected_fee, InsufficientFee, self.policy.base_fee, self.policy.fee_rate_prop
);
Ok(())
}
fn check_outgoing_addition(&self, htlc: &Htlc) -> Result<(), ForwardingError> {
fail_forwarding_inequality!(htlc.amount_msat, >, self.policy.max_htlc_size_msat, MoreThanMaximum);
fail_forwarding_inequality!(htlc.amount_msat, <, self.policy.min_htlc_size_msat, LessThanMinimum);
fail_forwarding_inequality!(
self.in_flight.len() as u64 + 1, >, self.policy.max_htlc_count, ExceedsInFlightCount
);
fail_forwarding_inequality!(
self.in_flight_total() + htlc.amount_msat, >, self.policy.max_in_flight_msat, ExceedsInFlightTotal
);
fail_forwarding_inequality!(htlc.amount_msat, >, self.local_balance_msat, InsufficientBalance);
fail_forwarding_inequality!(htlc.cltv_expiry, >, 500000000, ExpiryInSeconds);
Ok(())
}
fn add_outgoing_htlc(&mut self, hash: PaymentHash, htlc: Htlc) -> Result<(), ForwardingError> {
self.check_outgoing_addition(&htlc)?;
if self.in_flight.get(&hash).is_some() {
return Err(ForwardingError::PaymentHashExists(hash));
}
self.local_balance_msat -= htlc.amount_msat;
self.in_flight.insert(hash, htlc);
Ok(())
}
fn remove_outgoing_htlc(&mut self, hash: &PaymentHash) -> Result<Htlc, ForwardingError> {
self.in_flight
.remove(hash)
.ok_or(ForwardingError::PaymentHashNotFound(*hash))
}
fn settle_outgoing_htlc(&mut self, amt: u64, success: bool) {
if !success {
self.local_balance_msat += amt
}
}
fn settle_incoming_htlc(&mut self, amt: u64, success: bool) {
if success {
self.local_balance_msat += amt
}
}
}
#[derive(Clone)]
pub struct SimulatedChannel {
capacity_msat: u64,
short_channel_id: ShortChannelID,
node_1: ChannelState,
node_2: ChannelState,
}
impl SimulatedChannel {
pub fn new(
capacity_msat: u64,
short_channel_id: ShortChannelID,
node_1: ChannelPolicy,
node_2: ChannelPolicy,
) -> Self {
SimulatedChannel {
capacity_msat,
short_channel_id,
node_1: ChannelState::new(node_1, capacity_msat / 2),
node_2: ChannelState::new(node_2, capacity_msat / 2),
}
}
fn validate(&self) -> Result<(), SimulationError> {
if self.node_1.policy.pubkey == self.node_2.policy.pubkey {
return Err(SimulationError::SimulatedNetworkError(format!(
"Channel should have distinct node pubkeys, got: {} for both nodes.",
self.node_1.policy.pubkey
)));
}
self.node_1.policy.validate(self.capacity_msat)?;
self.node_2.policy.validate(self.capacity_msat)?;
Ok(())
}
fn get_node_mut(&mut self, pubkey: &PublicKey) -> Result<&mut ChannelState, ForwardingError> {
if pubkey == &self.node_1.policy.pubkey {
Ok(&mut self.node_1)
} else if pubkey == &self.node_2.policy.pubkey {
Ok(&mut self.node_2)
} else {
Err(ForwardingError::NodeNotFound(*pubkey))
}
}
fn get_node(&self, pubkey: &PublicKey) -> Result<&ChannelState, ForwardingError> {
if pubkey == &self.node_1.policy.pubkey {
Ok(&self.node_1)
} else if pubkey == &self.node_2.policy.pubkey {
Ok(&self.node_2)
} else {
Err(ForwardingError::NodeNotFound(*pubkey))
}
}
fn add_htlc(
&mut self,
sending_node: &PublicKey,
hash: PaymentHash,
htlc: Htlc,
) -> Result<(), ForwardingError> {
if htlc.amount_msat == 0 {
return Err(ForwardingError::ZeroAmountHtlc);
}
self.get_node_mut(sending_node)?
.add_outgoing_htlc(hash, htlc)?;
self.sanity_check()
}
fn sanity_check(&self) -> Result<(), ForwardingError> {
let node_1_total = self.node_1.local_balance_msat + self.node_1.in_flight_total();
let node_2_total = self.node_2.local_balance_msat + self.node_2.in_flight_total();
fail_forwarding_inequality!(node_1_total + node_2_total, !=, self.capacity_msat, SanityCheckFailed);
Ok(())
}
fn remove_htlc(
&mut self,
sending_node: &PublicKey,
hash: &PaymentHash,
success: bool,
) -> Result<(), ForwardingError> {
let htlc = self
.get_node_mut(sending_node)?
.remove_outgoing_htlc(hash)?;
self.settle_htlc(sending_node, htlc.amount_msat, success)?;
self.sanity_check()
}
fn settle_htlc(
&mut self,
sending_node: &PublicKey,
amount_msat: u64,
success: bool,
) -> Result<(), ForwardingError> {
if sending_node == &self.node_1.policy.pubkey {
self.node_1.settle_outgoing_htlc(amount_msat, success);
self.node_2.settle_incoming_htlc(amount_msat, success);
Ok(())
} else if sending_node == &self.node_2.policy.pubkey {
self.node_2.settle_outgoing_htlc(amount_msat, success);
self.node_1.settle_incoming_htlc(amount_msat, success);
Ok(())
} else {
Err(ForwardingError::NodeNotFound(*sending_node))
}
}
fn check_htlc_forward(
&self,
forwarding_node: &PublicKey,
cltv_delta: u32,
amount_msat: u64,
fee_msat: u64,
) -> Result<(), ForwardingError> {
self.get_node(forwarding_node)?
.check_htlc_forward(cltv_delta, amount_msat, fee_msat)
}
}
#[async_trait]
trait SimNetwork: Send + Sync {
fn dispatch_payment(
&mut self,
source: PublicKey,
route: Route,
payment_hash: PaymentHash,
sender: Sender<Result<PaymentResult, LightningError>>,
);
async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError>;
}
struct SimNode<'a, T: SimNetwork> {
info: NodeInfo,
network: Arc<Mutex<T>>,
in_flight: HashMap<PaymentHash, Receiver<Result<PaymentResult, LightningError>>>,
pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
}
impl<'a, T: SimNetwork> SimNode<'a, T> {
pub fn new(
pubkey: PublicKey,
payment_network: Arc<Mutex<T>>,
pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
) -> Self {
SimNode {
info: node_info(pubkey),
network: payment_network,
in_flight: HashMap::new(),
pathfinding_graph,
}
}
}
fn node_info(pubkey: PublicKey) -> NodeInfo {
let mut features = NodeFeatures::empty();
features.set_keysend_optional();
NodeInfo {
pubkey,
alias: "".to_string(), features,
}
}
fn find_payment_route(
source: &PublicKey,
dest: PublicKey,
amount_msat: u64,
pathfinding_graph: &NetworkGraph<&WrappedLog>,
) -> Result<Route, SimulationError> {
let scorer = ProbabilisticScorer::new(Default::default(), pathfinding_graph, &WrappedLog {});
find_route(
source,
&RouteParameters {
payment_params: PaymentParameters::from_node_id(dest, 0)
.with_max_total_cltv_expiry_delta(u32::MAX)
.with_max_path_count(1)
.with_max_channel_saturation_power_of_half(1),
final_value_msat: amount_msat,
max_total_routing_fee_msat: None,
},
pathfinding_graph,
None,
&WrappedLog {},
&scorer,
&Default::default(),
&[0; 32],
)
.map_err(|e| SimulationError::SimulatedNetworkError(e.err))
}
#[async_trait]
impl<T: SimNetwork> LightningNode for SimNode<'_, T> {
fn get_info(&self) -> &NodeInfo {
&self.info
}
async fn get_network(&mut self) -> Result<Network, LightningError> {
Ok(Network::Regtest)
}
async fn send_payment(
&mut self,
dest: PublicKey,
amount_msat: u64,
) -> Result<PaymentHash, LightningError> {
let (sender, receiver) = channel();
let preimage = PaymentPreimage(rand::random());
let payment_hash = PaymentHash(Sha256::hash(&preimage.0).to_byte_array());
match self.in_flight.entry(payment_hash) {
Entry::Occupied(_) => {
return Err(LightningError::SendPaymentError(
"payment hash exists".to_string(),
));
},
Entry::Vacant(vacant) => {
vacant.insert(receiver);
},
}
let route = match find_payment_route(
&self.info.pubkey,
dest,
amount_msat,
&self.pathfinding_graph,
) {
Ok(path) => path,
Err(e) => {
log::trace!("Could not find path for payment: {:?}.", e);
if let Err(e) = sender.send(Ok(PaymentResult {
htlc_count: 0,
payment_outcome: PaymentOutcome::RouteNotFound,
})) {
log::error!("Could not send payment result: {:?}.", e);
}
return Ok(payment_hash);
},
};
self.network
.lock()
.await
.dispatch_payment(self.info.pubkey, route, payment_hash, sender);
Ok(payment_hash)
}
async fn track_payment(
&mut self,
hash: &PaymentHash,
listener: Listener,
) -> Result<PaymentResult, LightningError> {
match self.in_flight.remove(hash) {
Some(receiver) => {
select! {
biased;
_ = listener => Err(
LightningError::TrackPaymentError("shutdown during payment tracking".to_string()),
),
res = receiver => {
res.map_err(|e| LightningError::TrackPaymentError(format!("channel receive err: {}", e)))?
},
}
},
None => Err(LightningError::TrackPaymentError(format!(
"payment hash {} not found",
hex::encode(hash.0),
))),
}
}
async fn get_node_info(&mut self, node_id: &PublicKey) -> Result<NodeInfo, LightningError> {
Ok(self.network.lock().await.lookup_node(node_id).await?.0)
}
async fn list_channels(&mut self) -> Result<Vec<u64>, LightningError> {
Ok(self
.network
.lock()
.await
.lookup_node(&self.info.pubkey)
.await?
.1)
}
}
pub struct SimGraph {
nodes: HashMap<PublicKey, Vec<u64>>,
channels: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
tasks: JoinSet<()>,
shutdown_trigger: Trigger,
}
impl SimGraph {
pub fn new(
graph_channels: Vec<SimulatedChannel>,
shutdown_trigger: Trigger,
) -> Result<Self, SimulationError> {
let mut nodes: HashMap<PublicKey, Vec<u64>> = HashMap::new();
let mut channels = HashMap::new();
for channel in graph_channels.iter() {
channel.validate()?;
match channels.entry(channel.short_channel_id) {
Entry::Occupied(_) => {
return Err(SimulationError::SimulatedNetworkError(format!(
"Simulated short channel ID should be unique: {} duplicated",
channel.short_channel_id
)))
},
Entry::Vacant(v) => v.insert(channel.clone()),
};
for pubkey in [channel.node_1.policy.pubkey, channel.node_2.policy.pubkey] {
match nodes.entry(pubkey) {
Entry::Occupied(o) => o.into_mut().push(channel.capacity_msat),
Entry::Vacant(v) => {
v.insert(vec![channel.capacity_msat]);
},
}
}
}
Ok(SimGraph {
nodes,
channels: Arc::new(Mutex::new(channels)),
tasks: JoinSet::new(),
shutdown_trigger,
})
}
pub async fn wait_for_shutdown(&mut self) {
log::debug!("Waiting for simulated graph to shutdown.");
while let Some(res) = self.tasks.join_next().await {
if let Err(e) = res {
log::error!("Graph task exited with error: {e}");
}
}
log::debug!("Simulated graph shutdown.");
}
}
pub async fn ln_node_from_graph<'a>(
graph: Arc<Mutex<SimGraph>>,
routing_graph: Arc<NetworkGraph<&'_ WrappedLog>>,
) -> HashMap<PublicKey, Arc<Mutex<dyn LightningNode + '_>>> {
let mut nodes: HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>> = HashMap::new();
for pk in graph.lock().await.nodes.keys() {
nodes.insert(
*pk,
Arc::new(Mutex::new(SimNode::new(
*pk,
graph.clone(),
routing_graph.clone(),
))),
);
}
nodes
}
pub fn populate_network_graph<'a>(
channels: Vec<SimulatedChannel>,
) -> Result<NetworkGraph<&'a WrappedLog>, LdkError> {
let graph = NetworkGraph::new(Network::Regtest, &WrappedLog {});
let chain_hash = ChainHash::using_genesis_block(Network::Regtest);
for channel in channels {
let announcement = UnsignedChannelAnnouncement {
features: ChannelFeatures::empty(),
chain_hash,
short_channel_id: channel.short_channel_id.into(),
node_id_1: NodeId::from_pubkey(&channel.node_1.policy.pubkey),
node_id_2: NodeId::from_pubkey(&channel.node_2.policy.pubkey),
bitcoin_key_1: NodeId::from_pubkey(&channel.node_1.policy.pubkey),
bitcoin_key_2: NodeId::from_pubkey(&channel.node_2.policy.pubkey),
excess_data: Vec::new(),
};
let utxo_validator = UtxoValidator {
amount_sat: channel.capacity_msat / 1000,
script: make_funding_redeemscript(
&channel.node_1.policy.pubkey,
&channel.node_2.policy.pubkey,
)
.to_v0_p2wsh(),
};
graph.update_channel_from_unsigned_announcement(&announcement, &Some(&utxo_validator))?;
for (i, node) in [channel.node_1, channel.node_2].iter().enumerate() {
let update = UnsignedChannelUpdate {
chain_hash,
short_channel_id: channel.short_channel_id.into(),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as u32,
flags: i as u8,
cltv_expiry_delta: node.policy.cltv_expiry_delta as u16,
htlc_minimum_msat: node.policy.min_htlc_size_msat,
htlc_maximum_msat: node.policy.max_htlc_size_msat,
fee_base_msat: node.policy.base_fee as u32,
fee_proportional_millionths: node.policy.fee_rate_prop as u32,
excess_data: Vec::new(),
};
graph.update_channel_unsigned(&update)?;
}
}
Ok(graph)
}
#[async_trait]
impl SimNetwork for SimGraph {
fn dispatch_payment(
&mut self,
source: PublicKey,
route: Route,
payment_hash: PaymentHash,
sender: Sender<Result<PaymentResult, LightningError>>,
) {
let path = match route.paths.first() {
Some(p) => p,
None => {
log::warn!("Find route did not return expected number of paths.");
if let Err(e) = sender.send(Ok(PaymentResult {
htlc_count: 0,
payment_outcome: PaymentOutcome::RouteNotFound,
})) {
log::error!("Could not send payment result: {:?}.", e);
}
return;
},
};
self.tasks.spawn(propagate_payment(
self.channels.clone(),
source,
path.clone(),
payment_hash,
sender,
self.shutdown_trigger.clone(),
));
}
async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError> {
self.nodes
.get(node)
.map(|channels| (node_info(*node), channels.clone()))
.ok_or(LightningError::GetNodeInfoError(
"Node not found".to_string(),
))
}
}
async fn add_htlcs(
nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
source: PublicKey,
route: Path,
payment_hash: PaymentHash,
) -> Result<(), (Option<usize>, ForwardingError)> {
let mut outgoing_node = source;
let mut outgoing_amount = route.fee_msat() + route.final_value_msat();
let mut outgoing_cltv = route.hops.iter().map(|hop| hop.cltv_expiry_delta).sum();
let mut fail_idx = None;
for (i, hop) in route.hops.iter().enumerate() {
let mut node_lock = nodes.lock().await;
let scid = ShortChannelID::from(hop.short_channel_id);
if let Some(channel) = node_lock.get_mut(&scid) {
channel
.add_htlc(
&outgoing_node,
payment_hash,
Htlc {
amount_msat: outgoing_amount,
cltv_expiry: outgoing_cltv,
},
)
.map_err(|e| (fail_idx, e))?;
fail_idx = Some(i);
if i != route.hops.len() - 1 {
if let Some(channel) =
node_lock.get(&ShortChannelID::from(route.hops[i + 1].short_channel_id))
{
channel
.check_htlc_forward(
&hop.pubkey,
hop.cltv_expiry_delta,
outgoing_amount - hop.fee_msat,
hop.fee_msat,
)
.map_err(|e| (fail_idx, e))?;
}
}
} else {
return Err((fail_idx, ForwardingError::ChannelNotFound(scid)));
}
outgoing_node = hop.pubkey;
outgoing_amount -= hop.fee_msat;
outgoing_cltv -= hop.cltv_expiry_delta;
}
Ok(())
}
async fn remove_htlcs(
nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
resolution_idx: usize,
source: PublicKey,
route: Path,
payment_hash: PaymentHash,
success: bool,
) -> Result<(), ForwardingError> {
for (i, hop) in route.hops[0..=resolution_idx].iter().enumerate().rev() {
let incoming_node = if i == 0 {
source
} else {
route.hops[i - 1].pubkey
};
match nodes
.lock()
.await
.get_mut(&ShortChannelID::from(hop.short_channel_id))
{
Some(channel) => channel.remove_htlc(&incoming_node, &payment_hash, success)?,
None => {
return Err(ForwardingError::ChannelNotFound(ShortChannelID::from(
hop.short_channel_id,
)))
},
}
}
Ok(())
}
async fn propagate_payment(
nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
source: PublicKey,
route: Path,
payment_hash: PaymentHash,
sender: Sender<Result<PaymentResult, LightningError>>,
shutdown: Trigger,
) {
let notify_result = if let Err((fail_idx, err)) =
add_htlcs(nodes.clone(), source, route.clone(), payment_hash).await
{
if err.is_critical() {
shutdown.trigger();
}
if let Some(resolution_idx) = fail_idx {
if let Err(e) =
remove_htlcs(nodes, resolution_idx, source, route, payment_hash, false).await
{
if e.is_critical() {
shutdown.trigger();
}
}
}
log::debug!(
"Forwarding failure for simulated payment {}: {err}",
hex::encode(payment_hash.0)
);
PaymentResult {
htlc_count: 0,
payment_outcome: PaymentOutcome::Unknown,
}
} else {
if let Err(e) = remove_htlcs(
nodes,
route.hops.len() - 1,
source,
route,
payment_hash,
true,
)
.await
{
if e.is_critical() {
shutdown.trigger();
}
log::error!("Could not remove htlcs from channel: {e}.");
}
PaymentResult {
htlc_count: 1,
payment_outcome: PaymentOutcome::Success,
}
};
if let Err(e) = sender.send(Ok(notify_result)) {
log::error!("Could not notify payment result: {:?}.", e);
}
}
pub struct WrappedLog {}
impl Logger for WrappedLog {
fn log(&self, record: Record) {
match record.level {
Level::Gossip => log::trace!("{}", record.args),
Level::Trace => log::trace!("{}", record.args),
Level::Debug => log::debug!("{}", record.args),
Level::Info => log::debug!("{}", record.args),
Level::Warn => log::warn!("{}", record.args),
Level::Error => log::error!("{}", record.args),
}
}
}
struct UtxoValidator {
amount_sat: u64,
script: ScriptBuf,
}
impl UtxoLookup for UtxoValidator {
fn get_utxo(&self, _genesis_hash: &ChainHash, _short_channel_id: u64) -> UtxoResult {
UtxoResult::Sync(Ok(TxOut {
value: self.amount_sat,
script_pubkey: self.script.clone(),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::get_random_keypair;
use bitcoin::secp256k1::PublicKey;
use lightning::routing::router::Route;
use mockall::mock;
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::timeout;
fn create_test_policy(max_in_flight_msat: u64) -> ChannelPolicy {
let (_, pk) = get_random_keypair();
ChannelPolicy {
pubkey: pk,
max_htlc_count: 10,
max_in_flight_msat,
min_htlc_size_msat: 2,
max_htlc_size_msat: max_in_flight_msat / 2,
cltv_expiry_delta: 10,
base_fee: 1000,
fee_rate_prop: 5000,
}
}
fn create_simulated_channels(n: u64, capacity_msat: u64) -> Vec<SimulatedChannel> {
let mut channels: Vec<SimulatedChannel> = vec![];
let (_, first_node) = get_random_keypair();
let mut node_1 = first_node;
for i in 0..n {
let (_, node_2) = get_random_keypair();
let node_1_to_2 = ChannelPolicy {
pubkey: node_1,
max_htlc_count: 483,
max_in_flight_msat: capacity_msat / 2,
min_htlc_size_msat: 1,
max_htlc_size_msat: capacity_msat / 2,
cltv_expiry_delta: 40,
base_fee: 1000 * i,
fee_rate_prop: 1500 * i,
};
let node_2_to_1 = ChannelPolicy {
pubkey: node_2,
max_htlc_count: 483,
max_in_flight_msat: capacity_msat / 2,
min_htlc_size_msat: 1,
max_htlc_size_msat: capacity_msat / 2,
cltv_expiry_delta: 40 + 10 * i as u32,
base_fee: 2000 * i,
fee_rate_prop: i,
};
channels.push(SimulatedChannel {
capacity_msat,
short_channel_id: ShortChannelID::from(i),
node_1: ChannelState::new(node_1_to_2, capacity_msat),
node_2: ChannelState::new(node_2_to_1, 0),
});
node_1 = node_2;
}
channels
}
macro_rules! assert_channel_balances {
($channel_state:expr, $local_balance:expr, $in_flight_len:expr, $in_flight_total:expr) => {
assert_eq!($channel_state.local_balance_msat, $local_balance);
assert_eq!($channel_state.in_flight.len(), $in_flight_len);
assert_eq!($channel_state.in_flight_total(), $in_flight_total);
};
}
#[test]
fn test_channel_state_transitions() {
let local_balance = 100_000_000;
let mut channel_state =
ChannelState::new(create_test_policy(local_balance / 2), local_balance);
assert_channel_balances!(channel_state, local_balance, 0, 0);
let hash_1 = PaymentHash([1; 32]);
let htlc_1 = Htlc {
amount_msat: 1000,
cltv_expiry: 40,
};
assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok());
assert_channel_balances!(
channel_state,
local_balance - htlc_1.amount_msat,
1,
htlc_1.amount_msat
);
assert!(matches!(
channel_state.add_outgoing_htlc(hash_1, htlc_1),
Err(ForwardingError::PaymentHashExists(_))
));
let hash_2 = PaymentHash([2; 32]);
let htlc_2 = Htlc {
amount_msat: 1000,
cltv_expiry: 40,
};
assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok());
assert_channel_balances!(
channel_state,
local_balance - htlc_1.amount_msat - htlc_2.amount_msat,
2,
htlc_1.amount_msat + htlc_2.amount_msat
);
assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok());
channel_state.settle_outgoing_htlc(htlc_2.amount_msat, false);
assert_channel_balances!(
channel_state,
local_balance - htlc_1.amount_msat,
1,
htlc_1.amount_msat
);
assert!(matches!(
channel_state.remove_outgoing_htlc(&hash_2),
Err(ForwardingError::PaymentHashNotFound(_))
));
assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok());
channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true);
assert_channel_balances!(channel_state, local_balance - htlc_1.amount_msat, 0, 0);
}
#[test]
fn test_htlc_forward() {
let local_balance = 140_000;
let channel_state = ChannelState::new(create_test_policy(local_balance / 2), local_balance);
assert!(matches!(
channel_state.check_htlc_forward(channel_state.policy.cltv_expiry_delta - 1, 0, 0),
Err(ForwardingError::InsufficientCltvDelta(_, _))
));
let htlc_amount = 1000;
let htlc_fee = channel_state.policy.base_fee
+ (channel_state.policy.fee_rate_prop * htlc_amount) / 1e6 as u64;
assert!(matches!(
channel_state.check_htlc_forward(
channel_state.policy.cltv_expiry_delta,
htlc_amount,
htlc_fee - 1
),
Err(ForwardingError::InsufficientFee(_, _, _, _))
));
assert!(channel_state
.check_htlc_forward(
channel_state.policy.cltv_expiry_delta,
htlc_amount,
htlc_fee,
)
.is_ok());
assert!(channel_state
.check_htlc_forward(
channel_state.policy.cltv_expiry_delta * 2,
htlc_amount,
htlc_fee * 3
)
.is_ok());
}
#[test]
fn test_check_outgoing_addition() {
let local_balance = 100_000;
let mut channel_state =
ChannelState::new(create_test_policy(local_balance / 2), local_balance);
let mut htlc = Htlc {
amount_msat: channel_state.policy.max_htlc_size_msat + 1,
cltv_expiry: channel_state.policy.cltv_expiry_delta,
};
assert!(matches!(
channel_state.check_outgoing_addition(&htlc),
Err(ForwardingError::MoreThanMaximum(_, _))
));
htlc.amount_msat = channel_state.policy.min_htlc_size_msat - 1;
assert!(matches!(
channel_state.check_outgoing_addition(&htlc),
Err(ForwardingError::LessThanMinimum(_, _))
));
let hash_1 = PaymentHash([1; 32]);
let htlc_1 = Htlc {
amount_msat: channel_state.policy.max_in_flight_msat / 2,
cltv_expiry: channel_state.policy.cltv_expiry_delta,
};
assert!(channel_state.check_outgoing_addition(&htlc_1).is_ok());
assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok());
let hash_2 = PaymentHash([2; 32]);
let htlc_2 = Htlc {
amount_msat: channel_state.policy.max_in_flight_msat / 2,
cltv_expiry: channel_state.policy.cltv_expiry_delta,
};
assert!(channel_state.check_outgoing_addition(&htlc_2).is_ok());
assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok());
htlc.amount_msat = channel_state.policy.min_htlc_size_msat;
assert!(matches!(
channel_state.check_outgoing_addition(&htlc),
Err(ForwardingError::ExceedsInFlightTotal(_, _))
));
assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok());
channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true);
assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok());
channel_state.settle_outgoing_htlc(htlc_2.amount_msat, true);
for i in 0..channel_state.policy.max_htlc_count {
let hash = PaymentHash([i.try_into().unwrap(); 32]);
assert!(channel_state.check_outgoing_addition(&htlc).is_ok());
assert!(channel_state.add_outgoing_htlc(hash, htlc).is_ok());
}
let htlc_3 = Htlc {
amount_msat: channel_state.policy.min_htlc_size_msat,
cltv_expiry: channel_state.policy.cltv_expiry_delta,
};
assert!(matches!(
channel_state.check_outgoing_addition(&htlc_3),
Err(ForwardingError::ExceedsInFlightCount(_, _))
));
for i in 0..channel_state.policy.max_htlc_count {
let hash = PaymentHash([i.try_into().unwrap(); 32]);
assert!(channel_state.remove_outgoing_htlc(&hash).is_ok());
channel_state.settle_outgoing_htlc(htlc.amount_msat, true)
}
let hash_4 = PaymentHash([1; 32]);
let htlc_4 = Htlc {
amount_msat: channel_state.policy.max_htlc_size_msat,
cltv_expiry: channel_state.policy.cltv_expiry_delta,
};
assert!(channel_state.check_outgoing_addition(&htlc_4).is_ok());
assert!(channel_state.add_outgoing_htlc(hash_4, htlc_4).is_ok());
assert!(channel_state.remove_outgoing_htlc(&hash_4).is_ok());
channel_state.settle_outgoing_htlc(htlc_4.amount_msat, true);
assert!(channel_state.local_balance_msat < channel_state.policy.max_htlc_size_msat);
assert!(matches!(
channel_state.check_outgoing_addition(&htlc_4),
Err(ForwardingError::InsufficientBalance(_, _))
));
}
#[test]
fn test_simulated_channel() {
let capacity_msat = 500_000_000;
let node_1 = ChannelState::new(create_test_policy(capacity_msat / 2), capacity_msat);
let node_2 = ChannelState::new(create_test_policy(capacity_msat / 2), 0);
let mut simulated_channel = SimulatedChannel {
capacity_msat,
short_channel_id: ShortChannelID::from(123),
node_1: node_1.clone(),
node_2: node_2.clone(),
};
let hash_1 = PaymentHash([1; 32]);
let htlc_1 = Htlc {
amount_msat: node_2.policy.min_htlc_size_msat,
cltv_expiry: node_1.policy.cltv_expiry_delta,
};
assert!(matches!(
simulated_channel.add_htlc(&node_2.policy.pubkey, hash_1, htlc_1),
Err(ForwardingError::InsufficientBalance(_, _))
));
let hash_2 = PaymentHash([1; 32]);
let htlc_2 = Htlc {
amount_msat: node_1.policy.max_htlc_size_msat,
cltv_expiry: node_2.policy.cltv_expiry_delta,
};
assert!(simulated_channel
.add_htlc(&node_1.policy.pubkey, hash_2, htlc_2)
.is_ok());
assert!(simulated_channel
.remove_htlc(&node_1.policy.pubkey, &hash_2, true)
.is_ok());
assert!(simulated_channel
.add_htlc(&node_2.policy.pubkey, hash_2, htlc_2)
.is_ok());
let (_, pk) = get_random_keypair();
assert!(matches!(
simulated_channel.add_htlc(&pk, hash_2, htlc_2),
Err(ForwardingError::NodeNotFound(_))
));
assert!(matches!(
simulated_channel.remove_htlc(&pk, &hash_2, true),
Err(ForwardingError::NodeNotFound(_))
));
}
mock! {
Network{}
#[async_trait]
impl SimNetwork for Network{
fn dispatch_payment(
&mut self,
source: PublicKey,
route: Route,
payment_hash: PaymentHash,
sender: Sender<Result<PaymentResult, LightningError>>,
);
async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError>;
}
}
#[tokio::test]
async fn test_simulated_node() {
let mock = MockNetwork::new();
let sim_network = Arc::new(Mutex::new(mock));
let channels = create_simulated_channels(5, 300000000);
let graph = populate_network_graph(channels.clone()).unwrap();
let pk = channels[0].node_1.policy.pubkey;
let mut node = SimNode::new(pk, sim_network.clone(), Arc::new(graph));
let lookup_pk = channels[3].node_1.policy.pubkey;
sim_network
.lock()
.await
.expect_lookup_node()
.returning(move |_| Ok((node_info(lookup_pk), vec![1, 2, 3])));
let node_info = node.get_node_info(&lookup_pk).await.unwrap();
assert_eq!(lookup_pk, node_info.pubkey);
assert_eq!(node.list_channels().await.unwrap().len(), 3);
let dest_1 = channels[2].node_1.policy.pubkey;
let dest_2 = channels[4].node_1.policy.pubkey;
sim_network
.lock()
.await
.expect_dispatch_payment()
.returning(
move |_, route: Route, _, sender: Sender<Result<PaymentResult, LightningError>>| {
let receiver = route.paths[0].hops.last().unwrap().pubkey;
let result = if receiver == dest_1 {
PaymentResult {
htlc_count: 2,
payment_outcome: PaymentOutcome::Success,
}
} else if receiver == dest_2 {
PaymentResult {
htlc_count: 0,
payment_outcome: PaymentOutcome::InsufficientBalance,
}
} else {
panic!("unknown mocked receiver");
};
sender.send(Ok(result)).unwrap();
},
);
let hash_1 = node.send_payment(dest_1, 10_000).await.unwrap();
let hash_2 = node.send_payment(dest_2, 15_000).await.unwrap();
let (_, shutdown_listener) = triggered::trigger();
let result_1 = node
.track_payment(&hash_1, shutdown_listener.clone())
.await
.unwrap();
assert!(matches!(result_1.payment_outcome, PaymentOutcome::Success));
let result_2 = node
.track_payment(&hash_2, shutdown_listener.clone())
.await
.unwrap();
assert!(matches!(
result_2.payment_outcome,
PaymentOutcome::InsufficientBalance
));
}
struct DispatchPaymentTestKit<'a> {
graph: SimGraph,
nodes: Vec<PublicKey>,
routing_graph: NetworkGraph<&'a WrappedLog>,
shutdown: triggered::Trigger,
}
impl<'a> DispatchPaymentTestKit<'a> {
async fn new(capacity: u64) -> Self {
let (shutdown, _listener) = triggered::trigger();
let channels = create_simulated_channels(3, capacity);
let mut nodes = channels
.iter()
.map(|c| c.node_1.policy.pubkey)
.collect::<Vec<PublicKey>>();
nodes.push(channels.last().unwrap().node_2.policy.pubkey);
let kit = DispatchPaymentTestKit {
graph: SimGraph::new(channels.clone(), shutdown.clone())
.expect("could not create test graph"),
nodes,
routing_graph: populate_network_graph(channels).unwrap(),
shutdown,
};
assert_eq!(
kit.channel_balances().await,
vec![(capacity, 0), (capacity, 0), (capacity, 0)]
);
kit
}
async fn channel_balances(&self) -> Vec<(u64, u64)> {
let mut balances = vec![];
let chan_count = self.graph.channels.lock().await.len();
for i in 0..chan_count {
let chan_lock = self.graph.channels.lock().await;
let channel = chan_lock.get(&ShortChannelID::from(i as u64)).unwrap();
balances.push((
channel.node_1.local_balance_msat,
channel.node_2.local_balance_msat,
));
}
balances
}
async fn send_test_payemnt(
&mut self,
source: PublicKey,
dest: PublicKey,
amt: u64,
) -> Route {
let route = find_payment_route(&source, dest, amt, &self.routing_graph).unwrap();
let (sender, receiver) = oneshot::channel();
self.graph
.dispatch_payment(source, route.clone(), PaymentHash([1; 32]), sender);
assert!(timeout(Duration::from_millis(10), receiver).await.is_ok());
route
}
async fn set_channel_balance(&mut self, scid: &ShortChannelID, balance: (u64, u64)) {
let mut channels_lock = self.graph.channels.lock().await;
let channel = channels_lock.get_mut(scid).unwrap();
channel.node_1.local_balance_msat = balance.0;
channel.node_2.local_balance_msat = balance.1;
assert!(channel.sanity_check().is_ok());
}
}
#[tokio::test]
async fn test_successful_dispatch() {
let chan_capacity = 500_000_000;
let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
let mut amt = 20_000;
let route = test_kit
.send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
.await;
let route_total = amt + route.get_total_fees();
let hop_1_amt = amt + route.paths[0].hops[1].fee_msat;
let alice_to_bob = (chan_capacity - route_total, route_total);
let mut bob_to_carol = (chan_capacity - hop_1_amt, hop_1_amt);
let carol_to_dave = (chan_capacity - amt, amt);
let mut expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
assert_eq!(test_kit.channel_balances().await, expected_balances);
let _ = test_kit
.send_test_payemnt(test_kit.nodes[3], test_kit.nodes[1], amt * 2)
.await;
assert_eq!(test_kit.channel_balances().await, expected_balances);
amt = bob_to_carol.0 / 2;
let _ = test_kit
.send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt)
.await;
bob_to_carol = (bob_to_carol.0 / 2, bob_to_carol.1 + amt);
expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
assert_eq!(test_kit.channel_balances().await, expected_balances);
let _ = test_kit
.send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt)
.await;
bob_to_carol = (0, chan_capacity);
expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
assert_eq!(test_kit.channel_balances().await, expected_balances);
let _ = test_kit
.send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], 20_000)
.await;
assert_eq!(test_kit.channel_balances().await, expected_balances);
test_kit.shutdown.trigger();
test_kit.graph.wait_for_shutdown().await;
}
#[tokio::test]
async fn test_successful_multi_hop() {
let chan_capacity = 500_000_000;
let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
let amt = 20_000;
let route = test_kit
.send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
.await;
let route_total = amt + route.get_total_fees();
let hop_1_amt = amt + route.paths[0].hops[1].fee_msat;
let expected_balances = vec![
(chan_capacity - route_total, route_total),
(chan_capacity - hop_1_amt, hop_1_amt),
(chan_capacity - amt, amt),
];
assert_eq!(test_kit.channel_balances().await, expected_balances);
test_kit.shutdown.trigger();
test_kit.graph.wait_for_shutdown().await;
}
#[tokio::test]
async fn test_single_hop_payments() {
let chan_capacity = 500_000_000;
let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
let amt = 150_000;
let _ = test_kit
.send_test_payemnt(test_kit.nodes[0], test_kit.nodes[1], amt)
.await;
let expected_balances = vec![
(chan_capacity - amt, amt),
(chan_capacity, 0),
(chan_capacity, 0),
];
assert_eq!(test_kit.channel_balances().await, expected_balances);
let _ = test_kit
.send_test_payemnt(test_kit.nodes[3], test_kit.nodes[2], amt)
.await;
assert_eq!(test_kit.channel_balances().await, expected_balances);
test_kit.shutdown.trigger();
test_kit.graph.wait_for_shutdown().await;
}
#[tokio::test]
async fn test_multi_hop_faiulre() {
let chan_capacity = 500_000_000;
let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
test_kit
.set_channel_balance(&ShortChannelID::from(1), (0, chan_capacity))
.await;
let mut expected_balances =
vec![(chan_capacity, 0), (0, chan_capacity), (chan_capacity, 0)];
assert_eq!(test_kit.channel_balances().await, expected_balances);
let amt = 150_000;
let _ = test_kit
.send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
.await;
assert_eq!(test_kit.channel_balances().await, expected_balances);
expected_balances[2] = (0, chan_capacity);
test_kit
.set_channel_balance(&ShortChannelID::from(2), (0, chan_capacity))
.await;
let _ = test_kit
.send_test_payemnt(test_kit.nodes[3], test_kit.nodes[0], amt)
.await;
assert_eq!(test_kit.channel_balances().await, expected_balances);
test_kit.shutdown.trigger();
test_kit.graph.wait_for_shutdown().await;
}
}