package sqlrite
import "C"
import (
"context"
"database/sql/driver"
"errors"
"io"
"sync"
)
type conn struct {
mu sync.Mutex
handle *C.SqlriteConnection
closed bool
}
func newConn(name string, readOnly bool) (*conn, error) {
cName := cString(name)
defer freeCString(cName)
var handle *C.SqlriteConnection
var status Status
if readOnly {
status = Status(C.sqlrite_open_read_only(cName, &handle))
} else if name == ":memory:" {
status = Status(C.sqlrite_open_in_memory(&handle))
} else {
status = Status(C.sqlrite_open(cName, &handle))
}
if err := wrapErr(status, "open"); err != nil {
return nil, err
}
return &conn{handle: handle}, nil
}
var _ driver.Conn = (*conn)(nil)
var _ driver.ConnBeginTx = (*conn)(nil)
var _ driver.ExecerContext = (*conn)(nil)
var _ driver.QueryerContext = (*conn)(nil)
var _ driver.Pinger = (*conn)(nil)
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, errors.New("sqlrite: connection is closed")
}
return &stmt{conn: c, sql: query}, nil
}
func (c *conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
C.sqlrite_close(c.handle)
c.handle = nil
c.closed = true
return nil
}
func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
if opts.ReadOnly {
return nil, errors.New("sqlrite: read-only transactions aren't supported via TxOptions (open the db via OpenReadOnly instead)")
}
if err := c.exec("BEGIN"); err != nil {
return nil, err
}
return &tx{conn: c}, nil
}
func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *conn) Ping(_ context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return driver.ErrBadConn
}
return nil
}
func (c *conn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if err := rejectNamedParamsForNow(args); err != nil {
return nil, err
}
if err := c.exec(query); err != nil {
return nil, err
}
return execResult{}, nil
}
func (c *conn) QueryContext(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if err := rejectNamedParamsForNow(args); err != nil {
return nil, err
}
return c.query(query)
}
func (c *conn) exec(query string) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return driver.ErrBadConn
}
cQuery := cString(query)
defer freeCString(cQuery)
status := Status(C.sqlrite_execute(c.handle, cQuery))
return wrapErr(status, "execute")
}
func (c *conn) query(query string) (driver.Rows, error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, driver.ErrBadConn
}
cQuery := cString(query)
defer freeCString(cQuery)
var stmtHandle *C.SqlriteStatement
status := Status(C.sqlrite_query(c.handle, cQuery, &stmtHandle))
if err := wrapErr(status, "query"); err != nil {
c.mu.Unlock()
return nil, err
}
var colCount C.int
if st := Status(C.sqlrite_column_count(stmtHandle, &colCount)); st != statusOk {
C.sqlrite_finalize(stmtHandle)
c.mu.Unlock()
return nil, wrapErr(st, "column_count")
}
cols := make([]string, int(colCount))
for i := 0; i < int(colCount); i++ {
var name *C.char
if st := Status(C.sqlrite_column_name(stmtHandle, C.int(i), &name)); st != statusOk {
C.sqlrite_finalize(stmtHandle)
c.mu.Unlock()
return nil, wrapErr(st, "column_name")
}
cols[i] = C.GoString(name)
C.sqlrite_free_string(name)
}
c.mu.Unlock()
return &rows{
conn: c,
handle: stmtHandle,
cols: cols,
}, nil
}
type tx struct {
conn *conn
}
func (t *tx) Commit() error { return t.conn.exec("COMMIT") }
func (t *tx) Rollback() error { return t.conn.exec("ROLLBACK") }
type execResult struct{}
func (execResult) LastInsertId() (int64, error) { return 0, nil }
func (execResult) RowsAffected() (int64, error) { return 0, nil }
var _ io.Closer = (*rows)(nil)