use serde::{Deserialize, Serialize};
use crate::error::CliError;
use drizzle_migrations::Migrations;
use super::{AppliedMigrationRecord, MigrationResult};
#[derive(Serialize)]
struct Request<'a> {
sql: &'a str,
#[serde(skip_serializing_if = "<[serde_json::Value]>::is_empty")]
params: &'a [serde_json::Value],
}
#[derive(Deserialize, Debug)]
struct Response {
success: bool,
#[serde(default)]
result: Vec<ResultEntry>,
#[serde(default)]
errors: Vec<ApiError>,
}
#[derive(Deserialize, Debug)]
struct ResultEntry {
#[serde(default)]
results: Option<Rows>,
}
#[derive(Deserialize, Debug)]
#[serde(untagged)]
enum Rows {
Objects(Vec<serde_json::Map<String, serde_json::Value>>),
Values {
#[allow(dead_code)]
columns: Vec<String>,
rows: Vec<Vec<serde_json::Value>>,
},
}
#[derive(Deserialize, Debug)]
struct ApiError {
code: i64,
message: String,
}
pub(super) struct D1HttpClient {
http: reqwest::Client,
base_url: String,
auth_header: String,
}
impl D1HttpClient {
pub fn new(account_id: &str, database_id: &str, token: &str) -> Result<Self, CliError> {
let http = reqwest::Client::builder()
.user_agent(concat!("drizzle-cli/", env!("CARGO_PKG_VERSION")))
.build()
.map_err(|e| {
CliError::ConnectionError(format!("Failed to build reqwest client: {e}"))
})?;
Ok(Self {
http,
base_url: format!(
"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database/{database_id}"
),
auth_header: format!("Bearer {token}"),
})
}
async fn post(
&self,
path: &str,
sql: &str,
params: &[serde_json::Value],
) -> Result<Response, CliError> {
let url = format!("{}{}", self.base_url, path);
let resp = self
.http
.post(&url)
.header(reqwest::header::AUTHORIZATION, &self.auth_header)
.header(reqwest::header::CONTENT_TYPE, "application/json")
.json(&Request { sql, params })
.send()
.await
.map_err(|e| CliError::ConnectionError(format!("D1 HTTP request failed: {e}")))?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| CliError::Other(format!("D1 response read failed ({status}): {e}")))?;
let body: Response = serde_json::from_str(&text).map_err(|e| {
CliError::Other(format!(
"D1 response parse failed ({status}): {e}\nbody: {text}"
))
})?;
if !body.success {
let msg = if body.errors.is_empty() {
format!("HTTP {status}")
} else {
body.errors
.iter()
.map(|e| format!("{}: {}", e.code, e.message))
.collect::<Vec<_>>()
.join("\n")
};
return Err(CliError::MigrationError(format!("D1 API error: {msg}")));
}
Ok(body)
}
pub async fn query(
&self,
sql: &str,
params: &[serde_json::Value],
) -> Result<Vec<serde_json::Map<String, serde_json::Value>>, CliError> {
let body = self.post("/query", sql, params).await?;
Ok(extract_row_objects(body))
}
pub async fn run(&self, sql: &str) -> Result<(), CliError> {
self.post("/query", sql, &[]).await?;
Ok(())
}
pub async fn batch(&self, statements: &[&str]) -> Result<(), CliError> {
let joined = statements
.iter()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join("; ");
if joined.is_empty() {
return Ok(());
}
self.post("/query", &joined, &[]).await?;
Ok(())
}
}
fn extract_row_objects(body: Response) -> Vec<serde_json::Map<String, serde_json::Value>> {
let mut out = Vec::new();
for entry in body.result {
match entry.results {
Some(Rows::Objects(rows)) => out.extend(rows),
Some(Rows::Values { rows, columns }) => {
for row in rows {
let mut map = serde_json::Map::new();
for (col, val) in columns.iter().zip(row) {
map.insert(col.clone(), val);
}
out.push(map);
}
}
None => {}
}
}
out
}
fn rt() -> Result<tokio::runtime::Runtime, CliError> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| CliError::Other(format!("Failed to create async runtime: {e}")))
}
fn client(account_id: &str, database_id: &str, token: &str) -> Result<D1HttpClient, CliError> {
D1HttpClient::new(account_id, database_id, token)
}
pub(super) fn inspect_migrations(
set: &Migrations,
account_id: &str,
database_id: &str,
token: &str,
) -> Result<super::MigrationPlan, CliError> {
let rt = rt()?;
rt.block_on(async {
let c = client(account_id, database_id, token)?;
ensure_tracking_table(&c, set).await?;
let applied = query_applied_records(&c, set).await?;
super::build_migration_plan(set, &applied)
})
}
pub(super) fn run_migrations(
set: &Migrations,
account_id: &str,
database_id: &str,
token: &str,
) -> Result<MigrationResult, CliError> {
let rt = rt()?;
rt.block_on(async {
let c = client(account_id, database_id, token)?;
ensure_tracking_table(&c, set).await?;
let applied_names = query_applied_names(&c, set).await?;
let pending: Vec<_> = set.pending(&applied_names).collect();
if pending.is_empty() {
return Ok(MigrationResult {
applied_count: 0,
applied_migrations: vec![],
});
}
let mut applied_hashes = Vec::new();
for migration in &pending {
let mut stmts: Vec<&str> = migration
.statements()
.iter()
.map(std::string::String::as_str)
.filter(|s| !s.trim().is_empty())
.collect();
let record_sql = set.record_migration_sql(migration);
stmts.push(&record_sql);
c.batch(&stmts).await.map_err(|e| match e {
CliError::MigrationError(inner) => CliError::MigrationError(format!(
"Migration '{}' failed: {}",
migration.hash(),
inner
)),
other => other,
})?;
applied_hashes.push(migration.hash().to_string());
}
Ok(MigrationResult {
applied_count: applied_hashes.len(),
applied_migrations: applied_hashes,
})
})
}
pub(super) fn execute_statements(
account_id: &str,
database_id: &str,
token: &str,
statements: &[String],
) -> Result<(), CliError> {
let rt = rt()?;
rt.block_on(async {
let c = client(account_id, database_id, token)?;
let refs: Vec<&str> = statements.iter().map(String::as_str).collect();
c.batch(&refs).await
})
}
pub(super) fn init_metadata(
set: &Migrations,
account_id: &str,
database_id: &str,
token: &str,
) -> Result<(), CliError> {
let rt = rt()?;
rt.block_on(async {
let c = client(account_id, database_id, token)?;
ensure_tracking_table(&c, set).await?;
let applied_names = query_applied_names(&c, set).await?;
super::validate_init_metadata(&applied_names, set)?;
let Some(first) = set.all().first() else {
return Ok(());
};
c.run(&set.record_migration_sql(first)).await?;
Ok(())
})
}
async fn ensure_tracking_table(c: &D1HttpClient, set: &Migrations) -> Result<(), CliError> {
c.run(&set.create_table_sql()).await
}
async fn query_applied_names(c: &D1HttpClient, set: &Migrations) -> Result<Vec<String>, CliError> {
let rows = c.query(&set.applied_names_sql(), &[]).await?;
Ok(rows
.into_iter()
.filter_map(|mut row| {
row.remove("name")
.and_then(|v| v.as_str().map(String::from))
})
.collect())
}
async fn query_applied_records(
c: &D1HttpClient,
set: &Migrations,
) -> Result<Vec<AppliedMigrationRecord>, CliError> {
let sql = format!(
r#"SELECT "hash", "name" FROM {} WHERE "name" IS NOT NULL ORDER BY id;"#,
set.table_ident_sql()
);
let rows = c.query(&sql, &[]).await?;
Ok(rows
.into_iter()
.filter_map(|mut row| {
let hash = row.remove("hash")?.as_str()?.to_string();
let name = row.remove("name")?.as_str()?.to_string();
Some(AppliedMigrationRecord { hash, name })
})
.collect())
}