1use std::sync::Arc;
2
3use arrow_array::RecordBatch;
4use arrow_schema::SchemaRef;
5use async_trait::async_trait;
6use datafusion::catalog::TableProvider;
7
8use crate::error::HirnDbError;
9use crate::reranker::Reranker;
10
11pub use hirn_core::DistanceMetric;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum NormalizeMethod {
20 #[default]
21 Score,
22 Rank,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum IndexType {
29 IvfHnswSq,
30 IvfHnswPq,
31 IvfPq,
32 IvfRq,
33 Bm25,
34 BTree,
35 Bitmap,
36 LabelList,
37}
38
39#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
42pub struct IndexParams {
43 pub num_partitions: Option<u32>,
44 pub num_sub_vectors: Option<u32>,
45 pub num_edges: Option<u32>,
46 pub ef_construction: Option<u32>,
47 pub sample_rate: Option<u32>,
48 pub num_bits: Option<u32>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct IndexConfig {
55 pub columns: Vec<String>,
56 pub index_type: IndexType,
57 pub params: IndexParams,
58 pub replace: bool,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct ScanOrdering {
65 pub column: String,
66 pub ascending: bool,
67 pub nulls_first: bool,
68}
69
70impl ScanOrdering {
71 #[must_use]
72 pub fn asc(column: impl Into<String>) -> Self {
73 Self {
74 column: column.into(),
75 ascending: true,
76 nulls_first: false,
77 }
78 }
79
80 #[must_use]
81 pub fn desc(column: impl Into<String>) -> Self {
82 Self {
83 column: column.into(),
84 ascending: false,
85 nulls_first: false,
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
93pub enum ExactMatchFilter {
94 Utf8In {
95 column: String,
96 values: Vec<String>,
97 },
98 Utf8MultiColumnOr {
101 columns: Vec<String>,
102 value: String,
103 },
104}
105
106impl ExactMatchFilter {
107 fn assert_safe_column(col: &str) {
113 debug_assert!(
114 !col.is_empty()
115 && col.len() <= 64
116 && col
117 .chars()
118 .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_'),
119 "column name '{col}' contains unsafe characters — only [a-z0-9_] are allowed"
120 );
121 }
122
123 #[must_use]
124 pub fn utf8_value(column: impl Into<String>, value: impl Into<String>) -> Self {
125 let column = column.into();
126 Self::assert_safe_column(&column);
127 Self::Utf8In {
128 column,
129 values: vec![value.into()],
130 }
131 }
132
133 #[must_use]
134 pub fn utf8_values<I, S>(column: impl Into<String>, values: I) -> Option<Self>
135 where
136 I: IntoIterator<Item = S>,
137 S: Into<String>,
138 {
139 let values: Vec<String> = values.into_iter().map(Into::into).collect();
140 if values.is_empty() {
141 return None;
142 }
143
144 let column = column.into();
145 Self::assert_safe_column(&column);
146 Some(Self::Utf8In {
147 column,
148 values,
149 })
150 }
151
152 #[must_use]
153 pub fn utf8_multi_column_or(columns: Vec<String>, value: impl Into<String>) -> Self {
154 for col in &columns {
155 Self::assert_safe_column(col);
156 }
157 Self::Utf8MultiColumnOr {
158 columns,
159 value: value.into(),
160 }
161 }
162
163 #[must_use]
164 pub fn to_predicate_sql(&self) -> String {
165 match self {
166 Self::Utf8In { column, values } => {
167 if values.is_empty() {
168 return "1 = 0".to_string();
169 }
170
171 let in_list = values
172 .iter()
173 .map(|value| format!("'{}'", value.replace('\'', "''")))
174 .collect::<Vec<_>>()
175 .join(", ");
176 format!("{column} IN ({in_list})")
177 }
178 Self::Utf8MultiColumnOr { columns, value } => {
179 if columns.is_empty() {
180 return "1 = 0".to_string();
181 }
182 let escaped = value.replace('\'', "''");
183 columns
184 .iter()
185 .map(|col| format!("{col} = '{escaped}'"))
186 .collect::<Vec<_>>()
187 .join(" OR ")
188 }
189 }
190 }
191}
192
193#[derive(Debug, Clone, Default)]
194pub struct ScanOptions {
195 pub filter: Option<String>,
196 pub exact_filter: Option<ExactMatchFilter>,
197 pub columns: Option<Vec<String>>,
198 pub order_by: Option<Vec<ScanOrdering>>,
199 pub limit: Option<usize>,
200 pub offset: Option<usize>,
201}
202
203#[derive(Debug, Clone)]
206pub struct VectorSearchOptions {
207 pub column: String,
208 pub query: Vec<f32>,
209 pub metric: DistanceMetric,
210 pub limit: usize,
211 pub filter: Option<String>,
212 pub nprobes: Option<usize>,
213 pub refine_factor: Option<u32>,
214}
215
216impl Default for VectorSearchOptions {
217 fn default() -> Self {
218 Self {
219 column: String::new(),
220 query: Vec::new(),
221 metric: DistanceMetric::default(),
222 limit: 10,
223 filter: None,
224 nprobes: None,
225 refine_factor: None,
226 }
227 }
228}
229
230#[derive(Debug, Clone)]
233pub struct FtsSearchOptions {
234 pub columns: Vec<String>,
235 pub query: String,
236 pub limit: usize,
237 pub filter: Option<String>,
238}
239
240#[derive(Clone)]
243pub struct HybridSearchOptions {
244 pub vector_column: String,
245 pub query_vector: Vec<f32>,
246 pub fts_columns: Vec<String>,
247 pub fts_query: String,
248 pub normalize: NormalizeMethod,
249 pub metric: DistanceMetric,
250 pub limit: usize,
251 pub filter: Option<String>,
252 pub reranker: Option<Arc<dyn Reranker>>,
254}
255
256impl std::fmt::Debug for HybridSearchOptions {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.debug_struct("HybridSearchOptions")
259 .field("vector_column", &self.vector_column)
260 .field("fts_columns", &self.fts_columns)
261 .field("fts_query", &self.fts_query)
262 .field("normalize", &self.normalize)
263 .field("metric", &self.metric)
264 .field("limit", &self.limit)
265 .field("filter", &self.filter)
266 .field("reranker", &self.reranker.as_ref().map(|_| ".."))
267 .finish()
268 }
269}
270
271#[derive(Debug, Clone)]
274pub enum MultivectorQuery {
275 Single(Vec<f32>),
276 Multi(Vec<Vec<f32>>),
277}
278
279#[derive(Debug, Clone)]
280pub struct MultivectorSearchOptions {
281 pub column: String,
283 pub query: MultivectorQuery,
284 pub metric: DistanceMetric,
285 pub limit: usize,
286 pub filter: Option<String>,
287 pub dense_column: Option<String>,
291 pub first_stage_limit: Option<usize>,
293}
294
295#[derive(Debug, Clone, Default)]
298pub struct CompactOptions {
299 pub max_rows_per_group: Option<usize>,
300 pub target_rows_per_fragment: Option<usize>,
301}
302
303#[derive(Debug, Clone, Default)]
304pub struct CompactResult {
305 pub fragments_removed: u64,
306 pub fragments_added: u64,
307 pub rows_removed: u64,
308}
309
310#[derive(Debug, Clone, PartialEq, Eq)]
313pub struct VersionTag {
314 pub name: String,
315 pub version: u64,
316 pub created_at: i64,
317}
318
319#[derive(Debug, Clone)]
322pub struct DatasetInfo {
323 pub name: String,
324 pub version: u64,
325 pub row_count: u64,
326 pub schema: SchemaRef,
327}
328
329pub type RecordBatchStream =
330 std::pin::Pin<Box<dyn futures::Stream<Item = Result<RecordBatch, HirnDbError>> + Send>>;
331
332#[derive(Debug, Clone)]
335pub enum ColumnTransform {
336 AddColumn {
337 name: String,
338 data_type: arrow_schema::DataType,
339 nullable: bool,
340 default_value: Option<String>,
341 },
342 RenameColumn {
343 old_name: String,
344 new_name: String,
345 },
346}
347
348#[async_trait]
355pub trait PhysicalStore: Send + Sync {
356 async fn append(&self, dataset: &str, batch: RecordBatch) -> Result<(), HirnDbError>;
360
361 async fn append_batches(
363 &self,
364 dataset: &str,
365 batches: Vec<RecordBatch>,
366 ) -> Result<(), HirnDbError>;
367
368 async fn append_stream(
380 &self,
381 dataset: &str,
382 mut stream: RecordBatchStream,
383 ) -> Result<(), HirnDbError> {
384 use futures::StreamExt as _;
385 const MAX_STREAM_BATCH_ROWS: usize = 50_000;
386 let mut buffer: Vec<RecordBatch> = Vec::new();
387 let mut buffered_rows: usize = 0;
388 while let Some(result) = stream.next().await {
389 let batch = result?;
390 if batch.num_rows() == 0 {
391 continue;
392 }
393 buffered_rows += batch.num_rows();
394 buffer.push(batch);
395 if buffered_rows >= MAX_STREAM_BATCH_ROWS {
396 self.append_batches(dataset, std::mem::take(&mut buffer))
397 .await?;
398 buffered_rows = 0;
399 }
400 }
401 if !buffer.is_empty() {
402 self.append_batches(dataset, buffer).await?;
403 }
404 Ok(())
405 }
406
407 async fn scan(&self, dataset: &str, opts: ScanOptions)
409 -> Result<Vec<RecordBatch>, HirnDbError>;
410
411 async fn scan_stream(
413 &self,
414 dataset: &str,
415 opts: ScanOptions,
416 ) -> Result<RecordBatchStream, HirnDbError>;
417
418 #[doc(hidden)]
426 async fn delete(&self, dataset: &str, predicate: &str) -> Result<u64, HirnDbError>;
427
428 async fn delete_exact(
430 &self,
431 dataset: &str,
432 filter: &ExactMatchFilter,
433 ) -> Result<u64, HirnDbError> {
434 let predicate = filter.to_predicate_sql();
435 self.delete(dataset, &predicate).await
436 }
437
438 async fn merge_insert(
440 &self,
441 dataset: &str,
442 on: &[&str],
443 batch: RecordBatch,
444 ) -> Result<(), HirnDbError>;
445
446 async fn update_where(
455 &self,
456 dataset: &str,
457 filter: &str,
458 updates: &[(&str, &str)],
459 ) -> Result<u64, HirnDbError>;
460
461 async fn count(&self, dataset: &str, filter: Option<&str>) -> Result<u64, HirnDbError>;
463
464 async fn vector_search(
468 &self,
469 dataset: &str,
470 opts: VectorSearchOptions,
471 ) -> Result<Vec<RecordBatch>, HirnDbError>;
472
473 async fn vector_search_many(
475 &self,
476 dataset: &str,
477 queries: Vec<VectorSearchOptions>,
478 ) -> Result<Vec<Vec<RecordBatch>>, HirnDbError>;
479
480 async fn fts_search(
482 &self,
483 dataset: &str,
484 opts: FtsSearchOptions,
485 ) -> Result<Vec<RecordBatch>, HirnDbError>;
486
487 async fn hybrid_search(
489 &self,
490 dataset: &str,
491 opts: HybridSearchOptions,
492 ) -> Result<Vec<RecordBatch>, HirnDbError>;
493
494 async fn multivector_search(
496 &self,
497 dataset: &str,
498 opts: MultivectorSearchOptions,
499 ) -> Result<Vec<RecordBatch>, HirnDbError>;
500
501 async fn create_index(&self, dataset: &str, config: IndexConfig) -> Result<(), HirnDbError>;
505
506 async fn optimize_indices(&self, dataset: &str) -> Result<(), HirnDbError>;
508
509 async fn compact(
513 &self,
514 dataset: &str,
515 opts: CompactOptions,
516 ) -> Result<CompactResult, HirnDbError>;
517
518 async fn version(&self, dataset: &str) -> Result<u64, HirnDbError>;
522
523 async fn tag(&self, dataset: &str, tag: &str) -> Result<(), HirnDbError>;
525
526 async fn checkout(&self, dataset: &str, version: u64) -> Result<(), HirnDbError>;
528
529 async fn list_tags(&self, dataset: &str) -> Result<Vec<VersionTag>, HirnDbError>;
531
532 async fn list_datasets(&self) -> Result<Vec<DatasetInfo>, HirnDbError>;
536
537 async fn exists(&self, dataset: &str) -> Result<bool, HirnDbError>;
539
540 async fn list_namespaces(&self) -> Result<Vec<String>, HirnDbError>;
544
545 async fn create_namespace(&self, name: &str) -> Result<(), HirnDbError>;
547
548 async fn drop_namespace(&self, name: &str) -> Result<(), HirnDbError>;
550
551 async fn add_columns(
555 &self,
556 dataset: &str,
557 transforms: Vec<ColumnTransform>,
558 ) -> Result<(), HirnDbError>;
559
560 async fn drop_columns(&self, dataset: &str, columns: &[&str]) -> Result<(), HirnDbError>;
562
563 async fn table_provider(&self, dataset: &str) -> Option<Arc<dyn TableProvider>>;
573}