package sqlrite
import "C"
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"strings"
"unsafe"
)
type AskConfig struct {
Provider string `json:"provider,omitempty"`
APIKey string `json:"api_key,omitempty"`
Model string `json:"model,omitempty"`
MaxTokens uint32 `json:"max_tokens,omitempty"`
CacheTTL string `json:"cache_ttl,omitempty"`
BaseURL string `json:"base_url,omitempty"`
}
func AskConfigFromEnv() (*AskConfig, error) {
cfg := &AskConfig{
Provider: envOrDefault("SQLRITE_LLM_PROVIDER", "anthropic"),
APIKey: os.Getenv("SQLRITE_LLM_API_KEY"),
Model: envOrDefault("SQLRITE_LLM_MODEL", "claude-sonnet-4-6"),
CacheTTL: envOrDefault("SQLRITE_LLM_CACHE_TTL", "5m"),
}
if v := os.Getenv("SQLRITE_LLM_MAX_TOKENS"); v != "" {
n, err := strconv.ParseUint(v, 10, 32)
if err != nil {
return nil, fmt.Errorf("sqlrite: SQLRITE_LLM_MAX_TOKENS not a u32: %s", v)
}
cfg.MaxTokens = uint32(n)
} else {
cfg.MaxTokens = 1024
}
return cfg, nil
}
func (c *AskConfig) String() string {
keyStatus := "<unset>"
if c.APIKey != "" {
keyStatus = "<set>"
}
return fmt.Sprintf(
"AskConfig(provider=%q, model=%q, maxTokens=%d, cacheTtl=%q, apiKey=%s)",
c.Provider, c.Model, c.MaxTokens, c.CacheTTL, keyStatus,
)
}
func envOrDefault(key, dflt string) string {
if v := os.Getenv(key); v != "" {
return v
}
return dflt
}
type AskResponse struct {
SQL string `json:"sql"`
Explanation string `json:"explanation"`
Usage AskUsage `json:"usage"`
}
type AskUsage struct {
InputTokens uint64 `json:"input_tokens"`
OutputTokens uint64 `json:"output_tokens"`
CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
}
func Ask(db *sql.DB, question string, cfg *AskConfig) (*AskResponse, error) {
return AskContext(context.Background(), db, question, cfg)
}
func AskContext(ctx context.Context, db *sql.DB, question string, cfg *AskConfig) (*AskResponse, error) {
if db == nil {
return nil, errors.New("sqlrite: Ask: db is nil")
}
dbConn, err := db.Conn(ctx)
if err != nil {
return nil, fmt.Errorf("sqlrite: Ask: %w", err)
}
defer dbConn.Close()
var resp *AskResponse
err = dbConn.Raw(func(driverConn any) error {
c, ok := driverConn.(*conn)
if !ok {
return fmt.Errorf("sqlrite: Ask: driver connection is %T, not *sqlrite.conn — Ask requires a sqlrite-backed *sql.DB", driverConn)
}
r, err := c.ask(question, cfg)
if err != nil {
return err
}
resp = r
return nil
})
return resp, err
}
func AskRun(db *sql.DB, question string, cfg *AskConfig) (*sql.Rows, error) {
return AskRunContext(context.Background(), db, question, cfg)
}
func AskRunContext(ctx context.Context, db *sql.DB, question string, cfg *AskConfig) (*sql.Rows, error) {
resp, err := AskContext(ctx, db, question, cfg)
if err != nil {
return nil, err
}
trimmed := strings.TrimSpace(resp.SQL)
if trimmed == "" {
expl := resp.Explanation
if expl == "" {
expl = "(no explanation)"
}
return nil, fmt.Errorf("sqlrite: AskRun: model declined to generate SQL: %s", expl)
}
return db.QueryContext(ctx, trimmed)
}
func (c *conn) ask(question string, cfg *AskConfig) (*AskResponse, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, driver.ErrBadConn
}
var configJSON string
if cfg != nil {
raw, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("sqlrite: ask: marshal config: %w", err)
}
configJSON = string(raw)
}
cQuestion := cString(question)
defer freeCString(cQuestion)
var cConfig *C.char
if configJSON != "" {
cConfig = cString(configJSON)
defer freeCString(cConfig)
}
var out *C.char
status := Status(C.sqlrite_ask(c.handle, cQuestion, cConfig, &out))
if err := wrapErr(status, "ask"); err != nil {
return nil, err
}
if out == nil {
return nil, errors.New("sqlrite: ask: FFI returned status=ok but null response")
}
defer C.sqlrite_free_string(out)
jsonStr := C.GoString(out)
var resp AskResponse
if err := json.Unmarshal([]byte(jsonStr), &resp); err != nil {
return nil, fmt.Errorf("sqlrite: ask: parse response JSON: %w (raw=%q)", err, jsonStr)
}
return &resp, nil
}
var _ = unsafe.Sizeof(C.int(0))