package htlcswitch
import (
"bytes"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker"
)
const (
DefaultFwdEventInterval = 15 * time.Second
DefaultLogInterval = 10 * time.Second
DefaultAckInterval = 15 * time.Second
DefaultHTLCExpiry = time.Minute
)
var (
ErrChannelLinkNotFound = errors.New("channel link not found")
ErrDuplicateAdd = errors.New("duplicate add HTLC detected")
ErrUnknownErrorDecryptor = errors.New("unknown error decryptor")
ErrSwitchExiting = errors.New("htlcswitch shutting down")
ErrNoLinksFound = errors.New("no channel links found")
ErrUnreadableFailureMessage = errors.New("unreadable failure message")
)
type plexPacket struct {
pkt *htlcPacket
err chan error
}
type ChannelCloseType uint8
const (
CloseRegular ChannelCloseType = iota
CloseBreach
)
type ChanClose struct {
CloseType ChannelCloseType
ChanPoint *wire.OutPoint
TargetFeePerKw chainfee.SatPerKWeight
DeliveryScript lnwire.DeliveryAddress
Updates chan interface{}
Err chan error
}
type Config struct {
FwdingLog ForwardingLog
LocalChannelClose func(pubKey []byte, request *ChanClose)
DB *channeldb.DB
SwitchPackager channeldb.FwdOperator
ExtractErrorEncrypter hop.ErrorEncrypterExtracter
FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error)
Notifier chainntnfs.ChainNotifier
HtlcNotifier htlcNotifier
FwdEventTicker ticker.Ticker
LogEventTicker ticker.Ticker
AckEventTicker ticker.Ticker
AllowCircularRoute bool
RejectHTLC bool
Clock clock.Clock
HTLCExpiry time.Duration
}
type Switch struct {
started int32 shutdown int32
bestHeight uint32
wg sync.WaitGroup
quit chan struct{}
cfg *Config
networkResults *networkResultStore
circuits CircuitMap
mailOrchestrator *mailOrchestrator
indexMtx sync.RWMutex
pendingLinkIndex map[lnwire.ChannelID]ChannelLink
linkIndex map[lnwire.ChannelID]ChannelLink
forwardingIndex map[lnwire.ShortChannelID]ChannelLink
interfaceIndex map[[33]byte]map[lnwire.ChannelID]ChannelLink
htlcPlex chan *plexPacket
chanCloseRequests chan *ChanClose
resolutionMsgs chan *resolutionMsg
fwdEventMtx sync.Mutex
pendingFwdingEvents []channeldb.ForwardingEvent
blockEpochStream *chainntnfs.BlockEpochEvent
pendingSettleFails []channeldb.SettleFailRef
}
func New(cfg Config, currentHeight uint32) (*Switch, error) {
circuitMap, err := NewCircuitMap(&CircuitMapConfig{
DB: cfg.DB,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
})
if err != nil {
return nil, err
}
s := &Switch{
bestHeight: currentHeight,
cfg: &cfg,
circuits: circuitMap,
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
interfaceIndex: make(map[[33]byte]map[lnwire.ChannelID]ChannelLink),
pendingLinkIndex: make(map[lnwire.ChannelID]ChannelLink),
networkResults: newNetworkResultStore(cfg.DB),
htlcPlex: make(chan *plexPacket),
chanCloseRequests: make(chan *ChanClose),
resolutionMsgs: make(chan *resolutionMsg),
quit: make(chan struct{}),
}
s.mailOrchestrator = newMailOrchestrator(&mailOrchConfig{
fetchUpdate: s.cfg.FetchLastChannelUpdate,
forwardPackets: s.ForwardPackets,
clock: s.cfg.Clock,
expiry: s.cfg.HTLCExpiry,
})
return s, nil
}
type resolutionMsg struct {
contractcourt.ResolutionMsg
doneChan chan struct{}
}
func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) error {
done := make(chan struct{})
select {
case s.resolutionMsgs <- &resolutionMsg{
ResolutionMsg: msg,
doneChan: done,
}:
case <-s.quit:
return ErrSwitchExiting
}
select {
case <-done:
case <-s.quit:
return ErrSwitchExiting
}
return nil
}
func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash,
deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) {
var (
nChan <-chan *networkResult
err error
outKey = CircuitKey{
ChanID: hop.Source,
HtlcID: paymentID,
}
)
if s.circuits.LookupCircuit(outKey) == nil {
res, err := s.networkResults.getResult(paymentID)
if err != nil {
return nil, err
}
c := make(chan *networkResult, 1)
c <- res
nChan = c
} else {
nChan, err = s.networkResults.subscribeResult(paymentID)
if err != nil {
return nil, err
}
}
resultChan := make(chan *PaymentResult, 1)
s.wg.Add(1)
go func() {
defer s.wg.Done()
var n *networkResult
select {
case n = <-nChan:
case <-s.quit:
close(resultChan)
return
}
result, err := s.extractResult(
deobfuscator, n, paymentID, paymentHash,
)
if err != nil {
e := fmt.Errorf("unable to extract result: %v", err)
log.Error(e)
resultChan <- &PaymentResult{
Error: e,
}
return
}
resultChan <- result
}()
return resultChan, nil
}
func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
htlc *lnwire.UpdateAddHTLC) error {
packet := &htlcPacket{
incomingChanID: hop.Source,
incomingHTLCID: paymentID,
outgoingChanID: firstHop,
htlc: htlc,
}
return s.forward(packet)
}
func (s *Switch) UpdateForwardingPolicies(
chanPolicies map[wire.OutPoint]ForwardingPolicy) {
log.Tracef("Updating link policies: %v", newLogClosure(func() string {
return spew.Sdump(chanPolicies)
}))
s.indexMtx.RLock()
for targetLink, policy := range chanPolicies {
cid := lnwire.NewChanIDFromOutPoint(&targetLink)
link, ok := s.linkIndex[cid]
if !ok {
log.Debugf("Unable to find ChannelPoint(%v) to update "+
"link policy", targetLink)
continue
}
link.UpdateForwardingPolicy(policy)
}
s.indexMtx.RUnlock()
}
func (s *Switch) IsForwardedHTLC(chanID lnwire.ShortChannelID,
htlcIndex uint64) bool {
circuit := s.circuits.LookupOpenCircuit(channeldb.CircuitKey{
ChanID: chanID,
HtlcID: htlcIndex,
})
return circuit != nil && circuit.Incoming.ChanID != hop.Source
}
func (s *Switch) forward(packet *htlcPacket) error {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
circuit := newPaymentCircuit(&htlc.PaymentHash, packet)
actions, err := s.circuits.CommitCircuits(circuit)
if err != nil {
log.Errorf("unable to commit circuit in switch: %v", err)
return err
}
switch {
case len(actions.Drops) == 1:
return ErrDuplicateAdd
case len(actions.Fails) == 1:
if packet.incomingChanID == hop.Source {
return err
}
var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
packet.incomingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
linkError := NewDetailedLinkError(
failure, OutgoingFailureIncompleteForward,
)
return s.failAddPacket(packet, linkError)
}
packet.circuit = circuit
}
return s.route(packet)
}
func (s *Switch) ForwardPackets(linkQuit chan struct{},
packets ...*htlcPacket) chan error {
var (
fwdChan = make(chan error, len(packets))
errChan = make(chan error, len(packets))
numSent int
)
if len(packets) == 0 {
close(errChan)
return errChan
}
var wg sync.WaitGroup
wg.Add(1)
defer wg.Done()
select {
case <-linkQuit:
close(errChan)
return errChan
case <-s.quit:
close(errChan)
return errChan
default:
s.wg.Add(1)
go s.proxyFwdErrs(&numSent, &wg, fwdChan, errChan)
}
var circuits []*PaymentCircuit
var addBatch []*htlcPacket
for _, packet := range packets {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
circuit := newPaymentCircuit(&htlc.PaymentHash, packet)
packet.circuit = circuit
circuits = append(circuits, circuit)
addBatch = append(addBatch, packet)
default:
err := s.routeAsync(packet, fwdChan, linkQuit)
if err != nil {
return errChan
}
numSent++
}
}
if len(circuits) == 0 {
return errChan
}
actions, err := s.circuits.CommitCircuits(circuits...)
if err != nil {
log.Errorf("unable to commit circuits in switch: %v", err)
}
var addedPackets, failedPackets []*htlcPacket
for _, packet := range addBatch {
switch {
case len(actions.Adds) > 0 && packet.circuit == actions.Adds[0]:
addedPackets = append(addedPackets, packet)
actions.Adds = actions.Adds[1:]
case len(actions.Drops) > 0 && packet.circuit == actions.Drops[0]:
actions.Drops = actions.Drops[1:]
case len(actions.Fails) > 0 && packet.circuit == actions.Fails[0]:
failedPackets = append(failedPackets, packet)
actions.Fails = actions.Fails[1:]
}
}
for _, packet := range addedPackets {
err := s.routeAsync(packet, fwdChan, linkQuit)
if err != nil {
return errChan
}
numSent++
}
if len(failedPackets) > 0 {
var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
failedPackets[0].incomingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
linkError := NewDetailedLinkError(
failure, OutgoingFailureIncompleteForward,
)
for _, packet := range failedPackets {
_ = s.failAddPacket(packet, linkError)
}
}
return errChan
}
func (s *Switch) proxyFwdErrs(num *int, wg *sync.WaitGroup,
fwdChan, errChan chan error) {
defer s.wg.Done()
defer close(errChan)
wg.Wait()
numSent := *num
for i := 0; i < numSent; i++ {
select {
case err := <-fwdChan:
errChan <- err
case <-s.quit:
log.Errorf("unable to forward htlc packet " +
"htlc switch was stopped")
return
}
}
}
func (s *Switch) route(packet *htlcPacket) error {
command := &plexPacket{
pkt: packet,
err: make(chan error, 1),
}
select {
case s.htlcPlex <- command:
case <-s.quit:
return ErrSwitchExiting
}
select {
case err := <-command.err:
return err
case <-s.quit:
return ErrSwitchExiting
}
}
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
linkQuit chan struct{}) error {
command := &plexPacket{
pkt: packet,
err: errChan,
}
select {
case s.htlcPlex <- command:
return nil
case <-linkQuit:
return ErrLinkShuttingDown
case <-s.quit:
return errors.New("htlc switch was stopped")
}
}
func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
if htlc, ok := pkt.htlc.(*lnwire.UpdateAddHTLC); ok {
link, err := s.handleLocalAddHTLC(pkt, htlc)
if err != nil {
s.cfg.HtlcNotifier.NotifyLinkFailEvent(
newHtlcKey(pkt),
HtlcInfo{
OutgoingTimeLock: htlc.Expiry,
OutgoingAmt: htlc.Amount,
},
HtlcEventTypeSend,
err,
false,
)
return err
}
return link.HandleSwitchPacket(pkt)
}
s.wg.Add(1)
go s.handleLocalResponse(pkt)
return nil
}
func (s *Switch) handleLocalAddHTLC(pkt *htlcPacket,
htlc *lnwire.UpdateAddHTLC) (ChannelLink, *LinkError) {
s.indexMtx.RLock()
link, err := s.getLinkByShortID(pkt.outgoingChanID)
s.indexMtx.RUnlock()
if err != nil {
log.Errorf("Link %v not found", pkt.outgoingChanID)
return nil, NewLinkError(&lnwire.FailUnknownNextPeer{})
}
if !link.EligibleToForward() {
log.Errorf("Link %v is not available to forward",
pkt.outgoingChanID)
return nil, NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureLinkNotEligible,
)
}
currentHeight := atomic.LoadUint32(&s.bestHeight)
htlcErr := link.CheckHtlcTransit(
htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight,
)
if htlcErr != nil {
log.Errorf("Link %v policy for local forward not "+
"satisfied", pkt.outgoingChanID)
return nil, htlcErr
}
return link, nil
}
func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
defer s.wg.Done()
paymentID := pkt.incomingHTLCID
unencrypted := pkt.localFailure || pkt.convertedError
n := &networkResult{
msg: pkt.htlc,
unencrypted: unencrypted,
isResolution: pkt.isResolution,
}
if err := s.networkResults.storeResult(paymentID, n); err != nil {
log.Errorf("Unable to complete payment for pid=%v: %v",
paymentID, err)
return
}
if pkt.destRef != nil {
if err := s.ackSettleFail(*pkt.destRef); err != nil {
log.Warnf("Unable to ack settle/fail reference: %s: %v",
*pkt.destRef, err)
return
}
}
if err := s.teardownCircuit(pkt); err != nil {
log.Warnf("Unable to teardown circuit %s: %v",
pkt.inKey(), err)
return
}
key := newHtlcKey(pkt)
eventType := getEventType(pkt)
switch pkt.htlc.(type) {
case *lnwire.UpdateFulfillHTLC:
s.cfg.HtlcNotifier.NotifySettleEvent(key, eventType)
case *lnwire.UpdateFailHTLC:
s.cfg.HtlcNotifier.NotifyForwardingFailEvent(key, eventType)
}
}
func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult,
paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) {
switch htlc := n.msg.(type) {
case *lnwire.UpdateFulfillHTLC:
return &PaymentResult{
Preimage: htlc.PaymentPreimage,
}, nil
case *lnwire.UpdateFailHTLC:
paymentErr := s.parseFailedPayment(
deobfuscator, paymentID, paymentHash, n.unencrypted,
n.isResolution, htlc,
)
return &PaymentResult{
Error: paymentErr,
}, nil
default:
return nil, fmt.Errorf("received unknown response type: %T",
htlc)
}
}
func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter,
paymentID uint64, paymentHash lntypes.Hash, unencrypted,
isResolution bool, htlc *lnwire.UpdateFailHTLC) error {
switch {
case unencrypted:
r := bytes.NewReader(htlc.Reason)
failureMsg, err := lnwire.DecodeFailure(r, 0)
if err != nil {
linkError := NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureDecodeError,
)
log.Errorf("%v: (hash=%v, pid=%d): %v",
linkError.FailureDetail.FailureString(),
paymentHash, paymentID, err)
return linkError
}
return NewLinkError(failureMsg)
case isResolution && htlc.Reason == nil:
linkError := NewDetailedLinkError(
&lnwire.FailPermanentChannelFailure{},
OutgoingFailureOnChainTimeout,
)
log.Info("%v: hash=%v, pid=%d",
linkError.FailureDetail.FailureString(),
paymentHash, paymentID)
return linkError
default:
failure, err := deobfuscator.DecryptError(htlc.Reason)
if err != nil {
log.Errorf("unable to de-obfuscate onion failure "+
"(hash=%v, pid=%d): %v",
paymentHash, paymentID, err)
return ErrUnreadableFailureMessage
}
return failure
}
}
func (s *Switch) handlePacketForward(packet *htlcPacket) error {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
if s.cfg.RejectHTLC && packet.incomingChanID != hop.Source {
failure := NewDetailedLinkError(
&lnwire.FailChannelDisabled{},
OutgoingFailureForwardsDisabled,
)
return s.failAddPacket(packet, failure)
}
if packet.incomingChanID == hop.Source {
return s.handleLocalDispatch(packet)
}
linkErr := checkCircularForward(
packet.incomingChanID, packet.outgoingChanID,
s.cfg.AllowCircularRoute, htlc.PaymentHash,
)
if linkErr != nil {
return s.failAddPacket(packet, linkErr)
}
s.indexMtx.RLock()
targetLink, err := s.getLinkByShortID(packet.outgoingChanID)
if err != nil {
s.indexMtx.RUnlock()
log.Debugf("unable to find link with "+
"destination %v", packet.outgoingChanID)
linkError := NewLinkError(
&lnwire.FailUnknownNextPeer{},
)
return s.failAddPacket(packet, linkError)
}
targetPeerKey := targetLink.Peer().PubKey()
interfaceLinks, _ := s.getLinks(targetPeerKey)
s.indexMtx.RUnlock()
linkErrs := make(map[lnwire.ShortChannelID]*LinkError)
var destination ChannelLink
for _, link := range interfaceLinks {
var failure *LinkError
if !link.EligibleToForward() {
failure = NewDetailedLinkError(
&lnwire.FailUnknownNextPeer{},
OutgoingFailureLinkNotEligible,
)
} else {
currentHeight := atomic.LoadUint32(&s.bestHeight)
failure = link.CheckHtlcForward(
htlc.PaymentHash, packet.incomingAmount,
packet.amount, packet.incomingTimeout,
packet.outgoingTimeout, currentHeight,
)
}
if failure == nil {
destination = link
break
}
linkErrs[link.ShortChanID()] = failure
}
if destination == nil {
linkErr, ok := linkErrs[packet.outgoingChanID]
if !ok {
linkErr = NewLinkError(
&lnwire.FailUnknownNextPeer{},
)
log.Warnf("unable to find err source for "+
"outgoing_link=%v, errors=%v",
packet.outgoingChanID, newLogClosure(func() string {
return spew.Sdump(linkErrs)
}))
}
log.Tracef("incoming HTLC(%x) violated "+
"target outgoing link (id=%v) policy: %v",
htlc.PaymentHash[:], packet.outgoingChanID,
linkErr)
return s.failAddPacket(packet, linkErr)
}
packet.outgoingChanID = destination.ShortChanID()
return destination.HandleSwitchPacket(packet)
case *lnwire.UpdateFailHTLC, *lnwire.UpdateFulfillHTLC:
circuit, err := s.closeCircuit(packet)
if err != nil {
return err
}
if circuit == nil {
return nil
}
fail, isFail := htlc.(*lnwire.UpdateFailHTLC)
if isFail && !packet.hasSource {
switch {
case circuit.ErrorEncrypter == nil:
case packet.isResolution:
var err error
failure := &lnwire.FailPermanentChannelFailure{}
fail.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop(
failure,
)
if err != nil {
err = fmt.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
}
case packet.convertedError:
log.Infof("Converting malformed HTLC error "+
"for circuit for Circuit(%x: "+
"(%s, %d) <-> (%s, %d))", packet.circuit.PaymentHash,
packet.incomingChanID, packet.incomingHTLCID,
packet.outgoingChanID, packet.outgoingHTLCID)
fail.Reason = circuit.ErrorEncrypter.EncryptMalformedError(
fail.Reason,
)
default:
fail.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
fail.Reason,
)
}
} else if !isFail && circuit.Outgoing != nil {
localHTLC := packet.incomingChanID == hop.Source
if !localHTLC {
s.fwdEventMtx.Lock()
s.pendingFwdingEvents = append(
s.pendingFwdingEvents,
channeldb.ForwardingEvent{
Timestamp: time.Now(),
IncomingChanID: circuit.Incoming.ChanID,
OutgoingChanID: circuit.Outgoing.ChanID,
AmtIn: circuit.IncomingAmount,
AmtOut: circuit.OutgoingAmount,
},
)
s.fwdEventMtx.Unlock()
}
}
if packet.incomingChanID == hop.Source {
return s.handleLocalDispatch(packet)
}
return s.mailOrchestrator.Deliver(packet.incomingChanID, packet)
default:
return errors.New("wrong update type")
}
}
func checkCircularForward(incoming, outgoing lnwire.ShortChannelID,
allowCircular bool, paymentHash lntypes.Hash) *LinkError {
if incoming != outgoing {
return nil
}
if allowCircular {
log.Debugf("allowing circular route over link: %v "+
"(payment hash: %x)", incoming, paymentHash)
return nil
}
return NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
)
}
func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error {
reason, err := packet.obfuscator.EncryptFirstHop(failure.WireMessage())
if err != nil {
err := fmt.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
return err
}
log.Error(failure.Error())
failPkt := &htlcPacket{
sourceRef: packet.sourceRef,
incomingChanID: packet.incomingChanID,
incomingHTLCID: packet.incomingHTLCID,
outgoingChanID: packet.outgoingChanID,
outgoingHTLCID: packet.outgoingHTLCID,
incomingAmount: packet.incomingAmount,
amount: packet.amount,
incomingTimeout: packet.incomingTimeout,
outgoingTimeout: packet.outgoingTimeout,
circuit: packet.circuit,
linkFailure: failure,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
err = s.mailOrchestrator.Deliver(failPkt.incomingChanID, failPkt)
if err != nil {
err = fmt.Errorf("source chanid=%v unable to "+
"handle switch packet: %v",
packet.incomingChanID, err)
log.Error(err)
return err
}
return failure
}
func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) {
if pkt.hasSource {
circuit, err := s.circuits.FailCircuit(pkt.inKey())
switch err {
case nil:
return circuit, nil
case ErrCircuitClosing:
return nil, err
case ErrUnknownCircuit:
return nil, err
default:
return nil, err
}
}
circuit, err := s.circuits.CloseCircuit(pkt.outKey())
switch err {
case nil:
pkt.incomingChanID = circuit.Incoming.ChanID
pkt.incomingHTLCID = circuit.Incoming.HtlcID
pkt.circuit = circuit
pkt.sourceRef = &circuit.AddRef
pktType := "SETTLE"
if _, ok := pkt.htlc.(*lnwire.UpdateFailHTLC); ok {
pktType = "FAIL"
}
log.Debugf("Closed completed %s circuit for %x: "+
"(%s, %d) <-> (%s, %d)", pktType, pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID,
pkt.outgoingChanID, pkt.outgoingHTLCID)
return circuit, nil
case ErrCircuitClosing:
return nil, err
case ErrUnknownCircuit:
if pkt.destRef != nil {
s.pendingSettleFails = append(s.pendingSettleFails, *pkt.destRef)
}
_, isSettle := pkt.htlc.(*lnwire.UpdateFulfillHTLC)
if !isSettle {
err := fmt.Errorf("unable to find target channel "+
"for HTLC fail: channel ID = %s, "+
"HTLC ID = %d", pkt.outgoingChanID,
pkt.outgoingHTLCID)
log.Error(err)
return nil, err
}
return nil, nil
default:
return nil, err
}
}
func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error {
return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error {
return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...)
})
}
func (s *Switch) teardownCircuit(pkt *htlcPacket) error {
var pktType string
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateFulfillHTLC:
pktType = "SETTLE"
case *lnwire.UpdateFailHTLC:
pktType = "FAIL"
default:
err := fmt.Errorf("cannot tear down packet of type: %T", htlc)
log.Errorf(err.Error())
return err
}
switch {
case pkt.circuit.HasKeystone():
log.Debugf("Tearing down open circuit with %s pkt, removing circuit=%v "+
"with keystone=%v", pktType, pkt.inKey(), pkt.outKey())
err := s.circuits.DeleteCircuits(pkt.inKey())
if err != nil {
log.Warnf("Failed to tear down open circuit (%s, %d) <-> (%s, %d) "+
"with payment_hash-%v using %s pkt",
pkt.incomingChanID, pkt.incomingHTLCID,
pkt.outgoingChanID, pkt.outgoingHTLCID,
pkt.circuit.PaymentHash, pktType)
return err
}
log.Debugf("Closed completed %s circuit for %x: "+
"(%s, %d) <-> (%s, %d)", pktType, pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID,
pkt.outgoingChanID, pkt.outgoingHTLCID)
default:
log.Debugf("Tearing down incomplete circuit with %s for inkey=%v",
pktType, pkt.inKey())
err := s.circuits.DeleteCircuits(pkt.inKey())
if err != nil {
log.Warnf("Failed to tear down pending %s circuit for %x: "+
"(%s, %d)", pktType, pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID)
return err
}
log.Debugf("Removed pending onion circuit for %x: "+
"(%s, %d)", pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID)
}
return nil
}
func (s *Switch) CloseLink(chanPoint *wire.OutPoint,
closeType ChannelCloseType, targetFeePerKw chainfee.SatPerKWeight,
deliveryScript lnwire.DeliveryAddress) (chan interface{}, chan error) {
updateChan := make(chan interface{}, 2)
errChan := make(chan error, 1)
command := &ChanClose{
CloseType: closeType,
ChanPoint: chanPoint,
Updates: updateChan,
TargetFeePerKw: targetFeePerKw,
DeliveryScript: deliveryScript,
Err: errChan,
}
select {
case s.chanCloseRequests <- command:
return updateChan, errChan
case <-s.quit:
errChan <- ErrSwitchExiting
close(updateChan)
return updateChan, errChan
}
}
func (s *Switch) htlcForwarder() {
defer s.wg.Done()
defer func() {
s.blockEpochStream.Cancel()
var linksToStop []ChannelLink
s.indexMtx.Lock()
for _, link := range s.linkIndex {
activeLink := s.removeLink(link.ChanID())
if activeLink == nil {
log.Errorf("unable to remove ChannelLink(%v) "+
"on stop", link.ChanID())
continue
}
linksToStop = append(linksToStop, activeLink)
}
for _, link := range s.pendingLinkIndex {
pendingLink := s.removeLink(link.ChanID())
if pendingLink == nil {
log.Errorf("unable to remove ChannelLink(%v) "+
"on stop", link.ChanID())
continue
}
linksToStop = append(linksToStop, pendingLink)
}
s.indexMtx.Unlock()
var wg sync.WaitGroup
for _, link := range linksToStop {
wg.Add(1)
go func(l ChannelLink) {
defer wg.Done()
l.Stop()
}(link)
}
wg.Wait()
if err := s.FlushForwardingEvents(); err != nil {
log.Errorf("unable to flush forwarding events: %v", err)
}
}()
var (
totalNumUpdates uint64
totalSatSent btcutil.Amount
totalSatRecv btcutil.Amount
)
s.cfg.LogEventTicker.Resume()
defer s.cfg.LogEventTicker.Stop()
s.cfg.FwdEventTicker.Resume()
defer s.cfg.FwdEventTicker.Stop()
defer s.cfg.AckEventTicker.Stop()
out:
for {
if len(s.pendingSettleFails) > 0 {
s.cfg.AckEventTicker.Resume()
}
select {
case blockEpoch, ok := <-s.blockEpochStream.Epochs:
if !ok {
break out
}
atomic.StoreUint32(&s.bestHeight, uint32(blockEpoch.Height))
case req := <-s.chanCloseRequests:
chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint)
s.indexMtx.RLock()
link, ok := s.linkIndex[chanID]
if !ok {
s.indexMtx.RUnlock()
req.Err <- fmt.Errorf("no peer for channel with "+
"chan_id=%x", chanID[:])
continue
}
s.indexMtx.RUnlock()
peerPub := link.Peer().PubKey()
log.Debugf("Requesting local channel close: peer=%v, "+
"chan_id=%x", link.Peer(), chanID[:])
go s.cfg.LocalChannelClose(peerPub[:], req)
case resolutionMsg := <-s.resolutionMsgs:
pkt := &htlcPacket{
outgoingChanID: resolutionMsg.SourceChan,
outgoingHTLCID: resolutionMsg.HtlcIndex,
isResolution: true,
}
if resolutionMsg.Failure != nil {
pkt.htlc = &lnwire.UpdateFailHTLC{}
} else {
pkt.htlc = &lnwire.UpdateFulfillHTLC{
PaymentPreimage: *resolutionMsg.PreImage,
}
}
log.Infof("Received outside contract resolution, "+
"mapping to: %v", spew.Sdump(pkt))
err := s.handlePacketForward(pkt)
if err != nil {
log.Errorf("Unable to forward resolution msg: %v", err)
}
close(resolutionMsg.doneChan)
case cmd := <-s.htlcPlex:
cmd.err <- s.handlePacketForward(cmd.pkt)
case <-s.cfg.FwdEventTicker.Ticks():
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := s.FlushForwardingEvents(); err != nil {
log.Errorf("unable to flush "+
"forwarding events: %v", err)
}
}()
case <-s.cfg.LogEventTicker.Ticks():
prevSatSent := totalSatSent
prevSatRecv := totalSatRecv
prevNumUpdates := totalNumUpdates
var (
newNumUpdates uint64
newSatSent btcutil.Amount
newSatRecv btcutil.Amount
)
s.indexMtx.RLock()
for _, link := range s.linkIndex {
updates, sent, recv := link.Stats()
newNumUpdates += updates
newSatSent += sent.ToSatoshis()
newSatRecv += recv.ToSatoshis()
}
s.indexMtx.RUnlock()
var (
diffNumUpdates uint64
diffSatSent btcutil.Amount
diffSatRecv btcutil.Amount
)
if prevNumUpdates == 0 {
diffNumUpdates = newNumUpdates
diffSatSent = newSatSent
diffSatRecv = newSatRecv
} else {
diffNumUpdates = newNumUpdates - prevNumUpdates
diffSatSent = newSatSent - prevSatSent
diffSatRecv = newSatRecv - prevSatRecv
}
if diffNumUpdates == 0 {
continue
}
if int64(diffNumUpdates) < 0 {
totalNumUpdates = newNumUpdates
totalSatSent = newSatSent
totalSatRecv = newSatRecv
continue
}
log.Debugf("Sent %d satoshis and received %d satoshis "+
"in the last 10 seconds (%f tx/sec)",
diffSatSent, diffSatRecv,
float64(diffNumUpdates)/10)
totalNumUpdates += diffNumUpdates
totalSatSent += diffSatSent
totalSatRecv += diffSatRecv
case <-s.cfg.AckEventTicker.Ticks():
if len(s.pendingSettleFails) == 0 {
s.cfg.AckEventTicker.Pause()
continue
}
if err := s.ackSettleFail(s.pendingSettleFails...); err != nil {
log.Errorf("Unable to ack batch of settle/fails: %v", err)
continue
}
log.Tracef("Acked %d settle fails: %v", len(s.pendingSettleFails),
newLogClosure(func() string {
return spew.Sdump(s.pendingSettleFails)
}))
s.pendingSettleFails = s.pendingSettleFails[:0]
case <-s.quit:
return
}
}
}
func (s *Switch) Start() error {
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
log.Warn("Htlc Switch already started")
return errors.New("htlc switch already started")
}
log.Infof("Starting HTLC Switch")
blockEpochStream, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return err
}
s.blockEpochStream = blockEpochStream
s.wg.Add(1)
go s.htlcForwarder()
if err := s.reforwardResponses(); err != nil {
s.Stop()
log.Errorf("unable to reforward responses: %v", err)
return err
}
return nil
}
func (s *Switch) reforwardResponses() error {
openChannels, err := s.cfg.DB.FetchAllOpenChannels()
if err != nil {
return err
}
for _, openChannel := range openChannels {
shortChanID := openChannel.ShortChanID()
if shortChanID == hop.Source {
continue
}
if openChannel.IsPending {
continue
}
fwdPkgs, err := s.loadChannelFwdPkgs(shortChanID)
if err != nil {
log.Errorf("unable to load forwarding "+
"packages for %v: %v", shortChanID, err)
return err
}
s.reforwardSettleFails(fwdPkgs)
}
return nil
}
func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
var fwdPkgs []*channeldb.FwdPkg
if err := kvdb.Update(s.cfg.DB, func(tx kvdb.RwTx) error {
var err error
fwdPkgs, err = s.cfg.SwitchPackager.LoadChannelFwdPkgs(
tx, source,
)
return err
}); err != nil {
return nil, err
}
return fwdPkgs, nil
}
func (s *Switch) reforwardSettleFails(fwdPkgs []*channeldb.FwdPkg) {
for _, fwdPkg := range fwdPkgs {
settleFails, err := lnwallet.PayDescsFromRemoteLogUpdates(
fwdPkg.Source, fwdPkg.Height, fwdPkg.SettleFails,
)
if err != nil {
log.Errorf("Unable to process remote log updates: %v",
err)
continue
}
switchPackets := make([]*htlcPacket, 0, len(settleFails))
for i, pd := range settleFails {
if fwdPkg.SettleFailFilter.Contains(uint16(i)) {
continue
}
switch pd.EntryType {
case lnwallet.Settle:
settlePacket := &htlcPacket{
outgoingChanID: fwdPkg.Source,
outgoingHTLCID: pd.ParentIndex,
destRef: pd.DestRef,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: pd.RPreimage,
},
}
switchPackets = append(switchPackets, settlePacket)
case lnwallet.Fail:
failPacket := &htlcPacket{
outgoingChanID: fwdPkg.Source,
outgoingHTLCID: pd.ParentIndex,
destRef: pd.DestRef,
htlc: &lnwire.UpdateFailHTLC{
Reason: lnwire.OpaqueReason(pd.FailReason),
},
}
switchPackets = append(switchPackets, failPacket)
}
}
errChan := s.ForwardPackets(nil, switchPackets...)
go handleBatchFwdErrs(errChan, log)
}
}
func handleBatchFwdErrs(errChan chan error, l btclog.Logger) {
for {
err, ok := <-errChan
if !ok {
return
}
if err == nil {
continue
}
l.Errorf("Unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
}
}
func (s *Switch) Stop() error {
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
log.Warn("Htlc Switch already stopped")
return errors.New("htlc switch already shutdown")
}
log.Infof("HTLC Switch shutting down")
close(s.quit)
s.wg.Wait()
s.mailOrchestrator.Stop()
return nil
}
func (s *Switch) AddLink(link ChannelLink) error {
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
chanID := link.ChanID()
_, err := s.getLink(chanID)
if err == nil {
return fmt.Errorf("unable to add ChannelLink(%v), already "+
"active", chanID)
}
shortChanID := link.ShortChanID()
mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID, shortChanID)
link.AttachMailBox(mailbox)
if err := link.Start(); err != nil {
s.removeLink(chanID)
return err
}
if shortChanID == hop.Source {
log.Infof("Adding pending link chan_id=%v, short_chan_id=%v",
chanID, shortChanID)
s.pendingLinkIndex[chanID] = link
} else {
log.Infof("Adding live link chan_id=%v, short_chan_id=%v",
chanID, shortChanID)
s.addLiveLink(link)
s.mailOrchestrator.BindLiveShortChanID(
mailbox, chanID, shortChanID,
)
}
return nil
}
func (s *Switch) addLiveLink(link ChannelLink) {
s.linkIndex[link.ChanID()] = link
s.forwardingIndex[link.ShortChanID()] = link
peerPub := link.Peer().PubKey()
if _, ok := s.interfaceIndex[peerPub]; !ok {
s.interfaceIndex[peerPub] = make(map[lnwire.ChannelID]ChannelLink)
}
s.interfaceIndex[peerPub][link.ChanID()] = link
}
func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
return s.getLink(chanID)
}
func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
link, ok := s.linkIndex[chanID]
if !ok {
link, ok = s.pendingLinkIndex[chanID]
if !ok {
return nil, ErrChannelLinkNotFound
}
}
return link, nil
}
func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, error) {
link, ok := s.forwardingIndex[chanID]
if !ok {
return nil, ErrChannelLinkNotFound
}
return link, nil
}
func (s *Switch) HasActiveLink(chanID lnwire.ChannelID) bool {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
if link, ok := s.linkIndex[chanID]; ok {
return link.EligibleToForward()
}
return false
}
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) {
s.indexMtx.Lock()
link := s.removeLink(chanID)
s.indexMtx.Unlock()
if link != nil {
link.Stop()
}
}
func (s *Switch) removeLink(chanID lnwire.ChannelID) ChannelLink {
log.Infof("Removing channel link with ChannelID(%v)", chanID)
link, err := s.getLink(chanID)
if err != nil {
return nil
}
delete(s.pendingLinkIndex, link.ChanID())
delete(s.linkIndex, link.ChanID())
delete(s.forwardingIndex, link.ShortChanID())
peerPub := link.Peer().PubKey()
if peerIndex, ok := s.interfaceIndex[peerPub]; ok {
delete(peerIndex, link.ChanID())
if len(peerIndex) == 0 {
delete(s.interfaceIndex, peerPub)
}
}
return link
}
func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID) error {
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
link, ok := s.pendingLinkIndex[chanID]
if !ok {
return fmt.Errorf("link %v not found", chanID)
}
oldShortChanID := link.ShortChanID()
shortChanID, err := link.UpdateShortChanID()
if err != nil {
return err
}
if shortChanID == hop.Source {
return fmt.Errorf("refusing trivial short_chan_id for chan_id=%v"+
"live link", chanID)
}
log.Infof("Updated short_chan_id for ChannelLink(%v): old=%v, new=%v",
chanID, oldShortChanID, shortChanID)
delete(s.pendingLinkIndex, chanID)
s.addLiveLink(link)
mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID, shortChanID)
s.mailOrchestrator.BindLiveShortChanID(
mailbox, chanID, shortChanID,
)
return nil
}
func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
return s.getLinks(hop)
}
func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
links, ok := s.interfaceIndex[destination]
if !ok {
return nil, ErrNoLinksFound
}
channelLinks := make([]ChannelLink, 0, len(links))
for _, link := range links {
channelLinks = append(channelLinks, link)
}
return channelLinks, nil
}
func (s *Switch) CircuitModifier() CircuitModifier {
return s.circuits
}
func (s *Switch) commitCircuits(circuits ...*PaymentCircuit) (
*CircuitFwdActions, error) {
return s.circuits.CommitCircuits(circuits...)
}
func (s *Switch) openCircuits(keystones ...Keystone) error {
return s.circuits.OpenCircuits(keystones...)
}
func (s *Switch) deleteCircuits(inKeys ...CircuitKey) error {
return s.circuits.DeleteCircuits(inKeys...)
}
func (s *Switch) FlushForwardingEvents() error {
s.fwdEventMtx.Lock()
if len(s.pendingFwdingEvents) == 0 {
s.fwdEventMtx.Unlock()
return nil
}
events := make([]channeldb.ForwardingEvent, len(s.pendingFwdingEvents))
copy(events[:], s.pendingFwdingEvents[:])
s.pendingFwdingEvents = s.pendingFwdingEvents[:0]
s.fwdEventMtx.Unlock()
return s.cfg.FwdingLog.AddForwardingEvents(events)
}
func (s *Switch) BestHeight() uint32 {
return atomic.LoadUint32(&s.bestHeight)
}