1use sqlx::postgres::PgPoolOptions;
2use sqlx::{PgPool, Row};
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), sqlx::Error> {
17 if identifier.is_empty() {
18 return Err(sqlx::Error::Configuration(
19 format!("{} cannot be empty", name).into(),
20 ));
21 }
22
23 if identifier.len() > 63 {
24 return Err(sqlx::Error::Configuration(
25 format!("{} exceeds PostgreSQL's 63 character limit", name).into(),
26 ));
27 }
28
29 let first_char = identifier.chars().next().unwrap();
32 if !first_char.is_ascii_alphabetic() && first_char != '_' {
33 return Err(sqlx::Error::Configuration(
34 format!("{} must start with a letter or underscore", name).into(),
35 ));
36 }
37
38 for c in identifier.chars() {
39 if !c.is_ascii_alphanumeric() && c != '_' {
40 return Err(sqlx::Error::Configuration(
41 format!(
42 "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
43 name, c
44 )
45 .into(),
46 ));
47 }
48 }
49
50 Ok(())
51}
52
53pub struct Db {
59 pool: PgPool,
60 register_query: String,
61 register_batch_query: String,
62 queries_executed: AtomicU64,
63 query_errors: AtomicU64,
64}
65
66impl Db {
67 #[allow(clippy::too_many_arguments)]
84 pub async fn new(
85 connection_string: &str,
86 table_name: &str,
87 id_column: &str,
88 jsonb_column: &str,
89 pool_size: u32,
90 acquire_timeout_secs: Option<u64>,
91 idle_timeout_secs: Option<u64>,
92 max_lifetime_secs: Option<u64>,
93 ) -> Result<Self, sqlx::Error> {
94 validate_sql_identifier(table_name, "table_name")?;
96 validate_sql_identifier(id_column, "id_column")?;
97 validate_sql_identifier(jsonb_column, "jsonb_column")?;
98
99 let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
101 let idle_timeout = idle_timeout_secs.map(Duration::from_secs);
102 let max_lifetime = max_lifetime_secs.map(Duration::from_secs);
103
104 let pool = PgPoolOptions::new()
105 .max_connections(pool_size)
106 .acquire_timeout(acquire_timeout)
108 .idle_timeout(idle_timeout.or(Some(Duration::from_secs(600))))
110 .max_lifetime(max_lifetime.or(Some(Duration::from_secs(1800))))
112 .connect(connection_string)
113 .await
114 .map_err(|e| {
115 let error_msg = e.to_string();
117 let sanitized_msg = crate::sanitize_connection_string(&error_msg);
118 sqlx::Error::Configuration(sanitized_msg.into())
119 })?;
120
121 let register_query = format!(
125 r#"
126 WITH inserted AS (
127 INSERT INTO {table_name} ({jsonb_column})
128 VALUES ($1::jsonb)
129 ON CONFLICT ({jsonb_column}) DO NOTHING
130 RETURNING {id_column}
131 )
132 SELECT {id_column} FROM inserted
133 UNION ALL
134 SELECT {id_column} FROM {table_name}
135 WHERE {jsonb_column} = $2::jsonb
136 AND NOT EXISTS (SELECT 1 FROM inserted)
137 LIMIT 1
138 "#
139 );
140
141 let register_batch_query = format!(
146 r#"
147 WITH input_objects AS (
148 SELECT
149 ord as original_order,
150 value as json_value
151 FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
152 ),
153 inserted AS (
154 INSERT INTO {table_name} ({jsonb_column})
155 SELECT json_value FROM input_objects
156 ON CONFLICT ({jsonb_column}) DO NOTHING
157 RETURNING {id_column}, {jsonb_column}
158 ),
159 existing AS (
160 SELECT t.{id_column}, t.{jsonb_column}
161 FROM {table_name} t
162 JOIN input_objects io ON t.{jsonb_column} = io.json_value
163 )
164 SELECT COALESCE(i.{id_column}, e.{id_column}) as {id_column}, io.original_order
165 FROM input_objects io
166 LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
167 LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
168 ORDER BY io.original_order
169 "#
170 );
171
172 Ok(Self {
173 pool,
174 register_query,
175 register_batch_query,
176 queries_executed: AtomicU64::new(0),
177 query_errors: AtomicU64::new(0),
178 })
179 }
180
181 pub async fn register_object(&self, json_str: &str) -> Result<i32, sqlx::Error> {
191 self.queries_executed.fetch_add(1, Ordering::Relaxed);
192
193 let result = sqlx::query(&self.register_query)
194 .bind(json_str) .bind(json_str) .fetch_one(&self.pool)
197 .await;
198
199 match result {
200 Ok(row) => row.try_get(0),
201 Err(e) => {
202 self.query_errors.fetch_add(1, Ordering::Relaxed);
203 Err(e)
204 }
205 }
206 }
207
208 pub async fn register_batch_objects(
218 &self,
219 json_strs: &[String],
220 ) -> Result<Vec<i32>, sqlx::Error> {
221 if json_strs.is_empty() {
222 return Ok(vec![]);
223 }
224
225 self.queries_executed.fetch_add(1, Ordering::Relaxed);
226
227 let result = sqlx::query(&self.register_batch_query)
228 .bind(json_strs) .fetch_all(&self.pool)
230 .await;
231
232 match result {
233 Ok(rows) => {
234 let mut ids = Vec::with_capacity(rows.len());
235 for row in rows {
236 let id: i32 = row.try_get(0)?;
237 ids.push(id);
238 }
239 Ok(ids)
240 }
241 Err(e) => {
242 self.query_errors.fetch_add(1, Ordering::Relaxed);
243 Err(e)
244 }
245 }
246 }
247
248 pub fn pool_size(&self) -> usize {
257 self.pool.size() as usize
258 }
259
260 pub fn idle_connections(&self) -> usize {
269 self.pool.num_idle()
270 }
271
272 pub fn is_closed(&self) -> bool {
280 self.pool.is_closed()
281 }
282
283 pub fn queries_executed(&self) -> u64 {
289 self.queries_executed.load(Ordering::Relaxed)
290 }
291
292 pub fn query_errors(&self) -> u64 {
298 self.query_errors.load(Ordering::Relaxed)
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_validate_sql_identifier_valid() {
308 assert!(validate_sql_identifier("table_name", "test").is_ok());
310 assert!(validate_sql_identifier("_underscore", "test").is_ok());
311 assert!(validate_sql_identifier("table123", "test").is_ok());
312 assert!(validate_sql_identifier("CamelCase", "test").is_ok());
313 assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
314 }
315
316 #[test]
317 fn test_validate_sql_identifier_empty() {
318 let result = validate_sql_identifier("", "test_name");
320 assert!(result.is_err());
321 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
322 }
323
324 #[test]
325 fn test_validate_sql_identifier_too_long() {
326 let long_name = "a".repeat(64);
328 let result = validate_sql_identifier(&long_name, "test_name");
329 assert!(result.is_err());
330 assert!(result
331 .unwrap_err()
332 .to_string()
333 .contains("63 character limit"));
334 }
335
336 #[test]
337 fn test_validate_sql_identifier_starts_with_number() {
338 let result = validate_sql_identifier("123table", "test_name");
340 assert!(result.is_err());
341 assert!(result
342 .unwrap_err()
343 .to_string()
344 .contains("must start with a letter or underscore"));
345 }
346
347 #[test]
348 fn test_validate_sql_identifier_invalid_characters() {
349 let test_cases = vec![
351 "table-name", "table.name", "table name", "table;name", "table'name", "table\"name", "table(name)", "table*name", "table/name", ];
361
362 for test_case in test_cases {
363 let result = validate_sql_identifier(test_case, "test_name");
364 assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
365 assert!(result
366 .unwrap_err()
367 .to_string()
368 .contains("invalid character"));
369 }
370 }
371
372 #[test]
373 fn test_validate_sql_identifier_boundary_cases() {
374 assert!(validate_sql_identifier("a", "test").is_ok()); assert!(validate_sql_identifier("_", "test").is_ok()); let exactly_63 = "a".repeat(63);
379 assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
380 }
381}