use std::sync::Arc;
use arrow::{
array::{RecordBatch, StringArray},
datatypes::Schema,
error::ArrowError,
};
use bytes::Bytes;
use datafusion::{
datasource::{
file_format::file_compression_type::{self, FileCompressionType},
physical_plan::FileOpener,
},
error::DataFusionError,
};
use exon_fasta::FASTAConfig;
use futures::{Stream, StreamExt, TryStreamExt};
use object_store::{GetOptions, GetRange, GetResultPayload};
use crate::datasources::indexed_file::{fai::FAIFileRange, region::RegionObjectStoreExtension};
#[derive(Debug)]
pub struct IndexedFASTAOpener {
config: Arc<FASTAConfig>,
file_compression_type: FileCompressionType,
}
impl IndexedFASTAOpener {
pub fn new(config: Arc<FASTAConfig>, file_compression_type: FileCompressionType) -> Self {
Self {
config,
file_compression_type,
}
}
}
fn record_batch_stream(
sequence_name: &str,
sequence_literal: &[u8],
record_schema: Arc<Schema>,
) -> impl Stream<Item = arrow::error::Result<RecordBatch>> {
let record_batch = RecordBatch::try_new(
Arc::clone(&record_schema),
vec![
Arc::new(StringArray::from(vec![String::from(sequence_name)])),
Arc::new(StringArray::from(Vec::<Option<String>>::from([None]))),
Arc::new(StringArray::from(vec![String::from_utf8_lossy(
sequence_literal,
)
.to_string()])),
],
);
let s = futures::stream::iter(vec![record_batch]);
s.map_err(ArrowError::from)
}
impl FileOpener for IndexedFASTAOpener {
fn open(
&self,
file_meta: datafusion::datasource::physical_plan::FileMeta,
) -> datafusion::error::Result<datafusion::datasource::physical_plan::FileOpenFuture> {
let config = Arc::clone(&self.config);
let schema = Arc::clone(&self.config.file_schema);
let file_compression_type = self.file_compression_type;
Ok(Box::pin(async move {
let fai_file_range = file_meta
.extensions
.as_ref()
.and_then(|ext| ext.downcast_ref::<FAIFileRange>());
match fai_file_range {
Some(fai_file_range) => {
let get_options = GetOptions {
range: Some(GetRange::Bounded(std::ops::Range {
start: fai_file_range.start as usize,
end: fai_file_range.end as usize,
})),
..Default::default()
};
let get_result = config
.object_store
.get_opts(file_meta.location(), get_options)
.await?;
let get_stream =
Box::pin(get_result.into_stream().map_err(DataFusionError::from));
if file_compression_type
!= file_compression_type::FileCompressionType::UNCOMPRESSED
{
return Err(DataFusionError::Execution(
"Indexed FASTA from remote storage only supports uncompressed files."
.to_string(),
));
}
let bytes: Vec<Bytes> = file_compression_type
.convert_stream(get_stream)?
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<_, _>>()?;
let sequence = bytes.into_iter().flatten().collect::<Vec<u8>>();
let record_batch =
record_batch_stream(&fai_file_range.region_name, &sequence, schema);
return Ok(record_batch.boxed());
}
None => {
let region_extension = file_meta
.extensions
.as_ref()
.and_then(|ext| ext.downcast_ref::<RegionObjectStoreExtension>())
.ok_or(DataFusionError::Execution(
"Region extension not found".to_string(),
))?;
let get_result = config.object_store.get(file_meta.location()).await?;
match get_result.payload {
GetResultPayload::File(_, path) => {
let mut reader = noodles::fasta::indexed_reader::Builder::default()
.build_from_path(path)?;
let record = reader.query(®ion_extension.region)?;
let sequence = record.sequence();
let record_batch = record_batch_stream(
®ion_extension.region_name(),
sequence.as_ref(),
schema,
);
Ok(record_batch.boxed())
}
_ => Err(DataFusionError::Execution(
"Direct region access only supported on local files.".to_string(),
)),
}
}
}
}))
}
}