package htlcswitch
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"time"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/fastsha256"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker"
)
var zeroCircuit = channeldb.CircuitKey{}
func genPreimage() ([32]byte, error) {
var preimage [32]byte
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
return preimage, err
}
return preimage, nil
}
func TestSwitchAddDuplicateLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
pendingChanID := lnwire.ShortChannelID{}
aliceChannelLink := newMockChannelLink(
s, chanID1, pendingChanID, alicePeer, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(aliceChannelLink); err == nil {
t.Fatalf("adding duplicate link should have failed")
}
aliceChannelLink.setLiveShortChanID(aliceChanID)
err = s.UpdateShortChanID(chanID1)
if err != nil {
t.Fatalf("unable to update alice short_chan_id: %v", err)
}
if err := s.AddLink(aliceChannelLink); err == nil {
t.Fatalf("adding duplicate link should have failed")
}
s.RemoveLink(chanID1)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
}
func TestSwitchHasActiveLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
pendingChanID := lnwire.ShortChannelID{}
aliceChannelLink := newMockChannelLink(
s, chanID1, pendingChanID, alicePeer, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if s.HasActiveLink(chanID1) {
t.Fatalf("link should not be active yet, still pending")
}
aliceChannelLink.setLiveShortChanID(aliceChanID)
err = s.UpdateShortChanID(chanID1)
if err != nil {
t.Fatalf("unable to update alice short_chan_id: %v", err)
}
aliceChannelLink.eligible = false
if s.HasActiveLink(chanID1) {
t.Fatalf("link should not be active yet, still ineligible")
}
aliceChannelLink.eligible = true
if !s.HasActiveLink(chanID1) {
t.Fatalf("link should not be active now")
}
}
func TestSwitchSendPending(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, bobChanID := genIDs()
pendingChanID := lnwire.ShortChannelID{}
aliceChannelLink := newMockChannelLink(
s, chanID1, pendingChanID, alicePeer, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: bobChanID,
incomingHTLCID: 0,
outgoingChanID: aliceChanID,
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
err = s.forward(packet)
linkErr, ok := err.(*LinkError)
if !ok {
t.Fatalf("expected link error, got: %T", err)
}
if linkErr.WireMessage().Code() != lnwire.CodeUnknownNextPeer {
t.Fatalf("expected fail unknown next peer, got: %T",
linkErr.WireMessage().Code())
}
select {
case <-aliceChannelLink.packets:
t.Fatal("expected not to receive message")
case <-time.After(time.Second):
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
aliceChannelLink.setLiveShortChanID(aliceChanID)
err = s.UpdateShortChanID(chanID1)
if err != nil {
t.Fatalf("unable to update alice short_chan_id: %v", err)
}
packet.incomingHTLCID++
if err := s.forward(packet); err != nil {
t.Fatalf("unexpected forward failure: %v", err)
}
select {
case <-aliceChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to alice")
}
}
func TestSwitchForward(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if err := s.forward(packet); err != nil {
t.Fatal(err)
}
select {
case <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
if !s.IsForwardedHTLC(bobChannelLink.ShortChanID(), 0) {
t.Fatal("htlc should be identified as forwarded")
}
packet = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
if err := s.forward(packet); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.deleteCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
fail := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
if err := s2.forward(fail); err != nil {
t.Fatalf(err.Error())
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
if err := s2.forward(fail); err == nil {
t.Fatalf("expected failure when sending duplicate fail " +
"with no pending circuit")
}
}
func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
settle := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
if err := s2.forward(settle); err != nil {
t.Fatalf(err.Error())
}
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
if err := s2.forward(settle); err != nil {
t.Fatalf("expected success when sending duplicate settle " +
"with no pending circuit")
}
}
func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of half circuits")
}
err = s2.forward(ogPacket)
if err != ErrDuplicateAdd {
t.Fatal("unexpected error when reforwarding a "+
"failed packet", err)
}
select {
case <-aliceChannelLink.packets:
t.Fatal("request should not have returned to source")
case <-bobChannelLink.packets:
t.Fatal("request should not have forwarded to destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
select {
case <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
err = s2.forward(ogPacket)
linkErr, ok := err.(*LinkError)
if !ok {
t.Fatalf("expected link error, got: %T", err)
}
if linkErr.FailureDetail != OutgoingFailureIncompleteForward {
t.Fatalf("expected incomplete forward, got: %v",
linkErr.FailureDetail)
}
select {
case <-aliceChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
}
func TestSwitchForwardCircuitPersistence(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
var packet *htlcPacket
select {
case packet = <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
ogPacket = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
if err := s2.forward(ogPacket); err != nil {
t.Fatal(err)
}
select {
case packet = <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits, want 1, got %d",
s2.circuits.NumPending())
}
if s2.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
if err := s2.Stop(); err != nil {
t.Fatal(err)
}
if err := cdb2.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb3, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s3, err := initSwitchWithDB(testStartingHeight, cdb3)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s3.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s3.Stop()
aliceChannelLink = newMockChannelLink(
s3, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s3, chanID2, bobChanID, bobPeer, true,
)
if err := s3.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s3.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s3.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s3.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
}
type multiHopFwdTest struct {
name string
eligible1, eligible2 bool
failure1, failure2 *LinkError
expectedReply lnwire.FailCode
}
func TestCircularForwards(t *testing.T) {
chanID1, aliceChanID := genID()
preimage := [sha256.Size]byte{1}
hash := fastsha256.Sum256(preimage[:])
tests := []struct {
name string
allowCircularPayment bool
expectedErr error
}{
{
name: "circular payment allowed",
allowCircularPayment: true,
expectedErr: nil,
},
{
name: "circular payment disallowed",
allowCircularPayment: false,
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil,
testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v",
err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() { _ = s.Stop() }()
s.cfg.AllowCircularRoute = test.allowCircularPayment
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
obfuscator := NewMockObfuscator()
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
outgoingChanID: aliceChannelLink.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: hash,
Amount: 1,
},
obfuscator: obfuscator,
}
err = s.forward(packet)
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
if s.circuits.NumOpen() > 0 {
t.Fatal("do not expect any open circuits")
}
})
}
}
func TestCheckCircularForward(t *testing.T) {
tests := []struct {
name string
allowCircular bool
incomingLink lnwire.ShortChannelID
outgoingLink lnwire.ShortChannelID
expectedErr *LinkError
}{
{
name: "not circular, allowed in config",
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "not circular, not allowed in config",
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "circular, allowed in config",
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: nil,
},
{
name: "circular, not allowed in config",
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
err := checkCircularForward(
test.incomingLink, test.outgoingLink,
test.allowCircular, lntypes.Hash{},
)
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
})
}
}
func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
tests := []multiHopFwdTest{
{
name: "not eligible",
expectedReply: lnwire.CodeUnknownNextPeer,
},
{
name: "policy fail",
eligible1: true,
failure1: NewLinkError(
lnwire.NewFinalIncorrectCltvExpiry(0),
),
expectedReply: lnwire.CodeFinalIncorrectCltvExpiry,
},
{
name: "non-strict success",
eligible2: true,
expectedReply: lnwire.CodeNone,
},
{
name: "non-strict policy fail",
eligible1: true,
failure1: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureInsufficientBalance,
),
eligible2: true,
failure2: NewLinkError(
lnwire.NewFinalIncorrectCltvExpiry(0),
),
expectedReply: lnwire.CodeTemporaryChannelFailure,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSkipIneligibleLinksMultiHopForward(t, &test)
})
}
}
func testSkipIneligibleLinksMultiHopForward(t *testing.T,
testCase *multiHopFwdTest) {
t.Parallel()
var packet *htlcPacket
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, aliceChanID := genID()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
chanID2, bobChanID2 := genID()
bobChannelLink1 := newMockChannelLink(
s, chanID2, bobChanID2, bobPeer, testCase.eligible1,
)
bobChannelLink1.checkHtlcForwardResult = testCase.failure1
chanID3, bobChanID3 := genID()
bobChannelLink2 := newMockChannelLink(
s, chanID3, bobChanID3, bobPeer, testCase.eligible2,
)
bobChannelLink2.checkHtlcForwardResult = testCase.failure2
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink1); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if err := s.AddLink(bobChannelLink2); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
obfuscator := NewMockObfuscator()
packet = &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink1.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
obfuscator: obfuscator,
}
err = s.forward(packet)
failure := obfuscator.(*mockObfuscator).failure
if testCase.expectedReply == lnwire.CodeNone {
if err != nil {
t.Fatalf("forwarding should have succeeded")
}
if failure != nil {
t.Fatalf("unexpected failure %T", failure)
}
} else {
if err == nil {
t.Fatalf("forwarding should have failed due to " +
"inactive link")
}
if failure.Code() != testCase.expectedReply {
t.Fatalf("unexpected failure %T", failure)
}
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSkipIneligibleLinksLocalForward(t *testing.T) {
t.Parallel()
testSkipLinkLocalForward(t, false, nil)
}
func TestSkipPolicyUnsatisfiedLinkLocalForward(t *testing.T) {
t.Parallel()
testSkipLinkLocalForward(t, true, lnwire.NewTemporaryChannelFailure(nil))
}
func testSkipLinkLocalForward(t *testing.T, eligible bool,
policyResult lnwire.FailureMessage) {
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, eligible,
)
aliceChannelLink.checkHtlcTransitResult = NewLinkError(
policyResult,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
addMsg := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg)
if err == nil {
t.Fatalf("local forward should fail due to inactive link")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSwitchCancel(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if err := s.forward(request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
if err := s.forward(request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumPending() != 0 {
t.Fatal("wrong amount of circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSwitchAddSamePayment(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if err := s.forward(request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 1,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if err := s.forward(request); err != nil {
t.Fatal(err)
}
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 2 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
if err := s.forward(request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 1,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
if err := s.forward(request); err != nil {
t.Fatal(err)
}
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSwitchSendPayment(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add link: %v", err)
}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
update := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
paymentID := uint64(123)
_, err = s.GetPaymentResult(
paymentID, rhash, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
}
errChan := make(chan error)
go func() {
err := s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update,
)
if err != nil {
errChan <- err
return
}
resultChan, err := s.GetPaymentResult(
paymentID, rhash, newMockDeobfuscator(),
)
if err != nil {
errChan <- err
return
}
result, ok := <-resultChan
if !ok {
errChan <- fmt.Errorf("shutting down")
}
if result.Error != nil {
errChan <- result.Error
return
}
errChan <- nil
}()
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case err := <-errChan:
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
obfuscator := NewMockObfuscator()
failure := lnwire.NewFailIncorrectDetails(update.Amount, 100)
reason, err := obfuscator.EncryptFirstHop(failure)
if err != nil {
t.Fatalf("unable obfuscate failure: %v", err)
}
if s.IsForwardedHTLC(aliceChannelLink.ShortChanID(), update.ID) {
t.Fatal("htlc should be identified as not forwarded")
}
packet := &htlcPacket{
outgoingChanID: aliceChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
if err := s.forward(packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
select {
case err := <-errChan:
assertFailureCode(
t, err, lnwire.CodeIncorrectOrUnknownPaymentDetails,
)
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
}
func TestLocalPaymentNoForwardingEvents(t *testing.T) {
t.Parallel()
channels, cleanUp, _, err := createClusterChannels(
btcutil.SatoshiPerBitcoin*3,
btcutil.SatoshiPerBitcoin*5)
if err != nil {
t.Fatalf("unable to create channel: %v", err)
}
defer cleanUp()
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
htlcAmt, totalTimelock, hops := generateHops(
amount, testStartingHeight, n.firstBobChannelLink,
)
receiver := n.bobServer
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, receiver, firstHop, hops, amount, htlcAmt,
totalTimelock,
).Wait(30 * time.Second)
if err != nil {
t.Fatalf("unable to make the payment: %v", err)
}
n.stop()
log, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
log.Lock()
defer log.Unlock()
if len(log.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(log.events))
}
}
func TestMultiHopPaymentForwardingEvents(t *testing.T) {
t.Parallel()
channels, cleanUp, _, err := createClusterChannels(
btcutil.SatoshiPerBitcoin*3,
btcutil.SatoshiPerBitcoin*5)
if err != nil {
t.Fatalf("unable to create channel: %v", err)
}
defer cleanUp()
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
const numPayments = 10
finalAmt := lnwire.NewMSatFromSatoshis(100000)
htlcAmt, totalTimelock, hops := generateHops(
finalAmt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
firstHop := n.firstBobChannelLink.ShortChanID()
for i := 0; i < numPayments/2; i++ {
_, err := makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
}
bobLog, ok := n.bobServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
bobTicker, ok := n.bobServer.htlcSwitch.cfg.FwdEventTicker.(*ticker.Force)
if !ok {
t.Fatalf("mockTicker assertion failed")
}
timeout := time.After(15 * time.Second)
for {
select {
case bobTicker.Force <- time.Now():
case <-time.After(1 * time.Second):
t.Fatalf("unable to force tick")
}
bobLog.Lock()
if len(bobLog.events) == 5 {
bobLog.Unlock()
break
}
bobLog.Unlock()
select {
case <-time.After(50 * time.Millisecond):
case <-timeout:
bobLog.Lock()
defer bobLog.Unlock()
t.Fatalf("expected 5 events in event log, instead "+
"found: %v", spew.Sdump(bobLog.events))
}
}
for i := numPayments / 2; i < numPayments; i++ {
_, err := makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
}
n.stop()
aliceLog, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
aliceLog.Lock()
defer aliceLog.Unlock()
if len(aliceLog.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(aliceLog.events))
}
carolLog, ok := n.carolServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
if !ok {
t.Fatalf("mockForwardingLog assertion failed")
}
carolLog.Lock()
defer carolLog.Unlock()
if len(carolLog.events) != 0 {
t.Fatalf("log should have no events, instead has: %v",
spew.Sdump(carolLog.events))
}
bobLog.Lock()
defer bobLog.Unlock()
if len(bobLog.events) != 10 {
t.Fatalf("log should have 10 events, instead has: %v",
spew.Sdump(bobLog.events))
}
for _, event := range bobLog.events {
if event.IncomingChanID != n.aliceChannelLink.ShortChanID() {
t.Fatalf("chan id mismatch: expected %v, got %v",
event.IncomingChanID,
n.aliceChannelLink.ShortChanID())
}
if event.OutgoingChanID != n.carolChannelLink.ShortChanID() {
t.Fatalf("chan id mismatch: expected %v, got %v",
event.OutgoingChanID,
n.carolChannelLink.ShortChanID())
}
if event.AmtIn != htlcAmt {
t.Fatalf("incoming amt mismatch: expected %v, got %v",
event.AmtIn, htlcAmt)
}
if event.AmtOut != finalAmt {
t.Fatalf("outgoing amt mismatch: expected %v, got %v",
event.AmtOut, finalAmt)
}
}
}
func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) {
t.Parallel()
channels, cleanUp, _, err := createClusterChannels(
btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
)
if err != nil {
t.Fatalf("unable to create channel: %v", err)
}
defer cleanUp()
n := newThreeHopNetwork(
t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight,
)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err)
}
assertPaymentFailure := func(t *testing.T) {
finalAmt := lnwire.NewMSatFromSatoshis(100000)
htlcAmt, totalTimelock, hops := generateHops(
finalAmt, testStartingHeight, n.firstBobChannelLink,
n.carolChannelLink,
)
firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment(
n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
htlcAmt, totalTimelock,
).Wait(30 * time.Second)
if err == nil {
t.Fatalf("unable to send payment: %v", err)
}
routingErr := err.(ClearTextError)
failureMsg := routingErr.WireMessage()
if _, ok := failureMsg.(*lnwire.FailInvalidOnionKey); !ok {
t.Fatalf("expected onion failure instead got: %v",
routingErr.WireMessage())
}
}
t.Run("multi-hop error conversion", func(t *testing.T) {
n.carolOnionDecoder.decodeFail = true
assertPaymentFailure(t)
})
t.Run("direct channel error conversion", func(t *testing.T) {
n.bobOnionDecoder.decodeFail = true
assertPaymentFailure(t)
})
}
func TestSwitchGetPaymentResult(t *testing.T) {
t.Parallel()
const paymentID = 123
var preimg lntypes.Preimage
preimg[0] = 3
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
lookup := make(chan *PaymentCircuit, 1)
s.circuits = &mockCircuitMap{
lookup: lookup,
}
lookup <- nil
_, err = s.GetPaymentResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
}
lookup <- &PaymentCircuit{}
resultChan, err := s.GetPaymentResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
if err != nil {
t.Fatalf("unable to get payment result: %v", err)
}
n := &networkResult{
msg: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimg,
},
unencrypted: true,
isResolution: true,
}
err = s.networkResults.storeResult(paymentID, n)
if err != nil {
t.Fatalf("unable to store result: %v", err)
}
select {
case res, ok := <-resultChan:
if !ok {
t.Fatalf("channel was closed")
}
if res.Error != nil {
t.Fatalf("got unexpected error result")
}
if res.Preimage != preimg {
t.Fatalf("expected preimg %v, got %v",
preimg, res.Preimage)
}
case <-time.After(1 * time.Second):
t.Fatalf("result not received")
}
lookup <- nil
resultChan, err = s.GetPaymentResult(
paymentID, lntypes.Hash{}, newMockDeobfuscator(),
)
if err != nil {
t.Fatalf("unable to get payment result: %v", err)
}
select {
case res, ok := <-resultChan:
if !ok {
t.Fatalf("channel was closed")
}
if res.Error != nil {
t.Fatalf("got unexpected error result")
}
if res.Preimage != preimg {
t.Fatalf("expected preimg %v, got %v",
preimg, res.Preimage)
}
case <-time.After(1 * time.Second):
t.Fatalf("result not received")
}
}
func TestInvalidFailure(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add link: %v", err)
}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
update := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
paymentID := uint64(123)
err = s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update,
)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
packet := &htlcPacket{
outgoingChanID: aliceChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{
Reason: []byte{1, 2, 3},
},
}
if err := s.forward(packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
deobfuscator := SphinxErrorDecrypter{
OnionErrorDecrypter: &mockOnionErrorDecryptor{
err: ErrUnreadableFailureMessage,
},
}
resultChan, err := s.GetPaymentResult(
paymentID, rhash, &deobfuscator,
)
if err != nil {
t.Fatal(err)
}
select {
case result := <-resultChan:
if result.Error != ErrUnreadableFailureMessage {
t.Fatal("expected unreadable failure message")
}
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
deobfuscator = SphinxErrorDecrypter{
OnionErrorDecrypter: &mockOnionErrorDecryptor{
sourceIdx: 2,
message: []byte{200},
},
}
resultChan, err = s.GetPaymentResult(
paymentID, rhash, &deobfuscator,
)
if err != nil {
t.Fatal(err)
}
select {
case result := <-resultChan:
rtErr, ok := result.Error.(ClearTextError)
if !ok {
t.Fatal("expected ClearTextError")
}
source, ok := rtErr.(*ForwardingError)
if !ok {
t.Fatalf("expected forwarding error, got: %T", rtErr)
}
if source.FailureSourceIdx != 2 {
t.Fatal("unexpected error source index")
}
if rtErr.WireMessage() != nil {
t.Fatal("expected empty failure message")
}
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
}
type htlcNotifierEvents func(channels *clusterChannels, htlcID uint64,
ts time.Time, htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload) ([]interface{}, []interface{}, []interface{})
func TestHtlcNotifier(t *testing.T) {
tests := []struct {
name string
options []serverOption
expectedEvents htlcNotifierEvents
iterations int
}{
{
name: "successful three hop payment",
options: nil,
expectedEvents: func(channels *clusterChannels,
htlcID uint64, ts time.Time,
htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload) ([]interface{},
[]interface{}, []interface{}) {
return getThreeHopEvents(
channels, htlcID, ts, htlc, hops, nil,
)
},
iterations: 2,
},
{
name: "failed at forwarding link",
options: []serverOption{
serverOptionRejectHtlc(false, true, false),
},
expectedEvents: func(channels *clusterChannels,
htlcID uint64, ts time.Time,
htlc *lnwire.UpdateAddHTLC,
hops []*hop.Payload) ([]interface{},
[]interface{}, []interface{}) {
return getThreeHopEvents(
channels, htlcID, ts, htlc, hops,
&LinkError{
msg: &lnwire.FailChannelDisabled{},
FailureDetail: OutgoingFailureForwardsDisabled,
},
)
},
iterations: 1,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testHtcNotifier(
t, test.options, test.iterations,
test.expectedEvents,
)
})
}
}
func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
getEvents htlcNotifierEvents) {
t.Parallel()
channels, cleanUp, _, err := createClusterChannels(
btcutil.SatoshiPerBitcoin*3,
btcutil.SatoshiPerBitcoin*5)
if err != nil {
t.Fatalf("unable to create channel: %v", err)
}
defer cleanUp()
now := time.Now()
mockTime := func() time.Time {
return now
}
aliceNotifier := NewHtlcNotifier(mockTime)
if err := aliceNotifier.Start(); err != nil {
t.Fatalf("could not start alice notifier")
}
defer aliceNotifier.Stop()
bobNotifier := NewHtlcNotifier(mockTime)
if err := bobNotifier.Start(); err != nil {
t.Fatalf("could not start bob notifier")
}
defer bobNotifier.Stop()
carolNotifier := NewHtlcNotifier(mockTime)
if err := carolNotifier.Start(); err != nil {
t.Fatalf("could not start carol notifier")
}
defer carolNotifier.Stop()
notifierOption := serverOptionWithHtlcNotifier(
aliceNotifier, bobNotifier, carolNotifier,
)
options := append(testOpts, notifierOption)
n := newThreeHopNetwork(
t, channels.aliceToBob,
channels.bobToAlice, channels.bobToCarol,
channels.carolToBob, testStartingHeight,
options...,
)
if err := n.start(); err != nil {
t.Fatalf("unable to start three hop "+
"network: %v", err)
}
defer n.stop()
aliceEvents, err := aliceNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to alice's"+
" events: %v", err)
}
defer aliceEvents.Cancel()
bobEvents, err := bobNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to bob's"+
" events: %v", err)
}
defer bobEvents.Cancel()
carolEvents, err := carolNotifier.SubscribeHtlcEvents()
if err != nil {
t.Fatalf("could not subscribe to carol's"+
" events: %v", err)
}
defer carolEvents.Cancel()
for i := 0; i < iterations; i++ {
htlc, hops := n.sendThreeHopPayment(t)
alice, bob, carol := getEvents(
channels, uint64(i), now, htlc, hops,
)
checkHtlcEvents(t, aliceEvents.Updates(), alice)
checkHtlcEvents(t, bobEvents.Updates(), bob)
checkHtlcEvents(t, carolEvents.Updates(), carol)
}
}
func checkHtlcEvents(t *testing.T, events <-chan interface{},
expectedEvents []interface{}) {
t.Helper()
for _, expected := range expectedEvents {
select {
case event := <-events:
if !reflect.DeepEqual(event, expected) {
t.Fatalf("expected %v, got: %v", expected,
event)
}
case <-time.After(5 * time.Second):
t.Fatalf("expected event: %v", expected)
}
}
}
func (n *threeHopNetwork) sendThreeHopPayment(t *testing.T) (*lnwire.UpdateAddHTLC,
[]*hop.Payload) {
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink)
blob, err := generateRoute(hops...)
if err != nil {
t.Fatal(err)
}
invoice, htlc, pid, err := generatePayment(
amount, htlcAmt, totalTimelock, blob,
)
if err != nil {
t.Fatal(err)
}
err = n.carolServer.registry.AddInvoice(*invoice, htlc.PaymentHash)
if err != nil {
t.Fatalf("unable to add invoice in carol registry: %v", err)
}
if err := n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
); err != nil {
t.Fatalf("could not send htlc")
}
return htlc, hops
}
func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
ts time.Time, htlc *lnwire.UpdateAddHTLC, hops []*hop.Payload,
linkError *LinkError) ([]interface{}, []interface{}, []interface{}) {
aliceKey := HtlcKey{
IncomingCircuit: zeroCircuit,
OutgoingCircuit: channeldb.CircuitKey{
ChanID: channels.aliceToBob.ShortChanID(),
HtlcID: htlcID,
},
}
aliceEvents := []interface{}{
&ForwardingEvent{
HtlcKey: aliceKey,
HtlcInfo: HtlcInfo{
OutgoingTimeLock: htlc.Expiry,
OutgoingAmt: htlc.Amount,
},
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
}
bobKey := HtlcKey{
IncomingCircuit: channeldb.CircuitKey{
ChanID: channels.bobToAlice.ShortChanID(),
HtlcID: htlcID,
},
OutgoingCircuit: channeldb.CircuitKey{
ChanID: channels.bobToCarol.ShortChanID(),
HtlcID: htlcID,
},
}
bobInfo := HtlcInfo{
IncomingTimeLock: htlc.Expiry,
IncomingAmt: htlc.Amount,
OutgoingTimeLock: hops[1].FwdInfo.OutgoingCTLV,
OutgoingAmt: hops[1].FwdInfo.AmountToForward,
}
if linkError != nil {
aliceEvents = append(aliceEvents,
&ForwardingFailEvent{
HtlcKey: aliceKey,
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
)
bobEvents := []interface{}{
&LinkFailEvent{
HtlcKey: bobKey,
HtlcInfo: bobInfo,
HtlcEventType: HtlcEventTypeForward,
LinkError: linkError,
Incoming: false,
Timestamp: ts,
},
}
return aliceEvents, bobEvents, nil
}
aliceEvents = append(
aliceEvents,
&SettleEvent{
HtlcKey: aliceKey,
HtlcEventType: HtlcEventTypeSend,
Timestamp: ts,
},
)
bobEvents := []interface{}{
&ForwardingEvent{
HtlcKey: bobKey,
HtlcInfo: bobInfo,
HtlcEventType: HtlcEventTypeForward,
Timestamp: ts,
},
&SettleEvent{
HtlcKey: bobKey,
HtlcEventType: HtlcEventTypeForward,
Timestamp: ts,
},
}
carolEvents := []interface{}{
&SettleEvent{
HtlcKey: HtlcKey{
IncomingCircuit: channeldb.CircuitKey{
ChanID: channels.carolToBob.ShortChanID(),
HtlcID: htlcID,
},
OutgoingCircuit: zeroCircuit,
},
HtlcEventType: HtlcEventTypeReceive,
Timestamp: ts,
},
}
return aliceEvents, bobEvents, carolEvents
}