pub(crate) mod background_cache;
use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Schema, SchemaRef};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::RecordBatchStream;
use futures::{FutureExt, Stream};
use lance::arrow::json::JsonDataType;
use lance::dataset::{ReadParams, WriteParams};
use lance::index::vector::utils::infer_vector_dim;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lazy_static::lazy_static;
use std::pin::Pin;
use crate::error::{Error, Result};
use datafusion_physical_plan::SendableRecordBatchStream;
lazy_static! {
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
static ref NAMESPACE_NAME_REGEX: regex::Regex =
regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
}
pub trait PatchStoreParam {
fn patch_with_store_wrapper(
self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<Option<ObjectStoreParams>>;
}
impl PatchStoreParam for Option<ObjectStoreParams> {
fn patch_with_store_wrapper(
self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<Option<ObjectStoreParams>> {
let mut params = self.unwrap_or_default();
if params.object_store_wrapper.is_some() {
return Err(Error::Other {
message: "can not patch param because object store is already set".into(),
source: None,
});
}
params.object_store_wrapper = Some(wrapper);
Ok(Some(params))
}
}
pub trait PatchWriteParam {
fn patch_with_store_wrapper(self, wrapper: Arc<dyn WrappingObjectStore>)
-> Result<WriteParams>;
}
impl PatchWriteParam for WriteParams {
fn patch_with_store_wrapper(
mut self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<WriteParams> {
self.store_params = self.store_params.patch_with_store_wrapper(wrapper)?;
Ok(self)
}
}
pub trait PatchReadParam {
fn patch_with_store_wrapper(self, wrapper: Arc<dyn WrappingObjectStore>) -> Result<ReadParams>;
}
impl PatchReadParam for ReadParams {
fn patch_with_store_wrapper(
mut self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<ReadParams> {
self.store_options = self.store_options.patch_with_store_wrapper(wrapper)?;
Ok(self)
}
}
pub fn validate_table_name(name: &str) -> Result<()> {
if name.is_empty() {
return Err(Error::InvalidTableName {
name: name.to_string(),
reason: "Table names cannot be empty strings".to_string(),
});
}
if !TABLE_NAME_REGEX.is_match(name) {
return Err(Error::InvalidTableName {
name: name.to_string(),
reason:
"Table names can only contain alphanumeric characters, underscores, hyphens, and periods"
.to_string(),
});
}
Ok(())
}
pub fn validate_namespace_name(name: &str) -> Result<()> {
if name.is_empty() {
return Err(Error::InvalidInput {
message: "Namespace names cannot be empty strings".to_string(),
});
}
if !NAMESPACE_NAME_REGEX.is_match(name) {
return Err(Error::InvalidInput {
message: format!(
"Invalid namespace name '{}': Namespace names can only contain alphanumeric characters, underscores, hyphens, and periods",
name
),
});
}
Ok(())
}
pub fn validate_namespace(namespace: &[String]) -> Result<()> {
for component in namespace {
validate_namespace_name(component)?;
}
Ok(())
}
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
let candidates = schema
.fields()
.iter()
.filter_map(|field| match infer_vector_dim(field.data_type()) {
Ok(d) if dim.is_none() || dim == Some(d as i32) => Some(field.name()),
_ => None,
})
.collect::<Vec<_>>();
if candidates.is_empty() {
Err(Error::InvalidInput {
message: format!(
"No vector column found to match with the query vector dimension: {}",
dim.unwrap_or_default()
),
})
} else if candidates.len() != 1 {
Err(Error::Schema {
message: format!(
"More than one vector columns found, \
please specify which column to create index or query: {:?}",
candidates
),
})
} else {
Ok(candidates[0].clone())
}
}
pub fn supported_btree_data_type(dtype: &DataType) -> bool {
dtype.is_integer()
|| dtype.is_floating()
|| matches!(
dtype,
DataType::Boolean
| DataType::Utf8
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Date32
| DataType::Date64
| DataType::Timestamp(_, _)
| DataType::FixedSizeBinary(_)
)
}
pub fn supported_bitmap_data_type(dtype: &DataType) -> bool {
dtype.is_integer()
|| matches!(
dtype,
DataType::Utf8
| DataType::LargeUtf8
| DataType::Binary
| DataType::LargeBinary
| DataType::Boolean
)
}
pub fn supported_label_list_data_type(dtype: &DataType) -> bool {
match dtype {
DataType::List(field) => supported_bitmap_data_type(field.data_type()),
DataType::FixedSizeList(field, _) => supported_bitmap_data_type(field.data_type()),
_ => false,
}
}
pub fn supported_fts_data_type(dtype: &DataType) -> bool {
supported_fts_data_type_impl(dtype, false)
}
fn supported_fts_data_type_impl(dtype: &DataType, in_list: bool) -> bool {
match (dtype, in_list) {
(DataType::Utf8 | DataType::LargeUtf8, _) => true,
(DataType::List(field) | DataType::LargeList(field), false) => {
supported_fts_data_type_impl(field.data_type(), true)
}
_ => false,
}
}
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
match dtype {
DataType::FixedSizeList(field, _) => {
field.data_type().is_floating() || field.data_type() == &DataType::UInt8
}
DataType::List(field) => supported_vector_data_type(field.data_type()),
_ => false,
}
}
pub fn string_to_datatype(s: &str) -> Option<DataType> {
let data_type: serde_json::Value = {
if let Ok(data_type) = serde_json::from_str(s) {
data_type
} else {
serde_json::json!({ "type": s })
}
};
let json_type: JsonDataType = serde_json::from_value(data_type).ok()?;
(&json_type).try_into().ok()
}
enum TimeoutState {
NotStarted {
timeout: std::time::Duration,
},
Started {
deadline: Pin<Box<tokio::time::Sleep>>,
timeout: std::time::Duration,
},
Completed,
}
pub struct TimeoutStream {
inner: SendableRecordBatchStream,
state: TimeoutState,
}
impl TimeoutStream {
pub fn new(inner: SendableRecordBatchStream, timeout: std::time::Duration) -> Self {
Self {
inner,
state: TimeoutState::NotStarted { timeout },
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
timeout: std::time::Duration,
) -> SendableRecordBatchStream {
Box::pin(Self::new(inner, timeout))
}
fn timeout_error(timeout: &std::time::Duration) -> DataFusionError {
DataFusionError::Execution(format!("Query timeout after {} ms", timeout.as_millis()))
}
}
impl RecordBatchStream for TimeoutStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for TimeoutStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match &mut self.state {
TimeoutState::NotStarted { timeout } => {
if timeout.is_zero() {
return std::task::Poll::Ready(Some(Err(Self::timeout_error(timeout))));
}
let deadline = Box::pin(tokio::time::sleep(*timeout));
self.state = TimeoutState::Started {
deadline,
timeout: *timeout,
};
self.poll_next(cx)
}
TimeoutState::Started { deadline, timeout } => match deadline.poll_unpin(cx) {
std::task::Poll::Ready(_) => {
let err = Self::timeout_error(timeout);
self.state = TimeoutState::Completed;
std::task::Poll::Ready(Some(Err(err)))
}
std::task::Poll::Pending => {
let inner = Pin::new(&mut self.inner);
inner.poll_next(cx)
}
},
TimeoutState::Completed => std::task::Poll::Ready(None),
}
}
}
pub struct MaxBatchLengthStream {
inner: SendableRecordBatchStream,
max_batch_length: Option<usize>,
buffered_batch: Option<RecordBatch>,
buffered_offset: usize,
}
impl MaxBatchLengthStream {
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
Self {
inner,
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
buffered_batch: None,
buffered_offset: 0,
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
max_batch_length: usize,
) -> SendableRecordBatchStream {
if max_batch_length == 0 {
inner
} else {
Box::pin(Self::new(inner, max_batch_length))
}
}
}
impl RecordBatchStream for MaxBatchLengthStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for MaxBatchLengthStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
loop {
let Some(max_batch_length) = self.max_batch_length else {
return Pin::new(&mut self.inner).poll_next(cx);
};
if let Some(batch) = self.buffered_batch.clone() {
if self.buffered_offset < batch.num_rows() {
let remaining = batch.num_rows() - self.buffered_offset;
let length = remaining.min(max_batch_length);
let sliced = batch.slice(self.buffered_offset, length);
self.buffered_offset += length;
if self.buffered_offset >= batch.num_rows() {
self.buffered_batch = None;
self.buffered_offset = 0;
}
return std::task::Poll::Ready(Some(Ok(sliced)));
}
self.buffered_batch = None;
self.buffered_offset = 0;
}
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(batch))) => {
if batch.num_rows() <= max_batch_length {
return std::task::Poll::Ready(Some(Ok(batch)));
}
self.buffered_batch = Some(batch);
self.buffered_offset = 0;
}
other => return other,
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_array::Int32Array;
use arrow_schema::Field;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{StreamExt, stream};
use tokio::time::sleep;
use super::*;
#[test]
fn test_guess_default_column() {
let schema_no_vector = Schema::new(vec![
Field::new("id", DataType::Int16, true),
Field::new("tag", DataType::Utf8, false),
]);
assert!(
default_vector_column(&schema_no_vector, None)
.unwrap_err()
.to_string()
.contains("No vector column")
);
let schema_with_vec_col = Schema::new(vec![
Field::new("id", DataType::Int16, true),
Field::new(
"vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
false,
),
]);
assert_eq!(
default_vector_column(&schema_with_vec_col, None).unwrap(),
"vec"
);
let multi_vec_col = Schema::new(vec![
Field::new("id", DataType::Int16, true),
Field::new(
"vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
false,
),
Field::new(
"vec2",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 50),
false,
),
]);
assert!(
default_vector_column(&multi_vec_col, None)
.unwrap_err()
.to_string()
.contains("More than one")
);
}
#[test]
fn test_validate_table_name() {
assert!(validate_table_name("my_table").is_ok());
assert!(validate_table_name("my_table_1").is_ok());
assert!(validate_table_name("123mytable").is_ok());
assert!(validate_table_name("_12345table").is_ok());
assert!(validate_table_name("table.12345").is_ok());
assert!(validate_table_name("table.._dot_..12345").is_ok());
assert!(validate_table_name("").is_err());
assert!(validate_table_name("my_table!").is_err());
assert!(validate_table_name("my/table").is_err());
assert!(validate_table_name("my@table").is_err());
assert!(validate_table_name("name with space").is_err());
}
#[test]
fn test_validate_namespace_name() {
assert!(validate_namespace_name("ns1").is_ok());
assert!(validate_namespace_name("namespace_123").is_ok());
assert!(validate_namespace_name("my-namespace").is_ok());
assert!(validate_namespace_name("my.namespace").is_ok());
assert!(validate_namespace_name("NS_1.2.3").is_ok());
assert!(validate_namespace_name("a").is_ok());
assert!(validate_namespace_name("123").is_ok());
assert!(validate_namespace_name("_underscore").is_ok());
assert!(validate_namespace_name("-hyphen").is_ok());
assert!(validate_namespace_name(".period").is_ok());
assert!(validate_namespace_name("").is_err());
assert!(validate_namespace_name("namespace with spaces").is_err());
assert!(validate_namespace_name("namespace/with/slashes").is_err());
assert!(validate_namespace_name("namespace\\with\\backslashes").is_err());
assert!(validate_namespace_name("namespace$with$delimiter").is_err());
assert!(validate_namespace_name("namespace@special").is_err());
assert!(validate_namespace_name("namespace#hash").is_err());
}
#[test]
fn test_validate_namespace() {
assert!(validate_namespace(&["ns1".to_string()]).is_ok());
assert!(
validate_namespace(&["ns1".to_string(), "ns2".to_string(), "ns3".to_string()]).is_ok()
);
assert!(validate_namespace(&[]).is_ok());
assert!(validate_namespace(&["ns1".to_string(), "".to_string()]).is_err());
assert!(validate_namespace(&["ns1".to_string(), "ns 2".to_string()]).is_err());
assert!(validate_namespace(&["ns1".to_string(), "ns@2".to_string()]).is_err());
assert!(validate_namespace(&["ns1".to_string(), "ns/2".to_string()]).is_err());
assert!(validate_namespace(&["ns1".to_string(), "ns$2".to_string()]).is_err());
assert!(
validate_namespace(&["ns_1".to_string(), "ns-2".to_string(), "ns.3".to_string()])
.is_ok()
);
}
#[test]
fn test_string_to_datatype() {
let string = "int32";
let expected = DataType::Int32;
assert_eq!(string_to_datatype(string), Some(expected));
}
fn sample_batch(num_rows: i32) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"col1",
DataType::Int32,
false,
)]));
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
)
.unwrap()
}
#[tokio::test]
async fn test_timeout_stream() {
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
let timeout_duration = std::time::Duration::from_millis(10);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
let first_result = timeout_stream.next().await;
assert!(first_result.is_some());
assert!(first_result.unwrap().is_ok());
sleep(timeout_duration).await;
let second_result = timeout_stream.next().await.unwrap();
assert!(second_result.is_err());
assert!(
second_result
.unwrap_err()
.to_string()
.contains("Query timeout")
);
}
#[tokio::test]
async fn test_timeout_stream_zero_duration() {
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
let timeout_duration = std::time::Duration::from_secs(0);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
let result = timeout_stream.next().await.unwrap();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Query timeout"));
}
#[tokio::test]
async fn test_timeout_stream_completes_normally() {
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
let timeout_duration = std::time::Duration::from_secs(1);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
assert!(timeout_stream.next().await.unwrap().is_ok());
assert!(timeout_stream.next().await.unwrap().is_ok());
assert!(timeout_stream.next().await.is_none());
}
async fn collect_batch_sizes(
stream: SendableRecordBatchStream,
max_batch_length: usize,
) -> Vec<usize> {
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
sliced_stream
.by_ref()
.map(|batch| batch.unwrap().num_rows())
.collect::<Vec<_>>()
.await
}
#[tokio::test]
async fn test_max_batch_length_stream_behaviors() {
let schema = sample_batch(7).schema();
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
assert_eq!(
collect_batch_sizes(sendable_stream, 3).await,
vec![2, 3, 3, 1]
);
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
));
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
}
}