json_register/
db.rs

1use deadpool::managed::{PoolError, QueueMode};
2use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
3use rustls::{ClientConfig, RootCertStore};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use tokio_postgres::NoTls;
7use tokio_postgres_rustls::MakeRustlsConnect;
8
9/// Validates that an SQL identifier (table or column name) is safe to use.
10///
11/// # Arguments
12///
13/// * `identifier` - The identifier to validate.
14/// * `name` - A descriptive name for error messages (e.g., "table_name", "column_name").
15///
16/// # Returns
17///
18/// `Ok(())` if valid, or an error describing the issue.
19fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), String> {
20    if identifier.is_empty() {
21        return Err(format!("{} cannot be empty", name));
22    }
23
24    if identifier.len() > 63 {
25        return Err(format!("{} exceeds PostgreSQL's 63 character limit", name));
26    }
27
28    // Validate that identifier contains only safe characters: alphanumeric, underscore
29    // Must start with a letter or underscore
30    let first_char = identifier.chars().next().unwrap();
31    if !first_char.is_ascii_alphabetic() && first_char != '_' {
32        return Err(format!("{} must start with a letter or underscore", name));
33    }
34
35    for c in identifier.chars() {
36        if !c.is_ascii_alphanumeric() && c != '_' {
37            return Err(format!(
38                "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
39                name, c
40            ));
41        }
42    }
43
44    Ok(())
45}
46
47/// Handles database interactions for registering JSON objects.
48///
49/// This struct manages the connection pool and executes SQL queries to insert
50/// or retrieve JSON objects. It uses optimized queries to handle concurrency
51/// and minimize round-trips.
52pub struct Db {
53    pool: Pool,
54    register_query: String,
55    register_batch_query: String,
56    queries_executed: AtomicU64,
57    query_errors: AtomicU64,
58}
59
60impl Db {
61    /// Creates a new `Db` instance.
62    ///
63    /// # Arguments
64    ///
65    /// * `connection_string` - The PostgreSQL connection string.
66    /// * `table_name` - The name of the table.
67    /// * `id_column` - The name of the ID column.
68    /// * `jsonb_column` - The name of the JSONB column.
69    /// * `pool_size` - The maximum number of connections in the pool.
70    /// * `acquire_timeout_secs` - Optional timeout for acquiring connections (default: 5s).
71    /// * `idle_timeout_secs` - Optional timeout for idle connections (default: 600s).
72    /// * `max_lifetime_secs` - Optional maximum lifetime for connections (default: 1800s).
73    /// * `use_tls` - Optional flag to enable TLS (default: false for backwards compatibility).
74    ///
75    /// # Returns
76    ///
77    /// A `Result` containing the new `Db` instance or a `JsonRegisterError`.
78    #[allow(clippy::too_many_arguments)]
79    pub async fn new(
80        connection_string: &str,
81        table_name: &str,
82        id_column: &str,
83        jsonb_column: &str,
84        pool_size: u32,
85        acquire_timeout_secs: Option<u64>,
86        idle_timeout_secs: Option<u64>,
87        max_lifetime_secs: Option<u64>,
88        use_tls: Option<bool>,
89    ) -> Result<Self, crate::errors::JsonRegisterError> {
90        // Validate SQL identifiers to prevent SQL injection
91        validate_sql_identifier(table_name, "table_name")
92            .map_err(crate::errors::JsonRegisterError::Configuration)?;
93        validate_sql_identifier(id_column, "id_column")
94            .map_err(crate::errors::JsonRegisterError::Configuration)?;
95        validate_sql_identifier(jsonb_column, "jsonb_column")
96            .map_err(crate::errors::JsonRegisterError::Configuration)?;
97
98        // Use provided timeouts or sensible defaults
99        let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
100        let _idle_timeout = idle_timeout_secs.map(Duration::from_secs);
101        let _max_lifetime = max_lifetime_secs.map(Duration::from_secs);
102
103        // Parse connection string into deadpool config
104        let mut cfg = Config::new();
105        cfg.url = Some(connection_string.to_string());
106        cfg.manager = Some(ManagerConfig {
107            recycling_method: RecyclingMethod::Fast,
108        });
109        cfg.pool = Some(deadpool_postgres::PoolConfig {
110            max_size: pool_size as usize,
111            timeouts: deadpool_postgres::Timeouts {
112                wait: Some(acquire_timeout),
113                create: Some(Duration::from_secs(10)),
114                recycle: Some(Duration::from_secs(10)),
115            },
116            queue_mode: QueueMode::Fifo,
117        });
118
119        let pool = if use_tls.unwrap_or(false) {
120            // Create TLS connector with system root certificates
121            let mut root_store = RootCertStore::empty();
122            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
123
124            let config = ClientConfig::builder()
125                .with_root_certificates(root_store)
126                .with_no_client_auth();
127
128            let tls = MakeRustlsConnect::new(config);
129
130            cfg.create_pool(Some(Runtime::Tokio1), tls).map_err(|e| {
131                let error_msg = e.to_string();
132                let sanitized_msg = crate::sanitize_connection_string(&error_msg);
133                crate::errors::JsonRegisterError::Configuration(sanitized_msg)
134            })?
135        } else {
136            cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
137                let error_msg = e.to_string();
138                let sanitized_msg = crate::sanitize_connection_string(&error_msg);
139                crate::errors::JsonRegisterError::Configuration(sanitized_msg)
140            })?
141        };
142
143        // Query to register a single object.
144        // It attempts to insert the object. If it exists (ON CONFLICT), it does nothing.
145        // Then it selects the ID, either from the inserted row or the existing row.
146        let register_query = format!(
147            r#"
148            WITH inserted AS (
149                INSERT INTO {table_name} ({jsonb_column})
150                VALUES ($1)
151                ON CONFLICT ({jsonb_column}) DO NOTHING
152                RETURNING {id_column}
153            )
154            SELECT {id_column}::INT FROM inserted
155            UNION ALL
156            SELECT {id_column}::INT FROM {table_name}
157            WHERE {jsonb_column} = $2
158              AND NOT EXISTS (SELECT 1 FROM inserted)
159            LIMIT 1
160            "#
161        );
162
163        // Query to register a batch of objects.
164        // It uses `unnest` to handle the array of inputs, attempts to insert new ones,
165        // and then joins the results to ensure every input gets its corresponding ID
166        // in the correct order.
167        let register_batch_query = format!(
168            r#"
169            WITH input_objects AS (
170                SELECT
171                    ord as original_order,
172                    value as json_value
173                FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
174            ),
175            inserted AS (
176                INSERT INTO {table_name} ({jsonb_column})
177                SELECT json_value FROM input_objects
178                ON CONFLICT ({jsonb_column}) DO NOTHING
179                RETURNING {id_column}, {jsonb_column}
180            ),
181            existing AS (
182                SELECT t.{id_column}, t.{jsonb_column}
183                FROM {table_name} t
184                JOIN input_objects io ON t.{jsonb_column} = io.json_value
185            )
186            SELECT COALESCE(i.{id_column}, e.{id_column})::INT as {id_column}, io.original_order
187            FROM input_objects io
188            LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
189            LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
190            ORDER BY io.original_order
191            "#
192        );
193
194        Ok(Self {
195            pool,
196            register_query,
197            register_batch_query,
198            queries_executed: AtomicU64::new(0),
199            query_errors: AtomicU64::new(0),
200        })
201    }
202
203    /// Registers a single JSON object in the database.
204    ///
205    /// # Arguments
206    ///
207    /// * `value` - The JSON value to register.
208    ///
209    /// # Returns
210    ///
211    /// A `Result` containing the ID (i32) or a `tokio_postgres::Error`.
212    pub async fn register_object(
213        &self,
214        value: &serde_json::Value,
215    ) -> Result<i32, tokio_postgres::Error> {
216        self.queries_executed.fetch_add(1, Ordering::Relaxed);
217
218        let client = self.pool.get().await.map_err(|e| {
219            self.query_errors.fetch_add(1, Ordering::Relaxed);
220            match e {
221                PoolError::Backend(db_err) => db_err,
222                PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
223                _ => tokio_postgres::Error::__private_api_timeout(),
224            }
225        })?;
226
227        let result = client
228            .query_one(&self.register_query, &[value, value])
229            .await;
230
231        match result {
232            Ok(row) => {
233                // Explicitly cast to INT in SQL ensures we always get i32
234                let id: i32 = row.get(0);
235                Ok(id)
236            }
237            Err(e) => {
238                self.query_errors.fetch_add(1, Ordering::Relaxed);
239                Err(e)
240            }
241        }
242    }
243
244    /// Registers a batch of JSON objects in the database.
245    ///
246    /// # Arguments
247    ///
248    /// * `values` - A slice of JSON values to register.
249    ///
250    /// # Returns
251    ///
252    /// A `Result` containing a vector of IDs or a `tokio_postgres::Error`.
253    pub async fn register_batch_objects(
254        &self,
255        values: &[serde_json::Value],
256    ) -> Result<Vec<i32>, tokio_postgres::Error> {
257        if values.is_empty() {
258            return Ok(vec![]);
259        }
260
261        self.queries_executed.fetch_add(1, Ordering::Relaxed);
262
263        let client = self.pool.get().await.map_err(|e| {
264            self.query_errors.fetch_add(1, Ordering::Relaxed);
265            match e {
266                PoolError::Backend(db_err) => db_err,
267                PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
268                _ => tokio_postgres::Error::__private_api_timeout(),
269            }
270        })?;
271
272        let values_vec: Vec<&serde_json::Value> = values.iter().collect();
273        let result = client
274            .query(&self.register_batch_query, &[&values_vec])
275            .await;
276
277        match result {
278            Ok(rows) => {
279                let mut ids = Vec::with_capacity(rows.len());
280                for row in rows {
281                    // Explicitly cast to INT in SQL ensures we always get i32
282                    let id: i32 = row.get(0);
283                    ids.push(id);
284                }
285                Ok(ids)
286            }
287            Err(e) => {
288                self.query_errors.fetch_add(1, Ordering::Relaxed);
289                Err(e)
290            }
291        }
292    }
293
294    /// Returns the current size of the connection pool.
295    ///
296    /// This is the total number of connections (both idle and active) currently
297    /// in the pool. Useful for monitoring pool utilization.
298    ///
299    /// # Returns
300    ///
301    /// The number of connections in the pool.
302    pub fn pool_size(&self) -> usize {
303        self.pool.status().size
304    }
305
306    /// Returns the number of idle connections in the pool.
307    ///
308    /// Idle connections are available for immediate use. A low idle count
309    /// during high load may indicate the pool is undersized.
310    ///
311    /// # Returns
312    ///
313    /// The number of idle connections.
314    pub fn idle_connections(&self) -> usize {
315        self.pool.status().available
316    }
317
318    /// Checks if the connection pool is closed.
319    ///
320    /// A closed pool cannot create new connections and will error on acquire attempts.
321    ///
322    /// # Returns
323    ///
324    /// `true` if the pool is closed, `false` otherwise.
325    pub fn is_closed(&self) -> bool {
326        // deadpool doesn't have is_closed, check if pool is available
327        self.pool.status().max_size == 0
328    }
329
330    /// Returns the total number of database queries executed.
331    ///
332    /// # Returns
333    ///
334    /// The total number of queries executed since instance creation.
335    pub fn queries_executed(&self) -> u64 {
336        self.queries_executed.load(Ordering::Relaxed)
337    }
338
339    /// Returns the total number of database query errors.
340    ///
341    /// # Returns
342    ///
343    /// The total number of failed queries since instance creation.
344    pub fn query_errors(&self) -> u64 {
345        self.query_errors.load(Ordering::Relaxed)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_validate_sql_identifier_valid() {
355        // Valid identifiers should pass
356        assert!(validate_sql_identifier("table_name", "test").is_ok());
357        assert!(validate_sql_identifier("_underscore", "test").is_ok());
358        assert!(validate_sql_identifier("table123", "test").is_ok());
359        assert!(validate_sql_identifier("CamelCase", "test").is_ok());
360        assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
361    }
362
363    #[test]
364    fn test_validate_sql_identifier_empty() {
365        // Empty identifier should fail
366        let result = validate_sql_identifier("", "test_name");
367        assert!(result.is_err());
368        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
369    }
370
371    #[test]
372    fn test_validate_sql_identifier_too_long() {
373        // Identifier exceeding 63 characters should fail
374        let long_name = "a".repeat(64);
375        let result = validate_sql_identifier(&long_name, "test_name");
376        assert!(result.is_err());
377        assert!(result
378            .unwrap_err()
379            .to_string()
380            .contains("63 character limit"));
381    }
382
383    #[test]
384    fn test_validate_sql_identifier_starts_with_number() {
385        // Identifier starting with number should fail
386        let result = validate_sql_identifier("123table", "test_name");
387        assert!(result.is_err());
388        assert!(result
389            .unwrap_err()
390            .to_string()
391            .contains("must start with a letter or underscore"));
392    }
393
394    #[test]
395    fn test_validate_sql_identifier_invalid_characters() {
396        // Identifiers with special characters should fail
397        let test_cases = vec![
398            "table-name",  // hyphen
399            "table.name",  // dot
400            "table name",  // space
401            "table;name",  // semicolon
402            "table'name",  // quote
403            "table\"name", // double quote
404            "table(name)", // parentheses
405            "table*name",  // asterisk
406            "table/name",  // slash
407        ];
408
409        for test_case in test_cases {
410            let result = validate_sql_identifier(test_case, "test_name");
411            assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
412            assert!(result
413                .unwrap_err()
414                .to_string()
415                .contains("invalid character"));
416        }
417    }
418
419    #[test]
420    fn test_validate_sql_identifier_boundary_cases() {
421        // Test boundary cases
422        assert!(validate_sql_identifier("a", "test").is_ok()); // Single character
423        assert!(validate_sql_identifier("_", "test").is_ok()); // Just underscore
424
425        let exactly_63 = "a".repeat(63);
426        assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
427    }
428}