use reqwest::{Client, StatusCode, Url};
use serde_json::{Map, Value, json};
use tokio::time::{Duration, sleep};
pub const D1_ENGINE_NAME: &str = "cloudflare-d1";
pub const HEADER_D1_BOOKMARK: &str = "x-athena-d1-bookmark";
pub const HEADER_D1_SESSION_MODE: &str = "x-athena-d1-session-mode";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct D1ConnectionInfo {
pub worker_base_url: String,
pub auth_token_env_var: String,
pub database_binding: String,
pub default_session_mode: Option<String>,
}
#[derive(Debug, Clone)]
pub struct D1ExecutionResult {
pub rows: Vec<Value>,
pub columns: Vec<String>,
pub duration_ms: Option<u64>,
pub bookmark: Option<String>,
pub count: Option<u64>,
pub meta: Value,
}
impl D1ConnectionInfo {
pub fn from_metadata(metadata: &Value) -> Result<Option<Self>, String> {
let Some(cloudflare_d1) = metadata.get("cloudflareD1") else {
return Ok(None);
};
let Some(cloudflare_d1) = cloudflare_d1.as_object() else {
return Err("cloudflareD1 metadata must be an object".to_string());
};
let engine = metadata
.get("dbEngine")
.and_then(Value::as_str)
.map(str::trim)
.unwrap_or(D1_ENGINE_NAME);
if !engine.eq_ignore_ascii_case(D1_ENGINE_NAME) {
return Ok(None);
}
let worker_base_url = required_string(
cloudflare_d1,
"worker_base_url",
"cloudflareD1.worker_base_url",
)?;
validate_worker_base_url(&worker_base_url)?;
let auth_token_env_var = required_string(
cloudflare_d1,
"auth_token_env_var",
"cloudflareD1.auth_token_env_var",
)?;
let database_binding = required_string(
cloudflare_d1,
"database_binding",
"cloudflareD1.database_binding",
)?;
let default_session_mode = optional_string(
cloudflare_d1,
"default_session_mode",
"cloudflareD1.default_session_mode",
)?;
Ok(Some(Self {
worker_base_url,
auth_token_env_var,
database_binding,
default_session_mode,
}))
}
pub fn resolve_auth_token(&self) -> Result<String, String> {
let value = std::env::var(&self.auth_token_env_var).map_err(|_| {
format!(
"missing Cloudflare D1 proxy auth token env var '{}'",
self.auth_token_env_var
)
})?;
let trimmed = value.trim();
if trimmed.is_empty() {
return Err(format!(
"Cloudflare D1 proxy auth token env var '{}' is empty",
self.auth_token_env_var
));
}
Ok(trimmed.to_string())
}
pub fn effective_session_mode(&self, requested: Option<&str>) -> Option<String> {
requested
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
.or_else(|| self.default_session_mode.clone())
}
}
pub async fn execute_query_via_proxy(
http: &Client,
info: &D1ConnectionInfo,
query: &str,
params: Vec<Value>,
requested_session_mode: Option<&str>,
bookmark: Option<&str>,
retry_writes: bool,
) -> Result<D1ExecutionResult, String> {
let token = info.resolve_auth_token()?;
let endpoint = format!("{}/query", info.worker_base_url.trim_end_matches('/'));
let session_mode = info.effective_session_mode(requested_session_mode);
let bookmark = bookmark
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string);
let body = json!({
"query": query,
"params": params,
"databaseBinding": info.database_binding,
"sessionMode": session_mode,
"bookmark": bookmark,
});
let is_write = retry_writes && sql_is_write(query);
let max_attempts = if is_write { 3 } else { 1 };
let mut last_error = String::new();
for attempt in 1..=max_attempts {
let request = http
.post(&endpoint)
.bearer_auth(&token)
.header("Content-Type", "application/json")
.json(&body);
match request.send().await {
Ok(response) => {
let status = response.status();
let response_bookmark = response
.headers()
.get(HEADER_D1_BOOKMARK)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string);
let payload = response.json::<Value>().await.map_err(|error| {
format!("Cloudflare D1 proxy returned invalid JSON: {error}")
})?;
if status.is_success() {
return Ok(parse_execution_result(payload, response_bookmark));
}
let retryable = is_write && is_retryable_status(status);
let message = payload
.get("message")
.and_then(Value::as_str)
.or_else(|| payload.get("error").and_then(Value::as_str))
.unwrap_or("Cloudflare D1 proxy request failed");
last_error = format!(
"Cloudflare D1 proxy request failed with status {}: {}",
status.as_u16(),
message
);
if retryable && attempt < max_attempts {
sleep(retry_delay(attempt)).await;
continue;
}
return Err(last_error);
}
Err(error) => {
last_error = format!("Cloudflare D1 proxy request failed: {error}");
if is_write && attempt < max_attempts {
sleep(retry_delay(attempt)).await;
continue;
}
return Err(last_error);
}
}
}
Err(last_error)
}
fn parse_execution_result(payload: Value, response_bookmark: Option<String>) -> D1ExecutionResult {
let rows = payload
.get("rows")
.and_then(Value::as_array)
.cloned()
.unwrap_or_default();
let columns = payload
.get("columns")
.and_then(Value::as_array)
.map(|values| {
values
.iter()
.filter_map(|value| value.as_str().map(str::to_string))
.collect::<Vec<_>>()
})
.unwrap_or_default();
let duration_ms = payload.get("durationMs").and_then(Value::as_u64);
let bookmark = response_bookmark.or_else(|| {
payload
.get("bookmark")
.and_then(Value::as_str)
.map(str::to_string)
});
let count = payload.get("count").and_then(Value::as_u64);
D1ExecutionResult {
rows,
columns,
duration_ms,
bookmark,
count,
meta: payload.get("meta").cloned().unwrap_or_else(|| json!({})),
}
}
fn required_string(object: &Map<String, Value>, key: &str, label: &str) -> Result<String, String> {
let value = object
.get(key)
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| format!("{label} is required"))?;
Ok(value.to_string())
}
fn optional_string(
object: &Map<String, Value>,
key: &str,
label: &str,
) -> Result<Option<String>, String> {
match object.get(key) {
None | Some(Value::Null) => Ok(None),
Some(Value::String(value)) => {
let trimmed = value.trim();
if trimmed.is_empty() {
Ok(None)
} else {
Ok(Some(trimmed.to_string()))
}
}
Some(_) => Err(format!("{label} must be a string")),
}
}
fn validate_worker_base_url(raw: &str) -> Result<(), String> {
let url = Url::parse(raw)
.map_err(|error| format!("invalid cloudflareD1.worker_base_url: {error}"))?;
let scheme = url.scheme();
if scheme.eq_ignore_ascii_case("https") {
return Ok(());
}
if scheme.eq_ignore_ascii_case("http") && is_local_host(url.host_str().unwrap_or_default()) {
return Ok(());
}
Err("cloudflareD1.worker_base_url must use https outside local development".to_string())
}
fn is_local_host(host: &str) -> bool {
matches!(
host.to_ascii_lowercase().as_str(),
"localhost" | "127.0.0.1" | "::1"
)
}
fn sql_is_write(sql: &str) -> bool {
let token = sql
.trim_start()
.split_whitespace()
.next()
.unwrap_or_default()
.to_ascii_uppercase();
matches!(
token.as_str(),
"INSERT"
| "UPDATE"
| "DELETE"
| "REPLACE"
| "CREATE"
| "ALTER"
| "DROP"
| "PRAGMA"
| "VACUUM"
)
}
fn is_retryable_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::TOO_MANY_REQUESTS
| StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
)
}
fn retry_delay(attempt: usize) -> Duration {
let base_ms = 100u64.saturating_mul(1u64 << (attempt.saturating_sub(1) as u32));
Duration::from_millis(base_ms.min(1_000))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parses_cloudflare_d1_connection_info() {
let metadata = json!({
"dbEngine": "cloudflare-d1",
"cloudflareD1": {
"worker_base_url": "https://d1-proxy.example.com",
"auth_token_env_var": "ATHENA_D1_PROXY_TOKEN_REPORTING",
"database_binding": "DB",
"default_session_mode": "first-unconstrained"
}
});
let info = D1ConnectionInfo::from_metadata(&metadata)
.expect("metadata should parse")
.expect("expected D1 info");
assert_eq!(info.database_binding, "DB");
assert_eq!(
info.default_session_mode.as_deref(),
Some("first-unconstrained")
);
}
#[test]
fn rejects_non_https_remote_worker_base_url() {
let metadata = json!({
"dbEngine": "cloudflare-d1",
"cloudflareD1": {
"worker_base_url": "http://d1-proxy.example.com",
"auth_token_env_var": "ATHENA_D1_PROXY_TOKEN_REPORTING",
"database_binding": "DB"
}
});
let err = D1ConnectionInfo::from_metadata(&metadata).expect_err("http remote should fail");
assert!(err.contains("must use https"));
}
#[test]
fn allows_local_http_worker_base_url() {
let metadata = json!({
"dbEngine": "cloudflare-d1",
"cloudflareD1": {
"worker_base_url": "http://localhost:8787",
"auth_token_env_var": "ATHENA_D1_PROXY_TOKEN_REPORTING",
"database_binding": "DB"
}
});
let info = D1ConnectionInfo::from_metadata(&metadata)
.expect("metadata should parse")
.expect("expected D1 info");
assert_eq!(info.worker_base_url, "http://localhost:8787");
}
#[test]
fn sql_write_detection_matches_expected_verbs() {
assert!(sql_is_write("INSERT INTO users (id) VALUES (1)"));
assert!(sql_is_write(" update users set active = 1"));
assert!(!sql_is_write("SELECT * FROM users"));
}
}