package gwp
import (
"context"
"io"
pb "github.com/GrafeoDB/gql-wire-protocol/go/gen/gql"
)
type GqlSession struct {
sessionID string
sessionClient pb.SessionServiceClient
gqlClient pb.GqlServiceClient
closed bool
}
func (s *GqlSession) SessionID() string {
return s.sessionID
}
func (s *GqlSession) Execute(ctx context.Context, statement string, params map[string]any) (*ResultCursor, error) {
protoParams := make(map[string]*pb.Value, len(params))
for k, v := range params {
protoParams[k] = valueToProto(v)
}
stream, err := s.gqlClient.Execute(ctx, &pb.ExecuteRequest{
SessionId: s.sessionID,
Statement: statement,
Parameters: protoParams,
})
if err != nil {
return nil, err
}
return newResultCursor(stream), nil
}
func (s *GqlSession) BeginTransaction(ctx context.Context, readOnly bool) (*Transaction, error) {
mode := pb.TransactionMode_READ_WRITE
if readOnly {
mode = pb.TransactionMode_READ_ONLY
}
resp, err := s.gqlClient.BeginTransaction(ctx, &pb.BeginRequest{
SessionId: s.sessionID,
Mode: mode,
})
if err != nil {
return nil, err
}
if resp.Status != nil && IsException(resp.Status.Code) {
return nil, &GqlStatusError{Code: resp.Status.Code, Message: resp.Status.Message}
}
if resp.TransactionId == "" {
return nil, &TransactionError{Message: "server returned empty transaction ID"}
}
return &Transaction{
sessionID: s.sessionID,
transactionID: resp.TransactionId,
gqlClient: s.gqlClient,
}, nil
}
func (s *GqlSession) SetGraph(ctx context.Context, name string) error {
_, err := s.sessionClient.Configure(ctx, &pb.ConfigureRequest{
SessionId: s.sessionID,
Property: &pb.ConfigureRequest_Graph{Graph: name},
})
return err
}
func (s *GqlSession) SetSchema(ctx context.Context, name string) error {
_, err := s.sessionClient.Configure(ctx, &pb.ConfigureRequest{
SessionId: s.sessionID,
Property: &pb.ConfigureRequest_Schema{Schema: name},
})
return err
}
func (s *GqlSession) SetTimeZone(ctx context.Context, offsetMinutes int32) error {
_, err := s.sessionClient.Configure(ctx, &pb.ConfigureRequest{
SessionId: s.sessionID,
Property: &pb.ConfigureRequest_TimeZoneOffsetMinutes{TimeZoneOffsetMinutes: offsetMinutes},
})
return err
}
func (s *GqlSession) Reset(ctx context.Context) error {
_, err := s.sessionClient.Reset(ctx, &pb.ResetRequest{
SessionId: s.sessionID,
Target: pb.ResetTarget_RESET_ALL,
})
return err
}
func (s *GqlSession) Ping(ctx context.Context) (int64, error) {
resp, err := s.sessionClient.Ping(ctx, &pb.PingRequest{
SessionId: s.sessionID,
})
if err != nil {
return 0, err
}
return resp.Timestamp, nil
}
func (s *GqlSession) Close(ctx context.Context) error {
if s.closed {
return nil
}
_, err := s.sessionClient.Close(ctx, &pb.CloseRequest{
SessionId: s.sessionID,
})
s.closed = true
return err
}
type resultCursorStream interface {
Recv() (*pb.ExecuteResponse, error)
}
func newResultCursor(stream resultCursorStream) *ResultCursor {
return &ResultCursor{stream: stream}
}
type ResultCursor struct {
stream resultCursorStream
header *pb.ResultHeader
summary *pb.ResultSummary
bufferedRows [][]any
rowIndex int
done bool
}
func (c *ResultCursor) consumeUntilRowsOrDone() error {
for !c.done && c.rowIndex >= len(c.bufferedRows) {
resp, err := c.stream.Recv()
if err == io.EOF {
c.done = true
return nil
}
if err != nil {
c.done = true
return err
}
switch f := resp.Frame.(type) {
case *pb.ExecuteResponse_Header:
c.header = f.Header
case *pb.ExecuteResponse_RowBatch:
for _, row := range f.RowBatch.Rows {
values := make([]any, len(row.Values))
for i, v := range row.Values {
values[i] = valueFromProto(v)
}
c.bufferedRows = append(c.bufferedRows, values)
}
case *pb.ExecuteResponse_Summary:
c.summary = f.Summary
c.done = true
}
}
return nil
}
func (c *ResultCursor) ColumnNames() ([]string, error) {
if c.header == nil {
if err := c.consumeUntilRowsOrDone(); err != nil {
return nil, err
}
}
if c.header == nil {
return nil, nil
}
names := make([]string, len(c.header.Columns))
for i, col := range c.header.Columns {
names[i] = col.Name
}
return names, nil
}
func (c *ResultCursor) NextRow() ([]any, error) {
if c.rowIndex < len(c.bufferedRows) {
row := c.bufferedRows[c.rowIndex]
c.rowIndex++
return row, nil
}
if err := c.consumeUntilRowsOrDone(); err != nil {
return nil, err
}
if c.rowIndex < len(c.bufferedRows) {
row := c.bufferedRows[c.rowIndex]
c.rowIndex++
return row, nil
}
return nil, nil
}
func (c *ResultCursor) CollectRows() ([][]any, error) {
var rows [][]any
for {
row, err := c.NextRow()
if err != nil {
return rows, err
}
if row == nil {
return rows, nil
}
rows = append(rows, row)
}
}
func (c *ResultCursor) Summary() (*ResultSummary, error) {
for !c.done {
c.rowIndex = len(c.bufferedRows)
if err := c.consumeUntilRowsOrDone(); err != nil {
return nil, err
}
}
if c.summary != nil {
return &ResultSummary{proto: c.summary}, nil
}
return nil, nil
}
func (c *ResultCursor) IsSuccess() (bool, error) {
s, err := c.Summary()
if err != nil {
return false, err
}
if s == nil {
return false, nil
}
return s.IsSuccess(), nil
}
func (c *ResultCursor) RowsAffected() (int64, error) {
s, err := c.Summary()
if err != nil {
return 0, err
}
if s == nil {
return 0, nil
}
return s.RowsAffected(), nil
}
type ResultSummary struct {
proto *pb.ResultSummary
}
func (s *ResultSummary) StatusCode() string {
if s.proto.Status != nil {
return s.proto.Status.Code
}
return ""
}
func (s *ResultSummary) Message() string {
if s.proto.Status != nil {
return s.proto.Status.Message
}
return ""
}
func (s *ResultSummary) RowsAffected() int64 {
return s.proto.RowsAffected
}
func (s *ResultSummary) IsSuccess() bool {
return IsSuccess(s.StatusCode())
}