Skip to main content

dbrest_core/schema_cache/
mod.rs

1//! Schema Cache module
2//!
3//! The schema cache is the heart of dbrest. It introspects the PostgreSQL database
4//! and caches:
5//! - Tables/Views metadata
6//! - Column information
7//! - Foreign key relationships
8//! - Functions/Procedures
9//!
10//! # Architecture
11//!
12//! The cache is immutable and wrapped in `ArcSwap` for lock-free reads and atomic
13//! replacement during schema reload.
14//!
15//! ```text
16//! ┌─────────────────────────────────────────────────┐
17//! │                  SchemaCache                     │
18//! ├─────────────────────────────────────────────────┤
19//! │  tables: HashMap<QualifiedIdentifier, Table>    │
20//! │  relationships: HashMap<..., Vec<Relationship>> │
21//! │  routines: HashMap<QualifiedIdentifier, Vec>    │
22//! │  timezones: HashSet<String>                     │
23//! └─────────────────────────────────────────────────┘
24//! ```
25
26pub mod db;
27pub mod media_handler;
28pub mod queries;
29pub mod relationship;
30pub mod representations;
31pub mod routine;
32pub mod table;
33
34// Re-export main types
35pub use db::{ComputedFieldRow, DbIntrospector, RelationshipRow, RoutineRow, TableRow};
36pub use media_handler::{MediaHandler, MediaHandlerMap, ResolvedHandler};
37pub use relationship::{
38    AnyRelationship, Cardinality, ComputedRelationship, Junction, Relationship,
39};
40pub use representations::{DataRepresentation, RepresentationsMap};
41pub use routine::{PgType, ReturnType, Routine, RoutineParam, Volatility};
42pub use table::{Column, ComputedField, Table};
43
44use std::collections::{HashMap, HashSet};
45use std::sync::Arc;
46
47use arc_swap::ArcSwap;
48
49use crate::config::AppConfig;
50use crate::error::Error;
51use crate::types::QualifiedIdentifier;
52
53/// Type alias for the tables map
54pub type TablesMap = HashMap<QualifiedIdentifier, Table>;
55
56/// Type alias for the relationships map
57/// Key: (source_table, schema) -> list of relationships from that table
58pub type RelationshipsMap = HashMap<(QualifiedIdentifier, String), Vec<AnyRelationship>>;
59
60/// Type alias for the routines map
61/// Key: QualifiedIdentifier -> list of overloaded functions
62pub type RoutinesMap = HashMap<QualifiedIdentifier, Vec<Routine>>;
63
64/// Immutable schema cache
65///
66/// This structure holds all introspected database metadata. It is designed to be
67/// immutable and wrapped in `ArcSwap` for lock-free reads.
68#[derive(Debug, Clone)]
69pub struct SchemaCache {
70    /// All tables and views by qualified name
71    pub tables: Arc<TablesMap>,
72    /// Relationships indexed by source table
73    pub relationships: Arc<RelationshipsMap>,
74    /// Functions/procedures indexed by qualified name
75    pub routines: Arc<RoutinesMap>,
76    /// Available PostgreSQL timezones
77    pub timezones: Arc<HashSet<String>>,
78    /// Data representation mappings
79    pub representations: Arc<RepresentationsMap>,
80    /// Media handler mappings
81    pub media_handlers: Arc<MediaHandlerMap>,
82}
83
84impl Default for SchemaCache {
85    fn default() -> Self {
86        Self::empty()
87    }
88}
89
90impl SchemaCache {
91    /// Create an empty schema cache
92    pub fn empty() -> Self {
93        Self {
94            tables: Arc::new(HashMap::new()),
95            relationships: Arc::new(HashMap::new()),
96            routines: Arc::new(HashMap::new()),
97            timezones: Arc::new(HashSet::new()),
98            representations: Arc::new(HashMap::new()),
99            media_handlers: Arc::new(HashMap::new()),
100        }
101    }
102
103    /// Load schema cache from database using the provided introspector
104    pub async fn load<I: DbIntrospector + ?Sized>(
105        introspector: &I,
106        config: &AppConfig,
107    ) -> Result<Self, Error> {
108        let schemas = &config.db_schemas;
109
110        tracing::info!("Loading schema cache for schemas: {:?}", schemas);
111
112        // Combine exposed schemas with extra search path for computed fields query
113        let mut all_schemas = config.db_schemas.clone();
114        for extra_schema in &config.db_extra_search_path {
115            if !all_schemas.contains(extra_schema) {
116                all_schemas.push(extra_schema.clone());
117            }
118        }
119
120        tracing::debug!("All schemas for computed fields query: {:?}", all_schemas);
121
122        // Query all data concurrently
123        let (tables_rows, rel_rows, routine_rows, computed_fields_rows, timezones) = tokio::try_join!(
124            introspector.query_tables(schemas),
125            introspector.query_relationships(),
126            introspector.query_routines(schemas),
127            introspector.query_computed_fields(&all_schemas),
128            introspector.query_timezones(),
129        )?;
130
131        tracing::debug!(
132            "Loaded: {} tables, {} relationships, {} routines, {} computed fields, {} timezones",
133            tables_rows.len(),
134            rel_rows.len(),
135            routine_rows.len(),
136            computed_fields_rows.len(),
137            timezones.len()
138        );
139
140        // Build tables map
141        let mut tables = HashMap::with_capacity(tables_rows.len());
142        for row in tables_rows {
143            let table = row.into_table()?;
144            let qi = table.qi();
145            tables.insert(qi.clone(), table);
146        }
147
148        // Group computed fields by table and attach them
149        use crate::schema_cache::table::ComputedField;
150        use crate::types::QualifiedIdentifier as QI;
151
152        let mut attached_count = 0;
153        let mut not_found_count = 0;
154
155        for row in computed_fields_rows {
156            let table_qi = QI::new(&row.table_schema, &row.table_name);
157            if let Some(table) = tables.get_mut(&table_qi) {
158                let function_qi = QI::new(&row.function_schema, &row.function_name);
159                let computed_field = ComputedField {
160                    function: function_qi,
161                    return_type: row.return_type.into(),
162                    returns_set: row.returns_set,
163                };
164                // Use function name as the key (not qualified, matching PostgREST behavior)
165                table
166                    .computed_fields
167                    .insert(row.function_name.clone().into(), computed_field);
168                tracing::trace!(
169                    "Attached computed field '{}' to table {}.{}",
170                    row.function_name,
171                    row.table_schema,
172                    row.table_name
173                );
174                attached_count += 1;
175            } else {
176                tracing::warn!(
177                    "Computed field function {}.{} references non-existent table {}.{}",
178                    row.function_schema,
179                    row.function_name,
180                    row.table_schema,
181                    row.table_name
182                );
183                not_found_count += 1;
184            }
185        }
186
187        tracing::debug!(
188            "Attached {} computed fields to tables, {} referenced non-existent tables",
189            attached_count,
190            not_found_count
191        );
192
193        // Build relationships map — store both forward (M2O) and reverse (O2M)
194        // directions so that resource embedding works in either direction.
195        let mut relationships: RelationshipsMap = HashMap::new();
196        for row in rel_rows {
197            let rel = row.into_relationship();
198
199            // Forward direction (M2O / O2O): keyed under the FK-holding table
200            let fwd_key = (rel.table.clone(), rel.table.schema.to_string());
201            let reverse = rel.reverse();
202            relationships
203                .entry(fwd_key)
204                .or_default()
205                .push(AnyRelationship::ForeignKey(rel));
206
207            // Reverse direction (O2M / O2O-parent): keyed under the referenced table
208            let rev_key = (reverse.table.clone(), reverse.table.schema.to_string());
209            relationships
210                .entry(rev_key)
211                .or_default()
212                .push(AnyRelationship::ForeignKey(reverse));
213        }
214
215        // Build routines map
216        let mut routines: RoutinesMap = HashMap::new();
217        for row in routine_rows {
218            let routine = row.into_routine()?;
219            let qi = routine.qi();
220            routines.entry(qi).or_default().push(routine);
221        }
222
223        // Convert timezones to HashSet, ensuring UTC is always included
224        let mut timezone_set: HashSet<String> = timezones.into_iter().collect();
225        timezone_set.insert("UTC".to_string());
226
227        Ok(Self {
228            tables: Arc::new(tables),
229            relationships: Arc::new(relationships),
230            routines: Arc::new(routines),
231            timezones: Arc::new(timezone_set),
232            representations: Arc::new(HashMap::new()),
233            media_handlers: Arc::new(HashMap::new()),
234        })
235    }
236
237    /// Get a table by qualified identifier
238    pub fn get_table(&self, qi: &QualifiedIdentifier) -> Option<&Table> {
239        self.tables.get(qi)
240    }
241
242    /// Get a table by schema and name
243    pub fn get_table_by_name(&self, schema: &str, name: &str) -> Option<&Table> {
244        let qi = QualifiedIdentifier::new(schema, name);
245        self.tables.get(&qi)
246    }
247
248    /// Find relationships from a source table
249    pub fn find_relationships(&self, source: &QualifiedIdentifier) -> &[AnyRelationship] {
250        let key = (source.clone(), source.schema.to_string());
251        self.relationships
252            .get(&key)
253            .map(|v| v.as_slice())
254            .unwrap_or(&[])
255    }
256
257    /// Find relationships from source to a specific target
258    pub fn find_relationships_to(
259        &self,
260        source: &QualifiedIdentifier,
261        target_name: &str,
262    ) -> Vec<&AnyRelationship> {
263        self.find_relationships(source)
264            .iter()
265            .filter(|r| r.foreign_table().name.as_str() == target_name)
266            .collect()
267    }
268
269    /// Get a routine by qualified identifier
270    pub fn get_routines(&self, qi: &QualifiedIdentifier) -> Option<&[Routine]> {
271        self.routines.get(qi).map(|v| v.as_slice())
272    }
273
274    /// Get a routine by schema and name
275    pub fn get_routines_by_name(&self, schema: &str, name: &str) -> Option<&[Routine]> {
276        let qi = QualifiedIdentifier::new(schema, name);
277        self.routines.get(&qi).map(|v| v.as_slice())
278    }
279
280    /// Check if a timezone is valid
281    pub fn is_valid_timezone(&self, tz: &str) -> bool {
282        self.timezones.contains(tz)
283    }
284
285    /// Get the number of tables
286    pub fn table_count(&self) -> usize {
287        self.tables.len()
288    }
289
290    /// Get the number of relationships
291    pub fn relationship_count(&self) -> usize {
292        self.relationships.values().map(|v| v.len()).sum()
293    }
294
295    /// Get the number of routines
296    pub fn routine_count(&self) -> usize {
297        self.routines.values().map(|v| v.len()).sum()
298    }
299
300    /// Get a summary string for logging
301    pub fn summary(&self) -> String {
302        format!(
303            "{} tables, {} relationships, {} routines, {} timezones",
304            self.table_count(),
305            self.relationship_count(),
306            self.routine_count(),
307            self.timezones.len(),
308        )
309    }
310
311    /// Iterate over all tables
312    pub fn tables_iter(&self) -> impl Iterator<Item = (&QualifiedIdentifier, &Table)> {
313        self.tables.iter()
314    }
315
316    /// Iterate over all tables in a specific schema
317    pub fn tables_in_schema(&self, schema: &str) -> impl Iterator<Item = &Table> {
318        self.tables
319            .values()
320            .filter(move |t| t.schema.as_str() == schema)
321    }
322}
323
324/// Schema cache holder with atomic swap capability
325///
326/// Wraps the schema cache in `ArcSwap` for lock-free reads and atomic updates.
327#[derive(Debug)]
328pub struct SchemaCacheHolder {
329    inner: ArcSwap<Option<SchemaCache>>,
330}
331
332impl Default for SchemaCacheHolder {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338impl SchemaCacheHolder {
339    /// Create a new empty holder
340    pub fn new() -> Self {
341        Self {
342            inner: ArcSwap::from_pointee(None),
343        }
344    }
345
346    /// Create a holder with an initial cache
347    pub fn with_cache(cache: SchemaCache) -> Self {
348        Self {
349            inner: ArcSwap::from_pointee(Some(cache)),
350        }
351    }
352
353    /// Get a reference to the current cache
354    ///
355    /// Returns None if the cache hasn't been loaded yet.
356    pub fn get(&self) -> Option<arc_swap::Guard<Arc<Option<SchemaCache>>>> {
357        let guard = self.inner.load();
358        if guard.is_some() { Some(guard) } else { None }
359    }
360
361    /// Replace the cache with a new one
362    pub fn replace(&self, cache: SchemaCache) {
363        self.inner.store(Arc::new(Some(cache)));
364    }
365
366    /// Clear the cache
367    pub fn clear(&self) {
368        self.inner.store(Arc::new(None));
369    }
370
371    /// Check if the cache is loaded
372    pub fn is_loaded(&self) -> bool {
373        self.inner.load().is_some()
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::test_helpers::*;
381
382    fn create_test_cache() -> SchemaCache {
383        let mut tables = HashMap::new();
384
385        let users_table = test_table()
386            .schema("public")
387            .name("users")
388            .pk_col("id")
389            .column(test_column().name("id").data_type("integer").build())
390            .column(test_column().name("name").data_type("text").build())
391            .build();
392
393        let posts_table = test_table()
394            .schema("public")
395            .name("posts")
396            .pk_col("id")
397            .column(test_column().name("id").data_type("integer").build())
398            .column(test_column().name("user_id").data_type("integer").build())
399            .column(test_column().name("title").data_type("text").build())
400            .build();
401
402        tables.insert(users_table.qi(), users_table);
403        tables.insert(posts_table.qi(), posts_table);
404
405        // Create relationship posts -> users
406        let rel = test_relationship()
407            .table("public", "posts")
408            .foreign_table("public", "users")
409            .m2o("fk_posts_user", &[("user_id", "id")])
410            .build();
411
412        let mut relationships = HashMap::new();
413        let key = (
414            QualifiedIdentifier::new("public", "posts"),
415            "public".to_string(),
416        );
417        relationships.insert(key, vec![AnyRelationship::ForeignKey(rel)]);
418
419        // Create routine
420        let routine = test_routine()
421            .schema("public")
422            .name("get_user")
423            .param(test_param().name("user_id").pg_type("integer").build())
424            .returns_setof_composite("public", "users")
425            .build();
426
427        let mut routines = HashMap::new();
428        routines.insert(routine.qi(), vec![routine]);
429
430        let mut timezones = HashSet::new();
431        timezones.insert("UTC".to_string());
432        timezones.insert("America/New_York".to_string());
433
434        SchemaCache {
435            tables: Arc::new(tables),
436            relationships: Arc::new(relationships),
437            routines: Arc::new(routines),
438            timezones: Arc::new(timezones),
439            representations: Arc::new(HashMap::new()),
440            media_handlers: Arc::new(HashMap::new()),
441        }
442    }
443
444    #[test]
445    fn test_schema_cache_empty() {
446        let cache = SchemaCache::empty();
447        assert_eq!(cache.table_count(), 0);
448        assert_eq!(cache.relationship_count(), 0);
449        assert_eq!(cache.routine_count(), 0);
450    }
451
452    #[test]
453    fn test_schema_cache_get_table() {
454        let cache = create_test_cache();
455
456        let qi = QualifiedIdentifier::new("public", "users");
457        let table = cache.get_table(&qi).unwrap();
458        assert_eq!(table.name.as_str(), "users");
459    }
460
461    #[test]
462    fn test_schema_cache_get_table_by_name() {
463        let cache = create_test_cache();
464
465        let table = cache.get_table_by_name("public", "posts").unwrap();
466        assert_eq!(table.name.as_str(), "posts");
467        assert!(table.has_pk());
468    }
469
470    #[test]
471    fn test_schema_cache_get_table_not_found() {
472        let cache = create_test_cache();
473
474        let qi = QualifiedIdentifier::new("public", "nonexistent");
475        assert!(cache.get_table(&qi).is_none());
476    }
477
478    #[test]
479    fn test_schema_cache_find_relationships() {
480        let cache = create_test_cache();
481
482        let source = QualifiedIdentifier::new("public", "posts");
483        let rels = cache.find_relationships(&source);
484        assert_eq!(rels.len(), 1);
485        assert_eq!(rels[0].foreign_table().name.as_str(), "users");
486    }
487
488    #[test]
489    fn test_schema_cache_find_relationships_to() {
490        let cache = create_test_cache();
491
492        let source = QualifiedIdentifier::new("public", "posts");
493        let rels = cache.find_relationships_to(&source, "users");
494        assert_eq!(rels.len(), 1);
495
496        let rels = cache.find_relationships_to(&source, "nonexistent");
497        assert!(rels.is_empty());
498    }
499
500    #[test]
501    fn test_schema_cache_get_routines() {
502        let cache = create_test_cache();
503
504        let qi = QualifiedIdentifier::new("public", "get_user");
505        let routines = cache.get_routines(&qi).unwrap();
506        assert_eq!(routines.len(), 1);
507        assert!(routines[0].returns_set());
508    }
509
510    #[test]
511    fn test_schema_cache_get_routines_by_name() {
512        let cache = create_test_cache();
513
514        let routines = cache.get_routines_by_name("public", "get_user").unwrap();
515        assert_eq!(routines.len(), 1);
516    }
517
518    #[test]
519    fn test_schema_cache_is_valid_timezone() {
520        let cache = create_test_cache();
521
522        assert!(cache.is_valid_timezone("UTC"));
523        assert!(cache.is_valid_timezone("America/New_York"));
524        assert!(!cache.is_valid_timezone("Invalid/Zone"));
525    }
526
527    #[test]
528    fn test_schema_cache_counts() {
529        let cache = create_test_cache();
530
531        assert_eq!(cache.table_count(), 2);
532        assert_eq!(cache.relationship_count(), 1);
533        assert_eq!(cache.routine_count(), 1);
534    }
535
536    #[test]
537    fn test_schema_cache_summary() {
538        let cache = create_test_cache();
539
540        let summary = cache.summary();
541        assert!(summary.contains("2 tables"));
542        assert!(summary.contains("1 relationships"));
543        assert!(summary.contains("1 routines"));
544    }
545
546    #[test]
547    fn test_schema_cache_tables_iter() {
548        let cache = create_test_cache();
549
550        let table_names: Vec<_> = cache.tables_iter().map(|(_, t)| t.name.as_str()).collect();
551        assert!(table_names.contains(&"users"));
552        assert!(table_names.contains(&"posts"));
553    }
554
555    #[test]
556    fn test_schema_cache_tables_in_schema() {
557        let cache = create_test_cache();
558
559        let public_tables: Vec<_> = cache.tables_in_schema("public").collect();
560        assert_eq!(public_tables.len(), 2);
561
562        let other_tables: Vec<_> = cache.tables_in_schema("other").collect();
563        assert!(other_tables.is_empty());
564    }
565
566    // ========================================================================
567    // SchemaCacheHolder Tests
568    // ========================================================================
569
570    #[test]
571    fn test_schema_cache_holder_new() {
572        let holder = SchemaCacheHolder::new();
573        assert!(!holder.is_loaded());
574        assert!(holder.get().is_none());
575    }
576
577    #[test]
578    fn test_schema_cache_holder_with_cache() {
579        let cache = create_test_cache();
580        let holder = SchemaCacheHolder::with_cache(cache);
581        assert!(holder.is_loaded());
582        assert!(holder.get().is_some());
583    }
584
585    #[test]
586    fn test_schema_cache_holder_replace() {
587        let holder = SchemaCacheHolder::new();
588        assert!(!holder.is_loaded());
589
590        let cache = create_test_cache();
591        holder.replace(cache);
592        assert!(holder.is_loaded());
593    }
594
595    #[test]
596    fn test_schema_cache_holder_clear() {
597        let cache = create_test_cache();
598        let holder = SchemaCacheHolder::with_cache(cache);
599        assert!(holder.is_loaded());
600
601        holder.clear();
602        assert!(!holder.is_loaded());
603    }
604
605    // ========================================================================
606    // Mock-based Tests
607    // ========================================================================
608
609    #[tokio::test]
610    async fn test_schema_cache_load_with_mock() {
611        use db::MockDbIntrospector;
612
613        let mut mock = MockDbIntrospector::new();
614
615        // Set up mock expectations
616        mock.expect_query_tables().returning(|_| {
617            Ok(vec![TableRow {
618                table_schema: "public".to_string(),
619                table_name: "test_table".to_string(),
620                table_description: None,
621                is_view: false,
622                insertable: true,
623                updatable: true,
624                deletable: true,
625                readable: true,
626                pk_cols: vec!["id".to_string()],
627                columns_json: r#"[{"name":"id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]}]"#.to_string(),
628            }])
629        });
630
631        mock.expect_query_relationships().returning(|| Ok(vec![]));
632        mock.expect_query_routines().returning(|_| Ok(vec![]));
633        mock.expect_query_computed_fields()
634            .returning(|_| Ok(vec![]));
635        mock.expect_query_timezones()
636            .returning(|| Ok(vec!["UTC".to_string()]));
637
638        let config = AppConfig::default();
639        let cache = SchemaCache::load(&mock, &config).await.unwrap();
640
641        assert_eq!(cache.table_count(), 1);
642        let table = cache.get_table_by_name("public", "test_table").unwrap();
643        assert!(table.has_pk());
644    }
645
646    #[tokio::test]
647    async fn test_schema_cache_load_with_relationships() {
648        use db::MockDbIntrospector;
649
650        let mut mock = MockDbIntrospector::new();
651
652        mock.expect_query_tables().returning(|_| {
653            Ok(vec![
654                TableRow {
655                    table_schema: "public".to_string(),
656                    table_name: "users".to_string(),
657                    table_description: None,
658                    is_view: false,
659                    insertable: true,
660                    updatable: true,
661                    deletable: true,
662                    readable: true,
663                    pk_cols: vec!["id".to_string()],
664                    columns_json: r#"[{"name":"id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]}]"#.to_string(),
665                },
666                TableRow {
667                    table_schema: "public".to_string(),
668                    table_name: "posts".to_string(),
669                    table_description: None,
670                    is_view: false,
671                    insertable: true,
672                    updatable: true,
673                    deletable: true,
674                    readable: true,
675                    pk_cols: vec!["id".to_string()],
676                    columns_json: r#"[{"name":"id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]},{"name":"user_id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]}]"#.to_string(),
677                },
678            ])
679        });
680
681        mock.expect_query_relationships().returning(|| {
682            Ok(vec![RelationshipRow {
683                table_schema: "public".to_string(),
684                table_name: "posts".to_string(),
685                foreign_table_schema: "public".to_string(),
686                foreign_table_name: "users".to_string(),
687                is_self: false,
688                constraint_name: "fk_posts_user".to_string(),
689                cols_and_fcols: vec![("user_id".to_string(), "id".to_string())],
690                one_to_one: false,
691            }])
692        });
693
694        mock.expect_query_routines().returning(|_| Ok(vec![]));
695        mock.expect_query_computed_fields()
696            .returning(|_| Ok(vec![]));
697        mock.expect_query_timezones().returning(|| Ok(vec![]));
698
699        let config = AppConfig::default();
700        let cache = SchemaCache::load(&mock, &config).await.unwrap();
701
702        assert_eq!(cache.table_count(), 2);
703        // 2 relationships: forward M2O (posts→users) + reverse O2M (users→posts)
704        assert_eq!(cache.relationship_count(), 2);
705
706        // Forward: posts → users (M2O)
707        let source = QualifiedIdentifier::new("public", "posts");
708        let rels = cache.find_relationships(&source);
709        assert_eq!(rels.len(), 1);
710        assert!(rels[0].is_to_one()); // M2O
711
712        // Reverse: users → posts (O2M)
713        let source_rev = QualifiedIdentifier::new("public", "users");
714        let rels_rev = cache.find_relationships(&source_rev);
715        assert_eq!(rels_rev.len(), 1);
716        assert!(!rels_rev[0].is_to_one()); // O2M is not to-one
717    }
718}