Skip to main content

sqlrite_ask/
lib.rs

1//! `sqlrite-ask` — natural-language → SQL adapter for SQLRite.
2//!
3//! Phase 7g.1 (foundational). One sync call:
4//!
5//! ```no_run
6//! use sqlrite::Connection;
7//! use sqlrite_ask::{ask, AskConfig};
8//!
9//! let conn = Connection::open("foo.sqlrite")?;
10//! let config = AskConfig::from_env()?;  // reads SQLRITE_LLM_API_KEY etc.
11//! let response = ask(&conn, "How many users are over 30?", &config)?;
12//! println!("Generated SQL: {}", response.sql);
13//! println!("Why: {}", response.explanation);
14//! # Ok::<(), sqlrite_ask::AskError>(())
15//! ```
16//!
17//! ## What this crate is
18//!
19//! - Reflects the schema of an open `Connection` into CREATE TABLE
20//!   text the LLM can ground on.
21//! - Wraps that schema in a stable system prompt with an
22//!   `cache_control: ephemeral` breakpoint so the schema dump is
23//!   served from Anthropic's prompt cache after the first call.
24//! - Sends one HTTP POST to the LLM provider per `ask()` call.
25//! - Parses the response into `AskResponse { sql, explanation }`.
26//!
27//! ## What this crate is NOT
28//!
29//! - **Not an executor.** The library deliberately does not run the
30//!   generated SQL — the caller decides whether to execute it. SDK
31//!   layers (`Python.Connection.ask_run`, `Node.db.askRun`, etc.)
32//!   add a one-shot generate-and-execute helper for the common
33//!   case, but the default API is "generate, return, let me decide".
34//! - **Not multi-turn.** Stateless — every call is a fresh prompt.
35//! - **Not multi-provider yet.** Anthropic-first per Phase 7 plan
36//!   Q4. OpenAI + Ollama follow-ups slot into [`provider`] without
37//!   changing the public surface.
38//!
39//! ## Configuration
40//!
41//! [`AskConfig`] resolves in this priority order:
42//! 1. Explicit values you set on the struct (`AskConfig { api_key: Some(...), .. }`)
43//! 2. Environment variables (`SQLRITE_LLM_*`)
44//! 3. Built-in defaults (model = `claude-sonnet-4-6`, max_tokens = 1024,
45//!    cache TTL = 5 min)
46
47use std::env;
48
49use sqlrite::Connection;
50
51mod prompt;
52mod provider;
53pub mod schema;
54
55pub use provider::anthropic::AnthropicProvider;
56pub use provider::{Provider, Request, Response, Usage};
57
58use prompt::{CacheControl, UserMessage, build_system};
59use provider::Request as ProviderRequest;
60
61/// Default model — Sonnet 4.6 hits the cost-quality sweet spot for
62/// NL→SQL. Override via `AskConfig::model` or the `SQLRITE_LLM_MODEL`
63/// env var. See `docs/phase-7-plan.md` for the model-choice rationale.
64pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
65
66/// Default `max_tokens`. SQL generation rarely needs more than ~500
67/// output tokens (single-statement queries + a one-sentence
68/// explanation). 1024 leaves headroom; under the SDK timeout cap so
69/// we don't have to stream.
70pub const DEFAULT_MAX_TOKENS: u32 = 1024;
71
72/// Result returned from a successful [`ask`] call.
73///
74/// `sql` is the generated query text — empty string if the model
75/// determined the question can't be answered against the schema.
76/// `explanation` is the model's one-sentence rationale; useful in
77/// REPL "confirm before run" UIs.
78///
79/// `usage` surfaces token counts (input/output/cache hit/cache write).
80/// Inspect it to verify prompt-caching is actually working — see
81/// `docs/phase-7-plan.md` Q3-adjacent for the audit checklist.
82#[derive(Debug, Clone)]
83pub struct AskResponse {
84    pub sql: String,
85    pub explanation: String,
86    pub usage: Usage,
87}
88
89/// Cache-TTL knob exposed on [`AskConfig`].
90///
91/// Anthropic's `ephemeral` cache supports two TTLs:
92/// - **5 minutes** (default) — break-even at 2 calls per cached
93///   prefix; right for interactive REPL use where users ask a few
94///   questions in a session.
95/// - **1 hour** — costs 2× write premium instead of 1.25×; needs
96///   3+ calls per prefix to break even. Worth it for long-running
97///   editor / desktop sessions where the same DB is queried
98///   sporadically over an hour.
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum CacheTtl {
101    FiveMinutes,
102    OneHour,
103    /// Disables caching — schema block is sent without a
104    /// `cache_control` marker. Useful when the schema is below the
105    /// model's minimum cacheable prefix size (~2K tokens for Sonnet,
106    /// ~4K for Haiku/Opus); marking it would be a no-op.
107    Off,
108}
109
110impl CacheTtl {
111    fn into_marker(self) -> Option<CacheControl> {
112        match self {
113            CacheTtl::FiveMinutes => Some(CacheControl::ephemeral()),
114            CacheTtl::OneHour => Some(CacheControl::ephemeral_1h()),
115            CacheTtl::Off => None,
116        }
117    }
118}
119
120/// Which LLM provider [`ask`] talks to. Anthropic-only in 7g.1; the
121/// enum is here so adding OpenAI/Ollama later doesn't break the
122/// `AskConfig` shape.
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum ProviderKind {
125    Anthropic,
126}
127
128impl ProviderKind {
129    fn parse(s: &str) -> Result<Self, AskError> {
130        match s.to_ascii_lowercase().as_str() {
131            "anthropic" => Ok(ProviderKind::Anthropic),
132            other => Err(AskError::UnknownProvider(other.to_string())),
133        }
134    }
135}
136
137/// Knobs for an `ask()` call. Construct directly, or via
138/// [`AskConfig::from_env`] to pull defaults from the environment.
139#[derive(Debug, Clone)]
140pub struct AskConfig {
141    pub provider: ProviderKind,
142    pub api_key: Option<String>,
143    pub model: String,
144    pub max_tokens: u32,
145    pub cache_ttl: CacheTtl,
146    /// Override the API base URL. Production callers leave this
147    /// `None`; tests point it at a localhost mock.
148    pub base_url: Option<String>,
149}
150
151impl Default for AskConfig {
152    fn default() -> Self {
153        Self {
154            provider: ProviderKind::Anthropic,
155            api_key: None,
156            model: DEFAULT_MODEL.to_string(),
157            max_tokens: DEFAULT_MAX_TOKENS,
158            cache_ttl: CacheTtl::FiveMinutes,
159            base_url: None,
160        }
161    }
162}
163
164impl AskConfig {
165    /// Build a config from environment variables, with built-in
166    /// defaults for anything not set.
167    ///
168    /// Recognized vars:
169    /// - `SQLRITE_LLM_PROVIDER` — `anthropic` (only currently supported)
170    /// - `SQLRITE_LLM_API_KEY` — required at call time, but a missing
171    ///   var is not an error here (lets you build a config to inspect
172    ///   without the secret loaded)
173    /// - `SQLRITE_LLM_MODEL` — overrides [`DEFAULT_MODEL`]
174    /// - `SQLRITE_LLM_MAX_TOKENS` — overrides [`DEFAULT_MAX_TOKENS`]
175    /// - `SQLRITE_LLM_CACHE_TTL` — `5m` (default) | `1h` | `off`
176    pub fn from_env() -> Result<Self, AskError> {
177        let mut cfg = AskConfig::default();
178        if let Ok(p) = env::var("SQLRITE_LLM_PROVIDER") {
179            cfg.provider = ProviderKind::parse(&p)?;
180        }
181        if let Ok(k) = env::var("SQLRITE_LLM_API_KEY") {
182            if !k.is_empty() {
183                cfg.api_key = Some(k);
184            }
185        }
186        if let Ok(m) = env::var("SQLRITE_LLM_MODEL") {
187            if !m.is_empty() {
188                cfg.model = m;
189            }
190        }
191        if let Ok(t) = env::var("SQLRITE_LLM_MAX_TOKENS") {
192            cfg.max_tokens = t
193                .parse()
194                .map_err(|_| AskError::Config(format!("SQLRITE_LLM_MAX_TOKENS not a u32: {t}")))?;
195        }
196        if let Ok(c) = env::var("SQLRITE_LLM_CACHE_TTL") {
197            cfg.cache_ttl = match c.to_ascii_lowercase().as_str() {
198                "5m" | "5min" | "5minutes" => CacheTtl::FiveMinutes,
199                "1h" | "1hr" | "1hour" => CacheTtl::OneHour,
200                "off" | "none" | "disabled" => CacheTtl::Off,
201                other => {
202                    return Err(AskError::Config(format!(
203                        "SQLRITE_LLM_CACHE_TTL: unknown value '{other}'"
204                    )));
205                }
206            };
207        }
208        Ok(cfg)
209    }
210}
211
212/// Errors `ask()` can return. Includes every failure mode along the
213/// path: config / network / API / parsing.
214#[derive(Debug, thiserror::Error)]
215pub enum AskError {
216    #[error("missing API key (set SQLRITE_LLM_API_KEY or AskConfig.api_key)")]
217    MissingApiKey,
218
219    #[error("config error: {0}")]
220    Config(String),
221
222    #[error("unknown provider: {0} (supported: anthropic)")]
223    UnknownProvider(String),
224
225    #[error("HTTP transport error: {0}")]
226    Http(String),
227
228    #[error("API returned status {status}: {detail}")]
229    ApiStatus { status: u16, detail: String },
230
231    #[error("API returned no text content")]
232    EmptyResponse,
233
234    #[error("model output not valid JSON: {0}")]
235    OutputNotJson(String),
236
237    #[error("model output JSON missing required field '{0}'")]
238    OutputMissingField(&'static str),
239
240    #[error("schema introspection failed: {0}")]
241    Schema(String),
242
243    #[error("JSON serialization error: {0}")]
244    Json(#[from] serde_json::Error),
245
246    #[error(transparent)]
247    Engine(#[from] sqlrite::SQLRiteError),
248}
249
250/// Extension trait that adds [`ConnectionAskExt::ask`] to
251/// [`sqlrite::Connection`]. Lives here (not on the engine) to keep the
252/// engine free of HTTP / TLS / serde deps. Bring it into scope with
253/// `use sqlrite_ask::ConnectionAskExt;`.
254pub trait ConnectionAskExt {
255    /// Generate SQL from a natural-language question. Equivalent to
256    /// the free-function [`ask`] but reads as a method:
257    ///
258    /// ```no_run
259    /// use sqlrite::Connection;
260    /// use sqlrite_ask::{AskConfig, ConnectionAskExt};
261    ///
262    /// let conn = Connection::open("foo.sqlrite")?;
263    /// let cfg = AskConfig::from_env()?;
264    /// let resp = conn.ask("how many users are over 30?", &cfg)?;
265    /// # Ok::<(), sqlrite_ask::AskError>(())
266    /// ```
267    fn ask(&self, question: &str, config: &AskConfig) -> Result<AskResponse, AskError>;
268}
269
270impl ConnectionAskExt for Connection {
271    fn ask(&self, question: &str, config: &AskConfig) -> Result<AskResponse, AskError> {
272        ask(self, question, config)
273    }
274}
275
276/// One-shot natural-language → SQL.
277///
278/// Walks `conn`'s schema, builds a cache-friendly prompt, calls the
279/// configured LLM, parses the JSON-shaped reply into [`AskResponse`].
280///
281/// The library does **not** execute the returned SQL — that's the
282/// caller's call. See module docs for rationale.
283pub fn ask(conn: &Connection, question: &str, config: &AskConfig) -> Result<AskResponse, AskError> {
284    let api_key = config.api_key.clone().ok_or(AskError::MissingApiKey)?;
285
286    let provider = match config.provider {
287        ProviderKind::Anthropic => match &config.base_url {
288            Some(url) => AnthropicProvider::with_base_url(api_key, url.clone()),
289            None => AnthropicProvider::new(api_key),
290        },
291    };
292
293    ask_with_provider(conn, question, config, &provider)
294}
295
296/// Lower-level entry point — same flow, but you supply the provider.
297///
298/// Used by the test suite (which passes a `MockProvider`) and by
299/// advanced callers who want to drive a custom backend (an internal
300/// LLM gateway, a recorded-replay test harness, etc.).
301pub fn ask_with_provider<P: Provider>(
302    conn: &Connection,
303    question: &str,
304    config: &AskConfig,
305    provider: &P,
306) -> Result<AskResponse, AskError> {
307    let schema_dump = schema::dump_schema(conn);
308    let system = build_system(&schema_dump, config.cache_ttl.into_marker());
309    let messages = [UserMessage::new(question)];
310
311    let req = ProviderRequest {
312        model: &config.model,
313        max_tokens: config.max_tokens,
314        system: &system,
315        messages: &messages,
316    };
317
318    let resp = provider.complete(req)?;
319    parse_response(&resp.text, resp.usage)
320}
321
322/// Pull `sql` and `explanation` out of the model's reply.
323///
324/// We accept three shapes — strict JSON object, JSON wrapped in a
325/// fenced code block, or "almost JSON" with leading/trailing prose —
326/// because real LLM output drifts even with strict instructions. The
327/// fence/prose tolerance matches what real callers do (better-sqlite3,
328/// rusqlite, etc.) when interfacing with model output.
329fn parse_response(raw: &str, usage: Usage) -> Result<AskResponse, AskError> {
330    // 1. Strip markdown fences if the model wrapped its JSON.
331    let trimmed = raw.trim();
332    let body = strip_markdown_fence(trimmed).unwrap_or(trimmed);
333
334    // 2. Try strict JSON first.
335    if let Ok(value) = serde_json::from_str::<serde_json::Value>(body) {
336        return extract_fields(&value, usage);
337    }
338
339    // 3. Fallback: extract the first {...} block. Some models tack
340    // prose like "Here is the SQL:" before the JSON despite the
341    // prompt instruction. Find the first balanced object and try
342    // parsing that.
343    if let Some(json_block) = extract_first_json_object(body) {
344        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&json_block) {
345            return extract_fields(&value, usage);
346        }
347    }
348
349    Err(AskError::OutputNotJson(raw.to_string()))
350}
351
352fn extract_fields(value: &serde_json::Value, usage: Usage) -> Result<AskResponse, AskError> {
353    let sql = value
354        .get("sql")
355        .and_then(|v| v.as_str())
356        .ok_or(AskError::OutputMissingField("sql"))?
357        .trim()
358        .trim_end_matches(';')
359        .to_string();
360    let explanation = value
361        .get("explanation")
362        .and_then(|v| v.as_str())
363        .unwrap_or("")
364        .to_string();
365    Ok(AskResponse {
366        sql,
367        explanation,
368        usage,
369    })
370}
371
372fn strip_markdown_fence(s: &str) -> Option<&str> {
373    let s = s.trim();
374    let opening_variants = ["```json\n", "```JSON\n", "```\n"];
375    for opener in opening_variants {
376        if let Some(rest) = s.strip_prefix(opener) {
377            // Strip trailing ``` (with or without a final newline).
378            let body = rest.trim_end();
379            let body = body.strip_suffix("```").unwrap_or(body);
380            return Some(body.trim());
381        }
382    }
383    None
384}
385
386fn extract_first_json_object(s: &str) -> Option<String> {
387    let bytes = s.as_bytes();
388    let start = s.find('{')?;
389    let mut depth = 0_i32;
390    let mut in_string = false;
391    let mut escape = false;
392    for (i, &b) in bytes.iter().enumerate().skip(start) {
393        if escape {
394            escape = false;
395            continue;
396        }
397        match b {
398            b'\\' if in_string => escape = true,
399            b'"' => in_string = !in_string,
400            b'{' if !in_string => depth += 1,
401            b'}' if !in_string => {
402                depth -= 1;
403                if depth == 0 {
404                    return Some(s[start..=i].to_string());
405                }
406            }
407            _ => {}
408        }
409    }
410    None
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use crate::provider::MockProvider;
417    use sqlrite::Connection;
418
419    fn open() -> Connection {
420        Connection::open_in_memory().unwrap()
421    }
422
423    fn cfg() -> AskConfig {
424        AskConfig {
425            api_key: Some("test-key".to_string()),
426            ..AskConfig::default()
427        }
428    }
429
430    #[test]
431    fn ask_with_mock_provider_returns_parsed_sql() {
432        let mut conn = open();
433        conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
434            .unwrap();
435
436        let provider = MockProvider::new(
437            r#"{"sql": "SELECT COUNT(*) FROM users", "explanation": "counts users"}"#,
438        );
439
440        let resp = ask_with_provider(&conn, "how many users?", &cfg(), &provider).unwrap();
441        assert_eq!(resp.sql, "SELECT COUNT(*) FROM users");
442        assert_eq!(resp.explanation, "counts users");
443    }
444
445    #[test]
446    fn schema_dump_appears_in_system_block() {
447        let mut conn = open();
448        conn.execute("CREATE TABLE widgets (id INTEGER PRIMARY KEY, name TEXT)")
449            .unwrap();
450        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
451        let _ = ask_with_provider(&conn, "anything", &cfg(), &provider).unwrap();
452
453        let captured = provider.last_request.borrow().clone().unwrap();
454        let schema_block = &captured.system_blocks[1];
455        assert!(
456            schema_block.contains("CREATE TABLE widgets"),
457            "got: {schema_block}"
458        );
459        assert!(schema_block.contains("name TEXT"), "got: {schema_block}");
460    }
461
462    #[test]
463    fn cache_ttl_off_omits_cache_control() {
464        let conn = open();
465        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
466        let mut config = cfg();
467        config.cache_ttl = CacheTtl::Off;
468        let _ = ask_with_provider(&conn, "test", &config, &provider).unwrap();
469        let captured = provider.last_request.borrow().clone().unwrap();
470        assert!(!captured.schema_block_has_cache_control);
471    }
472
473    #[test]
474    fn cache_ttl_5m_sets_cache_control() {
475        let conn = open();
476        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
477        let _ = ask_with_provider(&conn, "test", &cfg(), &provider).unwrap();
478        let captured = provider.last_request.borrow().clone().unwrap();
479        assert!(captured.schema_block_has_cache_control);
480    }
481
482    #[test]
483    fn user_question_arrives_in_messages_unchanged() {
484        let conn = open();
485        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
486        let q = "Find users with email containing '@example.com'";
487        let _ = ask_with_provider(&conn, q, &cfg(), &provider).unwrap();
488        assert_eq!(
489            provider
490                .last_request
491                .borrow()
492                .as_ref()
493                .unwrap()
494                .user_message,
495            q
496        );
497    }
498
499    #[test]
500    fn missing_api_key_errors_clearly() {
501        let conn = open();
502        // Default has api_key: None already; just be explicit for the
503        // reader.
504        let config = AskConfig {
505            api_key: None,
506            ..AskConfig::default()
507        };
508        let err = ask(&conn, "test", &config).unwrap_err();
509        match err {
510            AskError::MissingApiKey => {}
511            other => panic!("expected MissingApiKey, got {other:?}"),
512        }
513    }
514
515    #[test]
516    fn parse_response_strips_trailing_semicolon() {
517        let resp = parse_response(
518            r#"{"sql": "SELECT 1;", "explanation": "demo"}"#,
519            Usage::default(),
520        )
521        .unwrap();
522        assert_eq!(resp.sql, "SELECT 1");
523    }
524
525    #[test]
526    fn parse_response_handles_markdown_fence() {
527        let raw = "```json\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}\n```";
528        let resp = parse_response(raw, Usage::default()).unwrap();
529        assert_eq!(resp.sql, "SELECT 1");
530    }
531
532    #[test]
533    fn parse_response_handles_leading_prose() {
534        let raw =
535            "Here is the query you asked for:\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}";
536        let resp = parse_response(raw, Usage::default()).unwrap();
537        assert_eq!(resp.sql, "SELECT 1");
538    }
539
540    #[test]
541    fn parse_response_rejects_non_json() {
542        let err = parse_response("just some prose, no JSON here", Usage::default()).unwrap_err();
543        assert!(matches!(err, AskError::OutputNotJson(_)));
544    }
545
546    #[test]
547    fn parse_response_rejects_missing_sql_field() {
548        let err = parse_response(r#"{"explanation": "no sql key"}"#, Usage::default()).unwrap_err();
549        assert!(matches!(err, AskError::OutputMissingField("sql")));
550    }
551
552    #[test]
553    fn parse_response_allows_missing_explanation() {
554        let resp = parse_response(r#"{"sql": "SELECT 1"}"#, Usage::default()).unwrap();
555        assert_eq!(resp.sql, "SELECT 1");
556        assert_eq!(resp.explanation, "");
557    }
558
559    #[test]
560    fn parse_response_passes_usage_through() {
561        let usage = Usage {
562            input_tokens: 100,
563            output_tokens: 20,
564            cache_creation_input_tokens: 80,
565            cache_read_input_tokens: 0,
566        };
567        let resp =
568            parse_response(r#"{"sql": "SELECT 1", "explanation": ""}"#, usage.clone()).unwrap();
569        assert_eq!(resp.usage.input_tokens, 100);
570        assert_eq!(resp.usage.cache_creation_input_tokens, 80);
571    }
572}