package pool_test
import (
"bytes"
crand "crypto/rand"
"fmt"
"io"
"math/rand"
"testing"
"time"
"github.com/lightningnetwork/lnd/buffer"
"github.com/lightningnetwork/lnd/pool"
)
type workerPoolTest struct {
name string
newPool func() interface{}
numWorkers int
}
func TestConcreteWorkerPools(t *testing.T) {
const (
gcInterval = time.Second
expiryInterval = 250 * time.Millisecond
numWorkers = 5
workerTimeout = 500 * time.Millisecond
)
tests := []workerPoolTest{
{
name: "write pool",
newPool: func() interface{} {
bp := pool.NewWriteBuffer(
gcInterval, expiryInterval,
)
return pool.NewWrite(
bp, numWorkers, workerTimeout,
)
},
numWorkers: numWorkers,
},
{
name: "read pool",
newPool: func() interface{} {
bp := pool.NewReadBuffer(
gcInterval, expiryInterval,
)
return pool.NewRead(
bp, numWorkers, workerTimeout,
)
},
numWorkers: numWorkers,
},
}
for _, test := range tests {
testWorkerPool(t, test)
}
}
func testWorkerPool(t *testing.T, test workerPoolTest) {
t.Run(test.name+" non blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
defer stopGeneric(t, p)
submitNonblockingGeneric(t, p, test.numWorkers)
})
t.Run(test.name+" blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
defer stopGeneric(t, p)
submitBlockingGeneric(t, p, test.numWorkers)
})
t.Run(test.name+" partial blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
defer stopGeneric(t, p)
submitPartialBlockingGeneric(t, p, test.numWorkers)
})
}
func submitNonblockingGeneric(t *testing.T, p interface{}, nWorkers int) {
nUnblocked := 2 * nWorkers
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nUnblocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
pullNothing(t, errChan)
close(semChan)
pullParllel(t, nUnblocked, errChan)
pullNothing(t, errChan)
}
func submitBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
nBlocked := 2 * nWorkers
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nBlocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
pullNothing(t, errChan)
pullSequntial(t, nBlocked, errChan, semChan)
pullNothing(t, errChan)
}
func submitPartialBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
nBlocked := nWorkers - 1
nUnblocked := 2*nWorkers - nBlocked
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nBlocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
pullNothing(t, errChan)
semChanNB := make(chan struct{})
for i := 0; i < nUnblocked; i++ {
go func() { errChan <- submitGeneric(p, semChanNB) }()
}
pullNothing(t, errChan)
close(semChanNB)
pullParllel(t, nUnblocked, errChan)
pullNothing(t, errChan)
pullSequntial(t, nBlocked, errChan, semChan)
pullNothing(t, errChan)
}
func pullNothing(t *testing.T, errChan chan error) {
t.Helper()
select {
case err := <-errChan:
t.Fatalf("received unexpected error before semaphore "+
"release: %v", err)
case <-time.After(time.Second):
}
}
func pullParllel(t *testing.T, n int, errChan chan error) {
t.Helper()
for i := 0; i < n; i++ {
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
case <-time.After(time.Second):
t.Fatalf("task %d was not processed in time", i)
}
}
}
func pullSequntial(t *testing.T, n int, errChan chan error, semChan chan struct{}) {
t.Helper()
for i := 0; i < n; i++ {
select {
case semChan <- struct{}{}:
case <-time.After(time.Second):
t.Fatalf("task %d was not unblocked", i)
}
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
case <-time.After(time.Second):
t.Fatalf("task %d was not processed in time", i)
}
}
}
func startGeneric(t *testing.T, p interface{}) {
t.Helper()
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Start()
case *pool.Read:
err = pp.Start()
default:
t.Fatalf("unknown worker pool type: %T", p)
}
if err != nil {
t.Fatalf("unable to start worker pool: %v", err)
}
}
func stopGeneric(t *testing.T, p interface{}) {
t.Helper()
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Stop()
case *pool.Read:
err = pp.Stop()
default:
t.Fatalf("unknown worker pool type: %T", p)
}
if err != nil {
t.Fatalf("unable to stop worker pool: %v", err)
}
}
func submitGeneric(p interface{}, sem <-chan struct{}) error {
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Submit(func(buf *bytes.Buffer) error {
if buf.Len() != 0 {
return fmt.Errorf("buf should be length zero, "+
"instead has length %d", buf.Len())
}
if buf.Cap() != buffer.WriteSize {
return fmt.Errorf("buf should have capacity "+
"%d, instead has capacity %d",
buffer.WriteSize, buf.Cap())
}
b := make([]byte, rand.Intn(buf.Cap()))
_, err := io.ReadFull(crand.Reader, b)
if err != nil {
return err
}
_, err = buf.Write(b)
<-sem
return err
})
case *pool.Read:
err = pp.Submit(func(buf *buffer.Read) error {
for i := range buf[:] {
if buf[i] != 0x00 {
return fmt.Errorf("byte %d of "+
"buffer.Read should be "+
"0, instead is %d", i, buf[i])
}
}
_, err := io.ReadFull(crand.Reader, buf[:])
<-sem
return err
})
default:
return fmt.Errorf("unknown worker pool type: %T", p)
}
if err != nil {
return fmt.Errorf("unable to submit task: %v", err)
}
return nil
}