1use std::env;
75
76mod prompt;
77mod provider;
78
79pub use provider::anthropic::AnthropicProvider;
80pub use provider::{Provider, Request, Response, Usage};
81
82use prompt::{CacheControl, UserMessage, build_system};
83use provider::Request as ProviderRequest;
84
85pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
89
90pub const DEFAULT_MAX_TOKENS: u32 = 1024;
95
96#[derive(Debug, Clone)]
107pub struct AskResponse {
108 pub sql: String,
109 pub explanation: String,
110 pub usage: Usage,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum CacheTtl {
125 FiveMinutes,
126 OneHour,
127 Off,
132}
133
134impl CacheTtl {
135 fn into_marker(self) -> Option<CacheControl> {
136 match self {
137 CacheTtl::FiveMinutes => Some(CacheControl::ephemeral()),
138 CacheTtl::OneHour => Some(CacheControl::ephemeral_1h()),
139 CacheTtl::Off => None,
140 }
141 }
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum ProviderKind {
149 Anthropic,
150}
151
152impl ProviderKind {
153 fn parse(s: &str) -> Result<Self, AskError> {
154 match s.to_ascii_lowercase().as_str() {
155 "anthropic" => Ok(ProviderKind::Anthropic),
156 other => Err(AskError::UnknownProvider(other.to_string())),
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
164pub struct AskConfig {
165 pub provider: ProviderKind,
166 pub api_key: Option<String>,
167 pub model: String,
168 pub max_tokens: u32,
169 pub cache_ttl: CacheTtl,
170 pub base_url: Option<String>,
173}
174
175impl Default for AskConfig {
176 fn default() -> Self {
177 Self {
178 provider: ProviderKind::Anthropic,
179 api_key: None,
180 model: DEFAULT_MODEL.to_string(),
181 max_tokens: DEFAULT_MAX_TOKENS,
182 cache_ttl: CacheTtl::FiveMinutes,
183 base_url: None,
184 }
185 }
186}
187
188impl AskConfig {
189 pub fn from_env() -> Result<Self, AskError> {
201 let mut cfg = AskConfig::default();
202 if let Ok(p) = env::var("SQLRITE_LLM_PROVIDER") {
203 cfg.provider = ProviderKind::parse(&p)?;
204 }
205 if let Ok(k) = env::var("SQLRITE_LLM_API_KEY") {
206 if !k.is_empty() {
207 cfg.api_key = Some(k);
208 }
209 }
210 if let Ok(m) = env::var("SQLRITE_LLM_MODEL") {
211 if !m.is_empty() {
212 cfg.model = m;
213 }
214 }
215 if let Ok(t) = env::var("SQLRITE_LLM_MAX_TOKENS") {
216 cfg.max_tokens = t
217 .parse()
218 .map_err(|_| AskError::Config(format!("SQLRITE_LLM_MAX_TOKENS not a u32: {t}")))?;
219 }
220 if let Ok(c) = env::var("SQLRITE_LLM_CACHE_TTL") {
221 cfg.cache_ttl = match c.to_ascii_lowercase().as_str() {
222 "5m" | "5min" | "5minutes" => CacheTtl::FiveMinutes,
223 "1h" | "1hr" | "1hour" => CacheTtl::OneHour,
224 "off" | "none" | "disabled" => CacheTtl::Off,
225 other => {
226 return Err(AskError::Config(format!(
227 "SQLRITE_LLM_CACHE_TTL: unknown value '{other}'"
228 )));
229 }
230 };
231 }
232 Ok(cfg)
233 }
234}
235
236#[derive(Debug, thiserror::Error)]
239pub enum AskError {
240 #[error("missing API key (set SQLRITE_LLM_API_KEY or AskConfig.api_key)")]
241 MissingApiKey,
242
243 #[error("config error: {0}")]
244 Config(String),
245
246 #[error("unknown provider: {0} (supported: anthropic)")]
247 UnknownProvider(String),
248
249 #[error("HTTP transport error: {0}")]
250 Http(String),
251
252 #[error("API returned status {status}: {detail}")]
253 ApiStatus { status: u16, detail: String },
254
255 #[error("API returned no text content")]
256 EmptyResponse,
257
258 #[error("model output not valid JSON: {0}")]
259 OutputNotJson(String),
260
261 #[error("model output JSON missing required field '{0}'")]
262 OutputMissingField(&'static str),
263
264 #[error("JSON serialization error: {0}")]
265 Json(#[from] serde_json::Error),
266}
267
268pub fn ask_with_schema(
279 schema_dump: &str,
280 question: &str,
281 config: &AskConfig,
282) -> Result<AskResponse, AskError> {
283 let api_key = config.api_key.clone().ok_or(AskError::MissingApiKey)?;
284
285 let provider = match config.provider {
286 ProviderKind::Anthropic => match &config.base_url {
287 Some(url) => AnthropicProvider::with_base_url(api_key, url.clone()),
288 None => AnthropicProvider::new(api_key),
289 },
290 };
291
292 ask_with_schema_and_provider(schema_dump, question, config, &provider)
293}
294
295pub fn ask_with_schema_and_provider<P: Provider>(
305 schema_dump: &str,
306 question: &str,
307 config: &AskConfig,
308 provider: &P,
309) -> Result<AskResponse, AskError> {
310 let system = build_system(schema_dump, config.cache_ttl.into_marker());
311 let messages = [UserMessage::new(question)];
312
313 let req = ProviderRequest {
314 model: &config.model,
315 max_tokens: config.max_tokens,
316 system: &system,
317 messages: &messages,
318 };
319
320 let resp = provider.complete(req)?;
321 parse_response(&resp.text, resp.usage)
322}
323
324fn parse_response(raw: &str, usage: Usage) -> Result<AskResponse, AskError> {
332 let trimmed = raw.trim();
334 let body = strip_markdown_fence(trimmed).unwrap_or(trimmed);
335
336 if let Ok(value) = serde_json::from_str::<serde_json::Value>(body) {
338 return extract_fields(&value, usage);
339 }
340
341 if let Some(json_block) = extract_first_json_object(body) {
346 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&json_block) {
347 return extract_fields(&value, usage);
348 }
349 }
350
351 Err(AskError::OutputNotJson(raw.to_string()))
352}
353
354fn extract_fields(value: &serde_json::Value, usage: Usage) -> Result<AskResponse, AskError> {
355 let sql = value
356 .get("sql")
357 .and_then(|v| v.as_str())
358 .ok_or(AskError::OutputMissingField("sql"))?
359 .trim()
360 .trim_end_matches(';')
361 .to_string();
362 let explanation = value
363 .get("explanation")
364 .and_then(|v| v.as_str())
365 .unwrap_or("")
366 .to_string();
367 Ok(AskResponse {
368 sql,
369 explanation,
370 usage,
371 })
372}
373
374fn strip_markdown_fence(s: &str) -> Option<&str> {
375 let s = s.trim();
376 let opening_variants = ["```json\n", "```JSON\n", "```\n"];
377 for opener in opening_variants {
378 if let Some(rest) = s.strip_prefix(opener) {
379 let body = rest.trim_end();
381 let body = body.strip_suffix("```").unwrap_or(body);
382 return Some(body.trim());
383 }
384 }
385 None
386}
387
388fn extract_first_json_object(s: &str) -> Option<String> {
389 let bytes = s.as_bytes();
390 let start = s.find('{')?;
391 let mut depth = 0_i32;
392 let mut in_string = false;
393 let mut escape = false;
394 for (i, &b) in bytes.iter().enumerate().skip(start) {
395 if escape {
396 escape = false;
397 continue;
398 }
399 match b {
400 b'\\' if in_string => escape = true,
401 b'"' => in_string = !in_string,
402 b'{' if !in_string => depth += 1,
403 b'}' if !in_string => {
404 depth -= 1;
405 if depth == 0 {
406 return Some(s[start..=i].to_string());
407 }
408 }
409 _ => {}
410 }
411 }
412 None
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::provider::MockProvider;
419
420 const FIXTURE_SCHEMA: &str = "\
427CREATE TABLE users (
428 id INTEGER PRIMARY KEY,
429 name TEXT
430);
431";
432
433 fn cfg() -> AskConfig {
434 AskConfig {
435 api_key: Some("test-key".to_string()),
436 ..AskConfig::default()
437 }
438 }
439
440 #[test]
441 fn ask_with_mock_provider_returns_parsed_sql() {
442 let provider = MockProvider::new(
443 r#"{"sql": "SELECT COUNT(*) FROM users", "explanation": "counts users"}"#,
444 );
445 let resp =
446 ask_with_schema_and_provider(FIXTURE_SCHEMA, "how many users?", &cfg(), &provider)
447 .unwrap();
448 assert_eq!(resp.sql, "SELECT COUNT(*) FROM users");
449 assert_eq!(resp.explanation, "counts users");
450 }
451
452 #[test]
453 fn schema_dump_appears_in_system_block() {
454 let schema = "CREATE TABLE widgets (\n id INTEGER PRIMARY KEY,\n name TEXT\n);\n";
455 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
456 let _ = ask_with_schema_and_provider(schema, "anything", &cfg(), &provider).unwrap();
457
458 let captured = provider.last_request.borrow().clone().unwrap();
459 let schema_block = &captured.system_blocks[1];
460 assert!(
461 schema_block.contains("CREATE TABLE widgets"),
462 "got: {schema_block}"
463 );
464 assert!(schema_block.contains("name TEXT"), "got: {schema_block}");
465 }
466
467 #[test]
468 fn cache_ttl_off_omits_cache_control() {
469 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
470 let mut config = cfg();
471 config.cache_ttl = CacheTtl::Off;
472 let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &config, &provider).unwrap();
473 let captured = provider.last_request.borrow().clone().unwrap();
474 assert!(!captured.schema_block_has_cache_control);
475 }
476
477 #[test]
478 fn cache_ttl_5m_sets_cache_control() {
479 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
480 let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &cfg(), &provider).unwrap();
481 let captured = provider.last_request.borrow().clone().unwrap();
482 assert!(captured.schema_block_has_cache_control);
483 }
484
485 #[test]
486 fn user_question_arrives_in_messages_unchanged() {
487 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
488 let q = "Find users with email containing '@example.com'";
489 let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, q, &cfg(), &provider).unwrap();
490 assert_eq!(
491 provider
492 .last_request
493 .borrow()
494 .as_ref()
495 .unwrap()
496 .user_message,
497 q
498 );
499 }
500
501 #[test]
502 fn missing_api_key_errors_clearly() {
503 let config = AskConfig {
505 api_key: None,
506 ..AskConfig::default()
507 };
508 let err = ask_with_schema(FIXTURE_SCHEMA, "test", &config).unwrap_err();
509 match err {
510 AskError::MissingApiKey => {}
511 other => panic!("expected MissingApiKey, got {other:?}"),
512 }
513 }
514
515 #[test]
516 fn parse_response_strips_trailing_semicolon() {
517 let resp = parse_response(
518 r#"{"sql": "SELECT 1;", "explanation": "demo"}"#,
519 Usage::default(),
520 )
521 .unwrap();
522 assert_eq!(resp.sql, "SELECT 1");
523 }
524
525 #[test]
526 fn parse_response_handles_markdown_fence() {
527 let raw = "```json\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}\n```";
528 let resp = parse_response(raw, Usage::default()).unwrap();
529 assert_eq!(resp.sql, "SELECT 1");
530 }
531
532 #[test]
533 fn parse_response_handles_leading_prose() {
534 let raw =
535 "Here is the query you asked for:\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}";
536 let resp = parse_response(raw, Usage::default()).unwrap();
537 assert_eq!(resp.sql, "SELECT 1");
538 }
539
540 #[test]
541 fn parse_response_rejects_non_json() {
542 let err = parse_response("just some prose, no JSON here", Usage::default()).unwrap_err();
543 assert!(matches!(err, AskError::OutputNotJson(_)));
544 }
545
546 #[test]
547 fn parse_response_rejects_missing_sql_field() {
548 let err = parse_response(r#"{"explanation": "no sql key"}"#, Usage::default()).unwrap_err();
549 assert!(matches!(err, AskError::OutputMissingField("sql")));
550 }
551
552 #[test]
553 fn parse_response_allows_missing_explanation() {
554 let resp = parse_response(r#"{"sql": "SELECT 1"}"#, Usage::default()).unwrap();
555 assert_eq!(resp.sql, "SELECT 1");
556 assert_eq!(resp.explanation, "");
557 }
558
559 #[test]
560 fn parse_response_passes_usage_through() {
561 let usage = Usage {
562 input_tokens: 100,
563 output_tokens: 20,
564 cache_creation_input_tokens: 80,
565 cache_read_input_tokens: 0,
566 };
567 let resp =
568 parse_response(r#"{"sql": "SELECT 1", "explanation": ""}"#, usage.clone()).unwrap();
569 assert_eq!(resp.usage.input_tokens, 100);
570 assert_eq!(resp.usage.cache_creation_input_tokens, 80);
571 }
572}