Skip to main content

deltalake_catalog_unity/
datafusion.rs

1//! Datafusion integration for UnityCatalog
2
3use chrono::prelude::*;
4use dashmap::DashMap;
5use datafusion::catalog::SchemaProvider;
6use datafusion::catalog::{CatalogProvider, CatalogProviderList};
7use datafusion::common::DataFusionError;
8use datafusion::datasource::TableProvider;
9use moka::Expiry;
10use moka::future::Cache;
11use std::any::Any;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tracing::error;
15
16use super::models::{
17    GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse,
18    TableTempCredentialsResponse, TemporaryTableCredentials,
19};
20use super::{DataCatalogResult, UnityCatalog, UnityCatalogError};
21use deltalake_core::{DeltaTableBuilder, ensure_table_uri};
22
23/// In-memory list of catalogs populated by unity catalog
24#[derive(Debug)]
25pub struct UnityCatalogList {
26    /// Collection of catalogs containing schemas and ultimately TableProviders
27    pub catalogs: DashMap<String, Arc<dyn CatalogProvider>>,
28}
29
30impl UnityCatalogList {
31    /// Create a new instance of [`UnityCatalogList`]
32    pub async fn try_new(client: Arc<UnityCatalog>) -> DataCatalogResult<Self> {
33        let catalogs = match client.list_catalogs().await? {
34            ListCatalogsResponse::Success { catalogs, .. } => {
35                let mut providers = Vec::new();
36                for catalog in catalogs {
37                    let provider =
38                        UnityCatalogProvider::try_new(client.clone(), &catalog.name).await?;
39                    providers.push((catalog.name, Arc::new(provider) as Arc<dyn CatalogProvider>));
40                }
41                providers
42            }
43            _ => vec![],
44        };
45        Ok(Self {
46            catalogs: catalogs.into_iter().collect(),
47        })
48    }
49}
50
51impl CatalogProviderList for UnityCatalogList {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn register_catalog(
57        &self,
58        name: String,
59        catalog: Arc<dyn CatalogProvider>,
60    ) -> Option<Arc<dyn CatalogProvider>> {
61        self.catalogs.insert(name, catalog)
62    }
63
64    fn catalog_names(&self) -> Vec<String> {
65        self.catalogs.iter().map(|c| c.key().clone()).collect()
66    }
67
68    fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
69        self.catalogs.get(name).map(|c| c.value().clone())
70    }
71}
72
73/// A datafusion [`CatalogProvider`] backed by Databricks UnityCatalog
74#[derive(Debug)]
75pub struct UnityCatalogProvider {
76    /// Parent catalog for schemas of interest.
77    pub schemas: DashMap<String, Arc<dyn SchemaProvider>>,
78}
79
80impl UnityCatalogProvider {
81    /// Create a new instance of [`UnityCatalogProvider`]
82    pub async fn try_new(
83        client: Arc<UnityCatalog>,
84        catalog_name: impl Into<String>,
85    ) -> DataCatalogResult<Self> {
86        let catalog_name = catalog_name.into();
87        let schemas = match client.list_schemas(&catalog_name).await? {
88            ListSchemasResponse::Success { schemas } => {
89                let mut providers = Vec::new();
90                for schema in schemas {
91                    let provider =
92                        UnitySchemaProvider::try_new(client.clone(), &catalog_name, &schema.name)
93                            .await?;
94                    providers.push((schema.name, Arc::new(provider) as Arc<dyn SchemaProvider>));
95                }
96                providers
97            }
98            _ => vec![],
99        };
100        Ok(Self {
101            schemas: schemas.into_iter().collect(),
102        })
103    }
104}
105
106impl CatalogProvider for UnityCatalogProvider {
107    fn as_any(&self) -> &dyn Any {
108        self
109    }
110
111    fn schema_names(&self) -> Vec<String> {
112        self.schemas.iter().map(|c| c.key().clone()).collect()
113    }
114
115    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
116        self.schemas.get(name).map(|c| c.value().clone())
117    }
118}
119
120struct TokenExpiry;
121
122impl Expiry<String, TemporaryTableCredentials> for TokenExpiry {
123    fn expire_after_read(
124        &self,
125        _key: &String,
126        value: &TemporaryTableCredentials,
127        _read_at: Instant,
128        _duration_until_expiry: Option<Duration>,
129        _last_modified_at: Instant,
130    ) -> Option<Duration> {
131        let time_to_expire = value.expiration_time - Utc::now();
132        tracing::info!("Token {_key} expires in {time_to_expire}");
133        time_to_expire.to_std().ok()
134    }
135}
136
137/// A datafusion [`SchemaProvider`] backed by Databricks UnityCatalog
138#[derive(Debug)]
139pub struct UnitySchemaProvider {
140    client: Arc<UnityCatalog>,
141    catalog_name: String,
142    schema_name: String,
143
144    /// Parent catalog for schemas of interest.
145    table_names: Vec<String>,
146    token_cache: Cache<String, TemporaryTableCredentials>,
147}
148
149impl UnitySchemaProvider {
150    /// Create a new instance of [`UnitySchemaProvider`]
151    pub async fn try_new(
152        client: Arc<UnityCatalog>,
153        catalog_name: impl Into<String>,
154        schema_name: impl Into<String>,
155    ) -> DataCatalogResult<Self> {
156        let catalog_name = catalog_name.into();
157        let schema_name = schema_name.into();
158        let table_names = match client
159            .list_table_summaries(&catalog_name, &schema_name)
160            .await?
161        {
162            ListTableSummariesResponse::Success { tables, .. } => tables
163                .into_iter()
164                .filter_map(|t| t.full_name.split('.').next_back().map(|n| n.into()))
165                .collect(),
166            ListTableSummariesResponse::Error(_) => vec![],
167        };
168        let token_cache = Cache::builder().expire_after(TokenExpiry).build();
169        Ok(Self {
170            client,
171            table_names,
172            catalog_name,
173            schema_name,
174            token_cache,
175        })
176    }
177
178    async fn get_creds(
179        &self,
180        catalog: &str,
181        schema: &str,
182        table: &str,
183    ) -> Result<TemporaryTableCredentials, UnityCatalogError> {
184        tracing::debug!("Fetching new credential for: {catalog}.{schema}.{table}",);
185        match self
186            .client
187            .get_temp_table_credentials_with_permission(catalog, schema, table, "READ_WRITE")
188            .await
189        {
190            Ok(TableTempCredentialsResponse::Success(temp_creds)) => Ok(temp_creds),
191            Ok(TableTempCredentialsResponse::Error(rw_error)) => match self
192                .client
193                .get_temp_table_credentials(catalog, schema, table)
194                .await?
195            {
196                TableTempCredentialsResponse::Success(temp_creds) => Ok(temp_creds),
197                TableTempCredentialsResponse::Error(read_error) => {
198                    Err(UnityCatalogError::TemporaryCredentialsFetchFailure {
199                        error_code: read_error.error_code,
200                        message: format!(
201                            "READ_WRITE failed: {}. READ failed: {}",
202                            rw_error.message, read_error.message
203                        ),
204                    })
205                }
206            },
207            Err(err) => Err(err),
208        }
209    }
210}
211
212#[async_trait::async_trait]
213impl SchemaProvider for UnitySchemaProvider {
214    fn as_any(&self) -> &dyn Any {
215        self
216    }
217
218    fn table_names(&self) -> Vec<String> {
219        self.table_names.clone()
220    }
221
222    async fn table(
223        &self,
224        name: &str,
225    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
226        let maybe_table = self
227            .client
228            .get_table(&self.catalog_name, &self.schema_name, name)
229            .await
230            .map_err(|err| DataFusionError::External(Box::new(err)))?;
231
232        match maybe_table {
233            GetTableResponse::Success(table) => {
234                let temp_creds = self
235                    .token_cache
236                    .try_get_with(
237                        table.table_id,
238                        self.get_creds(&self.catalog_name, &self.schema_name, name),
239                    )
240                    .await
241                    .map_err(|err| DataFusionError::External(err.into()))?;
242
243                let new_storage_opts = temp_creds.get_credentials().ok_or_else(|| {
244                    DataFusionError::External(UnityCatalogError::MissingCredential.into())
245                })?;
246                let table_url = ensure_table_uri(&table.storage_location)
247                    .map_err(|e| DataFusionError::External(Box::new(e)))?;
248                let table = DeltaTableBuilder::from_url(table_url)
249                    .map_err(|e| DataFusionError::External(Box::new(e)))?
250                    .with_storage_options(new_storage_opts)
251                    .load()
252                    .await?;
253                Ok(Some(table.table_provider().await?))
254            }
255            GetTableResponse::Error(err) => {
256                error!("failed to fetch table from unity catalog: {}", err.message);
257                Err(DataFusionError::External(Box::new(err)))
258            }
259        }
260    }
261
262    fn table_exist(&self, name: &str) -> bool {
263        self.table_names.contains(&String::from(name))
264    }
265}