package utils
import (
"runtime"
"sync"
"testing"
"time"
"go.uber.org/atomic"
"github.com/benbjohnson/clock"
"github.com/stretchr/testify/require"
)
const UnstableTest = "UNSTABLE TEST"
type Limiter interface {
Take() time.Time
}
type config struct {
clock Clock
slack int
per time.Duration
}
func buildConfig(opts []Option) config {
c := config{
clock: clock.New(),
slack: 10,
per: time.Second,
}
for _, opt := range opts {
opt.apply(&c)
}
return c
}
type Option interface {
apply(*config)
}
type clockOption struct {
clock Clock
}
func (o clockOption) apply(c *config) {
c.clock = o.clock
}
func WithClock(clock Clock) Option {
return clockOption{clock: clock}
}
type slackOption int
func (o slackOption) apply(c *config) {
c.slack = int(o)
}
var WithoutSlack Option = slackOption(0)
func WithSlack(slack int) Option {
return slackOption(slack)
}
type perOption time.Duration
func (p perOption) apply(c *config) {
c.per = time.Duration(p)
}
func Per(per time.Duration) Option {
return perOption(per)
}
type testRunner interface {
createLimiter(int, ...Option) Limiter
takeOnceAfter(time.Duration, Limiter)
startTaking(rls ...Limiter)
assertCountAt(d time.Duration, count int)
assertCountAtWithNoise(d time.Duration, count int, noise int)
afterFunc(d time.Duration, fn func())
getClock() *clock.Mock
}
type runnerImpl struct {
t *testing.T
clock *clock.Mock
constructor func(int, ...Option) Limiter
count atomic.Int32
maxDuration time.Duration
doneCh chan struct{}
wg sync.WaitGroup
}
func runTest(t *testing.T, fn func(testRunner)) {
impls := []struct {
name string
constructor func(int, ...Option) Limiter
}{
{
name: "mutex",
constructor: func(rate int, opts ...Option) Limiter {
config := buildConfig(opts)
perRequest := config.per / time.Duration(rate)
cfg := leakyBucketConfig{
perRequest: perRequest,
maxSlack: -1 * time.Duration(config.slack) * perRequest,
}
l := &LeakyBucket{
clock: config.clock,
}
l.cfg.Store(&cfg)
return l
},
},
}
for _, tt := range impls {
t.Run(tt.name, func(t *testing.T) {
clockMock := clock.NewMock()
clockMock.Set(time.Now())
r := runnerImpl{
t: t,
clock: clockMock,
constructor: tt.constructor,
doneCh: make(chan struct{}),
}
fn(&r)
r.advanceUntilDone()
close(r.doneCh)
})
}
}
func (r *runnerImpl) advanceUntilDone() {
if r.maxDuration <= 0 {
r.wg.Wait()
return
}
waitDone := make(chan struct{})
go func() {
r.wg.Wait()
close(waitDone)
}()
step := r.clockAdvanceStep()
for {
select {
case <-waitDone:
return
default:
}
r.clock.Add(step)
runtime.Gosched()
}
}
func (r *runnerImpl) clockAdvanceStep() time.Duration {
step := r.maxDuration / 1_000
if step < time.Millisecond {
return time.Millisecond
}
if step > 100*time.Millisecond {
return 100 * time.Millisecond
}
return step
}
func (r *runnerImpl) createLimiter(rate int, opts ...Option) Limiter {
opts = append(opts, WithClock(r.clock))
return r.constructor(rate, opts...)
}
func (r *runnerImpl) getClock() *clock.Mock {
return r.clock
}
func (r *runnerImpl) startTaking(rls ...Limiter) {
r.goWait(func() {
for {
for _, rl := range rls {
rl.Take()
}
r.count.Inc()
select {
case <-r.doneCh:
return
default:
}
}
})
}
func (r *runnerImpl) takeOnceAfter(d time.Duration, rl Limiter) {
r.wg.Add(1)
r.afterFunc(d, func() {
rl.Take()
r.count.Inc()
r.wg.Done()
})
}
func (r *runnerImpl) assertCountAt(d time.Duration, count int) {
r.wg.Add(1)
r.afterFunc(d, func() {
defer r.wg.Done()
require.Equal(r.t, int32(count), r.count.Load(), "count not as expected")
})
}
func (r *runnerImpl) assertCountAtWithNoise(d time.Duration, count int, noise int) {
r.wg.Add(1)
r.afterFunc(d, func() {
defer r.wg.Done()
require.InDelta(r.t, count, int(r.count.Load()), float64(noise),
"expected count to be within noise tolerance")
})
}
func (r *runnerImpl) afterFunc(d time.Duration, fn func()) {
if d > r.maxDuration {
r.maxDuration = d
}
r.goWait(func() {
select {
case <-r.doneCh:
return
case <-r.clock.After(d):
}
fn()
})
}
func (r *runnerImpl) goWait(fn func()) {
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
fn()
}()
wg.Wait()
}
func TestRateLimiter(t *testing.T) {
runTest(t, func(r testRunner) {
rl := r.createLimiter(100, WithoutSlack)
r.startTaking(rl)
r.startTaking(rl)
r.startTaking(rl)
r.startTaking(rl)
r.assertCountAtWithNoise(1*time.Second, 100, 2)
r.assertCountAtWithNoise(2*time.Second, 200, 2)
r.assertCountAtWithNoise(3*time.Second, 300, 2)
})
}
func TestDelayedRateLimiter(t *testing.T) {
t.Skip(UnstableTest)
runTest(t, func(r testRunner) {
slow := r.createLimiter(10, WithoutSlack)
fast := r.createLimiter(100, WithoutSlack)
r.startTaking(slow, fast)
r.afterFunc(20*time.Second, func() {
r.startTaking(fast)
r.startTaking(fast)
r.startTaking(fast)
r.startTaking(fast)
})
r.assertCountAt(30*time.Second, 1200)
})
}
func TestPer(t *testing.T) {
runTest(t, func(r testRunner) {
rl := r.createLimiter(7, WithoutSlack, Per(time.Minute))
r.startTaking(rl)
r.startTaking(rl)
r.assertCountAt(1*time.Second, 1)
r.assertCountAt(1*time.Minute, 8)
r.assertCountAt(2*time.Minute, 15)
})
}
func TestInitial(t *testing.T) {
tests := []struct {
msg string
opts []Option
}{
{
msg: "With Slack",
},
{
msg: "Without Slack",
opts: []Option{WithoutSlack},
},
}
for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
runTest(t, func(r testRunner) {
perRequest := 100 * time.Millisecond
rl := r.createLimiter(10, tt.opts...)
var (
clk = r.getClock()
prev = clk.Now()
results = make(chan time.Time, 3)
have []time.Duration
)
results <- rl.Take()
clk.Add(perRequest)
results <- rl.Take()
clk.Add(perRequest)
results <- rl.Take()
clk.Add(perRequest)
for i := 0; i < 3; i++ {
ts := <-results
have = append(have, ts.Sub(prev))
prev = ts
}
require.Equal(t,
[]time.Duration{
0,
perRequest,
perRequest,
},
have,
"bad timestamps for inital takes",
)
})
})
}
}
func TestMaxSlack(t *testing.T) {
runTest(t, func(r testRunner) {
clock := r.getClock()
rl := r.createLimiter(1, WithSlack(1))
rl.Take()
clock.Add(time.Second)
rl.Take()
clock.Add(time.Second)
rl.Take()
doneCh := make(chan struct{})
go func() {
rl.Take()
close(doneCh)
}()
select {
case <-doneCh:
require.Fail(t, "expect rate limiter to be waiting")
case <-time.After(time.Millisecond):
clock.Add(time.Second)
}
})
}
func TestSlack(t *testing.T) {
t.Skip(UnstableTest)
tests := []struct {
msg string
opt []Option
want int
}{
{
msg: "no option, defaults to 10",
want: 130,
},
{
msg: "slack of 10, like default",
opt: []Option{WithSlack(10)},
want: 130,
},
{
msg: "slack of 20",
opt: []Option{WithSlack(20)},
want: 140,
},
{
msg: "slack of 150",
opt: []Option{WithSlack(150)},
want: 270,
},
{
msg: "no option, defaults to 10, with per",
opt: []Option{Per(500 * time.Millisecond)},
want: 230,
},
{
msg: "slack of 10, like default, with per",
opt: []Option{WithSlack(10), Per(500 * time.Millisecond)},
want: 230,
},
{
msg: "slack of 20, with per",
opt: []Option{WithSlack(20), Per(500 * time.Millisecond)},
want: 240,
},
{
msg: "slack of 150, with per",
opt: []Option{WithSlack(150), Per(500 * time.Millisecond)},
want: 370,
},
}
for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
runTest(t, func(r testRunner) {
slow := r.createLimiter(10, WithoutSlack)
fast := r.createLimiter(100, tt.opt...)
r.startTaking(slow, fast)
r.afterFunc(2*time.Second, func() {
r.startTaking(fast)
r.startTaking(fast)
})
r.assertCountAtWithNoise(1*time.Second, 10, 2)
r.assertCountAtWithNoise(3*time.Second, tt.want, 2)
})
})
}
}
func TestSetRateLimitOnTheFly(t *testing.T) {
t.Skip(UnstableTest)
runTest(t, func(r testRunner) {
limiter, ok := r.createLimiter(1, WithoutSlack).(*LeakyBucket)
if !ok {
t.Skip("Update is not supported")
}
r.startTaking(limiter)
r.assertCountAt(time.Second, 2)
r.getClock().Add(time.Second)
r.assertCountAt(time.Second, 3)
limiter.Update(2, 0)
r.getClock().Add(time.Second)
r.assertCountAt(time.Second, 4) r.getClock().Add(time.Second)
r.assertCountAt(time.Second, 6)
limiter.Update(1, 0)
r.getClock().Add(time.Second)
r.assertCountAt(time.Second, 7)
r.getClock().Add(time.Second)
r.assertCountAt(time.Second, 8)
slack := 3
require.GreaterOrEqual(t, limiter.sleepFor, time.Duration(0))
limiter.Update(1, slack)
r.getClock().Add(time.Second * time.Duration(slack))
r.assertCountAt(time.Second, 8+slack)
})
}