datafusion_cli/
catalog.rs1use std::any::Any;
19use std::sync::{Arc, Weak};
20
21use crate::object_storage::{get_object_store, AwsOptions, GcpOptions};
22
23use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider};
24
25use datafusion::common::plan_datafusion_err;
26use datafusion::datasource::listing::ListingTableUrl;
27use datafusion::datasource::TableProvider;
28use datafusion::error::Result;
29use datafusion::execution::context::SessionState;
30use datafusion::execution::session_state::SessionStateBuilder;
31
32use async_trait::async_trait;
33use dirs::home_dir;
34use parking_lot::RwLock;
35
36#[derive(Debug)]
38pub struct DynamicObjectStoreCatalog {
39 inner: Arc<dyn CatalogProviderList>,
40 state: Weak<RwLock<SessionState>>,
41}
42
43impl DynamicObjectStoreCatalog {
44 pub fn new(
45 inner: Arc<dyn CatalogProviderList>,
46 state: Weak<RwLock<SessionState>>,
47 ) -> Self {
48 Self { inner, state }
49 }
50}
51
52impl CatalogProviderList for DynamicObjectStoreCatalog {
53 fn as_any(&self) -> &dyn Any {
54 self
55 }
56
57 fn register_catalog(
58 &self,
59 name: String,
60 catalog: Arc<dyn CatalogProvider>,
61 ) -> Option<Arc<dyn CatalogProvider>> {
62 self.inner.register_catalog(name, catalog)
63 }
64
65 fn catalog_names(&self) -> Vec<String> {
66 self.inner.catalog_names()
67 }
68
69 fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
70 let state = self.state.clone();
71 self.inner.catalog(name).map(|catalog| {
72 Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _
73 })
74 }
75}
76
77#[derive(Debug)]
79struct DynamicObjectStoreCatalogProvider {
80 inner: Arc<dyn CatalogProvider>,
81 state: Weak<RwLock<SessionState>>,
82}
83
84impl DynamicObjectStoreCatalogProvider {
85 pub fn new(
86 inner: Arc<dyn CatalogProvider>,
87 state: Weak<RwLock<SessionState>>,
88 ) -> Self {
89 Self { inner, state }
90 }
91}
92
93impl CatalogProvider for DynamicObjectStoreCatalogProvider {
94 fn as_any(&self) -> &dyn Any {
95 self
96 }
97
98 fn schema_names(&self) -> Vec<String> {
99 self.inner.schema_names()
100 }
101
102 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
103 let state = self.state.clone();
104 self.inner.schema(name).map(|schema| {
105 Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _
106 })
107 }
108
109 fn register_schema(
110 &self,
111 name: &str,
112 schema: Arc<dyn SchemaProvider>,
113 ) -> Result<Option<Arc<dyn SchemaProvider>>> {
114 self.inner.register_schema(name, schema)
115 }
116}
117
118#[derive(Debug)]
121struct DynamicObjectStoreSchemaProvider {
122 inner: Arc<dyn SchemaProvider>,
123 state: Weak<RwLock<SessionState>>,
124}
125
126impl DynamicObjectStoreSchemaProvider {
127 pub fn new(
128 inner: Arc<dyn SchemaProvider>,
129 state: Weak<RwLock<SessionState>>,
130 ) -> Self {
131 Self { inner, state }
132 }
133}
134
135#[async_trait]
136impl SchemaProvider for DynamicObjectStoreSchemaProvider {
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140
141 fn table_names(&self) -> Vec<String> {
142 self.inner.table_names()
143 }
144
145 fn register_table(
146 &self,
147 name: String,
148 table: Arc<dyn TableProvider>,
149 ) -> Result<Option<Arc<dyn TableProvider>>> {
150 self.inner.register_table(name, table)
151 }
152
153 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
154 let inner_table = self.inner.table(name).await;
155 if inner_table.is_ok() {
156 if let Some(inner_table) = inner_table? {
157 return Ok(Some(inner_table));
158 }
159 }
160
161 let mut state = self
164 .state
165 .upgrade()
166 .ok_or_else(|| plan_datafusion_err!("locking error"))?
167 .read()
168 .clone();
169 let mut builder = SessionStateBuilder::from(state.clone());
170 let optimized_name = substitute_tilde(name.to_owned());
171 let table_url = ListingTableUrl::parse(optimized_name.as_str())?;
172 let scheme = table_url.scheme();
173 let url = table_url.as_ref();
174
175 match state.runtime_env().object_store_registry.get_store(url) {
180 Ok(_) => { }
181 Err(_) => {
182 match scheme {
185 "s3" | "oss" | "cos" => {
186 if let Some(table_options) = builder.table_options() {
187 table_options.extensions.insert(AwsOptions::default())
188 }
189 }
190 "gs" | "gcs" => {
191 if let Some(table_options) = builder.table_options() {
192 table_options.extensions.insert(GcpOptions::default())
193 }
194 }
195 _ => {}
196 };
197 state = builder.build();
198 let store = get_object_store(
199 &state,
200 table_url.scheme(),
201 url,
202 &state.default_table_options(),
203 )
204 .await?;
205 state.runtime_env().register_object_store(url, store);
206 }
207 }
208 self.inner.table(name).await
209 }
210
211 fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
212 self.inner.deregister_table(name)
213 }
214
215 fn table_exist(&self, name: &str) -> bool {
216 self.inner.table_exist(name)
217 }
218}
219
220pub fn substitute_tilde(cur: String) -> String {
221 if let Some(usr_dir_path) = home_dir() {
222 if let Some(usr_dir) = usr_dir_path.to_str() {
223 if cur.starts_with('~') && !usr_dir.is_empty() {
224 return cur.replacen('~', usr_dir, 1);
225 }
226 }
227 }
228 cur
229}
230#[cfg(test)]
231mod tests {
232
233 use super::*;
234
235 use datafusion::catalog::SchemaProvider;
236 use datafusion::prelude::SessionContext;
237
238 fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
239 let ctx = SessionContext::new();
240 ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new(
241 ctx.state().catalog_list().clone(),
242 ctx.state_weak_ref(),
243 )));
244
245 let provider = &DynamicObjectStoreCatalog::new(
246 ctx.state().catalog_list().clone(),
247 ctx.state_weak_ref(),
248 ) as &dyn CatalogProviderList;
249 let catalog = provider
250 .catalog(provider.catalog_names().first().unwrap())
251 .unwrap();
252 let schema = catalog
253 .schema(catalog.schema_names().first().unwrap())
254 .unwrap();
255 (ctx, schema)
256 }
257
258 #[tokio::test]
259 async fn query_http_location_test() -> Result<()> {
260 let domain = "example.com";
263 let location = format!("http://{domain}/file.parquet");
264
265 let (ctx, schema) = setup_context();
266
267 let table = schema.table(&location).await?;
269 assert!(table.is_none());
270
271 let store = ctx
273 .runtime_env()
274 .object_store(ListingTableUrl::parse(location)?)?;
275
276 assert_eq!(format!("{store}"), "HttpStore");
277
278 let expected_domain = format!("Domain(\"{domain}\")");
280 assert!(format!("{store:?}").contains(&expected_domain));
281
282 Ok(())
283 }
284
285 #[tokio::test]
286 async fn query_s3_location_test() -> Result<()> {
287 let bucket = "examples3bucket";
288 let location = format!("s3://{bucket}/file.parquet");
289
290 let (ctx, schema) = setup_context();
291
292 let table = schema.table(&location).await?;
293 assert!(table.is_none());
294
295 let store = ctx
296 .runtime_env()
297 .object_store(ListingTableUrl::parse(location)?)?;
298 assert_eq!(format!("{store}"), format!("AmazonS3({bucket})"));
299
300 let expected_bucket = format!("bucket: \"{bucket}\"");
302 assert!(format!("{store:?}").contains(&expected_bucket));
303
304 Ok(())
305 }
306
307 #[tokio::test]
308 async fn query_gs_location_test() -> Result<()> {
309 let bucket = "examplegsbucket";
310 let location = format!("gs://{bucket}/file.parquet");
311
312 let (ctx, schema) = setup_context();
313
314 let table = schema.table(&location).await?;
315 assert!(table.is_none());
316
317 let store = ctx
318 .runtime_env()
319 .object_store(ListingTableUrl::parse(location)?)?;
320 assert_eq!(format!("{store}"), format!("GoogleCloudStorage({bucket})"));
321
322 let expected_bucket = format!("bucket_name_encoded: \"{bucket}\"");
324 assert!(format!("{store:?}").contains(&expected_bucket));
325
326 Ok(())
327 }
328
329 #[tokio::test]
330 async fn query_invalid_location_test() {
331 let location = "ts://file.parquet";
332 let (_ctx, schema) = setup_context();
333
334 assert!(schema.table(location).await.is_err());
335 }
336
337 #[cfg(not(target_os = "windows"))]
338 #[test]
339 fn test_substitute_tilde() {
340 use std::env;
341 use std::path::MAIN_SEPARATOR;
342 let original_home = home_dir();
343 let test_home_path = if cfg!(windows) {
344 "C:\\Users\\user"
345 } else {
346 "/home/user"
347 };
348 env::set_var(
349 if cfg!(windows) { "USERPROFILE" } else { "HOME" },
350 test_home_path,
351 );
352 let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet";
353 let expected = format!(
354 "{}{}Code{}datafusion{}benchmarks{}data{}tpch_sf1{}part{}part-0.parquet",
355 test_home_path,
356 MAIN_SEPARATOR,
357 MAIN_SEPARATOR,
358 MAIN_SEPARATOR,
359 MAIN_SEPARATOR,
360 MAIN_SEPARATOR,
361 MAIN_SEPARATOR,
362 MAIN_SEPARATOR
363 );
364 let actual = substitute_tilde(input.to_string());
365 assert_eq!(actual, expected);
366 match original_home {
367 Some(home_path) => env::set_var(
368 if cfg!(windows) { "USERPROFILE" } else { "HOME" },
369 home_path.to_str().unwrap(),
370 ),
371 None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }),
372 }
373 }
374}