json_register/
db.rs

1use sqlx::postgres::PgPoolOptions;
2use sqlx::{PgPool, Row};
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6/// Validates that an SQL identifier (table or column name) is safe to use.
7///
8/// # Arguments
9///
10/// * `identifier` - The identifier to validate.
11/// * `name` - A descriptive name for error messages (e.g., "table_name", "column_name").
12///
13/// # Returns
14///
15/// `Ok(())` if valid, or an error describing the issue.
16fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), sqlx::Error> {
17    if identifier.is_empty() {
18        return Err(sqlx::Error::Configuration(
19            format!("{} cannot be empty", name).into(),
20        ));
21    }
22
23    if identifier.len() > 63 {
24        return Err(sqlx::Error::Configuration(
25            format!("{} exceeds PostgreSQL's 63 character limit", name).into(),
26        ));
27    }
28
29    // Validate that identifier contains only safe characters: alphanumeric, underscore
30    // Must start with a letter or underscore
31    let first_char = identifier.chars().next().unwrap();
32    if !first_char.is_ascii_alphabetic() && first_char != '_' {
33        return Err(sqlx::Error::Configuration(
34            format!("{} must start with a letter or underscore", name).into(),
35        ));
36    }
37
38    for c in identifier.chars() {
39        if !c.is_ascii_alphanumeric() && c != '_' {
40            return Err(sqlx::Error::Configuration(
41                format!(
42                    "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
43                    name, c
44                )
45                .into(),
46            ));
47        }
48    }
49
50    Ok(())
51}
52
53/// Handles database interactions for registering JSON objects.
54///
55/// This struct manages the connection pool and executes SQL queries to insert
56/// or retrieve JSON objects. It uses optimized queries to handle concurrency
57/// and minimize round-trips.
58pub struct Db {
59    pool: PgPool,
60    register_query: String,
61    register_batch_query: String,
62    queries_executed: AtomicU64,
63    query_errors: AtomicU64,
64}
65
66impl Db {
67    /// Creates a new `Db` instance.
68    ///
69    /// # Arguments
70    ///
71    /// * `connection_string` - The PostgreSQL connection string.
72    /// * `table_name` - The name of the table.
73    /// * `id_column` - The name of the ID column.
74    /// * `jsonb_column` - The name of the JSONB column.
75    /// * `pool_size` - The maximum number of connections in the pool.
76    /// * `acquire_timeout_secs` - Optional timeout for acquiring connections (default: 5s).
77    /// * `idle_timeout_secs` - Optional timeout for idle connections (default: 600s).
78    /// * `max_lifetime_secs` - Optional maximum lifetime for connections (default: 1800s).
79    ///
80    /// # Returns
81    ///
82    /// A `Result` containing the new `Db` instance or a `sqlx::Error`.
83    #[allow(clippy::too_many_arguments)]
84    pub async fn new(
85        connection_string: &str,
86        table_name: &str,
87        id_column: &str,
88        jsonb_column: &str,
89        pool_size: u32,
90        acquire_timeout_secs: Option<u64>,
91        idle_timeout_secs: Option<u64>,
92        max_lifetime_secs: Option<u64>,
93    ) -> Result<Self, sqlx::Error> {
94        // Validate SQL identifiers to prevent SQL injection
95        validate_sql_identifier(table_name, "table_name")?;
96        validate_sql_identifier(id_column, "id_column")?;
97        validate_sql_identifier(jsonb_column, "jsonb_column")?;
98
99        // Use provided timeouts or sensible defaults
100        let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
101        let idle_timeout = idle_timeout_secs.map(Duration::from_secs);
102        let max_lifetime = max_lifetime_secs.map(Duration::from_secs);
103
104        let pool = PgPoolOptions::new()
105            .max_connections(pool_size)
106            // Acquire timeout: get a connection from the pool
107            .acquire_timeout(acquire_timeout)
108            // Idle timeout: close connections idle for too long (default: 10 min)
109            .idle_timeout(idle_timeout.or(Some(Duration::from_secs(600))))
110            // Max lifetime: close connections after max age (default: 30 min)
111            .max_lifetime(max_lifetime.or(Some(Duration::from_secs(1800))))
112            .connect(connection_string)
113            .await
114            .map_err(|e| {
115                // Sanitize any connection strings that might appear in error messages
116                let error_msg = e.to_string();
117                let sanitized_msg = crate::sanitize_connection_string(&error_msg);
118                sqlx::Error::Configuration(sanitized_msg.into())
119            })?;
120
121        // Query to register a single object.
122        // It attempts to insert the object. If it exists (ON CONFLICT), it does nothing.
123        // Then it selects the ID, either from the inserted row or the existing row.
124        let register_query = format!(
125            r#"
126            WITH inserted AS (
127                INSERT INTO {table_name} ({jsonb_column})
128                VALUES ($1::jsonb)
129                ON CONFLICT ({jsonb_column}) DO NOTHING
130                RETURNING {id_column}
131            )
132            SELECT {id_column} FROM inserted
133            UNION ALL
134            SELECT {id_column} FROM {table_name}
135            WHERE {jsonb_column} = $2::jsonb
136              AND NOT EXISTS (SELECT 1 FROM inserted)
137            LIMIT 1
138            "#
139        );
140
141        // Query to register a batch of objects.
142        // It uses `unnest` to handle the array of inputs, attempts to insert new ones,
143        // and then joins the results to ensure every input gets its corresponding ID
144        // in the correct order.
145        let register_batch_query = format!(
146            r#"
147            WITH input_objects AS (
148                SELECT
149                    ord as original_order,
150                    value as json_value
151                FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
152            ),
153            inserted AS (
154                INSERT INTO {table_name} ({jsonb_column})
155                SELECT json_value FROM input_objects
156                ON CONFLICT ({jsonb_column}) DO NOTHING
157                RETURNING {id_column}, {jsonb_column}
158            ),
159            existing AS (
160                SELECT t.{id_column}, t.{jsonb_column}
161                FROM {table_name} t
162                JOIN input_objects io ON t.{jsonb_column} = io.json_value
163            )
164            SELECT COALESCE(i.{id_column}, e.{id_column}) as {id_column}, io.original_order
165            FROM input_objects io
166            LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
167            LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
168            ORDER BY io.original_order
169            "#
170        );
171
172        Ok(Self {
173            pool,
174            register_query,
175            register_batch_query,
176            queries_executed: AtomicU64::new(0),
177            query_errors: AtomicU64::new(0),
178        })
179    }
180
181    /// Registers a single JSON object string in the database.
182    ///
183    /// # Arguments
184    ///
185    /// * `json_str` - The canonicalised JSON string.
186    ///
187    /// # Returns
188    ///
189    /// A `Result` containing the ID (i32) or a `sqlx::Error`.
190    pub async fn register_object(&self, json_str: &str) -> Result<i32, sqlx::Error> {
191        self.queries_executed.fetch_add(1, Ordering::Relaxed);
192
193        let result = sqlx::query(&self.register_query)
194            .bind(json_str) // $1
195            .bind(json_str) // $2
196            .fetch_one(&self.pool)
197            .await;
198
199        match result {
200            Ok(row) => row.try_get(0),
201            Err(e) => {
202                self.query_errors.fetch_add(1, Ordering::Relaxed);
203                Err(e)
204            }
205        }
206    }
207
208    /// Registers a batch of JSON object strings in the database.
209    ///
210    /// # Arguments
211    ///
212    /// * `json_strs` - A slice of canonicalised JSON strings.
213    ///
214    /// # Returns
215    ///
216    /// A `Result` containing a vector of IDs or a `sqlx::Error`.
217    pub async fn register_batch_objects(
218        &self,
219        json_strs: &[String],
220    ) -> Result<Vec<i32>, sqlx::Error> {
221        if json_strs.is_empty() {
222            return Ok(vec![]);
223        }
224
225        self.queries_executed.fetch_add(1, Ordering::Relaxed);
226
227        let result = sqlx::query(&self.register_batch_query)
228            .bind(json_strs) // $1::jsonb[]
229            .fetch_all(&self.pool)
230            .await;
231
232        match result {
233            Ok(rows) => {
234                let mut ids = Vec::with_capacity(rows.len());
235                for row in rows {
236                    let id: i32 = row.try_get(0)?;
237                    ids.push(id);
238                }
239                Ok(ids)
240            }
241            Err(e) => {
242                self.query_errors.fetch_add(1, Ordering::Relaxed);
243                Err(e)
244            }
245        }
246    }
247
248    /// Returns the current size of the connection pool.
249    ///
250    /// This is the total number of connections (both idle and active) currently
251    /// in the pool. Useful for monitoring pool utilization.
252    ///
253    /// # Returns
254    ///
255    /// The number of connections in the pool.
256    pub fn pool_size(&self) -> usize {
257        self.pool.size() as usize
258    }
259
260    /// Returns the number of idle connections in the pool.
261    ///
262    /// Idle connections are available for immediate use. A low idle count
263    /// during high load may indicate the pool is undersized.
264    ///
265    /// # Returns
266    ///
267    /// The number of idle connections.
268    pub fn idle_connections(&self) -> usize {
269        self.pool.num_idle()
270    }
271
272    /// Checks if the connection pool is closed.
273    ///
274    /// A closed pool cannot create new connections and will error on acquire attempts.
275    ///
276    /// # Returns
277    ///
278    /// `true` if the pool is closed, `false` otherwise.
279    pub fn is_closed(&self) -> bool {
280        self.pool.is_closed()
281    }
282
283    /// Returns the total number of database queries executed.
284    ///
285    /// # Returns
286    ///
287    /// The total number of queries executed since instance creation.
288    pub fn queries_executed(&self) -> u64 {
289        self.queries_executed.load(Ordering::Relaxed)
290    }
291
292    /// Returns the total number of database query errors.
293    ///
294    /// # Returns
295    ///
296    /// The total number of failed queries since instance creation.
297    pub fn query_errors(&self) -> u64 {
298        self.query_errors.load(Ordering::Relaxed)
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_validate_sql_identifier_valid() {
308        // Valid identifiers should pass
309        assert!(validate_sql_identifier("table_name", "test").is_ok());
310        assert!(validate_sql_identifier("_underscore", "test").is_ok());
311        assert!(validate_sql_identifier("table123", "test").is_ok());
312        assert!(validate_sql_identifier("CamelCase", "test").is_ok());
313        assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
314    }
315
316    #[test]
317    fn test_validate_sql_identifier_empty() {
318        // Empty identifier should fail
319        let result = validate_sql_identifier("", "test_name");
320        assert!(result.is_err());
321        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
322    }
323
324    #[test]
325    fn test_validate_sql_identifier_too_long() {
326        // Identifier exceeding 63 characters should fail
327        let long_name = "a".repeat(64);
328        let result = validate_sql_identifier(&long_name, "test_name");
329        assert!(result.is_err());
330        assert!(result
331            .unwrap_err()
332            .to_string()
333            .contains("63 character limit"));
334    }
335
336    #[test]
337    fn test_validate_sql_identifier_starts_with_number() {
338        // Identifier starting with number should fail
339        let result = validate_sql_identifier("123table", "test_name");
340        assert!(result.is_err());
341        assert!(result
342            .unwrap_err()
343            .to_string()
344            .contains("must start with a letter or underscore"));
345    }
346
347    #[test]
348    fn test_validate_sql_identifier_invalid_characters() {
349        // Identifiers with special characters should fail
350        let test_cases = vec![
351            "table-name",  // hyphen
352            "table.name",  // dot
353            "table name",  // space
354            "table;name",  // semicolon
355            "table'name",  // quote
356            "table\"name", // double quote
357            "table(name)", // parentheses
358            "table*name",  // asterisk
359            "table/name",  // slash
360        ];
361
362        for test_case in test_cases {
363            let result = validate_sql_identifier(test_case, "test_name");
364            assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
365            assert!(result
366                .unwrap_err()
367                .to_string()
368                .contains("invalid character"));
369        }
370    }
371
372    #[test]
373    fn test_validate_sql_identifier_boundary_cases() {
374        // Test boundary cases
375        assert!(validate_sql_identifier("a", "test").is_ok()); // Single character
376        assert!(validate_sql_identifier("_", "test").is_ok()); // Just underscore
377
378        let exactly_63 = "a".repeat(63);
379        assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
380    }
381}