use std::sync::Arc;
use datafusion::prelude::SessionContext;
use hamelin_lib::catalog::CatalogProvider as HamelinCatalogProvider;
use hamelin_lib::err::TranslationErrors;
use hamelin_lib::parse_and_typecheck_with_options;
use hamelin_lib::provider::EnvironmentProvider;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::ast::identifier::{Identifier, SimpleIdentifier};
use hamelin_lib::tree::ast::query::Query;
use hamelin_lib::tree::options::{TranslateOptions, TypeCheckOptions};
use hamelin_lib::tree::typed_ast::query::TypedStatement;
use hamelin_lib::type_check_with_options;
use hamelin_translation::{normalize_with, NormalizationOptions};
use crate::arrow::arrow_to_hamelin_type;
use crate::statement::{translate_statement, TranslatedStatement};
pub async fn catalog_provider_from_session(
ctx: &SessionContext,
) -> Result<Arc<dyn EnvironmentProvider>, TranslationErrors> {
let catalog = HamelinCatalogProvider::default();
let state = ctx.state();
let catalog_list = state.catalog_list();
let options = state.config().options();
let default_catalog = &options.catalog.default_catalog;
let default_schema = &options.catalog.default_schema;
for catalog_name in catalog_list.catalog_names() {
let Some(df_catalog) = catalog_list.catalog(&catalog_name) else {
continue;
};
for schema_name in df_catalog.schema_names() {
let Some(schema) = df_catalog.schema(&schema_name) else {
continue;
};
for table_name in schema.table_names() {
let Some(provider) = schema.table(&table_name).await.ok().flatten() else {
continue;
};
let arrow_schema = provider.schema();
let cols: ordermap::OrderMap<_, _> = arrow_schema
.fields()
.iter()
.map(|f| {
let id = SimpleIdentifier::new(f.name());
let typ = arrow_to_hamelin_type(f.data_type());
(id, typ)
})
.collect();
let table_si = SimpleIdentifier::new(&table_name);
let schema_si = SimpleIdentifier::new(&schema_name);
let catalog_si = SimpleIdentifier::new(&catalog_name);
if let Some(full) = Identifier::from_segments(vec![
catalog_si.clone(),
schema_si.clone(),
table_si.clone(),
]) {
catalog.set(full, cols.clone());
}
if catalog_name == *default_catalog {
if let Some(partial) =
Identifier::from_segments(vec![schema_si, table_si.clone()])
{
catalog.set(partial, cols.clone());
}
}
if catalog_name == *default_catalog && schema_name == *default_schema {
catalog.set(table_si.into(), cols);
}
}
}
}
Ok(Arc::new(catalog))
}
pub async fn translate(
typed: impl Into<Arc<TypedStatement>>,
ctx: &SessionContext,
) -> Result<TranslatedStatement, TranslationErrors> {
let provider = catalog_provider_from_session(ctx).await?;
translate_with_options(
typed,
ctx,
TranslateOptions::builder().provider(provider).build(),
)
.await
}
pub async fn translate_with_options(
typed: impl Into<Arc<TypedStatement>>,
ctx: &SessionContext,
opts: TranslateOptions,
) -> Result<TranslatedStatement, TranslationErrors> {
let typed = typed.into();
let mut norm_opts: NormalizationOptions = normalize_with()
.with_lower_transform()
.with_registry(opts.registry)
.with_provider(opts.provider);
if let Some(ts_field) = opts.timestamp_field {
norm_opts = norm_opts.with_timestamp_field(ts_field);
}
if let Some(msg_field) = opts.message_field {
norm_opts = norm_opts.with_message_field(msg_field);
}
let ir = norm_opts.lower(typed)?;
translate_statement(&ir, ctx)
.await
.map_err(|e| e.as_ref().clone().single())
}
pub async fn type_check_and_translate(
ast: impl Into<Arc<Query>>,
ctx: &SessionContext,
) -> Result<TranslatedStatement, TranslationErrors> {
let provider = catalog_provider_from_session(ctx).await?;
type_check_and_translate_with_options(
ast,
ctx,
TypeCheckOptions::builder().provider(provider).build(),
)
.await
}
pub async fn type_check_and_translate_with_options(
ast: impl Into<Arc<Query>>,
ctx: &SessionContext,
opts: TypeCheckOptions,
) -> Result<TranslatedStatement, TranslationErrors> {
let translate_opts = TranslateOptions::builder()
.registry(opts.registry.clone())
.provider(opts.provider.clone())
.maybe_timestamp_field(opts.timestamp_field.clone())
.maybe_message_field(opts.message_field.clone())
.build();
let typed = type_check_with_options(ast, opts).into_result()?;
translate_with_options(Arc::new(typed), ctx, translate_opts).await
}
pub async fn parse_and_translate(
input: impl Into<String>,
ctx: &SessionContext,
) -> Result<TranslatedStatement, TranslationErrors> {
let provider = catalog_provider_from_session(ctx).await?;
parse_and_translate_with_options(
input,
ctx,
TypeCheckOptions::builder().provider(provider).build(),
)
.await
}
pub async fn parse_and_translate_with_time_range(
input: impl Into<String>,
ctx: &SessionContext,
time_range: Option<Arc<Expression>>,
) -> Result<TranslatedStatement, TranslationErrors> {
let provider = catalog_provider_from_session(ctx).await?;
parse_and_translate_with_options(
input,
ctx,
TypeCheckOptions::builder()
.provider(provider)
.maybe_time_range(time_range)
.build(),
)
.await
}
pub async fn parse_and_translate_with_options(
input: impl Into<String>,
ctx: &SessionContext,
opts: TypeCheckOptions,
) -> Result<TranslatedStatement, TranslationErrors> {
let translate_opts = TranslateOptions::builder()
.registry(opts.registry.clone())
.provider(opts.provider.clone())
.maybe_timestamp_field(opts.timestamp_field.clone())
.maybe_message_field(opts.message_field.clone())
.build();
let typed = parse_and_typecheck_with_options(input, opts).into_result()?;
translate_with_options(Arc::new(typed), ctx, translate_opts).await
}