json_register/
db.rs

1use deadpool::managed::{PoolError, QueueMode};
2use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
3use rustls::{ClientConfig, RootCertStore};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use tokio_postgres::NoTls;
7use tokio_postgres_rustls::MakeRustlsConnect;
8use tracing::{debug, info, instrument, trace};
9
10/// Validates that an SQL identifier (table or column name) is safe to use.
11///
12/// # Arguments
13///
14/// * `identifier` - The identifier to validate.
15/// * `name` - A descriptive name for error messages (e.g., "table_name", "column_name").
16///
17/// # Returns
18///
19/// `Ok(())` if valid, or an error describing the issue.
20fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), String> {
21    if identifier.is_empty() {
22        return Err(format!("{} cannot be empty", name));
23    }
24
25    if identifier.len() > 63 {
26        return Err(format!("{} exceeds PostgreSQL's 63 character limit", name));
27    }
28
29    // Validate that identifier contains only safe characters: alphanumeric, underscore
30    // Must start with a letter or underscore
31    let first_char = identifier.chars().next().unwrap();
32    if !first_char.is_ascii_alphabetic() && first_char != '_' {
33        return Err(format!("{} must start with a letter or underscore", name));
34    }
35
36    for c in identifier.chars() {
37        if !c.is_ascii_alphanumeric() && c != '_' {
38            return Err(format!(
39                "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
40                name, c
41            ));
42        }
43    }
44
45    Ok(())
46}
47
48/// Handles database interactions for registering JSON objects.
49///
50/// This struct manages the connection pool and executes SQL queries to insert
51/// or retrieve JSON objects. It uses optimized queries to handle concurrency
52/// and minimize round-trips.
53pub struct Db {
54    pool: Pool,
55    register_query: String,
56    register_batch_query: String,
57    queries_executed: AtomicU64,
58    query_errors: AtomicU64,
59}
60
61impl Db {
62    /// Creates a new `Db` instance.
63    ///
64    /// # Arguments
65    ///
66    /// * `connection_string` - The PostgreSQL connection string.
67    /// * `table_name` - The name of the table.
68    /// * `id_column` - The name of the ID column.
69    /// * `jsonb_column` - The name of the JSONB column.
70    /// * `pool_size` - The maximum number of connections in the pool.
71    /// * `acquire_timeout_secs` - Optional timeout for acquiring connections (default: 5s).
72    /// * `idle_timeout_secs` - Optional timeout for idle connections (default: 600s).
73    /// * `max_lifetime_secs` - Optional maximum lifetime for connections (default: 1800s).
74    /// * `use_tls` - Optional flag to enable TLS (default: false for backwards compatibility).
75    ///
76    /// # Returns
77    ///
78    /// A `Result` containing the new `Db` instance or a `JsonRegisterError`.
79    #[allow(clippy::too_many_arguments)]
80    pub async fn new(
81        connection_string: &str,
82        table_name: &str,
83        id_column: &str,
84        jsonb_column: &str,
85        pool_size: u32,
86        acquire_timeout_secs: Option<u64>,
87        idle_timeout_secs: Option<u64>,
88        max_lifetime_secs: Option<u64>,
89        use_tls: Option<bool>,
90    ) -> Result<Self, crate::errors::JsonRegisterError> {
91        // Validate SQL identifiers to prevent SQL injection
92        validate_sql_identifier(table_name, "table_name")
93            .map_err(crate::errors::JsonRegisterError::Configuration)?;
94        validate_sql_identifier(id_column, "id_column")
95            .map_err(crate::errors::JsonRegisterError::Configuration)?;
96        validate_sql_identifier(jsonb_column, "jsonb_column")
97            .map_err(crate::errors::JsonRegisterError::Configuration)?;
98
99        // Use provided timeouts or sensible defaults
100        let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
101        let _idle_timeout = idle_timeout_secs.map(Duration::from_secs);
102        let _max_lifetime = max_lifetime_secs.map(Duration::from_secs);
103
104        // Parse connection string into deadpool config
105        let mut cfg = Config::new();
106        cfg.url = Some(connection_string.to_string());
107        cfg.manager = Some(ManagerConfig {
108            recycling_method: RecyclingMethod::Fast,
109        });
110        cfg.pool = Some(deadpool_postgres::PoolConfig {
111            max_size: pool_size as usize,
112            timeouts: deadpool_postgres::Timeouts {
113                wait: Some(acquire_timeout),
114                create: Some(Duration::from_secs(10)),
115                recycle: Some(Duration::from_secs(10)),
116            },
117            queue_mode: QueueMode::Fifo,
118        });
119
120        let pool = if use_tls.unwrap_or(false) {
121            // Create TLS connector with system root certificates
122            let mut root_store = RootCertStore::empty();
123            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
124
125            let config = ClientConfig::builder()
126                .with_root_certificates(root_store)
127                .with_no_client_auth();
128
129            let tls = MakeRustlsConnect::new(config);
130
131            cfg.create_pool(Some(Runtime::Tokio1), tls).map_err(|e| {
132                let error_msg = e.to_string();
133                let sanitized_msg = crate::sanitize_connection_string(&error_msg);
134                crate::errors::JsonRegisterError::Configuration(sanitized_msg)
135            })?
136        } else {
137            cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
138                let error_msg = e.to_string();
139                let sanitized_msg = crate::sanitize_connection_string(&error_msg);
140                crate::errors::JsonRegisterError::Configuration(sanitized_msg)
141            })?
142        };
143
144        // Query to register a single object.
145        // It attempts to insert the object. If it exists (ON CONFLICT), it does nothing.
146        // Then it selects the ID, either from the inserted row or the existing row.
147        let register_query = format!(
148            r#"
149            WITH inserted AS (
150                INSERT INTO {table_name} ({jsonb_column})
151                VALUES ($1)
152                ON CONFLICT ({jsonb_column}) DO NOTHING
153                RETURNING {id_column}
154            )
155            SELECT {id_column}::INT FROM inserted
156            UNION ALL
157            SELECT {id_column}::INT FROM {table_name}
158            WHERE {jsonb_column} = $2
159              AND NOT EXISTS (SELECT 1 FROM inserted)
160            LIMIT 1
161            "#
162        );
163
164        // Query to register a batch of objects.
165        // It uses `unnest` to handle the array of inputs, attempts to insert new ones,
166        // and then joins the results to ensure every input gets its corresponding ID
167        // in the correct order.
168        let register_batch_query = format!(
169            r#"
170            WITH input_objects AS (
171                SELECT
172                    ord as original_order,
173                    value as json_value
174                FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
175            ),
176            inserted AS (
177                INSERT INTO {table_name} ({jsonb_column})
178                SELECT json_value FROM input_objects
179                ON CONFLICT ({jsonb_column}) DO NOTHING
180                RETURNING {id_column}, {jsonb_column}
181            ),
182            existing AS (
183                SELECT t.{id_column}, t.{jsonb_column}
184                FROM {table_name} t
185                JOIN input_objects io ON t.{jsonb_column} = io.json_value
186            )
187            SELECT COALESCE(i.{id_column}, e.{id_column})::INT as {id_column}, io.original_order
188            FROM input_objects io
189            LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
190            LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
191            ORDER BY io.original_order
192            "#
193        );
194
195        info!(
196            table = table_name,
197            pool_size,
198            tls = use_tls.unwrap_or(false),
199            "JSON register database connected"
200        );
201
202        Ok(Self {
203            pool,
204            register_query,
205            register_batch_query,
206            queries_executed: AtomicU64::new(0),
207            query_errors: AtomicU64::new(0),
208        })
209    }
210
211    /// Registers a single JSON object in the database.
212    ///
213    /// # Arguments
214    ///
215    /// * `value` - The JSON value to register.
216    ///
217    /// # Returns
218    ///
219    /// A `Result` containing the ID (i32) or a `tokio_postgres::Error`.
220    #[instrument(skip(self, value), fields(query_type = "single"))]
221    pub async fn register_object(
222        &self,
223        value: &serde_json::Value,
224    ) -> Result<i32, tokio_postgres::Error> {
225        self.queries_executed.fetch_add(1, Ordering::Relaxed);
226        trace!("acquiring database connection");
227
228        let client = self.pool.get().await.map_err(|e| {
229            self.query_errors.fetch_add(1, Ordering::Relaxed);
230            match e {
231                PoolError::Backend(db_err) => db_err,
232                PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
233                _ => tokio_postgres::Error::__private_api_timeout(),
234            }
235        })?;
236
237        trace!("executing register query");
238        let result = client
239            .query_one(&self.register_query, &[value, value])
240            .await;
241
242        match result {
243            Ok(row) => {
244                // Explicitly cast to INT in SQL ensures we always get i32
245                let id: i32 = row.get(0);
246                debug!(id, "object registered");
247                Ok(id)
248            }
249            Err(e) => {
250                self.query_errors.fetch_add(1, Ordering::Relaxed);
251                debug!(error = %e, "register query failed");
252                Err(e)
253            }
254        }
255    }
256
257    /// Registers a batch of JSON objects in the database.
258    ///
259    /// # Arguments
260    ///
261    /// * `values` - A slice of JSON values to register.
262    ///
263    /// # Returns
264    ///
265    /// A `Result` containing a vector of IDs or a `tokio_postgres::Error`.
266    #[instrument(skip(self, values), fields(query_type = "batch", batch_size = values.len()))]
267    pub async fn register_batch_objects(
268        &self,
269        values: &[serde_json::Value],
270    ) -> Result<Vec<i32>, tokio_postgres::Error> {
271        if values.is_empty() {
272            trace!("empty batch, returning early");
273            return Ok(vec![]);
274        }
275
276        self.queries_executed.fetch_add(1, Ordering::Relaxed);
277        trace!("acquiring database connection");
278
279        let client = self.pool.get().await.map_err(|e| {
280            self.query_errors.fetch_add(1, Ordering::Relaxed);
281            match e {
282                PoolError::Backend(db_err) => db_err,
283                PoolError::Timeout(_) => tokio_postgres::Error::__private_api_timeout(),
284                _ => tokio_postgres::Error::__private_api_timeout(),
285            }
286        })?;
287
288        trace!("executing batch register query");
289        let values_vec: Vec<&serde_json::Value> = values.iter().collect();
290        let result = client
291            .query(&self.register_batch_query, &[&values_vec])
292            .await;
293
294        match result {
295            Ok(rows) => {
296                let mut ids = Vec::with_capacity(rows.len());
297                for row in rows {
298                    // Explicitly cast to INT in SQL ensures we always get i32
299                    let id: i32 = row.get(0);
300                    ids.push(id);
301                }
302                debug!(registered = ids.len(), "batch registered");
303                Ok(ids)
304            }
305            Err(e) => {
306                self.query_errors.fetch_add(1, Ordering::Relaxed);
307                debug!(error = %e, "batch register query failed");
308                Err(e)
309            }
310        }
311    }
312
313    /// Returns the current size of the connection pool.
314    ///
315    /// This is the total number of connections (both idle and active) currently
316    /// in the pool. Useful for monitoring pool utilization.
317    ///
318    /// # Returns
319    ///
320    /// The number of connections in the pool.
321    pub fn pool_size(&self) -> usize {
322        self.pool.status().size
323    }
324
325    /// Returns the number of idle connections in the pool.
326    ///
327    /// Idle connections are available for immediate use. A low idle count
328    /// during high load may indicate the pool is undersized.
329    ///
330    /// # Returns
331    ///
332    /// The number of idle connections.
333    pub fn idle_connections(&self) -> usize {
334        self.pool.status().available
335    }
336
337    /// Checks if the connection pool is closed.
338    ///
339    /// A closed pool cannot create new connections and will error on acquire attempts.
340    ///
341    /// # Returns
342    ///
343    /// `true` if the pool is closed, `false` otherwise.
344    pub fn is_closed(&self) -> bool {
345        // deadpool doesn't have is_closed, check if pool is available
346        self.pool.status().max_size == 0
347    }
348
349    /// Returns the total number of database queries executed.
350    ///
351    /// # Returns
352    ///
353    /// The total number of queries executed since instance creation.
354    pub fn queries_executed(&self) -> u64 {
355        self.queries_executed.load(Ordering::Relaxed)
356    }
357
358    /// Returns the total number of database query errors.
359    ///
360    /// # Returns
361    ///
362    /// The total number of failed queries since instance creation.
363    pub fn query_errors(&self) -> u64 {
364        self.query_errors.load(Ordering::Relaxed)
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_validate_sql_identifier_valid() {
374        // Valid identifiers should pass
375        assert!(validate_sql_identifier("table_name", "test").is_ok());
376        assert!(validate_sql_identifier("_underscore", "test").is_ok());
377        assert!(validate_sql_identifier("table123", "test").is_ok());
378        assert!(validate_sql_identifier("CamelCase", "test").is_ok());
379        assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
380    }
381
382    #[test]
383    fn test_validate_sql_identifier_empty() {
384        // Empty identifier should fail
385        let result = validate_sql_identifier("", "test_name");
386        assert!(result.is_err());
387        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
388    }
389
390    #[test]
391    fn test_validate_sql_identifier_too_long() {
392        // Identifier exceeding 63 characters should fail
393        let long_name = "a".repeat(64);
394        let result = validate_sql_identifier(&long_name, "test_name");
395        assert!(result.is_err());
396        assert!(result
397            .unwrap_err()
398            .to_string()
399            .contains("63 character limit"));
400    }
401
402    #[test]
403    fn test_validate_sql_identifier_starts_with_number() {
404        // Identifier starting with number should fail
405        let result = validate_sql_identifier("123table", "test_name");
406        assert!(result.is_err());
407        assert!(result
408            .unwrap_err()
409            .to_string()
410            .contains("must start with a letter or underscore"));
411    }
412
413    #[test]
414    fn test_validate_sql_identifier_invalid_characters() {
415        // Identifiers with special characters should fail
416        let test_cases = vec![
417            "table-name",  // hyphen
418            "table.name",  // dot
419            "table name",  // space
420            "table;name",  // semicolon
421            "table'name",  // quote
422            "table\"name", // double quote
423            "table(name)", // parentheses
424            "table*name",  // asterisk
425            "table/name",  // slash
426        ];
427
428        for test_case in test_cases {
429            let result = validate_sql_identifier(test_case, "test_name");
430            assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
431            assert!(result
432                .unwrap_err()
433                .to_string()
434                .contains("invalid character"));
435        }
436    }
437
438    #[test]
439    fn test_validate_sql_identifier_boundary_cases() {
440        // Test boundary cases
441        assert!(validate_sql_identifier("a", "test").is_ok()); // Single character
442        assert!(validate_sql_identifier("_", "test").is_ok()); // Just underscore
443
444        let exactly_63 = "a".repeat(63);
445        assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
446    }
447}