Skip to main content

athena_driver/postgresql/
schema_cache.rs

1use once_cell::sync::Lazy;
2use sqlx::PgPool;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
6
7type TableSchemaCache = HashMap<String, HashMap<String, String>>;
8type TableUniqueConstraintsCache = HashMap<String, Vec<UniqueConstraintMetadata>>;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct UniqueConstraintMetadata {
12    pub constraint_name: String,
13    pub columns: Vec<String>,
14}
15
16static TABLE_COLUMN_TYPE_CACHE: Lazy<Arc<RwLock<TableSchemaCache>>> =
17    Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
18
19static TABLE_UNIQUE_CONSTRAINT_CACHE: Lazy<Arc<RwLock<TableUniqueConstraintsCache>>> =
20    Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
21
22fn metadata_cache_key(schema_name: &str, table_name: &str) -> String {
23    format!(
24        "{}.{}",
25        schema_name.to_ascii_lowercase(),
26        table_name.to_ascii_lowercase()
27    )
28}
29
30pub async fn get_table_column_types(
31    pool: &PgPool,
32    schema_name: &str,
33    table_name: &str,
34) -> Result<HashMap<String, String>, sqlx::Error> {
35    let cache_key: String = metadata_cache_key(schema_name, table_name);
36    {
37        let cache: RwLockReadGuard<'_, HashMap<String, HashMap<String, String>>> =
38            TABLE_COLUMN_TYPE_CACHE.read().await;
39        if let Some(columns) = cache.get(&cache_key).or_else(|| cache.get(table_name)) {
40            return Ok(columns.clone());
41        }
42    }
43
44    let rows: Vec<(String, String, String)> = sqlx::query_as::<_, (String, String, String)>(
45        r#"
46                SELECT column_name, data_type, udt_name
47        FROM information_schema.columns
48        WHERE table_schema = $1
49          AND table_name = $2
50        "#,
51    )
52    .bind(schema_name)
53    .bind(table_name)
54    .fetch_all(pool)
55    .await?;
56
57    let columns: HashMap<String, String> = rows
58        .into_iter()
59        .map(|(column, data_type, udt_name)| {
60            (
61                column.to_ascii_lowercase(),
62                format!("{}|{}", data_type, udt_name),
63            )
64        })
65        .collect::<HashMap<_, _>>();
66
67    let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
68        TABLE_COLUMN_TYPE_CACHE.write().await;
69    cache.insert(cache_key, columns.clone());
70
71    Ok(columns)
72}
73
74/// Returns true when a value from [`get_table_column_types`] / [`get_public_table_column_types`] is a `bigint` column.
75///
76/// Descriptors combine `information_schema.columns.data_type`, `|`, and `udt_name` (for example `bigint|int8`),
77/// not the bare word `bigint`.
78pub fn postgres_column_descriptor_is_bigint(descriptor: &str) -> bool {
79    let mut parts = descriptor.split('|');
80    let data_type = parts.next().map(str::trim);
81    let udt_name = parts.next().map(str::trim);
82    data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("bigint"))
83        || udt_name.is_some_and(|u| u.eq_ignore_ascii_case("int8"))
84}
85
86/// True for `timestamp with time zone` / `timestamptz` columns per [`get_table_column_types`] descriptors
87/// (e.g. `timestamp with time zone|timestamptz`).
88pub fn postgres_column_descriptor_is_timestamptz(descriptor: &str) -> bool {
89    let mut parts: std::str::Split<'_, char> = descriptor.split('|');
90    let data_type: Option<&str> = parts.next().map(str::trim);
91    let udt_name: Option<&str> = parts.next().map(str::trim);
92    data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("timestamp with time zone"))
93        || udt_name.is_some_and(|u| u.eq_ignore_ascii_case("timestamptz"))
94}
95
96/// ## `get_public_table_column_types`
97/// Get the column types for a public table.
98///
99/// # Arguments
100///
101/// * `pool` - The pool to use.
102/// * `table_name` - The name of the table.
103///
104/// # Returns
105///
106/// A `Result` containing the column types.
107///
108/// The column types for a public table.
109///
110pub async fn get_public_table_column_types(
111    pool: &PgPool,
112    table_name: &str,
113) -> Result<HashMap<String, String>, sqlx::Error> {
114    get_table_column_types(pool, "public", table_name).await
115}
116
117/// ## `get_public_table_unique_constraints`
118/// Get unique constraint metadata for a public table.
119///
120/// # Arguments
121///
122/// * `pool` - The pool to use.
123/// * `table_name` - The name of the table.
124///
125/// # Returns
126///
127/// A `Result` containing unique constraints and their ordered column lists.
128pub async fn get_public_table_unique_constraints(
129    pool: &PgPool,
130    table_name: &str,
131) -> Result<Vec<UniqueConstraintMetadata>, sqlx::Error> {
132    {
133        let cache: RwLockReadGuard<'_, TableUniqueConstraintsCache> =
134            TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
135        if let Some(constraints) = cache.get(table_name) {
136            return Ok(constraints.clone());
137        }
138    }
139
140    // TODO: MANUAL_QUERY needs to be replaced with a proper SQLx query builder or pre-built query for better performance and security.
141    let rows: Vec<(String, String)> = sqlx::query_as::<_, (String, String)>(
142        r#"
143        SELECT tc.constraint_name, kcu.column_name
144        FROM information_schema.table_constraints AS tc
145        JOIN information_schema.key_column_usage AS kcu
146          ON tc.constraint_name = kcu.constraint_name
147         AND tc.table_schema = kcu.table_schema
148         AND tc.table_name = kcu.table_name
149        WHERE tc.table_schema = 'public'
150          AND tc.table_name = $1
151          AND tc.constraint_type = 'UNIQUE'
152        ORDER BY tc.constraint_name, kcu.ordinal_position
153        "#,
154    )
155    .bind(table_name)
156    .fetch_all(pool)
157    .await?;
158
159    let mut grouped: HashMap<String, Vec<String>> = HashMap::new();
160    for (constraint_name, column_name) in rows {
161        grouped
162            .entry(constraint_name)
163            .or_default()
164            .push(column_name);
165    }
166
167    let mut constraints: Vec<UniqueConstraintMetadata> = grouped
168        .into_iter()
169        .map(|(constraint_name, columns)| UniqueConstraintMetadata {
170            constraint_name,
171            columns,
172        })
173        .collect();
174    constraints.sort_by(|a, b| a.constraint_name.cmp(&b.constraint_name));
175
176    let mut cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
177        TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
178    cache.insert(table_name.to_string(), constraints.clone());
179
180    Ok(constraints)
181}
182
183/// Invalidate cached public-table metadata for one table.
184pub async fn invalidate_public_table_metadata(table_name: &str) {
185    let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
186        TABLE_COLUMN_TYPE_CACHE.write().await;
187    column_cache.remove(&metadata_cache_key("public", table_name));
188    column_cache.remove(table_name);
189    drop(column_cache);
190
191    let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
192        TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
193    constraint_cache.remove(table_name);
194}
195
196/// Invalidate all cached public-table metadata.
197pub async fn invalidate_all_public_table_metadata() {
198    let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
199        TABLE_COLUMN_TYPE_CACHE.write().await;
200    column_cache.clear();
201    drop(column_cache);
202
203    let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
204        TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
205    constraint_cache.clear();
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use once_cell::sync::Lazy;
212    use tokio::sync::Mutex;
213
214    static SCHEMA_CACHE_TEST_MUTEX: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
215
216    #[test]
217    fn postgres_column_descriptor_is_bigint_matches_data_type_udt_pair() {
218        assert!(postgres_column_descriptor_is_bigint("bigint|int8"));
219        assert!(postgres_column_descriptor_is_bigint(" bigint | int8 "));
220        assert!(!postgres_column_descriptor_is_bigint("uuid|uuid"));
221        assert!(!postgres_column_descriptor_is_bigint("text|text"));
222    }
223
224    #[test]
225    fn postgres_column_descriptor_is_timestamptz_matches_descriptor_pair() {
226        assert!(postgres_column_descriptor_is_timestamptz(
227            "timestamp with time zone|timestamptz"
228        ));
229        assert!(postgres_column_descriptor_is_timestamptz(
230            " timestamp with time zone | timestamptz "
231        ));
232        assert!(!postgres_column_descriptor_is_timestamptz("bigint|int8"));
233    }
234
235    #[tokio::test]
236    async fn invalidate_public_table_metadata_removes_only_target_table() {
237        let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
238        {
239            let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
240                TABLE_COLUMN_TYPE_CACHE.write().await;
241            columns.clear();
242            columns.insert(
243                "users".to_string(),
244                HashMap::from([("id".to_string(), "uuid".to_string())]),
245            );
246            columns.insert("orders".to_string(), HashMap::new());
247        }
248
249        {
250            let mut constraints: RwLockWriteGuard<
251                '_,
252                HashMap<String, Vec<UniqueConstraintMetadata>>,
253            > = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
254            constraints.clear();
255            constraints.insert(
256                "users".to_string(),
257                vec![UniqueConstraintMetadata {
258                    constraint_name: "users_email_key".to_string(),
259                    columns: vec!["email".to_string()],
260                }],
261            );
262            constraints.insert("orders".to_string(), Vec::new());
263        }
264
265        invalidate_public_table_metadata("users").await;
266
267        {
268            let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
269            assert!(!columns.contains_key("users"));
270            assert!(columns.contains_key("orders"));
271        }
272
273        {
274            let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
275            assert!(!constraints.contains_key("users"));
276            assert!(constraints.contains_key("orders"));
277        }
278    }
279
280    #[tokio::test]
281    async fn invalidate_all_public_table_metadata_clears_both_caches() {
282        let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
283        {
284            let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
285                TABLE_COLUMN_TYPE_CACHE.write().await;
286            columns.clear();
287            columns.insert(
288                "users".to_string(),
289                HashMap::from([("id".to_string(), "uuid".to_string())]),
290            );
291        }
292
293        {
294            let mut constraints: RwLockWriteGuard<
295                '_,
296                HashMap<String, Vec<UniqueConstraintMetadata>>,
297            > = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
298            constraints.clear();
299            constraints.insert(
300                "users".to_string(),
301                vec![UniqueConstraintMetadata {
302                    constraint_name: "users_email_key".to_string(),
303                    columns: vec!["email".to_string()],
304                }],
305            );
306        }
307
308        invalidate_all_public_table_metadata().await;
309
310        {
311            let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
312            assert!(columns.is_empty());
313        }
314
315        {
316            let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
317            assert!(constraints.is_empty());
318        }
319    }
320}