1use deadpool::managed::{PoolError, QueueMode};
2use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
3use rustls::{ClientConfig, RootCertStore};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use tokio_postgres::NoTls;
7use tokio_postgres_rustls::MakeRustlsConnect;
8use tracing::{debug, info, instrument, trace};
9
10fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), String> {
21 if identifier.is_empty() {
22 return Err(format!("{} cannot be empty", name));
23 }
24
25 if identifier.len() > 63 {
26 return Err(format!("{} exceeds PostgreSQL's 63 character limit", name));
27 }
28
29 let first_char = identifier.chars().next().unwrap();
32 if !first_char.is_ascii_alphabetic() && first_char != '_' {
33 return Err(format!("{} must start with a letter or underscore", name));
34 }
35
36 for c in identifier.chars() {
37 if !c.is_ascii_alphanumeric() && c != '_' {
38 return Err(format!(
39 "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
40 name, c
41 ));
42 }
43 }
44
45 Ok(())
46}
47
48pub struct Db {
54 pool: Pool,
55 register_query: String,
56 register_batch_query: String,
57 queries_executed: AtomicU64,
58 query_errors: AtomicU64,
59}
60
61impl Db {
62 #[allow(clippy::too_many_arguments)]
80 pub async fn new(
81 connection_string: &str,
82 table_name: &str,
83 id_column: &str,
84 jsonb_column: &str,
85 pool_size: u32,
86 acquire_timeout_secs: Option<u64>,
87 idle_timeout_secs: Option<u64>,
88 max_lifetime_secs: Option<u64>,
89 use_tls: Option<bool>,
90 ) -> Result<Self, crate::errors::JsonRegisterError> {
91 validate_sql_identifier(table_name, "table_name")
93 .map_err(crate::errors::JsonRegisterError::Configuration)?;
94 validate_sql_identifier(id_column, "id_column")
95 .map_err(crate::errors::JsonRegisterError::Configuration)?;
96 validate_sql_identifier(jsonb_column, "jsonb_column")
97 .map_err(crate::errors::JsonRegisterError::Configuration)?;
98
99 let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
101 let _idle_timeout = idle_timeout_secs.map(Duration::from_secs);
102 let _max_lifetime = max_lifetime_secs.map(Duration::from_secs);
103
104 let mut cfg = Config::new();
106 cfg.url = Some(connection_string.to_string());
107 cfg.manager = Some(ManagerConfig {
108 recycling_method: RecyclingMethod::Fast,
109 });
110 cfg.pool = Some(deadpool_postgres::PoolConfig {
111 max_size: pool_size as usize,
112 timeouts: deadpool_postgres::Timeouts {
113 wait: Some(acquire_timeout),
114 create: Some(Duration::from_secs(10)),
115 recycle: Some(Duration::from_secs(10)),
116 },
117 queue_mode: QueueMode::Fifo,
118 });
119
120 let pool = if use_tls.unwrap_or(false) {
121 let mut root_store = RootCertStore::empty();
123 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
124
125 let config = ClientConfig::builder()
126 .with_root_certificates(root_store)
127 .with_no_client_auth();
128
129 let tls = MakeRustlsConnect::new(config);
130
131 cfg.create_pool(Some(Runtime::Tokio1), tls).map_err(|e| {
132 let error_msg = e.to_string();
133 let sanitized_msg = crate::sanitize_connection_string(&error_msg);
134 crate::errors::JsonRegisterError::Configuration(sanitized_msg)
135 })?
136 } else {
137 cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
138 let error_msg = e.to_string();
139 let sanitized_msg = crate::sanitize_connection_string(&error_msg);
140 crate::errors::JsonRegisterError::Configuration(sanitized_msg)
141 })?
142 };
143
144 let register_query = format!(
148 r#"
149 WITH inserted AS (
150 INSERT INTO {table_name} ({jsonb_column})
151 VALUES ($1)
152 ON CONFLICT ({jsonb_column}) DO NOTHING
153 RETURNING {id_column}
154 )
155 SELECT {id_column}::INT FROM inserted
156 UNION ALL
157 SELECT {id_column}::INT FROM {table_name}
158 WHERE {jsonb_column} = $2
159 AND NOT EXISTS (SELECT 1 FROM inserted)
160 LIMIT 1
161 "#
162 );
163
164 let register_batch_query = format!(
169 r#"
170 WITH input_objects AS (
171 SELECT
172 ord as original_order,
173 value as json_value
174 FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
175 ),
176 inserted AS (
177 INSERT INTO {table_name} ({jsonb_column})
178 SELECT json_value FROM input_objects
179 ON CONFLICT ({jsonb_column}) DO NOTHING
180 RETURNING {id_column}, {jsonb_column}
181 ),
182 existing AS (
183 SELECT t.{id_column}, t.{jsonb_column}
184 FROM {table_name} t
185 JOIN input_objects io ON t.{jsonb_column} = io.json_value
186 )
187 SELECT COALESCE(i.{id_column}, e.{id_column})::INT as {id_column}, io.original_order
188 FROM input_objects io
189 LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
190 LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
191 ORDER BY io.original_order
192 "#
193 );
194
195 info!(
196 table = table_name,
197 pool_size,
198 tls = use_tls.unwrap_or(false),
199 "JSON register database connected"
200 );
201
202 Ok(Self {
203 pool,
204 register_query,
205 register_batch_query,
206 queries_executed: AtomicU64::new(0),
207 query_errors: AtomicU64::new(0),
208 })
209 }
210
211 #[instrument(skip(self, value), fields(query_type = "single"))]
221 pub async fn register_object(
222 &self,
223 value: &serde_json::Value,
224 ) -> Result<i32, tokio_postgres::Error> {
225 self.queries_executed.fetch_add(1, Ordering::Relaxed);
226 trace!("acquiring database connection");
227
228 let client = self.pool.get().await.map_err(|e| {
229 self.query_errors.fetch_add(1, Ordering::Relaxed);
230 match e {
231 PoolError::Backend(db_err) => db_err,
232 PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
233 _ => tokio_postgres::Error::__private_api_timeout(),
234 }
235 })?;
236
237 trace!("executing register query");
238 let result = client
239 .query_one(&self.register_query, &[value, value])
240 .await;
241
242 match result {
243 Ok(row) => {
244 let id: i32 = row.get(0);
246 debug!(id, "object registered");
247 Ok(id)
248 }
249 Err(e) => {
250 self.query_errors.fetch_add(1, Ordering::Relaxed);
251 debug!(error = %e, "register query failed");
252 Err(e)
253 }
254 }
255 }
256
257 #[instrument(skip(self, values), fields(query_type = "batch", batch_size = values.len()))]
267 pub async fn register_batch_objects(
268 &self,
269 values: &[serde_json::Value],
270 ) -> Result<Vec<i32>, tokio_postgres::Error> {
271 if values.is_empty() {
272 trace!("empty batch, returning early");
273 return Ok(vec![]);
274 }
275
276 self.queries_executed.fetch_add(1, Ordering::Relaxed);
277 trace!("acquiring database connection");
278
279 let client = self.pool.get().await.map_err(|e| {
280 self.query_errors.fetch_add(1, Ordering::Relaxed);
281 match e {
282 PoolError::Backend(db_err) => db_err,
283 PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
284 _ => tokio_postgres::Error::__private_api_timeout(),
285 }
286 })?;
287
288 trace!("executing batch register query");
289 let values_vec: Vec<&serde_json::Value> = values.iter().collect();
290 let result = client
291 .query(&self.register_batch_query, &[&values_vec])
292 .await;
293
294 match result {
295 Ok(rows) => {
296 let mut ids = Vec::with_capacity(rows.len());
297 for row in rows {
298 let id: i32 = row.get(0);
300 ids.push(id);
301 }
302 debug!(registered = ids.len(), "batch registered");
303 Ok(ids)
304 }
305 Err(e) => {
306 self.query_errors.fetch_add(1, Ordering::Relaxed);
307 debug!(error = %e, "batch register query failed");
308 Err(e)
309 }
310 }
311 }
312
313 pub fn pool_size(&self) -> usize {
322 self.pool.status().size
323 }
324
325 pub fn idle_connections(&self) -> usize {
334 self.pool.status().available
335 }
336
337 pub fn is_closed(&self) -> bool {
345 self.pool.status().max_size == 0
347 }
348
349 pub fn queries_executed(&self) -> u64 {
355 self.queries_executed.load(Ordering::Relaxed)
356 }
357
358 pub fn query_errors(&self) -> u64 {
364 self.query_errors.load(Ordering::Relaxed)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_validate_sql_identifier_valid() {
374 assert!(validate_sql_identifier("table_name", "test").is_ok());
376 assert!(validate_sql_identifier("_underscore", "test").is_ok());
377 assert!(validate_sql_identifier("table123", "test").is_ok());
378 assert!(validate_sql_identifier("CamelCase", "test").is_ok());
379 assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
380 }
381
382 #[test]
383 fn test_validate_sql_identifier_empty() {
384 let result = validate_sql_identifier("", "test_name");
386 assert!(result.is_err());
387 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
388 }
389
390 #[test]
391 fn test_validate_sql_identifier_too_long() {
392 let long_name = "a".repeat(64);
394 let result = validate_sql_identifier(&long_name, "test_name");
395 assert!(result.is_err());
396 assert!(result
397 .unwrap_err()
398 .to_string()
399 .contains("63 character limit"));
400 }
401
402 #[test]
403 fn test_validate_sql_identifier_starts_with_number() {
404 let result = validate_sql_identifier("123table", "test_name");
406 assert!(result.is_err());
407 assert!(result
408 .unwrap_err()
409 .to_string()
410 .contains("must start with a letter or underscore"));
411 }
412
413 #[test]
414 fn test_validate_sql_identifier_invalid_characters() {
415 let test_cases = vec![
417 "table-name", "table.name", "table name", "table;name", "table'name", "table\"name", "table(name)", "table*name", "table/name", ];
427
428 for test_case in test_cases {
429 let result = validate_sql_identifier(test_case, "test_name");
430 assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
431 assert!(result
432 .unwrap_err()
433 .to_string()
434 .contains("invalid character"));
435 }
436 }
437
438 #[test]
439 fn test_validate_sql_identifier_boundary_cases() {
440 assert!(validate_sql_identifier("a", "test").is_ok()); assert!(validate_sql_identifier("_", "test").is_ok()); let exactly_63 = "a".repeat(63);
445 assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
446 }
447}