1use chrono::prelude::*;
4use dashmap::DashMap;
5use datafusion::catalog::SchemaProvider;
6use datafusion::catalog::{CatalogProvider, CatalogProviderList};
7use datafusion::common::DataFusionError;
8use datafusion::datasource::TableProvider;
9use moka::Expiry;
10use moka::future::Cache;
11use std::any::Any;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tracing::error;
15
16use super::models::{
17 GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse,
18 TableTempCredentialsResponse, TemporaryTableCredentials,
19};
20use super::{DataCatalogResult, UnityCatalog, UnityCatalogError};
21use deltalake_core::{DeltaTableBuilder, ensure_table_uri};
22
23#[derive(Debug)]
25pub struct UnityCatalogList {
26 pub catalogs: DashMap<String, Arc<dyn CatalogProvider>>,
28}
29
30impl UnityCatalogList {
31 pub async fn try_new(client: Arc<UnityCatalog>) -> DataCatalogResult<Self> {
33 let catalogs = match client.list_catalogs().await? {
34 ListCatalogsResponse::Success { catalogs, .. } => {
35 let mut providers = Vec::new();
36 for catalog in catalogs {
37 let provider =
38 UnityCatalogProvider::try_new(client.clone(), &catalog.name).await?;
39 providers.push((catalog.name, Arc::new(provider) as Arc<dyn CatalogProvider>));
40 }
41 providers
42 }
43 _ => vec![],
44 };
45 Ok(Self {
46 catalogs: catalogs.into_iter().collect(),
47 })
48 }
49}
50
51impl CatalogProviderList for UnityCatalogList {
52 fn as_any(&self) -> &dyn Any {
53 self
54 }
55
56 fn register_catalog(
57 &self,
58 name: String,
59 catalog: Arc<dyn CatalogProvider>,
60 ) -> Option<Arc<dyn CatalogProvider>> {
61 self.catalogs.insert(name, catalog)
62 }
63
64 fn catalog_names(&self) -> Vec<String> {
65 self.catalogs.iter().map(|c| c.key().clone()).collect()
66 }
67
68 fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
69 self.catalogs.get(name).map(|c| c.value().clone())
70 }
71}
72
73#[derive(Debug)]
75pub struct UnityCatalogProvider {
76 pub schemas: DashMap<String, Arc<dyn SchemaProvider>>,
78}
79
80impl UnityCatalogProvider {
81 pub async fn try_new(
83 client: Arc<UnityCatalog>,
84 catalog_name: impl Into<String>,
85 ) -> DataCatalogResult<Self> {
86 let catalog_name = catalog_name.into();
87 let schemas = match client.list_schemas(&catalog_name).await? {
88 ListSchemasResponse::Success { schemas } => {
89 let mut providers = Vec::new();
90 for schema in schemas {
91 let provider =
92 UnitySchemaProvider::try_new(client.clone(), &catalog_name, &schema.name)
93 .await?;
94 providers.push((schema.name, Arc::new(provider) as Arc<dyn SchemaProvider>));
95 }
96 providers
97 }
98 _ => vec![],
99 };
100 Ok(Self {
101 schemas: schemas.into_iter().collect(),
102 })
103 }
104}
105
106impl CatalogProvider for UnityCatalogProvider {
107 fn as_any(&self) -> &dyn Any {
108 self
109 }
110
111 fn schema_names(&self) -> Vec<String> {
112 self.schemas.iter().map(|c| c.key().clone()).collect()
113 }
114
115 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
116 self.schemas.get(name).map(|c| c.value().clone())
117 }
118}
119
120struct TokenExpiry;
121
122impl Expiry<String, TemporaryTableCredentials> for TokenExpiry {
123 fn expire_after_read(
124 &self,
125 _key: &String,
126 value: &TemporaryTableCredentials,
127 _read_at: Instant,
128 _duration_until_expiry: Option<Duration>,
129 _last_modified_at: Instant,
130 ) -> Option<Duration> {
131 let time_to_expire = value.expiration_time - Utc::now();
132 tracing::info!("Token {_key} expires in {time_to_expire}");
133 time_to_expire.to_std().ok()
134 }
135}
136
137#[derive(Debug)]
139pub struct UnitySchemaProvider {
140 client: Arc<UnityCatalog>,
141 catalog_name: String,
142 schema_name: String,
143
144 table_names: Vec<String>,
146 token_cache: Cache<String, TemporaryTableCredentials>,
147}
148
149impl UnitySchemaProvider {
150 pub async fn try_new(
152 client: Arc<UnityCatalog>,
153 catalog_name: impl Into<String>,
154 schema_name: impl Into<String>,
155 ) -> DataCatalogResult<Self> {
156 let catalog_name = catalog_name.into();
157 let schema_name = schema_name.into();
158 let table_names = match client
159 .list_table_summaries(&catalog_name, &schema_name)
160 .await?
161 {
162 ListTableSummariesResponse::Success { tables, .. } => tables
163 .into_iter()
164 .filter_map(|t| t.full_name.split('.').next_back().map(|n| n.into()))
165 .collect(),
166 ListTableSummariesResponse::Error(_) => vec![],
167 };
168 let token_cache = Cache::builder().expire_after(TokenExpiry).build();
169 Ok(Self {
170 client,
171 table_names,
172 catalog_name,
173 schema_name,
174 token_cache,
175 })
176 }
177
178 async fn get_creds(
179 &self,
180 catalog: &str,
181 schema: &str,
182 table: &str,
183 ) -> Result<TemporaryTableCredentials, UnityCatalogError> {
184 tracing::debug!("Fetching new credential for: {catalog}.{schema}.{table}",);
185 match self
186 .client
187 .get_temp_table_credentials_with_permission(catalog, schema, table, "READ_WRITE")
188 .await
189 {
190 Ok(TableTempCredentialsResponse::Success(temp_creds)) => Ok(temp_creds),
191 Ok(TableTempCredentialsResponse::Error(rw_error)) => match self
192 .client
193 .get_temp_table_credentials(catalog, schema, table)
194 .await?
195 {
196 TableTempCredentialsResponse::Success(temp_creds) => Ok(temp_creds),
197 TableTempCredentialsResponse::Error(read_error) => {
198 Err(UnityCatalogError::TemporaryCredentialsFetchFailure {
199 error_code: read_error.error_code,
200 message: format!(
201 "READ_WRITE failed: {}. READ failed: {}",
202 rw_error.message, read_error.message
203 ),
204 })
205 }
206 },
207 Err(err) => Err(err),
208 }
209 }
210}
211
212#[async_trait::async_trait]
213impl SchemaProvider for UnitySchemaProvider {
214 fn as_any(&self) -> &dyn Any {
215 self
216 }
217
218 fn table_names(&self) -> Vec<String> {
219 self.table_names.clone()
220 }
221
222 async fn table(
223 &self,
224 name: &str,
225 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
226 let maybe_table = self
227 .client
228 .get_table(&self.catalog_name, &self.schema_name, name)
229 .await
230 .map_err(|err| DataFusionError::External(Box::new(err)))?;
231
232 match maybe_table {
233 GetTableResponse::Success(table) => {
234 let temp_creds = self
235 .token_cache
236 .try_get_with(
237 table.table_id,
238 self.get_creds(&self.catalog_name, &self.schema_name, name),
239 )
240 .await
241 .map_err(|err| DataFusionError::External(err.into()))?;
242
243 let new_storage_opts = temp_creds.get_credentials().ok_or_else(|| {
244 DataFusionError::External(UnityCatalogError::MissingCredential.into())
245 })?;
246 let table_url = ensure_table_uri(&table.storage_location)
247 .map_err(|e| DataFusionError::External(Box::new(e)))?;
248 let table = DeltaTableBuilder::from_url(table_url)
249 .map_err(|e| DataFusionError::External(Box::new(e)))?
250 .with_storage_options(new_storage_opts)
251 .load()
252 .await?;
253 Ok(Some(table.table_provider().await?))
254 }
255 GetTableResponse::Error(err) => {
256 error!("failed to fetch table from unity catalog: {}", err.message);
257 Err(DataFusionError::External(Box::new(err)))
258 }
259 }
260 }
261
262 fn table_exist(&self, name: &str) -> bool {
263 self.table_names.contains(&String::from(name))
264 }
265}