Skip to main content

prax_pgvector/
index.rs

1//! Vector index management for pgvector.
2//!
3//! pgvector supports two approximate nearest-neighbor (ANN) index types:
4//!
5//! | Index | Algorithm | Best For | Tradeoff |
6//! |-------|-----------|----------|----------|
7//! | **IVFFlat** | Inverted file with flat quantization | Large datasets, tunable recall | Requires training data |
8//! | **HNSW** | Hierarchical navigable small world | Most workloads, no training needed | Higher memory usage |
9//!
10//! # Choosing an Index
11//!
12//! - **HNSW** is recommended for most use cases — better recall/speed tradeoff,
13//!   no training step, and supports concurrent inserts.
14//! - **IVFFlat** is useful when memory is constrained or when you have very
15//!   large datasets and can tolerate a training step.
16
17use std::fmt;
18
19use serde::{Deserialize, Serialize};
20
21use crate::error::{VectorError, VectorResult};
22use crate::ops::{BinaryDistanceMetric, DistanceMetric};
23
24/// The type of ANN index to create.
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26#[non_exhaustive]
27pub enum IndexType {
28    /// IVFFlat (Inverted File with Flat quantization).
29    IvfFlat(IvfFlatConfig),
30
31    /// HNSW (Hierarchical Navigable Small World).
32    Hnsw(HnswConfig),
33}
34
35impl fmt::Display for IndexType {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            Self::IvfFlat(_) => write!(f, "ivfflat"),
39            Self::Hnsw(_) => write!(f, "hnsw"),
40        }
41    }
42}
43
44/// Configuration for IVFFlat indexes.
45///
46/// IVFFlat divides vectors into `lists` number of clusters during a training phase.
47/// At query time, `probes` clusters are searched.
48///
49/// # Tuning Guidelines
50///
51/// - `lists`: Start with `rows / 1000` for up to 1M rows, `sqrt(rows)` for more.
52/// - `probes`: Start with `sqrt(lists)` and increase for better recall.
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54pub struct IvfFlatConfig {
55    /// Number of inverted lists (clusters).
56    ///
57    /// More lists = faster search but potentially lower recall.
58    /// Recommended: `rows / 1000` for up to 1M rows.
59    pub lists: usize,
60}
61
62impl IvfFlatConfig {
63    /// Create a new IVFFlat config with the given number of lists.
64    pub fn new(lists: usize) -> Self {
65        Self { lists }
66    }
67
68    /// Create a config with the recommended number of lists for a given row count.
69    pub fn for_row_count(rows: usize) -> Self {
70        let lists = if rows <= 1_000_000 {
71            (rows / 1000).max(1)
72        } else {
73            (rows as f64).sqrt() as usize
74        };
75        Self { lists }
76    }
77}
78
79impl Default for IvfFlatConfig {
80    fn default() -> Self {
81        Self { lists: 100 }
82    }
83}
84
85/// Configuration for HNSW indexes.
86///
87/// HNSW builds a multi-layered graph that enables efficient approximate nearest-neighbor
88/// search without a separate training step.
89///
90/// # Tuning Guidelines
91///
92/// - `m`: Number of connections per node. Higher = better recall, more memory.
93///   Default: 16. Range: 2-100.
94/// - `ef_construction`: Size of the dynamic candidate list during index build.
95///   Higher = better recall, slower build. Default: 64. Range: 4-1000.
96#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct HnswConfig {
98    /// Maximum number of connections per node per layer.
99    ///
100    /// Higher values improve recall but increase memory and build time.
101    /// Default: 16.
102    pub m: Option<usize>,
103
104    /// Size of the dynamic candidate list during construction.
105    ///
106    /// Higher values improve index quality but slow down build.
107    /// Default: 64.
108    pub ef_construction: Option<usize>,
109}
110
111impl HnswConfig {
112    /// Create a new HNSW config with defaults.
113    pub fn new() -> Self {
114        Self {
115            m: None,
116            ef_construction: None,
117        }
118    }
119
120    /// Set the `m` parameter (connections per node).
121    pub fn m(mut self, m: usize) -> Self {
122        self.m = Some(m);
123        self
124    }
125
126    /// Set the `ef_construction` parameter.
127    pub fn ef_construction(mut self, ef: usize) -> Self {
128        self.ef_construction = Some(ef);
129        self
130    }
131
132    /// High-recall configuration (slower build, better search quality).
133    pub fn high_recall() -> Self {
134        Self {
135            m: Some(32),
136            ef_construction: Some(128),
137        }
138    }
139
140    /// Fast-build configuration (faster build, lower recall).
141    pub fn fast_build() -> Self {
142        Self {
143            m: Some(8),
144            ef_construction: Some(32),
145        }
146    }
147}
148
149impl Default for HnswConfig {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155/// A vector index definition.
156///
157/// # Examples
158///
159/// ```rust
160/// use prax_pgvector::index::{VectorIndex, HnswConfig};
161/// use prax_pgvector::DistanceMetric;
162///
163/// // Create an HNSW index
164/// let index = VectorIndex::hnsw("idx_embedding", "documents", "embedding")
165///     .metric(DistanceMetric::Cosine)
166///     .config(HnswConfig::high_recall())
167///     .build()
168///     .unwrap();
169///
170/// let sql = index.to_create_sql();
171/// assert!(sql.contains("USING hnsw"));
172/// assert!(sql.contains("vector_cosine_ops"));
173/// ```
174#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
175pub struct VectorIndex {
176    /// Index name.
177    pub name: String,
178    /// Table name.
179    pub table: String,
180    /// Column name.
181    pub column: String,
182    /// Distance metric.
183    pub metric: DistanceMetric,
184    /// Index type and configuration.
185    pub index_type: IndexType,
186    /// Whether to create concurrently (non-blocking).
187    pub concurrent: bool,
188    /// Whether to add IF NOT EXISTS clause.
189    pub if_not_exists: bool,
190}
191
192impl VectorIndex {
193    /// Start building an HNSW index.
194    pub fn hnsw(
195        name: impl Into<String>,
196        table: impl Into<String>,
197        column: impl Into<String>,
198    ) -> VectorIndexBuilder {
199        VectorIndexBuilder {
200            name: name.into(),
201            table: table.into(),
202            column: column.into(),
203            metric: DistanceMetric::L2,
204            index_type: IndexType::Hnsw(HnswConfig::default()),
205            concurrent: false,
206            if_not_exists: false,
207        }
208    }
209
210    /// Start building an IVFFlat index.
211    pub fn ivfflat(
212        name: impl Into<String>,
213        table: impl Into<String>,
214        column: impl Into<String>,
215    ) -> VectorIndexBuilder {
216        VectorIndexBuilder {
217            name: name.into(),
218            table: table.into(),
219            column: column.into(),
220            metric: DistanceMetric::L2,
221            index_type: IndexType::IvfFlat(IvfFlatConfig::default()),
222            concurrent: false,
223            if_not_exists: false,
224        }
225    }
226
227    /// Generate the CREATE INDEX SQL statement.
228    pub fn to_create_sql(&self) -> String {
229        let concurrent = if self.concurrent { " CONCURRENTLY" } else { "" };
230        let if_not_exists = if self.if_not_exists {
231            " IF NOT EXISTS"
232        } else {
233            ""
234        };
235
236        let (method, with_clause) = match &self.index_type {
237            IndexType::IvfFlat(config) => {
238                let with = format!(" WITH (lists = {})", config.lists);
239                ("ivfflat", with)
240            }
241            IndexType::Hnsw(config) => {
242                let mut with_parts = Vec::new();
243                if let Some(m) = config.m {
244                    with_parts.push(format!("m = {m}"));
245                }
246                if let Some(ef) = config.ef_construction {
247                    with_parts.push(format!("ef_construction = {ef}"));
248                }
249                let with = if with_parts.is_empty() {
250                    String::new()
251                } else {
252                    format!(" WITH ({})", with_parts.join(", "))
253                };
254                ("hnsw", with)
255            }
256        };
257
258        format!(
259            "CREATE INDEX{}{} {} ON {} USING {} ({} {}){}",
260            concurrent,
261            if_not_exists,
262            self.name,
263            self.table,
264            method,
265            self.column,
266            self.metric.ops_class(),
267            with_clause
268        )
269    }
270
271    /// Generate the DROP INDEX SQL statement.
272    pub fn to_drop_sql(&self) -> String {
273        let concurrent = if self.concurrent { " CONCURRENTLY" } else { "" };
274        format!("DROP INDEX{} IF EXISTS {}", concurrent, self.name)
275    }
276
277    /// Generate SQL to check if this index exists.
278    pub fn to_exists_sql(&self) -> String {
279        format!(
280            "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = '{}')",
281            self.name
282        )
283    }
284
285    /// Generate SQL to get the index size.
286    pub fn to_size_sql(&self) -> String {
287        format!("SELECT pg_size_pretty(pg_relation_size('{}'))", self.name)
288    }
289}
290
291/// Builder for [`VectorIndex`].
292#[derive(Debug, Clone)]
293pub struct VectorIndexBuilder {
294    name: String,
295    table: String,
296    column: String,
297    metric: DistanceMetric,
298    index_type: IndexType,
299    concurrent: bool,
300    if_not_exists: bool,
301}
302
303impl VectorIndexBuilder {
304    /// Set the distance metric.
305    pub fn metric(mut self, metric: DistanceMetric) -> Self {
306        self.metric = metric;
307        self
308    }
309
310    /// Set the HNSW configuration (only effective for HNSW indexes).
311    pub fn config(mut self, config: HnswConfig) -> Self {
312        self.index_type = IndexType::Hnsw(config);
313        self
314    }
315
316    /// Set the IVFFlat configuration (only effective for IVFFlat indexes).
317    pub fn ivfflat_config(mut self, config: IvfFlatConfig) -> Self {
318        self.index_type = IndexType::IvfFlat(config);
319        self
320    }
321
322    /// Create the index concurrently (non-blocking).
323    pub fn concurrent(mut self) -> Self {
324        self.concurrent = true;
325        self
326    }
327
328    /// Add IF NOT EXISTS clause.
329    pub fn if_not_exists(mut self) -> Self {
330        self.if_not_exists = true;
331        self
332    }
333
334    /// Build the index definition.
335    ///
336    /// # Errors
337    ///
338    /// Returns an error if the configuration is invalid.
339    pub fn build(self) -> VectorResult<VectorIndex> {
340        if self.name.is_empty() {
341            return Err(VectorError::index("index name cannot be empty"));
342        }
343        if self.table.is_empty() {
344            return Err(VectorError::index("table name cannot be empty"));
345        }
346        if self.column.is_empty() {
347            return Err(VectorError::index("column name cannot be empty"));
348        }
349
350        Ok(VectorIndex {
351            name: self.name,
352            table: self.table,
353            column: self.column,
354            metric: self.metric,
355            index_type: self.index_type,
356            concurrent: self.concurrent,
357            if_not_exists: self.if_not_exists,
358        })
359    }
360}
361
362/// A binary vector index definition.
363#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
364pub struct BinaryVectorIndex {
365    /// Index name.
366    pub name: String,
367    /// Table name.
368    pub table: String,
369    /// Column name.
370    pub column: String,
371    /// Distance metric.
372    pub metric: BinaryDistanceMetric,
373    /// HNSW configuration (only HNSW is supported for bit vectors).
374    pub hnsw_config: HnswConfig,
375    /// Whether to create concurrently.
376    pub concurrent: bool,
377}
378
379impl BinaryVectorIndex {
380    /// Create a new binary vector index builder.
381    #[allow(clippy::new_ret_no_self)]
382    pub fn new(
383        name: impl Into<String>,
384        table: impl Into<String>,
385        column: impl Into<String>,
386    ) -> BinaryVectorIndexBuilder {
387        BinaryVectorIndexBuilder {
388            name: name.into(),
389            table: table.into(),
390            column: column.into(),
391            metric: BinaryDistanceMetric::Hamming,
392            hnsw_config: HnswConfig::default(),
393            concurrent: false,
394        }
395    }
396
397    /// Generate the CREATE INDEX SQL.
398    pub fn to_create_sql(&self) -> String {
399        let concurrent = if self.concurrent { " CONCURRENTLY" } else { "" };
400
401        let mut with_parts = Vec::new();
402        if let Some(m) = self.hnsw_config.m {
403            with_parts.push(format!("m = {m}"));
404        }
405        if let Some(ef) = self.hnsw_config.ef_construction {
406            with_parts.push(format!("ef_construction = {ef}"));
407        }
408        let with = if with_parts.is_empty() {
409            String::new()
410        } else {
411            format!(" WITH ({})", with_parts.join(", "))
412        };
413
414        format!(
415            "CREATE INDEX{} {} ON {} USING hnsw ({} {}){}",
416            concurrent,
417            self.name,
418            self.table,
419            self.column,
420            self.metric.ops_class(),
421            with
422        )
423    }
424}
425
426/// Builder for [`BinaryVectorIndex`].
427#[derive(Debug, Clone)]
428pub struct BinaryVectorIndexBuilder {
429    name: String,
430    table: String,
431    column: String,
432    metric: BinaryDistanceMetric,
433    hnsw_config: HnswConfig,
434    concurrent: bool,
435}
436
437impl BinaryVectorIndexBuilder {
438    /// Set the distance metric.
439    pub fn metric(mut self, metric: BinaryDistanceMetric) -> Self {
440        self.metric = metric;
441        self
442    }
443
444    /// Set the HNSW configuration.
445    pub fn config(mut self, config: HnswConfig) -> Self {
446        self.hnsw_config = config;
447        self
448    }
449
450    /// Create the index concurrently.
451    pub fn concurrent(mut self) -> Self {
452        self.concurrent = true;
453        self
454    }
455
456    /// Build the index definition.
457    pub fn build(self) -> VectorResult<BinaryVectorIndex> {
458        if self.name.is_empty() {
459            return Err(VectorError::index("index name cannot be empty"));
460        }
461        Ok(BinaryVectorIndex {
462            name: self.name,
463            table: self.table,
464            column: self.column,
465            metric: self.metric,
466            hnsw_config: self.hnsw_config,
467            concurrent: self.concurrent,
468        })
469    }
470}
471
472/// SQL helpers for pgvector extension management.
473pub mod extension {
474    /// Generate SQL to create the pgvector extension.
475    pub fn create_extension_sql() -> &'static str {
476        "CREATE EXTENSION IF NOT EXISTS vector"
477    }
478
479    /// Generate SQL to create the pgvector extension in a specific schema.
480    pub fn create_extension_in_schema_sql(schema: &str) -> String {
481        format!("CREATE EXTENSION IF NOT EXISTS vector SCHEMA {schema}")
482    }
483
484    /// Generate SQL to drop the pgvector extension.
485    pub fn drop_extension_sql() -> &'static str {
486        "DROP EXTENSION IF EXISTS vector"
487    }
488
489    /// Generate SQL to check if pgvector is installed.
490    pub fn check_extension_sql() -> &'static str {
491        "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector')"
492    }
493
494    /// Generate SQL to get the installed pgvector version.
495    pub fn version_sql() -> &'static str {
496        "SELECT extversion FROM pg_extension WHERE extname = 'vector'"
497    }
498
499    /// Generate SQL to create a vector column.
500    pub fn add_vector_column_sql(table: &str, column: &str, dimensions: usize) -> String {
501        format!("ALTER TABLE {table} ADD COLUMN {column} vector({dimensions})")
502    }
503
504    /// Generate SQL to create a halfvec column.
505    pub fn add_halfvec_column_sql(table: &str, column: &str, dimensions: usize) -> String {
506        format!("ALTER TABLE {table} ADD COLUMN {column} halfvec({dimensions})")
507    }
508
509    /// Generate SQL to create a sparsevec column.
510    pub fn add_sparsevec_column_sql(table: &str, column: &str, dimensions: usize) -> String {
511        format!("ALTER TABLE {table} ADD COLUMN {column} sparsevec({dimensions})")
512    }
513
514    /// Generate SQL to create a bit column.
515    pub fn add_bit_column_sql(table: &str, column: &str, dimensions: usize) -> String {
516        format!("ALTER TABLE {table} ADD COLUMN {column} bit({dimensions})")
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_hnsw_index_create_sql() {
526        let index = VectorIndex::hnsw("idx_embedding", "documents", "embedding")
527            .metric(DistanceMetric::Cosine)
528            .config(HnswConfig::new().m(16).ef_construction(64))
529            .build()
530            .unwrap();
531
532        let sql = index.to_create_sql();
533        assert!(sql.contains("CREATE INDEX"));
534        assert!(sql.contains("idx_embedding"));
535        assert!(sql.contains("documents"));
536        assert!(sql.contains("USING hnsw"));
537        assert!(sql.contains("vector_cosine_ops"));
538        assert!(sql.contains("m = 16"));
539        assert!(sql.contains("ef_construction = 64"));
540    }
541
542    #[test]
543    fn test_hnsw_index_default_config() {
544        let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
545
546        let sql = index.to_create_sql();
547        assert!(sql.contains("USING hnsw"));
548        assert!(sql.contains("vector_l2_ops")); // default metric
549        assert!(!sql.contains("WITH")); // no config = no WITH clause
550    }
551
552    #[test]
553    fn test_ivfflat_index_create_sql() {
554        let index = VectorIndex::ivfflat("idx_embedding", "documents", "embedding")
555            .metric(DistanceMetric::L2)
556            .ivfflat_config(IvfFlatConfig::new(200))
557            .build()
558            .unwrap();
559
560        let sql = index.to_create_sql();
561        assert!(sql.contains("USING ivfflat"));
562        assert!(sql.contains("vector_l2_ops"));
563        assert!(sql.contains("lists = 200"));
564    }
565
566    #[test]
567    fn test_ivfflat_for_row_count() {
568        let config = IvfFlatConfig::for_row_count(500_000);
569        assert_eq!(config.lists, 500);
570
571        let config = IvfFlatConfig::for_row_count(5_000_000);
572        assert_eq!(config.lists, 2236); // sqrt(5M)
573    }
574
575    #[test]
576    fn test_concurrent_index() {
577        let index = VectorIndex::hnsw("idx_emb", "docs", "emb")
578            .concurrent()
579            .if_not_exists()
580            .build()
581            .unwrap();
582
583        let sql = index.to_create_sql();
584        assert!(sql.contains("CONCURRENTLY"));
585        assert!(sql.contains("IF NOT EXISTS"));
586    }
587
588    #[test]
589    fn test_drop_index() {
590        let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
591
592        let sql = index.to_drop_sql();
593        assert_eq!(sql, "DROP INDEX IF EXISTS idx_emb");
594    }
595
596    #[test]
597    fn test_concurrent_drop_index() {
598        let index = VectorIndex::hnsw("idx_emb", "docs", "emb")
599            .concurrent()
600            .build()
601            .unwrap();
602
603        let sql = index.to_drop_sql();
604        assert!(sql.contains("CONCURRENTLY"));
605    }
606
607    #[test]
608    fn test_index_exists_sql() {
609        let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
610
611        let sql = index.to_exists_sql();
612        assert!(sql.contains("pg_indexes"));
613        assert!(sql.contains("idx_emb"));
614    }
615
616    #[test]
617    fn test_index_size_sql() {
618        let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
619
620        let sql = index.to_size_sql();
621        assert!(sql.contains("pg_size_pretty"));
622        assert!(sql.contains("idx_emb"));
623    }
624
625    #[test]
626    fn test_empty_name_error() {
627        let result = VectorIndex::hnsw("", "docs", "emb").build();
628        assert!(result.is_err());
629    }
630
631    #[test]
632    fn test_hnsw_high_recall() {
633        let config = HnswConfig::high_recall();
634        assert_eq!(config.m, Some(32));
635        assert_eq!(config.ef_construction, Some(128));
636    }
637
638    #[test]
639    fn test_hnsw_fast_build() {
640        let config = HnswConfig::fast_build();
641        assert_eq!(config.m, Some(8));
642        assert_eq!(config.ef_construction, Some(32));
643    }
644
645    #[test]
646    fn test_binary_vector_index() {
647        let index = BinaryVectorIndex::new("idx_bits", "docs", "binary_emb")
648            .metric(BinaryDistanceMetric::Hamming)
649            .build()
650            .unwrap();
651
652        let sql = index.to_create_sql();
653        assert!(sql.contains("USING hnsw"));
654        assert!(sql.contains("bit_hamming_ops"));
655    }
656
657    #[test]
658    fn test_extension_create_sql() {
659        assert_eq!(
660            extension::create_extension_sql(),
661            "CREATE EXTENSION IF NOT EXISTS vector"
662        );
663    }
664
665    #[test]
666    fn test_extension_in_schema() {
667        let sql = extension::create_extension_in_schema_sql("public");
668        assert!(sql.contains("SCHEMA public"));
669    }
670
671    #[test]
672    fn test_add_vector_column() {
673        let sql = extension::add_vector_column_sql("documents", "embedding", 1536);
674        assert_eq!(
675            sql,
676            "ALTER TABLE documents ADD COLUMN embedding vector(1536)"
677        );
678    }
679
680    #[test]
681    fn test_add_sparsevec_column() {
682        let sql = extension::add_sparsevec_column_sql("documents", "sparse_emb", 30000);
683        assert!(sql.contains("sparsevec(30000)"));
684    }
685
686    #[test]
687    fn test_add_bit_column() {
688        let sql = extension::add_bit_column_sql("documents", "binary_emb", 1024);
689        assert!(sql.contains("bit(1024)"));
690    }
691
692    #[test]
693    fn test_check_extension_sql() {
694        let sql = extension::check_extension_sql();
695        assert!(sql.contains("pg_extension"));
696    }
697
698    #[test]
699    fn test_version_sql() {
700        let sql = extension::version_sql();
701        assert!(sql.contains("extversion"));
702    }
703
704    #[test]
705    fn test_index_type_display() {
706        let ivf = IndexType::IvfFlat(IvfFlatConfig::default());
707        assert_eq!(format!("{ivf}"), "ivfflat");
708
709        let hnsw = IndexType::Hnsw(HnswConfig::default());
710        assert_eq!(format!("{hnsw}"), "hnsw");
711    }
712
713    #[test]
714    fn test_all_metrics_with_ivfflat() {
715        for metric in [
716            DistanceMetric::L2,
717            DistanceMetric::InnerProduct,
718            DistanceMetric::Cosine,
719            DistanceMetric::L1,
720        ] {
721            let index = VectorIndex::ivfflat("idx", "t", "c")
722                .metric(metric)
723                .build()
724                .unwrap();
725            let sql = index.to_create_sql();
726            assert!(sql.contains(metric.ops_class()));
727        }
728    }
729
730    #[test]
731    fn test_all_metrics_with_hnsw() {
732        for metric in [
733            DistanceMetric::L2,
734            DistanceMetric::InnerProduct,
735            DistanceMetric::Cosine,
736            DistanceMetric::L1,
737        ] {
738            let index = VectorIndex::hnsw("idx", "t", "c")
739                .metric(metric)
740                .build()
741                .unwrap();
742            let sql = index.to_create_sql();
743            assert!(sql.contains(metric.ops_class()));
744        }
745    }
746}