package main
import "C"
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/authzed/spicedb/pkg/cmd/datastore"
"github.com/authzed/spicedb/pkg/cmd/server"
"github.com/authzed/spicedb/pkg/cmd/util"
spicedbdatastore "github.com/authzed/spicedb/pkg/datastore"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
type Instance struct {
server server.RunnableServer
transport string clientConn *grpc.ClientConn
streamingAddr string
streamingListener net.Listener
streamingServer *grpc.Server
cancel context.CancelFunc
wg *errgroup.Group
tracerProvider *sdktrace.TracerProvider
prevTracerProvider oteltrace.TracerProvider
metricsServer *http.Server
}
var (
instanceMu sync.RWMutex
instances = make(map[uint64]*Instance)
nextID uint64
)
type Response struct {
Success bool `json:"success"`
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
func makeError(msg string) *C.char {
resp := Response{Success: false, Error: msg}
data, _ := json.Marshal(resp)
return C.CString(string(data))
}
func makeSuccess(data interface{}) *C.char {
var rawData json.RawMessage
if data != nil {
rawData, _ = json.Marshal(data)
}
resp := Response{Success: true, Data: rawData}
respData, _ := json.Marshal(resp)
return C.CString(string(respData))
}
func spicedb_free(ptr *C.char) {
if ptr != nil {
C.free(unsafe.Pointer(ptr))
}
}
type StartOptions struct {
Datastore string `json:"datastore"`
DatastoreURI string `json:"datastore_uri"`
SpannerCredentialsFile string `json:"spanner_credentials_file"`
SpannerEmulatorHost string `json:"spanner_emulator_host"`
MySQLTablePrefix string `json:"mysql_table_prefix"`
MetricsEnabled bool `json:"metrics_enabled"`
DatastoreMetricsEnabled *bool `json:"datastore_metrics_enabled,omitempty"`
CacheMetricsEnabled *bool `json:"cache_metrics_enabled,omitempty"`
OTLPEndpoint string `json:"otlp_endpoint,omitempty"`
MetricsPort int `json:"metrics_port,omitempty"`
MetricsHost string `json:"metrics_host,omitempty"`
}
func spicedb_start(optionsJSON *C.char) *C.char {
opts := parseStartOptions(optionsJSON)
engine := opts.Datastore
if engine == "" {
engine = datastore.MemoryEngine
}
remoteEngines := map[string]bool{
"postgres": true, "cockroachdb": true, "spanner": true, "mysql": true,
}
if remoteEngines[engine] && opts.DatastoreURI == "" {
return makeError(fmt.Sprintf("datastore_uri is required for %s datastore", engine))
}
ctx, cancel := context.WithCancel(context.Background())
instance := &Instance{cancel: cancel}
id := atomic.AddUint64(&nextID, 1)
ds, err := newDatastoreFromOpts(ctx, opts)
if err != nil {
cancel()
return makeError(fmt.Sprintf("failed to create datastore: %v", err))
}
srv, err := newSpiceDBServerFromDatastore(ctx, "", "memory", ds, opts.MetricsEnabled && boolPtrOrDefault(opts.CacheMetricsEnabled, true))
if err != nil {
cancel()
return makeError(fmt.Sprintf("failed to create server: %v", err))
}
instance.server = srv
instance.transport = "memory"
if opts.MetricsEnabled && opts.OTLPEndpoint != "" {
prev, tp, err := setupOTelTracing(ctx, opts.OTLPEndpoint)
if err != nil {
cancel()
return makeError(fmt.Sprintf("failed to configure OTLP tracing: %v", err))
}
instance.tracerProvider = tp
instance.prevTracerProvider = prev
}
if opts.MetricsEnabled && opts.MetricsPort > 0 {
host := opts.MetricsHost
if host == "" {
host = "0.0.0.0"
}
srv, err := startMetricsServer(host, opts.MetricsPort)
if err != nil {
cancel()
cleanupObservability(instance)
return makeError(fmt.Sprintf("failed to start metrics server: %v", err))
}
instance.metricsServer = srv
}
var wg errgroup.Group
wg.Go(func() error {
if err := instance.server.Run(ctx); err != nil && ctx.Err() == nil {
return err
}
return nil
})
instance.wg = &wg
dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second)
defer dialCancel()
for i := 0; i < 20; i++ {
conn, err := instance.server.GRPCDialContext(dialCtx, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err == nil {
instance.clientConn = conn
break
}
time.Sleep(25 * time.Millisecond)
}
if instance.clientConn == nil {
cancel()
_ = wg.Wait()
cleanupObservability(instance)
return makeError("failed to dial in-memory server")
}
proxyAddr, err := startStreamingProxy(ctx, instance, id, &wg)
if err != nil {
cancel()
_ = wg.Wait()
cleanupObservability(instance)
return makeError(fmt.Sprintf("failed to start streaming proxy: %v", err))
}
instance.streamingAddr = proxyAddr
instanceMu.Lock()
instances[id] = instance
instanceMu.Unlock()
streamingTransport := "unix"
if runtime.GOOS == "windows" {
streamingTransport = "tcp"
}
data := map[string]interface{}{
"handle": id,
"grpc_transport": "memory",
"streaming_address": instance.streamingAddr,
"streaming_transport": streamingTransport,
}
return makeSuccess(data)
}
func parseStartOptions(optionsJSON *C.char) StartOptions {
opts := StartOptions{}
if optionsJSON == nil {
return opts
}
s := C.GoString(optionsJSON)
if s == "" {
return opts
}
_ = json.Unmarshal([]byte(s), &opts)
return opts
}
func newDatastoreFromOpts(ctx context.Context, opts StartOptions) (spicedbdatastore.Datastore, error) {
engine := opts.Datastore
if engine == "" {
engine = datastore.MemoryEngine
}
datastoreMetrics := opts.MetricsEnabled && boolPtrOrDefault(opts.DatastoreMetricsEnabled, true)
dsOpts := []datastore.ConfigOption{
datastore.DefaultDatastoreConfig().ToOption(),
datastore.WithEngine(engine),
datastore.WithRequestHedgingEnabled(false),
datastore.WithEnableDatastoreMetrics(datastoreMetrics),
}
if opts.DatastoreURI != "" {
dsOpts = append(dsOpts, datastore.WithURI(opts.DatastoreURI))
}
if engine == "spanner" {
if opts.SpannerCredentialsFile != "" {
dsOpts = append(dsOpts, datastore.WithSpannerCredentialsFile(opts.SpannerCredentialsFile))
}
if opts.SpannerEmulatorHost != "" {
dsOpts = append(dsOpts, datastore.WithSpannerEmulatorHost(opts.SpannerEmulatorHost))
}
}
if engine == "mysql" && opts.MySQLTablePrefix != "" {
dsOpts = append(dsOpts, datastore.WithTablePrefix(opts.MySQLTablePrefix))
}
return datastore.NewDatastore(ctx, dsOpts...)
}
func newSpiceDBServerFromDatastore(ctx context.Context, addr string, transport string, ds spicedbdatastore.Datastore, cacheMetricsEnabled bool) (server.RunnableServer, error) {
grpcConfig := util.GRPCServerConfig{Enabled: true}
if transport == "memory" {
grpcConfig.Network = util.BufferedNetwork
grpcConfig.Address = ""
grpcConfig.BufferSize = 1024 * 1024
} else {
network := "unix"
if transport == "tcp" {
network = "tcp"
}
grpcConfig.Network = network
grpcConfig.Address = addr
}
configOpts := []server.ConfigOption{
server.WithGRPCServer(grpcConfig),
server.WithGRPCAuthFunc(func(ctx context.Context) (context.Context, error) {
return ctx, nil
}),
server.WithHTTPGateway(util.HTTPServerConfig{HTTPEnabled: false}),
server.WithMetricsAPI(util.HTTPServerConfig{HTTPEnabled: false}),
server.WithDispatchCacheConfig(server.CacheConfig{Enabled: false, Metrics: cacheMetricsEnabled}),
server.WithNamespaceCacheConfig(server.CacheConfig{Enabled: false, Metrics: cacheMetricsEnabled}),
server.WithClusterDispatchCacheConfig(server.CacheConfig{Enabled: false, Metrics: cacheMetricsEnabled}),
server.WithDatastore(ds),
}
return server.NewConfigWithOptionsAndDefaults(configOpts...).Complete(ctx)
}
func boolPtrOrDefault(b *bool, def bool) bool {
if b == nil {
return def
}
return *b
}
func setupOTelTracing(ctx context.Context, endpoint string) (prev oteltrace.TracerProvider, tp *sdktrace.TracerProvider, err error) {
exporter, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(endpoint),
otlptracegrpc.WithInsecure(),
)
if err != nil {
return nil, nil, fmt.Errorf("creating OTLP gRPC exporter: %w", err)
}
prev = otel.GetTracerProvider()
tp = sdktrace.NewTracerProvider(sdktrace.WithBatcher(exporter))
otel.SetTracerProvider(tp)
return prev, tp, nil
}
func startMetricsServer(host string, port int) (*http.Server, error) {
addr := fmt.Sprintf("%s:%d", host, port)
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("binding metrics port %d: %w", port, err)
}
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
srv := &http.Server{
Addr: addr,
Handler: mux,
}
go func() { _ = srv.Serve(ln) }()
return srv, nil
}
func cleanupObservability(instance *Instance) {
if instance.metricsServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = instance.metricsServer.Shutdown(ctx)
}
if instance.tracerProvider != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = instance.tracerProvider.Shutdown(ctx)
otel.SetTracerProvider(instance.prevTracerProvider)
}
}
func spicedb_dispose(handle C.ulonglong) *C.char {
id := uint64(handle)
instanceMu.Lock()
instance, ok := instances[id]
if ok {
delete(instances, id)
}
instanceMu.Unlock()
if !ok {
return makeError(fmt.Sprintf("invalid handle: %d", id))
}
instance.cancel()
if instance.clientConn != nil {
_ = instance.clientConn.Close()
}
if instance.streamingServer != nil {
instance.streamingServer.GracefulStop()
}
if instance.streamingListener != nil {
_ = instance.streamingListener.Close()
}
_ = instance.wg.Wait()
cleanupObservability(instance)
if instance.streamingAddr != "" && runtime.GOOS != "windows" {
os.Remove(instance.streamingAddr)
}
return makeSuccess(nil)
}
func main() {}