package chanbackup
import (
"bytes"
"fmt"
"io"
"net"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
)
type SingleBackupVersion byte
const (
DefaultSingleVersion = 0
TweaklessCommitVersion = 1
AnchorsCommitVersion = 2
)
type Single struct {
Version SingleBackupVersion
IsInitiator bool
ChainHash chainhash.Hash
FundingOutpoint wire.OutPoint
ShortChannelID lnwire.ShortChannelID
RemoteNodePub *btcec.PublicKey
Addresses []net.Addr
Capacity btcutil.Amount
LocalChanCfg channeldb.ChannelConfig
RemoteChanCfg channeldb.ChannelConfig
ShaChainRootDesc keychain.KeyDescriptor
}
func NewSingle(channel *channeldb.OpenChannel,
nodeAddrs []net.Addr) Single {
var b bytes.Buffer
channel.RevocationProducer.Encode(&b)
_, shaChainPoint := btcec.PrivKeyFromBytes(btcec.S256(), b.Bytes())
chanID := channel.ShortChanID()
if chanID.BlockHeight == 0 {
chanID.BlockHeight = channel.FundingBroadcastHeight
}
single := Single{
IsInitiator: channel.IsInitiator,
ChainHash: channel.ChainHash,
FundingOutpoint: channel.FundingOutpoint,
ShortChannelID: chanID,
RemoteNodePub: channel.IdentityPub,
Addresses: nodeAddrs,
Capacity: channel.Capacity,
LocalChanCfg: channel.LocalChanCfg,
RemoteChanCfg: channel.RemoteChanCfg,
ShaChainRootDesc: keychain.KeyDescriptor{
PubKey: shaChainPoint,
KeyLocator: keychain.KeyLocator{
Family: keychain.KeyFamilyRevocationRoot,
},
},
}
switch {
case channel.ChanType.HasAnchors():
single.Version = AnchorsCommitVersion
case channel.ChanType.IsTweakless():
single.Version = TweaklessCommitVersion
default:
single.Version = DefaultSingleVersion
}
return single
}
func (s *Single) Serialize(w io.Writer) error {
switch s.Version {
case DefaultSingleVersion:
case TweaklessCommitVersion:
case AnchorsCommitVersion:
default:
return fmt.Errorf("unable to serialize w/ unknown "+
"version: %v", s.Version)
}
var shaChainPub [33]byte
if s.ShaChainRootDesc.PubKey != nil {
copy(
shaChainPub[:],
s.ShaChainRootDesc.PubKey.SerializeCompressed(),
)
}
var singleBytes bytes.Buffer
if err := lnwire.WriteElements(
&singleBytes,
s.IsInitiator,
s.ChainHash[:],
s.FundingOutpoint,
s.ShortChannelID,
s.RemoteNodePub,
s.Addresses,
s.Capacity,
s.LocalChanCfg.CsvDelay,
uint32(s.LocalChanCfg.MultiSigKey.Family),
s.LocalChanCfg.MultiSigKey.Index,
uint32(s.LocalChanCfg.RevocationBasePoint.Family),
s.LocalChanCfg.RevocationBasePoint.Index,
uint32(s.LocalChanCfg.PaymentBasePoint.Family),
s.LocalChanCfg.PaymentBasePoint.Index,
uint32(s.LocalChanCfg.DelayBasePoint.Family),
s.LocalChanCfg.DelayBasePoint.Index,
uint32(s.LocalChanCfg.HtlcBasePoint.Family),
s.LocalChanCfg.HtlcBasePoint.Index,
s.RemoteChanCfg.CsvDelay,
s.RemoteChanCfg.MultiSigKey.PubKey,
s.RemoteChanCfg.RevocationBasePoint.PubKey,
s.RemoteChanCfg.PaymentBasePoint.PubKey,
s.RemoteChanCfg.DelayBasePoint.PubKey,
s.RemoteChanCfg.HtlcBasePoint.PubKey,
shaChainPub[:],
uint32(s.ShaChainRootDesc.KeyLocator.Family),
s.ShaChainRootDesc.KeyLocator.Index,
); err != nil {
return err
}
return lnwire.WriteElements(
w,
byte(s.Version),
uint16(len(singleBytes.Bytes())),
singleBytes.Bytes(),
)
}
func (s *Single) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
var rawBytes bytes.Buffer
if err := s.Serialize(&rawBytes); err != nil {
return err
}
return encryptPayloadToWriter(rawBytes, w, keyRing)
}
func readLocalKeyDesc(r io.Reader) (keychain.KeyDescriptor, error) {
var keyDesc keychain.KeyDescriptor
var keyFam uint32
if err := lnwire.ReadElements(r, &keyFam); err != nil {
return keyDesc, err
}
keyDesc.Family = keychain.KeyFamily(keyFam)
if err := lnwire.ReadElements(r, &keyDesc.Index); err != nil {
return keyDesc, err
}
return keyDesc, nil
}
func readRemoteKeyDesc(r io.Reader) (keychain.KeyDescriptor, error) {
var (
keyDesc keychain.KeyDescriptor
pub [33]byte
)
_, err := io.ReadFull(r, pub[:])
if err != nil {
return keychain.KeyDescriptor{}, err
}
keyDesc.PubKey, err = btcec.ParsePubKey(pub[:], btcec.S256())
if err != nil {
return keychain.KeyDescriptor{}, err
}
keyDesc.PubKey.Curve = nil
return keyDesc, nil
}
func (s *Single) Deserialize(r io.Reader) error {
var version byte
err := lnwire.ReadElements(r, &version)
if err != nil {
return err
}
s.Version = SingleBackupVersion(version)
switch s.Version {
case DefaultSingleVersion:
case TweaklessCommitVersion:
case AnchorsCommitVersion:
default:
return fmt.Errorf("unable to de-serialize w/ unknown "+
"version: %v", s.Version)
}
var length uint16
if err := lnwire.ReadElements(r, &length); err != nil {
return err
}
err = lnwire.ReadElements(
r, &s.IsInitiator, s.ChainHash[:], &s.FundingOutpoint,
&s.ShortChannelID, &s.RemoteNodePub, &s.Addresses, &s.Capacity,
)
if err != nil {
return err
}
err = lnwire.ReadElements(r, &s.LocalChanCfg.CsvDelay)
if err != nil {
return err
}
s.LocalChanCfg.MultiSigKey, err = readLocalKeyDesc(r)
if err != nil {
return err
}
s.LocalChanCfg.RevocationBasePoint, err = readLocalKeyDesc(r)
if err != nil {
return err
}
s.LocalChanCfg.PaymentBasePoint, err = readLocalKeyDesc(r)
if err != nil {
return err
}
s.LocalChanCfg.DelayBasePoint, err = readLocalKeyDesc(r)
if err != nil {
return err
}
s.LocalChanCfg.HtlcBasePoint, err = readLocalKeyDesc(r)
if err != nil {
return err
}
err = lnwire.ReadElements(r, &s.RemoteChanCfg.CsvDelay)
if err != nil {
return err
}
s.RemoteChanCfg.MultiSigKey, err = readRemoteKeyDesc(r)
if err != nil {
return err
}
s.RemoteChanCfg.RevocationBasePoint, err = readRemoteKeyDesc(r)
if err != nil {
return err
}
s.RemoteChanCfg.PaymentBasePoint, err = readRemoteKeyDesc(r)
if err != nil {
return err
}
s.RemoteChanCfg.DelayBasePoint, err = readRemoteKeyDesc(r)
if err != nil {
return err
}
s.RemoteChanCfg.HtlcBasePoint, err = readRemoteKeyDesc(r)
if err != nil {
return err
}
var (
shaChainPub [33]byte
zeroPub [33]byte
)
if err := lnwire.ReadElements(r, shaChainPub[:]); err != nil {
return err
}
if !bytes.Equal(shaChainPub[:], zeroPub[:]) {
s.ShaChainRootDesc.PubKey, err = btcec.ParsePubKey(
shaChainPub[:], btcec.S256(),
)
if err != nil {
return err
}
}
var shaKeyFam uint32
if err := lnwire.ReadElements(r, &shaKeyFam); err != nil {
return err
}
s.ShaChainRootDesc.KeyLocator.Family = keychain.KeyFamily(shaKeyFam)
return lnwire.ReadElements(r, &s.ShaChainRootDesc.KeyLocator.Index)
}
func (s *Single) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
plaintext, err := decryptPayloadFromReader(r, keyRing)
if err != nil {
return err
}
backupReader := bytes.NewReader(plaintext)
return s.Deserialize(backupReader)
}
func PackStaticChanBackups(backups []Single,
keyRing keychain.KeyRing) (map[wire.OutPoint][]byte, error) {
packedBackups := make(map[wire.OutPoint][]byte)
for _, chanBackup := range backups {
chanPoint := chanBackup.FundingOutpoint
var b bytes.Buffer
err := chanBackup.PackToWriter(&b, keyRing)
if err != nil {
return nil, fmt.Errorf("unable to pack chan backup "+
"for %v: %v", chanPoint, err)
}
packedBackups[chanPoint] = b.Bytes()
}
return packedBackups, nil
}
type PackedSingles [][]byte
func (p PackedSingles) Unpack(keyRing keychain.KeyRing) ([]Single, error) {
backups := make([]Single, len(p))
for i, encryptedBackup := range p {
var backup Single
backupReader := bytes.NewReader(encryptedBackup)
err := backup.UnpackFromReader(backupReader, keyRing)
if err != nil {
return nil, err
}
backups[i] = backup
}
return backups, nil
}