use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion::catalog::TableProvider;
use crate::error::HirnDbError;
use crate::reranker::Reranker;
pub use hirn_core::DistanceMetric;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NormalizeMethod {
#[default]
Score,
Rank,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IndexType {
IvfHnswSq,
IvfHnswPq,
IvfPq,
IvfRq,
Bm25,
BTree,
Bitmap,
LabelList,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct IndexParams {
pub num_partitions: Option<u32>,
pub num_sub_vectors: Option<u32>,
pub num_edges: Option<u32>,
pub ef_construction: Option<u32>,
pub sample_rate: Option<u32>,
pub num_bits: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IndexConfig {
pub columns: Vec<String>,
pub index_type: IndexType,
pub params: IndexParams,
pub replace: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScanOrdering {
pub column: String,
pub ascending: bool,
pub nulls_first: bool,
}
impl ScanOrdering {
#[must_use]
pub fn asc(column: impl Into<String>) -> Self {
Self {
column: column.into(),
ascending: true,
nulls_first: false,
}
}
#[must_use]
pub fn desc(column: impl Into<String>) -> Self {
Self {
column: column.into(),
ascending: false,
nulls_first: false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExactMatchFilter {
Utf8In {
column: String,
values: Vec<String>,
},
Utf8MultiColumnOr {
columns: Vec<String>,
value: String,
},
}
impl ExactMatchFilter {
fn assert_safe_column(col: &str) {
debug_assert!(
!col.is_empty()
&& col.len() <= 64
&& col
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_'),
"column name '{col}' contains unsafe characters — only [a-z0-9_] are allowed"
);
}
#[must_use]
pub fn utf8_value(column: impl Into<String>, value: impl Into<String>) -> Self {
let column = column.into();
Self::assert_safe_column(&column);
Self::Utf8In {
column,
values: vec![value.into()],
}
}
#[must_use]
pub fn utf8_values<I, S>(column: impl Into<String>, values: I) -> Option<Self>
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let values: Vec<String> = values.into_iter().map(Into::into).collect();
if values.is_empty() {
return None;
}
let column = column.into();
Self::assert_safe_column(&column);
Some(Self::Utf8In {
column,
values,
})
}
#[must_use]
pub fn utf8_multi_column_or(columns: Vec<String>, value: impl Into<String>) -> Self {
for col in &columns {
Self::assert_safe_column(col);
}
Self::Utf8MultiColumnOr {
columns,
value: value.into(),
}
}
#[must_use]
pub fn to_predicate_sql(&self) -> String {
match self {
Self::Utf8In { column, values } => {
if values.is_empty() {
return "1 = 0".to_string();
}
let in_list = values
.iter()
.map(|value| format!("'{}'", value.replace('\'', "''")))
.collect::<Vec<_>>()
.join(", ");
format!("{column} IN ({in_list})")
}
Self::Utf8MultiColumnOr { columns, value } => {
if columns.is_empty() {
return "1 = 0".to_string();
}
let escaped = value.replace('\'', "''");
columns
.iter()
.map(|col| format!("{col} = '{escaped}'"))
.collect::<Vec<_>>()
.join(" OR ")
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ScanOptions {
pub filter: Option<String>,
pub exact_filter: Option<ExactMatchFilter>,
pub columns: Option<Vec<String>>,
pub order_by: Option<Vec<ScanOrdering>>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct VectorSearchOptions {
pub column: String,
pub query: Vec<f32>,
pub metric: DistanceMetric,
pub limit: usize,
pub filter: Option<String>,
pub nprobes: Option<usize>,
pub refine_factor: Option<u32>,
}
impl Default for VectorSearchOptions {
fn default() -> Self {
Self {
column: String::new(),
query: Vec::new(),
metric: DistanceMetric::default(),
limit: 10,
filter: None,
nprobes: None,
refine_factor: None,
}
}
}
#[derive(Debug, Clone)]
pub struct FtsSearchOptions {
pub columns: Vec<String>,
pub query: String,
pub limit: usize,
pub filter: Option<String>,
}
#[derive(Clone)]
pub struct HybridSearchOptions {
pub vector_column: String,
pub query_vector: Vec<f32>,
pub fts_columns: Vec<String>,
pub fts_query: String,
pub normalize: NormalizeMethod,
pub metric: DistanceMetric,
pub limit: usize,
pub filter: Option<String>,
pub reranker: Option<Arc<dyn Reranker>>,
}
impl std::fmt::Debug for HybridSearchOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HybridSearchOptions")
.field("vector_column", &self.vector_column)
.field("fts_columns", &self.fts_columns)
.field("fts_query", &self.fts_query)
.field("normalize", &self.normalize)
.field("metric", &self.metric)
.field("limit", &self.limit)
.field("filter", &self.filter)
.field("reranker", &self.reranker.as_ref().map(|_| ".."))
.finish()
}
}
#[derive(Debug, Clone)]
pub enum MultivectorQuery {
Single(Vec<f32>),
Multi(Vec<Vec<f32>>),
}
#[derive(Debug, Clone)]
pub struct MultivectorSearchOptions {
pub column: String,
pub query: MultivectorQuery,
pub metric: DistanceMetric,
pub limit: usize,
pub filter: Option<String>,
pub dense_column: Option<String>,
pub first_stage_limit: Option<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct CompactOptions {
pub max_rows_per_group: Option<usize>,
pub target_rows_per_fragment: Option<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct CompactResult {
pub fragments_removed: u64,
pub fragments_added: u64,
pub rows_removed: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VersionTag {
pub name: String,
pub version: u64,
pub created_at: i64,
}
#[derive(Debug, Clone)]
pub struct DatasetInfo {
pub name: String,
pub version: u64,
pub row_count: u64,
pub schema: SchemaRef,
}
pub type RecordBatchStream =
std::pin::Pin<Box<dyn futures::Stream<Item = Result<RecordBatch, HirnDbError>> + Send>>;
#[derive(Debug, Clone)]
pub enum ColumnTransform {
AddColumn {
name: String,
data_type: arrow_schema::DataType,
nullable: bool,
default_value: Option<String>,
},
RenameColumn {
old_name: String,
new_name: String,
},
}
#[async_trait]
pub trait PhysicalStore: Send + Sync {
async fn append(&self, dataset: &str, batch: RecordBatch) -> Result<(), HirnDbError>;
async fn append_batches(
&self,
dataset: &str,
batches: Vec<RecordBatch>,
) -> Result<(), HirnDbError>;
async fn append_stream(
&self,
dataset: &str,
mut stream: RecordBatchStream,
) -> Result<(), HirnDbError> {
use futures::StreamExt as _;
const MAX_STREAM_BATCH_ROWS: usize = 50_000;
let mut buffer: Vec<RecordBatch> = Vec::new();
let mut buffered_rows: usize = 0;
while let Some(result) = stream.next().await {
let batch = result?;
if batch.num_rows() == 0 {
continue;
}
buffered_rows += batch.num_rows();
buffer.push(batch);
if buffered_rows >= MAX_STREAM_BATCH_ROWS {
self.append_batches(dataset, std::mem::take(&mut buffer))
.await?;
buffered_rows = 0;
}
}
if !buffer.is_empty() {
self.append_batches(dataset, buffer).await?;
}
Ok(())
}
async fn scan(&self, dataset: &str, opts: ScanOptions)
-> Result<Vec<RecordBatch>, HirnDbError>;
async fn scan_stream(
&self,
dataset: &str,
opts: ScanOptions,
) -> Result<RecordBatchStream, HirnDbError>;
#[doc(hidden)]
async fn delete(&self, dataset: &str, predicate: &str) -> Result<u64, HirnDbError>;
async fn delete_exact(
&self,
dataset: &str,
filter: &ExactMatchFilter,
) -> Result<u64, HirnDbError> {
let predicate = filter.to_predicate_sql();
self.delete(dataset, &predicate).await
}
async fn merge_insert(
&self,
dataset: &str,
on: &[&str],
batch: RecordBatch,
) -> Result<(), HirnDbError>;
async fn update_where(
&self,
dataset: &str,
filter: &str,
updates: &[(&str, &str)],
) -> Result<u64, HirnDbError>;
async fn count(&self, dataset: &str, filter: Option<&str>) -> Result<u64, HirnDbError>;
async fn vector_search(
&self,
dataset: &str,
opts: VectorSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError>;
async fn vector_search_many(
&self,
dataset: &str,
queries: Vec<VectorSearchOptions>,
) -> Result<Vec<Vec<RecordBatch>>, HirnDbError>;
async fn fts_search(
&self,
dataset: &str,
opts: FtsSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError>;
async fn hybrid_search(
&self,
dataset: &str,
opts: HybridSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError>;
async fn multivector_search(
&self,
dataset: &str,
opts: MultivectorSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError>;
async fn create_index(&self, dataset: &str, config: IndexConfig) -> Result<(), HirnDbError>;
async fn optimize_indices(&self, dataset: &str) -> Result<(), HirnDbError>;
async fn compact(
&self,
dataset: &str,
opts: CompactOptions,
) -> Result<CompactResult, HirnDbError>;
async fn version(&self, dataset: &str) -> Result<u64, HirnDbError>;
async fn tag(&self, dataset: &str, tag: &str) -> Result<(), HirnDbError>;
async fn checkout(&self, dataset: &str, version: u64) -> Result<(), HirnDbError>;
async fn list_tags(&self, dataset: &str) -> Result<Vec<VersionTag>, HirnDbError>;
async fn list_datasets(&self) -> Result<Vec<DatasetInfo>, HirnDbError>;
async fn exists(&self, dataset: &str) -> Result<bool, HirnDbError>;
async fn list_namespaces(&self) -> Result<Vec<String>, HirnDbError>;
async fn create_namespace(&self, name: &str) -> Result<(), HirnDbError>;
async fn drop_namespace(&self, name: &str) -> Result<(), HirnDbError>;
async fn add_columns(
&self,
dataset: &str,
transforms: Vec<ColumnTransform>,
) -> Result<(), HirnDbError>;
async fn drop_columns(&self, dataset: &str, columns: &[&str]) -> Result<(), HirnDbError>;
async fn table_provider(&self, dataset: &str) -> Option<Arc<dyn TableProvider>>;
}