1use sqlx::postgres::{PgPoolOptions, PgRow};
2use sqlx::{PgPool, Row};
3use std::time::Duration;
4
5fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), sqlx::Error> {
16 if identifier.is_empty() {
17 return Err(sqlx::Error::Configuration(
18 format!("{} cannot be empty", name).into(),
19 ));
20 }
21
22 if identifier.len() > 63 {
23 return Err(sqlx::Error::Configuration(
24 format!("{} exceeds PostgreSQL's 63 character limit", name).into(),
25 ));
26 }
27
28 let first_char = identifier.chars().next().unwrap();
31 if !first_char.is_ascii_alphabetic() && first_char != '_' {
32 return Err(sqlx::Error::Configuration(
33 format!("{} must start with a letter or underscore", name).into(),
34 ));
35 }
36
37 for c in identifier.chars() {
38 if !c.is_ascii_alphanumeric() && c != '_' {
39 return Err(sqlx::Error::Configuration(
40 format!(
41 "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
42 name, c
43 )
44 .into(),
45 ));
46 }
47 }
48
49 Ok(())
50}
51
52pub struct Db {
58 pool: PgPool,
59 register_query: String,
60 register_batch_query: String,
61}
62
63impl Db {
64 #[allow(clippy::too_many_arguments)]
81 pub async fn new(
82 connection_string: &str,
83 table_name: &str,
84 id_column: &str,
85 jsonb_column: &str,
86 pool_size: u32,
87 acquire_timeout_secs: Option<u64>,
88 idle_timeout_secs: Option<u64>,
89 max_lifetime_secs: Option<u64>,
90 ) -> Result<Self, sqlx::Error> {
91 validate_sql_identifier(table_name, "table_name")?;
93 validate_sql_identifier(id_column, "id_column")?;
94 validate_sql_identifier(jsonb_column, "jsonb_column")?;
95
96 let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
98 let idle_timeout = idle_timeout_secs.map(Duration::from_secs);
99 let max_lifetime = max_lifetime_secs.map(Duration::from_secs);
100
101 let pool = PgPoolOptions::new()
102 .max_connections(pool_size)
103 .acquire_timeout(acquire_timeout)
105 .idle_timeout(idle_timeout.or(Some(Duration::from_secs(600))))
107 .max_lifetime(max_lifetime.or(Some(Duration::from_secs(1800))))
109 .connect(connection_string)
110 .await
111 .map_err(|e| {
112 let error_msg = e.to_string();
114 let sanitized_msg = crate::sanitize_connection_string(&error_msg);
115 sqlx::Error::Configuration(sanitized_msg.into())
116 })?;
117
118 let register_query = format!(
122 r#"
123 WITH inserted AS (
124 INSERT INTO {table_name} ({jsonb_column})
125 VALUES ($1::jsonb)
126 ON CONFLICT ({jsonb_column}) DO NOTHING
127 RETURNING {id_column}
128 )
129 SELECT {id_column} FROM inserted
130 UNION ALL
131 SELECT {id_column} FROM {table_name}
132 WHERE {jsonb_column} = $2::jsonb
133 AND NOT EXISTS (SELECT 1 FROM inserted)
134 LIMIT 1
135 "#
136 );
137
138 let register_batch_query = format!(
143 r#"
144 WITH input_objects AS (
145 SELECT
146 ord as original_order,
147 value as json_value
148 FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
149 ),
150 inserted AS (
151 INSERT INTO {table_name} ({jsonb_column})
152 SELECT json_value FROM input_objects
153 ON CONFLICT ({jsonb_column}) DO NOTHING
154 RETURNING {id_column}, {jsonb_column}
155 ),
156 existing AS (
157 SELECT t.{id_column}, t.{jsonb_column}
158 FROM {table_name} t
159 JOIN input_objects io ON t.{jsonb_column} = io.json_value
160 )
161 SELECT COALESCE(i.{id_column}, e.{id_column}) as {id_column}, io.original_order
162 FROM input_objects io
163 LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
164 LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
165 ORDER BY io.original_order
166 "#
167 );
168
169 Ok(Self {
170 pool,
171 register_query,
172 register_batch_query,
173 })
174 }
175
176 pub async fn register_object(&self, json_str: &str) -> Result<i32, sqlx::Error> {
186 let row: PgRow = sqlx::query(&self.register_query)
187 .bind(json_str) .bind(json_str) .fetch_one(&self.pool)
190 .await?;
191
192 row.try_get(0)
193 }
194
195 pub async fn register_batch_objects(
205 &self,
206 json_strs: &[String],
207 ) -> Result<Vec<i32>, sqlx::Error> {
208 if json_strs.is_empty() {
209 return Ok(vec![]);
210 }
211
212 let rows = sqlx::query(&self.register_batch_query)
213 .bind(json_strs) .fetch_all(&self.pool)
215 .await?;
216
217 let mut ids = Vec::with_capacity(rows.len());
218 for row in rows {
219 let id: i32 = row.try_get(0)?;
220 ids.push(id);
221 }
222
223 Ok(ids)
224 }
225
226 pub fn pool_size(&self) -> usize {
235 self.pool.size() as usize
236 }
237
238 pub fn idle_connections(&self) -> usize {
247 self.pool.num_idle()
248 }
249
250 pub fn is_closed(&self) -> bool {
258 self.pool.is_closed()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_validate_sql_identifier_valid() {
268 assert!(validate_sql_identifier("table_name", "test").is_ok());
270 assert!(validate_sql_identifier("_underscore", "test").is_ok());
271 assert!(validate_sql_identifier("table123", "test").is_ok());
272 assert!(validate_sql_identifier("CamelCase", "test").is_ok());
273 assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
274 }
275
276 #[test]
277 fn test_validate_sql_identifier_empty() {
278 let result = validate_sql_identifier("", "test_name");
280 assert!(result.is_err());
281 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
282 }
283
284 #[test]
285 fn test_validate_sql_identifier_too_long() {
286 let long_name = "a".repeat(64);
288 let result = validate_sql_identifier(&long_name, "test_name");
289 assert!(result.is_err());
290 assert!(result
291 .unwrap_err()
292 .to_string()
293 .contains("63 character limit"));
294 }
295
296 #[test]
297 fn test_validate_sql_identifier_starts_with_number() {
298 let result = validate_sql_identifier("123table", "test_name");
300 assert!(result.is_err());
301 assert!(result
302 .unwrap_err()
303 .to_string()
304 .contains("must start with a letter or underscore"));
305 }
306
307 #[test]
308 fn test_validate_sql_identifier_invalid_characters() {
309 let test_cases = vec![
311 "table-name", "table.name", "table name", "table;name", "table'name", "table\"name", "table(name)", "table*name", "table/name", ];
321
322 for test_case in test_cases {
323 let result = validate_sql_identifier(test_case, "test_name");
324 assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
325 assert!(result
326 .unwrap_err()
327 .to_string()
328 .contains("invalid character"));
329 }
330 }
331
332 #[test]
333 fn test_validate_sql_identifier_boundary_cases() {
334 assert!(validate_sql_identifier("a", "test").is_ok()); assert!(validate_sql_identifier("_", "test").is_ok()); let exactly_63 = "a".repeat(63);
339 assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
340 }
341}