package chanbackup
import (
"bytes"
"fmt"
"testing"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain"
)
var (
testWalletPrivKey = []byte{
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
}
)
type mockKeyRing struct {
fail bool
}
func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
return keychain.KeyDescriptor{}, nil
}
func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
if m.fail {
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
}
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testWalletPrivKey)
return keychain.KeyDescriptor{
PubKey: pub,
}, nil
}
func TestEncryptDecryptPayload(t *testing.T) {
t.Parallel()
payloadCases := []struct {
plaintext []byte
mutator func(*[]byte)
valid bool
}{
{
plaintext: []byte("payload test plain text"),
mutator: nil,
valid: true,
},
{
plaintext: []byte("payload test plain text"),
mutator: func(p *[]byte) {
(*p)[0] ^= 1
},
valid: false,
},
{
plaintext: []byte("payload test plain text"),
mutator: func(p *[]byte) {
*p = []byte{}
},
valid: false,
},
}
keyRing := &mockKeyRing{}
for i, payloadCase := range payloadCases {
var cipherBuffer bytes.Buffer
payloadReader := bytes.NewBuffer(payloadCase.plaintext)
err := encryptPayloadToWriter(
*payloadReader, &cipherBuffer, keyRing,
)
if err != nil {
t.Fatalf("unable encrypt paylaod: %v", err)
}
if payloadCase.mutator != nil {
cipherText := cipherBuffer.Bytes()
payloadCase.mutator(&cipherText)
cipherBuffer.Reset()
cipherBuffer.Write(cipherText)
}
plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing)
switch {
case err != nil && payloadCase.valid:
t.Fatalf("unable to decrypt valid payload case %v", i)
case err == nil && !payloadCase.valid:
t.Fatalf("payload was invalid yet was able to decrypt")
}
if payloadCase.valid &&
!bytes.Equal(plaintext, payloadCase.plaintext) {
t.Fatalf("#%v: expected %v, got %v: ", i,
payloadCase.plaintext, plaintext)
}
}
}
func TestInvalidKeyEncryption(t *testing.T) {
t.Parallel()
var b bytes.Buffer
err := encryptPayloadToWriter(b, &b, &mockKeyRing{true})
if err == nil {
t.Fatalf("expected error due to fail key gen")
}
}
func TestInvalidKeyDecrytion(t *testing.T) {
t.Parallel()
var b bytes.Buffer
_, err := decryptPayloadFromReader(&b, &mockKeyRing{true})
if err == nil {
t.Fatalf("expected error due to fail key gen")
}
}