1use std::env;
48
49use sqlrite::Connection;
50
51mod prompt;
52mod provider;
53pub mod schema;
54
55pub use provider::anthropic::AnthropicProvider;
56pub use provider::{Provider, Request, Response, Usage};
57
58use prompt::{CacheControl, UserMessage, build_system};
59use provider::Request as ProviderRequest;
60
61pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
65
66pub const DEFAULT_MAX_TOKENS: u32 = 1024;
71
72#[derive(Debug, Clone)]
83pub struct AskResponse {
84 pub sql: String,
85 pub explanation: String,
86 pub usage: Usage,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum CacheTtl {
101 FiveMinutes,
102 OneHour,
103 Off,
108}
109
110impl CacheTtl {
111 fn into_marker(self) -> Option<CacheControl> {
112 match self {
113 CacheTtl::FiveMinutes => Some(CacheControl::ephemeral()),
114 CacheTtl::OneHour => Some(CacheControl::ephemeral_1h()),
115 CacheTtl::Off => None,
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum ProviderKind {
125 Anthropic,
126}
127
128impl ProviderKind {
129 fn parse(s: &str) -> Result<Self, AskError> {
130 match s.to_ascii_lowercase().as_str() {
131 "anthropic" => Ok(ProviderKind::Anthropic),
132 other => Err(AskError::UnknownProvider(other.to_string())),
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
140pub struct AskConfig {
141 pub provider: ProviderKind,
142 pub api_key: Option<String>,
143 pub model: String,
144 pub max_tokens: u32,
145 pub cache_ttl: CacheTtl,
146 pub base_url: Option<String>,
149}
150
151impl Default for AskConfig {
152 fn default() -> Self {
153 Self {
154 provider: ProviderKind::Anthropic,
155 api_key: None,
156 model: DEFAULT_MODEL.to_string(),
157 max_tokens: DEFAULT_MAX_TOKENS,
158 cache_ttl: CacheTtl::FiveMinutes,
159 base_url: None,
160 }
161 }
162}
163
164impl AskConfig {
165 pub fn from_env() -> Result<Self, AskError> {
177 let mut cfg = AskConfig::default();
178 if let Ok(p) = env::var("SQLRITE_LLM_PROVIDER") {
179 cfg.provider = ProviderKind::parse(&p)?;
180 }
181 if let Ok(k) = env::var("SQLRITE_LLM_API_KEY") {
182 if !k.is_empty() {
183 cfg.api_key = Some(k);
184 }
185 }
186 if let Ok(m) = env::var("SQLRITE_LLM_MODEL") {
187 if !m.is_empty() {
188 cfg.model = m;
189 }
190 }
191 if let Ok(t) = env::var("SQLRITE_LLM_MAX_TOKENS") {
192 cfg.max_tokens = t
193 .parse()
194 .map_err(|_| AskError::Config(format!("SQLRITE_LLM_MAX_TOKENS not a u32: {t}")))?;
195 }
196 if let Ok(c) = env::var("SQLRITE_LLM_CACHE_TTL") {
197 cfg.cache_ttl = match c.to_ascii_lowercase().as_str() {
198 "5m" | "5min" | "5minutes" => CacheTtl::FiveMinutes,
199 "1h" | "1hr" | "1hour" => CacheTtl::OneHour,
200 "off" | "none" | "disabled" => CacheTtl::Off,
201 other => {
202 return Err(AskError::Config(format!(
203 "SQLRITE_LLM_CACHE_TTL: unknown value '{other}'"
204 )));
205 }
206 };
207 }
208 Ok(cfg)
209 }
210}
211
212#[derive(Debug, thiserror::Error)]
215pub enum AskError {
216 #[error("missing API key (set SQLRITE_LLM_API_KEY or AskConfig.api_key)")]
217 MissingApiKey,
218
219 #[error("config error: {0}")]
220 Config(String),
221
222 #[error("unknown provider: {0} (supported: anthropic)")]
223 UnknownProvider(String),
224
225 #[error("HTTP transport error: {0}")]
226 Http(String),
227
228 #[error("API returned status {status}: {detail}")]
229 ApiStatus { status: u16, detail: String },
230
231 #[error("API returned no text content")]
232 EmptyResponse,
233
234 #[error("model output not valid JSON: {0}")]
235 OutputNotJson(String),
236
237 #[error("model output JSON missing required field '{0}'")]
238 OutputMissingField(&'static str),
239
240 #[error("schema introspection failed: {0}")]
241 Schema(String),
242
243 #[error("JSON serialization error: {0}")]
244 Json(#[from] serde_json::Error),
245
246 #[error(transparent)]
247 Engine(#[from] sqlrite::SQLRiteError),
248}
249
250pub trait ConnectionAskExt {
255 fn ask(&self, question: &str, config: &AskConfig) -> Result<AskResponse, AskError>;
268}
269
270impl ConnectionAskExt for Connection {
271 fn ask(&self, question: &str, config: &AskConfig) -> Result<AskResponse, AskError> {
272 ask(self, question, config)
273 }
274}
275
276pub fn ask(conn: &Connection, question: &str, config: &AskConfig) -> Result<AskResponse, AskError> {
284 let api_key = config.api_key.clone().ok_or(AskError::MissingApiKey)?;
285
286 let provider = match config.provider {
287 ProviderKind::Anthropic => match &config.base_url {
288 Some(url) => AnthropicProvider::with_base_url(api_key, url.clone()),
289 None => AnthropicProvider::new(api_key),
290 },
291 };
292
293 ask_with_provider(conn, question, config, &provider)
294}
295
296pub fn ask_with_provider<P: Provider>(
302 conn: &Connection,
303 question: &str,
304 config: &AskConfig,
305 provider: &P,
306) -> Result<AskResponse, AskError> {
307 let schema_dump = schema::dump_schema(conn);
308 let system = build_system(&schema_dump, config.cache_ttl.into_marker());
309 let messages = [UserMessage::new(question)];
310
311 let req = ProviderRequest {
312 model: &config.model,
313 max_tokens: config.max_tokens,
314 system: &system,
315 messages: &messages,
316 };
317
318 let resp = provider.complete(req)?;
319 parse_response(&resp.text, resp.usage)
320}
321
322fn parse_response(raw: &str, usage: Usage) -> Result<AskResponse, AskError> {
330 let trimmed = raw.trim();
332 let body = strip_markdown_fence(trimmed).unwrap_or(trimmed);
333
334 if let Ok(value) = serde_json::from_str::<serde_json::Value>(body) {
336 return extract_fields(&value, usage);
337 }
338
339 if let Some(json_block) = extract_first_json_object(body) {
344 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&json_block) {
345 return extract_fields(&value, usage);
346 }
347 }
348
349 Err(AskError::OutputNotJson(raw.to_string()))
350}
351
352fn extract_fields(value: &serde_json::Value, usage: Usage) -> Result<AskResponse, AskError> {
353 let sql = value
354 .get("sql")
355 .and_then(|v| v.as_str())
356 .ok_or(AskError::OutputMissingField("sql"))?
357 .trim()
358 .trim_end_matches(';')
359 .to_string();
360 let explanation = value
361 .get("explanation")
362 .and_then(|v| v.as_str())
363 .unwrap_or("")
364 .to_string();
365 Ok(AskResponse {
366 sql,
367 explanation,
368 usage,
369 })
370}
371
372fn strip_markdown_fence(s: &str) -> Option<&str> {
373 let s = s.trim();
374 let opening_variants = ["```json\n", "```JSON\n", "```\n"];
375 for opener in opening_variants {
376 if let Some(rest) = s.strip_prefix(opener) {
377 let body = rest.trim_end();
379 let body = body.strip_suffix("```").unwrap_or(body);
380 return Some(body.trim());
381 }
382 }
383 None
384}
385
386fn extract_first_json_object(s: &str) -> Option<String> {
387 let bytes = s.as_bytes();
388 let start = s.find('{')?;
389 let mut depth = 0_i32;
390 let mut in_string = false;
391 let mut escape = false;
392 for (i, &b) in bytes.iter().enumerate().skip(start) {
393 if escape {
394 escape = false;
395 continue;
396 }
397 match b {
398 b'\\' if in_string => escape = true,
399 b'"' => in_string = !in_string,
400 b'{' if !in_string => depth += 1,
401 b'}' if !in_string => {
402 depth -= 1;
403 if depth == 0 {
404 return Some(s[start..=i].to_string());
405 }
406 }
407 _ => {}
408 }
409 }
410 None
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use crate::provider::MockProvider;
417 use sqlrite::Connection;
418
419 fn open() -> Connection {
420 Connection::open_in_memory().unwrap()
421 }
422
423 fn cfg() -> AskConfig {
424 AskConfig {
425 api_key: Some("test-key".to_string()),
426 ..AskConfig::default()
427 }
428 }
429
430 #[test]
431 fn ask_with_mock_provider_returns_parsed_sql() {
432 let mut conn = open();
433 conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
434 .unwrap();
435
436 let provider = MockProvider::new(
437 r#"{"sql": "SELECT COUNT(*) FROM users", "explanation": "counts users"}"#,
438 );
439
440 let resp = ask_with_provider(&conn, "how many users?", &cfg(), &provider).unwrap();
441 assert_eq!(resp.sql, "SELECT COUNT(*) FROM users");
442 assert_eq!(resp.explanation, "counts users");
443 }
444
445 #[test]
446 fn schema_dump_appears_in_system_block() {
447 let mut conn = open();
448 conn.execute("CREATE TABLE widgets (id INTEGER PRIMARY KEY, name TEXT)")
449 .unwrap();
450 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
451 let _ = ask_with_provider(&conn, "anything", &cfg(), &provider).unwrap();
452
453 let captured = provider.last_request.borrow().clone().unwrap();
454 let schema_block = &captured.system_blocks[1];
455 assert!(
456 schema_block.contains("CREATE TABLE widgets"),
457 "got: {schema_block}"
458 );
459 assert!(schema_block.contains("name TEXT"), "got: {schema_block}");
460 }
461
462 #[test]
463 fn cache_ttl_off_omits_cache_control() {
464 let conn = open();
465 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
466 let mut config = cfg();
467 config.cache_ttl = CacheTtl::Off;
468 let _ = ask_with_provider(&conn, "test", &config, &provider).unwrap();
469 let captured = provider.last_request.borrow().clone().unwrap();
470 assert!(!captured.schema_block_has_cache_control);
471 }
472
473 #[test]
474 fn cache_ttl_5m_sets_cache_control() {
475 let conn = open();
476 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
477 let _ = ask_with_provider(&conn, "test", &cfg(), &provider).unwrap();
478 let captured = provider.last_request.borrow().clone().unwrap();
479 assert!(captured.schema_block_has_cache_control);
480 }
481
482 #[test]
483 fn user_question_arrives_in_messages_unchanged() {
484 let conn = open();
485 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
486 let q = "Find users with email containing '@example.com'";
487 let _ = ask_with_provider(&conn, q, &cfg(), &provider).unwrap();
488 assert_eq!(
489 provider
490 .last_request
491 .borrow()
492 .as_ref()
493 .unwrap()
494 .user_message,
495 q
496 );
497 }
498
499 #[test]
500 fn missing_api_key_errors_clearly() {
501 let conn = open();
502 let config = AskConfig {
505 api_key: None,
506 ..AskConfig::default()
507 };
508 let err = ask(&conn, "test", &config).unwrap_err();
509 match err {
510 AskError::MissingApiKey => {}
511 other => panic!("expected MissingApiKey, got {other:?}"),
512 }
513 }
514
515 #[test]
516 fn parse_response_strips_trailing_semicolon() {
517 let resp = parse_response(
518 r#"{"sql": "SELECT 1;", "explanation": "demo"}"#,
519 Usage::default(),
520 )
521 .unwrap();
522 assert_eq!(resp.sql, "SELECT 1");
523 }
524
525 #[test]
526 fn parse_response_handles_markdown_fence() {
527 let raw = "```json\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}\n```";
528 let resp = parse_response(raw, Usage::default()).unwrap();
529 assert_eq!(resp.sql, "SELECT 1");
530 }
531
532 #[test]
533 fn parse_response_handles_leading_prose() {
534 let raw =
535 "Here is the query you asked for:\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}";
536 let resp = parse_response(raw, Usage::default()).unwrap();
537 assert_eq!(resp.sql, "SELECT 1");
538 }
539
540 #[test]
541 fn parse_response_rejects_non_json() {
542 let err = parse_response("just some prose, no JSON here", Usage::default()).unwrap_err();
543 assert!(matches!(err, AskError::OutputNotJson(_)));
544 }
545
546 #[test]
547 fn parse_response_rejects_missing_sql_field() {
548 let err = parse_response(r#"{"explanation": "no sql key"}"#, Usage::default()).unwrap_err();
549 assert!(matches!(err, AskError::OutputMissingField("sql")));
550 }
551
552 #[test]
553 fn parse_response_allows_missing_explanation() {
554 let resp = parse_response(r#"{"sql": "SELECT 1"}"#, Usage::default()).unwrap();
555 assert_eq!(resp.sql, "SELECT 1");
556 assert_eq!(resp.explanation, "");
557 }
558
559 #[test]
560 fn parse_response_passes_usage_through() {
561 let usage = Usage {
562 input_tokens: 100,
563 output_tokens: 20,
564 cache_creation_input_tokens: 80,
565 cache_read_input_tokens: 0,
566 };
567 let resp =
568 parse_response(r#"{"sql": "SELECT 1", "explanation": ""}"#, usage.clone()).unwrap();
569 assert_eq!(resp.usage.input_tokens, 100);
570 assert_eq!(resp.usage.cache_creation_input_tokens, 80);
571 }
572}