json_register/
db.rs

1use sqlx::postgres::{PgPoolOptions, PgRow};
2use sqlx::{PgPool, Row};
3use std::time::Duration;
4
5/// Validates that an SQL identifier (table or column name) is safe to use.
6///
7/// # Arguments
8///
9/// * `identifier` - The identifier to validate.
10/// * `name` - A descriptive name for error messages (e.g., "table_name", "column_name").
11///
12/// # Returns
13///
14/// `Ok(())` if valid, or an error describing the issue.
15fn validate_sql_identifier(identifier: &str, name: &str) -> Result<(), sqlx::Error> {
16    if identifier.is_empty() {
17        return Err(sqlx::Error::Configuration(
18            format!("{} cannot be empty", name).into(),
19        ));
20    }
21
22    if identifier.len() > 63 {
23        return Err(sqlx::Error::Configuration(
24            format!("{} exceeds PostgreSQL's 63 character limit", name).into(),
25        ));
26    }
27
28    // Validate that identifier contains only safe characters: alphanumeric, underscore
29    // Must start with a letter or underscore
30    let first_char = identifier.chars().next().unwrap();
31    if !first_char.is_ascii_alphabetic() && first_char != '_' {
32        return Err(sqlx::Error::Configuration(
33            format!("{} must start with a letter or underscore", name).into(),
34        ));
35    }
36
37    for c in identifier.chars() {
38        if !c.is_ascii_alphanumeric() && c != '_' {
39            return Err(sqlx::Error::Configuration(
40                format!(
41                    "{} contains invalid character '{}'. Only alphanumeric and underscore allowed",
42                    name, c
43                )
44                .into(),
45            ));
46        }
47    }
48
49    Ok(())
50}
51
52/// Handles database interactions for registering JSON objects.
53///
54/// This struct manages the connection pool and executes SQL queries to insert
55/// or retrieve JSON objects. It uses optimized queries to handle concurrency
56/// and minimize round-trips.
57pub struct Db {
58    pool: PgPool,
59    register_query: String,
60    register_batch_query: String,
61}
62
63impl Db {
64    /// Creates a new `Db` instance.
65    ///
66    /// # Arguments
67    ///
68    /// * `connection_string` - The PostgreSQL connection string.
69    /// * `table_name` - The name of the table.
70    /// * `id_column` - The name of the ID column.
71    /// * `jsonb_column` - The name of the JSONB column.
72    /// * `pool_size` - The maximum number of connections in the pool.
73    /// * `acquire_timeout_secs` - Optional timeout for acquiring connections (default: 5s).
74    /// * `idle_timeout_secs` - Optional timeout for idle connections (default: 600s).
75    /// * `max_lifetime_secs` - Optional maximum lifetime for connections (default: 1800s).
76    ///
77    /// # Returns
78    ///
79    /// A `Result` containing the new `Db` instance or a `sqlx::Error`.
80    #[allow(clippy::too_many_arguments)]
81    pub async fn new(
82        connection_string: &str,
83        table_name: &str,
84        id_column: &str,
85        jsonb_column: &str,
86        pool_size: u32,
87        acquire_timeout_secs: Option<u64>,
88        idle_timeout_secs: Option<u64>,
89        max_lifetime_secs: Option<u64>,
90    ) -> Result<Self, sqlx::Error> {
91        // Validate SQL identifiers to prevent SQL injection
92        validate_sql_identifier(table_name, "table_name")?;
93        validate_sql_identifier(id_column, "id_column")?;
94        validate_sql_identifier(jsonb_column, "jsonb_column")?;
95
96        // Use provided timeouts or sensible defaults
97        let acquire_timeout = Duration::from_secs(acquire_timeout_secs.unwrap_or(5));
98        let idle_timeout = idle_timeout_secs.map(Duration::from_secs);
99        let max_lifetime = max_lifetime_secs.map(Duration::from_secs);
100
101        let pool = PgPoolOptions::new()
102            .max_connections(pool_size)
103            // Acquire timeout: get a connection from the pool
104            .acquire_timeout(acquire_timeout)
105            // Idle timeout: close connections idle for too long (default: 10 min)
106            .idle_timeout(idle_timeout.or(Some(Duration::from_secs(600))))
107            // Max lifetime: close connections after max age (default: 30 min)
108            .max_lifetime(max_lifetime.or(Some(Duration::from_secs(1800))))
109            .connect(connection_string)
110            .await
111            .map_err(|e| {
112                // Sanitize any connection strings that might appear in error messages
113                let error_msg = e.to_string();
114                let sanitized_msg = crate::sanitize_connection_string(&error_msg);
115                sqlx::Error::Configuration(sanitized_msg.into())
116            })?;
117
118        // Query to register a single object.
119        // It attempts to insert the object. If it exists (ON CONFLICT), it does nothing.
120        // Then it selects the ID, either from the inserted row or the existing row.
121        let register_query = format!(
122            r#"
123            WITH inserted AS (
124                INSERT INTO {table_name} ({jsonb_column})
125                VALUES ($1::jsonb)
126                ON CONFLICT ({jsonb_column}) DO NOTHING
127                RETURNING {id_column}
128            )
129            SELECT {id_column} FROM inserted
130            UNION ALL
131            SELECT {id_column} FROM {table_name}
132            WHERE {jsonb_column} = $2::jsonb
133              AND NOT EXISTS (SELECT 1 FROM inserted)
134            LIMIT 1
135            "#
136        );
137
138        // Query to register a batch of objects.
139        // It uses `unnest` to handle the array of inputs, attempts to insert new ones,
140        // and then joins the results to ensure every input gets its corresponding ID
141        // in the correct order.
142        let register_batch_query = format!(
143            r#"
144            WITH input_objects AS (
145                SELECT
146                    ord as original_order,
147                    value as json_value
148                FROM unnest($1::jsonb[]) WITH ORDINALITY AS t(value, ord)
149            ),
150            inserted AS (
151                INSERT INTO {table_name} ({jsonb_column})
152                SELECT json_value FROM input_objects
153                ON CONFLICT ({jsonb_column}) DO NOTHING
154                RETURNING {id_column}, {jsonb_column}
155            ),
156            existing AS (
157                SELECT t.{id_column}, t.{jsonb_column}
158                FROM {table_name} t
159                JOIN input_objects io ON t.{jsonb_column} = io.json_value
160            )
161            SELECT COALESCE(i.{id_column}, e.{id_column}) as {id_column}, io.original_order
162            FROM input_objects io
163            LEFT JOIN inserted i ON io.json_value = i.{jsonb_column}
164            LEFT JOIN existing e ON io.json_value = e.{jsonb_column}
165            ORDER BY io.original_order
166            "#
167        );
168
169        Ok(Self {
170            pool,
171            register_query,
172            register_batch_query,
173        })
174    }
175
176    /// Registers a single JSON object string in the database.
177    ///
178    /// # Arguments
179    ///
180    /// * `json_str` - The canonicalised JSON string.
181    ///
182    /// # Returns
183    ///
184    /// A `Result` containing the ID (i32) or a `sqlx::Error`.
185    pub async fn register_object(&self, json_str: &str) -> Result<i32, sqlx::Error> {
186        let row: PgRow = sqlx::query(&self.register_query)
187            .bind(json_str) // $1
188            .bind(json_str) // $2
189            .fetch_one(&self.pool)
190            .await?;
191
192        row.try_get(0)
193    }
194
195    /// Registers a batch of JSON object strings in the database.
196    ///
197    /// # Arguments
198    ///
199    /// * `json_strs` - A slice of canonicalised JSON strings.
200    ///
201    /// # Returns
202    ///
203    /// A `Result` containing a vector of IDs or a `sqlx::Error`.
204    pub async fn register_batch_objects(
205        &self,
206        json_strs: &[String],
207    ) -> Result<Vec<i32>, sqlx::Error> {
208        if json_strs.is_empty() {
209            return Ok(vec![]);
210        }
211
212        let rows = sqlx::query(&self.register_batch_query)
213            .bind(json_strs) // $1::jsonb[]
214            .fetch_all(&self.pool)
215            .await?;
216
217        let mut ids = Vec::with_capacity(rows.len());
218        for row in rows {
219            let id: i32 = row.try_get(0)?;
220            ids.push(id);
221        }
222
223        Ok(ids)
224    }
225
226    /// Returns the current size of the connection pool.
227    ///
228    /// This is the total number of connections (both idle and active) currently
229    /// in the pool. Useful for monitoring pool utilization.
230    ///
231    /// # Returns
232    ///
233    /// The number of connections in the pool.
234    pub fn pool_size(&self) -> usize {
235        self.pool.size() as usize
236    }
237
238    /// Returns the number of idle connections in the pool.
239    ///
240    /// Idle connections are available for immediate use. A low idle count
241    /// during high load may indicate the pool is undersized.
242    ///
243    /// # Returns
244    ///
245    /// The number of idle connections.
246    pub fn idle_connections(&self) -> usize {
247        self.pool.num_idle()
248    }
249
250    /// Checks if the connection pool is closed.
251    ///
252    /// A closed pool cannot create new connections and will error on acquire attempts.
253    ///
254    /// # Returns
255    ///
256    /// `true` if the pool is closed, `false` otherwise.
257    pub fn is_closed(&self) -> bool {
258        self.pool.is_closed()
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_validate_sql_identifier_valid() {
268        // Valid identifiers should pass
269        assert!(validate_sql_identifier("table_name", "test").is_ok());
270        assert!(validate_sql_identifier("_underscore", "test").is_ok());
271        assert!(validate_sql_identifier("table123", "test").is_ok());
272        assert!(validate_sql_identifier("CamelCase", "test").is_ok());
273        assert!(validate_sql_identifier("snake_case_123", "test").is_ok());
274    }
275
276    #[test]
277    fn test_validate_sql_identifier_empty() {
278        // Empty identifier should fail
279        let result = validate_sql_identifier("", "test_name");
280        assert!(result.is_err());
281        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
282    }
283
284    #[test]
285    fn test_validate_sql_identifier_too_long() {
286        // Identifier exceeding 63 characters should fail
287        let long_name = "a".repeat(64);
288        let result = validate_sql_identifier(&long_name, "test_name");
289        assert!(result.is_err());
290        assert!(result
291            .unwrap_err()
292            .to_string()
293            .contains("63 character limit"));
294    }
295
296    #[test]
297    fn test_validate_sql_identifier_starts_with_number() {
298        // Identifier starting with number should fail
299        let result = validate_sql_identifier("123table", "test_name");
300        assert!(result.is_err());
301        assert!(result
302            .unwrap_err()
303            .to_string()
304            .contains("must start with a letter or underscore"));
305    }
306
307    #[test]
308    fn test_validate_sql_identifier_invalid_characters() {
309        // Identifiers with special characters should fail
310        let test_cases = vec![
311            "table-name",  // hyphen
312            "table.name",  // dot
313            "table name",  // space
314            "table;name",  // semicolon
315            "table'name",  // quote
316            "table\"name", // double quote
317            "table(name)", // parentheses
318            "table*name",  // asterisk
319            "table/name",  // slash
320        ];
321
322        for test_case in test_cases {
323            let result = validate_sql_identifier(test_case, "test_name");
324            assert!(result.is_err(), "Expected '{}' to be invalid", test_case);
325            assert!(result
326                .unwrap_err()
327                .to_string()
328                .contains("invalid character"));
329        }
330    }
331
332    #[test]
333    fn test_validate_sql_identifier_boundary_cases() {
334        // Test boundary cases
335        assert!(validate_sql_identifier("a", "test").is_ok()); // Single character
336        assert!(validate_sql_identifier("_", "test").is_ok()); // Just underscore
337
338        let exactly_63 = "a".repeat(63);
339        assert!(validate_sql_identifier(&exactly_63, "test").is_ok());
340    }
341}