use core::num::NonZeroU16;
use core::str::FromStr as _;
use std::collections::BTreeSet;
use sqlx::Row as _;
use sqlx::SqlSafeStr as _;
use crate::identifier::{QualifiedTable, Schema, Table};
#[derive(Debug)]
struct AnalyzeTask {
qualified_table: QualifiedTable,
statement: sqlx::SqlStr,
}
#[derive(Debug, Clone)]
pub enum Schemas {
All,
Specific(BTreeSet<Schema>),
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Connection(#[from] crate::sqlx::ConnectionError),
#[error("worker task panicked: {0}")]
WorkerPanic(tokio::task::JoinError),
#[error("SQL error: {0}")]
Sql(#[from] sqlx::Error),
}
#[derive(Debug)]
pub struct Result {
pub elapsed: std::time::Duration,
pub table_count: u64,
}
pub async fn run_all(
config: &crate::Config,
schemas: &Schemas,
jobs: NonZeroU16,
) -> core::result::Result<Result, Error> {
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
let start = std::time::Instant::now();
let tasks = fetch_tasks(config, schemas).await?;
let table_count = u64::try_from(tasks.len()).expect("task count fits in u64");
let shared_config = Arc::new(config.clone());
let shared_queue = Arc::new(Mutex::new(VecDeque::from(tasks)));
let mut join_set = JoinSet::new();
for _ in 0..jobs.get() {
let worker_config = Arc::clone(&shared_config);
let worker_queue = Arc::clone(&shared_queue);
join_set.spawn(async move { worker(worker_config, worker_queue).await });
}
while let Some(result) = join_set.join_next().await {
match result {
Ok(worker_result) => worker_result?,
Err(join_error) => return Err(Error::WorkerPanic(join_error)),
}
}
Ok(Result {
elapsed: start.elapsed(),
table_count,
})
}
async fn worker(
config: std::sync::Arc<crate::Config>,
queue: std::sync::Arc<tokio::sync::Mutex<std::collections::VecDeque<AnalyzeTask>>>,
) -> core::result::Result<(), Error> {
config
.as_ref()
.with_sqlx_connection(async move |connection| {
loop {
let task = queue.lock().await.pop_front();
let Some(task) = task else {
break;
};
log::info!("Analyzing {}", task.qualified_table);
sqlx::raw_sql(task.statement.clone())
.execute(&mut *connection)
.await?;
log::info!("Analyzed {}", task.qualified_table);
}
Ok(())
})
.await?
}
async fn fetch_tasks(
config: &crate::Config,
schemas: &Schemas,
) -> core::result::Result<Vec<AnalyzeTask>, Error> {
config
.with_sqlx_connection(async |connection| {
let rows = match schemas {
Schemas::All => {
sqlx::query(indoc::indoc! {"
SELECT
pg_tables.schemaname AS schema_name
, pg_tables.tablename AS table_name
, format('ANALYZE %I.%I', pg_tables.schemaname, pg_tables.tablename) AS statement
FROM
pg_tables
JOIN
pg_class ON pg_class.relname = pg_tables.tablename
JOIN
pg_namespace ON pg_namespace.oid = pg_class.relnamespace AND pg_namespace.nspname = pg_tables.schemaname
WHERE
pg_class.relkind != 'p'
ORDER BY
pg_tables.schemaname
, pg_tables.tablename
"})
.fetch_all(connection)
.await?
}
Schemas::Specific(schema_set) => {
let schema_names: Vec<&str> = schema_set.iter().map(Schema::as_ref).collect();
sqlx::query(indoc::indoc! {"
SELECT
pg_tables.schemaname AS schema_name
, pg_tables.tablename AS table_name
, format('ANALYZE %I.%I', pg_tables.schemaname, pg_tables.tablename) AS statement
FROM
pg_tables
JOIN
pg_class ON pg_class.relname = pg_tables.tablename
JOIN
pg_namespace ON pg_namespace.oid = pg_class.relnamespace AND pg_namespace.nspname = pg_tables.schemaname
WHERE
pg_class.relkind != 'p'
AND
pg_tables.schemaname = ANY($1)
ORDER BY
pg_tables.schemaname
, pg_tables.tablename
"})
.bind(&schema_names)
.fetch_all(connection)
.await?
}
};
let tasks = rows
.into_iter()
.map(|row| {
let schema: String = row.get("schema_name");
let table: String = row.get("table_name");
let statement: String = row.get("statement");
AnalyzeTask {
qualified_table: QualifiedTable {
schema: Schema::from_str(&schema)
.expect("schema name from database should be valid"),
table: Table::from_str(&table)
.expect("table name from database should be valid"),
},
statement: sqlx::AssertSqlSafe(statement).into_sql_str(),
}
})
.collect();
Ok(tasks)
})
.await?
}