athena_driver/postgresql/
schema_cache.rs1use once_cell::sync::Lazy;
2use sqlx::PgPool;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
6
7type TableSchemaCache = HashMap<String, HashMap<String, String>>;
8type TableUniqueConstraintsCache = HashMap<String, Vec<UniqueConstraintMetadata>>;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct UniqueConstraintMetadata {
12 pub constraint_name: String,
13 pub columns: Vec<String>,
14}
15
16static TABLE_COLUMN_TYPE_CACHE: Lazy<Arc<RwLock<TableSchemaCache>>> =
17 Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
18
19static TABLE_UNIQUE_CONSTRAINT_CACHE: Lazy<Arc<RwLock<TableUniqueConstraintsCache>>> =
20 Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
21
22fn metadata_cache_key(schema_name: &str, table_name: &str) -> String {
23 format!(
24 "{}.{}",
25 schema_name.to_ascii_lowercase(),
26 table_name.to_ascii_lowercase()
27 )
28}
29
30pub async fn get_table_column_types(
31 pool: &PgPool,
32 schema_name: &str,
33 table_name: &str,
34) -> Result<HashMap<String, String>, sqlx::Error> {
35 let cache_key: String = metadata_cache_key(schema_name, table_name);
36 {
37 let cache: RwLockReadGuard<'_, HashMap<String, HashMap<String, String>>> =
38 TABLE_COLUMN_TYPE_CACHE.read().await;
39 if let Some(columns) = cache.get(&cache_key).or_else(|| cache.get(table_name)) {
40 return Ok(columns.clone());
41 }
42 }
43
44 let rows: Vec<(String, String, String)> = sqlx::query_as::<_, (String, String, String)>(
45 r#"
46 SELECT column_name, data_type, udt_name
47 FROM information_schema.columns
48 WHERE table_schema = $1
49 AND table_name = $2
50 "#,
51 )
52 .bind(schema_name)
53 .bind(table_name)
54 .fetch_all(pool)
55 .await?;
56
57 let columns: HashMap<String, String> = rows
58 .into_iter()
59 .map(|(column, data_type, udt_name)| {
60 (
61 column.to_ascii_lowercase(),
62 format!("{}|{}", data_type, udt_name),
63 )
64 })
65 .collect::<HashMap<_, _>>();
66
67 let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
68 TABLE_COLUMN_TYPE_CACHE.write().await;
69 cache.insert(cache_key, columns.clone());
70
71 Ok(columns)
72}
73
74pub fn postgres_column_descriptor_is_bigint(descriptor: &str) -> bool {
79 let mut parts = descriptor.split('|');
80 let data_type = parts.next().map(str::trim);
81 let udt_name = parts.next().map(str::trim);
82 data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("bigint"))
83 || udt_name.is_some_and(|u| u.eq_ignore_ascii_case("int8"))
84}
85
86pub fn postgres_column_descriptor_is_timestamptz(descriptor: &str) -> bool {
89 let mut parts: std::str::Split<'_, char> = descriptor.split('|');
90 let data_type: Option<&str> = parts.next().map(str::trim);
91 let udt_name: Option<&str> = parts.next().map(str::trim);
92 data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("timestamp with time zone"))
93 || udt_name.is_some_and(|u| u.eq_ignore_ascii_case("timestamptz"))
94}
95
96pub async fn get_public_table_column_types(
111 pool: &PgPool,
112 table_name: &str,
113) -> Result<HashMap<String, String>, sqlx::Error> {
114 get_table_column_types(pool, "public", table_name).await
115}
116
117pub async fn get_public_table_unique_constraints(
129 pool: &PgPool,
130 table_name: &str,
131) -> Result<Vec<UniqueConstraintMetadata>, sqlx::Error> {
132 {
133 let cache: RwLockReadGuard<'_, TableUniqueConstraintsCache> =
134 TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
135 if let Some(constraints) = cache.get(table_name) {
136 return Ok(constraints.clone());
137 }
138 }
139
140 let rows: Vec<(String, String)> = sqlx::query_as::<_, (String, String)>(
142 r#"
143 SELECT tc.constraint_name, kcu.column_name
144 FROM information_schema.table_constraints AS tc
145 JOIN information_schema.key_column_usage AS kcu
146 ON tc.constraint_name = kcu.constraint_name
147 AND tc.table_schema = kcu.table_schema
148 AND tc.table_name = kcu.table_name
149 WHERE tc.table_schema = 'public'
150 AND tc.table_name = $1
151 AND tc.constraint_type = 'UNIQUE'
152 ORDER BY tc.constraint_name, kcu.ordinal_position
153 "#,
154 )
155 .bind(table_name)
156 .fetch_all(pool)
157 .await?;
158
159 let mut grouped: HashMap<String, Vec<String>> = HashMap::new();
160 for (constraint_name, column_name) in rows {
161 grouped
162 .entry(constraint_name)
163 .or_default()
164 .push(column_name);
165 }
166
167 let mut constraints: Vec<UniqueConstraintMetadata> = grouped
168 .into_iter()
169 .map(|(constraint_name, columns)| UniqueConstraintMetadata {
170 constraint_name,
171 columns,
172 })
173 .collect();
174 constraints.sort_by(|a, b| a.constraint_name.cmp(&b.constraint_name));
175
176 let mut cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
177 TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
178 cache.insert(table_name.to_string(), constraints.clone());
179
180 Ok(constraints)
181}
182
183pub async fn invalidate_public_table_metadata(table_name: &str) {
185 let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
186 TABLE_COLUMN_TYPE_CACHE.write().await;
187 column_cache.remove(&metadata_cache_key("public", table_name));
188 column_cache.remove(table_name);
189 drop(column_cache);
190
191 let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
192 TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
193 constraint_cache.remove(table_name);
194}
195
196pub async fn invalidate_all_public_table_metadata() {
198 let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
199 TABLE_COLUMN_TYPE_CACHE.write().await;
200 column_cache.clear();
201 drop(column_cache);
202
203 let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
204 TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
205 constraint_cache.clear();
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use once_cell::sync::Lazy;
212 use tokio::sync::Mutex;
213
214 static SCHEMA_CACHE_TEST_MUTEX: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
215
216 #[test]
217 fn postgres_column_descriptor_is_bigint_matches_data_type_udt_pair() {
218 assert!(postgres_column_descriptor_is_bigint("bigint|int8"));
219 assert!(postgres_column_descriptor_is_bigint(" bigint | int8 "));
220 assert!(!postgres_column_descriptor_is_bigint("uuid|uuid"));
221 assert!(!postgres_column_descriptor_is_bigint("text|text"));
222 }
223
224 #[test]
225 fn postgres_column_descriptor_is_timestamptz_matches_descriptor_pair() {
226 assert!(postgres_column_descriptor_is_timestamptz(
227 "timestamp with time zone|timestamptz"
228 ));
229 assert!(postgres_column_descriptor_is_timestamptz(
230 " timestamp with time zone | timestamptz "
231 ));
232 assert!(!postgres_column_descriptor_is_timestamptz("bigint|int8"));
233 }
234
235 #[tokio::test]
236 async fn invalidate_public_table_metadata_removes_only_target_table() {
237 let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
238 {
239 let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
240 TABLE_COLUMN_TYPE_CACHE.write().await;
241 columns.clear();
242 columns.insert(
243 "users".to_string(),
244 HashMap::from([("id".to_string(), "uuid".to_string())]),
245 );
246 columns.insert("orders".to_string(), HashMap::new());
247 }
248
249 {
250 let mut constraints: RwLockWriteGuard<
251 '_,
252 HashMap<String, Vec<UniqueConstraintMetadata>>,
253 > = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
254 constraints.clear();
255 constraints.insert(
256 "users".to_string(),
257 vec![UniqueConstraintMetadata {
258 constraint_name: "users_email_key".to_string(),
259 columns: vec!["email".to_string()],
260 }],
261 );
262 constraints.insert("orders".to_string(), Vec::new());
263 }
264
265 invalidate_public_table_metadata("users").await;
266
267 {
268 let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
269 assert!(!columns.contains_key("users"));
270 assert!(columns.contains_key("orders"));
271 }
272
273 {
274 let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
275 assert!(!constraints.contains_key("users"));
276 assert!(constraints.contains_key("orders"));
277 }
278 }
279
280 #[tokio::test]
281 async fn invalidate_all_public_table_metadata_clears_both_caches() {
282 let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
283 {
284 let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
285 TABLE_COLUMN_TYPE_CACHE.write().await;
286 columns.clear();
287 columns.insert(
288 "users".to_string(),
289 HashMap::from([("id".to_string(), "uuid".to_string())]),
290 );
291 }
292
293 {
294 let mut constraints: RwLockWriteGuard<
295 '_,
296 HashMap<String, Vec<UniqueConstraintMetadata>>,
297 > = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
298 constraints.clear();
299 constraints.insert(
300 "users".to_string(),
301 vec![UniqueConstraintMetadata {
302 constraint_name: "users_email_key".to_string(),
303 columns: vec!["email".to_string()],
304 }],
305 );
306 }
307
308 invalidate_all_public_table_metadata().await;
309
310 {
311 let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
312 assert!(columns.is_empty());
313 }
314
315 {
316 let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
317 assert!(constraints.is_empty());
318 }
319 }
320}