1pub mod db;
27pub mod media_handler;
28pub mod queries;
29pub mod relationship;
30pub mod representations;
31pub mod routine;
32pub mod table;
33
34pub use db::{ComputedFieldRow, DbIntrospector, RelationshipRow, RoutineRow, TableRow};
36pub use media_handler::{MediaHandler, MediaHandlerMap, ResolvedHandler};
37pub use relationship::{
38 AnyRelationship, Cardinality, ComputedRelationship, Junction, Relationship,
39};
40pub use representations::{DataRepresentation, RepresentationsMap};
41pub use routine::{PgType, ReturnType, Routine, RoutineParam, Volatility};
42pub use table::{Column, ComputedField, Table};
43
44use std::collections::{HashMap, HashSet};
45use std::sync::Arc;
46
47use arc_swap::ArcSwap;
48
49use crate::config::AppConfig;
50use crate::error::Error;
51use crate::types::QualifiedIdentifier;
52
53pub type TablesMap = HashMap<QualifiedIdentifier, Table>;
55
56pub type RelationshipsMap = HashMap<(QualifiedIdentifier, String), Vec<AnyRelationship>>;
59
60pub type RoutinesMap = HashMap<QualifiedIdentifier, Vec<Routine>>;
63
64#[derive(Debug, Clone)]
69pub struct SchemaCache {
70 pub tables: Arc<TablesMap>,
72 pub relationships: Arc<RelationshipsMap>,
74 pub routines: Arc<RoutinesMap>,
76 pub timezones: Arc<HashSet<String>>,
78 pub representations: Arc<RepresentationsMap>,
80 pub media_handlers: Arc<MediaHandlerMap>,
82}
83
84impl Default for SchemaCache {
85 fn default() -> Self {
86 Self::empty()
87 }
88}
89
90impl SchemaCache {
91 pub fn empty() -> Self {
93 Self {
94 tables: Arc::new(HashMap::new()),
95 relationships: Arc::new(HashMap::new()),
96 routines: Arc::new(HashMap::new()),
97 timezones: Arc::new(HashSet::new()),
98 representations: Arc::new(HashMap::new()),
99 media_handlers: Arc::new(HashMap::new()),
100 }
101 }
102
103 pub async fn load<I: DbIntrospector + ?Sized>(
105 introspector: &I,
106 config: &AppConfig,
107 ) -> Result<Self, Error> {
108 let schemas = &config.db_schemas;
109
110 tracing::info!("Loading schema cache for schemas: {:?}", schemas);
111
112 let mut all_schemas = config.db_schemas.clone();
114 for extra_schema in &config.db_extra_search_path {
115 if !all_schemas.contains(extra_schema) {
116 all_schemas.push(extra_schema.clone());
117 }
118 }
119
120 tracing::debug!("All schemas for computed fields query: {:?}", all_schemas);
121
122 let (tables_rows, rel_rows, routine_rows, computed_fields_rows, timezones) = tokio::try_join!(
124 introspector.query_tables(schemas),
125 introspector.query_relationships(),
126 introspector.query_routines(schemas),
127 introspector.query_computed_fields(&all_schemas),
128 introspector.query_timezones(),
129 )?;
130
131 tracing::debug!(
132 "Loaded: {} tables, {} relationships, {} routines, {} computed fields, {} timezones",
133 tables_rows.len(),
134 rel_rows.len(),
135 routine_rows.len(),
136 computed_fields_rows.len(),
137 timezones.len()
138 );
139
140 let mut tables = HashMap::with_capacity(tables_rows.len());
142 for row in tables_rows {
143 let table = row.into_table()?;
144 let qi = table.qi();
145 tables.insert(qi.clone(), table);
146 }
147
148 use crate::schema_cache::table::ComputedField;
150 use crate::types::QualifiedIdentifier as QI;
151
152 let mut attached_count = 0;
153 let mut not_found_count = 0;
154
155 for row in computed_fields_rows {
156 let table_qi = QI::new(&row.table_schema, &row.table_name);
157 if let Some(table) = tables.get_mut(&table_qi) {
158 let function_qi = QI::new(&row.function_schema, &row.function_name);
159 let computed_field = ComputedField {
160 function: function_qi,
161 return_type: row.return_type.into(),
162 returns_set: row.returns_set,
163 };
164 table
166 .computed_fields
167 .insert(row.function_name.clone().into(), computed_field);
168 tracing::trace!(
169 "Attached computed field '{}' to table {}.{}",
170 row.function_name,
171 row.table_schema,
172 row.table_name
173 );
174 attached_count += 1;
175 } else {
176 tracing::warn!(
177 "Computed field function {}.{} references non-existent table {}.{}",
178 row.function_schema,
179 row.function_name,
180 row.table_schema,
181 row.table_name
182 );
183 not_found_count += 1;
184 }
185 }
186
187 tracing::debug!(
188 "Attached {} computed fields to tables, {} referenced non-existent tables",
189 attached_count,
190 not_found_count
191 );
192
193 let mut relationships: RelationshipsMap = HashMap::new();
196 for row in rel_rows {
197 let rel = row.into_relationship();
198
199 let fwd_key = (rel.table.clone(), rel.table.schema.to_string());
201 let reverse = rel.reverse();
202 relationships
203 .entry(fwd_key)
204 .or_default()
205 .push(AnyRelationship::ForeignKey(rel));
206
207 let rev_key = (reverse.table.clone(), reverse.table.schema.to_string());
209 relationships
210 .entry(rev_key)
211 .or_default()
212 .push(AnyRelationship::ForeignKey(reverse));
213 }
214
215 let mut routines: RoutinesMap = HashMap::new();
217 for row in routine_rows {
218 let routine = row.into_routine()?;
219 let qi = routine.qi();
220 routines.entry(qi).or_default().push(routine);
221 }
222
223 let mut timezone_set: HashSet<String> = timezones.into_iter().collect();
225 timezone_set.insert("UTC".to_string());
226
227 Ok(Self {
228 tables: Arc::new(tables),
229 relationships: Arc::new(relationships),
230 routines: Arc::new(routines),
231 timezones: Arc::new(timezone_set),
232 representations: Arc::new(HashMap::new()),
233 media_handlers: Arc::new(HashMap::new()),
234 })
235 }
236
237 pub fn get_table(&self, qi: &QualifiedIdentifier) -> Option<&Table> {
239 self.tables.get(qi)
240 }
241
242 pub fn get_table_by_name(&self, schema: &str, name: &str) -> Option<&Table> {
244 let qi = QualifiedIdentifier::new(schema, name);
245 self.tables.get(&qi)
246 }
247
248 pub fn find_relationships(&self, source: &QualifiedIdentifier) -> &[AnyRelationship] {
250 let key = (source.clone(), source.schema.to_string());
251 self.relationships
252 .get(&key)
253 .map(|v| v.as_slice())
254 .unwrap_or(&[])
255 }
256
257 pub fn find_relationships_to(
259 &self,
260 source: &QualifiedIdentifier,
261 target_name: &str,
262 ) -> Vec<&AnyRelationship> {
263 self.find_relationships(source)
264 .iter()
265 .filter(|r| r.foreign_table().name.as_str() == target_name)
266 .collect()
267 }
268
269 pub fn get_routines(&self, qi: &QualifiedIdentifier) -> Option<&[Routine]> {
271 self.routines.get(qi).map(|v| v.as_slice())
272 }
273
274 pub fn get_routines_by_name(&self, schema: &str, name: &str) -> Option<&[Routine]> {
276 let qi = QualifiedIdentifier::new(schema, name);
277 self.routines.get(&qi).map(|v| v.as_slice())
278 }
279
280 pub fn is_valid_timezone(&self, tz: &str) -> bool {
282 self.timezones.contains(tz)
283 }
284
285 pub fn table_count(&self) -> usize {
287 self.tables.len()
288 }
289
290 pub fn relationship_count(&self) -> usize {
292 self.relationships.values().map(|v| v.len()).sum()
293 }
294
295 pub fn routine_count(&self) -> usize {
297 self.routines.values().map(|v| v.len()).sum()
298 }
299
300 pub fn summary(&self) -> String {
302 format!(
303 "{} tables, {} relationships, {} routines, {} timezones",
304 self.table_count(),
305 self.relationship_count(),
306 self.routine_count(),
307 self.timezones.len(),
308 )
309 }
310
311 pub fn tables_iter(&self) -> impl Iterator<Item = (&QualifiedIdentifier, &Table)> {
313 self.tables.iter()
314 }
315
316 pub fn tables_in_schema(&self, schema: &str) -> impl Iterator<Item = &Table> {
318 self.tables
319 .values()
320 .filter(move |t| t.schema.as_str() == schema)
321 }
322}
323
324#[derive(Debug)]
328pub struct SchemaCacheHolder {
329 inner: ArcSwap<Option<SchemaCache>>,
330}
331
332impl Default for SchemaCacheHolder {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338impl SchemaCacheHolder {
339 pub fn new() -> Self {
341 Self {
342 inner: ArcSwap::from_pointee(None),
343 }
344 }
345
346 pub fn with_cache(cache: SchemaCache) -> Self {
348 Self {
349 inner: ArcSwap::from_pointee(Some(cache)),
350 }
351 }
352
353 pub fn get(&self) -> Option<arc_swap::Guard<Arc<Option<SchemaCache>>>> {
357 let guard = self.inner.load();
358 if guard.is_some() { Some(guard) } else { None }
359 }
360
361 pub fn replace(&self, cache: SchemaCache) {
363 self.inner.store(Arc::new(Some(cache)));
364 }
365
366 pub fn clear(&self) {
368 self.inner.store(Arc::new(None));
369 }
370
371 pub fn is_loaded(&self) -> bool {
373 self.inner.load().is_some()
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::test_helpers::*;
381
382 fn create_test_cache() -> SchemaCache {
383 let mut tables = HashMap::new();
384
385 let users_table = test_table()
386 .schema("public")
387 .name("users")
388 .pk_col("id")
389 .column(test_column().name("id").data_type("integer").build())
390 .column(test_column().name("name").data_type("text").build())
391 .build();
392
393 let posts_table = test_table()
394 .schema("public")
395 .name("posts")
396 .pk_col("id")
397 .column(test_column().name("id").data_type("integer").build())
398 .column(test_column().name("user_id").data_type("integer").build())
399 .column(test_column().name("title").data_type("text").build())
400 .build();
401
402 tables.insert(users_table.qi(), users_table);
403 tables.insert(posts_table.qi(), posts_table);
404
405 let rel = test_relationship()
407 .table("public", "posts")
408 .foreign_table("public", "users")
409 .m2o("fk_posts_user", &[("user_id", "id")])
410 .build();
411
412 let mut relationships = HashMap::new();
413 let key = (
414 QualifiedIdentifier::new("public", "posts"),
415 "public".to_string(),
416 );
417 relationships.insert(key, vec![AnyRelationship::ForeignKey(rel)]);
418
419 let routine = test_routine()
421 .schema("public")
422 .name("get_user")
423 .param(test_param().name("user_id").pg_type("integer").build())
424 .returns_setof_composite("public", "users")
425 .build();
426
427 let mut routines = HashMap::new();
428 routines.insert(routine.qi(), vec![routine]);
429
430 let mut timezones = HashSet::new();
431 timezones.insert("UTC".to_string());
432 timezones.insert("America/New_York".to_string());
433
434 SchemaCache {
435 tables: Arc::new(tables),
436 relationships: Arc::new(relationships),
437 routines: Arc::new(routines),
438 timezones: Arc::new(timezones),
439 representations: Arc::new(HashMap::new()),
440 media_handlers: Arc::new(HashMap::new()),
441 }
442 }
443
444 #[test]
445 fn test_schema_cache_empty() {
446 let cache = SchemaCache::empty();
447 assert_eq!(cache.table_count(), 0);
448 assert_eq!(cache.relationship_count(), 0);
449 assert_eq!(cache.routine_count(), 0);
450 }
451
452 #[test]
453 fn test_schema_cache_get_table() {
454 let cache = create_test_cache();
455
456 let qi = QualifiedIdentifier::new("public", "users");
457 let table = cache.get_table(&qi).unwrap();
458 assert_eq!(table.name.as_str(), "users");
459 }
460
461 #[test]
462 fn test_schema_cache_get_table_by_name() {
463 let cache = create_test_cache();
464
465 let table = cache.get_table_by_name("public", "posts").unwrap();
466 assert_eq!(table.name.as_str(), "posts");
467 assert!(table.has_pk());
468 }
469
470 #[test]
471 fn test_schema_cache_get_table_not_found() {
472 let cache = create_test_cache();
473
474 let qi = QualifiedIdentifier::new("public", "nonexistent");
475 assert!(cache.get_table(&qi).is_none());
476 }
477
478 #[test]
479 fn test_schema_cache_find_relationships() {
480 let cache = create_test_cache();
481
482 let source = QualifiedIdentifier::new("public", "posts");
483 let rels = cache.find_relationships(&source);
484 assert_eq!(rels.len(), 1);
485 assert_eq!(rels[0].foreign_table().name.as_str(), "users");
486 }
487
488 #[test]
489 fn test_schema_cache_find_relationships_to() {
490 let cache = create_test_cache();
491
492 let source = QualifiedIdentifier::new("public", "posts");
493 let rels = cache.find_relationships_to(&source, "users");
494 assert_eq!(rels.len(), 1);
495
496 let rels = cache.find_relationships_to(&source, "nonexistent");
497 assert!(rels.is_empty());
498 }
499
500 #[test]
501 fn test_schema_cache_get_routines() {
502 let cache = create_test_cache();
503
504 let qi = QualifiedIdentifier::new("public", "get_user");
505 let routines = cache.get_routines(&qi).unwrap();
506 assert_eq!(routines.len(), 1);
507 assert!(routines[0].returns_set());
508 }
509
510 #[test]
511 fn test_schema_cache_get_routines_by_name() {
512 let cache = create_test_cache();
513
514 let routines = cache.get_routines_by_name("public", "get_user").unwrap();
515 assert_eq!(routines.len(), 1);
516 }
517
518 #[test]
519 fn test_schema_cache_is_valid_timezone() {
520 let cache = create_test_cache();
521
522 assert!(cache.is_valid_timezone("UTC"));
523 assert!(cache.is_valid_timezone("America/New_York"));
524 assert!(!cache.is_valid_timezone("Invalid/Zone"));
525 }
526
527 #[test]
528 fn test_schema_cache_counts() {
529 let cache = create_test_cache();
530
531 assert_eq!(cache.table_count(), 2);
532 assert_eq!(cache.relationship_count(), 1);
533 assert_eq!(cache.routine_count(), 1);
534 }
535
536 #[test]
537 fn test_schema_cache_summary() {
538 let cache = create_test_cache();
539
540 let summary = cache.summary();
541 assert!(summary.contains("2 tables"));
542 assert!(summary.contains("1 relationships"));
543 assert!(summary.contains("1 routines"));
544 }
545
546 #[test]
547 fn test_schema_cache_tables_iter() {
548 let cache = create_test_cache();
549
550 let table_names: Vec<_> = cache.tables_iter().map(|(_, t)| t.name.as_str()).collect();
551 assert!(table_names.contains(&"users"));
552 assert!(table_names.contains(&"posts"));
553 }
554
555 #[test]
556 fn test_schema_cache_tables_in_schema() {
557 let cache = create_test_cache();
558
559 let public_tables: Vec<_> = cache.tables_in_schema("public").collect();
560 assert_eq!(public_tables.len(), 2);
561
562 let other_tables: Vec<_> = cache.tables_in_schema("other").collect();
563 assert!(other_tables.is_empty());
564 }
565
566 #[test]
571 fn test_schema_cache_holder_new() {
572 let holder = SchemaCacheHolder::new();
573 assert!(!holder.is_loaded());
574 assert!(holder.get().is_none());
575 }
576
577 #[test]
578 fn test_schema_cache_holder_with_cache() {
579 let cache = create_test_cache();
580 let holder = SchemaCacheHolder::with_cache(cache);
581 assert!(holder.is_loaded());
582 assert!(holder.get().is_some());
583 }
584
585 #[test]
586 fn test_schema_cache_holder_replace() {
587 let holder = SchemaCacheHolder::new();
588 assert!(!holder.is_loaded());
589
590 let cache = create_test_cache();
591 holder.replace(cache);
592 assert!(holder.is_loaded());
593 }
594
595 #[test]
596 fn test_schema_cache_holder_clear() {
597 let cache = create_test_cache();
598 let holder = SchemaCacheHolder::with_cache(cache);
599 assert!(holder.is_loaded());
600
601 holder.clear();
602 assert!(!holder.is_loaded());
603 }
604
605 #[tokio::test]
610 async fn test_schema_cache_load_with_mock() {
611 use db::MockDbIntrospector;
612
613 let mut mock = MockDbIntrospector::new();
614
615 mock.expect_query_tables().returning(|_| {
617 Ok(vec![TableRow {
618 table_schema: "public".to_string(),
619 table_name: "test_table".to_string(),
620 table_description: None,
621 is_view: false,
622 insertable: true,
623 updatable: true,
624 deletable: true,
625 readable: true,
626 pk_cols: vec!["id".to_string()],
627 columns_json: r#"[{"name":"id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]}]"#.to_string(),
628 }])
629 });
630
631 mock.expect_query_relationships().returning(|| Ok(vec![]));
632 mock.expect_query_routines().returning(|_| Ok(vec![]));
633 mock.expect_query_computed_fields()
634 .returning(|_| Ok(vec![]));
635 mock.expect_query_timezones()
636 .returning(|| Ok(vec!["UTC".to_string()]));
637
638 let config = AppConfig::default();
639 let cache = SchemaCache::load(&mock, &config).await.unwrap();
640
641 assert_eq!(cache.table_count(), 1);
642 let table = cache.get_table_by_name("public", "test_table").unwrap();
643 assert!(table.has_pk());
644 }
645
646 #[tokio::test]
647 async fn test_schema_cache_load_with_relationships() {
648 use db::MockDbIntrospector;
649
650 let mut mock = MockDbIntrospector::new();
651
652 mock.expect_query_tables().returning(|_| {
653 Ok(vec![
654 TableRow {
655 table_schema: "public".to_string(),
656 table_name: "users".to_string(),
657 table_description: None,
658 is_view: false,
659 insertable: true,
660 updatable: true,
661 deletable: true,
662 readable: true,
663 pk_cols: vec!["id".to_string()],
664 columns_json: r#"[{"name":"id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]}]"#.to_string(),
665 },
666 TableRow {
667 table_schema: "public".to_string(),
668 table_name: "posts".to_string(),
669 table_description: None,
670 is_view: false,
671 insertable: true,
672 updatable: true,
673 deletable: true,
674 readable: true,
675 pk_cols: vec!["id".to_string()],
676 columns_json: r#"[{"name":"id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]},{"name":"user_id","description":null,"nullable":false,"data_type":"integer","nominal_type":"integer","max_length":null,"default":null,"enum_values":[]}]"#.to_string(),
677 },
678 ])
679 });
680
681 mock.expect_query_relationships().returning(|| {
682 Ok(vec![RelationshipRow {
683 table_schema: "public".to_string(),
684 table_name: "posts".to_string(),
685 foreign_table_schema: "public".to_string(),
686 foreign_table_name: "users".to_string(),
687 is_self: false,
688 constraint_name: "fk_posts_user".to_string(),
689 cols_and_fcols: vec![("user_id".to_string(), "id".to_string())],
690 one_to_one: false,
691 }])
692 });
693
694 mock.expect_query_routines().returning(|_| Ok(vec![]));
695 mock.expect_query_computed_fields()
696 .returning(|_| Ok(vec![]));
697 mock.expect_query_timezones().returning(|| Ok(vec![]));
698
699 let config = AppConfig::default();
700 let cache = SchemaCache::load(&mock, &config).await.unwrap();
701
702 assert_eq!(cache.table_count(), 2);
703 assert_eq!(cache.relationship_count(), 2);
705
706 let source = QualifiedIdentifier::new("public", "posts");
708 let rels = cache.find_relationships(&source);
709 assert_eq!(rels.len(), 1);
710 assert!(rels[0].is_to_one()); let source_rev = QualifiedIdentifier::new("public", "users");
714 let rels_rev = cache.find_relationships(&source_rev);
715 assert_eq!(rels_rev.len(), 1);
716 assert!(!rels_rev[0].is_to_one()); }
718}