athena-driver 3.18.0

Backend driver primitives for Athena, starting with Scylla and Supabase health-aware clients
Documentation
use reqwest::{Client, StatusCode, Url};
use serde_json::{Map, Value, json};
use tokio::time::{Duration, sleep};

pub const D1_ENGINE_NAME: &str = "cloudflare-d1";
pub const HEADER_D1_BOOKMARK: &str = "x-athena-d1-bookmark";
pub const HEADER_D1_SESSION_MODE: &str = "x-athena-d1-session-mode";

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct D1ConnectionInfo {
    pub worker_base_url: String,
    pub auth_token_env_var: String,
    pub database_binding: String,
    pub default_session_mode: Option<String>,
}

#[derive(Debug, Clone)]
pub struct D1ExecutionResult {
    pub rows: Vec<Value>,
    pub columns: Vec<String>,
    pub duration_ms: Option<u64>,
    pub bookmark: Option<String>,
    pub count: Option<u64>,
    pub meta: Value,
}

impl D1ConnectionInfo {
    pub fn from_metadata(metadata: &Value) -> Result<Option<Self>, String> {
        let Some(cloudflare_d1) = metadata.get("cloudflareD1") else {
            return Ok(None);
        };
        let Some(cloudflare_d1) = cloudflare_d1.as_object() else {
            return Err("cloudflareD1 metadata must be an object".to_string());
        };

        let engine = metadata
            .get("dbEngine")
            .and_then(Value::as_str)
            .map(str::trim)
            .unwrap_or(D1_ENGINE_NAME);
        if !engine.eq_ignore_ascii_case(D1_ENGINE_NAME) {
            return Ok(None);
        }

        let worker_base_url = required_string(
            cloudflare_d1,
            "worker_base_url",
            "cloudflareD1.worker_base_url",
        )?;
        validate_worker_base_url(&worker_base_url)?;

        let auth_token_env_var = required_string(
            cloudflare_d1,
            "auth_token_env_var",
            "cloudflareD1.auth_token_env_var",
        )?;
        let database_binding = required_string(
            cloudflare_d1,
            "database_binding",
            "cloudflareD1.database_binding",
        )?;
        let default_session_mode = optional_string(
            cloudflare_d1,
            "default_session_mode",
            "cloudflareD1.default_session_mode",
        )?;

        Ok(Some(Self {
            worker_base_url,
            auth_token_env_var,
            database_binding,
            default_session_mode,
        }))
    }

    pub fn resolve_auth_token(&self) -> Result<String, String> {
        let value = std::env::var(&self.auth_token_env_var).map_err(|_| {
            format!(
                "missing Cloudflare D1 proxy auth token env var '{}'",
                self.auth_token_env_var
            )
        })?;
        let trimmed = value.trim();
        if trimmed.is_empty() {
            return Err(format!(
                "Cloudflare D1 proxy auth token env var '{}' is empty",
                self.auth_token_env_var
            ));
        }
        Ok(trimmed.to_string())
    }

    pub fn effective_session_mode(&self, requested: Option<&str>) -> Option<String> {
        requested
            .map(str::trim)
            .filter(|value| !value.is_empty())
            .map(str::to_string)
            .or_else(|| self.default_session_mode.clone())
    }
}

pub async fn execute_query_via_proxy(
    http: &Client,
    info: &D1ConnectionInfo,
    query: &str,
    params: Vec<Value>,
    requested_session_mode: Option<&str>,
    bookmark: Option<&str>,
    retry_writes: bool,
) -> Result<D1ExecutionResult, String> {
    let token = info.resolve_auth_token()?;
    let endpoint = format!("{}/query", info.worker_base_url.trim_end_matches('/'));
    let session_mode = info.effective_session_mode(requested_session_mode);
    let bookmark = bookmark
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(str::to_string);
    let body = json!({
        "query": query,
        "params": params,
        "databaseBinding": info.database_binding,
        "sessionMode": session_mode,
        "bookmark": bookmark,
    });

    let is_write = retry_writes && sql_is_write(query);
    let max_attempts = if is_write { 3 } else { 1 };
    let mut last_error = String::new();

    for attempt in 1..=max_attempts {
        let request = http
            .post(&endpoint)
            .bearer_auth(&token)
            .header("Content-Type", "application/json")
            .json(&body);

        match request.send().await {
            Ok(response) => {
                let status = response.status();
                let response_bookmark = response
                    .headers()
                    .get(HEADER_D1_BOOKMARK)
                    .and_then(|value| value.to_str().ok())
                    .map(str::trim)
                    .filter(|value| !value.is_empty())
                    .map(str::to_string);
                let payload = response.json::<Value>().await.map_err(|error| {
                    format!("Cloudflare D1 proxy returned invalid JSON: {error}")
                })?;

                if status.is_success() {
                    return Ok(parse_execution_result(payload, response_bookmark));
                }

                let retryable = is_write && is_retryable_status(status);
                let message = payload
                    .get("message")
                    .and_then(Value::as_str)
                    .or_else(|| payload.get("error").and_then(Value::as_str))
                    .unwrap_or("Cloudflare D1 proxy request failed");
                last_error = format!(
                    "Cloudflare D1 proxy request failed with status {}: {}",
                    status.as_u16(),
                    message
                );

                if retryable && attempt < max_attempts {
                    sleep(retry_delay(attempt)).await;
                    continue;
                }

                return Err(last_error);
            }
            Err(error) => {
                last_error = format!("Cloudflare D1 proxy request failed: {error}");
                if is_write && attempt < max_attempts {
                    sleep(retry_delay(attempt)).await;
                    continue;
                }
                return Err(last_error);
            }
        }
    }

    Err(last_error)
}

