use once_cell::sync::Lazy;
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
type TableSchemaCache = HashMap<String, HashMap<String, String>>;
type TableUniqueConstraintsCache = HashMap<String, Vec<UniqueConstraintMetadata>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UniqueConstraintMetadata {
pub constraint_name: String,
pub columns: Vec<String>,
}
static TABLE_COLUMN_TYPE_CACHE: Lazy<Arc<RwLock<TableSchemaCache>>> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
static TABLE_UNIQUE_CONSTRAINT_CACHE: Lazy<Arc<RwLock<TableUniqueConstraintsCache>>> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
fn metadata_cache_key(schema_name: &str, table_name: &str) -> String {
format!(
"{}.{}",
schema_name.to_ascii_lowercase(),
table_name.to_ascii_lowercase()
)
}
pub async fn get_table_column_types(
pool: &PgPool,
schema_name: &str,
table_name: &str,
) -> Result<HashMap<String, String>, sqlx::Error> {
let cache_key: String = metadata_cache_key(schema_name, table_name);
{
let cache: RwLockReadGuard<'_, HashMap<String, HashMap<String, String>>> =
TABLE_COLUMN_TYPE_CACHE.read().await;
if let Some(columns) = cache.get(&cache_key).or_else(|| cache.get(table_name)) {
return Ok(columns.clone());
}
}
let rows: Vec<(String, String, String)> = sqlx::query_as::<_, (String, String, String)>(
r#"
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = $1
AND table_name = $2
"#,
)
.bind(schema_name)
.bind(table_name)
.fetch_all(pool)
.await?;
let columns: HashMap<String, String> = rows
.into_iter()
.map(|(column, data_type, udt_name)| {
(
column.to_ascii_lowercase(),
format!("{}|{}", data_type, udt_name),
)
})
.collect::<HashMap<_, _>>();
let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
TABLE_COLUMN_TYPE_CACHE.write().await;
cache.insert(cache_key, columns.clone());
Ok(columns)
}
pub fn postgres_column_descriptor_is_bigint(descriptor: &str) -> bool {
let mut parts = descriptor.split('|');
let data_type = parts.next().map(str::trim);
let udt_name = parts.next().map(str::trim);
data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("bigint"))
|| udt_name.is_some_and(|u| u.eq_ignore_ascii_case("int8"))
}
pub fn postgres_column_descriptor_is_timestamptz(descriptor: &str) -> bool {
let mut parts: std::str::Split<'_, char> = descriptor.split('|');
let data_type: Option<&str> = parts.next().map(str::trim);
let udt_name: Option<&str> = parts.next().map(str::trim);
data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("timestamp with time zone"))
|| udt_name.is_some_and(|u| u.eq_ignore_ascii_case("timestamptz"))
}
pub async fn get_public_table_column_types(
pool: &PgPool,
table_name: &str,
) -> Result<HashMap<String, String>, sqlx::Error> {
get_table_column_types(pool, "public", table_name).await
}
pub async fn get_public_table_unique_constraints(
pool: &PgPool,
table_name: &str,
) -> Result<Vec<UniqueConstraintMetadata>, sqlx::Error> {
{
let cache: RwLockReadGuard<'_, TableUniqueConstraintsCache> =
TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
if let Some(constraints) = cache.get(table_name) {
return Ok(constraints.clone());
}
}
let rows: Vec<(String, String)> = sqlx::query_as::<_, (String, String)>(
r#"
SELECT tc.constraint_name, kcu.column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = 'public'
AND tc.table_name = $1
AND tc.constraint_type = 'UNIQUE'
ORDER BY tc.constraint_name, kcu.ordinal_position
"#,
)
.bind(table_name)
.fetch_all(pool)
.await?;
let mut grouped: HashMap<String, Vec<String>> = HashMap::new();
for (constraint_name, column_name) in rows {
grouped
.entry(constraint_name)
.or_default()
.push(column_name);
}
let mut constraints: Vec<UniqueConstraintMetadata> = grouped
.into_iter()
.map(|(constraint_name, columns)| UniqueConstraintMetadata {
constraint_name,
columns,
})
.collect();
constraints.sort_by(|a, b| a.constraint_name.cmp(&b.constraint_name));
let mut cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
cache.insert(table_name.to_string(), constraints.clone());
Ok(constraints)
}
pub async fn invalidate_public_table_metadata(table_name: &str) {
let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
TABLE_COLUMN_TYPE_CACHE.write().await;
column_cache.remove(&metadata_cache_key("public", table_name));
column_cache.remove(table_name);
drop(column_cache);
let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
constraint_cache.remove(table_name);
}
pub async fn invalidate_all_public_table_metadata() {
let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
TABLE_COLUMN_TYPE_CACHE.write().await;
column_cache.clear();
drop(column_cache);
let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
constraint_cache.clear();
}
#[cfg(test)]
mod tests {
use super::*;
use once_cell::sync::Lazy;
use tokio::sync::Mutex;
static SCHEMA_CACHE_TEST_MUTEX: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
#[test]
fn postgres_column_descriptor_is_bigint_matches_data_type_udt_pair() {
assert!(postgres_column_descriptor_is_bigint("bigint|int8"));
assert!(postgres_column_descriptor_is_bigint(" bigint | int8 "));
assert!(!postgres_column_descriptor_is_bigint("uuid|uuid"));
assert!(!postgres_column_descriptor_is_bigint("text|text"));
}
#[test]
fn postgres_column_descriptor_is_timestamptz_matches_descriptor_pair() {
assert!(postgres_column_descriptor_is_timestamptz(
"timestamp with time zone|timestamptz"
));
assert!(postgres_column_descriptor_is_timestamptz(
" timestamp with time zone | timestamptz "
));
assert!(!postgres_column_descriptor_is_timestamptz("bigint|int8"));
}
#[tokio::test]
async fn invalidate_public_table_metadata_removes_only_target_table() {
let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
{
let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
TABLE_COLUMN_TYPE_CACHE.write().await;
columns.clear();
columns.insert(
"users".to_string(),
HashMap::from([("id".to_string(), "uuid".to_string())]),
);
columns.insert("orders".to_string(), HashMap::new());
}
{
let mut constraints: RwLockWriteGuard<
'_,
HashMap<String, Vec<UniqueConstraintMetadata>>,
> = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
constraints.clear();
constraints.insert(
"users".to_string(),
vec![UniqueConstraintMetadata {
constraint_name: "users_email_key".to_string(),
columns: vec!["email".to_string()],
}],
);
constraints.insert("orders".to_string(), Vec::new());
}
invalidate_public_table_metadata("users").await;
{
let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
assert!(!columns.contains_key("users"));
assert!(columns.contains_key("orders"));
}
{
let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
assert!(!constraints.contains_key("users"));
assert!(constraints.contains_key("orders"));
}
}
#[tokio::test]
async fn invalidate_all_public_table_metadata_clears_both_caches() {
let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
{
let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
TABLE_COLUMN_TYPE_CACHE.write().await;
columns.clear();
columns.insert(
"users".to_string(),
HashMap::from([("id".to_string(), "uuid".to_string())]),
);
}
{
let mut constraints: RwLockWriteGuard<
'_,
HashMap<String, Vec<UniqueConstraintMetadata>>,
> = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
constraints.clear();
constraints.insert(
"users".to_string(),
vec![UniqueConstraintMetadata {
constraint_name: "users_email_key".to_string(),
columns: vec!["email".to_string()],
}],
);
}
invalidate_all_public_table_metadata().await;
{
let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
assert!(columns.is_empty());
}
{
let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
assert!(constraints.is_empty());
}
}
}