use std::collections::HashSet;
use std::sync::Arc;
use aws_config::sts::AssumeRoleProvider;
use aws_config::{Region, SdkConfig};
use aws_credential_types::provider::SharedCredentialsProvider;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::catalog::MemTable;
use datafusion::datasource::file_format::csv::CsvFormat;
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::file_format::json::JsonFormat;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl,
};
use datafusion::datasource::TableProvider;
use datafusion::prelude::SessionContext;
use hamelin_executor::config::{
ColumnConfig, Compression, FileFormat, FilesConfig, StaticAwsCredentials,
};
use hamelin_executor::executor::ExecutorError;
use hamelin_lib::catalog::Column;
use hamelin_lib::types::Type;
use object_store::aws::{AmazonS3Builder, AwsCredentialProvider};
use crate::arrow::arrow_schema_to_columns;
use crate::aws::AwsCredentialAdapter;
use crate::struct_expansion::hamelin_type_to_arrow;
pub async fn register_s3_bucket(
ctx: &SessionContext,
sdk_config: &SdkConfig,
scheme: &str,
bucket: &str,
region: Option<&str>,
role_arn: Option<&str>,
static_credentials: &Option<StaticAwsCredentials>,
endpoint_url: Option<&str>,
force_path_style: Option<bool>,
allow_http: Option<bool>,
) -> Result<(), ExecutorError> {
let resolved_region = region.or(sdk_config.region().map(|r| r.as_ref()));
let mut builder = AmazonS3Builder::new().with_bucket_name(bucket);
if let Some(ref s) = static_credentials {
builder = builder
.with_access_key_id(&s.access_id)
.with_secret_access_key(&s.secret_key);
} else {
let sdk_provider: SharedCredentialsProvider = if let Some(arn) = role_arn {
let mut arp = AssumeRoleProvider::builder(arn)
.session_name("hamelin")
.configure(sdk_config);
if let Some(r) = resolved_region {
arp = arp.region(Region::new(r.to_owned()));
}
SharedCredentialsProvider::new(arp.build().await)
} else {
sdk_config.credentials_provider().ok_or_else(|| {
ExecutorError::ConfigurationError(
anyhow::anyhow!("No AWS credentials configured").into(),
)
})?
};
let credentials: AwsCredentialProvider = Arc::new(AwsCredentialAdapter::new(sdk_provider));
builder = builder.with_credentials(credentials);
}
if let Some(r) = resolved_region {
builder = builder.with_region(r);
}
if let Some(e) = endpoint_url {
builder = builder.with_endpoint(e);
}
if let Some(f) = force_path_style {
builder = builder.with_virtual_hosted_style_request(!f);
}
if let Some(a) = allow_http {
builder = builder.with_allow_http(a);
}
let store = builder.build().map_err(|e| {
ExecutorError::ConfigurationError(
anyhow::anyhow!(
"Failed to build S3 object store for bucket '{}': {}",
bucket,
e
)
.into(),
)
})?;
let url = url::Url::parse(&format!("{}://{}", scheme, bucket)).map_err(|e| {
ExecutorError::ConfigurationError(
anyhow::anyhow!("Invalid S3 URL for bucket '{}': {}", bucket, e).into(),
)
})?;
ctx.runtime_env()
.register_object_store(&url, Arc::new(store));
Ok(())
}
pub fn extract_s3_buckets(paths: &[String]) -> Vec<(String, String)> {
let mut seen = HashSet::new();
for path in paths {
let (scheme, rest) = if let Some(rest) = path.strip_prefix("s3://") {
("s3", rest)
} else if let Some(rest) = path.strip_prefix("s3a://") {
("s3a", rest)
} else {
continue;
};
if let Some(bucket) = rest.split('/').next() {
seen.insert((scheme.to_string(), bucket.to_string()));
}
}
seen.into_iter().collect()
}
pub async fn resolve_files(
files: &FilesConfig,
ctx: &SessionContext,
) -> Result<(Vec<Column>, Arc<dyn TableProvider>), ExecutorError> {
let table_paths: Vec<ListingTableUrl> = files
.paths
.iter()
.map(|p| ListingTableUrl::parse(p).map_err(|e| ExecutorError::ConfigurationError(e.into())))
.collect::<Result<_, _>>()?;
let format = files.format.as_ref();
let compression = files.compression.as_ref();
let file_extension = files.file_extension.as_deref();
let columns = files.columns.as_deref();
let header = files.header;
let delimiter = files.delimiter.as_deref();
let is_lines = matches!(format, Some(FileFormat::Lines));
if is_lines && columns.is_some() {
return Err(ExecutorError::ConfigurationError(
anyhow::anyhow!(
"The 'lines' format has a fixed schema and does not support custom columns"
)
.into(),
));
}
let all_defaults = format.is_none()
&& compression.is_none()
&& file_extension.is_none()
&& columns.is_none()
&& header.is_none()
&& delimiter.is_none();
let (listing_config, hamelin_columns) = if all_defaults {
let config = ListingTableConfig::new_with_multi_paths(table_paths)
.infer(&ctx.state())
.await
.map_err(|e| ExecutorError::ConfigurationError(e.into()))?;
let schema = config.file_schema.as_ref().ok_or_else(|| {
ExecutorError::ConfigurationError(
anyhow::anyhow!("Failed to infer schema for files dataset").into(),
)
})?;
let cols = arrow_schema_to_columns(schema);
(config, cols)
} else {
let ct = match compression {
None => FileCompressionType::UNCOMPRESSED,
Some(Compression::Gzip) => FileCompressionType::GZIP,
Some(Compression::Bzip2) => FileCompressionType::BZIP2,
Some(Compression::Xz) => FileCompressionType::XZ,
Some(Compression::Zstd) => FileCompressionType::ZSTD,
};
let file_format: Arc<dyn datafusion::datasource::file_format::FileFormat> = match format {
Some(FileFormat::Jsonl) => {
Arc::new(JsonFormat::default().with_file_compression_type(ct))
}
Some(FileFormat::Parquet) => Arc::new(ParquetFormat::default()),
Some(FileFormat::CSV) | None => {
let mut csv = CsvFormat::default()
.with_has_header(header.unwrap_or(true))
.with_file_compression_type(ct);
if let Some(delim) = delimiter {
let byte = delim.as_bytes().first().copied().ok_or_else(|| {
ExecutorError::ConfigurationError(
anyhow::anyhow!("Delimiter must be a single byte character").into(),
)
})?;
csv = csv.with_delimiter(byte);
}
Arc::new(csv)
}
Some(FileFormat::Lines) => {
let csv = CsvFormat::default()
.with_has_header(false)
.with_delimiter(b'\x01') .with_quote(b'\x00') .with_file_compression_type(ct);
Arc::new(csv)
}
};
let ext = file_extension.or(match format {
Some(FileFormat::Lines) => Some(""),
Some(FileFormat::Jsonl) => Some(".jsonl"),
_ => None,
});
let listing_options = ListingOptions::new(file_format).with_file_extension_opt(ext);
let (schema, hamelin_cols) = if let Some(cols) = columns {
columns_to_schema(cols)?
} else if is_lines {
let schema = Arc::new(Schema::new(vec![Field::new("line", DataType::Utf8, false)]));
let cols = vec![Column {
name: "line".to_string(),
typ: Type::String.into(),
}];
(schema, cols)
} else {
let inferred = listing_options
.infer_schema(
&ctx.state(),
table_paths.first().ok_or_else(|| {
ExecutorError::ConfigurationError(
anyhow::anyhow!("No paths provided for files dataset").into(),
)
})?,
)
.await
.map_err(|e| ExecutorError::ConfigurationError(e.into()))?;
let cols = arrow_schema_to_columns(&inferred);
(inferred, cols)
};
let config = ListingTableConfig::new_with_multi_paths(table_paths)
.with_listing_options(listing_options)
.with_schema(schema);
(config, hamelin_cols)
};
let listing_table: Arc<dyn TableProvider> = Arc::new(
ListingTable::try_new(listing_config)
.map_err(|e| ExecutorError::ConfigurationError(e.into()))?,
);
Ok((hamelin_columns, listing_table))
}
fn columns_to_schema(
columns: &[ColumnConfig],
) -> Result<(Arc<Schema>, Vec<Column>), ExecutorError> {
let mut fields = Vec::with_capacity(columns.len());
let mut hamelin_columns = Vec::with_capacity(columns.len());
for c in columns {
let typ: Type = c
.typ
.clone()
.try_into()
.map_err(|e: anyhow::Error| ExecutorError::ConfigurationError(e.into()))?;
fields.push(Field::new(&c.name, hamelin_type_to_arrow(&typ), true));
hamelin_columns.push(Column {
name: c.name.clone(),
typ: c.typ.clone(),
});
}
Ok((Arc::new(Schema::new(fields)), hamelin_columns))
}
pub fn resolve_mem(
columns: &[ColumnConfig],
) -> Result<(Vec<Column>, Arc<dyn TableProvider>), ExecutorError> {
let (schema, hamelin_columns) = columns_to_schema(columns)?;
let batch = RecordBatch::new_empty(schema.clone());
let provider: Arc<dyn TableProvider> = Arc::new(
MemTable::try_new(schema, vec![vec![batch]])
.map_err(|e| ExecutorError::ConfigurationError(e.into()))?,
);
Ok((hamelin_columns, provider))
}