package utils
import (
"context"
"errors"
"time"
"go.uber.org/multierr"
)
var ErrMaxAttemptsReached = errors.New("max attempts reached")
type HedgeParams[T any] struct {
Timeout time.Duration
RetryDelay time.Duration
MaxAttempts int
IsRecoverable func(err error) bool
Func func(context.Context) (T, error)
}
func HedgeCall[T any](ctx context.Context, params HedgeParams[T]) (v T, err error) {
ctx, cancel := context.WithTimeout(ctx, params.Timeout)
defer cancel()
type result struct {
value T
err error
}
ch := make(chan result, params.MaxAttempts)
race := func() {
value, err := params.Func(ctx)
ch <- result{value, err}
}
var attempt, done int
delay := time.NewTimer(0)
defer delay.Stop()
for {
select {
case <-delay.C:
go race()
if attempt++; attempt < params.MaxAttempts {
delay.Reset(params.RetryDelay)
}
case res := <-ch:
if res.err == nil {
return res.value, nil
}
err = multierr.Append(err, res.err)
if params.IsRecoverable != nil && !params.IsRecoverable(res.err) {
return
}
if done++; done == params.MaxAttempts {
err = multierr.Append(err, ErrMaxAttemptsReached)
return
}
case <-ctx.Done():
err = multierr.Append(err, ctx.Err())
return
}
}
}