fn parse_execution_result(payload: Value, response_bookmark: Option<String>) -> D1ExecutionResult {
    let rows = payload
        .get("rows")
        .and_then(Value::as_array)
        .cloned()
        .unwrap_or_default();
    let columns = payload
        .get("columns")
        .and_then(Value::as_array)
        .map(|values| {
            values
                .iter()
                .filter_map(|value| value.as_str().map(str::to_string))
                .collect::<Vec<_>>()
        })
        .unwrap_or_default();
    let duration_ms = payload.get("durationMs").and_then(Value::as_u64);
    let bookmark = response_bookmark.or_else(|| {
        payload
            .get("bookmark")
            .and_then(Value::as_str)
            .map(str::to_string)
    });
    let count = payload.get("count").and_then(Value::as_u64);

    D1ExecutionResult {
        rows,
        columns,
        duration_ms,
        bookmark,
        count,
        meta: payload.get("meta").cloned().unwrap_or_else(|| json!({})),
    }
}

fn required_string(object: &Map<String, Value>, key: &str, label: &str) -> Result<String, String> {
    let value = object
        .get(key)
        .and_then(Value::as_str)
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .ok_or_else(|| format!("{label} is required"))?;
    Ok(value.to_string())
}

fn optional_string(
    object: &Map<String, Value>,
    key: &str,
    label: &str,
) -> Result<Option<String>, String> {
    match object.get(key) {
        None | Some(Value::Null) => Ok(None),
        Some(Value::String(value)) => {
            let trimmed = value.trim();
            if trimmed.is_empty() {
                Ok(None)
            } else {
                Ok(Some(trimmed.to_string()))
            }
        }
        Some(_) => Err(format!("{label} must be a string")),
    }
}

fn validate_worker_base_url(raw: &str) -> Result<(), String> {
    let url = Url::parse(raw)
        .map_err(|error| format!("invalid cloudflareD1.worker_base_url: {error}"))?;

    let scheme = url.scheme();
    if scheme.eq_ignore_ascii_case("https") {
        return Ok(());
    }

    if scheme.eq_ignore_ascii_case("http") && is_local_host(url.host_str().unwrap_or_default()) {
        return Ok(());
    }

    Err("cloudflareD1.worker_base_url must use https outside local development".to_string())
}

fn is_local_host(host: &str) -> bool {
    matches!(
        host.to_ascii_lowercase().as_str(),
        "localhost" | "127.0.0.1" | "::1"
    )
}

fn sql_is_write(sql: &str) -> bool {
    let token = sql
        .trim_start()
        .split_whitespace()
        .next()
        .unwrap_or_default()
        .to_ascii_uppercase();
    matches!(
        token.as_str(),
        "INSERT"
            | "UPDATE"
            | "DELETE"
            | "REPLACE"
            | "CREATE"
            | "ALTER"
            | "DROP"
            | "PRAGMA"
            | "VACUUM"
    )
}

fn is_retryable_status(status: StatusCode) -> bool {
    matches!(
        status,
        StatusCode::TOO_MANY_REQUESTS
            | StatusCode::INTERNAL_SERVER_ERROR
            | StatusCode::BAD_GATEWAY
            | StatusCode::SERVICE_UNAVAILABLE
            | StatusCode::GATEWAY_TIMEOUT
    )
}

fn retry_delay(attempt: usize) -> Duration {
    let base_ms = 100u64.saturating_mul(1u64 << (attempt.saturating_sub(1) as u32));
    Duration::from_millis(base_ms.min(1_000))
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    #[test]
    fn parses_cloudflare_d1_connection_info() {
        let metadata = json!({
            "dbEngine": "cloudflare-d1",
            "cloudflareD1": {
                "worker_base_url": "https://d1-proxy.example.com",
                "auth_token_env_var": "ATHENA_D1_PROXY_TOKEN_REPORTING",
                "database_binding": "DB",
                "default_session_mode": "first-unconstrained"
            }
        });

        let info = D1ConnectionInfo::from_metadata(&metadata)
            .expect("metadata should parse")
            .expect("expected D1 info");
        assert_eq!(info.database_binding, "DB");
        assert_eq!(
            info.default_session_mode.as_deref(),
            Some("first-unconstrained")
        );
    }

    #[test]
    fn rejects_non_https_remote_worker_base_url() {
        let metadata = json!({
            "dbEngine": "cloudflare-d1",
            "cloudflareD1": {
                "worker_base_url": "http://d1-proxy.example.com",
                "auth_token_env_var": "ATHENA_D1_PROXY_TOKEN_REPORTING",
                "database_binding": "DB"
            }
        });

        let err = D1ConnectionInfo::from_metadata(&metadata).expect_err("http remote should fail");
        assert!(err.contains("must use https"));
    }

    #[test]
    fn allows_local_http_worker_base_url() {
        let metadata = json!({
            "dbEngine": "cloudflare-d1",
            "cloudflareD1": {
                "worker_base_url": "http://localhost:8787",
                "auth_token_env_var": "ATHENA_D1_PROXY_TOKEN_REPORTING",
                "database_binding": "DB"
            }
        });

        let info = D1ConnectionInfo::from_metadata(&metadata)
            .expect("metadata should parse")
            .expect("expected D1 info");
        assert_eq!(info.worker_base_url, "http://localhost:8787");
    }

    #[test]
    fn sql_write_detection_matches_expected_verbs() {
        assert!(sql_is_write("INSERT INTO users (id) VALUES (1)"));
        assert!(sql_is_write("  update users set active = 1"));
        assert!(!sql_is_write("SELECT * FROM users"));
    }
}