1use std::sync::Arc;
19
20use async_trait::async_trait;
21use datafusion_common::{HashMap, TableReference, error::Result, not_impl_err};
22use datafusion_execution::config::SessionConfig;
23
24use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider};
25
26#[derive(Debug)]
30struct ResolvedSchemaProvider {
31 owner_name: Option<String>,
32 cached_tables: HashMap<String, Arc<dyn TableProvider>>,
33}
34#[async_trait]
35impl SchemaProvider for ResolvedSchemaProvider {
36 fn owner_name(&self) -> Option<&str> {
37 self.owner_name.as_deref()
38 }
39
40 fn as_any(&self) -> &dyn std::any::Any {
41 self
42 }
43
44 fn table_names(&self) -> Vec<String> {
45 self.cached_tables.keys().cloned().collect()
46 }
47
48 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
49 Ok(self.cached_tables.get(name).cloned())
50 }
51
52 fn register_table(
53 &self,
54 name: String,
55 _table: Arc<dyn TableProvider>,
56 ) -> Result<Option<Arc<dyn TableProvider>>> {
57 not_impl_err!(
58 "Attempt to register table '{name}' with ResolvedSchemaProvider which is not supported"
59 )
60 }
61
62 fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
63 not_impl_err!(
64 "Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported"
65 )
66 }
67
68 fn table_exist(&self, name: &str) -> bool {
69 self.cached_tables.contains_key(name)
70 }
71}
72
73struct ResolvedSchemaProviderBuilder {
75 owner_name: String,
76 async_provider: Arc<dyn AsyncSchemaProvider>,
77 cached_tables: HashMap<String, Option<Arc<dyn TableProvider>>>,
78}
79impl ResolvedSchemaProviderBuilder {
80 fn new(owner_name: String, async_provider: Arc<dyn AsyncSchemaProvider>) -> Self {
81 Self {
82 owner_name,
83 async_provider,
84 cached_tables: HashMap::new(),
85 }
86 }
87
88 async fn resolve_table(&mut self, table_name: &str) -> Result<()> {
89 if !self.cached_tables.contains_key(table_name) {
90 let resolved_table = self.async_provider.table(table_name).await?;
91 self.cached_tables
92 .insert(table_name.to_string(), resolved_table);
93 }
94 Ok(())
95 }
96
97 fn finish(self) -> Arc<dyn SchemaProvider> {
98 let cached_tables = self
99 .cached_tables
100 .into_iter()
101 .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value)))
102 .collect();
103 Arc::new(ResolvedSchemaProvider {
104 owner_name: Some(self.owner_name),
105 cached_tables,
106 })
107 }
108}
109
110#[derive(Debug)]
114struct ResolvedCatalogProvider {
115 cached_schemas: HashMap<String, Arc<dyn SchemaProvider>>,
116}
117impl CatalogProvider for ResolvedCatalogProvider {
118 fn as_any(&self) -> &dyn std::any::Any {
119 self
120 }
121
122 fn schema_names(&self) -> Vec<String> {
123 self.cached_schemas.keys().cloned().collect()
124 }
125
126 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
127 self.cached_schemas.get(name).cloned()
128 }
129}
130
131struct ResolvedCatalogProviderBuilder {
133 cached_schemas: HashMap<String, Option<ResolvedSchemaProviderBuilder>>,
134 async_provider: Arc<dyn AsyncCatalogProvider>,
135}
136impl ResolvedCatalogProviderBuilder {
137 fn new(async_provider: Arc<dyn AsyncCatalogProvider>) -> Self {
138 Self {
139 cached_schemas: HashMap::new(),
140 async_provider,
141 }
142 }
143 fn finish(self) -> Arc<dyn CatalogProvider> {
144 let cached_schemas = self
145 .cached_schemas
146 .into_iter()
147 .filter_map(|(key, maybe_value)| {
148 maybe_value.map(|value| (key, value.finish()))
149 })
150 .collect();
151 Arc::new(ResolvedCatalogProvider { cached_schemas })
152 }
153}
154
155#[derive(Debug)]
159struct ResolvedCatalogProviderList {
160 cached_catalogs: HashMap<String, Arc<dyn CatalogProvider>>,
161}
162impl CatalogProviderList for ResolvedCatalogProviderList {
163 fn as_any(&self) -> &dyn std::any::Any {
164 self
165 }
166
167 fn register_catalog(
168 &self,
169 _name: String,
170 _catalog: Arc<dyn CatalogProvider>,
171 ) -> Option<Arc<dyn CatalogProvider>> {
172 unimplemented!("resolved providers cannot handle registration APIs")
173 }
174
175 fn catalog_names(&self) -> Vec<String> {
176 self.cached_catalogs.keys().cloned().collect()
177 }
178
179 fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
180 self.cached_catalogs.get(name).cloned()
181 }
182}
183
184#[async_trait]
200pub trait AsyncSchemaProvider: Send + Sync {
201 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>>;
203 async fn resolve(
214 &self,
215 references: &[TableReference],
216 config: &SessionConfig,
217 catalog_name: &str,
218 schema_name: &str,
219 ) -> Result<Arc<dyn SchemaProvider>> {
220 let mut cached_tables = HashMap::<String, Option<Arc<dyn TableProvider>>>::new();
221
222 for reference in references {
223 let ref_catalog_name = reference
224 .catalog()
225 .unwrap_or(&config.options().catalog.default_catalog);
226
227 if ref_catalog_name != catalog_name {
229 continue;
230 }
231
232 let ref_schema_name = reference
233 .schema()
234 .unwrap_or(&config.options().catalog.default_schema);
235
236 if ref_schema_name != schema_name {
237 continue;
238 }
239
240 if !cached_tables.contains_key(reference.table()) {
241 let resolved_table = self.table(reference.table()).await?;
242 cached_tables.insert(reference.table().to_string(), resolved_table);
243 }
244 }
245
246 let cached_tables = cached_tables
247 .into_iter()
248 .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value)))
249 .collect();
250
251 Ok(Arc::new(ResolvedSchemaProvider {
252 cached_tables,
253 owner_name: Some(catalog_name.to_string()),
254 }))
255 }
256}
257
258#[async_trait]
265pub trait AsyncCatalogProvider: Send + Sync {
266 async fn schema(&self, name: &str) -> Result<Option<Arc<dyn AsyncSchemaProvider>>>;
268
269 async fn resolve(
279 &self,
280 references: &[TableReference],
281 config: &SessionConfig,
282 catalog_name: &str,
283 ) -> Result<Arc<dyn CatalogProvider>> {
284 let mut cached_schemas =
285 HashMap::<String, Option<ResolvedSchemaProviderBuilder>>::new();
286
287 for reference in references {
288 let ref_catalog_name = reference
289 .catalog()
290 .unwrap_or(&config.options().catalog.default_catalog);
291
292 if ref_catalog_name != catalog_name {
294 continue;
295 }
296
297 let schema_name = reference
298 .schema()
299 .unwrap_or(&config.options().catalog.default_schema);
300
301 let schema = if let Some(schema) = cached_schemas.get_mut(schema_name) {
302 schema
303 } else {
304 let resolved_schema = self.schema(schema_name).await?;
305 let resolved_schema = resolved_schema.map(|resolved_schema| {
306 ResolvedSchemaProviderBuilder::new(
307 catalog_name.to_string(),
308 resolved_schema,
309 )
310 });
311 cached_schemas.insert(schema_name.to_string(), resolved_schema);
312 cached_schemas.get_mut(schema_name).unwrap()
313 };
314
315 let Some(schema) = schema else { continue };
317
318 schema.resolve_table(reference.table()).await?;
319 }
320
321 let cached_schemas = cached_schemas
322 .into_iter()
323 .filter_map(|(key, maybe_builder)| {
324 maybe_builder.map(|schema_builder| (key, schema_builder.finish()))
325 })
326 .collect::<HashMap<_, _>>();
327
328 Ok(Arc::new(ResolvedCatalogProvider { cached_schemas }))
329 }
330}
331
332#[async_trait]
338pub trait AsyncCatalogProviderList: Send + Sync {
339 async fn catalog(&self, name: &str) -> Result<Option<Arc<dyn AsyncCatalogProvider>>>;
341
342 async fn resolve(
352 &self,
353 references: &[TableReference],
354 config: &SessionConfig,
355 ) -> Result<Arc<dyn CatalogProviderList>> {
356 let mut cached_catalogs =
357 HashMap::<String, Option<ResolvedCatalogProviderBuilder>>::new();
358
359 for reference in references {
360 let catalog_name = reference
361 .catalog()
362 .unwrap_or(&config.options().catalog.default_catalog);
363
364 let catalog = if let Some(catalog) = cached_catalogs.get_mut(catalog_name) {
374 catalog
375 } else {
376 let resolved_catalog = self.catalog(catalog_name).await?;
377 let resolved_catalog =
378 resolved_catalog.map(ResolvedCatalogProviderBuilder::new);
379 cached_catalogs.insert(catalog_name.to_string(), resolved_catalog);
380 cached_catalogs.get_mut(catalog_name).unwrap()
381 };
382
383 let Some(catalog) = catalog else { continue };
385
386 let schema_name = reference
387 .schema()
388 .unwrap_or(&config.options().catalog.default_schema);
389
390 let schema = if let Some(schema) = catalog.cached_schemas.get_mut(schema_name)
391 {
392 schema
393 } else {
394 let resolved_schema = catalog.async_provider.schema(schema_name).await?;
395 let resolved_schema = resolved_schema.map(|async_schema| {
396 ResolvedSchemaProviderBuilder::new(
397 catalog_name.to_string(),
398 async_schema,
399 )
400 });
401 catalog
402 .cached_schemas
403 .insert(schema_name.to_string(), resolved_schema);
404 catalog.cached_schemas.get_mut(schema_name).unwrap()
405 };
406
407 let Some(schema) = schema else { continue };
409
410 schema.resolve_table(reference.table()).await?;
411 }
412
413 let cached_catalogs = cached_catalogs
415 .into_iter()
416 .filter_map(|(key, maybe_builder)| {
417 maybe_builder.map(|catalog_builder| (key, catalog_builder.finish()))
418 })
419 .collect::<HashMap<_, _>>();
420
421 Ok(Arc::new(ResolvedCatalogProviderList { cached_catalogs }))
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use std::{
428 any::Any,
429 sync::{
430 Arc,
431 atomic::{AtomicU32, Ordering},
432 },
433 };
434
435 use arrow::datatypes::SchemaRef;
436 use async_trait::async_trait;
437 use datafusion_common::{Statistics, TableReference, error::Result};
438 use datafusion_execution::config::SessionConfig;
439 use datafusion_expr::{Expr, TableType};
440 use datafusion_physical_plan::ExecutionPlan;
441
442 use crate::{Session, TableProvider};
443
444 use super::{AsyncCatalogProvider, AsyncCatalogProviderList, AsyncSchemaProvider};
445
446 #[derive(Debug)]
447 struct MockTableProvider {}
448 #[async_trait]
449 impl TableProvider for MockTableProvider {
450 fn as_any(&self) -> &dyn Any {
451 self
452 }
453
454 fn schema(&self) -> SchemaRef {
456 unimplemented!()
457 }
458
459 fn table_type(&self) -> TableType {
460 unimplemented!()
461 }
462
463 async fn scan(
464 &self,
465 _state: &dyn Session,
466 _projection: Option<&Vec<usize>>,
467 _filters: &[Expr],
468 _limit: Option<usize>,
469 ) -> Result<Arc<dyn ExecutionPlan>> {
470 unimplemented!()
471 }
472
473 fn statistics(&self) -> Option<Statistics> {
474 unimplemented!()
475 }
476 }
477
478 #[derive(Default)]
479 struct MockAsyncSchemaProvider {
480 lookup_count: AtomicU32,
481 }
482
483 const MOCK_CATALOG: &str = "mock_catalog";
484 const MOCK_SCHEMA: &str = "mock_schema";
485 const MOCK_TABLE: &str = "mock_table";
486
487 #[async_trait]
488 impl AsyncSchemaProvider for MockAsyncSchemaProvider {
489 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
490 self.lookup_count.fetch_add(1, Ordering::Release);
491 if name == MOCK_TABLE {
492 Ok(Some(Arc::new(MockTableProvider {})))
493 } else {
494 Ok(None)
495 }
496 }
497 }
498
499 fn test_config() -> SessionConfig {
500 let mut config = SessionConfig::default();
501 config.options_mut().catalog.default_catalog = MOCK_CATALOG.to_string();
502 config.options_mut().catalog.default_schema = MOCK_SCHEMA.to_string();
503 config
504 }
505
506 #[tokio::test]
507 async fn test_async_schema_provider_resolve() {
508 async fn check(
509 refs: Vec<TableReference>,
510 expected_lookup_count: u32,
511 found_tables: &[&str],
512 not_found_tables: &[&str],
513 ) {
514 let async_provider = MockAsyncSchemaProvider::default();
515 let cached_provider = async_provider
516 .resolve(&refs, &test_config(), MOCK_CATALOG, MOCK_SCHEMA)
517 .await
518 .unwrap();
519
520 assert_eq!(
521 async_provider.lookup_count.load(Ordering::Acquire),
522 expected_lookup_count
523 );
524
525 for table_ref in found_tables {
526 let table = cached_provider.table(table_ref).await.unwrap();
527 assert!(table.is_some());
528 }
529
530 for table_ref in not_found_tables {
531 assert!(cached_provider.table(table_ref).await.unwrap().is_none());
532 }
533 }
534
535 check(
537 vec![
538 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
539 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"),
540 ],
541 2,
542 &[MOCK_TABLE],
543 &["not_exists"],
544 )
545 .await;
546
547 check(
549 vec![
550 TableReference::full(MOCK_CATALOG, "foo", MOCK_TABLE),
551 TableReference::full("foo", MOCK_SCHEMA, MOCK_TABLE),
552 ],
553 0,
554 &[],
555 &[MOCK_TABLE],
556 )
557 .await;
558
559 check(
561 vec![
562 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
563 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
564 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"),
565 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"),
566 ],
567 2,
568 &[MOCK_TABLE],
569 &["not_exists"],
570 )
571 .await;
572 }
573
574 #[derive(Default)]
575 struct MockAsyncCatalogProvider {
576 lookup_count: AtomicU32,
577 }
578
579 #[async_trait]
580 impl AsyncCatalogProvider for MockAsyncCatalogProvider {
581 async fn schema(
582 &self,
583 name: &str,
584 ) -> Result<Option<Arc<dyn AsyncSchemaProvider>>> {
585 self.lookup_count.fetch_add(1, Ordering::Release);
586 if name == MOCK_SCHEMA {
587 Ok(Some(Arc::new(MockAsyncSchemaProvider::default())))
588 } else {
589 Ok(None)
590 }
591 }
592 }
593
594 #[tokio::test]
595 async fn test_async_catalog_provider_resolve() {
596 async fn check(
597 refs: Vec<TableReference>,
598 expected_lookup_count: u32,
599 found_schemas: &[&str],
600 not_found_schemas: &[&str],
601 ) {
602 let async_provider = MockAsyncCatalogProvider::default();
603 let cached_provider = async_provider
604 .resolve(&refs, &test_config(), MOCK_CATALOG)
605 .await
606 .unwrap();
607
608 assert_eq!(
609 async_provider.lookup_count.load(Ordering::Acquire),
610 expected_lookup_count
611 );
612
613 for schema_ref in found_schemas {
614 let schema = cached_provider.schema(schema_ref);
615 assert!(schema.is_some());
616 }
617
618 for schema_ref in not_found_schemas {
619 assert!(cached_provider.schema(schema_ref).is_none());
620 }
621 }
622
623 check(
625 vec![
626 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"),
627 TableReference::full(MOCK_CATALOG, "not_exists", "x"),
628 ],
629 2,
630 &[MOCK_SCHEMA],
631 &["not_exists"],
632 )
633 .await;
634
635 check(
637 vec![TableReference::full("foo", MOCK_SCHEMA, "x")],
638 0,
639 &[],
640 &[MOCK_SCHEMA],
641 )
642 .await;
643
644 check(
646 vec![
647 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"),
648 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"),
649 TableReference::full(MOCK_CATALOG, "not_exists", "x"),
650 TableReference::full(MOCK_CATALOG, "not_exists", "x"),
651 ],
652 2,
653 &[MOCK_SCHEMA],
654 &["not_exists"],
655 )
656 .await;
657 }
658
659 #[derive(Default)]
660 struct MockAsyncCatalogProviderList {
661 lookup_count: AtomicU32,
662 }
663
664 #[async_trait]
665 impl AsyncCatalogProviderList for MockAsyncCatalogProviderList {
666 async fn catalog(
667 &self,
668 name: &str,
669 ) -> Result<Option<Arc<dyn AsyncCatalogProvider>>> {
670 self.lookup_count.fetch_add(1, Ordering::Release);
671 if name == MOCK_CATALOG {
672 Ok(Some(Arc::new(MockAsyncCatalogProvider::default())))
673 } else {
674 Ok(None)
675 }
676 }
677 }
678
679 #[tokio::test]
680 async fn test_async_catalog_provider_list_resolve() {
681 async fn check(
682 refs: Vec<TableReference>,
683 expected_lookup_count: u32,
684 found_catalogs: &[&str],
685 not_found_catalogs: &[&str],
686 ) {
687 let async_provider = MockAsyncCatalogProviderList::default();
688 let cached_provider =
689 async_provider.resolve(&refs, &test_config()).await.unwrap();
690
691 assert_eq!(
692 async_provider.lookup_count.load(Ordering::Acquire),
693 expected_lookup_count
694 );
695
696 for catalog_ref in found_catalogs {
697 let catalog = cached_provider.catalog(catalog_ref);
698 assert!(catalog.is_some());
699 }
700
701 for catalog_ref in not_found_catalogs {
702 assert!(cached_provider.catalog(catalog_ref).is_none());
703 }
704 }
705
706 check(
708 vec![
709 TableReference::full(MOCK_CATALOG, "x", "x"),
710 TableReference::full("not_exists", "x", "x"),
711 ],
712 2,
713 &[MOCK_CATALOG],
714 &["not_exists"],
715 )
716 .await;
717
718 check(
720 vec![
721 TableReference::full(MOCK_CATALOG, "x", "x"),
722 TableReference::full(MOCK_CATALOG, "x", "x"),
723 TableReference::full("not_exists", "x", "x"),
724 TableReference::full("not_exists", "x", "x"),
725 ],
726 2,
727 &[MOCK_CATALOG],
728 &["not_exists"],
729 )
730 .await;
731 }
732
733 #[tokio::test]
734 async fn test_defaults() {
735 for table_ref in &[
736 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
737 TableReference::partial(MOCK_SCHEMA, MOCK_TABLE),
738 TableReference::bare(MOCK_TABLE),
739 ] {
740 let async_provider = MockAsyncCatalogProviderList::default();
741 let cached_provider = async_provider
742 .resolve(std::slice::from_ref(table_ref), &test_config())
743 .await
744 .unwrap();
745
746 let catalog = cached_provider
747 .catalog(table_ref.catalog().unwrap_or(MOCK_CATALOG))
748 .unwrap();
749 let schema = catalog
750 .schema(table_ref.schema().unwrap_or(MOCK_SCHEMA))
751 .unwrap();
752 assert!(schema.table(table_ref.table()).await.unwrap().is_some());
753 }
754 }
755}