1use deadpool::managed::{PoolError, QueueMode};
2use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
3use rustls::{ClientConfig, RootCertStore};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use tokio_postgres::NoTls;
7use tokio_postgres_rustls::MakeRustlsConnect;
8
9fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), String> {
20 if identifier.is_empty() {
21 return Err(format!("{} cannot be empty", name));
22 }
23
24 if identifier.len() > 63 {
25 return Err(format!("{} exceeds PostgreSQL's 63 character limit", name));
26 }
27
28 let first_char = identifier.chars().next().unwrap();
31 if !first_char.is_ascii_alphabetic() && first_char != '_' {
32 return Err(format!("{} must start with a letter or underscore", name));
33 }
34
35 for c in identifier.chars() {
36 if !c.is_ascii_alphanumeric() && c != '_' {
37 return Err(format!(
38 "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
39 name, c
40 ));
41 }
42 }
43
44 Ok(())
45}
46
47pub struct Db {
53 pool: Pool,
54 register_query: String,
55 register_batch_query: String,
56 queries_executed: AtomicU64,
57 query_errors: AtomicU64,
58}
59
60impl Db {
61 #[allow(clippy::too_many_arguments)]
79 pub async fn new(
80 connection_string: &str,
81 table_name: &str,
82 id_column: &str,
83 jsonb_column: &str,
84 pool_size: u32,
85 acquire_timeout_secs: Option<u64>,
86 idle_timeout_secs: Option<u64>,
87 max_lifetime_secs: Option<u64>,
88 use_tls: Option<bool>,
89 ) -> Result<Self, crate::errors::JsonRegisterError> {
90 validate_sql_identifier(table_name, "table_name")
92 .map_err(crate::errors::JsonRegisterError::Configuration)?;
93 validate_sql_identifier(id_column, "id_column")
94 .map_err(crate::errors::JsonRegisterError::Configuration)?;
95 validate_sql_identifier(jsonb_column, "jsonb_column")
96 .map_err(crate::errors::JsonRegisterError::Configuration)?;
97
98 let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
100 let _idle_timeout = idle_timeout_secs.map(Duration::from_secs);
101 let _max_lifetime = max_lifetime_secs.map(Duration::from_secs);
102
103 let mut cfg = Config::new();
105 cfg.url = Some(connection_string.to_string());
106 cfg.manager = Some(ManagerConfig {
107 recycling_method: RecyclingMethod::Fast,
108 });
109 cfg.pool = Some(deadpool_postgres::PoolConfig {
110 max_size: pool_size as usize,
111 timeouts: deadpool_postgres::Timeouts {
112 wait: Some(acquire_timeout),
113 create: Some(Duration::from_secs(10)),
114 recycle: Some(Duration::from_secs(10)),
115 },
116 queue_mode: QueueMode::Fifo,
117 });
118
119 let pool = if use_tls.unwrap_or(false) {
120 let mut root_store = RootCertStore::empty();
122 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
123
124 let config = ClientConfig::builder()
125 .with_root_certificates(root_store)
126 .with_no_client_auth();
127
128 let tls = MakeRustlsConnect::new(config);
129
130 cfg.create_pool(Some(Runtime::Tokio1), tls).map_err(|e| {
131 let error_msg = e.to_string();
132 let sanitized_msg = crate::sanitize_connection_string(&error_msg);
133 crate::errors::JsonRegisterError::Configuration(sanitized_msg)
134 })?
135 } else {
136 cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
137 let error_msg = e.to_string();
138 let sanitized_msg = crate::sanitize_connection_string(&error_msg);
139 crate::errors::JsonRegisterError::Configuration(sanitized_msg)
140 })?
141 };
142
143 let register_query = format!(
147 r#"
148 WITH inserted AS (
149 INSERT INTO {table_name} ({jsonb_column})
150 VALUES ($1)
151 ON CONFLICT ({jsonb_column}) DO NOTHING
152 RETURNING {id_column}
153 )
154 SELECT {id_column}::INT FROM inserted
155 UNION ALL
156 SELECT {id_column}::INT FROM {table_name}
157 WHERE {jsonb_column} = $2
158 AND NOT EXISTS (SELECT 1 FROM inserted)
159 LIMIT 1
160 "#
161 );
162
163 let register_batch_query = format!(
168 r#"
169 WITH input_objects AS (
170 SELECT
171 ord as original_order,
172 value as json_value
173 FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
174 ),
175 inserted AS (
176 INSERT INTO {table_name} ({jsonb_column})
177 SELECT json_value FROM input_objects
178 ON CONFLICT ({jsonb_column}) DO NOTHING
179 RETURNING {id_column}, {jsonb_column}
180 ),
181 existing AS (
182 SELECT t.{id_column}, t.{jsonb_column}
183 FROM {table_name} t
184 JOIN input_objects io ON t.{jsonb_column} = io.json_value
185 )
186 SELECT COALESCE(i.{id_column}, e.{id_column})::INT as {id_column}, io.original_order
187 FROM input_objects io
188 LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
189 LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
190 ORDER BY io.original_order
191 "#
192 );
193
194 Ok(Self {
195 pool,
196 register_query,
197 register_batch_query,
198 queries_executed: AtomicU64::new(0),
199 query_errors: AtomicU64::new(0),
200 })
201 }
202
203 pub async fn register_object(
213 &self,
214 value: &serde_json::Value,
215 ) -> Result<i32, tokio_postgres::Error> {
216 self.queries_executed.fetch_add(1, Ordering::Relaxed);
217
218 let client = self.pool.get().await.map_err(|e| {
219 self.query_errors.fetch_add(1, Ordering::Relaxed);
220 match e {
221 PoolError::Backend(db_err) => db_err,
222 PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
223 _ => tokio_postgres::Error::__private_api_timeout(),
224 }
225 })?;
226
227 let result = client
228 .query_one(&self.register_query, &[value, value])
229 .await;
230
231 match result {
232 Ok(row) => {
233 let id: i32 = row.get(0);
235 Ok(id)
236 }
237 Err(e) => {
238 self.query_errors.fetch_add(1, Ordering::Relaxed);
239 Err(e)
240 }
241 }
242 }
243
244 pub async fn register_batch_objects(
254 &self,
255 values: &[serde_json::Value],
256 ) -> Result<Vec<i32>, tokio_postgres::Error> {
257 if values.is_empty() {
258 return Ok(vec![]);
259 }
260
261 self.queries_executed.fetch_add(1, Ordering::Relaxed);
262
263 let client = self.pool.get().await.map_err(|e| {
264 self.query_errors.fetch_add(1, Ordering::Relaxed);
265 match e {
266 PoolError::Backend(db_err) => db_err,
267 PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
268 _ => tokio_postgres::Error::__private_api_timeout(),
269 }
270 })?;
271
272 let values_vec: Vec<&serde_json::Value> = values.iter().collect();
273 let result = client
274 .query(&self.register_batch_query, &[&values_vec])
275 .await;
276
277 match result {
278 Ok(rows) => {
279 let mut ids = Vec::with_capacity(rows.len());
280 for row in rows {
281 let id: i32 = row.get(0);
283 ids.push(id);
284 }
285 Ok(ids)
286 }
287 Err(e) => {
288 self.query_errors.fetch_add(1, Ordering::Relaxed);
289 Err(e)
290 }
291 }
292 }
293
294 pub fn pool_size(&self) -> usize {
303 self.pool.status().size
304 }
305
306 pub fn idle_connections(&self) -> usize {
315 self.pool.status().available
316 }
317
318 pub fn is_closed(&self) -> bool {
326 self.pool.status().max_size == 0
328 }
329
330 pub fn queries_executed(&self) -> u64 {
336 self.queries_executed.load(Ordering::Relaxed)
337 }
338
339 pub fn query_errors(&self) -> u64 {
345 self.query_errors.load(Ordering::Relaxed)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_validate_sql_identifier_valid() {
355 assert!(validate_sql_identifier("table_name", "test").is_ok());
357 assert!(validate_sql_identifier("_underscore", "test").is_ok());
358 assert!(validate_sql_identifier("table123", "test").is_ok());
359 assert!(validate_sql_identifier("CamelCase", "test").is_ok());
360 assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
361 }
362
363 #[test]
364 fn test_validate_sql_identifier_empty() {
365 let result = validate_sql_identifier("", "test_name");
367 assert!(result.is_err());
368 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
369 }
370
371 #[test]
372 fn test_validate_sql_identifier_too_long() {
373 let long_name = "a".repeat(64);
375 let result = validate_sql_identifier(&long_name, "test_name");
376 assert!(result.is_err());
377 assert!(result
378 .unwrap_err()
379 .to_string()
380 .contains("63 character limit"));
381 }
382
383 #[test]
384 fn test_validate_sql_identifier_starts_with_number() {
385 let result = validate_sql_identifier("123table", "test_name");
387 assert!(result.is_err());
388 assert!(result
389 .unwrap_err()
390 .to_string()
391 .contains("must start with a letter or underscore"));
392 }
393
394 #[test]
395 fn test_validate_sql_identifier_invalid_characters() {
396 let test_cases = vec![
398 "table-name", "table.name", "table name", "table;name", "table'name", "table\"name", "table(name)", "table*name", "table/name", ];
408
409 for test_case in test_cases {
410 let result = validate_sql_identifier(test_case, "test_name");
411 assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
412 assert!(result
413 .unwrap_err()
414 .to_string()
415 .contains("invalid character"));
416 }
417 }
418
419 #[test]
420 fn test_validate_sql_identifier_boundary_cases() {
421 assert!(validate_sql_identifier("a", "test").is_ok()); assert!(validate_sql_identifier("_", "test").is_ok()); let exactly_63 = "a".repeat(63);
426 assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
427 }
428}