1use crate::events::{Event, EventQueue};
13use crate::lsps0::ser::{
14 LSPSMessage, ProtocolMessageHandler, RequestId, ResponseError,
15 JSONRPC_INTERNAL_ERROR_ERROR_CODE, JSONRPC_INTERNAL_ERROR_ERROR_MESSAGE,
16 LSPS0_CLIENT_REJECTED_ERROR_CODE,
17};
18use crate::lsps2::event::LSPS2ServiceEvent;
19use crate::lsps2::payment_queue::{InterceptedHTLC, PaymentQueue};
20use crate::lsps2::utils::{
21 compute_opening_fee, is_expired_opening_fee_params, is_valid_opening_fee_params,
22};
23use crate::message_queue::MessageQueue;
24use crate::prelude::hash_map::Entry;
25use crate::prelude::{new_hash_map, HashMap, String, ToString, Vec};
26use crate::sync::{Arc, Mutex, MutexGuard, RwLock};
27
28use lightning::events::HTLCDestination;
29use lightning::ln::channelmanager::{AChannelManager, InterceptId};
30use lightning::ln::msgs::{ErrorAction, LightningError};
31use lightning::ln::types::ChannelId;
32use lightning::util::errors::APIError;
33use lightning::util::logger::Level;
34
35use lightning_types::payment::PaymentHash;
36
37use bitcoin::secp256k1::PublicKey;
38
39use core::ops::Deref;
40use core::sync::atomic::{AtomicUsize, Ordering};
41
42use crate::lsps2::msgs::{
43 BuyRequest, BuyResponse, GetInfoRequest, GetInfoResponse, LSPS2Message, LSPS2Request,
44 LSPS2Response, OpeningFeeParams, RawOpeningFeeParams,
45 LSPS2_BUY_REQUEST_INVALID_OPENING_FEE_PARAMS_ERROR_CODE,
46 LSPS2_BUY_REQUEST_PAYMENT_SIZE_TOO_LARGE_ERROR_CODE,
47 LSPS2_BUY_REQUEST_PAYMENT_SIZE_TOO_SMALL_ERROR_CODE,
48 LSPS2_GET_INFO_REQUEST_UNRECOGNIZED_OR_STALE_TOKEN_ERROR_CODE,
49};
50
51const MAX_PENDING_REQUESTS_PER_PEER: usize = 10;
52const MAX_TOTAL_PENDING_REQUESTS: usize = 1000;
53const MAX_TOTAL_PEERS: usize = 100000;
54
55#[derive(Clone, Debug)]
57pub struct LSPS2ServiceConfig {
58 pub promise_secret: [u8; 32],
62}
63
64#[derive(Clone, Debug, PartialEq)]
67struct OpenChannelParams {
68 opening_fee_msat: u64,
69 amt_to_forward_msat: u64,
70}
71
72#[derive(Clone, Debug, PartialEq)]
74struct FeePayment {
75 htlcs: Vec<InterceptedHTLC>,
76 opening_fee_msat: u64,
77}
78
79#[derive(Debug)]
80struct ChannelStateError(String);
81
82impl From<ChannelStateError> for LightningError {
83 fn from(value: ChannelStateError) -> Self {
84 LightningError { err: value.0, action: ErrorAction::IgnoreAndLog(Level::Info) }
85 }
86}
87
88#[derive(Debug, PartialEq)]
90enum HTLCInterceptedAction {
91 OpenChannel(OpenChannelParams),
93 ForwardHTLC(ChannelId),
95 ForwardPayment(ChannelId, FeePayment),
96}
97
98#[derive(Debug, PartialEq)]
100struct ForwardPaymentAction(ChannelId, FeePayment);
101
102#[derive(Debug, PartialEq)]
104struct ForwardHTLCsAction(ChannelId, Vec<InterceptedHTLC>);
105
106#[derive(Debug)]
108enum OutboundJITChannelState {
109 PendingInitialPayment { payment_queue: Arc<Mutex<PaymentQueue>> },
112 PendingChannelOpen { payment_queue: Arc<Mutex<PaymentQueue>>, opening_fee_msat: u64 },
115 PendingPaymentForward {
119 payment_queue: Arc<Mutex<PaymentQueue>>,
120 opening_fee_msat: u64,
121 channel_id: ChannelId,
122 },
123 PendingPayment {
128 payment_queue: Arc<Mutex<PaymentQueue>>,
129 opening_fee_msat: u64,
130 channel_id: ChannelId,
131 },
132 PaymentForwarded { channel_id: ChannelId },
135}
136
137impl OutboundJITChannelState {
138 fn new() -> Self {
139 OutboundJITChannelState::PendingInitialPayment {
140 payment_queue: Arc::new(Mutex::new(PaymentQueue::new())),
141 }
142 }
143
144 fn htlc_intercepted(
145 &mut self, opening_fee_params: &OpeningFeeParams, payment_size_msat: &Option<u64>,
146 htlc: InterceptedHTLC,
147 ) -> Result<(Self, Option<HTLCInterceptedAction>), ChannelStateError> {
148 match self {
149 OutboundJITChannelState::PendingInitialPayment { payment_queue } => {
150 let (total_expected_outbound_amount_msat, num_htlcs) =
151 payment_queue.lock().unwrap().add_htlc(htlc);
152
153 let (expected_payment_size_msat, mpp_mode) =
154 if let Some(payment_size_msat) = payment_size_msat {
155 (*payment_size_msat, true)
156 } else {
157 debug_assert_eq!(num_htlcs, 1);
158 if num_htlcs != 1 {
159 return Err(ChannelStateError(
160 "Paying via multiple HTLCs is disallowed in \"no-MPP+var-invoice\" mode.".to_string()
161 ));
162 }
163 (total_expected_outbound_amount_msat, false)
164 };
165
166 if expected_payment_size_msat < opening_fee_params.min_payment_size_msat
167 || expected_payment_size_msat > opening_fee_params.max_payment_size_msat
168 {
169 return Err(ChannelStateError(
170 format!("Payment size violates our limits: expected_payment_size_msat = {}, min_payment_size_msat = {}, max_payment_size_msat = {}",
171 expected_payment_size_msat,
172 opening_fee_params.min_payment_size_msat,
173 opening_fee_params.max_payment_size_msat
174 )));
175 }
176
177 let opening_fee_msat = compute_opening_fee(
178 expected_payment_size_msat,
179 opening_fee_params.min_fee_msat,
180 opening_fee_params.proportional.into(),
181 ).ok_or(ChannelStateError(
182 format!("Could not compute valid opening fee with min_fee_msat = {}, proportional = {}, and expected_payment_size_msat = {}",
183 opening_fee_params.min_fee_msat,
184 opening_fee_params.proportional,
185 expected_payment_size_msat
186 )
187 ))?;
188
189 let amt_to_forward_msat =
190 expected_payment_size_msat.saturating_sub(opening_fee_msat);
191
192 if total_expected_outbound_amount_msat >= expected_payment_size_msat
194 && amt_to_forward_msat > 0
195 {
196 let pending_channel_open = OutboundJITChannelState::PendingChannelOpen {
197 payment_queue: Arc::clone(&payment_queue),
198 opening_fee_msat,
199 };
200 let open_channel = HTLCInterceptedAction::OpenChannel(OpenChannelParams {
201 opening_fee_msat,
202 amt_to_forward_msat,
203 });
204 Ok((pending_channel_open, Some(open_channel)))
205 } else {
206 if mpp_mode {
207 let pending_initial_payment =
208 OutboundJITChannelState::PendingInitialPayment {
209 payment_queue: Arc::clone(&payment_queue),
210 };
211 Ok((pending_initial_payment, None))
212 } else {
213 Err(ChannelStateError(
214 "Intercepted HTLC is too small to pay opening fee".to_string(),
215 ))
216 }
217 }
218 },
219 OutboundJITChannelState::PendingChannelOpen { payment_queue, opening_fee_msat } => {
220 let mut payment_queue_lock = payment_queue.lock().unwrap();
221 payment_queue_lock.add_htlc(htlc);
222 let pending_channel_open = OutboundJITChannelState::PendingChannelOpen {
223 payment_queue: payment_queue.clone(),
224 opening_fee_msat: *opening_fee_msat,
225 };
226 Ok((pending_channel_open, None))
227 },
228 OutboundJITChannelState::PendingPaymentForward {
229 payment_queue,
230 opening_fee_msat,
231 channel_id,
232 } => {
233 let mut payment_queue_lock = payment_queue.lock().unwrap();
234 payment_queue_lock.add_htlc(htlc);
235 let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward {
236 payment_queue: payment_queue.clone(),
237 opening_fee_msat: *opening_fee_msat,
238 channel_id: *channel_id,
239 };
240 Ok((pending_payment_forward, None))
241 },
242 OutboundJITChannelState::PendingPayment {
243 payment_queue,
244 opening_fee_msat,
245 channel_id,
246 } => {
247 let mut payment_queue_lock = payment_queue.lock().unwrap();
248 payment_queue_lock.add_htlc(htlc);
249 if let Some((_payment_hash, htlcs)) =
250 payment_queue_lock.pop_greater_than_msat(*opening_fee_msat)
251 {
252 let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward {
253 payment_queue: payment_queue.clone(),
254 opening_fee_msat: *opening_fee_msat,
255 channel_id: *channel_id,
256 };
257 let forward_payment = HTLCInterceptedAction::ForwardPayment(
258 *channel_id,
259 FeePayment { htlcs, opening_fee_msat: *opening_fee_msat },
260 );
261 Ok((pending_payment_forward, Some(forward_payment)))
262 } else {
263 let pending_payment = OutboundJITChannelState::PendingPayment {
264 payment_queue: payment_queue.clone(),
265 opening_fee_msat: *opening_fee_msat,
266 channel_id: *channel_id,
267 };
268 Ok((pending_payment, None))
269 }
270 },
271 OutboundJITChannelState::PaymentForwarded { channel_id } => {
272 let payment_forwarded =
273 OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
274 let forward = HTLCInterceptedAction::ForwardHTLC(*channel_id);
275 Ok((payment_forwarded, Some(forward)))
276 },
277 }
278 }
279
280 fn channel_ready(
281 &self, channel_id: ChannelId,
282 ) -> Result<(Self, ForwardPaymentAction), ChannelStateError> {
283 match self {
284 OutboundJITChannelState::PendingChannelOpen { payment_queue, opening_fee_msat } => {
285 let mut payment_queue_lock = payment_queue.lock().unwrap();
286 if let Some((_payment_hash, htlcs)) =
287 payment_queue_lock.pop_greater_than_msat(*opening_fee_msat)
288 {
289 let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward {
290 payment_queue: Arc::clone(&payment_queue),
291 opening_fee_msat: *opening_fee_msat,
292 channel_id,
293 };
294 let forward_payment = ForwardPaymentAction(
295 channel_id,
296 FeePayment { opening_fee_msat: *opening_fee_msat, htlcs },
297 );
298 Ok((pending_payment_forward, forward_payment))
299 } else {
300 Err(ChannelStateError(
301 "No forwardable payment available when moving to channel ready."
302 .to_string(),
303 ))
304 }
305 },
306 state => Err(ChannelStateError(format!(
307 "Channel ready received when JIT Channel was in state: {:?}",
308 state
309 ))),
310 }
311 }
312
313 fn htlc_handling_failed(
314 &mut self,
315 ) -> Result<(Self, Option<ForwardPaymentAction>), ChannelStateError> {
316 match self {
317 OutboundJITChannelState::PendingPaymentForward {
318 payment_queue,
319 opening_fee_msat,
320 channel_id,
321 } => {
322 let mut payment_queue_lock = payment_queue.lock().unwrap();
323 if let Some((_payment_hash, htlcs)) =
324 payment_queue_lock.pop_greater_than_msat(*opening_fee_msat)
325 {
326 let pending_payment_forward = OutboundJITChannelState::PendingPaymentForward {
327 payment_queue: payment_queue.clone(),
328 opening_fee_msat: *opening_fee_msat,
329 channel_id: *channel_id,
330 };
331 let forward_payment = ForwardPaymentAction(
332 *channel_id,
333 FeePayment { htlcs, opening_fee_msat: *opening_fee_msat },
334 );
335 Ok((pending_payment_forward, Some(forward_payment)))
336 } else {
337 let pending_payment = OutboundJITChannelState::PendingPayment {
338 payment_queue: payment_queue.clone(),
339 opening_fee_msat: *opening_fee_msat,
340 channel_id: *channel_id,
341 };
342 Ok((pending_payment, None))
343 }
344 },
345 OutboundJITChannelState::PendingPayment {
346 payment_queue,
347 opening_fee_msat,
348 channel_id,
349 } => {
350 let pending_payment = OutboundJITChannelState::PendingPayment {
351 payment_queue: payment_queue.clone(),
352 opening_fee_msat: *opening_fee_msat,
353 channel_id: *channel_id,
354 };
355 Ok((pending_payment, None))
356 },
357 OutboundJITChannelState::PaymentForwarded { channel_id } => {
358 let payment_forwarded =
359 OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
360 Ok((payment_forwarded, None))
361 },
362 state => Err(ChannelStateError(format!(
363 "HTLC handling failed when JIT Channel was in state: {:?}",
364 state
365 ))),
366 }
367 }
368
369 fn payment_forwarded(
370 &mut self,
371 ) -> Result<(Self, Option<ForwardHTLCsAction>), ChannelStateError> {
372 match self {
373 OutboundJITChannelState::PendingPaymentForward {
374 payment_queue, channel_id, ..
375 } => {
376 let mut payment_queue_lock = payment_queue.lock().unwrap();
377 let payment_forwarded =
378 OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
379 let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue_lock.clear());
380 Ok((payment_forwarded, Some(forward_htlcs)))
381 },
382 OutboundJITChannelState::PaymentForwarded { channel_id } => {
383 let payment_forwarded =
384 OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
385 Ok((payment_forwarded, None))
386 },
387 state => Err(ChannelStateError(format!(
388 "Payment forwarded when JIT Channel was in state: {:?}",
389 state
390 ))),
391 }
392 }
393}
394
395struct OutboundJITChannel {
396 state: OutboundJITChannelState,
397 user_channel_id: u128,
398 opening_fee_params: OpeningFeeParams,
399 payment_size_msat: Option<u64>,
400}
401
402impl OutboundJITChannel {
403 fn new(
404 payment_size_msat: Option<u64>, opening_fee_params: OpeningFeeParams, user_channel_id: u128,
405 ) -> Self {
406 Self {
407 user_channel_id,
408 state: OutboundJITChannelState::new(),
409 opening_fee_params,
410 payment_size_msat,
411 }
412 }
413
414 fn htlc_intercepted(
415 &mut self, htlc: InterceptedHTLC,
416 ) -> Result<Option<HTLCInterceptedAction>, LightningError> {
417 let (new_state, action) =
418 self.state.htlc_intercepted(&self.opening_fee_params, &self.payment_size_msat, htlc)?;
419 self.state = new_state;
420 Ok(action)
421 }
422
423 fn htlc_handling_failed(&mut self) -> Result<Option<ForwardPaymentAction>, LightningError> {
424 let (new_state, action) = self.state.htlc_handling_failed()?;
425 self.state = new_state;
426 Ok(action)
427 }
428
429 fn channel_ready(
430 &mut self, channel_id: ChannelId,
431 ) -> Result<ForwardPaymentAction, LightningError> {
432 let (new_state, action) = self.state.channel_ready(channel_id)?;
433 self.state = new_state;
434 Ok(action)
435 }
436
437 fn payment_forwarded(&mut self) -> Result<Option<ForwardHTLCsAction>, LightningError> {
438 let (new_state, action) = self.state.payment_forwarded()?;
439 self.state = new_state;
440 Ok(action)
441 }
442
443 fn is_pending_initial_payment(&self) -> bool {
444 matches!(self.state, OutboundJITChannelState::PendingInitialPayment { .. })
445 }
446
447 fn is_prunable(&self) -> bool {
448 let is_expired = is_expired_opening_fee_params(&self.opening_fee_params);
451 self.is_pending_initial_payment() && is_expired
452 }
453}
454
455struct PeerState {
456 outbound_channels_by_intercept_scid: HashMap<u64, OutboundJITChannel>,
457 intercept_scid_by_user_channel_id: HashMap<u128, u64>,
458 intercept_scid_by_channel_id: HashMap<ChannelId, u64>,
459 pending_requests: HashMap<RequestId, LSPS2Request>,
460}
461
462impl PeerState {
463 fn new() -> Self {
464 let outbound_channels_by_intercept_scid = new_hash_map();
465 let pending_requests = new_hash_map();
466 let intercept_scid_by_user_channel_id = new_hash_map();
467 let intercept_scid_by_channel_id = new_hash_map();
468 Self {
469 outbound_channels_by_intercept_scid,
470 pending_requests,
471 intercept_scid_by_user_channel_id,
472 intercept_scid_by_channel_id,
473 }
474 }
475
476 fn insert_outbound_channel(&mut self, intercept_scid: u64, channel: OutboundJITChannel) {
477 self.outbound_channels_by_intercept_scid.insert(intercept_scid, channel);
478 }
479
480 fn prune_expired_request_state(&mut self) {
481 self.pending_requests.retain(|_, entry| {
482 match entry {
483 LSPS2Request::GetInfo(_) => false,
484 LSPS2Request::Buy(request) => {
485 !is_expired_opening_fee_params(&request.opening_fee_params)
487 },
488 }
489 });
490
491 self.outbound_channels_by_intercept_scid.retain(|intercept_scid, entry| {
492 if entry.is_prunable() {
493 self.intercept_scid_by_channel_id.retain(|_, iscid| intercept_scid != iscid);
495 self.intercept_scid_by_user_channel_id.retain(|_, iscid| intercept_scid != iscid);
496 return false;
497 }
498 true
499 });
500 }
501
502 fn pending_requests_and_channels(&self) -> usize {
503 let pending_requests = self.pending_requests.len();
504 let pending_outbound_channels = self
505 .outbound_channels_by_intercept_scid
506 .iter()
507 .filter(|(_, v)| v.is_pending_initial_payment())
508 .count();
509 pending_requests + pending_outbound_channels
510 }
511
512 fn is_prunable(&self) -> bool {
513 self.pending_requests.is_empty() && self.outbound_channels_by_intercept_scid.is_empty()
515 }
516}
517
518macro_rules! get_or_insert_peer_state_entry {
519 ($self: ident, $outer_state_lock: expr, $counterparty_node_id: expr) => {{
520 let is_limited_by_max_total_peers = $outer_state_lock.len() >= MAX_TOTAL_PEERS;
522 match $outer_state_lock.entry(*$counterparty_node_id) {
523 Entry::Vacant(e) => {
524 if is_limited_by_max_total_peers {
525 let error_response = ResponseError {
526 code: JSONRPC_INTERNAL_ERROR_ERROR_CODE,
527 message: JSONRPC_INTERNAL_ERROR_ERROR_MESSAGE.to_string(), data: None,
528 };
529
530 let msg = LSPSMessage::Invalid(error_response);
531 drop($outer_state_lock);
532 $self.pending_messages.enqueue($counterparty_node_id, msg);
533
534 let err = format!(
535 "Dropping request from peer {} due to reaching maximally allowed number of total peers: {}",
536 $counterparty_node_id, MAX_TOTAL_PEERS
537 );
538
539 return Err(LightningError { err, action: ErrorAction::IgnoreAndLog(Level::Error) });
540 } else {
541 e.insert(Mutex::new(PeerState::new()))
542 }
543 }
544 Entry::Occupied(e) => {
545 e.into_mut()
546 }
547 }
548
549 }}
550}
551
552pub struct LSPS2ServiceHandler<CM: Deref + Clone>
554where
555 CM::Target: AChannelManager,
556{
557 channel_manager: CM,
558 pending_messages: Arc<MessageQueue>,
559 pending_events: Arc<EventQueue>,
560 per_peer_state: RwLock<HashMap<PublicKey, Mutex<PeerState>>>,
561 peer_by_intercept_scid: RwLock<HashMap<u64, PublicKey>>,
562 peer_by_channel_id: RwLock<HashMap<ChannelId, PublicKey>>,
563 total_pending_requests: AtomicUsize,
564 config: LSPS2ServiceConfig,
565}
566
567impl<CM: Deref + Clone> LSPS2ServiceHandler<CM>
568where
569 CM::Target: AChannelManager,
570{
571 pub(crate) fn new(
573 pending_messages: Arc<MessageQueue>, pending_events: Arc<EventQueue>, channel_manager: CM,
574 config: LSPS2ServiceConfig,
575 ) -> Self {
576 Self {
577 pending_messages,
578 pending_events,
579 per_peer_state: RwLock::new(new_hash_map()),
580 peer_by_intercept_scid: RwLock::new(new_hash_map()),
581 peer_by_channel_id: RwLock::new(new_hash_map()),
582 total_pending_requests: AtomicUsize::new(0),
583 channel_manager,
584 config,
585 }
586 }
587
588 pub fn invalid_token_provided(
594 &self, counterparty_node_id: &PublicKey, request_id: RequestId,
595 ) -> Result<(), APIError> {
596 let (result, response) = {
597 let outer_state_lock = self.per_peer_state.read().unwrap();
598
599 match outer_state_lock.get(counterparty_node_id) {
600 Some(inner_state_lock) => {
601 let mut peer_state_lock = inner_state_lock.lock().unwrap();
602
603 match self.remove_pending_request(&mut peer_state_lock, &request_id) {
604 Some(LSPS2Request::GetInfo(_)) => {
605 let response = LSPS2Response::GetInfoError(ResponseError {
606 code: LSPS2_GET_INFO_REQUEST_UNRECOGNIZED_OR_STALE_TOKEN_ERROR_CODE,
607 message: "an unrecognized or stale token was provided".to_string(),
608 data: None,
609 });
610 (Ok(()), Some(response))
611 },
612 _ => (
613 Err(APIError::APIMisuseError {
614 err: format!(
615 "No pending get_info request for request_id: {:?}",
616 request_id
617 ),
618 }),
619 None,
620 ),
621 }
622 },
623 None => (
624 Err(APIError::APIMisuseError {
625 err: format!(
626 "No state for the counterparty exists: {:?}",
627 counterparty_node_id
628 ),
629 }),
630 None,
631 ),
632 }
633 };
634
635 if let Some(response) = response {
636 let msg = LSPS2Message::Response(request_id, response).into();
637 self.pending_messages.enqueue(counterparty_node_id, msg);
638 }
639
640 result
641 }
642
643 pub fn opening_fee_params_generated(
649 &self, counterparty_node_id: &PublicKey, request_id: RequestId,
650 opening_fee_params_menu: Vec<RawOpeningFeeParams>,
651 ) -> Result<(), APIError> {
652 let (result, response) = {
653 let outer_state_lock = self.per_peer_state.read().unwrap();
654
655 match outer_state_lock.get(counterparty_node_id) {
656 Some(inner_state_lock) => {
657 let mut peer_state_lock = inner_state_lock.lock().unwrap();
658
659 match self.remove_pending_request(&mut peer_state_lock, &request_id) {
660 Some(LSPS2Request::GetInfo(_)) => {
661 let response = LSPS2Response::GetInfo(GetInfoResponse {
662 opening_fee_params_menu: opening_fee_params_menu
663 .into_iter()
664 .map(|param| {
665 param.into_opening_fee_params(&self.config.promise_secret)
666 })
667 .collect(),
668 });
669 (Ok(()), Some(response))
670 },
671 _ => (
672 Err(APIError::APIMisuseError {
673 err: format!(
674 "No pending get_info request for request_id: {:?}",
675 request_id
676 ),
677 }),
678 None,
679 ),
680 }
681 },
682 None => (
683 Err(APIError::APIMisuseError {
684 err: format!(
685 "No state for the counterparty exists: {:?}",
686 counterparty_node_id
687 ),
688 }),
689 None,
690 ),
691 }
692 };
693
694 if let Some(response) = response {
695 let msg = LSPS2Message::Response(request_id, response).into();
696 self.pending_messages.enqueue(counterparty_node_id, msg);
697 }
698
699 result
700 }
701
702 pub fn invoice_parameters_generated(
708 &self, counterparty_node_id: &PublicKey, request_id: RequestId, intercept_scid: u64,
709 cltv_expiry_delta: u32, client_trusts_lsp: bool, user_channel_id: u128,
710 ) -> Result<(), APIError> {
711 let (result, response) = {
712 let outer_state_lock = self.per_peer_state.read().unwrap();
713
714 match outer_state_lock.get(counterparty_node_id) {
715 Some(inner_state_lock) => {
716 let mut peer_state_lock = inner_state_lock.lock().unwrap();
717
718 match self.remove_pending_request(&mut peer_state_lock, &request_id) {
719 Some(LSPS2Request::Buy(buy_request)) => {
720 {
721 let mut peer_by_intercept_scid =
722 self.peer_by_intercept_scid.write().unwrap();
723 peer_by_intercept_scid
724 .insert(intercept_scid, *counterparty_node_id);
725 }
726
727 let outbound_jit_channel = OutboundJITChannel::new(
728 buy_request.payment_size_msat,
729 buy_request.opening_fee_params,
730 user_channel_id,
731 );
732
733 peer_state_lock
734 .intercept_scid_by_user_channel_id
735 .insert(user_channel_id, intercept_scid);
736 peer_state_lock
737 .insert_outbound_channel(intercept_scid, outbound_jit_channel);
738
739 let response = LSPS2Response::Buy(BuyResponse {
740 jit_channel_scid: intercept_scid.into(),
741 lsp_cltv_expiry_delta: cltv_expiry_delta,
742 client_trusts_lsp,
743 });
744 (Ok(()), Some(response))
745 },
746 _ => (
747 Err(APIError::APIMisuseError {
748 err: format!(
749 "No pending buy request for request_id: {:?}",
750 request_id
751 ),
752 }),
753 None,
754 ),
755 }
756 },
757 None => (
758 Err(APIError::APIMisuseError {
759 err: format!(
760 "No state for the counterparty exists: {:?}",
761 counterparty_node_id
762 ),
763 }),
764 None,
765 ),
766 }
767 };
768
769 if let Some(response) = response {
770 let msg = LSPS2Message::Response(request_id, response).into();
771 self.pending_messages.enqueue(counterparty_node_id, msg);
772 }
773
774 result
775 }
776
777 pub fn htlc_intercepted(
790 &self, intercept_scid: u64, intercept_id: InterceptId, expected_outbound_amount_msat: u64,
791 payment_hash: PaymentHash,
792 ) -> Result<(), APIError> {
793 let peer_by_intercept_scid = self.peer_by_intercept_scid.read().unwrap();
794 if let Some(counterparty_node_id) = peer_by_intercept_scid.get(&intercept_scid) {
795 let outer_state_lock = self.per_peer_state.read().unwrap();
796 match outer_state_lock.get(counterparty_node_id) {
797 Some(inner_state_lock) => {
798 let mut peer_state = inner_state_lock.lock().unwrap();
799 if let Some(jit_channel) =
800 peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
801 {
802 let htlc = InterceptedHTLC {
803 intercept_id,
804 expected_outbound_amount_msat,
805 payment_hash,
806 };
807 match jit_channel.htlc_intercepted(htlc) {
808 Ok(Some(HTLCInterceptedAction::OpenChannel(open_channel_params))) => {
809 let event = Event::LSPS2Service(LSPS2ServiceEvent::OpenChannel {
810 their_network_key: counterparty_node_id.clone(),
811 amt_to_forward_msat: open_channel_params.amt_to_forward_msat,
812 opening_fee_msat: open_channel_params.opening_fee_msat,
813 user_channel_id: jit_channel.user_channel_id,
814 intercept_scid,
815 });
816 self.pending_events.enqueue(event);
817 },
818 Ok(Some(HTLCInterceptedAction::ForwardHTLC(channel_id))) => {
819 self.channel_manager.get_cm().forward_intercepted_htlc(
820 intercept_id,
821 &channel_id,
822 *counterparty_node_id,
823 expected_outbound_amount_msat,
824 )?;
825 },
826 Ok(Some(HTLCInterceptedAction::ForwardPayment(
827 channel_id,
828 FeePayment { opening_fee_msat, htlcs },
829 ))) => {
830 let amounts_to_forward_msat =
831 calculate_amount_to_forward_per_htlc(&htlcs, opening_fee_msat);
832
833 for (intercept_id, amount_to_forward_msat) in
834 amounts_to_forward_msat
835 {
836 self.channel_manager.get_cm().forward_intercepted_htlc(
837 intercept_id,
838 &channel_id,
839 *counterparty_node_id,
840 amount_to_forward_msat,
841 )?;
842 }
843 },
844 Ok(None) => {},
845 Err(e) => {
846 self.channel_manager
847 .get_cm()
848 .fail_intercepted_htlc(intercept_id)?;
849 peer_state
850 .outbound_channels_by_intercept_scid
851 .remove(&intercept_scid);
852 return Err(APIError::APIMisuseError { err: e.err });
854 },
855 }
856 }
857 },
858 None => {
859 return Err(APIError::APIMisuseError {
860 err: format!("No counterparty found for scid: {}", intercept_scid),
861 });
862 },
863 }
864 }
865
866 Ok(())
867 }
868
869 pub fn htlc_handling_failed(
877 &self, failed_next_destination: HTLCDestination,
878 ) -> Result<(), APIError> {
879 if let HTLCDestination::NextHopChannel { channel_id, .. } = failed_next_destination {
880 let peer_by_channel_id = self.peer_by_channel_id.read().unwrap();
881 if let Some(counterparty_node_id) = peer_by_channel_id.get(&channel_id) {
882 let outer_state_lock = self.per_peer_state.read().unwrap();
883 match outer_state_lock.get(counterparty_node_id) {
884 Some(inner_state_lock) => {
885 let mut peer_state = inner_state_lock.lock().unwrap();
886 if let Some(intercept_scid) =
887 peer_state.intercept_scid_by_channel_id.get(&channel_id).copied()
888 {
889 if let Some(jit_channel) = peer_state
890 .outbound_channels_by_intercept_scid
891 .get_mut(&intercept_scid)
892 {
893 match jit_channel.htlc_handling_failed() {
894 Ok(Some(ForwardPaymentAction(
895 channel_id,
896 FeePayment { opening_fee_msat, htlcs },
897 ))) => {
898 let amounts_to_forward_msat =
899 calculate_amount_to_forward_per_htlc(
900 &htlcs,
901 opening_fee_msat,
902 );
903
904 for (intercept_id, amount_to_forward_msat) in
905 amounts_to_forward_msat
906 {
907 self.channel_manager
908 .get_cm()
909 .forward_intercepted_htlc(
910 intercept_id,
911 &channel_id,
912 *counterparty_node_id,
913 amount_to_forward_msat,
914 )?;
915 }
916 },
917 Ok(None) => {},
918 Err(e) => {
919 return Err(APIError::APIMisuseError {
920 err: format!("Unable to fail HTLC: {}.", e.err),
921 });
922 },
923 }
924 }
925 }
926 },
927 None => {},
928 }
929 }
930 }
931
932 Ok(())
933 }
934
935 pub fn payment_forwarded(&self, next_channel_id: ChannelId) -> Result<(), APIError> {
947 if let Some(counterparty_node_id) =
948 self.peer_by_channel_id.read().unwrap().get(&next_channel_id)
949 {
950 let outer_state_lock = self.per_peer_state.read().unwrap();
951 match outer_state_lock.get(counterparty_node_id) {
952 Some(inner_state_lock) => {
953 let mut peer_state = inner_state_lock.lock().unwrap();
954 if let Some(intercept_scid) =
955 peer_state.intercept_scid_by_channel_id.get(&next_channel_id).copied()
956 {
957 if let Some(jit_channel) =
958 peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
959 {
960 match jit_channel.payment_forwarded() {
961 Ok(Some(ForwardHTLCsAction(channel_id, htlcs))) => {
962 for htlc in htlcs {
963 self.channel_manager.get_cm().forward_intercepted_htlc(
964 htlc.intercept_id,
965 &channel_id,
966 *counterparty_node_id,
967 htlc.expected_outbound_amount_msat,
968 )?;
969 }
970 },
971 Ok(None) => {},
972 Err(e) => {
973 return Err(APIError::APIMisuseError {
974 err: format!(
975 "Forwarded payment was not applicable for JIT channel: {}",
976 e.err
977 ),
978 })
979 },
980 }
981 }
982 } else {
983 return Err(APIError::APIMisuseError {
984 err: format!("No state for for channel id: {}", next_channel_id),
985 });
986 }
987 },
988 None => {
989 return Err(APIError::APIMisuseError {
990 err: format!("No counterparty state for: {}", counterparty_node_id),
991 });
992 },
993 }
994 }
995
996 Ok(())
997 }
998
999 pub fn channel_ready(
1006 &self, user_channel_id: u128, channel_id: &ChannelId, counterparty_node_id: &PublicKey,
1007 ) -> Result<(), APIError> {
1008 {
1009 let mut peer_by_channel_id = self.peer_by_channel_id.write().unwrap();
1010 peer_by_channel_id.insert(*channel_id, *counterparty_node_id);
1011 }
1012 let outer_state_lock = self.per_peer_state.read().unwrap();
1013 match outer_state_lock.get(counterparty_node_id) {
1014 Some(inner_state_lock) => {
1015 let mut peer_state = inner_state_lock.lock().unwrap();
1016 if let Some(intercept_scid) =
1017 peer_state.intercept_scid_by_user_channel_id.get(&user_channel_id).copied()
1018 {
1019 peer_state.intercept_scid_by_channel_id.insert(*channel_id, intercept_scid);
1020 if let Some(jit_channel) =
1021 peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
1022 {
1023 match jit_channel.channel_ready(*channel_id) {
1024 Ok(ForwardPaymentAction(
1025 channel_id,
1026 FeePayment { opening_fee_msat, htlcs },
1027 )) => {
1028 let amounts_to_forward_msat =
1029 calculate_amount_to_forward_per_htlc(&htlcs, opening_fee_msat);
1030
1031 for (intercept_id, amount_to_forward_msat) in
1032 amounts_to_forward_msat
1033 {
1034 self.channel_manager.get_cm().forward_intercepted_htlc(
1035 intercept_id,
1036 &channel_id,
1037 *counterparty_node_id,
1038 amount_to_forward_msat,
1039 )?;
1040 }
1041 },
1042 Err(e) => {
1043 return Err(APIError::APIMisuseError {
1044 err: format!(
1045 "Failed to transition to channel ready: {}",
1046 e.err
1047 ),
1048 })
1049 },
1050 }
1051 } else {
1052 return Err(APIError::APIMisuseError {
1053 err: format!(
1054 "Could not find a channel with user_channel_id {}",
1055 user_channel_id
1056 ),
1057 });
1058 }
1059 } else {
1060 return Err(APIError::APIMisuseError {
1061 err: format!(
1062 "Could not find a channel with that user_channel_id {}",
1063 user_channel_id
1064 ),
1065 });
1066 }
1067 },
1068 None => {
1069 return Err(APIError::APIMisuseError {
1070 err: format!("No counterparty state for: {}", counterparty_node_id),
1071 });
1072 },
1073 }
1074
1075 Ok(())
1076 }
1077
1078 fn handle_get_info_request(
1079 &self, request_id: RequestId, counterparty_node_id: &PublicKey, params: GetInfoRequest,
1080 ) -> Result<(), LightningError> {
1081 let (result, response) = {
1082 let mut outer_state_lock = self.per_peer_state.write().unwrap();
1083 let inner_state_lock =
1084 get_or_insert_peer_state_entry!(self, outer_state_lock, counterparty_node_id);
1085 let mut peer_state_lock = inner_state_lock.lock().unwrap();
1086 let request = LSPS2Request::GetInfo(params.clone());
1087 match self.insert_pending_request(
1088 &mut peer_state_lock,
1089 request_id.clone(),
1090 *counterparty_node_id,
1091 request,
1092 ) {
1093 (Ok(()), msg) => {
1094 let event = Event::LSPS2Service(LSPS2ServiceEvent::GetInfo {
1095 request_id,
1096 counterparty_node_id: *counterparty_node_id,
1097 token: params.token,
1098 });
1099 self.pending_events.enqueue(event);
1100
1101 (Ok(()), msg)
1102 },
1103 (e, msg) => (e, msg),
1104 }
1105 };
1106
1107 if let Some(msg) = response {
1108 self.pending_messages.enqueue(counterparty_node_id, msg);
1109 }
1110
1111 result
1112 }
1113
1114 fn handle_buy_request(
1115 &self, request_id: RequestId, counterparty_node_id: &PublicKey, params: BuyRequest,
1116 ) -> Result<(), LightningError> {
1117 if let Some(payment_size_msat) = params.payment_size_msat {
1118 if payment_size_msat < params.opening_fee_params.min_payment_size_msat {
1119 let response = LSPS2Response::BuyError(ResponseError {
1120 code: LSPS2_BUY_REQUEST_PAYMENT_SIZE_TOO_SMALL_ERROR_CODE,
1121 message: "payment size is below our minimum supported payment size".to_string(),
1122 data: None,
1123 });
1124 let msg = LSPS2Message::Response(request_id, response).into();
1125 self.pending_messages.enqueue(counterparty_node_id, msg);
1126
1127 return Err(LightningError {
1128 err: "payment size is below our minimum supported payment size".to_string(),
1129 action: ErrorAction::IgnoreAndLog(Level::Info),
1130 });
1131 }
1132
1133 if payment_size_msat > params.opening_fee_params.max_payment_size_msat {
1134 let response = LSPS2Response::BuyError(ResponseError {
1135 code: LSPS2_BUY_REQUEST_PAYMENT_SIZE_TOO_LARGE_ERROR_CODE,
1136 message: "payment size is above our maximum supported payment size".to_string(),
1137 data: None,
1138 });
1139 let msg = LSPS2Message::Response(request_id, response).into();
1140 self.pending_messages.enqueue(counterparty_node_id, msg);
1141 return Err(LightningError {
1142 err: "payment size is above our maximum supported payment size".to_string(),
1143 action: ErrorAction::IgnoreAndLog(Level::Info),
1144 });
1145 }
1146
1147 match compute_opening_fee(
1148 payment_size_msat,
1149 params.opening_fee_params.min_fee_msat,
1150 params.opening_fee_params.proportional.into(),
1151 ) {
1152 Some(opening_fee) => {
1153 if opening_fee >= payment_size_msat {
1154 let response = LSPS2Response::BuyError(ResponseError {
1155 code: LSPS2_BUY_REQUEST_PAYMENT_SIZE_TOO_SMALL_ERROR_CODE,
1156 message: "payment size is too small to cover the opening fee"
1157 .to_string(),
1158 data: None,
1159 });
1160 let msg = LSPS2Message::Response(request_id, response).into();
1161 self.pending_messages.enqueue(counterparty_node_id, msg);
1162 return Err(LightningError {
1163 err: "payment size is too small to cover the opening fee".to_string(),
1164 action: ErrorAction::IgnoreAndLog(Level::Info),
1165 });
1166 }
1167 },
1168 None => {
1169 let response = LSPS2Response::BuyError(ResponseError {
1170 code: LSPS2_BUY_REQUEST_PAYMENT_SIZE_TOO_LARGE_ERROR_CODE,
1171 message: "overflow error when calculating opening_fee".to_string(),
1172 data: None,
1173 });
1174 let msg = LSPS2Message::Response(request_id, response).into();
1175 self.pending_messages.enqueue(counterparty_node_id, msg);
1176 return Err(LightningError {
1177 err: "overflow error when calculating opening_fee".to_string(),
1178 action: ErrorAction::IgnoreAndLog(Level::Info),
1179 });
1180 },
1181 }
1182 }
1183
1184 if !is_valid_opening_fee_params(¶ms.opening_fee_params, &self.config.promise_secret) {
1186 let response = LSPS2Response::BuyError(ResponseError {
1187 code: LSPS2_BUY_REQUEST_INVALID_OPENING_FEE_PARAMS_ERROR_CODE,
1188 message: "valid_until is already past OR the promise did not match the provided parameters".to_string(),
1189 data: None,
1190 });
1191 let msg = LSPS2Message::Response(request_id, response).into();
1192 self.pending_messages.enqueue(counterparty_node_id, msg);
1193 return Err(LightningError {
1194 err: "invalid opening fee parameters were supplied by client".to_string(),
1195 action: ErrorAction::IgnoreAndLog(Level::Info),
1196 });
1197 }
1198
1199 let (result, response) = {
1200 let mut outer_state_lock = self.per_peer_state.write().unwrap();
1201 let inner_state_lock =
1202 get_or_insert_peer_state_entry!(self, outer_state_lock, counterparty_node_id);
1203 let mut peer_state_lock = inner_state_lock.lock().unwrap();
1204
1205 let request = LSPS2Request::Buy(params.clone());
1206 match self.insert_pending_request(
1207 &mut peer_state_lock,
1208 request_id.clone(),
1209 *counterparty_node_id,
1210 request,
1211 ) {
1212 (Ok(()), msg) => {
1213 let event = Event::LSPS2Service(LSPS2ServiceEvent::BuyRequest {
1214 request_id,
1215 counterparty_node_id: *counterparty_node_id,
1216 opening_fee_params: params.opening_fee_params,
1217 payment_size_msat: params.payment_size_msat,
1218 });
1219 self.pending_events.enqueue(event);
1220
1221 (Ok(()), msg)
1222 },
1223 (e, msg) => (e, msg),
1224 }
1225 };
1226
1227 if let Some(msg) = response {
1228 self.pending_messages.enqueue(counterparty_node_id, msg);
1229 }
1230
1231 result
1232 }
1233
1234 fn insert_pending_request<'a>(
1235 &self, peer_state_lock: &mut MutexGuard<'a, PeerState>, request_id: RequestId,
1236 counterparty_node_id: PublicKey, request: LSPS2Request,
1237 ) -> (Result<(), LightningError>, Option<LSPSMessage>) {
1238 if self.total_pending_requests.load(Ordering::Relaxed) >= MAX_TOTAL_PENDING_REQUESTS {
1239 let response = LSPS2Response::BuyError(ResponseError {
1240 code: LSPS0_CLIENT_REJECTED_ERROR_CODE,
1241 message: "Reached maximum number of pending requests. Please try again later."
1242 .to_string(),
1243 data: None,
1244 });
1245 let msg = Some(LSPS2Message::Response(request_id, response).into());
1246
1247 let err = format!(
1248 "Peer {} reached maximum number of total pending requests: {}",
1249 counterparty_node_id, MAX_TOTAL_PENDING_REQUESTS
1250 );
1251 let result =
1252 Err(LightningError { err, action: ErrorAction::IgnoreAndLog(Level::Debug) });
1253 return (result, msg);
1254 }
1255
1256 if peer_state_lock.pending_requests_and_channels() < MAX_PENDING_REQUESTS_PER_PEER {
1257 peer_state_lock.pending_requests.insert(request_id, request);
1258 self.total_pending_requests.fetch_add(1, Ordering::Relaxed);
1259 (Ok(()), None)
1260 } else {
1261 let response = LSPS2Response::BuyError(ResponseError {
1262 code: LSPS0_CLIENT_REJECTED_ERROR_CODE,
1263 message: "Reached maximum number of pending requests. Please try again later."
1264 .to_string(),
1265 data: None,
1266 });
1267 let msg = Some(LSPS2Message::Response(request_id, response).into());
1268
1269 let err = format!(
1270 "Peer {} reached maximum number of pending requests: {}",
1271 counterparty_node_id, MAX_PENDING_REQUESTS_PER_PEER
1272 );
1273 let result =
1274 Err(LightningError { err, action: ErrorAction::IgnoreAndLog(Level::Debug) });
1275
1276 (result, msg)
1277 }
1278 }
1279
1280 fn remove_pending_request<'a>(
1281 &self, peer_state_lock: &mut MutexGuard<'a, PeerState>, request_id: &RequestId,
1282 ) -> Option<LSPS2Request> {
1283 match peer_state_lock.pending_requests.remove(request_id) {
1284 Some(req) => {
1285 let res = self.total_pending_requests.fetch_update(
1286 Ordering::Relaxed,
1287 Ordering::Relaxed,
1288 |x| Some(x.saturating_sub(1)),
1289 );
1290 match res {
1291 Ok(previous_value) if previous_value == 0 => debug_assert!(
1292 false,
1293 "total_pending_requests counter out-of-sync! This should never happen!"
1294 ),
1295 Err(previous_value) if previous_value == 0 => debug_assert!(
1296 false,
1297 "total_pending_requests counter out-of-sync! This should never happen!"
1298 ),
1299 _ => {},
1300 }
1301 Some(req)
1302 },
1303 res => res,
1304 }
1305 }
1306
1307 #[cfg(debug_assertions)]
1308 fn verify_pending_request_counter(&self) {
1309 let mut num_requests = 0;
1310 let outer_state_lock = self.per_peer_state.read().unwrap();
1311 for (_, inner) in outer_state_lock.iter() {
1312 let inner_state_lock = inner.lock().unwrap();
1313 num_requests += inner_state_lock.pending_requests.len();
1314 }
1315 debug_assert_eq!(
1316 num_requests,
1317 self.total_pending_requests.load(Ordering::Relaxed),
1318 "total_pending_requests counter out-of-sync! This should never happen!"
1319 );
1320 }
1321
1322 pub(crate) fn peer_disconnected(&self, counterparty_node_id: PublicKey) {
1323 let mut outer_state_lock = self.per_peer_state.write().unwrap();
1324 let is_prunable =
1325 if let Some(inner_state_lock) = outer_state_lock.get(&counterparty_node_id) {
1326 let mut peer_state_lock = inner_state_lock.lock().unwrap();
1327 peer_state_lock.prune_expired_request_state();
1328 peer_state_lock.is_prunable()
1329 } else {
1330 return;
1331 };
1332 if is_prunable {
1333 outer_state_lock.remove(&counterparty_node_id);
1334 }
1335 }
1336
1337 #[allow(clippy::bool_comparison)]
1338 pub(crate) fn prune_peer_state(&self) {
1339 let mut outer_state_lock = self.per_peer_state.write().unwrap();
1340 outer_state_lock.retain(|_, inner_state_lock| {
1341 let mut peer_state_lock = inner_state_lock.lock().unwrap();
1342 peer_state_lock.prune_expired_request_state();
1343 peer_state_lock.is_prunable() == false
1344 });
1345 }
1346}
1347
1348impl<CM: Deref + Clone> ProtocolMessageHandler for LSPS2ServiceHandler<CM>
1349where
1350 CM::Target: AChannelManager,
1351{
1352 type ProtocolMessage = LSPS2Message;
1353 const PROTOCOL_NUMBER: Option<u16> = Some(2);
1354
1355 fn handle_message(
1356 &self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey,
1357 ) -> Result<(), LightningError> {
1358 match message {
1359 LSPS2Message::Request(request_id, request) => {
1360 let res = match request {
1361 LSPS2Request::GetInfo(params) => {
1362 self.handle_get_info_request(request_id, counterparty_node_id, params)
1363 },
1364 LSPS2Request::Buy(params) => {
1365 self.handle_buy_request(request_id, counterparty_node_id, params)
1366 },
1367 };
1368 #[cfg(debug_assertions)]
1369 self.verify_pending_request_counter();
1370 res
1371 },
1372 _ => {
1373 debug_assert!(
1374 false,
1375 "Service handler received LSPS2 response message. This should never happen."
1376 );
1377 Err(LightningError { err: format!("Service handler received LSPS2 response message from node {:?}. This should never happen.", counterparty_node_id), action: ErrorAction::IgnoreAndLog(Level::Info)})
1378 },
1379 }
1380 }
1381}
1382
1383fn calculate_amount_to_forward_per_htlc(
1384 htlcs: &[InterceptedHTLC], total_fee_msat: u64,
1385) -> Vec<(InterceptId, u64)> {
1386 let total_expected_outbound_msat: u64 =
1388 htlcs.iter().map(|htlc| htlc.expected_outbound_amount_msat).sum();
1389 if total_fee_msat > total_expected_outbound_msat {
1390 debug_assert!(false, "Fee is larger than the total expected outbound amount.");
1391 return Vec::new();
1392 }
1393
1394 let mut fee_remaining_msat = total_fee_msat;
1395 let mut per_htlc_forwards = vec![];
1396 for (index, htlc) in htlcs.iter().enumerate() {
1397 let proportional_fee_amt_msat = (total_fee_msat as u128
1398 * htlc.expected_outbound_amount_msat as u128
1399 / total_expected_outbound_msat as u128) as u64;
1400
1401 let mut actual_fee_amt_msat = core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat);
1402 actual_fee_amt_msat =
1403 core::cmp::min(actual_fee_amt_msat, htlc.expected_outbound_amount_msat);
1404 fee_remaining_msat -= actual_fee_amt_msat;
1405
1406 if index == htlcs.len() - 1 {
1407 actual_fee_amt_msat += fee_remaining_msat;
1408 }
1409
1410 let amount_to_forward_msat =
1411 htlc.expected_outbound_amount_msat.saturating_sub(actual_fee_amt_msat);
1412
1413 per_htlc_forwards.push((htlc.intercept_id, amount_to_forward_msat))
1414 }
1415 per_htlc_forwards
1416}
1417
1418#[cfg(test)]
1419mod tests {
1420
1421 use super::*;
1422 use chrono::TimeZone;
1423 use chrono::Utc;
1424 use proptest::prelude::*;
1425
1426 const MAX_VALUE_MSAT: u64 = 21_000_000_0000_0000_000;
1427
1428 fn arb_forward_amounts() -> impl Strategy<Value = (u64, u64, u64, u64)> {
1429 (1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT, 1u64..MAX_VALUE_MSAT)
1430 .prop_map(|(a, b, c, d)| {
1431 (a, b, c, core::cmp::min(d, a.saturating_add(b).saturating_add(c)))
1432 })
1433 }
1434
1435 proptest! {
1436 #[test]
1437 fn proptest_calculate_amount_to_forward((o_0, o_1, o_2, total_fee_msat) in arb_forward_amounts()) {
1438 let htlcs = vec![
1439 InterceptedHTLC {
1440 intercept_id: InterceptId([0; 32]),
1441 expected_outbound_amount_msat: o_0,
1442 payment_hash: PaymentHash([0; 32]),
1443 },
1444 InterceptedHTLC {
1445 intercept_id: InterceptId([1; 32]),
1446 expected_outbound_amount_msat: o_1,
1447 payment_hash: PaymentHash([0; 32]),
1448 },
1449 InterceptedHTLC {
1450 intercept_id: InterceptId([2; 32]),
1451 expected_outbound_amount_msat: o_2,
1452 payment_hash: PaymentHash([0; 32]),
1453 },
1454 ];
1455
1456 let result = calculate_amount_to_forward_per_htlc(&htlcs, total_fee_msat);
1457 let total_received_msat = o_0 + o_1 + o_2;
1458
1459 if total_received_msat < total_fee_msat {
1460 assert_eq!(result.len(), 0);
1461 } else {
1462 assert_ne!(result.len(), 0);
1463 assert_eq!(result[0].0, htlcs[0].intercept_id);
1464 assert_eq!(result[1].0, htlcs[1].intercept_id);
1465 assert_eq!(result[2].0, htlcs[2].intercept_id);
1466 assert!(result[0].1 <= o_0);
1467 assert!(result[1].1 <= o_1);
1468 assert!(result[2].1 <= o_2);
1469
1470 let result_sum = result.iter().map(|(_, f)| f).sum::<u64>();
1471 assert_eq!(total_received_msat - result_sum, total_fee_msat);
1472 let five_pct = result_sum as f32 * 0.05;
1473 let fair_share_0 = (o_0 as f32 / total_received_msat as f32) * result_sum as f32;
1474 assert!(result[0].1 as f32 <= fair_share_0 + five_pct);
1475 let fair_share_1 = (o_1 as f32 / total_received_msat as f32) * result_sum as f32;
1476 assert!(result[1].1 as f32 <= fair_share_1 + five_pct);
1477 let fair_share_2 = (o_2 as f32 / total_received_msat as f32) * result_sum as f32;
1478 assert!(result[2].1 as f32 <= fair_share_2 + five_pct);
1479 }
1480 }
1481 }
1482
1483 #[test]
1484 fn test_calculate_amount_to_forward() {
1485 let htlcs = vec![
1486 InterceptedHTLC {
1487 intercept_id: InterceptId([0; 32]),
1488 expected_outbound_amount_msat: 2,
1489 payment_hash: PaymentHash([0; 32]),
1490 },
1491 InterceptedHTLC {
1492 intercept_id: InterceptId([1; 32]),
1493 expected_outbound_amount_msat: 6,
1494 payment_hash: PaymentHash([0; 32]),
1495 },
1496 InterceptedHTLC {
1497 intercept_id: InterceptId([2; 32]),
1498 expected_outbound_amount_msat: 2,
1499 payment_hash: PaymentHash([0; 32]),
1500 },
1501 ];
1502 let result = calculate_amount_to_forward_per_htlc(&htlcs, 5);
1503 assert_eq!(
1504 result,
1505 vec![
1506 (htlcs[0].intercept_id, 1),
1507 (htlcs[1].intercept_id, 3),
1508 (htlcs[2].intercept_id, 1),
1509 ]
1510 );
1511 }
1512
1513 #[test]
1514 fn test_jit_channel_state_mpp() {
1515 let payment_size_msat = Some(500_000_000);
1516 let opening_fee_params = OpeningFeeParams {
1517 min_fee_msat: 10_000_000,
1518 proportional: 10_000,
1519 valid_until: Utc.timestamp_opt(3000, 0).unwrap(),
1520 min_lifetime: 4032,
1521 max_client_to_self_delay: 2016,
1522 min_payment_size_msat: 10_000_000,
1523 max_payment_size_msat: 1_000_000_000,
1524 promise: "ignore".to_string(),
1525 };
1526 let mut state = OutboundJITChannelState::new();
1527 {
1529 let (new_state, action) = state
1530 .htlc_intercepted(
1531 &opening_fee_params,
1532 &payment_size_msat,
1533 InterceptedHTLC {
1534 intercept_id: InterceptId([0; 32]),
1535 expected_outbound_amount_msat: 200_000_000,
1536 payment_hash: PaymentHash([100; 32]),
1537 },
1538 )
1539 .unwrap();
1540 assert!(matches!(new_state, OutboundJITChannelState::PendingInitialPayment { .. }));
1541 assert!(action.is_none());
1542 state = new_state;
1543 }
1544 {
1546 let (new_state, action) = state
1547 .htlc_intercepted(
1548 &opening_fee_params,
1549 &payment_size_msat,
1550 InterceptedHTLC {
1551 intercept_id: InterceptId([1; 32]),
1552 expected_outbound_amount_msat: 1_000_000,
1553 payment_hash: PaymentHash([101; 32]),
1554 },
1555 )
1556 .unwrap();
1557 assert!(matches!(new_state, OutboundJITChannelState::PendingInitialPayment { .. }));
1558 assert!(action.is_none());
1559 state = new_state;
1560 }
1561 {
1564 let (new_state, action) = state
1565 .htlc_intercepted(
1566 &opening_fee_params,
1567 &payment_size_msat,
1568 InterceptedHTLC {
1569 intercept_id: InterceptId([2; 32]),
1570 expected_outbound_amount_msat: 300_000_000,
1571 payment_hash: PaymentHash([100; 32]),
1572 },
1573 )
1574 .unwrap();
1575 assert!(matches!(new_state, OutboundJITChannelState::PendingChannelOpen { .. }));
1576 assert!(matches!(action, Some(HTLCInterceptedAction::OpenChannel(_))));
1577 state = new_state;
1578 }
1579 {
1581 let (new_state, ForwardPaymentAction(channel_id, payment)) =
1582 state.channel_ready(ChannelId([200; 32])).unwrap();
1583 assert_eq!(channel_id, ChannelId([200; 32]));
1584 assert_eq!(payment.opening_fee_msat, 10_000_000);
1585 assert_eq!(
1586 payment.htlcs,
1587 vec![
1588 InterceptedHTLC {
1589 intercept_id: InterceptId([0; 32]),
1590 expected_outbound_amount_msat: 200_000_000,
1591 payment_hash: PaymentHash([100; 32]),
1592 },
1593 InterceptedHTLC {
1594 intercept_id: InterceptId([2; 32]),
1595 expected_outbound_amount_msat: 300_000_000,
1596 payment_hash: PaymentHash([100; 32]),
1597 },
1598 ]
1599 );
1600 state = new_state;
1601 }
1602 {
1604 let (new_state, action) = state
1605 .htlc_intercepted(
1606 &opening_fee_params,
1607 &payment_size_msat,
1608 InterceptedHTLC {
1609 intercept_id: InterceptId([3; 32]),
1610 expected_outbound_amount_msat: 2_000_000,
1611 payment_hash: PaymentHash([102; 32]),
1612 },
1613 )
1614 .unwrap();
1615 assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. }));
1616 assert!(action.is_none());
1617 state = new_state;
1618 }
1619 {
1621 let (new_state, action) = state.htlc_handling_failed().unwrap();
1622 assert!(matches!(new_state, OutboundJITChannelState::PendingPayment { .. }));
1623 assert!(action.is_none());
1625 state = new_state;
1626 }
1627 {
1629 let (new_state, action) = state
1630 .htlc_intercepted(
1631 &opening_fee_params,
1632 &payment_size_msat,
1633 InterceptedHTLC {
1634 intercept_id: InterceptId([4; 32]),
1635 expected_outbound_amount_msat: 500_000_000,
1636 payment_hash: PaymentHash([101; 32]),
1637 },
1638 )
1639 .unwrap();
1640 assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. }));
1641 match action {
1642 Some(HTLCInterceptedAction::ForwardPayment(channel_id, payment)) => {
1643 assert_eq!(channel_id, ChannelId([200; 32]));
1644 assert_eq!(payment.opening_fee_msat, 10_000_000);
1645 assert_eq!(
1646 payment.htlcs,
1647 vec![
1648 InterceptedHTLC {
1649 intercept_id: InterceptId([1; 32]),
1650 expected_outbound_amount_msat: 1_000_000,
1651 payment_hash: PaymentHash([101; 32]),
1652 },
1653 InterceptedHTLC {
1654 intercept_id: InterceptId([4; 32]),
1655 expected_outbound_amount_msat: 500_000_000,
1656 payment_hash: PaymentHash([101; 32]),
1657 },
1658 ]
1659 );
1660 },
1661 _ => panic!("Unexpected action when intercepted HTLC."),
1662 }
1663 state = new_state;
1664 }
1665 {
1667 let (new_state, action) = state.payment_forwarded().unwrap();
1668 assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
1669 match action {
1670 Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1671 assert_eq!(channel_id, ChannelId([200; 32]));
1672 assert_eq!(
1673 htlcs,
1674 vec![InterceptedHTLC {
1675 intercept_id: InterceptId([3; 32]),
1676 expected_outbound_amount_msat: 2_000_000,
1677 payment_hash: PaymentHash([102; 32]),
1678 }]
1679 );
1680 },
1681 _ => panic!("Unexpected action when forwarded payment."),
1682 }
1683 state = new_state;
1684 }
1685 {
1687 let (new_state, action) = state
1688 .htlc_intercepted(
1689 &opening_fee_params,
1690 &payment_size_msat,
1691 InterceptedHTLC {
1692 intercept_id: InterceptId([5; 32]),
1693 expected_outbound_amount_msat: 200_000_000,
1694 payment_hash: PaymentHash([103; 32]),
1695 },
1696 )
1697 .unwrap();
1698 assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
1699 assert!(
1700 matches!(action, Some(HTLCInterceptedAction::ForwardHTLC(channel_id)) if channel_id == ChannelId([200; 32]))
1701 );
1702 }
1703 }
1704
1705 #[test]
1706 fn test_jit_channel_state_no_mpp() {
1707 let payment_size_msat = None;
1708 let opening_fee_params = OpeningFeeParams {
1709 min_fee_msat: 10_000_000,
1710 proportional: 10_000,
1711 valid_until: Utc.timestamp_opt(3000, 0).unwrap(),
1712 min_lifetime: 4032,
1713 max_client_to_self_delay: 2016,
1714 min_payment_size_msat: 10_000_000,
1715 max_payment_size_msat: 1_000_000_000,
1716 promise: "ignore".to_string(),
1717 };
1718 let mut state = OutboundJITChannelState::new();
1719 {
1721 let (new_state, action) = state
1722 .htlc_intercepted(
1723 &opening_fee_params,
1724 &payment_size_msat,
1725 InterceptedHTLC {
1726 intercept_id: InterceptId([0; 32]),
1727 expected_outbound_amount_msat: 500_000_000,
1728 payment_hash: PaymentHash([100; 32]),
1729 },
1730 )
1731 .unwrap();
1732 assert!(matches!(new_state, OutboundJITChannelState::PendingChannelOpen { .. }));
1733 assert!(matches!(action, Some(HTLCInterceptedAction::OpenChannel(_))));
1734 state = new_state;
1735 }
1736 {
1738 let (new_state, action) = state
1739 .htlc_intercepted(
1740 &opening_fee_params,
1741 &payment_size_msat,
1742 InterceptedHTLC {
1743 intercept_id: InterceptId([1; 32]),
1744 expected_outbound_amount_msat: 600_000_000,
1745 payment_hash: PaymentHash([101; 32]),
1746 },
1747 )
1748 .unwrap();
1749 assert!(matches!(new_state, OutboundJITChannelState::PendingChannelOpen { .. }));
1750 assert!(action.is_none());
1751 state = new_state;
1752 }
1753 {
1755 let (new_state, ForwardPaymentAction(channel_id, payment)) =
1756 state.channel_ready(ChannelId([200; 32])).unwrap();
1757 assert_eq!(channel_id, ChannelId([200; 32]));
1758 assert_eq!(payment.opening_fee_msat, 10_000_000);
1759 assert_eq!(
1760 payment.htlcs,
1761 vec![InterceptedHTLC {
1762 intercept_id: InterceptId([0; 32]),
1763 expected_outbound_amount_msat: 500_000_000,
1764 payment_hash: PaymentHash([100; 32]),
1765 },]
1766 );
1767 state = new_state;
1768 }
1769 {
1771 let (new_state, action) = state
1772 .htlc_intercepted(
1773 &opening_fee_params,
1774 &payment_size_msat,
1775 InterceptedHTLC {
1776 intercept_id: InterceptId([2; 32]),
1777 expected_outbound_amount_msat: 500_000_000,
1778 payment_hash: PaymentHash([102; 32]),
1779 },
1780 )
1781 .unwrap();
1782 assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. }));
1783 assert!(action.is_none());
1784 state = new_state;
1785 }
1786 {
1788 let (new_state, action) = state.htlc_handling_failed().unwrap();
1789 assert!(matches!(new_state, OutboundJITChannelState::PendingPaymentForward { .. }));
1790 match action {
1791 Some(ForwardPaymentAction(channel_id, payment)) => {
1792 assert_eq!(channel_id, ChannelId([200; 32]));
1793 assert_eq!(
1794 payment.htlcs,
1795 vec![InterceptedHTLC {
1796 intercept_id: InterceptId([1; 32]),
1797 expected_outbound_amount_msat: 600_000_000,
1798 payment_hash: PaymentHash([101; 32]),
1799 },]
1800 );
1801 },
1802 _ => panic!("Unexpected action when HTLC handling failed."),
1803 }
1804 state = new_state;
1805 }
1806 {
1808 let (new_state, action) = state.payment_forwarded().unwrap();
1809 assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
1810 match action {
1811 Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1812 assert_eq!(channel_id, ChannelId([200; 32]));
1813 assert_eq!(
1814 htlcs,
1815 vec![InterceptedHTLC {
1816 intercept_id: InterceptId([2; 32]),
1817 expected_outbound_amount_msat: 500_000_000,
1818 payment_hash: PaymentHash([102; 32]),
1819 }]
1820 );
1821 },
1822 _ => panic!("Unexpected action when forwarded payment."),
1823 }
1824 state = new_state;
1825 }
1826 {
1828 let (new_state, action) = state
1829 .htlc_intercepted(
1830 &opening_fee_params,
1831 &payment_size_msat,
1832 InterceptedHTLC {
1833 intercept_id: InterceptId([3; 32]),
1834 expected_outbound_amount_msat: 200_000_000,
1835 payment_hash: PaymentHash([103; 32]),
1836 },
1837 )
1838 .unwrap();
1839 assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
1840 assert!(
1841 matches!(action, Some(HTLCInterceptedAction::ForwardHTLC(channel_id)) if channel_id == ChannelId([200; 32]))
1842 );
1843 }
1844 }
1845}