1use std::fmt;
18
19use serde::{Deserialize, Serialize};
20
21use crate::error::{VectorError, VectorResult};
22use crate::ops::{BinaryDistanceMetric, DistanceMetric};
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26#[non_exhaustive]
27pub enum IndexType {
28 IvfFlat(IvfFlatConfig),
30
31 Hnsw(HnswConfig),
33}
34
35impl fmt::Display for IndexType {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 Self::IvfFlat(_) => write!(f, "ivfflat"),
39 Self::Hnsw(_) => write!(f, "hnsw"),
40 }
41 }
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54pub struct IvfFlatConfig {
55 pub lists: usize,
60}
61
62impl IvfFlatConfig {
63 pub fn new(lists: usize) -> Self {
65 Self { lists }
66 }
67
68 pub fn for_row_count(rows: usize) -> Self {
70 let lists = if rows <= 1_000_000 {
71 (rows / 1000).max(1)
72 } else {
73 (rows as f64).sqrt() as usize
74 };
75 Self { lists }
76 }
77}
78
79impl Default for IvfFlatConfig {
80 fn default() -> Self {
81 Self { lists: 100 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct HnswConfig {
98 pub m: Option<usize>,
103
104 pub ef_construction: Option<usize>,
109}
110
111impl HnswConfig {
112 pub fn new() -> Self {
114 Self {
115 m: None,
116 ef_construction: None,
117 }
118 }
119
120 pub fn m(mut self, m: usize) -> Self {
122 self.m = Some(m);
123 self
124 }
125
126 pub fn ef_construction(mut self, ef: usize) -> Self {
128 self.ef_construction = Some(ef);
129 self
130 }
131
132 pub fn high_recall() -> Self {
134 Self {
135 m: Some(32),
136 ef_construction: Some(128),
137 }
138 }
139
140 pub fn fast_build() -> Self {
142 Self {
143 m: Some(8),
144 ef_construction: Some(32),
145 }
146 }
147}
148
149impl Default for HnswConfig {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
175pub struct VectorIndex {
176 pub name: String,
178 pub table: String,
180 pub column: String,
182 pub metric: DistanceMetric,
184 pub index_type: IndexType,
186 pub concurrent: bool,
188 pub if_not_exists: bool,
190}
191
192impl VectorIndex {
193 pub fn hnsw(
195 name: impl Into<String>,
196 table: impl Into<String>,
197 column: impl Into<String>,
198 ) -> VectorIndexBuilder {
199 VectorIndexBuilder {
200 name: name.into(),
201 table: table.into(),
202 column: column.into(),
203 metric: DistanceMetric::L2,
204 index_type: IndexType::Hnsw(HnswConfig::default()),
205 concurrent: false,
206 if_not_exists: false,
207 }
208 }
209
210 pub fn ivfflat(
212 name: impl Into<String>,
213 table: impl Into<String>,
214 column: impl Into<String>,
215 ) -> VectorIndexBuilder {
216 VectorIndexBuilder {
217 name: name.into(),
218 table: table.into(),
219 column: column.into(),
220 metric: DistanceMetric::L2,
221 index_type: IndexType::IvfFlat(IvfFlatConfig::default()),
222 concurrent: false,
223 if_not_exists: false,
224 }
225 }
226
227 pub fn to_create_sql(&self) -> String {
229 let concurrent = if self.concurrent { " CONCURRENTLY" } else { "" };
230 let if_not_exists = if self.if_not_exists {
231 " IF NOT EXISTS"
232 } else {
233 ""
234 };
235
236 let (method, with_clause) = match &self.index_type {
237 IndexType::IvfFlat(config) => {
238 let with = format!(" WITH (lists = {})", config.lists);
239 ("ivfflat", with)
240 }
241 IndexType::Hnsw(config) => {
242 let mut with_parts = Vec::new();
243 if let Some(m) = config.m {
244 with_parts.push(format!("m = {m}"));
245 }
246 if let Some(ef) = config.ef_construction {
247 with_parts.push(format!("ef_construction = {ef}"));
248 }
249 let with = if with_parts.is_empty() {
250 String::new()
251 } else {
252 format!(" WITH ({})", with_parts.join(", "))
253 };
254 ("hnsw", with)
255 }
256 };
257
258 format!(
259 "CREATE INDEX{}{} {} ON {} USING {} ({} {}){}",
260 concurrent,
261 if_not_exists,
262 self.name,
263 self.table,
264 method,
265 self.column,
266 self.metric.ops_class(),
267 with_clause
268 )
269 }
270
271 pub fn to_drop_sql(&self) -> String {
273 let concurrent = if self.concurrent { " CONCURRENTLY" } else { "" };
274 format!("DROP INDEX{} IF EXISTS {}", concurrent, self.name)
275 }
276
277 pub fn to_exists_sql(&self) -> String {
279 format!(
280 "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = '{}')",
281 self.name
282 )
283 }
284
285 pub fn to_size_sql(&self) -> String {
287 format!("SELECT pg_size_pretty(pg_relation_size('{}'))", self.name)
288 }
289}
290
291#[derive(Debug, Clone)]
293pub struct VectorIndexBuilder {
294 name: String,
295 table: String,
296 column: String,
297 metric: DistanceMetric,
298 index_type: IndexType,
299 concurrent: bool,
300 if_not_exists: bool,
301}
302
303impl VectorIndexBuilder {
304 pub fn metric(mut self, metric: DistanceMetric) -> Self {
306 self.metric = metric;
307 self
308 }
309
310 pub fn config(mut self, config: HnswConfig) -> Self {
312 self.index_type = IndexType::Hnsw(config);
313 self
314 }
315
316 pub fn ivfflat_config(mut self, config: IvfFlatConfig) -> Self {
318 self.index_type = IndexType::IvfFlat(config);
319 self
320 }
321
322 pub fn concurrent(mut self) -> Self {
324 self.concurrent = true;
325 self
326 }
327
328 pub fn if_not_exists(mut self) -> Self {
330 self.if_not_exists = true;
331 self
332 }
333
334 pub fn build(self) -> VectorResult<VectorIndex> {
340 if self.name.is_empty() {
341 return Err(VectorError::index("index name cannot be empty"));
342 }
343 if self.table.is_empty() {
344 return Err(VectorError::index("table name cannot be empty"));
345 }
346 if self.column.is_empty() {
347 return Err(VectorError::index("column name cannot be empty"));
348 }
349
350 Ok(VectorIndex {
351 name: self.name,
352 table: self.table,
353 column: self.column,
354 metric: self.metric,
355 index_type: self.index_type,
356 concurrent: self.concurrent,
357 if_not_exists: self.if_not_exists,
358 })
359 }
360}
361
362#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
364pub struct BinaryVectorIndex {
365 pub name: String,
367 pub table: String,
369 pub column: String,
371 pub metric: BinaryDistanceMetric,
373 pub hnsw_config: HnswConfig,
375 pub concurrent: bool,
377}
378
379impl BinaryVectorIndex {
380 #[allow(clippy::new_ret_no_self)]
382 pub fn new(
383 name: impl Into<String>,
384 table: impl Into<String>,
385 column: impl Into<String>,
386 ) -> BinaryVectorIndexBuilder {
387 BinaryVectorIndexBuilder {
388 name: name.into(),
389 table: table.into(),
390 column: column.into(),
391 metric: BinaryDistanceMetric::Hamming,
392 hnsw_config: HnswConfig::default(),
393 concurrent: false,
394 }
395 }
396
397 pub fn to_create_sql(&self) -> String {
399 let concurrent = if self.concurrent { " CONCURRENTLY" } else { "" };
400
401 let mut with_parts = Vec::new();
402 if let Some(m) = self.hnsw_config.m {
403 with_parts.push(format!("m = {m}"));
404 }
405 if let Some(ef) = self.hnsw_config.ef_construction {
406 with_parts.push(format!("ef_construction = {ef}"));
407 }
408 let with = if with_parts.is_empty() {
409 String::new()
410 } else {
411 format!(" WITH ({})", with_parts.join(", "))
412 };
413
414 format!(
415 "CREATE INDEX{} {} ON {} USING hnsw ({} {}){}",
416 concurrent,
417 self.name,
418 self.table,
419 self.column,
420 self.metric.ops_class(),
421 with
422 )
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct BinaryVectorIndexBuilder {
429 name: String,
430 table: String,
431 column: String,
432 metric: BinaryDistanceMetric,
433 hnsw_config: HnswConfig,
434 concurrent: bool,
435}
436
437impl BinaryVectorIndexBuilder {
438 pub fn metric(mut self, metric: BinaryDistanceMetric) -> Self {
440 self.metric = metric;
441 self
442 }
443
444 pub fn config(mut self, config: HnswConfig) -> Self {
446 self.hnsw_config = config;
447 self
448 }
449
450 pub fn concurrent(mut self) -> Self {
452 self.concurrent = true;
453 self
454 }
455
456 pub fn build(self) -> VectorResult<BinaryVectorIndex> {
458 if self.name.is_empty() {
459 return Err(VectorError::index("index name cannot be empty"));
460 }
461 Ok(BinaryVectorIndex {
462 name: self.name,
463 table: self.table,
464 column: self.column,
465 metric: self.metric,
466 hnsw_config: self.hnsw_config,
467 concurrent: self.concurrent,
468 })
469 }
470}
471
472pub mod extension {
474 pub fn create_extension_sql() -> &'static str {
476 "CREATE EXTENSION IF NOT EXISTS vector"
477 }
478
479 pub fn create_extension_in_schema_sql(schema: &str) -> String {
481 format!("CREATE EXTENSION IF NOT EXISTS vector SCHEMA {schema}")
482 }
483
484 pub fn drop_extension_sql() -> &'static str {
486 "DROP EXTENSION IF EXISTS vector"
487 }
488
489 pub fn check_extension_sql() -> &'static str {
491 "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector')"
492 }
493
494 pub fn version_sql() -> &'static str {
496 "SELECT extversion FROM pg_extension WHERE extname = 'vector'"
497 }
498
499 pub fn add_vector_column_sql(table: &str, column: &str, dimensions: usize) -> String {
501 format!("ALTER TABLE {table} ADD COLUMN {column} vector({dimensions})")
502 }
503
504 pub fn add_halfvec_column_sql(table: &str, column: &str, dimensions: usize) -> String {
506 format!("ALTER TABLE {table} ADD COLUMN {column} halfvec({dimensions})")
507 }
508
509 pub fn add_sparsevec_column_sql(table: &str, column: &str, dimensions: usize) -> String {
511 format!("ALTER TABLE {table} ADD COLUMN {column} sparsevec({dimensions})")
512 }
513
514 pub fn add_bit_column_sql(table: &str, column: &str, dimensions: usize) -> String {
516 format!("ALTER TABLE {table} ADD COLUMN {column} bit({dimensions})")
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_hnsw_index_create_sql() {
526 let index = VectorIndex::hnsw("idx_embedding", "documents", "embedding")
527 .metric(DistanceMetric::Cosine)
528 .config(HnswConfig::new().m(16).ef_construction(64))
529 .build()
530 .unwrap();
531
532 let sql = index.to_create_sql();
533 assert!(sql.contains("CREATE INDEX"));
534 assert!(sql.contains("idx_embedding"));
535 assert!(sql.contains("documents"));
536 assert!(sql.contains("USING hnsw"));
537 assert!(sql.contains("vector_cosine_ops"));
538 assert!(sql.contains("m = 16"));
539 assert!(sql.contains("ef_construction = 64"));
540 }
541
542 #[test]
543 fn test_hnsw_index_default_config() {
544 let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
545
546 let sql = index.to_create_sql();
547 assert!(sql.contains("USING hnsw"));
548 assert!(sql.contains("vector_l2_ops")); assert!(!sql.contains("WITH")); }
551
552 #[test]
553 fn test_ivfflat_index_create_sql() {
554 let index = VectorIndex::ivfflat("idx_embedding", "documents", "embedding")
555 .metric(DistanceMetric::L2)
556 .ivfflat_config(IvfFlatConfig::new(200))
557 .build()
558 .unwrap();
559
560 let sql = index.to_create_sql();
561 assert!(sql.contains("USING ivfflat"));
562 assert!(sql.contains("vector_l2_ops"));
563 assert!(sql.contains("lists = 200"));
564 }
565
566 #[test]
567 fn test_ivfflat_for_row_count() {
568 let config = IvfFlatConfig::for_row_count(500_000);
569 assert_eq!(config.lists, 500);
570
571 let config = IvfFlatConfig::for_row_count(5_000_000);
572 assert_eq!(config.lists, 2236); }
574
575 #[test]
576 fn test_concurrent_index() {
577 let index = VectorIndex::hnsw("idx_emb", "docs", "emb")
578 .concurrent()
579 .if_not_exists()
580 .build()
581 .unwrap();
582
583 let sql = index.to_create_sql();
584 assert!(sql.contains("CONCURRENTLY"));
585 assert!(sql.contains("IF NOT EXISTS"));
586 }
587
588 #[test]
589 fn test_drop_index() {
590 let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
591
592 let sql = index.to_drop_sql();
593 assert_eq!(sql, "DROP INDEX IF EXISTS idx_emb");
594 }
595
596 #[test]
597 fn test_concurrent_drop_index() {
598 let index = VectorIndex::hnsw("idx_emb", "docs", "emb")
599 .concurrent()
600 .build()
601 .unwrap();
602
603 let sql = index.to_drop_sql();
604 assert!(sql.contains("CONCURRENTLY"));
605 }
606
607 #[test]
608 fn test_index_exists_sql() {
609 let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
610
611 let sql = index.to_exists_sql();
612 assert!(sql.contains("pg_indexes"));
613 assert!(sql.contains("idx_emb"));
614 }
615
616 #[test]
617 fn test_index_size_sql() {
618 let index = VectorIndex::hnsw("idx_emb", "docs", "emb").build().unwrap();
619
620 let sql = index.to_size_sql();
621 assert!(sql.contains("pg_size_pretty"));
622 assert!(sql.contains("idx_emb"));
623 }
624
625 #[test]
626 fn test_empty_name_error() {
627 let result = VectorIndex::hnsw("", "docs", "emb").build();
628 assert!(result.is_err());
629 }
630
631 #[test]
632 fn test_hnsw_high_recall() {
633 let config = HnswConfig::high_recall();
634 assert_eq!(config.m, Some(32));
635 assert_eq!(config.ef_construction, Some(128));
636 }
637
638 #[test]
639 fn test_hnsw_fast_build() {
640 let config = HnswConfig::fast_build();
641 assert_eq!(config.m, Some(8));
642 assert_eq!(config.ef_construction, Some(32));
643 }
644
645 #[test]
646 fn test_binary_vector_index() {
647 let index = BinaryVectorIndex::new("idx_bits", "docs", "binary_emb")
648 .metric(BinaryDistanceMetric::Hamming)
649 .build()
650 .unwrap();
651
652 let sql = index.to_create_sql();
653 assert!(sql.contains("USING hnsw"));
654 assert!(sql.contains("bit_hamming_ops"));
655 }
656
657 #[test]
658 fn test_extension_create_sql() {
659 assert_eq!(
660 extension::create_extension_sql(),
661 "CREATE EXTENSION IF NOT EXISTS vector"
662 );
663 }
664
665 #[test]
666 fn test_extension_in_schema() {
667 let sql = extension::create_extension_in_schema_sql("public");
668 assert!(sql.contains("SCHEMA public"));
669 }
670
671 #[test]
672 fn test_add_vector_column() {
673 let sql = extension::add_vector_column_sql("documents", "embedding", 1536);
674 assert_eq!(
675 sql,
676 "ALTER TABLE documents ADD COLUMN embedding vector(1536)"
677 );
678 }
679
680 #[test]
681 fn test_add_sparsevec_column() {
682 let sql = extension::add_sparsevec_column_sql("documents", "sparse_emb", 30000);
683 assert!(sql.contains("sparsevec(30000)"));
684 }
685
686 #[test]
687 fn test_add_bit_column() {
688 let sql = extension::add_bit_column_sql("documents", "binary_emb", 1024);
689 assert!(sql.contains("bit(1024)"));
690 }
691
692 #[test]
693 fn test_check_extension_sql() {
694 let sql = extension::check_extension_sql();
695 assert!(sql.contains("pg_extension"));
696 }
697
698 #[test]
699 fn test_version_sql() {
700 let sql = extension::version_sql();
701 assert!(sql.contains("extversion"));
702 }
703
704 #[test]
705 fn test_index_type_display() {
706 let ivf = IndexType::IvfFlat(IvfFlatConfig::default());
707 assert_eq!(format!("{ivf}"), "ivfflat");
708
709 let hnsw = IndexType::Hnsw(HnswConfig::default());
710 assert_eq!(format!("{hnsw}"), "hnsw");
711 }
712
713 #[test]
714 fn test_all_metrics_with_ivfflat() {
715 for metric in [
716 DistanceMetric::L2,
717 DistanceMetric::InnerProduct,
718 DistanceMetric::Cosine,
719 DistanceMetric::L1,
720 ] {
721 let index = VectorIndex::ivfflat("idx", "t", "c")
722 .metric(metric)
723 .build()
724 .unwrap();
725 let sql = index.to_create_sql();
726 assert!(sql.contains(metric.ops_class()));
727 }
728 }
729
730 #[test]
731 fn test_all_metrics_with_hnsw() {
732 for metric in [
733 DistanceMetric::L2,
734 DistanceMetric::InnerProduct,
735 DistanceMetric::Cosine,
736 DistanceMetric::L1,
737 ] {
738 let index = VectorIndex::hnsw("idx", "t", "c")
739 .metric(metric)
740 .build()
741 .unwrap();
742 let sql = index.to_create_sql();
743 assert!(sql.contains(metric.ops_class()));
744 }
745 }
746}