use core::num::NonZeroU16;
use std::collections::BTreeSet;
use sqlx::Row as _;
use sqlx::SqlSafeStr as _;
use super::{ConcurrentlyConfig, Error, FillFactor, SqlFragment};
use crate::identifier::{AccessMethod, Index, QualifiedTable, Table};
pub struct Input {
pub qualified_table: QualifiedTable,
pub index: Index,
pub key_expression: SqlFragment,
pub unique: bool,
pub method: AccessMethod,
pub include: Option<SqlFragment>,
pub where_clause: Option<SqlFragment>,
pub fillfactor: Option<FillFactor>,
pub concurrently: ConcurrentlyConfig,
}
#[derive(Debug)]
pub struct Result {
pub elapsed: std::time::Duration,
pub partitions: Vec<Partition>,
}
impl std::fmt::Display for Result {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Created {} partition indexes in {:.2}s",
self.partitions.len(),
self.elapsed.as_secs_f64()
)
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Partition {
#[serde(flatten)]
pub qualified_table: QualifiedTable,
pub index: Index,
#[serde(deserialize_with = "super::sql_str_serde::deserialize")]
pub create_index_statement: sqlx::SqlStr,
#[serde(deserialize_with = "super::sql_str_serde::deserialize")]
pub attach_statement: sqlx::SqlStr,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct PartitionRow {
#[serde(flatten)]
qualified_table: QualifiedTable,
index: Index,
#[serde(deserialize_with = "super::sql_str_serde::deserialize")]
create_index_statement: sqlx::SqlStr,
#[serde(deserialize_with = "super::sql_str_serde::deserialize")]
create_index_statement_concurrently: sqlx::SqlStr,
#[serde(deserialize_with = "super::sql_str_serde::deserialize")]
attach_statement: sqlx::SqlStr,
}
pub struct Statements {
pub parent_create: sqlx::SqlStr,
pub partitions: Vec<Partition>,
}
pub async fn fetch_statements(
config: &crate::Config,
input: &Input,
) -> core::result::Result<Statements, Error> {
let method = input.method.as_str();
let unique_keyword = if input.unique { "UNIQUE " } else { "" };
let key_expression = input.key_expression.as_str();
let include_clause = input
.include
.as_ref()
.map(|i| i.as_str())
.unwrap_or_default();
let where_clause = input
.where_clause
.as_ref()
.map(|w| w.as_str())
.unwrap_or_default();
let fillfactor = input
.fillfactor
.map(|value| value.as_u8().to_string())
.unwrap_or_default();
let (parent_create, partitions) = config
.with_sqlx_connection(async |connection| {
let row = sqlx::query(indoc::indoc! {"
WITH
params
( parent_table
, schema_name
, parent_index
, access_method
, unique_keyword
, key_expression
, include_clause
, fillfactor
, where_clause
) AS (
VALUES
( $1::text
, $2::text
, $3::text
, $4::text
, $5::text
, $6::text
, $7::text
, $8::text
, $9::text
)
)
, fragments AS (
SELECT
derived.include_clause
, derived.storage_clause
, derived.where_clause
, format
( 'CREATE %sINDEX %I ON ONLY %I.%I USING %I (%s)%s%s%s'
, params.unique_keyword
, params.parent_index
, params.schema_name
, params.parent_table
, params.access_method
, params.key_expression
, derived.include_clause
, derived.storage_clause
, derived.where_clause
) AS parent_create_statement
, format
( 'CREATE %sINDEX %%I ON %%I.%%I USING %I (%s)%s%s%s'
, params.unique_keyword
, params.access_method
, params.key_expression
, derived.include_clause
, derived.storage_clause
, derived.where_clause
) AS create_index_template
, format
( 'CREATE %sINDEX CONCURRENTLY %%I ON %%I.%%I USING %I (%s)%s%s%s'
, params.unique_keyword
, params.access_method
, params.key_expression
, derived.include_clause
, derived.storage_clause
, derived.where_clause
) AS create_index_template_concurrently
, format
( 'ALTER INDEX %I.%I ATTACH PARTITION %%I.%%I'
, params.schema_name
, params.parent_index
) AS attach_index_template
FROM
params
CROSS JOIN LATERAL (
SELECT
CASE WHEN params.include_clause = '' THEN '' ELSE ' INCLUDE (' || params.include_clause || ')' END
, CASE WHEN params.fillfactor = '' THEN '' ELSE ' WITH (fillfactor = ' || params.fillfactor || ')' END
, CASE WHEN params.where_clause = '' THEN '' ELSE ' WHERE ' || params.where_clause END
) AS derived(include_clause, storage_clause, where_clause)
)
, partitions AS (
SELECT
child_namespace.nspname AS schema
, child_class.relname AS table
, derived.partition_index_name AS index
, format
( fragments.create_index_template
, derived.partition_index_name
, child_namespace.nspname
, child_class.relname
) AS create_index_statement
, format
( fragments.create_index_template_concurrently
, derived.partition_index_name
, child_namespace.nspname
, child_class.relname
) AS create_index_statement_concurrently
, format
( fragments.attach_index_template
, child_namespace.nspname
, derived.partition_index_name
) AS attach_statement
FROM
params
CROSS JOIN
fragments
CROSS JOIN
pg_inherits
JOIN
pg_class AS parent_class
ON
parent_class.oid = pg_inherits.inhparent
JOIN
pg_class AS child_class
ON
child_class.oid = pg_inherits.inhrelid
CROSS JOIN LATERAL (
SELECT
CASE
WHEN octet_length(base_name) <= 63 THEN base_name
ELSE left(base_name, 54) || '_' || substr(md5(base_name), 1, 8)
END AS partition_index_name
FROM (
SELECT params.parent_index || '_' || child_class.relname AS base_name
) AS base
) AS derived(partition_index_name)
JOIN
pg_namespace AS parent_namespace
ON
parent_namespace.oid = parent_class.relnamespace
JOIN
pg_namespace AS child_namespace
ON
child_namespace.oid = child_class.relnamespace
WHERE
parent_class.relkind = 'p'
AND
parent_class.relname = params.parent_table
AND
parent_namespace.nspname = params.schema_name
)
SELECT
fragments.parent_create_statement
, (
SELECT
COALESCE(jsonb_agg(
jsonb_build_object
( 'schema', partitions.schema
, 'table', partitions.table
, 'index', partitions.index
, 'create_index_statement', partitions.create_index_statement
, 'create_index_statement_concurrently', partitions.create_index_statement_concurrently
, 'attach_statement', partitions.attach_statement
)
ORDER BY partitions.schema, partitions.table
), '[]'::jsonb)
FROM
partitions
) AS partitions
FROM
fragments
"})
.bind(input.qualified_table.table.as_str())
.bind(input.qualified_table.schema.as_str())
.bind(input.index.as_str())
.bind(method)
.bind(unique_keyword)
.bind(key_expression)
.bind(include_clause)
.bind(fillfactor.clone())
.bind(where_clause)
.fetch_one(connection)
.await?;
let parent_create: String = row.get("parent_create_statement");
let partitions_json: serde_json::Value = row.get("partitions");
let partitions: Vec<PartitionRow> =
serde_json::from_value(partitions_json).expect("valid partition JSON from database");
Ok::<_, sqlx::Error>((parent_create, partitions))
})
.await??;
if partitions.is_empty() {
return Err(Error::NoPartitions {
qualified_table: input.qualified_table.clone(),
});
}
let partition_tables: BTreeSet<Table> = partitions
.iter()
.map(|partition| partition.qualified_table.table.clone())
.collect();
validate_concurrently_tables(&input.concurrently, &partition_tables)?;
let partitions = partitions
.into_iter()
.map(|partition| {
let create_index_statement = if input
.concurrently
.is_concurrent_for(&partition.qualified_table.table)
{
partition.create_index_statement_concurrently
} else {
partition.create_index_statement
};
Partition {
qualified_table: partition.qualified_table,
index: partition.index,
create_index_statement,
attach_statement: partition.attach_statement,
}
})
.collect();
Ok(Statements {
parent_create: sqlx::AssertSqlSafe(parent_create).into_sql_str(),
partitions,
})
}
fn validate_concurrently_tables(
concurrently: &ConcurrentlyConfig,
partition_tables: &BTreeSet<Table>,
) -> core::result::Result<(), Error> {
let requested_tables = match concurrently {
ConcurrentlyConfig::Except(tables) => Some(tables),
ConcurrentlyConfig::None | ConcurrentlyConfig::All => None,
};
let Some(requested_tables) = requested_tables else {
return Ok(());
};
let unknown_tables: BTreeSet<Table> = requested_tables
.difference(partition_tables)
.cloned()
.collect();
if unknown_tables.is_empty() {
return Ok(());
}
Err(Error::UnknownPartitionTables {
tables: unknown_tables,
})
}
async fn worker(
config: std::sync::Arc<crate::Config>,
queue: std::sync::Arc<tokio::sync::Mutex<std::collections::VecDeque<Partition>>>,
) -> core::result::Result<(), Error> {
Ok(config
.as_ref()
.with_sqlx_connection(async move |connection| {
loop {
let partition = queue.lock().await.pop_front();
let Some(partition) = partition else {
break;
};
log::info!(
"Creating index {} on {}",
partition.index,
partition.qualified_table
);
sqlx::raw_sql(partition.create_index_statement.clone())
.execute(&mut *connection)
.await?;
log::info!(
"Created index {} on {}",
partition.index,
partition.qualified_table
);
}
Ok::<(), sqlx::Error>(())
})
.await??)
}
pub async fn run(
config: &crate::Config,
input: &Input,
jobs: NonZeroU16,
dry_run: bool,
) -> core::result::Result<Result, Error> {
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
let start = std::time::Instant::now();
let Statements {
parent_create,
partitions,
..
} = fetch_statements(config, input).await?;
if dry_run {
for partition in &partitions {
log::info!("[dry-run] {};", partition.create_index_statement.as_str());
}
log::info!("[dry-run] {};", parent_create.as_str());
for partition in &partitions {
log::info!("[dry-run] {};", partition.attach_statement.as_str());
}
} else {
let partitions_for_workers = partitions.clone();
let shared_config = Arc::new(config.clone());
let shared_queue = Arc::new(Mutex::new(VecDeque::from(partitions_for_workers)));
let mut join_set = JoinSet::new();
for _ in 0..jobs.get() {
let worker_config = Arc::clone(&shared_config);
let worker_queue = Arc::clone(&shared_queue);
join_set.spawn(async move { worker(worker_config, worker_queue).await });
}
while let Some(result) = join_set.join_next().await {
match result {
Ok(worker_result) => worker_result?,
Err(join_error) => return Err(Error::WorkerPanic(join_error)),
}
}
config
.with_sqlx_connection(async |connection| {
log::info!("Creating parent index {}", input.index);
sqlx::raw_sql(parent_create.clone())
.execute(&mut *connection)
.await?;
log::info!("Created parent index {}", input.index);
for partition in &partitions {
sqlx::raw_sql(partition.attach_statement.clone())
.execute(&mut *connection)
.await?;
}
Ok::<(), sqlx::Error>(())
})
.await??;
}
Ok(Result {
elapsed: start.elapsed(),
partitions,
})
}