package psrpc
import (
"context"
"google.golang.org/protobuf/proto"
"github.com/livekit/protocol/utils"
"github.com/livekit/psrpc"
)
func WithSuppressClientErrors(errs ...error) psrpc.ClientOption {
return psrpc.WithClientOptions(
psrpc.WithClientRPCInterceptors(newClientRPCErrorInterceptor(errs...)),
psrpc.WithClientMultiRPCInterceptors(newMultiRPCErorrInterceptor(errs...)),
)
}
func WithSuppressServerErrors(errs ...error) psrpc.ServerOption {
return psrpc.WithServerOptions(
psrpc.WithServerRPCInterceptors(newServerRPCErorrInterceptor(errs...)),
)
}
func newClientRPCErrorInterceptor(errs ...error) psrpc.ClientRPCInterceptor {
return func(rpcInfo psrpc.RPCInfo, next psrpc.ClientRPCHandler) psrpc.ClientRPCHandler {
return func(ctx context.Context, req proto.Message, opts ...psrpc.RequestOption) (res proto.Message, err error) {
res, err = next(ctx, req, opts...)
return res, utils.ScreenError(err, errs...)
}
}
}
func newServerRPCErorrInterceptor(errs ...error) psrpc.ServerRPCInterceptor {
return func(ctx context.Context, req proto.Message, rpcInfo psrpc.RPCInfo, handler psrpc.ServerRPCHandler) (res proto.Message, err error) {
res, err = handler(ctx, req)
return res, utils.ScreenError(err, errs...)
}
}
func newMultiRPCErorrInterceptor(errs ...error) psrpc.ClientMultiRPCInterceptor {
return func(rpcInfo psrpc.RPCInfo, next psrpc.ClientMultiRPCHandler) psrpc.ClientMultiRPCHandler {
return &multiRPCErorrInterceptor{
ClientMultiRPCHandler: next,
errors: errs,
}
}
}
type multiRPCErorrInterceptor struct {
psrpc.ClientMultiRPCHandler
errors []error
}
func (r *multiRPCErorrInterceptor) Recv(msg proto.Message, err error) {
r.ClientMultiRPCHandler.Recv(msg, utils.ScreenError(err, r.errors...))
}