package htlcswitch
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"sync"
"sync/atomic"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
)
const (
defaultDbDirectory = "sharedhashes"
)
var (
sharedHashBucket = []byte("shared-hash")
batchReplayBucket = []byte("batch-replay")
)
var (
ErrDecayedLogInit = errors.New("unable to initialize decayed log")
ErrDecayedLogCorrupted = errors.New("decayed log structure corrupted")
)
type DecayedLog struct {
started int32 stopped int32
dbPath string
db kvdb.Backend
notifier chainntnfs.ChainNotifier
wg sync.WaitGroup
quit chan struct{}
}
func NewDecayedLog(dbPath string,
notifier chainntnfs.ChainNotifier) *DecayedLog {
if dbPath == "" {
dbPath = defaultDbDirectory
}
return &DecayedLog{
dbPath: dbPath,
notifier: notifier,
quit: make(chan struct{}),
}
}
func (d *DecayedLog) Start() error {
if !atomic.CompareAndSwapInt32(&d.started, 0, 1) {
return nil
}
var err error
d.db, err = kvdb.Create(
kvdb.BoltBackendName, d.dbPath, true,
)
if err != nil {
return fmt.Errorf("could not open boltdb: %v", err)
}
if err := d.initBuckets(); err != nil {
return err
}
if d.notifier != nil {
epochClient, err := d.notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return fmt.Errorf("unable to register for epoch "+
"notifications: %v", err)
}
d.wg.Add(1)
go d.garbageCollector(epochClient)
}
return nil
}
func (d *DecayedLog) initBuckets() error {
return kvdb.Update(d.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(sharedHashBucket)
if err != nil {
return ErrDecayedLogInit
}
_, err = tx.CreateTopLevelBucket(batchReplayBucket)
if err != nil {
return ErrDecayedLogInit
}
return nil
})
}
func (d *DecayedLog) Stop() error {
if !atomic.CompareAndSwapInt32(&d.stopped, 0, 1) {
return nil
}
close(d.quit)
d.wg.Wait()
d.db.Close()
return nil
}
func (d *DecayedLog) garbageCollector(epochClient *chainntnfs.BlockEpochEvent) {
defer d.wg.Done()
defer epochClient.Cancel()
for {
select {
case epoch, ok := <-epochClient.Epochs:
if !ok {
log.Infof("Block epoch canceled, " +
"decaying hash log shutting down")
return
}
height := uint32(epoch.Height)
numExpired, err := d.gcExpiredHashes(height)
if err != nil {
log.Errorf("unable to expire hashes at "+
"height=%d", height)
}
if numExpired > 0 {
log.Infof("Garbage collected %v shared "+
"secret hashes at height=%v",
numExpired, height)
}
case <-d.quit:
log.Infof("Decaying hash log received " +
"shutdown request")
return
}
}
}
func (d *DecayedLog) gcExpiredHashes(height uint32) (uint32, error) {
var numExpiredHashes uint32
err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
numExpiredHashes = 0
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return fmt.Errorf("sharedHashBucket " +
"is nil")
}
var expiredCltv [][]byte
if err := sharedHashes.ForEach(func(k, v []byte) error {
cltv := uint32(binary.BigEndian.Uint32(v))
if cltv < height {
expiredCltv = append(expiredCltv, k)
numExpiredHashes++
}
return nil
}); err != nil {
return err
}
for _, hash := range expiredCltv {
err := sharedHashes.Delete(hash)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return 0, err
}
return numExpiredHashes, nil
}
func (d *DecayedLog) Delete(hash *sphinx.HashPrefix) error {
return kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
return sharedHashes.Delete(hash[:])
})
}
func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) {
var value uint32
err := kvdb.View(d.db, func(tx kvdb.ReadTx) error {
sharedHashes := tx.ReadBucket(sharedHashBucket)
if sharedHashes == nil {
return fmt.Errorf("sharedHashes is nil, could " +
"not retrieve CLTV value")
}
valueBytes := sharedHashes.Get(hash[:])
if valueBytes == nil {
return sphinx.ErrLogEntryNotFound
}
value = uint32(binary.BigEndian.Uint32(valueBytes))
return nil
})
if err != nil {
return value, err
}
return value, nil
}
func (d *DecayedLog) Put(hash *sphinx.HashPrefix, cltv uint32) error {
var scratch [4]byte
binary.BigEndian.PutUint32(scratch[:], cltv)
return kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
valueBytes := sharedHashes.Get(hash[:])
if valueBytes != nil {
return sphinx.ErrReplayedPacket
}
return sharedHashes.Put(hash[:], scratch[:])
})
}
func (d *DecayedLog) PutBatch(b *sphinx.Batch) (*sphinx.ReplaySet, error) {
var replays *sphinx.ReplaySet
if err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
batchReplayBkt := tx.ReadWriteBucket(batchReplayBucket)
if batchReplayBkt == nil {
return ErrDecayedLogCorrupted
}
replayBytes := batchReplayBkt.Get(b.ID)
if replayBytes != nil {
replays = sphinx.NewReplaySet()
return replays.Decode(bytes.NewReader(replayBytes))
}
var scratch [4]byte
replays = sphinx.NewReplaySet()
err := b.ForEach(func(seqNum uint16, hashPrefix *sphinx.HashPrefix, cltv uint32) error {
valueBytes := sharedHashes.Get(hashPrefix[:])
if valueBytes != nil {
replays.Add(seqNum)
return nil
}
binary.BigEndian.PutUint32(scratch[:], cltv)
return sharedHashes.Put(hashPrefix[:], scratch[:])
})
if err != nil {
return err
}
replays.Merge(b.ReplaySet)
var replayBuf bytes.Buffer
if err := replays.Encode(&replayBuf); err != nil {
return err
}
return batchReplayBkt.Put(b.ID, replayBuf.Bytes())
}); err != nil {
return nil, err
}
b.ReplaySet = replays
b.IsCommitted = true
return replays, nil
}
var _ sphinx.ReplayLog = (*DecayedLog)(nil)