uri_register/postgres.rs
1// Copyright TELICENT LTD
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::cache::{create_cache, Cache, CacheStrategy};
16use crate::error::{ConfigurationError, Result};
17use crate::service::UriService;
18use async_trait::async_trait;
19use deadpool_postgres::{ManagerConfig, Pool, RecyclingMethod, Runtime};
20use rustls::RootCertStore;
21use std::sync::Arc;
22use tokio_postgres::{Config, NoTls};
23use tokio_postgres_rustls::MakeRustlsConnect;
24use tracing::{debug, info, instrument, trace};
25use url::Url;
26
27/// PostgreSQL-based URI register implementation with configurable caching
28///
29/// This implementation uses a PostgreSQL table to store URI-to-ID mappings
30/// with an in-memory cache (W-TinyLFU by default, or LRU) to reduce database round-trips.
31/// It's designed for high concurrency with connection pooling and batch operations.
32///
33/// ## Prerequisites
34///
35/// The database schema must be initialized before using this service.
36/// See `schema.sql` for the DDL statements.
37///
38/// ## URI Validation
39///
40/// All URIs are validated before registration to ensure they conform to RFC 3986.
41/// Invalid URIs will return an error.
42///
43/// ## Performance
44///
45/// With default logged tables on typical hardware:
46/// - Batch insert: ~10K-50K URIs/sec
47/// - Batch lookup (cached): ~100K-1M+ URIs/sec (no DB round-trip)
48/// - Batch lookup (uncached): ~100K-200K URIs/sec
49/// - Query overhead: ~2-10ms per query (2 round-trips)
50///
51/// The cache (W-TinyLFU or LRU) significantly improves performance for repeated URI lookups.
52/// Cache strategy and size are configurable when creating the register instance.
53///
54/// For faster writes at the cost of durability, the table can be configured
55/// as UNLOGGED (see `schema.sql` for options).
56pub struct PostgresUriRegister {
57 pool: Pool,
58 /// Cache for URI-to-ID mappings (W-TinyLFU or LRU)
59 cache: Arc<dyn Cache>,
60 /// Name of the database table to use
61 table_name: String,
62}
63
64impl PostgresUriRegister {
65 /// Create a new PostgreSQL URI register service with configurable cache
66 ///
67 /// # Arguments
68 ///
69 /// * `database_url` - PostgreSQL connection string (e.g., "postgres://user:password@host:port/database")
70 /// * `table_name` - Name of the database table to use (must be a valid SQL identifier, default: "uri_register")
71 /// * `max_connections` - Maximum number of connections in the pool (recommended: 10-50)
72 /// * `cache_size` - Number of URI-to-ID mappings to cache in memory (recommended: 1,000-100,000)
73 /// * `cache_strategy` - Cache strategy to use (Moka/W-TinyLFU is default and recommended for most workloads)
74 ///
75 /// # Prerequisites
76 ///
77 /// The database schema must be initialized before using this service.
78 /// See the `schema.sql` file and README.md for setup instructions.
79 ///
80 /// # Example
81 ///
82 /// ```rust,no_run
83 /// use uri_register::PostgresUriRegister;
84 ///
85 /// #[tokio::main]
86 /// async fn main() -> uri_register::Result<()> {
87 /// let register = PostgresUriRegister::new(
88 /// "postgres://localhost/mydb",
89 /// "uri_register", // table name
90 /// 20, // max connections
91 /// 10_000 // cache size (defaults to Moka/W-TinyLFU)
92 /// ).await?;
93 /// Ok(())
94 /// }
95 /// ```
96 pub async fn new(
97 database_url: &str,
98 table_name: &str,
99 max_connections: u32,
100 cache_size: usize,
101 ) -> Result<Self> {
102 Self::new_with_cache_strategy(
103 database_url,
104 table_name,
105 max_connections,
106 cache_size,
107 None, // Default to Moka
108 None, // Default to no TLS
109 )
110 .await
111 }
112
113 /// Create a new PostgreSQL URI register with a specific cache strategy and TLS
114 ///
115 /// This is identical to `new()` but allows specifying a cache strategy and TLS option.
116 /// Most users should use `new()` which defaults to the recommended Moka (W-TinyLFU) cache and no TLS.
117 ///
118 /// # Arguments
119 ///
120 /// * `cache_strategy` - Optional cache strategy (None = Moka default, or specify CacheStrategy::Lru)
121 /// * `use_tls` - Optional TLS flag (None/false = no TLS, true = TLS with webpki root certificates)
122 ///
123 /// # Example
124 ///
125 /// ```rust,no_run
126 /// use uri_register::{CacheStrategy, PostgresUriRegister};
127 ///
128 /// #[tokio::main]
129 /// async fn main() -> uri_register::Result<()> {
130 /// // Use LRU instead of default Moka, with TLS enabled
131 /// let register = PostgresUriRegister::new_with_cache_strategy(
132 /// "postgres://localhost/mydb",
133 /// "uri_register",
134 /// 20,
135 /// 10_000,
136 /// Some(CacheStrategy::Lru),
137 /// Some(true) // Enable TLS
138 /// ).await?;
139 /// Ok(())
140 /// }
141 /// ```
142 pub async fn new_with_cache_strategy(
143 database_url: &str,
144 table_name: &str,
145 max_connections: u32,
146 cache_size: usize,
147 cache_strategy: Option<CacheStrategy>,
148 use_tls: Option<bool>,
149 ) -> Result<Self> {
150 // Validate inputs
151 if cache_size == 0 {
152 return Err(ConfigurationError::InvalidCacheSize(cache_size).into());
153 }
154
155 if max_connections == 0 {
156 return Err(ConfigurationError::InvalidMaxConnections(max_connections).into());
157 }
158
159 // Validate table name as SQL identifier
160 Self::validate_table_name(table_name)?;
161
162 // Parse the database URL into tokio-postgres Config
163 let pg_config: Config = database_url.parse().map_err(|e| {
164 ConfigurationError::InvalidBackoff(format!("Failed to parse database URL: {}", e))
165 })?;
166
167 // Create deadpool configuration
168 let mut cfg = deadpool_postgres::Config::new();
169 cfg.dbname = pg_config.get_dbname().map(|s| s.to_string());
170 cfg.host = pg_config.get_hosts().first().map(|h| match h {
171 tokio_postgres::config::Host::Tcp(s) => s.to_string(),
172 #[cfg(unix)]
173 tokio_postgres::config::Host::Unix(p) => p.to_str().unwrap_or_default().to_string(),
174 });
175 cfg.port = pg_config.get_ports().first().copied();
176 cfg.user = pg_config.get_user().map(|s| s.to_string());
177 cfg.password = pg_config
178 .get_password()
179 .map(|p| std::str::from_utf8(p).unwrap_or_default().to_string());
180 cfg.manager = Some(ManagerConfig {
181 recycling_method: RecyclingMethod::Fast,
182 });
183 cfg.pool = Some(deadpool_postgres::PoolConfig {
184 max_size: max_connections as usize,
185 timeouts: deadpool_postgres::Timeouts {
186 wait: Some(std::time::Duration::from_secs(10)),
187 create: Some(std::time::Duration::from_secs(10)),
188 recycle: Some(std::time::Duration::from_secs(10)),
189 },
190 ..Default::default()
191 });
192
193 // Create the pool with or without TLS based on use_tls parameter
194 let pool = if use_tls.unwrap_or(false) {
195 // Configure TLS with webpki root certificates
196 let mut root_store = RootCertStore::empty();
197 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
198
199 let tls_config = rustls::ClientConfig::builder()
200 .with_root_certificates(root_store)
201 .with_no_client_auth();
202
203 let tls = MakeRustlsConnect::new(tls_config);
204
205 cfg.create_pool(Some(Runtime::Tokio1), tls).map_err(|e| {
206 ConfigurationError::InvalidBackoff(format!(
207 "Failed to create connection pool with TLS: {}",
208 e
209 ))
210 })?
211 } else {
212 // No TLS
213 cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
214 ConfigurationError::InvalidBackoff(format!(
215 "Failed to create connection pool: {}",
216 e
217 ))
218 })?
219 };
220
221 let cache = create_cache(cache_strategy.unwrap_or_default(), cache_size);
222
223 info!(
224 table = table_name,
225 max_connections,
226 cache_size,
227 tls = use_tls.unwrap_or(false),
228 "URI register connected"
229 );
230
231 Ok(Self {
232 pool,
233 cache,
234 table_name: table_name.to_string(),
235 })
236 }
237
238 /// Validate that a table name is a valid SQL identifier
239 ///
240 /// Prevents SQL injection by ensuring the table name only contains
241 /// alphanumeric characters and underscores, and doesn't start with a digit.
242 fn validate_table_name(name: &str) -> Result<()> {
243 if name.is_empty() {
244 return Err(ConfigurationError::InvalidTableName(
245 "table name cannot be empty".to_string(),
246 )
247 .into());
248 }
249
250 if name.len() > 63 {
251 return Err(ConfigurationError::InvalidTableName(format!(
252 "table name too long (max 63 characters): '{}'",
253 name
254 ))
255 .into());
256 }
257
258 // First character must be a letter or underscore
259 let first_char = name.chars().next().unwrap();
260 if !first_char.is_ascii_alphabetic() && first_char != '_' {
261 return Err(ConfigurationError::InvalidTableName(format!(
262 "table name must start with a letter or underscore: '{}'",
263 name
264 ))
265 .into());
266 }
267
268 // All characters must be alphanumeric or underscore
269 if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
270 return Err(ConfigurationError::InvalidTableName(format!(
271 "table name can only contain letters, numbers, and underscores: '{}'",
272 name
273 ))
274 .into());
275 }
276
277 Ok(())
278 }
279
280 /// Get statistics about the URI register
281 ///
282 /// Returns the total number of URIs and the storage size.
283 ///
284 /// # Example
285 ///
286 /// ```rust,no_run
287 /// use uri_register::PostgresUriRegister;
288 ///
289 /// #[tokio::main]
290 /// async fn main() -> uri_register::Result<()> {
291 /// let register = PostgresUriRegister::new(
292 /// "postgres://localhost/mydb",
293 /// "uri_register",
294 /// 20,
295 /// 10_000
296 /// ).await?;
297 /// let stats = register.stats().await?;
298 /// println!("Total URIs: {}", stats.total_uris);
299 /// println!("Size: {} bytes", stats.size_bytes);
300 /// Ok(())
301 /// }
302 /// ```
303 pub async fn stats(&self) -> Result<RegisterStats> {
304 // Build query with validated table name (safe from SQL injection)
305 let query = format!(
306 r#"
307 SELECT
308 COUNT(*)::bigint as count,
309 pg_total_relation_size('{}')::bigint as size_bytes
310 FROM {}
311 "#,
312 self.table_name, self.table_name
313 );
314
315 // Execute with retry logic
316 let client = self.pool.get().await.map_err(|e| {
317 crate::error::Error::Database(format!("Failed to get database connection: {}", e))
318 })?;
319
320 let rows = client
321 .query(&query, &[])
322 .await
323 .map_err(|e| crate::error::Error::Database(e.to_string()))?;
324
325 let row = rows.into_iter().next().ok_or_else(|| {
326 crate::error::Error::Database("No rows returned from stats query".to_string())
327 })?;
328
329 // Get cache statistics
330 let cache_stats = self.cache.stats();
331
332 // Get connection pool statistics
333 let status = self.pool.status();
334 let pool_stats = PoolStats {
335 connections_active: (status.size - status.available) as u32,
336 connections_idle: status.available as u32,
337 connections_max: status.max_size as u32,
338 };
339
340 Ok(RegisterStats {
341 total_uris: row.get::<_, i64>("count") as u64,
342 size_bytes: row.get::<_, i64>("size_bytes") as u64,
343 cache: cache_stats,
344 pool: pool_stats,
345 })
346 }
347
348 /// Clone the register instance (shares pool and cache)
349 ///
350 /// This is a shallow clone that shares both the connection pool and cache.
351 /// Both pool and cache use Arc internally, so this clone is cheap and shares
352 /// the underlying resources.
353 ///
354 /// This method is primarily used for Python bindings where we need to move
355 /// data into async closures.
356 #[cfg(feature = "python")]
357 pub(crate) fn clone_inner(&self) -> Self {
358 PostgresUriRegister {
359 pool: self.pool.clone(),
360 cache: self.cache.clone(), // Clone the Arc, shares the same cache
361 table_name: self.table_name.clone(),
362 }
363 }
364
365 /// Validate that a string is a valid URI according to RFC 3986
366 fn validate_uri(uri: &str) -> Result<()> {
367 Url::parse(uri).map_err(|e| {
368 crate::error::Error::InvalidUri(format!("Invalid URI '{}': {}", uri, e))
369 })?;
370 Ok(())
371 }
372}
373
374#[async_trait]
375impl UriService for PostgresUriRegister {
376 #[instrument(skip(self), fields(table = %self.table_name))]
377 async fn register_uri(&self, uri: &str) -> Result<u64> {
378 // Validate URI first
379 Self::validate_uri(uri)?;
380
381 // Check cache first
382 if let Some(id) = self.cache.get(uri) {
383 trace!(id, "cache hit");
384 return Ok(id);
385 }
386 trace!("cache miss, querying database");
387
388 // Insert and return ID (ON CONFLICT handles race conditions and existing URIs)
389 // Build query with validated table name (safe from SQL injection)
390 let query = format!(
391 r#"
392 INSERT INTO {} (uri)
393 VALUES ($1)
394 ON CONFLICT (uri_hash) DO UPDATE SET uri = EXCLUDED.uri
395 RETURNING id
396 "#,
397 self.table_name
398 );
399
400 // Execute with retry logic
401 let client = self.pool.get().await.map_err(|e| {
402 crate::error::Error::Database(format!("Failed to get database connection: {}", e))
403 })?;
404
405 let rows = client
406 .query(&query, &[&uri])
407 .await
408 .map_err(|e| crate::error::Error::Database(e.to_string()))?;
409
410 let result = rows.into_iter().next().ok_or_else(|| {
411 crate::error::Error::Database("No rows returned from register_uri query".to_string())
412 })?;
413
414 let id = result.get::<_, i64>("id") as u64;
415
416 // Update cache
417 self.cache.put(uri.to_string(), id);
418
419 Ok(id)
420 }
421
422 #[instrument(skip(self, uris), fields(table = %self.table_name, batch_size = uris.len()))]
423 async fn register_uri_batch(&self, uris: &[String]) -> Result<Vec<u64>> {
424 if uris.is_empty() {
425 trace!("empty batch, returning early");
426 return Ok(Vec::new());
427 }
428
429 // Validate all URIs first
430 for uri in uris {
431 Self::validate_uri(uri)?;
432 }
433
434 // CORRECTNESS GUARANTEE: Order preservation
435 // We maintain strict correspondence between input URIs and output IDs
436 // by tracking the original index of each URI and using URI strings
437 // (not SQL result order) to map IDs back to their positions.
438
439 let mut result_ids = vec![None; uris.len()];
440 let mut uncached_indices = Vec::new();
441 let mut uncached_uris_dedup = Vec::new();
442 let mut seen_uncached = std::collections::HashMap::new();
443
444 // Step 1: Check cache for all URIs
445 for (idx, uri) in uris.iter().enumerate() {
446 if let Some(id) = self.cache.get(uri) {
447 result_ids[idx] = Some(id);
448 } else {
449 uncached_indices.push(idx);
450 // Deduplicate uncached URIs for DB query
451 if !seen_uncached.contains_key(uri) {
452 seen_uncached.insert(uri.clone(), uncached_uris_dedup.len());
453 uncached_uris_dedup.push(uri.clone());
454 }
455 }
456 }
457
458 // If everything was cached, return early
459 if uncached_uris_dedup.is_empty() {
460 debug!(cached = uris.len(), "all URIs found in cache");
461 return Ok(result_ids.into_iter().map(|id| id.unwrap()).collect());
462 }
463
464 let cached_count = uris.len() - uncached_indices.len();
465 debug!(
466 cached = cached_count,
467 uncached = uncached_uris_dedup.len(),
468 "cache lookup complete, querying database"
469 );
470
471 // Step 2: Register deduplicated uncached URIs in batch
472 // IMPORTANT: SQL may return results in ANY order (not guaranteed to match input order)
473 // We use "RETURNING id, uri" to get BOTH values together, then map by URI string
474 // Build query with validated table name (safe from SQL injection)
475 let query = format!(
476 r#"
477 INSERT INTO {} (uri)
478 SELECT unnest($1::text[])
479 ON CONFLICT (uri_hash) DO UPDATE SET uri = EXCLUDED.uri
480 RETURNING id, uri
481 "#,
482 self.table_name
483 );
484
485 // Execute with retry logic
486 let client = self.pool.get().await.map_err(|e| {
487 crate::error::Error::Database(format!("Failed to get database connection: {}", e))
488 })?;
489
490 let rows = client
491 .query(&query, &[&uncached_uris_dedup])
492 .await
493 .map_err(|e| crate::error::Error::Database(e.to_string()))?;
494
495 // Build a map of URI -> ID from database results
496 // This allows us to look up IDs by URI string (order-independent)
497 let mut uri_to_id = std::collections::HashMap::new();
498 for row in rows {
499 let uri: String = row.get("uri");
500 let id: i64 = row.get("id");
501 uri_to_id.insert(uri, id as u64);
502 }
503
504 // Step 3: Fill in the result vector and update cache
505 // CORRECTNESS: We use the saved indices and look up by URI string,
506 // guaranteeing that result_ids[i] corresponds to uris[i]
507 for idx in uncached_indices {
508 let uri = &uris[idx]; // Get URI from original position
509 if let Some(&id) = uri_to_id.get(uri) {
510 // Look up ID by URI string
511 result_ids[idx] = Some(id); // Store at original index
512 self.cache.put(uri.clone(), id);
513 }
514 }
515
516 // Convert Option<u64> to u64 (all should be Some at this point)
517 Ok(result_ids
518 .into_iter()
519 .map(|id| id.expect("All URIs should have IDs"))
520 .collect())
521 }
522
523 #[instrument(skip(self, uris), fields(table = %self.table_name, batch_size = uris.len()))]
524 async fn register_uri_batch_hashmap(
525 &self,
526 uris: &[String],
527 ) -> Result<std::collections::HashMap<String, u64>> {
528 if uris.is_empty() {
529 trace!("empty batch, returning early");
530 return Ok(std::collections::HashMap::new());
531 }
532
533 // Validate all URIs first
534 for uri in uris {
535 Self::validate_uri(uri)?;
536 }
537
538 // CORRECTNESS GUARANTEE: URI-to-ID mapping accuracy
539 // Each URI in the result HashMap is guaranteed to map to its correct ID
540 // because SQL returns both 'id' and 'uri' together in each row (RETURNING id, uri).
541 // We never rely on positional correspondence, eliminating ordering errors.
542
543 let mut result = std::collections::HashMap::new();
544 let mut uncached_uris = Vec::new();
545
546 // Step 1: Deduplicate input and check cache
547 let unique_uris: std::collections::HashSet<_> = uris.iter().collect();
548
549 for uri in unique_uris {
550 if let Some(id) = self.cache.get(uri) {
551 result.insert(uri.clone(), id);
552 } else {
553 uncached_uris.push(uri.clone());
554 }
555 }
556
557 // If everything was cached, return early
558 if uncached_uris.is_empty() {
559 debug!(cached = result.len(), "all URIs found in cache");
560 return Ok(result);
561 }
562
563 debug!(
564 cached = result.len(),
565 uncached = uncached_uris.len(),
566 "cache lookup complete, querying database"
567 );
568
569 // Step 2: Register uncached URIs in batch
570 // Build query with validated table name (safe from SQL injection)
571 let query = format!(
572 r#"
573 INSERT INTO {} (uri)
574 SELECT unnest($1::text[])
575 ON CONFLICT (uri_hash) DO UPDATE SET uri = EXCLUDED.uri
576 RETURNING id, uri
577 "#,
578 self.table_name
579 );
580
581 // Execute with retry logic
582 let client = self.pool.get().await.map_err(|e| {
583 crate::error::Error::Database(format!("Failed to get database connection: {}", e))
584 })?;
585
586 let rows = client
587 .query(&query, &[&uncached_uris])
588 .await
589 .map_err(|e| crate::error::Error::Database(e.to_string()))?;
590
591 // Step 3: Add database results to result map and update cache
592 // CORRECTNESS: Each row contains both URI and ID from same DB row,
593 // guaranteeing correct mapping (no opportunity for misalignment)
594 for row in rows {
595 let uri: String = row.get("uri");
596 let id: i64 = row.get("id");
597 let id_u64 = id as u64;
598
599 result.insert(uri.clone(), id_u64); // URI and ID are from same row
600 self.cache.put(uri, id_u64);
601 }
602
603 Ok(result)
604 }
605}
606
607/// Statistics about the URI register for observability and OpenTelemetry
608#[derive(Debug, Clone)]
609pub struct RegisterStats {
610 /// Total number of URIs in the register
611 pub total_uris: u64,
612 /// Total storage size in bytes (includes indexes)
613 pub size_bytes: u64,
614 /// Cache performance metrics
615 pub cache: crate::cache::CacheStats,
616 /// Connection pool metrics
617 pub pool: PoolStats,
618}
619
620/// Connection pool statistics for observability
621#[derive(Debug, Clone)]
622pub struct PoolStats {
623 /// Number of connections currently being used
624 pub connections_active: u32,
625 /// Number of idle connections in the pool
626 pub connections_idle: u32,
627 /// Maximum number of connections allowed in the pool
628 pub connections_max: u32,
629}