Skip to main content

hirn_storage/
store.rs

1use std::sync::Arc;
2
3use arrow_array::RecordBatch;
4use arrow_schema::SchemaRef;
5use async_trait::async_trait;
6use datafusion::catalog::TableProvider;
7
8use crate::error::HirnDbError;
9use crate::reranker::Reranker;
10
11// ── Distance Metrics ──
12
13/// Re-exported from `hirn-core` — single canonical definition across the codebase.
14pub use hirn_core::DistanceMetric;
15
16// ── Normalize Method ──
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum NormalizeMethod {
20    #[default]
21    Score,
22    Rank,
23}
24
25// ── Index Types ──
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum IndexType {
29    IvfHnswSq,
30    IvfHnswPq,
31    IvfPq,
32    IvfRq,
33    Bm25,
34    BTree,
35    Bitmap,
36    LabelList,
37}
38
39// ── Index Parameters ──
40
41#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
42pub struct IndexParams {
43    pub num_partitions: Option<u32>,
44    pub num_sub_vectors: Option<u32>,
45    pub num_edges: Option<u32>,
46    pub ef_construction: Option<u32>,
47    pub sample_rate: Option<u32>,
48    pub num_bits: Option<u32>,
49}
50
51// ── Index Config ──
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct IndexConfig {
55    pub columns: Vec<String>,
56    pub index_type: IndexType,
57    pub params: IndexParams,
58    pub replace: bool,
59}
60
61// ── Scan Ordering ──
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct ScanOrdering {
65    pub column: String,
66    pub ascending: bool,
67    pub nulls_first: bool,
68}
69
70impl ScanOrdering {
71    #[must_use]
72    pub fn asc(column: impl Into<String>) -> Self {
73        Self {
74            column: column.into(),
75            ascending: true,
76            nulls_first: false,
77        }
78    }
79
80    #[must_use]
81    pub fn desc(column: impl Into<String>) -> Self {
82        Self {
83            column: column.into(),
84            ascending: false,
85            nulls_first: false,
86        }
87    }
88}
89
90// ── Scan Options ──
91
92#[derive(Debug, Clone, PartialEq, Eq)]
93pub enum ExactMatchFilter {
94    Utf8In {
95        column: String,
96        values: Vec<String>,
97    },
98    /// Matches rows where `column_a = value` OR `column_b = value`.
99    /// Used for bidirectional edge lookup (source OR target equals a given id).
100    Utf8MultiColumnOr {
101        columns: Vec<String>,
102        value: String,
103    },
104}
105
106impl ExactMatchFilter {
107    /// Validate that a column name is safe to interpolate into a SQL predicate.
108    ///
109    /// Column names in hirn are always statically known lowercase snake_case
110    /// identifiers. This assertion ensures no user-controlled string can reach
111    /// the SQL interpolation path and create an injection vector.
112    fn assert_safe_column(col: &str) {
113        debug_assert!(
114            !col.is_empty()
115                && col.len() <= 64
116                && col
117                    .chars()
118                    .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_'),
119            "column name '{col}' contains unsafe characters — only [a-z0-9_] are allowed"
120        );
121    }
122
123    #[must_use]
124    pub fn utf8_value(column: impl Into<String>, value: impl Into<String>) -> Self {
125        let column = column.into();
126        Self::assert_safe_column(&column);
127        Self::Utf8In {
128            column,
129            values: vec![value.into()],
130        }
131    }
132
133    #[must_use]
134    pub fn utf8_values<I, S>(column: impl Into<String>, values: I) -> Option<Self>
135    where
136        I: IntoIterator<Item = S>,
137        S: Into<String>,
138    {
139        let values: Vec<String> = values.into_iter().map(Into::into).collect();
140        if values.is_empty() {
141            return None;
142        }
143
144        let column = column.into();
145        Self::assert_safe_column(&column);
146        Some(Self::Utf8In {
147            column,
148            values,
149        })
150    }
151
152    #[must_use]
153    pub fn utf8_multi_column_or(columns: Vec<String>, value: impl Into<String>) -> Self {
154        for col in &columns {
155            Self::assert_safe_column(col);
156        }
157        Self::Utf8MultiColumnOr {
158            columns,
159            value: value.into(),
160        }
161    }
162
163    #[must_use]
164    pub fn to_predicate_sql(&self) -> String {
165        match self {
166            Self::Utf8In { column, values } => {
167                if values.is_empty() {
168                    return "1 = 0".to_string();
169                }
170
171                let in_list = values
172                    .iter()
173                    .map(|value| format!("'{}'", value.replace('\'', "''")))
174                    .collect::<Vec<_>>()
175                    .join(", ");
176                format!("{column} IN ({in_list})")
177            }
178            Self::Utf8MultiColumnOr { columns, value } => {
179                if columns.is_empty() {
180                    return "1 = 0".to_string();
181                }
182                let escaped = value.replace('\'', "''");
183                columns
184                    .iter()
185                    .map(|col| format!("{col} = '{escaped}'"))
186                    .collect::<Vec<_>>()
187                    .join(" OR ")
188            }
189        }
190    }
191}
192
193#[derive(Debug, Clone, Default)]
194pub struct ScanOptions {
195    pub filter: Option<String>,
196    pub exact_filter: Option<ExactMatchFilter>,
197    pub columns: Option<Vec<String>>,
198    pub order_by: Option<Vec<ScanOrdering>>,
199    pub limit: Option<usize>,
200    pub offset: Option<usize>,
201}
202
203// ── Vector Search Options ──
204
205#[derive(Debug, Clone)]
206pub struct VectorSearchOptions {
207    pub column: String,
208    pub query: Vec<f32>,
209    pub metric: DistanceMetric,
210    pub limit: usize,
211    pub filter: Option<String>,
212    pub nprobes: Option<usize>,
213    pub refine_factor: Option<u32>,
214}
215
216impl Default for VectorSearchOptions {
217    fn default() -> Self {
218        Self {
219            column: String::new(),
220            query: Vec::new(),
221            metric: DistanceMetric::default(),
222            limit: 10,
223            filter: None,
224            nprobes: None,
225            refine_factor: None,
226        }
227    }
228}
229
230// ── FTS Search Options ──
231
232#[derive(Debug, Clone)]
233pub struct FtsSearchOptions {
234    pub columns: Vec<String>,
235    pub query: String,
236    pub limit: usize,
237    pub filter: Option<String>,
238}
239
240// ── Hybrid Search Options ──
241
242#[derive(Clone)]
243pub struct HybridSearchOptions {
244    pub vector_column: String,
245    pub query_vector: Vec<f32>,
246    pub fts_columns: Vec<String>,
247    pub fts_query: String,
248    pub normalize: NormalizeMethod,
249    pub metric: DistanceMetric,
250    pub limit: usize,
251    pub filter: Option<String>,
252    /// Optional reranker. Defaults to [`RRFReranker`](crate::reranker::RRFReranker) if `None`.
253    pub reranker: Option<Arc<dyn Reranker>>,
254}
255
256impl std::fmt::Debug for HybridSearchOptions {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.debug_struct("HybridSearchOptions")
259            .field("vector_column", &self.vector_column)
260            .field("fts_columns", &self.fts_columns)
261            .field("fts_query", &self.fts_query)
262            .field("normalize", &self.normalize)
263            .field("metric", &self.metric)
264            .field("limit", &self.limit)
265            .field("filter", &self.filter)
266            .field("reranker", &self.reranker.as_ref().map(|_| ".."))
267            .finish()
268    }
269}
270
271// ── Multivector Search ──
272
273#[derive(Debug, Clone)]
274pub enum MultivectorQuery {
275    Single(Vec<f32>),
276    Multi(Vec<Vec<f32>>),
277}
278
279#[derive(Debug, Clone)]
280pub struct MultivectorSearchOptions {
281    /// Multivector column (`List<FixedSizeList<Float32>>`) for MaxSim scoring.
282    pub column: String,
283    pub query: MultivectorQuery,
284    pub metric: DistanceMetric,
285    pub limit: usize,
286    pub filter: Option<String>,
287    /// Optional dense embedding column for first-stage ANN retrieval.
288    /// When set, enables two-stage search: ANN over this column → MaxSim
289    /// re-scoring using `column`. When `None`, falls back to brute-force scan.
290    pub dense_column: Option<String>,
291    /// Number of candidates to retrieve in the first stage (default: `limit * 10`).
292    pub first_stage_limit: Option<usize>,
293}
294
295// ── Compact Options / Result ──
296
297#[derive(Debug, Clone, Default)]
298pub struct CompactOptions {
299    pub max_rows_per_group: Option<usize>,
300    pub target_rows_per_fragment: Option<usize>,
301}
302
303#[derive(Debug, Clone, Default)]
304pub struct CompactResult {
305    pub fragments_removed: u64,
306    pub fragments_added: u64,
307    pub rows_removed: u64,
308}
309
310// ── Version Tag ──
311
312#[derive(Debug, Clone, PartialEq, Eq)]
313pub struct VersionTag {
314    pub name: String,
315    pub version: u64,
316    pub created_at: i64,
317}
318
319// ── Dataset Info ──
320
321#[derive(Debug, Clone)]
322pub struct DatasetInfo {
323    pub name: String,
324    pub version: u64,
325    pub row_count: u64,
326    pub schema: SchemaRef,
327}
328
329pub type RecordBatchStream =
330    std::pin::Pin<Box<dyn futures::Stream<Item = Result<RecordBatch, HirnDbError>> + Send>>;
331
332// ── Column Transform ──
333
334#[derive(Debug, Clone)]
335pub enum ColumnTransform {
336    AddColumn {
337        name: String,
338        data_type: arrow_schema::DataType,
339        nullable: bool,
340        default_value: Option<String>,
341    },
342    RenameColumn {
343        old_name: String,
344        new_name: String,
345    },
346}
347
348// ── PhysicalStore Trait ──
349
350/// Physical storage operations on Lance datasets.
351///
352/// `LancePhysicalStore` implements this directly against lance 4.0 Dataset + LanceNamespace.
353/// `MemoryStore` implements this for tests with real Arrow data, brute-force search, etc.
354#[async_trait]
355pub trait PhysicalStore: Send + Sync {
356    // ── CRUD ──
357
358    /// Append rows to a dataset. Creates the dataset if it doesn't exist.
359    async fn append(&self, dataset: &str, batch: RecordBatch) -> Result<(), HirnDbError>;
360
361    /// Append multiple record batches in one logical storage operation.
362    async fn append_batches(
363        &self,
364        dataset: &str,
365        batches: Vec<RecordBatch>,
366    ) -> Result<(), HirnDbError>;
367
368    /// Append a streaming sequence of record batches to a dataset.
369    ///
370    /// Batches are buffered up to `MAX_STREAM_BATCH_ROWS` rows before each
371    /// flush to `append_batches`, bounding peak memory for large streams.
372    /// This is the correct API for pipeline or operator-driven writes where
373    /// the total row count is not known up front.
374    ///
375    /// The default implementation collects bounded buffers and calls
376    /// `append_batches`. Store implementations may override to stream
377    /// directly into the underlying storage engine without intermediate
378    /// materialization.
379    async fn append_stream(
380        &self,
381        dataset: &str,
382        mut stream: RecordBatchStream,
383    ) -> Result<(), HirnDbError> {
384        use futures::StreamExt as _;
385        const MAX_STREAM_BATCH_ROWS: usize = 50_000;
386        let mut buffer: Vec<RecordBatch> = Vec::new();
387        let mut buffered_rows: usize = 0;
388        while let Some(result) = stream.next().await {
389            let batch = result?;
390            if batch.num_rows() == 0 {
391                continue;
392            }
393            buffered_rows += batch.num_rows();
394            buffer.push(batch);
395            if buffered_rows >= MAX_STREAM_BATCH_ROWS {
396                self.append_batches(dataset, std::mem::take(&mut buffer))
397                    .await?;
398                buffered_rows = 0;
399            }
400        }
401        if !buffer.is_empty() {
402            self.append_batches(dataset, buffer).await?;
403        }
404        Ok(())
405    }
406
407    /// Scan with predicate pushdown, projection, and optional limit/offset.
408    async fn scan(&self, dataset: &str, opts: ScanOptions)
409    -> Result<Vec<RecordBatch>, HirnDbError>;
410
411    /// Stream batches incrementally instead of materializing the whole scan.
412    async fn scan_stream(
413        &self,
414        dataset: &str,
415        opts: ScanOptions,
416    ) -> Result<RecordBatchStream, HirnDbError>;
417
418    /// Delete rows by predicate. Returns count of deleted rows.
419    ///
420    /// # Security note
421    /// This method accepts a raw SQL predicate string. All callers **must** ensure
422    /// values are constructed from system-generated identifiers (ULIDs, integers) or
423    /// properly escaped via `str::replace('\'', "''")`. Prefer [`Self::delete_exact`]
424    /// for single-column exact matches.
425    #[doc(hidden)]
426    async fn delete(&self, dataset: &str, predicate: &str) -> Result<u64, HirnDbError>;
427
428    /// Delete rows by structured exact-match filter. Returns count of deleted rows.
429    async fn delete_exact(
430        &self,
431        dataset: &str,
432        filter: &ExactMatchFilter,
433    ) -> Result<u64, HirnDbError> {
434        let predicate = filter.to_predicate_sql();
435        self.delete(dataset, &predicate).await
436    }
437
438    /// Merge-insert (upsert): insert new rows, update matching rows.
439    async fn merge_insert(
440        &self,
441        dataset: &str,
442        on: &[&str],
443        batch: RecordBatch,
444    ) -> Result<(), HirnDbError>;
445
446    /// Targeted in-place column update.
447    ///
448    /// Executes a narrow `SET col = expr [, …] WHERE filter` statement — no
449    /// full-row read-modify-write.  `updates` is a slice of `(column, sql_expr)`
450    /// pairs where `sql_expr` is a SQL literal or expression understood by the
451    /// backing store (e.g. `"true"`, `"'hello'"`, `"42"`).
452    ///
453    /// This avoids the RMW race inherent in scan → modify → merge_insert.
454    async fn update_where(
455        &self,
456        dataset: &str,
457        filter: &str,
458        updates: &[(&str, &str)],
459    ) -> Result<u64, HirnDbError>;
460
461    /// Count rows (optionally filtered). Uses fast metadata path when no filter.
462    async fn count(&self, dataset: &str, filter: Option<&str>) -> Result<u64, HirnDbError>;
463
464    // ── Search ──
465
466    /// Vector ANN search.
467    async fn vector_search(
468        &self,
469        dataset: &str,
470        opts: VectorSearchOptions,
471    ) -> Result<Vec<RecordBatch>, HirnDbError>;
472
473    /// Batched vector ANN search preserving query order.
474    async fn vector_search_many(
475        &self,
476        dataset: &str,
477        queries: Vec<VectorSearchOptions>,
478    ) -> Result<Vec<Vec<RecordBatch>>, HirnDbError>;
479
480    /// Full-text search (BM25).
481    async fn fts_search(
482        &self,
483        dataset: &str,
484        opts: FtsSearchOptions,
485    ) -> Result<Vec<RecordBatch>, HirnDbError>;
486
487    /// Hybrid search (vector + FTS fusion with configurable reranker + normalization).
488    async fn hybrid_search(
489        &self,
490        dataset: &str,
491        opts: HybridSearchOptions,
492    ) -> Result<Vec<RecordBatch>, HirnDbError>;
493
494    /// Multivector search (ColBERT/ColPaLi-style late interaction with MaxSim).
495    async fn multivector_search(
496        &self,
497        dataset: &str,
498        opts: MultivectorSearchOptions,
499    ) -> Result<Vec<RecordBatch>, HirnDbError>;
500
501    // ── Indexing ──
502
503    /// Create or replace an index (vector, scalar, FTS).
504    async fn create_index(&self, dataset: &str, config: IndexConfig) -> Result<(), HirnDbError>;
505
506    /// Optimize existing indices.
507    async fn optimize_indices(&self, dataset: &str) -> Result<(), HirnDbError>;
508
509    // ── Compaction ──
510
511    /// Compact fragments + prune deleted rows.
512    async fn compact(
513        &self,
514        dataset: &str,
515        opts: CompactOptions,
516    ) -> Result<CompactResult, HirnDbError>;
517
518    // ── Versioning ──
519
520    /// Get current dataset version.
521    async fn version(&self, dataset: &str) -> Result<u64, HirnDbError>;
522
523    /// Snapshot (tag) the current version.
524    async fn tag(&self, dataset: &str, tag: &str) -> Result<(), HirnDbError>;
525
526    /// Checkout a historical version (read-only).
527    async fn checkout(&self, dataset: &str, version: u64) -> Result<(), HirnDbError>;
528
529    /// List all tags.
530    async fn list_tags(&self, dataset: &str) -> Result<Vec<VersionTag>, HirnDbError>;
531
532    // ── Dataset management ──
533
534    /// List all datasets in the current namespace.
535    async fn list_datasets(&self) -> Result<Vec<DatasetInfo>, HirnDbError>;
536
537    /// Check existence.
538    async fn exists(&self, dataset: &str) -> Result<bool, HirnDbError>;
539
540    // ── Namespace ──
541
542    /// List sub-namespaces.
543    async fn list_namespaces(&self) -> Result<Vec<String>, HirnDbError>;
544
545    /// Create a new namespace.
546    async fn create_namespace(&self, name: &str) -> Result<(), HirnDbError>;
547
548    /// Drop a namespace and all its tables.
549    async fn drop_namespace(&self, name: &str) -> Result<(), HirnDbError>;
550
551    // ── Schema evolution ──
552
553    /// Add columns to a dataset.
554    async fn add_columns(
555        &self,
556        dataset: &str,
557        transforms: Vec<ColumnTransform>,
558    ) -> Result<(), HirnDbError>;
559
560    /// Drop columns from a dataset.
561    async fn drop_columns(&self, dataset: &str, columns: &[&str]) -> Result<(), HirnDbError>;
562
563    // ── DataFusion Integration ──
564
565    /// Return a DataFusion `TableProvider` for the named dataset.
566    ///
567    /// Lance-backed stores return a `LanceTableProvider` with native projection
568    /// and filter pushdown. Non-Lance stores (e.g. `MemoryStore`) return `None`,
569    /// triggering a fallback to empty `MemTable` stubs.
570    ///
571    /// Wrapper stores (e.g. `PolicyEnforcedStore`) delegate to their inner store.
572    async fn table_provider(&self, dataset: &str) -> Option<Arc<dyn TableProvider>>;
573}