use std::sync::Arc;
use crate::{Dataset, Error, Result};
use arrow_array::ArrayRef;
use arrow_array::{Array, RecordBatch, RecordBatchReader, new_null_array};
use arrow_row::{OwnedRow, RowConverter, Rows, SortField};
use arrow_schema::{DataType as ArrowDataType, SchemaRef};
use arrow_select::interleave::interleave;
use dashmap::{DashMap, ReadOnlyView};
use futures::{StreamExt, TryStreamExt};
use lance_core::utils::tokio::get_num_compute_intensive_cpus;
use tokio::task;
pub struct HashJoiner {
index_map: ReadOnlyView<OwnedRow, (usize, usize)>,
index_type: ArrowDataType,
batches: Vec<RecordBatch>,
out_schema: SchemaRef,
}
fn column_to_rows(column: ArrayRef) -> Result<Rows> {
let row_converter = RowConverter::new(vec![SortField::new(column.data_type().clone())])?;
let rows = row_converter.convert_columns(&[column])?;
Ok(rows)
}
impl HashJoiner {
pub async fn try_new(reader: Box<dyn RecordBatchReader + Send>, on: &str) -> Result<Self> {
let schema = reader.schema();
schema.field_with_name(on)?;
let batches = tokio::task::spawn_blocking(|| {
reader.collect::<std::result::Result<Vec<RecordBatch>, _>>()
})
.await
.unwrap()?;
if batches.is_empty() {
return Err(Error::invalid_input("HashJoiner: No data".to_string()));
};
let map = DashMap::new();
let keep_indices: Vec<usize> = schema
.fields()
.iter()
.enumerate()
.filter_map(|(i, field)| if field.name() == on { None } else { Some(i) })
.collect();
let out_schema: Arc<arrow_schema::Schema> = Arc::new(schema.project(&keep_indices)?);
let right_batches = batches
.iter()
.map(|batch| {
let mut columns = Vec::with_capacity(keep_indices.len());
for i in &keep_indices {
columns.push(batch.column(*i).clone());
}
RecordBatch::try_new(out_schema.clone(), columns).unwrap()
})
.collect::<Vec<_>>();
let map = Arc::new(map);
futures::stream::iter(batches.iter().enumerate().map(Ok::<_, Error>))
.try_for_each_concurrent(get_num_compute_intensive_cpus(), |(batch_i, batch)| {
let map = map.clone();
async move {
let column = batch[on].clone();
let task_result = task::spawn_blocking(move || {
let rows = column_to_rows(column)?;
for (row_i, row) in rows.iter().enumerate() {
map.insert(row.owned(), (batch_i, row_i));
}
Ok(())
})
.await;
match task_result {
Ok(Ok(_)) => Ok(()),
Ok(Err(err)) => Err(err),
Err(err) => Err(Error::invalid_input(format!("HashJoiner: {}", err))),
}
}
})
.await?;
let map = Arc::try_unwrap(map)
.expect("HashJoiner: No remaining tasks should still be referencing map.");
let index_type = batches[0]
.schema()
.field_with_name(on)
.unwrap()
.data_type()
.clone();
Ok(Self {
index_map: map.into_read_only(),
index_type,
batches: right_batches,
out_schema,
})
}
pub fn out_schema(&self) -> &SchemaRef {
&self.out_schema
}
pub(super) async fn collect(
&self,
dataset: &Dataset,
index_column: ArrayRef,
) -> Result<RecordBatch> {
if index_column.data_type() != &self.index_type {
return Err(Error::invalid_input(format!(
"Index column type mismatch: expected {}, got {}",
self.index_type,
index_column.data_type()
)));
}
let null_index = self.batches.len();
let indices = column_to_rows(index_column)?
.into_iter()
.map(|row| {
self.index_map
.get(&row.owned())
.map(|(batch_i, row_i)| (*batch_i, *row_i))
.unwrap_or((null_index, 0))
})
.collect::<Vec<_>>();
let indices = Arc::new(indices);
let columns = futures::stream::iter(0..self.batches[0].num_columns())
.map(|column_i| {
let mut arrays = Vec::with_capacity(self.batches.len() + 1);
for batch in &self.batches {
arrays.push(batch.column(column_i).clone());
}
arrays.push(Arc::new(new_null_array(arrays[0].data_type(), 1)));
let indices = indices.clone();
async move {
let task_result = task::spawn_blocking(move || {
let array_refs = arrays.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
interleave(array_refs.as_ref(), indices.as_ref())
.map_err(|err| Error::invalid_input(format!("HashJoiner: {}", err)))
})
.await;
match task_result {
Ok(Ok(array)) => {
Self::check_lance_support_null(&array, dataset)?;
Ok(array)
}
Ok(Err(err)) => Err(err),
Err(err) => Err(Error::io(format!("HashJoiner: {}", err))),
}
}
})
.buffered(get_num_compute_intensive_cpus())
.try_collect::<Vec<_>>()
.await?;
Ok(RecordBatch::try_new(self.batches[0].schema(), columns)?)
}
pub fn check_lance_support_null(array: &ArrayRef, dataset: &Dataset) -> Result<()> {
if array.null_count() > 0 && !dataset.lance_supports_nulls(array.data_type()) {
return Err(Error::invalid_input(format!(
"Join produced null values for type: {:?}, but storing \
nulls for this data type is not supported by the \
dataset's current Lance file format version: {:?}. This \
can be caused by an explicit null in the new data.",
array.data_type(),
dataset
.manifest()
.data_storage_format
.lance_file_version()
.unwrap()
)));
}
Ok(())
}
pub(super) async fn collect_with_fallback(
&self,
left_batch: &RecordBatch,
index_column: ArrayRef,
dataset: &Dataset,
) -> Result<RecordBatch> {
if index_column.data_type() != &self.index_type {
return Err(Error::invalid_input(format!(
"Index column type mismatch: expected {}, got {}",
self.index_type,
index_column.data_type()
)));
}
let left_batch_index = self.batches.len();
let indices = column_to_rows(index_column)?
.into_iter()
.enumerate()
.map(|(left_rowi, row)| {
self.index_map
.get(&row.owned())
.map(|(batch_i, row_i)| (*batch_i, *row_i))
.unwrap_or((left_batch_index, left_rowi))
})
.collect::<Vec<_>>();
let indices = Arc::new(indices);
let columns = futures::stream::iter(0..self.batches[0].num_columns())
.map(|column_i| {
let mut arrays = Vec::with_capacity(self.batches.len() + 1);
for batch in &self.batches {
arrays.push(batch.column(column_i).clone());
}
arrays.push(left_batch.column(column_i).clone());
let indices = indices.clone();
async move {
let task_result = task::spawn_blocking(move || {
let array_refs = arrays.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
interleave(array_refs.as_ref(), indices.as_ref())
.map_err(|err| Error::invalid_input(format!("HashJoiner: {}", err)))
})
.await;
match task_result {
Ok(Ok(array)) => {
Self::check_lance_support_null(&array, dataset)?;
Ok(array)
}
Ok(Err(err)) => Err(err),
Err(err) => Err(Error::invalid_input(format!("HashJoiner: {}", err))),
}
}
})
.buffered(get_num_compute_intensive_cpus())
.try_collect::<Vec<_>>()
.await?;
Ok(RecordBatch::try_new(self.batches[0].schema(), columns)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, RecordBatchIterator, StringArray, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use lance_core::utils::tempfile::TempDir;
async fn create_dataset() -> Dataset {
let uri = TempDir::default().path_str();
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)]));
let batches = RecordBatchIterator::new(std::iter::empty().map(Ok), schema.clone());
Dataset::write(batches, &uri, None).await.unwrap();
Dataset::open(&uri).await.unwrap()
}
#[tokio::test]
async fn test_joiner_collect() {
let schema = Arc::new(Schema::new(vec![
Field::new("i", DataType::Int32, true),
Field::new("s", DataType::Utf8, true),
]));
let batches: Vec<RecordBatch> = (0..5)
.map(|v| {
let values = (v * 10..v * 10 + 10).collect::<Vec<_>>();
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter(values.iter().copied())),
Arc::new(StringArray::from_iter_values(
values.iter().map(|v| format!("str_{}", v)),
)),
],
)
.unwrap()
})
.collect();
let batches: Box<dyn RecordBatchReader + Send> = Box::new(RecordBatchIterator::new(
batches.into_iter().map(Ok),
schema.clone(),
));
let joiner = HashJoiner::try_new(batches, "i").await.unwrap();
let dataset = create_dataset().await;
let indices = Arc::new(Int32Array::from_iter(&[
Some(15),
None,
Some(10),
Some(0),
None,
None,
Some(22),
Some(11111), ]));
let results = joiner.collect(&dataset, indices).await.unwrap();
assert_eq!(
results.column_by_name("s").unwrap().as_ref(),
&StringArray::from(vec![
Some("str_15"),
None,
Some("str_10"),
Some("str_0"),
None,
None,
Some("str_22"),
None ])
);
assert_eq!(results.num_columns(), 1);
}
#[tokio::test]
async fn test_reject_invalid() {
let schema = Arc::new(Schema::new(vec![
Field::new("i", DataType::Int32, true),
Field::new("s", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["a", "b"])),
],
)
.unwrap();
let batches: Box<dyn RecordBatchReader + Send> = Box::new(RecordBatchIterator::new(
vec![batch].into_iter().map(Ok),
schema.clone(),
));
let joiner = HashJoiner::try_new(batches, "i").await.unwrap();
let dataset = create_dataset().await;
let indices = Arc::new(UInt32Array::from_iter(&[Some(15)]));
let result = joiner.collect(&dataset, indices).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Index column type mismatch: expected Int32, got UInt32")
);
}
}