use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use sqlrite_ask::{AskConfig, AskError, CacheTtl, ask_with_schema};
const TEST_SCHEMA: &str = "\
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT
);
";
struct Mock {
server: Arc<tiny_http::Server>,
addr: String,
captured: Arc<Mutex<Option<CapturedRequest>>>,
handle: Option<thread::JoinHandle<()>>,
}
struct CapturedRequest {
body: serde_json::Value,
headers: Vec<(String, String)>,
}
impl Mock {
fn start(canned_status: u16, canned_body: &'static str) -> Self {
let server = Arc::new(tiny_http::Server::http("127.0.0.1:0").expect("bind localhost"));
let addr = format!("http://{}", server.server_addr());
let captured: Arc<Mutex<Option<CapturedRequest>>> = Arc::new(Mutex::new(None));
let server_for_thread = server.clone();
let captured_for_thread = captured.clone();
let handle = thread::spawn(move || {
if let Ok(mut req) = server_for_thread.recv() {
let headers: Vec<(String, String)> = req
.headers()
.iter()
.map(|h| (h.field.as_str().to_string(), h.value.as_str().to_string()))
.collect();
let mut body = String::new();
req.as_reader().read_to_string(&mut body).unwrap();
let parsed: serde_json::Value =
serde_json::from_str(&body).unwrap_or(serde_json::Value::Null);
*captured_for_thread.lock().unwrap() = Some(CapturedRequest {
body: parsed,
headers,
});
let response = tiny_http::Response::from_string(canned_body)
.with_status_code(canned_status)
.with_header(
"Content-Type: application/json"
.parse::<tiny_http::Header>()
.unwrap(),
);
let _ = req.respond(response);
}
});
Self {
server,
addr,
captured,
handle: Some(handle),
}
}
fn captured(&self) -> Option<CapturedRequest> {
self.captured.lock().unwrap().take()
}
}
impl Drop for Mock {
fn drop(&mut self) {
self.server.unblock();
if let Some(h) = self.handle.take() {
let _ = h.join();
}
}
}
const SUCCESS_BODY: &str = r#"{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": "claude-sonnet-4-6",
"content": [
{"type": "text", "text": "{\"sql\": \"SELECT * FROM users\", \"explanation\": \"reads all users\"}"}
],
"stop_reason": "end_turn",
"usage": {"input_tokens": 1234, "output_tokens": 56, "cache_creation_input_tokens": 1000, "cache_read_input_tokens": 0}
}"#;
#[test]
fn end_to_end_against_localhost_mock() {
let mock = Mock::start(200, SUCCESS_BODY);
let config = AskConfig {
api_key: Some("test-key".to_string()),
base_url: Some(mock.addr.clone()),
..AskConfig::default()
};
let resp = ask_with_schema(TEST_SCHEMA, "list all users", &config).expect("ask succeeds");
assert_eq!(resp.sql, "SELECT * FROM users");
assert_eq!(resp.explanation, "reads all users");
assert_eq!(resp.usage.input_tokens, 1234);
assert_eq!(resp.usage.cache_creation_input_tokens, 1000);
assert_eq!(resp.usage.cache_read_input_tokens, 0);
let captured = mock.captured().expect("server received request");
assert_eq!(captured.body["model"], "claude-sonnet-4-6");
assert_eq!(captured.body["max_tokens"], 1024);
assert_eq!(captured.body["messages"][0]["role"], "user");
assert_eq!(captured.body["messages"][0]["content"], "list all users");
assert!(
captured.body["system"][1]["text"]
.as_str()
.unwrap()
.contains("CREATE TABLE users")
);
assert_eq!(
captured.body["system"][1]["cache_control"]["type"],
"ephemeral"
);
let mut saw_api_key = false;
let mut saw_version = false;
for (k, v) in &captured.headers {
if k.eq_ignore_ascii_case("x-api-key") && v == "test-key" {
saw_api_key = true;
}
if k.eq_ignore_ascii_case("anthropic-version") && v == "2023-06-01" {
saw_version = true;
}
}
assert!(
saw_api_key,
"missing x-api-key header; saw: {:?}",
captured.headers
);
assert!(
saw_version,
"missing anthropic-version header; saw: {:?}",
captured.headers
);
}
#[test]
fn cache_ttl_one_hour_propagates_to_request() {
let mock = Mock::start(200, SUCCESS_BODY);
let config = AskConfig {
api_key: Some("test-key".to_string()),
base_url: Some(mock.addr.clone()),
cache_ttl: CacheTtl::OneHour,
..AskConfig::default()
};
let _ = ask_with_schema(TEST_SCHEMA, "anything", &config).unwrap();
let captured = mock.captured().unwrap();
assert_eq!(captured.body["system"][1]["cache_control"]["ttl"], "1h");
}
#[test]
fn api_error_response_is_surfaced() {
let mock = Mock::start(
400,
r#"{"type":"error","error":{"type":"invalid_request_error","message":"max_tokens too large"}}"#,
);
let config = AskConfig {
api_key: Some("test-key".to_string()),
base_url: Some(mock.addr.clone()),
..AskConfig::default()
};
let err = ask_with_schema(TEST_SCHEMA, "anything", &config).unwrap_err();
match err {
AskError::ApiStatus { status, detail } => {
assert_eq!(status, 400);
assert!(
detail.contains("invalid_request_error") && detail.contains("max_tokens too large"),
"got: {detail}"
);
}
other => panic!("expected ApiStatus, got {other:?}"),
}
}
#[test]
fn http_transport_error_is_surfaced() {
let port = {
let s = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let p = s.local_addr().unwrap().port();
drop(s);
p
};
std::thread::sleep(Duration::from_millis(10));
let config = AskConfig {
api_key: Some("test-key".to_string()),
base_url: Some(format!("http://127.0.0.1:{port}")),
..AskConfig::default()
};
let err = ask_with_schema(TEST_SCHEMA, "anything", &config).unwrap_err();
assert!(
matches!(err, AskError::Http(_)),
"expected Http error, got {err:?}"
);
}