1use std::sync::Arc;
19
20use async_trait::async_trait;
21use datafusion_common::{HashMap, TableReference, error::Result, not_impl_err};
22use datafusion_execution::config::SessionConfig;
23
24use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider};
25
26#[derive(Debug)]
30struct ResolvedSchemaProvider {
31 owner_name: Option<String>,
32 cached_tables: HashMap<String, Arc<dyn TableProvider>>,
33}
34#[async_trait]
35impl SchemaProvider for ResolvedSchemaProvider {
36 fn owner_name(&self) -> Option<&str> {
37 self.owner_name.as_deref()
38 }
39
40 fn table_names(&self) -> Vec<String> {
41 self.cached_tables.keys().cloned().collect()
42 }
43
44 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
45 Ok(self.cached_tables.get(name).cloned())
46 }
47
48 fn register_table(
49 &self,
50 name: String,
51 _table: Arc<dyn TableProvider>,
52 ) -> Result<Option<Arc<dyn TableProvider>>> {
53 not_impl_err!(
54 "Attempt to register table '{name}' with ResolvedSchemaProvider which is not supported"
55 )
56 }
57
58 fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
59 not_impl_err!(
60 "Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported"
61 )
62 }
63
64 fn table_exist(&self, name: &str) -> bool {
65 self.cached_tables.contains_key(name)
66 }
67}
68
69struct ResolvedSchemaProviderBuilder {
71 owner_name: String,
72 async_provider: Arc<dyn AsyncSchemaProvider>,
73 cached_tables: HashMap<String, Option<Arc<dyn TableProvider>>>,
74}
75impl ResolvedSchemaProviderBuilder {
76 fn new(owner_name: String, async_provider: Arc<dyn AsyncSchemaProvider>) -> Self {
77 Self {
78 owner_name,
79 async_provider,
80 cached_tables: HashMap::new(),
81 }
82 }
83
84 async fn resolve_table(&mut self, table_name: &str) -> Result<()> {
85 if !self.cached_tables.contains_key(table_name) {
86 let resolved_table = self.async_provider.table(table_name).await?;
87 self.cached_tables
88 .insert(table_name.to_string(), resolved_table);
89 }
90 Ok(())
91 }
92
93 fn finish(self) -> Arc<dyn SchemaProvider> {
94 let cached_tables = self
95 .cached_tables
96 .into_iter()
97 .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value)))
98 .collect();
99 Arc::new(ResolvedSchemaProvider {
100 owner_name: Some(self.owner_name),
101 cached_tables,
102 })
103 }
104}
105
106#[derive(Debug)]
110struct ResolvedCatalogProvider {
111 cached_schemas: HashMap<String, Arc<dyn SchemaProvider>>,
112}
113impl CatalogProvider for ResolvedCatalogProvider {
114 fn schema_names(&self) -> Vec<String> {
115 self.cached_schemas.keys().cloned().collect()
116 }
117
118 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
119 self.cached_schemas.get(name).cloned()
120 }
121}
122
123struct ResolvedCatalogProviderBuilder {
125 cached_schemas: HashMap<String, Option<ResolvedSchemaProviderBuilder>>,
126 async_provider: Arc<dyn AsyncCatalogProvider>,
127}
128impl ResolvedCatalogProviderBuilder {
129 fn new(async_provider: Arc<dyn AsyncCatalogProvider>) -> Self {
130 Self {
131 cached_schemas: HashMap::new(),
132 async_provider,
133 }
134 }
135 fn finish(self) -> Arc<dyn CatalogProvider> {
136 let cached_schemas = self
137 .cached_schemas
138 .into_iter()
139 .filter_map(|(key, maybe_value)| {
140 maybe_value.map(|value| (key, value.finish()))
141 })
142 .collect();
143 Arc::new(ResolvedCatalogProvider { cached_schemas })
144 }
145}
146
147#[derive(Debug)]
151struct ResolvedCatalogProviderList {
152 cached_catalogs: HashMap<String, Arc<dyn CatalogProvider>>,
153}
154impl CatalogProviderList for ResolvedCatalogProviderList {
155 fn register_catalog(
156 &self,
157 _name: String,
158 _catalog: Arc<dyn CatalogProvider>,
159 ) -> Option<Arc<dyn CatalogProvider>> {
160 unimplemented!("resolved providers cannot handle registration APIs")
161 }
162
163 fn catalog_names(&self) -> Vec<String> {
164 self.cached_catalogs.keys().cloned().collect()
165 }
166
167 fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
168 self.cached_catalogs.get(name).cloned()
169 }
170}
171
172#[async_trait]
188pub trait AsyncSchemaProvider: Send + Sync {
189 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>>;
191 async fn resolve(
202 &self,
203 references: &[TableReference],
204 config: &SessionConfig,
205 catalog_name: &str,
206 schema_name: &str,
207 ) -> Result<Arc<dyn SchemaProvider>> {
208 let mut cached_tables = HashMap::<String, Option<Arc<dyn TableProvider>>>::new();
209
210 for reference in references {
211 let ref_catalog_name = reference
212 .catalog()
213 .unwrap_or(&config.options().catalog.default_catalog);
214
215 if ref_catalog_name != catalog_name {
217 continue;
218 }
219
220 let ref_schema_name = reference
221 .schema()
222 .unwrap_or(&config.options().catalog.default_schema);
223
224 if ref_schema_name != schema_name {
225 continue;
226 }
227
228 if !cached_tables.contains_key(reference.table()) {
229 let resolved_table = self.table(reference.table()).await?;
230 cached_tables.insert(reference.table().to_string(), resolved_table);
231 }
232 }
233
234 let cached_tables = cached_tables
235 .into_iter()
236 .filter_map(|(key, maybe_value)| maybe_value.map(|value| (key, value)))
237 .collect();
238
239 Ok(Arc::new(ResolvedSchemaProvider {
240 cached_tables,
241 owner_name: Some(catalog_name.to_string()),
242 }))
243 }
244}
245
246#[async_trait]
253pub trait AsyncCatalogProvider: Send + Sync {
254 async fn schema(&self, name: &str) -> Result<Option<Arc<dyn AsyncSchemaProvider>>>;
256
257 async fn resolve(
267 &self,
268 references: &[TableReference],
269 config: &SessionConfig,
270 catalog_name: &str,
271 ) -> Result<Arc<dyn CatalogProvider>> {
272 let mut cached_schemas =
273 HashMap::<String, Option<ResolvedSchemaProviderBuilder>>::new();
274
275 for reference in references {
276 let ref_catalog_name = reference
277 .catalog()
278 .unwrap_or(&config.options().catalog.default_catalog);
279
280 if ref_catalog_name != catalog_name {
282 continue;
283 }
284
285 let schema_name = reference
286 .schema()
287 .unwrap_or(&config.options().catalog.default_schema);
288
289 let schema = if let Some(schema) = cached_schemas.get_mut(schema_name) {
290 schema
291 } else {
292 let resolved_schema = self.schema(schema_name).await?;
293 let resolved_schema = resolved_schema.map(|resolved_schema| {
294 ResolvedSchemaProviderBuilder::new(
295 catalog_name.to_string(),
296 resolved_schema,
297 )
298 });
299 cached_schemas.insert(schema_name.to_string(), resolved_schema);
300 cached_schemas.get_mut(schema_name).unwrap()
301 };
302
303 let Some(schema) = schema else { continue };
305
306 schema.resolve_table(reference.table()).await?;
307 }
308
309 let cached_schemas = cached_schemas
310 .into_iter()
311 .filter_map(|(key, maybe_builder)| {
312 maybe_builder.map(|schema_builder| (key, schema_builder.finish()))
313 })
314 .collect::<HashMap<_, _>>();
315
316 Ok(Arc::new(ResolvedCatalogProvider { cached_schemas }))
317 }
318}
319
320#[async_trait]
326pub trait AsyncCatalogProviderList: Send + Sync {
327 async fn catalog(&self, name: &str) -> Result<Option<Arc<dyn AsyncCatalogProvider>>>;
329
330 async fn resolve(
340 &self,
341 references: &[TableReference],
342 config: &SessionConfig,
343 ) -> Result<Arc<dyn CatalogProviderList>> {
344 let mut cached_catalogs =
345 HashMap::<String, Option<ResolvedCatalogProviderBuilder>>::new();
346
347 for reference in references {
348 let catalog_name = reference
349 .catalog()
350 .unwrap_or(&config.options().catalog.default_catalog);
351
352 let catalog = if let Some(catalog) = cached_catalogs.get_mut(catalog_name) {
362 catalog
363 } else {
364 let resolved_catalog = self.catalog(catalog_name).await?;
365 let resolved_catalog =
366 resolved_catalog.map(ResolvedCatalogProviderBuilder::new);
367 cached_catalogs.insert(catalog_name.to_string(), resolved_catalog);
368 cached_catalogs.get_mut(catalog_name).unwrap()
369 };
370
371 let Some(catalog) = catalog else { continue };
373
374 let schema_name = reference
375 .schema()
376 .unwrap_or(&config.options().catalog.default_schema);
377
378 let schema = if let Some(schema) = catalog.cached_schemas.get_mut(schema_name)
379 {
380 schema
381 } else {
382 let resolved_schema = catalog.async_provider.schema(schema_name).await?;
383 let resolved_schema = resolved_schema.map(|async_schema| {
384 ResolvedSchemaProviderBuilder::new(
385 catalog_name.to_string(),
386 async_schema,
387 )
388 });
389 catalog
390 .cached_schemas
391 .insert(schema_name.to_string(), resolved_schema);
392 catalog.cached_schemas.get_mut(schema_name).unwrap()
393 };
394
395 let Some(schema) = schema else { continue };
397
398 schema.resolve_table(reference.table()).await?;
399 }
400
401 let cached_catalogs = cached_catalogs
403 .into_iter()
404 .filter_map(|(key, maybe_builder)| {
405 maybe_builder.map(|catalog_builder| (key, catalog_builder.finish()))
406 })
407 .collect::<HashMap<_, _>>();
408
409 Ok(Arc::new(ResolvedCatalogProviderList { cached_catalogs }))
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use std::sync::{
416 Arc,
417 atomic::{AtomicU32, Ordering},
418 };
419
420 use arrow::datatypes::SchemaRef;
421 use async_trait::async_trait;
422 use datafusion_common::{Statistics, TableReference, error::Result};
423 use datafusion_execution::config::SessionConfig;
424 use datafusion_expr::{Expr, TableType};
425 use datafusion_physical_plan::ExecutionPlan;
426
427 use crate::{Session, TableProvider};
428
429 use super::{AsyncCatalogProvider, AsyncCatalogProviderList, AsyncSchemaProvider};
430
431 #[derive(Debug)]
432 struct MockTableProvider {}
433 #[async_trait]
434 impl TableProvider for MockTableProvider {
435 fn schema(&self) -> SchemaRef {
437 unimplemented!()
438 }
439
440 fn table_type(&self) -> TableType {
441 unimplemented!()
442 }
443
444 async fn scan(
445 &self,
446 _state: &dyn Session,
447 _projection: Option<&Vec<usize>>,
448 _filters: &[Expr],
449 _limit: Option<usize>,
450 ) -> Result<Arc<dyn ExecutionPlan>> {
451 unimplemented!()
452 }
453
454 fn statistics(&self) -> Option<Statistics> {
455 unimplemented!()
456 }
457 }
458
459 #[derive(Default)]
460 struct MockAsyncSchemaProvider {
461 lookup_count: AtomicU32,
462 }
463
464 const MOCK_CATALOG: &str = "mock_catalog";
465 const MOCK_SCHEMA: &str = "mock_schema";
466 const MOCK_TABLE: &str = "mock_table";
467
468 #[async_trait]
469 impl AsyncSchemaProvider for MockAsyncSchemaProvider {
470 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
471 self.lookup_count.fetch_add(1, Ordering::Release);
472 if name == MOCK_TABLE {
473 Ok(Some(Arc::new(MockTableProvider {})))
474 } else {
475 Ok(None)
476 }
477 }
478 }
479
480 fn test_config() -> SessionConfig {
481 let mut config = SessionConfig::default();
482 config.options_mut().catalog.default_catalog = MOCK_CATALOG.to_string();
483 config.options_mut().catalog.default_schema = MOCK_SCHEMA.to_string();
484 config
485 }
486
487 #[tokio::test]
488 async fn test_async_schema_provider_resolve() {
489 async fn check(
490 refs: Vec<TableReference>,
491 expected_lookup_count: u32,
492 found_tables: &[&str],
493 not_found_tables: &[&str],
494 ) {
495 let async_provider = MockAsyncSchemaProvider::default();
496 let cached_provider = async_provider
497 .resolve(&refs, &test_config(), MOCK_CATALOG, MOCK_SCHEMA)
498 .await
499 .unwrap();
500
501 assert_eq!(
502 async_provider.lookup_count.load(Ordering::Acquire),
503 expected_lookup_count
504 );
505
506 for table_ref in found_tables {
507 let table = cached_provider.table(table_ref).await.unwrap();
508 assert!(table.is_some());
509 }
510
511 for table_ref in not_found_tables {
512 assert!(cached_provider.table(table_ref).await.unwrap().is_none());
513 }
514 }
515
516 check(
518 vec![
519 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
520 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"),
521 ],
522 2,
523 &[MOCK_TABLE],
524 &["not_exists"],
525 )
526 .await;
527
528 check(
530 vec![
531 TableReference::full(MOCK_CATALOG, "foo", MOCK_TABLE),
532 TableReference::full("foo", MOCK_SCHEMA, MOCK_TABLE),
533 ],
534 0,
535 &[],
536 &[MOCK_TABLE],
537 )
538 .await;
539
540 check(
542 vec![
543 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
544 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
545 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"),
546 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "not_exists"),
547 ],
548 2,
549 &[MOCK_TABLE],
550 &["not_exists"],
551 )
552 .await;
553 }
554
555 #[derive(Default)]
556 struct MockAsyncCatalogProvider {
557 lookup_count: AtomicU32,
558 }
559
560 #[async_trait]
561 impl AsyncCatalogProvider for MockAsyncCatalogProvider {
562 async fn schema(
563 &self,
564 name: &str,
565 ) -> Result<Option<Arc<dyn AsyncSchemaProvider>>> {
566 self.lookup_count.fetch_add(1, Ordering::Release);
567 if name == MOCK_SCHEMA {
568 Ok(Some(Arc::new(MockAsyncSchemaProvider::default())))
569 } else {
570 Ok(None)
571 }
572 }
573 }
574
575 #[tokio::test]
576 async fn test_async_catalog_provider_resolve() {
577 async fn check(
578 refs: Vec<TableReference>,
579 expected_lookup_count: u32,
580 found_schemas: &[&str],
581 not_found_schemas: &[&str],
582 ) {
583 let async_provider = MockAsyncCatalogProvider::default();
584 let cached_provider = async_provider
585 .resolve(&refs, &test_config(), MOCK_CATALOG)
586 .await
587 .unwrap();
588
589 assert_eq!(
590 async_provider.lookup_count.load(Ordering::Acquire),
591 expected_lookup_count
592 );
593
594 for schema_ref in found_schemas {
595 let schema = cached_provider.schema(schema_ref);
596 assert!(schema.is_some());
597 }
598
599 for schema_ref in not_found_schemas {
600 assert!(cached_provider.schema(schema_ref).is_none());
601 }
602 }
603
604 check(
606 vec![
607 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"),
608 TableReference::full(MOCK_CATALOG, "not_exists", "x"),
609 ],
610 2,
611 &[MOCK_SCHEMA],
612 &["not_exists"],
613 )
614 .await;
615
616 check(
618 vec![TableReference::full("foo", MOCK_SCHEMA, "x")],
619 0,
620 &[],
621 &[MOCK_SCHEMA],
622 )
623 .await;
624
625 check(
627 vec![
628 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"),
629 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, "x"),
630 TableReference::full(MOCK_CATALOG, "not_exists", "x"),
631 TableReference::full(MOCK_CATALOG, "not_exists", "x"),
632 ],
633 2,
634 &[MOCK_SCHEMA],
635 &["not_exists"],
636 )
637 .await;
638 }
639
640 #[derive(Default)]
641 struct MockAsyncCatalogProviderList {
642 lookup_count: AtomicU32,
643 }
644
645 #[async_trait]
646 impl AsyncCatalogProviderList for MockAsyncCatalogProviderList {
647 async fn catalog(
648 &self,
649 name: &str,
650 ) -> Result<Option<Arc<dyn AsyncCatalogProvider>>> {
651 self.lookup_count.fetch_add(1, Ordering::Release);
652 if name == MOCK_CATALOG {
653 Ok(Some(Arc::new(MockAsyncCatalogProvider::default())))
654 } else {
655 Ok(None)
656 }
657 }
658 }
659
660 #[tokio::test]
661 async fn test_async_catalog_provider_list_resolve() {
662 async fn check(
663 refs: Vec<TableReference>,
664 expected_lookup_count: u32,
665 found_catalogs: &[&str],
666 not_found_catalogs: &[&str],
667 ) {
668 let async_provider = MockAsyncCatalogProviderList::default();
669 let cached_provider =
670 async_provider.resolve(&refs, &test_config()).await.unwrap();
671
672 assert_eq!(
673 async_provider.lookup_count.load(Ordering::Acquire),
674 expected_lookup_count
675 );
676
677 for catalog_ref in found_catalogs {
678 let catalog = cached_provider.catalog(catalog_ref);
679 assert!(catalog.is_some());
680 }
681
682 for catalog_ref in not_found_catalogs {
683 assert!(cached_provider.catalog(catalog_ref).is_none());
684 }
685 }
686
687 check(
689 vec![
690 TableReference::full(MOCK_CATALOG, "x", "x"),
691 TableReference::full("not_exists", "x", "x"),
692 ],
693 2,
694 &[MOCK_CATALOG],
695 &["not_exists"],
696 )
697 .await;
698
699 check(
701 vec![
702 TableReference::full(MOCK_CATALOG, "x", "x"),
703 TableReference::full(MOCK_CATALOG, "x", "x"),
704 TableReference::full("not_exists", "x", "x"),
705 TableReference::full("not_exists", "x", "x"),
706 ],
707 2,
708 &[MOCK_CATALOG],
709 &["not_exists"],
710 )
711 .await;
712 }
713
714 #[tokio::test]
715 async fn test_defaults() {
716 for table_ref in &[
717 TableReference::full(MOCK_CATALOG, MOCK_SCHEMA, MOCK_TABLE),
718 TableReference::partial(MOCK_SCHEMA, MOCK_TABLE),
719 TableReference::bare(MOCK_TABLE),
720 ] {
721 let async_provider = MockAsyncCatalogProviderList::default();
722 let cached_provider = async_provider
723 .resolve(std::slice::from_ref(table_ref), &test_config())
724 .await
725 .unwrap();
726
727 let catalog = cached_provider
728 .catalog(table_ref.catalog().unwrap_or(MOCK_CATALOG))
729 .unwrap();
730 let schema = catalog
731 .schema(table_ref.schema().unwrap_or(MOCK_SCHEMA))
732 .unwrap();
733 assert!(schema.table(table_ref.table()).await.unwrap().is_some());
734 }
735 }
736}