use actix_web::{
HttpRequest, Responder, post,
web::{Data, Json},
};
use serde_json::json;
use sqlx::{PgPool, Row};
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::ffi::OsString;
use tokio::process::Command;
use super::types::{
DdlObjectFamily, ExportedDdlObject, ManagementDdlExportFiltersApplied,
ManagementDdlExportRequest, ManagementDdlExportResponse,
};
use super::{
MANAGEMENT_READ_RIGHT, authorize_management_request, registered_client_for,
required_client_name,
};
use crate::AppState;
use crate::api::client_context::pool_for_client;
use crate::api::response::{api_success, bad_request, internal_error};
use crate::drivers::postgresql::sqlx_driver::RegisteredClient;
use crate::parser::resolve_compatible_postgres_uri;
use crate::utils::pg_tools::ensure_pg_tools;
use crate::utils::request_logging::{log_operation_event, log_request};
#[derive(Debug, Clone)]
struct DdlExportRequestNormalized {
client_name: String,
schemas: Option<Vec<String>>,
schema_filter: Option<HashSet<String>>,
object_families: Vec<DdlObjectFamily>,
requested_families: HashSet<DdlObjectFamily>,
include_system: bool,
include_roles: bool,
include_databases: bool,
format: String,
}
#[derive(Debug, Clone)]
struct ClassifiedStatement {
family: DdlObjectFamily,
schema_name: Option<String>,
object_name: String,
identity: Option<String>,
ddl: String,
original_index: usize,
sort_weight: usize,
sort_schema_name: Option<String>,
sort_object_name: String,
}
#[derive(Debug, Clone)]
struct SchemaPreamble {
schema_name: String,
ddl: String,
}
#[derive(Debug, Clone)]
struct ExportedDdlGroup {
family: DdlObjectFamily,
schema_name: Option<String>,
object_name: String,
identity: Option<String>,
statements: Vec<String>,
original_index: usize,
sort_weight: usize,
sort_schema_name: Option<String>,
sort_object_name: String,
}
#[derive(Debug, Clone)]
struct DdlExportResult {
sql: String,
objects: Vec<ExportedDdlObject>,
warnings: Vec<String>,
}
#[post("/management/ddl/export")]
pub async fn management_export_ddl(
req: HttpRequest,
body: Json<ManagementDdlExportRequest>,
app_state: Data<AppState>,
) -> impl Responder {
let started = std::time::Instant::now();
let caller_client = match required_client_name(&req) {
Ok(value) => value,
Err(resp) => return resp,
};
let normalized = match normalize_request(&body.0, &caller_client) {
Ok(value) => value,
Err(resp) => return resp,
};
if let Err(resp) = registered_client_for(app_state.get_ref(), &normalized.client_name) {
return resp;
}
let auth = match authorize_management_request(
&req,
app_state.get_ref(),
&normalized.client_name,
vec![MANAGEMENT_READ_RIGHT.to_string()],
)
.await
{
Ok(auth) => auth,
Err(resp) => return resp,
};
let logged_request = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
let client = match registered_client_for(app_state.get_ref(), &normalized.client_name) {
Ok(value) => value,
Err(resp) => return resp,
};
let pool = match pool_for_client(app_state.get_ref(), &normalized.client_name).await {
Ok(value) => value,
Err(resp) => return resp,
};
let export = match run_ddl_export(&normalized, &client, &pool).await {
Ok(value) => value,
Err(err) => {
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"management_export_ddl",
None,
started.elapsed().as_millis(),
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
Some(json!({ "error": err })),
);
return internal_error("Failed to export DDL", err);
}
};
let filters_applied = ManagementDdlExportFiltersApplied {
client_name: normalized.client_name.clone(),
schemas: normalized.schemas.clone(),
object_families: normalized.object_families.clone(),
include_system: normalized.include_system,
include_roles: normalized.include_roles,
include_databases: normalized.include_databases,
format: normalized.format.clone(),
};
let response = ManagementDdlExportResponse {
sql: export.sql,
objects: export.objects,
warnings: export.warnings,
filters_applied,
};
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"management_export_ddl",
None,
started.elapsed().as_millis(),
actix_web::http::StatusCode::OK,
Some(json!({
"object_count": response.objects.len(),
"warning_count": response.warnings.len(),
})),
);
api_success("Exported DDL", json!(response))
}
fn normalize_request(
body: &ManagementDdlExportRequest,
caller_client: &str,
) -> Result<DdlExportRequestNormalized, actix_web::HttpResponse> {
let body_client = body.client_name.trim();
if body_client.is_empty() {
return Err(bad_request(
"Invalid DDL export request",
"client_name is required",
));
}
if body_client != caller_client {
return Err(bad_request(
"Invalid DDL export request",
format!(
"Body client_name '{}' must match the bound client header '{}'.",
body_client, caller_client
),
));
}
let format = body.format.trim().to_ascii_lowercase();
if format != "sql" {
return Err(bad_request(
"Invalid DDL export request",
format!(
"Unsupported export format '{}'. Only 'sql' is supported.",
body.format
),
));
}
let requested_families: HashSet<DdlObjectFamily> =
body.object_families.iter().copied().collect();
if requested_families.is_empty() {
return Err(bad_request(
"Invalid DDL export request",
"At least one object family must be selected.",
));
}
let mut schema_values: Vec<String> = body
.schemas
.clone()
.unwrap_or_default()
.into_iter()
.map(|value| normalize_schema_name(&value))
.filter(|value| !value.is_empty())
.collect();
schema_values.sort();
schema_values.dedup();
let schemas = if schema_values.is_empty() {
None
} else {
Some(schema_values)
};
let schema_filter = schemas
.as_ref()
.map(|items| items.iter().cloned().collect::<HashSet<String>>());
Ok(DdlExportRequestNormalized {
client_name: body_client.to_string(),
schemas,
schema_filter,
object_families: body.object_families.clone(),
requested_families,
include_system: body.include_system,
include_roles: body.include_roles,
include_databases: body.include_databases,
format,
})
}
async fn run_ddl_export(
request: &DdlExportRequestNormalized,
client: &RegisteredClient,
pool: &PgPool,
) -> Result<DdlExportResult, String> {
let mut warnings: Vec<String> = Vec::new();
let mut schema_preambles: Vec<SchemaPreamble> = Vec::new();
let mut groups: Vec<ExportedDdlGroup> = Vec::new();
if needs_schema_dump(request) {
let dump_sql = run_pg_dump_schema_only(request, client).await?;
let (dump_preambles, dump_groups, dump_warnings) = classify_dump_sql(&dump_sql, request);
schema_preambles = dump_preambles;
groups = dump_groups;
warnings.extend(dump_warnings);
}
if request.requested_families.contains(&DdlObjectFamily::Roles) {
if request.include_roles {
let role_groups = export_roles(pool, request).await?;
if role_groups.is_empty() {
warnings
.push("No database roles matched the requested export filters.".to_string());
} else {
groups.extend(role_groups);
warnings.push(
"Role export omits passwords, memberships, and object grants in this v1 DDL export."
.to_string(),
);
}
} else {
warnings.push(
"Roles were requested but omitted because include_roles is disabled.".to_string(),
);
}
}
if request
.requested_families
.contains(&DdlObjectFamily::Databases)
{
if request.include_databases {
let database_groups = export_current_database(pool).await?;
if database_groups.is_empty() {
warnings.push(
"No database definition was available for the connected client.".to_string(),
);
} else {
groups.extend(database_groups);
warnings.push(
"Database export emits the connected database definition only and omits ownership, tablespace, and database-level grants."
.to_string(),
);
}
} else {
warnings.push(
"Databases were requested but omitted because include_databases is disabled."
.to_string(),
);
}
}
groups.sort_by(compare_exported_group_order);
let generated_at = chrono::Utc::now().to_rfc3339();
let objects: Vec<ExportedDdlObject> = groups
.iter()
.enumerate()
.map(|(index, group)| ExportedDdlObject {
family: group.family,
schema_name: group.schema_name.clone(),
object_name: group.object_name.clone(),
identity: group.identity.clone(),
ddl: group.statements.join("\n\n"),
order_key: format!(
"{:02}:{:04}:{}:{}:{}",
group.sort_weight,
index,
group.schema_name.clone().unwrap_or_default(),
group.object_name,
group.identity.clone().unwrap_or_default()
),
})
.collect();
let sql = build_output_sql(&generated_at, request, &schema_preambles, &objects);
Ok(DdlExportResult {
sql,
objects,
warnings,
})
}
fn needs_schema_dump(request: &DdlExportRequestNormalized) -> bool {
request.requested_families.iter().any(|family| {
matches!(
family,
DdlObjectFamily::Tables
| DdlObjectFamily::Views
| DdlObjectFamily::MaterializedViews
| DdlObjectFamily::Functions
| DdlObjectFamily::Types
| DdlObjectFamily::Extensions
| DdlObjectFamily::Sequences
)
})
}
async fn run_pg_dump_schema_only(
request: &DdlExportRequestNormalized,
client: &RegisteredClient,
) -> Result<String, String> {
let pg_tools = ensure_pg_tools()
.await
.map_err(|err| format!("pg_dump resolution failed: {err}"))?;
let connection_uri = resolve_registered_client_connection_uri(client).ok_or_else(|| {
format!(
"No usable PostgreSQL connection URI is configured for client '{}'.",
request.client_name
)
})?;
let (sanitized_uri, pg_password) = extract_pg_password(&connection_uri);
let mut cmd = Command::new(pg_tools.pg_dump);
if let Some(password) = pg_password {
cmd.env("PGPASSWORD", password);
}
cmd.args(pg_dump_export_cli_args(&sanitized_uri, request));
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let output = cmd
.output()
.await
.map_err(|err| format!("Failed to run pg_dump: {err}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
let code = output.status.code().unwrap_or(-1);
let detail = if !stderr.is_empty() {
stderr
} else if !stdout.is_empty() {
stdout
} else {
"No output from pg_dump.".to_string()
};
return Err(format!("pg_dump exited with code {code}: {detail}"));
}
String::from_utf8(output.stdout)
.map_err(|err| format!("pg_dump returned non-UTF8 schema output: {err}"))
}
fn pg_dump_export_cli_args(
pg_uri_safe: &str,
request: &DdlExportRequestNormalized,
) -> Vec<OsString> {
let mut args = vec![
OsString::from("--dbname"),
OsString::from(pg_uri_safe),
OsString::from("--schema-only"),
OsString::from("--no-owner"),
OsString::from("--no-privileges"),
OsString::from("--no-comments"),
];
if let Some(schemas) = &request.schemas {
for schema in schemas {
args.push(OsString::from("--schema"));
args.push(OsString::from(schema));
}
}
args
}
fn classify_dump_sql(
dump_sql: &str,
request: &DdlExportRequestNormalized,
) -> (Vec<SchemaPreamble>, Vec<ExportedDdlGroup>, Vec<String>) {
let mut preambles: Vec<SchemaPreamble> = Vec::new();
let mut warnings: Vec<String> = Vec::new();
let mut groups: HashMap<String, ExportedDdlGroup> = HashMap::new();
let mut order: Vec<String> = Vec::new();
for statement in split_pg_dump_sql_statements(dump_sql) {
let trimmed = statement.sql.trim();
if trimmed.is_empty() {
continue;
}
let classified_sql = strip_leading_sql_comments(trimmed);
if classified_sql.is_empty() {
continue;
}
if let Some(schema_name) = classify_schema_preamble(classified_sql) {
if schema_allowed(Some(&schema_name), request) {
preambles.push(SchemaPreamble {
schema_name,
ddl: classified_sql.to_string(),
});
}
continue;
}
let Some(classified) =
classify_schema_statement(classified_sql, statement.index, &mut warnings)
else {
continue;
};
if !request.requested_families.contains(&classified.family) {
continue;
}
if !schema_allowed(classified.schema_name.as_deref(), request) {
continue;
}
let group_key = format!(
"{:?}|{}|{}|{}|{}|{}",
classified.family,
classified.schema_name.clone().unwrap_or_default(),
classified.object_name,
classified.identity.clone().unwrap_or_default(),
classified.sort_schema_name.clone().unwrap_or_default(),
classified.sort_object_name
);
if let Some(group) = groups.get_mut(&group_key) {
group.statements.push(classified.ddl);
} else {
order.push(group_key.clone());
groups.insert(
group_key,
ExportedDdlGroup {
family: classified.family,
schema_name: classified.schema_name,
object_name: classified.object_name,
identity: classified.identity,
statements: vec![classified.ddl],
original_index: classified.original_index,
sort_weight: classified.sort_weight,
sort_schema_name: classified.sort_schema_name,
sort_object_name: classified.sort_object_name,
},
);
}
}
let ordered_groups = order
.into_iter()
.filter_map(|key| groups.remove(&key))
.collect::<Vec<_>>();
preambles.sort_by(|left, right| left.schema_name.cmp(&right.schema_name));
preambles.dedup_by(|left, right| left.schema_name == right.schema_name);
(preambles, ordered_groups, warnings)
}
fn strip_leading_sql_comments(statement: &str) -> &str {
let mut current = statement.trim_start();
loop {
if let Some(rest) = current.strip_prefix("--") {
if let Some(newline_index) = rest.find('\n') {
current = rest[newline_index + 1..].trim_start();
continue;
}
return "";
}
if let Some(rest) = current.strip_prefix("/*") {
if let Some(end_index) = rest.find("*/") {
current = rest[end_index + 2..].trim_start();
continue;
}
return "";
}
return current;
}
}
#[derive(Debug, Clone)]
struct ParsedStatement {
index: usize,
sql: String,
}
fn split_pg_dump_sql_statements(input: &str) -> Vec<ParsedStatement> {
let chars: Vec<char> = input.chars().collect();
let mut statements: Vec<ParsedStatement> = Vec::new();
let mut statement_start = 0usize;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut in_line_comment = false;
let mut in_block_comment = false;
let mut dollar_tag: Option<String> = None;
let mut idx = 0usize;
while idx < chars.len() {
let ch = chars[idx];
let next = chars.get(idx + 1).copied();
if in_line_comment {
if ch == '\n' {
in_line_comment = false;
}
idx += 1;
continue;
}
if in_block_comment {
if ch == '*' && next == Some('/') {
in_block_comment = false;
idx += 2;
} else {
idx += 1;
}
continue;
}
if let Some(tag) = &dollar_tag {
if ch == '$' {
let candidate = read_dollar_tag(&chars, idx);
if candidate.as_deref() == Some(tag.as_str()) {
let tag_len = tag.len();
dollar_tag = None;
idx += tag_len;
continue;
}
}
idx += 1;
continue;
}
if !in_single_quote && !in_double_quote {
if ch == '-' && next == Some('-') {
in_line_comment = true;
idx += 2;
continue;
}
if ch == '/' && next == Some('*') {
in_block_comment = true;
idx += 2;
continue;
}
if ch == '$'
&& let Some(tag) = read_dollar_tag(&chars, idx)
{
dollar_tag = Some(tag.clone());
idx += tag.len();
continue;
}
}
if ch == '\'' && !in_double_quote {
if in_single_quote && next == Some('\'') {
idx += 2;
continue;
}
in_single_quote = !in_single_quote;
idx += 1;
continue;
}
if ch == '"' && !in_single_quote {
if in_double_quote && next == Some('"') {
idx += 2;
continue;
}
in_double_quote = !in_double_quote;
idx += 1;
continue;
}
if ch == ';' && !in_single_quote && !in_double_quote {
let sql: String = chars[statement_start..=idx].iter().collect();
if !sql.trim().is_empty() {
statements.push(ParsedStatement {
index: statements.len(),
sql,
});
}
statement_start = idx + 1;
}
idx += 1;
}
if statement_start < chars.len() {
let sql: String = chars[statement_start..].iter().collect();
if !sql.trim().is_empty() {
statements.push(ParsedStatement {
index: statements.len(),
sql,
});
}
}
statements
}
fn read_dollar_tag(chars: &[char], start: usize) -> Option<String> {
if chars.get(start).copied() != Some('$') {
return None;
}
let mut idx = start + 1;
while idx < chars.len() {
let ch = chars[idx];
if ch == '$' {
let tag: String = chars[start..=idx].iter().collect();
let valid = tag[1..tag.len() - 1]
.chars()
.all(|value| value == '_' || value.is_ascii_alphanumeric());
return valid.then_some(tag);
}
if !(ch == '_' || ch.is_ascii_alphanumeric()) {
return None;
}
idx += 1;
}
None
}
fn classify_schema_preamble(statement: &str) -> Option<String> {
let lower = statement.trim_start().to_ascii_lowercase();
if !lower.starts_with("create schema") {
return None;
}
let rest = statement.trim_start()[lower.find("schema").unwrap_or(0) + "schema".len()..].trim();
let rest = rest.strip_prefix("IF NOT EXISTS ").unwrap_or(rest);
let (qualified, _) = read_qualified_identifier(rest)?;
split_qualified_identifier_parts(&qualified)
.last()
.map(|value| unquote_identifier(value))
}
fn classify_schema_statement(
statement: &str,
original_index: usize,
warnings: &mut Vec<String>,
) -> Option<ClassifiedStatement> {
let compact = statement.trim_start();
let lower = compact.to_ascii_lowercase();
if lower.starts_with("create extension") {
let rest = compact[lower.find("extension").unwrap_or(0) + "extension".len()..].trim();
let rest = rest.strip_prefix("IF NOT EXISTS ").unwrap_or(rest);
let (name, _) = read_qualified_identifier(rest)?;
return Some(ClassifiedStatement {
family: DdlObjectFamily::Extensions,
schema_name: None,
object_name: unquote_identifier(&name),
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Extensions),
sort_schema_name: None,
sort_object_name: unquote_identifier(&name),
});
}
if lower.starts_with("create type") || lower.starts_with("create domain") {
let keyword = if lower.starts_with("create type") {
"type"
} else {
"domain"
};
let rest = compact[lower.find(keyword).unwrap_or(0) + keyword.len()..].trim();
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Types,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Types),
});
}
if lower.starts_with("alter type") || lower.starts_with("alter domain") {
let keyword = if lower.starts_with("alter type") {
"type"
} else {
"domain"
};
let rest = compact[lower.find(keyword).unwrap_or(0) + keyword.len()..].trim();
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Types,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Types),
});
}
if lower.starts_with("create sequence") {
let rest = compact[lower.find("sequence").unwrap_or(0) + "sequence".len()..].trim();
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Sequences,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Sequences),
});
}
if lower.starts_with("alter sequence") {
let rest = compact[lower.find("sequence").unwrap_or(0) + "sequence".len()..].trim();
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
let is_owned_by_statement = extract_owned_by_sort_target(compact).is_some();
return Some(ClassifiedStatement {
family: DdlObjectFamily::Sequences,
sort_schema_name: schema_name.clone(),
sort_object_name: if is_owned_by_statement {
grouped_statement_sort_object_name(&object_name, "owned_by", original_index)
} else {
object_name.clone()
},
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Sequences),
});
}
if lower.starts_with("create table") {
let rest = compact[lower.find("table").unwrap_or(0) + "table".len()..].trim();
let rest = rest.strip_prefix("IF NOT EXISTS ").unwrap_or(rest);
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Tables,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Tables),
});
}
if lower.starts_with("alter table") {
let rest = compact[lower.find("table").unwrap_or(0) + "table".len()..].trim();
let rest = rest.strip_prefix("ONLY ").unwrap_or(rest);
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
let is_reference_alter = contains_references_keyword(compact);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Tables,
sort_schema_name: schema_name.clone(),
sort_object_name: if is_reference_alter {
grouped_statement_sort_object_name(&object_name, "references", original_index)
} else {
object_name.clone()
},
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Tables),
});
}
if lower.starts_with("create index")
|| lower.starts_with("create unique index")
|| lower.starts_with("create index concurrently")
|| lower.starts_with("create unique index concurrently")
{
let Some(qualified) = extract_on_target_qualified_name(compact) else {
warnings.push(format!(
"Skipped CREATE INDEX statement with unrecognized target: {compact}"
));
return None;
};
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Tables,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Tables),
});
}
if lower.starts_with("create materialized view") {
let rest = compact[lower.find("view").unwrap_or(0) + "view".len()..].trim();
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::MaterializedViews,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::MaterializedViews),
});
}
if lower.starts_with("alter materialized view") {
let rest = compact[lower.find("view").unwrap_or(0) + "view".len()..].trim();
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::MaterializedViews,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::MaterializedViews),
});
}
if lower.starts_with("create view") {
let rest = compact[lower.find("view").unwrap_or(0) + "view".len()..].trim();
let (qualified, _) = read_qualified_identifier(rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Views,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: None,
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Views),
});
}
if lower.starts_with("create function") || lower.starts_with("create or replace function") {
let function_rest =
compact[lower.find("function").unwrap_or(0) + "function".len()..].trim();
let (qualified, args, _) = parse_function_signature(function_rest)?;
let (schema_name, object_name) = split_qualified_identifier(&qualified);
return Some(ClassifiedStatement {
family: DdlObjectFamily::Functions,
sort_schema_name: schema_name.clone(),
sort_object_name: object_name.clone(),
schema_name,
object_name,
identity: Some(args),
ddl: compact.to_string(),
original_index,
sort_weight: family_sort_weight(DdlObjectFamily::Functions),
});
}
None
}
async fn export_roles(
pool: &PgPool,
request: &DdlExportRequestNormalized,
) -> Result<Vec<ExportedDdlGroup>, String> {
let rows = sqlx::query(
r#"
SELECT
rolname,
rolsuper,
rolinherit,
rolcreaterole,
rolcreatedb,
rolcanlogin,
rolreplication,
rolbypassrls,
rolconnlimit
FROM pg_roles
WHERE ($1::bool OR rolname NOT LIKE 'pg\_%' ESCAPE '\')
ORDER BY rolname ASC
"#,
)
.bind(request.include_system)
.fetch_all(pool)
.await
.map_err(|err| format!("Failed to query database roles: {err}"))?;
Ok(rows
.into_iter()
.enumerate()
.map(|(index, row)| {
let role_name = row.try_get::<String, _>("rolname").unwrap_or_default();
let ddl = format!(
"CREATE ROLE {} WITH {} {} {} {} {} {} {} CONNECTION LIMIT {};",
quote_identifier(&role_name),
if row.try_get::<bool, _>("rolsuper").unwrap_or(false) {
"SUPERUSER"
} else {
"NOSUPERUSER"
},
if row.try_get::<bool, _>("rolinherit").unwrap_or(false) {
"INHERIT"
} else {
"NOINHERIT"
},
if row.try_get::<bool, _>("rolcreaterole").unwrap_or(false) {
"CREATEROLE"
} else {
"NOCREATEROLE"
},
if row.try_get::<bool, _>("rolcreatedb").unwrap_or(false) {
"CREATEDB"
} else {
"NOCREATEDB"
},
if row.try_get::<bool, _>("rolcanlogin").unwrap_or(false) {
"LOGIN"
} else {
"NOLOGIN"
},
if row.try_get::<bool, _>("rolreplication").unwrap_or(false) {
"REPLICATION"
} else {
"NOREPLICATION"
},
if row.try_get::<bool, _>("rolbypassrls").unwrap_or(false) {
"BYPASSRLS"
} else {
"NOBYPASSRLS"
},
row.try_get::<i32, _>("rolconnlimit").unwrap_or(-1),
);
ExportedDdlGroup {
family: DdlObjectFamily::Roles,
schema_name: None,
object_name: role_name.clone(),
identity: None,
statements: vec![ddl],
original_index: index,
sort_weight: family_sort_weight(DdlObjectFamily::Roles),
sort_schema_name: None,
sort_object_name: role_name,
}
})
.collect())
}
async fn export_current_database(pool: &PgPool) -> Result<Vec<ExportedDdlGroup>, String> {
let rows = sqlx::query(
r#"
SELECT
datname,
pg_encoding_to_char(encoding) AS encoding,
datcollate,
datctype
FROM pg_database
WHERE datname = current_database()
"#,
)
.fetch_all(pool)
.await
.map_err(|err| format!("Failed to query current database metadata: {err}"))?;
Ok(rows
.into_iter()
.enumerate()
.map(|(index, row)| {
let datname = row.try_get::<String, _>("datname").unwrap_or_default();
let encoding = row
.try_get::<String, _>("encoding")
.unwrap_or_else(|_| "UTF8".to_string());
let datcollate = row.try_get::<String, _>("datcollate").unwrap_or_default();
let datctype = row.try_get::<String, _>("datctype").unwrap_or_default();
let ddl = format!(
"CREATE DATABASE {} WITH ENCODING = '{}' LC_COLLATE = '{}' LC_CTYPE = '{}';",
quote_identifier(&datname),
escape_sql_literal(&encoding),
escape_sql_literal(&datcollate),
escape_sql_literal(&datctype),
);
ExportedDdlGroup {
family: DdlObjectFamily::Databases,
schema_name: None,
object_name: datname.clone(),
identity: None,
statements: vec![ddl],
original_index: index,
sort_weight: family_sort_weight(DdlObjectFamily::Databases),
sort_schema_name: None,
sort_object_name: datname,
}
})
.collect())
}
fn build_output_sql(
generated_at: &str,
request: &DdlExportRequestNormalized,
preambles: &[SchemaPreamble],
objects: &[ExportedDdlObject],
) -> String {
let selected_schemas = request
.schemas
.clone()
.unwrap_or_else(|| vec!["all".to_string()])
.join(", ");
let selected_families = request
.object_families
.iter()
.map(|family| family_label(*family))
.collect::<Vec<_>>()
.join(", ");
let mut out = vec![
"-- Athena generated DDL".to_string(),
format!("-- Client: {}", request.client_name),
format!("-- Generated at: {generated_at}"),
format!("-- Schemas: {selected_schemas}"),
format!("-- Families: {selected_families}"),
format!(
"-- Roles included: {} | Databases included: {}",
if request.include_roles { "yes" } else { "no" },
if request.include_databases {
"yes"
} else {
"no"
}
),
String::new(),
];
if !preambles.is_empty() {
out.push("-- Schema preamble".to_string());
out.push(String::new());
for preamble in preambles {
out.push(preamble.ddl.trim().to_string());
out.push(String::new());
}
}
for object in objects {
let label = exported_object_label(object);
out.push(format!("-- {label}"));
out.push(object.ddl.trim().to_string());
out.push(String::new());
}
out.join("\n").trim_end().to_string() + "\n"
}
fn family_sort_weight(family: DdlObjectFamily) -> usize {
match family {
DdlObjectFamily::Extensions => 0,
DdlObjectFamily::Types => 1,
DdlObjectFamily::Sequences => 2,
DdlObjectFamily::Tables => 3,
DdlObjectFamily::Views => 4,
DdlObjectFamily::MaterializedViews => 5,
DdlObjectFamily::Functions => 6,
DdlObjectFamily::Roles => 7,
DdlObjectFamily::Databases => 8,
}
}
fn family_label(family: DdlObjectFamily) -> &'static str {
match family {
DdlObjectFamily::Tables => "tables",
DdlObjectFamily::Views => "views",
DdlObjectFamily::MaterializedViews => "materialized_views",
DdlObjectFamily::Functions => "functions",
DdlObjectFamily::Types => "types",
DdlObjectFamily::Roles => "roles",
DdlObjectFamily::Databases => "databases",
DdlObjectFamily::Extensions => "extensions",
DdlObjectFamily::Sequences => "sequences",
}
}
fn exported_object_label(object: &ExportedDdlObject) -> String {
let family = match object.family {
DdlObjectFamily::Tables => "Table",
DdlObjectFamily::Views => "View",
DdlObjectFamily::MaterializedViews => "Materialized View",
DdlObjectFamily::Functions => "Function",
DdlObjectFamily::Types => "Type",
DdlObjectFamily::Roles => "Role",
DdlObjectFamily::Databases => "Database",
DdlObjectFamily::Extensions => "Extension",
DdlObjectFamily::Sequences => "Sequence",
};
let qualified = if let Some(schema_name) = &object.schema_name {
format!("{schema_name}.{}", object.object_name)
} else {
object.object_name.clone()
};
if let Some(identity) = &object.identity {
format!("{family}: {qualified}({identity})")
} else {
format!("{family}: {qualified}")
}
}
fn resolve_registered_client_connection_uri(client: &RegisteredClient) -> Option<String> {
if let Some(uri) = client
.pg_uri
.as_ref()
.filter(|value| !value.trim().is_empty())
{
return Some(resolve_compatible_postgres_uri(uri));
}
if let Some(env_var) = client
.pg_uri_env_var
.as_ref()
.filter(|value| !value.trim().is_empty())
{
return Some(resolve_compatible_postgres_uri(&format!("${{{env_var}}}")));
}
client
.config_uri_template
.as_ref()
.filter(|value| !value.trim().is_empty())
.map(|value| resolve_compatible_postgres_uri(value))
}
fn extract_pg_password(pg_uri: &str) -> (String, Option<String>) {
let prefix = if pg_uri.starts_with("postgresql://") {
"postgresql://"
} else if pg_uri.starts_with("postgres://") {
"postgres://"
} else {
return (pg_uri.to_string(), None);
};
let after_scheme = &pg_uri[prefix.len()..];
if let Some(at_pos) = after_scheme.rfind('@') {
let userinfo = &after_scheme[..at_pos];
let after_at = &after_scheme[at_pos..];
if let Some(colon_pos) = userinfo.find(':') {
let user = &userinfo[..colon_pos];
let password = decode_percent_component(&userinfo[colon_pos + 1..])
.unwrap_or_else(|| userinfo[colon_pos + 1..].to_string());
return (format!("{prefix}{user}{after_at}"), Some(password));
}
}
(pg_uri.to_string(), None)
}
fn decode_percent_component(value: &str) -> Option<String> {
let bytes = value.as_bytes();
let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
let mut idx = 0usize;
while idx < bytes.len() {
if bytes[idx] == b'%' {
let hi = *bytes.get(idx + 1)?;
let lo = *bytes.get(idx + 2)?;
out.push((decode_hex_nibble(hi)? << 4) | decode_hex_nibble(lo)?);
idx += 3;
continue;
}
out.push(bytes[idx]);
idx += 1;
}
String::from_utf8(out).ok()
}
fn decode_hex_nibble(value: u8) -> Option<u8> {
match value {
b'0'..=b'9' => Some(value - b'0'),
b'a'..=b'f' => Some(value - b'a' + 10),
b'A'..=b'F' => Some(value - b'A' + 10),
_ => None,
}
}
fn compare_exported_group_order(left: &ExportedDdlGroup, right: &ExportedDdlGroup) -> Ordering {
match (
is_schema_dump_family(left.family),
is_schema_dump_family(right.family),
) {
(true, true) => left.original_index.cmp(&right.original_index),
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
(false, false) => left
.sort_weight
.cmp(&right.sort_weight)
.then_with(|| left.sort_schema_name.cmp(&right.sort_schema_name))
.then_with(|| left.sort_object_name.cmp(&right.sort_object_name))
.then_with(|| left.identity.cmp(&right.identity))
.then_with(|| left.original_index.cmp(&right.original_index)),
}
}
fn is_schema_dump_family(family: DdlObjectFamily) -> bool {
!matches!(family, DdlObjectFamily::Roles | DdlObjectFamily::Databases)
}
fn schema_allowed(schema_name: Option<&str>, request: &DdlExportRequestNormalized) -> bool {
if !request.include_system
&& let Some(schema_name) = schema_name
&& is_system_schema(schema_name)
{
return false;
}
if let Some(filter) = &request.schema_filter
&& let Some(schema_name) = schema_name
{
return filter.contains(&normalize_schema_name(schema_name));
}
true
}
fn is_system_schema(schema_name: &str) -> bool {
matches!(
normalize_schema_name(schema_name).as_str(),
"pg_catalog" | "information_schema" | "pg_toast"
)
}
fn normalize_schema_name(value: &str) -> String {
unquote_identifier(value).trim().to_ascii_lowercase()
}
fn unquote_identifier(value: &str) -> String {
let trimmed = value.trim();
if trimmed.starts_with('"') && trimmed.ends_with('"') && trimmed.len() >= 2 {
trimmed[1..trimmed.len() - 1].replace("\"\"", "\"")
} else {
trimmed.to_string()
}
}
fn split_qualified_identifier(value: &str) -> (Option<String>, String) {
let parts = split_qualified_identifier_parts(value);
match parts.as_slice() {
[] => (None, value.trim().to_string()),
[only] => (Some("public".to_string()), unquote_identifier(only)),
[schema, object] => (Some(unquote_identifier(schema)), unquote_identifier(object)),
many => (
Some(unquote_identifier(&many[many.len() - 2])),
unquote_identifier(&many[many.len() - 1]),
),
}
}
fn extract_owned_by_sort_target(statement: &str) -> Option<(Option<String>, String)> {
let lower = statement.to_ascii_lowercase();
let owned_by_index = lower.find("owned by")?;
let owned_by = statement[owned_by_index + "owned by".len()..].trim();
if owned_by.eq_ignore_ascii_case("none") {
return None;
}
let (qualified, _) = read_qualified_identifier(owned_by)?;
let parts = split_qualified_identifier_parts(&qualified);
match parts.as_slice() {
[] => None,
[table] => Some((Some("public".to_string()), unquote_identifier(table))),
[table, _column] => Some((Some("public".to_string()), unquote_identifier(table))),
[schema, table, ..] => Some((Some(unquote_identifier(schema)), unquote_identifier(table))),
}
}
fn grouped_statement_sort_object_name(
object_name: &str,
group_kind: &str,
original_index: usize,
) -> String {
format!("{object_name}__{group_kind}__{original_index:06}")
}
fn contains_references_keyword(statement: &str) -> bool {
statement
.split_whitespace()
.any(|token| token.eq_ignore_ascii_case("references"))
}
fn split_qualified_identifier_parts(value: &str) -> Vec<String> {
let mut parts: Vec<String> = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
let chars: Vec<char> = value.trim().chars().collect();
let mut idx = 0usize;
while idx < chars.len() {
let ch = chars[idx];
if ch == '"' {
if in_quotes && chars.get(idx + 1).copied() == Some('"') {
current.push('"');
idx += 2;
continue;
}
in_quotes = !in_quotes;
current.push(ch);
idx += 1;
continue;
}
if ch == '.' && !in_quotes {
if !current.trim().is_empty() {
parts.push(current.trim().to_string());
}
current.clear();
idx += 1;
continue;
}
current.push(ch);
idx += 1;
}
if !current.trim().is_empty() {
parts.push(current.trim().to_string());
}
parts
}
fn read_qualified_identifier(input: &str) -> Option<(String, usize)> {
let chars: Vec<char> = input.chars().collect();
let mut idx = 0usize;
while idx < chars.len() && chars[idx].is_whitespace() {
idx += 1;
}
let start = idx;
let mut in_quotes = false;
while idx < chars.len() {
let ch = chars[idx];
if ch == '"' {
if in_quotes && chars.get(idx + 1).copied() == Some('"') {
idx += 2;
continue;
}
in_quotes = !in_quotes;
idx += 1;
continue;
}
if !in_quotes && (ch.is_whitespace() || ch == '(') {
break;
}
idx += 1;
}
let value: String = chars[start..idx].iter().collect();
(!value.trim().is_empty()).then_some((value.trim().to_string(), idx))
}
fn parse_function_signature(input: &str) -> Option<(String, String, usize)> {
let (qualified, mut idx) = read_qualified_identifier(input)?;
let chars: Vec<char> = input.chars().collect();
while idx < chars.len() && chars[idx].is_whitespace() {
idx += 1;
}
if chars.get(idx).copied() != Some('(') {
return None;
}
let args_start = idx + 1;
let mut depth = 1isize;
idx += 1;
let mut in_single = false;
let mut in_double = false;
let mut dollar_tag: Option<String> = None;
while idx < chars.len() {
let ch = chars[idx];
if let Some(tag) = &dollar_tag {
if ch == '$' {
let candidate = read_dollar_tag(&chars, idx);
if candidate.as_deref() == Some(tag.as_str()) {
let tag_len = tag.len();
dollar_tag = None;
idx += tag_len;
continue;
}
}
idx += 1;
continue;
}
if !in_single
&& !in_double
&& ch == '$'
&& let Some(tag) = read_dollar_tag(&chars, idx)
{
dollar_tag = Some(tag.clone());
idx += tag.len();
continue;
}
if ch == '\'' && !in_double {
if in_single && chars.get(idx + 1).copied() == Some('\'') {
idx += 2;
continue;
}
in_single = !in_single;
idx += 1;
continue;
}
if ch == '"' && !in_single {
if in_double && chars.get(idx + 1).copied() == Some('"') {
idx += 2;
continue;
}
in_double = !in_double;
idx += 1;
continue;
}
if !in_single && !in_double {
if ch == '(' {
depth += 1;
} else if ch == ')' {
depth -= 1;
if depth == 0 {
let args: String = chars[args_start..idx].iter().collect();
return Some((qualified, args.trim().to_string(), idx + 1));
}
}
}
idx += 1;
}
None
}
fn extract_on_target_qualified_name(statement: &str) -> Option<String> {
let lower = statement.to_ascii_lowercase();
let on_index = lower.find(" on ")?;
let rest = &statement[on_index + 4..];
let (qualified, _) = read_qualified_identifier(rest)?;
Some(qualified)
}
fn quote_identifier(value: &str) -> String {
format!("\"{}\"", value.replace('"', "\"\""))
}
fn escape_sql_literal(value: &str) -> String {
value.replace('\'', "''")
}
#[cfg(test)]
mod tests {
use super::{
DdlExportRequestNormalized, DdlObjectFamily, classify_dump_sql,
compare_exported_group_order, exported_object_label, extract_pg_password,
family_sort_weight, split_pg_dump_sql_statements,
};
use crate::api::management::types::ExportedDdlObject;
use std::collections::HashSet;
fn request(families: &[DdlObjectFamily]) -> DdlExportRequestNormalized {
DdlExportRequestNormalized {
client_name: "athena_logging".to_string(),
schemas: Some(vec!["public".to_string()]),
schema_filter: Some(HashSet::from(["public".to_string()])),
object_families: families.to_vec(),
requested_families: families.iter().copied().collect(),
include_system: false,
include_roles: false,
include_databases: false,
format: "sql".to_string(),
}
}
#[test]
fn splitter_handles_dollar_quoted_function_bodies() {
let sql = r#"
CREATE FUNCTION public.echo(value text) RETURNS text
LANGUAGE plpgsql
AS $function$
BEGIN
RETURN value || ';';
END;
$function$;
CREATE VIEW public.sample AS SELECT 1 AS id;
"#;
let statements = split_pg_dump_sql_statements(sql);
assert_eq!(statements.len(), 2);
assert!(statements[0].sql.contains("RETURN value || ';'"));
}
#[test]
fn classify_dump_groups_requested_families() {
let sql = r#"
CREATE SCHEMA analytics;
CREATE TYPE public.status_enum AS ENUM ('pending', 'done');
CREATE TABLE public.users (id uuid);
ALTER TABLE ONLY public.users ADD CONSTRAINT users_pkey PRIMARY KEY (id);
CREATE INDEX users_id_idx ON public.users USING btree (id);
CREATE VIEW public.active_users AS SELECT id FROM public.users;
CREATE MATERIALIZED VIEW public.rollup AS SELECT id FROM public.users;
CREATE FUNCTION public.echo(value text) RETURNS text LANGUAGE sql AS $$ SELECT value; $$;
"#;
let (_, groups, warnings) = classify_dump_sql(
sql,
&request(&[
DdlObjectFamily::Tables,
DdlObjectFamily::Views,
DdlObjectFamily::MaterializedViews,
DdlObjectFamily::Functions,
DdlObjectFamily::Types,
]),
);
assert!(warnings.is_empty());
assert_eq!(groups.len(), 5);
let users = groups
.iter()
.find(|group| group.object_name == "users")
.expect("users table group");
assert_eq!(users.family, DdlObjectFamily::Tables);
assert_eq!(users.statements.len(), 3);
}
#[test]
fn classify_dump_skips_pg_dump_toc_comments() {
let sql = r#"
-- Name: users; Type: TABLE; Schema: public; Owner: postgres
CREATE TABLE public.users (id uuid);
-- Name: users users_pkey; Type: CONSTRAINT; Schema: public; Owner: postgres
ALTER TABLE ONLY public.users ADD CONSTRAINT users_pkey PRIMARY KEY (id);
"#;
let (_, groups, warnings) = classify_dump_sql(sql, &request(&[DdlObjectFamily::Tables]));
assert!(warnings.is_empty());
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].object_name, "users");
assert_eq!(groups[0].statements.len(), 2);
}
#[test]
fn classify_dump_keeps_owned_by_after_table_create() {
let sql = r#"
CREATE SEQUENCE public.users_id_seq;
CREATE TABLE public.users (id integer NOT NULL);
ALTER SEQUENCE public.users_id_seq OWNED BY public.users.id;
"#;
let (_, mut groups, warnings) = classify_dump_sql(
sql,
&request(&[DdlObjectFamily::Tables, DdlObjectFamily::Sequences]),
);
assert!(warnings.is_empty());
groups.sort_by(compare_exported_group_order);
let statements = groups
.iter()
.flat_map(|group| group.statements.iter().map(String::as_str))
.collect::<Vec<_>>();
assert_eq!(statements.len(), 3);
assert!(statements[0].starts_with("CREATE SEQUENCE public.users_id_seq"));
assert!(statements[1].starts_with("CREATE TABLE public.users"));
assert!(statements[2].starts_with("ALTER SEQUENCE public.users_id_seq OWNED BY"));
}
#[test]
fn classify_dump_keeps_foreign_keys_after_referenced_table_create() {
let sql = r#"
CREATE TABLE public.child (parent_id uuid NOT NULL);
CREATE TABLE public.parent (id uuid NOT NULL);
ALTER TABLE ONLY public.child ADD CONSTRAINT child_parent_fkey FOREIGN KEY (parent_id) REFERENCES public.parent(id);
"#;
let (_, mut groups, warnings) =
classify_dump_sql(sql, &request(&[DdlObjectFamily::Tables]));
assert!(warnings.is_empty());
groups.sort_by(compare_exported_group_order);
let statements = groups
.iter()
.flat_map(|group| group.statements.iter().map(String::as_str))
.collect::<Vec<_>>();
assert_eq!(statements.len(), 3);
assert!(statements[0].starts_with("CREATE TABLE public.child"));
assert!(statements[1].starts_with("CREATE TABLE public.parent"));
assert!(statements[2].starts_with("ALTER TABLE ONLY public.child ADD CONSTRAINT"));
}
#[test]
fn extract_pg_password_decodes_percent_encoding() {
let (sanitized, password) =
extract_pg_password("postgres://user:p%40ss%2Fword@localhost/example");
assert_eq!(sanitized, "postgres://user@localhost/example");
assert_eq!(password.as_deref(), Some("p@ss/word"));
}
#[test]
fn object_label_is_family_aware() {
let label = exported_object_label(&ExportedDdlObject {
family: DdlObjectFamily::Functions,
schema_name: Some("public".to_string()),
object_name: "echo".to_string(),
identity: Some("value text".to_string()),
ddl: "CREATE FUNCTION public.echo(value text) ...".to_string(),
order_key: format!("{:02}:0000", family_sort_weight(DdlObjectFamily::Functions)),
});
assert_eq!(label, "Function: public.echo(value text)");
}
}