package keychain
import (
"fmt"
"io/ioutil"
"math/rand"
"os"
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcwallet/snacl"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/davecgh/go-spew/spew"
_ "github.com/btcsuite/btcwallet/walletdb/bdb" )
var versionZeroKeyFamilies = []KeyFamily{
KeyFamilyMultiSig,
KeyFamilyRevocationBase,
KeyFamilyHtlcBase,
KeyFamilyPaymentBase,
KeyFamilyDelayBase,
KeyFamilyRevocationRoot,
KeyFamilyNodeKey,
}
var (
testHDSeed = chainhash.Hash{
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
0x4f, 0x2f, 0x6f, 0x25, 0x98, 0xa3, 0xef, 0xb9,
0x69, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
}
)
func createTestBtcWallet(coinType uint32) (func(), *wallet.Wallet, error) {
fastScrypt := waddrmgr.FastScryptOptions
keyGen := func(passphrase *[]byte, config *waddrmgr.ScryptOptions) (
*snacl.SecretKey, error) {
return snacl.NewSecretKey(
passphrase, fastScrypt.N, fastScrypt.R, fastScrypt.P,
)
}
waddrmgr.SetSecretKeyGen(keyGen)
tempDir, err := ioutil.TempDir("", "keyring-lnwallet")
if err != nil {
return nil, nil, err
}
loader := wallet.NewLoader(&chaincfg.SimNetParams, tempDir, true, 0)
pass := []byte("test")
baseWallet, err := loader.CreateNewWallet(
pass, pass, testHDSeed[:], time.Time{},
)
if err != nil {
return nil, nil, err
}
if err := baseWallet.Unlock(pass, nil); err != nil {
return nil, nil, err
}
chainKeyScope := waddrmgr.KeyScope{
Purpose: BIP0043Purpose,
Coin: coinType,
}
_, err = baseWallet.Manager.FetchScopedKeyManager(chainKeyScope)
if err != nil {
err := walletdb.Update(baseWallet.Database(), func(tx walletdb.ReadWriteTx) error {
addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
_, err := baseWallet.Manager.NewScopedKeyManager(
addrmgrNs, chainKeyScope, lightningAddrSchema,
)
return err
})
if err != nil {
return nil, nil, err
}
}
cleanUp := func() {
baseWallet.Lock()
os.RemoveAll(tempDir)
}
return cleanUp, baseWallet, nil
}
func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) {
t.Helper()
if a != b {
t.Fatalf("mismatched key locators: expected %v, "+
"got %v", spew.Sdump(a), spew.Sdump(b))
}
}
type keyRingConstructor func() (string, func(), KeyRing, error)
func TestKeyRingDerivation(t *testing.T) {
t.Parallel()
keyRingImplementations := []keyRingConstructor{
func() (string, func(), KeyRing, error) {
cleanUp, wallet, err := createTestBtcWallet(
CoinTypeBitcoin,
)
if err != nil {
t.Fatalf("unable to create wallet: %v", err)
}
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeBitcoin)
return "btcwallet", cleanUp, keyRing, nil
},
func() (string, func(), KeyRing, error) {
cleanUp, wallet, err := createTestBtcWallet(
CoinTypeLitecoin,
)
if err != nil {
t.Fatalf("unable to create wallet: %v", err)
}
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeLitecoin)
return "ltcwallet", cleanUp, keyRing, nil
},
func() (string, func(), KeyRing, error) {
cleanUp, wallet, err := createTestBtcWallet(
CoinTypeTestnet,
)
if err != nil {
t.Fatalf("unable to create wallet: %v", err)
}
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeTestnet)
return "testwallet", cleanUp, keyRing, nil
},
}
const numKeysToDerive = 10
for _, keyRingConstructor := range keyRingImplementations {
keyRingName, cleanUp, keyRing, err := keyRingConstructor()
if err != nil {
t.Fatalf("unable to create key ring %v: %v", keyRingName,
err)
}
defer cleanUp()
success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) {
for _, keyFam := range versionZeroKeyFamilies {
keyDesc, err := keyRing.DeriveNextKey(keyFam)
if err != nil {
t.Fatalf("unable to derive next for "+
"keyFam=%v: %v", keyFam, err)
}
assertEqualKeyLocator(t,
KeyLocator{
Family: keyFam,
Index: 0,
}, keyDesc.KeyLocator,
)
keyLoc := KeyLocator{
Family: keyFam,
Index: 0,
}
firstKeyDesc, err := keyRing.DeriveKey(keyLoc)
if err != nil {
t.Fatalf("unable to derive first key for "+
"keyFam=%v: %v", keyFam, err)
}
if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) {
t.Fatalf("mismatched keys: expected %x, "+
"got %x",
keyDesc.PubKey.SerializeCompressed(),
firstKeyDesc.PubKey.SerializeCompressed())
}
assertEqualKeyLocator(t,
KeyLocator{
Family: keyFam,
Index: 0,
}, firstKeyDesc.KeyLocator,
)
for i := 0; i < numKeysToDerive+1; i++ {
keyLoc := KeyLocator{
Family: keyFam,
Index: uint32(i),
}
keyDesc, err := keyRing.DeriveKey(keyLoc)
if err != nil {
t.Fatalf("unable to derive first key for "+
"keyFam=%v: %v", keyFam, err)
}
assertEqualKeyLocator(
t, keyLoc, keyDesc.KeyLocator,
)
}
randKeyIndex := uint32(rand.Int31())
keyLoc = KeyLocator{
Family: keyFam,
Index: randKeyIndex,
}
keyDesc, err = keyRing.DeriveKey(keyLoc)
if err != nil {
t.Fatalf("unable to derive key_index=%v "+
"for keyFam=%v: %v",
randKeyIndex, keyFam, err)
}
assertEqualKeyLocator(
t, keyLoc, keyDesc.KeyLocator,
)
}
})
if !success {
break
}
}
}
type secretKeyRingConstructor func() (string, func(), SecretKeyRing, error)
func TestSecretKeyRingDerivation(t *testing.T) {
t.Parallel()
secretKeyRingImplementations := []secretKeyRingConstructor{
func() (string, func(), SecretKeyRing, error) {
cleanUp, wallet, err := createTestBtcWallet(
CoinTypeBitcoin,
)
if err != nil {
t.Fatalf("unable to create wallet: %v", err)
}
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeBitcoin)
return "btcwallet", cleanUp, keyRing, nil
},
func() (string, func(), SecretKeyRing, error) {
cleanUp, wallet, err := createTestBtcWallet(
CoinTypeLitecoin,
)
if err != nil {
t.Fatalf("unable to create wallet: %v", err)
}
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeLitecoin)
return "ltcwallet", cleanUp, keyRing, nil
},
func() (string, func(), SecretKeyRing, error) {
cleanUp, wallet, err := createTestBtcWallet(
CoinTypeTestnet,
)
if err != nil {
t.Fatalf("unable to create wallet: %v", err)
}
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeTestnet)
return "testwallet", cleanUp, keyRing, nil
},
}
for _, secretKeyRingConstructor := range secretKeyRingImplementations {
keyRingName, cleanUp, secretKeyRing, err := secretKeyRingConstructor()
if err != nil {
t.Fatalf("unable to create secret key ring %v: %v",
keyRingName, err)
}
defer cleanUp()
success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) {
for _, keyFam := range versionZeroKeyFamilies {
randKeyIndex := uint32(rand.Int31())
keyLoc := KeyLocator{
Family: keyFam,
Index: randKeyIndex,
}
pubKeyDesc, err := secretKeyRing.DeriveKey(keyLoc)
if err != nil {
t.Fatalf("unable to derive pubkey "+
"(fam=%v, index=%v): %v",
keyLoc.Family,
keyLoc.Index, err)
}
privKey, err := secretKeyRing.DerivePrivKey(KeyDescriptor{
KeyLocator: keyLoc,
})
if err != nil {
t.Fatalf("unable to derive priv "+
"(fam=%v, index=%v): %v", keyLoc.Family,
keyLoc.Index, err)
}
if !pubKeyDesc.PubKey.IsEqual(privKey.PubKey()) {
t.Fatalf("pubkeys mismatched: expected %x, got %x",
pubKeyDesc.PubKey.SerializeCompressed(),
privKey.PubKey().SerializeCompressed())
}
keyDesc, err := secretKeyRing.DeriveNextKey(keyFam)
if err != nil {
t.Fatalf("unable to derive key: %v", err)
}
keyDesc = KeyDescriptor{
PubKey: keyDesc.PubKey,
KeyLocator: KeyLocator{
Family: keyFam,
},
}
privKey, err = secretKeyRing.DerivePrivKey(keyDesc)
if err != nil {
t.Fatalf("unable to derive priv key "+
"via scanning: %v", err)
}
if !keyDesc.PubKey.IsEqual(privKey.PubKey()) {
t.Fatalf("pubkeys mismatched: expected %x, got %x",
pubKeyDesc.PubKey.SerializeCompressed(),
privKey.PubKey().SerializeCompressed())
}
_, pub := btcec.PrivKeyFromBytes(
btcec.S256(), testHDSeed[:],
)
keyDesc.PubKey = pub
privKey, err = secretKeyRing.DerivePrivKey(
keyDesc,
)
if err != ErrCannotDerivePrivKey {
t.Fatalf("expected %T, instead got %v",
ErrCannotDerivePrivKey, err)
}
}
})
if !success {
break
}
}
}
func init() {
MaxKeyRangeScan = 3
}