use std::env;
pub mod prompt;
pub mod provider;
#[cfg(feature = "http")]
pub use provider::anthropic::AnthropicProvider;
pub use provider::{Provider, Request, Response, Usage};
use prompt::{CacheControl, UserMessage, build_system};
use provider::Request as ProviderRequest;
pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
pub const DEFAULT_MAX_TOKENS: u32 = 1024;
#[derive(Debug, Clone)]
pub struct AskResponse {
pub sql: String,
pub explanation: String,
pub usage: Usage,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheTtl {
FiveMinutes,
OneHour,
Off,
}
impl CacheTtl {
fn into_marker(self) -> Option<CacheControl> {
match self {
CacheTtl::FiveMinutes => Some(CacheControl::ephemeral()),
CacheTtl::OneHour => Some(CacheControl::ephemeral_1h()),
CacheTtl::Off => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProviderKind {
Anthropic,
}
impl ProviderKind {
fn parse(s: &str) -> Result<Self, AskError> {
match s.to_ascii_lowercase().as_str() {
"anthropic" => Ok(ProviderKind::Anthropic),
other => Err(AskError::UnknownProvider(other.to_string())),
}
}
}
#[derive(Debug, Clone)]
pub struct AskConfig {
pub provider: ProviderKind,
pub api_key: Option<String>,
pub model: String,
pub max_tokens: u32,
pub cache_ttl: CacheTtl,
pub base_url: Option<String>,
}
impl Default for AskConfig {
fn default() -> Self {
Self {
provider: ProviderKind::Anthropic,
api_key: None,
model: DEFAULT_MODEL.to_string(),
max_tokens: DEFAULT_MAX_TOKENS,
cache_ttl: CacheTtl::FiveMinutes,
base_url: None,
}
}
}
impl AskConfig {
pub fn from_env() -> Result<Self, AskError> {
let mut cfg = AskConfig::default();
if let Ok(p) = env::var("SQLRITE_LLM_PROVIDER") {
cfg.provider = ProviderKind::parse(&p)?;
}
if let Ok(k) = env::var("SQLRITE_LLM_API_KEY") {
if !k.is_empty() {
cfg.api_key = Some(k);
}
}
if let Ok(m) = env::var("SQLRITE_LLM_MODEL") {
if !m.is_empty() {
cfg.model = m;
}
}
if let Ok(t) = env::var("SQLRITE_LLM_MAX_TOKENS") {
cfg.max_tokens = t
.parse()
.map_err(|_| AskError::Config(format!("SQLRITE_LLM_MAX_TOKENS not a u32: {t}")))?;
}
if let Ok(c) = env::var("SQLRITE_LLM_CACHE_TTL") {
cfg.cache_ttl = match c.to_ascii_lowercase().as_str() {
"5m" | "5min" | "5minutes" => CacheTtl::FiveMinutes,
"1h" | "1hr" | "1hour" => CacheTtl::OneHour,
"off" | "none" | "disabled" => CacheTtl::Off,
other => {
return Err(AskError::Config(format!(
"SQLRITE_LLM_CACHE_TTL: unknown value '{other}'"
)));
}
};
}
Ok(cfg)
}
}
#[derive(Debug, thiserror::Error)]
pub enum AskError {
#[error("missing API key (set SQLRITE_LLM_API_KEY or AskConfig.api_key)")]
MissingApiKey,
#[error("config error: {0}")]
Config(String),
#[error("unknown provider: {0} (supported: anthropic)")]
UnknownProvider(String),
#[error("HTTP transport error: {0}")]
Http(String),
#[error("API returned status {status}: {detail}")]
ApiStatus { status: u16, detail: String },
#[error("API returned no text content")]
EmptyResponse,
#[error("model output not valid JSON: {0}")]
OutputNotJson(String),
#[error("model output JSON missing required field '{0}'")]
OutputMissingField(&'static str),
#[error("JSON serialization error: {0}")]
Json(#[from] serde_json::Error),
}
#[cfg(feature = "http")]
pub fn ask_with_schema(
schema_dump: &str,
question: &str,
config: &AskConfig,
) -> Result<AskResponse, AskError> {
let api_key = config.api_key.clone().ok_or(AskError::MissingApiKey)?;
let provider = match config.provider {
ProviderKind::Anthropic => match &config.base_url {
Some(url) => AnthropicProvider::with_base_url(api_key, url.clone()),
None => AnthropicProvider::new(api_key),
},
};
ask_with_schema_and_provider(schema_dump, question, config, &provider)
}
pub fn ask_with_schema_and_provider<P: Provider>(
schema_dump: &str,
question: &str,
config: &AskConfig,
provider: &P,
) -> Result<AskResponse, AskError> {
let system = build_system(schema_dump, config.cache_ttl.into_marker());
let messages = [UserMessage::new(question)];
let req = ProviderRequest {
model: &config.model,
max_tokens: config.max_tokens,
system: &system,
messages: &messages,
};
let resp = provider.complete(req)?;
parse_response(&resp.text, resp.usage)
}
pub fn parse_response(raw: &str, usage: Usage) -> Result<AskResponse, AskError> {
let trimmed = raw.trim();
let body = strip_markdown_fence(trimmed).unwrap_or(trimmed);
if let Ok(value) = serde_json::from_str::<serde_json::Value>(body) {
return extract_fields(&value, usage);
}
if let Some(json_block) = extract_first_json_object(body) {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&json_block) {
return extract_fields(&value, usage);
}
}
Err(AskError::OutputNotJson(raw.to_string()))
}
fn extract_fields(value: &serde_json::Value, usage: Usage) -> Result<AskResponse, AskError> {
let sql = value
.get("sql")
.and_then(|v| v.as_str())
.ok_or(AskError::OutputMissingField("sql"))?
.trim()
.trim_end_matches(';')
.to_string();
let explanation = value
.get("explanation")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
Ok(AskResponse {
sql,
explanation,
usage,
})
}
fn strip_markdown_fence(s: &str) -> Option<&str> {
let s = s.trim();
let opening_variants = ["```json\n", "```JSON\n", "```\n"];
for opener in opening_variants {
if let Some(rest) = s.strip_prefix(opener) {
let body = rest.trim_end();
let body = body.strip_suffix("```").unwrap_or(body);
return Some(body.trim());
}
}
None
}
fn extract_first_json_object(s: &str) -> Option<String> {
let bytes = s.as_bytes();
let start = s.find('{')?;
let mut depth = 0_i32;
let mut in_string = false;
let mut escape = false;
for (i, &b) in bytes.iter().enumerate().skip(start) {
if escape {
escape = false;
continue;
}
match b {
b'\\' if in_string => escape = true,
b'"' => in_string = !in_string,
b'{' if !in_string => depth += 1,
b'}' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(s[start..=i].to_string());
}
}
_ => {}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::MockProvider;
const FIXTURE_SCHEMA: &str = "\
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT
);
";
fn cfg() -> AskConfig {
AskConfig {
api_key: Some("test-key".to_string()),
..AskConfig::default()
}
}
#[test]
fn ask_with_mock_provider_returns_parsed_sql() {
let provider = MockProvider::new(
r#"{"sql": "SELECT COUNT(*) FROM users", "explanation": "counts users"}"#,
);
let resp =
ask_with_schema_and_provider(FIXTURE_SCHEMA, "how many users?", &cfg(), &provider)
.unwrap();
assert_eq!(resp.sql, "SELECT COUNT(*) FROM users");
assert_eq!(resp.explanation, "counts users");
}
#[test]
fn schema_dump_appears_in_system_block() {
let schema = "CREATE TABLE widgets (\n id INTEGER PRIMARY KEY,\n name TEXT\n);\n";
let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
let _ = ask_with_schema_and_provider(schema, "anything", &cfg(), &provider).unwrap();
let captured = provider.last_request.borrow().clone().unwrap();
let schema_block = &captured.system_blocks[1];
assert!(
schema_block.contains("CREATE TABLE widgets"),
"got: {schema_block}"
);
assert!(schema_block.contains("name TEXT"), "got: {schema_block}");
}
#[test]
fn cache_ttl_off_omits_cache_control() {
let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
let mut config = cfg();
config.cache_ttl = CacheTtl::Off;
let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &config, &provider).unwrap();
let captured = provider.last_request.borrow().clone().unwrap();
assert!(!captured.schema_block_has_cache_control);
}
#[test]
fn cache_ttl_5m_sets_cache_control() {
let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &cfg(), &provider).unwrap();
let captured = provider.last_request.borrow().clone().unwrap();
assert!(captured.schema_block_has_cache_control);
}
#[test]
fn user_question_arrives_in_messages_unchanged() {
let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
let q = "Find users with email containing '@example.com'";
let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, q, &cfg(), &provider).unwrap();
assert_eq!(
provider
.last_request
.borrow()
.as_ref()
.unwrap()
.user_message,
q
);
}
#[test]
fn missing_api_key_errors_clearly() {
let config = AskConfig {
api_key: None,
..AskConfig::default()
};
let err = ask_with_schema(FIXTURE_SCHEMA, "test", &config).unwrap_err();
match err {
AskError::MissingApiKey => {}
other => panic!("expected MissingApiKey, got {other:?}"),
}
}
#[test]
fn parse_response_strips_trailing_semicolon() {
let resp = parse_response(
r#"{"sql": "SELECT 1;", "explanation": "demo"}"#,
Usage::default(),
)
.unwrap();
assert_eq!(resp.sql, "SELECT 1");
}
#[test]
fn parse_response_handles_markdown_fence() {
let raw = "```json\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}\n```";
let resp = parse_response(raw, Usage::default()).unwrap();
assert_eq!(resp.sql, "SELECT 1");
}
#[test]
fn parse_response_handles_leading_prose() {
let raw =
"Here is the query you asked for:\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}";
let resp = parse_response(raw, Usage::default()).unwrap();
assert_eq!(resp.sql, "SELECT 1");
}
#[test]
fn parse_response_rejects_non_json() {
let err = parse_response("just some prose, no JSON here", Usage::default()).unwrap_err();
assert!(matches!(err, AskError::OutputNotJson(_)));
}
#[test]
fn parse_response_rejects_missing_sql_field() {
let err = parse_response(r#"{"explanation": "no sql key"}"#, Usage::default()).unwrap_err();
assert!(matches!(err, AskError::OutputMissingField("sql")));
}
#[test]
fn parse_response_allows_missing_explanation() {
let resp = parse_response(r#"{"sql": "SELECT 1"}"#, Usage::default()).unwrap();
assert_eq!(resp.sql, "SELECT 1");
assert_eq!(resp.explanation, "");
}
#[test]
fn parse_response_passes_usage_through() {
let usage = Usage {
input_tokens: 100,
output_tokens: 20,
cache_creation_input_tokens: 80,
cache_read_input_tokens: 0,
};
let resp =
parse_response(r#"{"sql": "SELECT 1", "explanation": ""}"#, usage.clone()).unwrap();
assert_eq!(resp.usage.input_tokens, 100);
assert_eq!(resp.usage.cache_creation_input_tokens, 80);
}
}