use crate::sources::DataSourceType;
use anyhow::{Context, Result};
use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider};
use datafusion::datasource::TableProvider;
use datafusion::prelude::SessionContext;
use futures::stream::{self, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
const CATALOG_BUILD_CONCURRENCY: usize = 8;
#[derive(Debug, Clone, Copy)]
pub struct SourceLabel<'a> {
pub kind: DataSourceType,
pub hierarchy: HierarchyLevel,
pub name: &'a str,
}
impl<'a> SourceLabel<'a> {
pub fn new(kind: DataSourceType, hierarchy: HierarchyLevel, name: &'a str) -> Self {
Self {
kind,
hierarchy,
name,
}
}
}
impl<'a> fmt::Display for SourceLabel<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} {} '{}'",
self.kind,
self.hierarchy.as_str(),
self.name
)
}
}
pub const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
pub const MAX_RETRIES: u32 = 3;
#[derive(Debug, Clone, Copy, Deserialize, Serialize, Hash, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum HierarchyLevel {
#[default]
Table,
Catalog,
}
impl HierarchyLevel {
pub fn as_str(&self) -> &'static str {
match self {
Self::Table => "table",
Self::Catalog => "catalog",
}
}
}
pub fn parse_allowed_schemas(options: Option<&HashMap<String, String>>) -> Option<Vec<String>> {
let value = options.and_then(|opts| opts.get("allowed_schemas"))?;
let values = value
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect::<Vec<_>>();
if values.is_empty() {
None
} else {
Some(values)
}
}
pub async fn retry_with_timeout<T, F, Fut>(
label: SourceLabel<'_>,
op_name: &str,
mut op: F,
) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut last_err: Option<anyhow::Error> = None;
for attempt in 1..=MAX_RETRIES {
match timeout(CONNECT_TIMEOUT, op()).await {
Ok(Ok(value)) => return Ok(value),
Ok(Err(e)) => {
tracing::warn!(
"{}: {} attempt {}/{} failed: {}",
label,
op_name,
attempt,
MAX_RETRIES,
e
);
last_err = Some(e);
}
Err(_) => {
let e = anyhow::anyhow!(
"Timed out after {}s during {} for {}. \
Check that the upstream is reachable and credentials are correct.",
CONNECT_TIMEOUT.as_secs(),
op_name,
label
);
tracing::warn!(
"{}: {} attempt {}/{} timed out",
label,
op_name,
attempt,
MAX_RETRIES
);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("retry_with_timeout: no attempts were made")))
}
pub async fn build_catalog<F, Fut>(
session_ctx: &SessionContext,
catalog_name: &str,
schema_tables: Vec<(String, String)>,
mut build_table: F,
) -> Result<()>
where
F: FnMut(String, String) -> Fut,
Fut: Future<Output = Result<Arc<dyn TableProvider>>>,
{
let catalog_provider = Arc::new(MemoryCatalogProvider::new());
let provider_futures: Vec<_> = schema_tables
.into_iter()
.map(|(schema, table_name)| {
let fut = build_table(schema.clone(), table_name.clone());
let catalog_name = catalog_name.to_string();
async move {
let provider = fut.await.with_context(|| {
format!(
"Failed to build table provider for '{}.{}' in catalog '{}'",
schema, table_name, catalog_name
)
})?;
Ok::<_, anyhow::Error>((schema, table_name, provider))
}
})
.collect();
let mut prepared: Vec<(String, String, Arc<dyn TableProvider>)> =
stream::iter(provider_futures)
.buffer_unordered(CATALOG_BUILD_CONCURRENCY)
.try_collect()
.await?;
prepared.sort_by(|a, b| (a.0.as_str(), a.1.as_str()).cmp(&(b.0.as_str(), b.1.as_str())));
for (schema, table_name, table_provider) in prepared {
if catalog_provider.schema(&schema).is_none() {
catalog_provider
.register_schema(&schema, Arc::new(MemorySchemaProvider::new()))
.map_err(|e| {
anyhow::anyhow!(
"Failed to register schema '{}' for catalog '{}': {}",
schema,
catalog_name,
e
)
})?;
}
let schema_provider = catalog_provider.schema(&schema).ok_or_else(|| {
anyhow::anyhow!(
"Schema '{}' was not found after registration in catalog '{}'",
schema,
catalog_name
)
})?;
schema_provider
.register_table(table_name.clone(), table_provider)
.map_err(|e| {
anyhow::anyhow!(
"Failed to register table '{}.{}' in catalog '{}': {}",
schema,
table_name,
catalog_name,
e
)
})?;
tracing::debug!(
"Prepared '{}.{}' in catalog '{}'",
schema,
table_name,
catalog_name
);
}
session_ctx.register_catalog(catalog_name, catalog_provider);
Ok(())
}