Skip to main content

daimon_plugin_pgvector/
builder.rs

1//! Builder for [`PgVectorStore`].
2
3use deadpool_postgres::{Config, Pool, Runtime};
4use daimon_core::{DaimonError, Result};
5use tokio_postgres::NoTls;
6
7use crate::migrations;
8use crate::store::PgVectorStore;
9use crate::DistanceMetric;
10
11/// Builds a [`PgVectorStore`] with connection pooling and optional auto-migration.
12///
13/// # Example
14///
15/// ```ignore
16/// use daimon_plugin_pgvector::{PgVectorStoreBuilder, DistanceMetric};
17///
18/// let store = PgVectorStoreBuilder::new("host=localhost dbname=mydb", 1536)
19///     .table("embeddings")
20///     .distance_metric(DistanceMetric::Cosine)
21///     .hnsw_m(16)
22///     .hnsw_ef_construction(64)
23///     .auto_migrate(true)
24///     .build()
25///     .await?;
26/// ```
27pub struct PgVectorStoreBuilder {
28    connection_string: String,
29    dimensions: usize,
30    table: String,
31    distance_metric: DistanceMetric,
32    auto_migrate: bool,
33    hnsw_m: Option<usize>,
34    hnsw_ef_construction: Option<usize>,
35    pool_size: usize,
36}
37
38impl PgVectorStoreBuilder {
39    /// Creates a new builder.
40    ///
41    /// - `connection_string`: PostgreSQL connection string
42    ///   (e.g. `"host=localhost dbname=mydb user=postgres"` or
43    ///   `"postgresql://user:pass@host/db"`)
44    /// - `dimensions`: the fixed vector dimension count (must match your embedding model)
45    pub fn new(connection_string: impl Into<String>, dimensions: usize) -> Self {
46        Self {
47            connection_string: connection_string.into(),
48            dimensions,
49            table: "daimon_vectors".into(),
50            distance_metric: DistanceMetric::Cosine,
51            auto_migrate: true,
52            hnsw_m: None,
53            hnsw_ef_construction: None,
54            pool_size: 16,
55        }
56    }
57
58    /// Sets the table name. Default: `"daimon_vectors"`.
59    pub fn table(mut self, table: impl Into<String>) -> Self {
60        self.table = table.into();
61        self
62    }
63
64    /// Sets the distance metric. Default: [`DistanceMetric::Cosine`].
65    pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
66        self.distance_metric = metric;
67        self
68    }
69
70    /// Enables or disables automatic schema creation on first connection.
71    /// Default: `true`.
72    ///
73    /// When disabled, use the SQL from [`crate::migrations`] to set up
74    /// the schema manually.
75    pub fn auto_migrate(mut self, enabled: bool) -> Self {
76        self.auto_migrate = enabled;
77        self
78    }
79
80    /// Sets the HNSW `m` parameter (max connections per layer).
81    /// `None` uses the PostgreSQL default (16).
82    pub fn hnsw_m(mut self, m: usize) -> Self {
83        self.hnsw_m = Some(m);
84        self
85    }
86
87    /// Sets the HNSW `ef_construction` parameter (build-time search width).
88    /// `None` uses the PostgreSQL default (64).
89    pub fn hnsw_ef_construction(mut self, ef: usize) -> Self {
90        self.hnsw_ef_construction = Some(ef);
91        self
92    }
93
94    /// Sets the maximum number of connections in the pool. Default: `16`.
95    pub fn pool_size(mut self, size: usize) -> Self {
96        self.pool_size = size;
97        self
98    }
99
100    /// Builds the [`PgVectorStore`], optionally running migrations.
101    pub async fn build(self) -> Result<PgVectorStore> {
102        let pool = self.create_pool()?;
103
104        if self.auto_migrate {
105            self.run_migrations(&pool).await?;
106        }
107
108        Ok(PgVectorStore {
109            pool,
110            table: self.table,
111            dimensions: self.dimensions,
112            distance_metric: self.distance_metric,
113        })
114    }
115
116    fn create_pool(&self) -> Result<Pool> {
117        let mut cfg = Config::new();
118        cfg.url = Some(self.connection_string.clone());
119        cfg.pool = Some(deadpool_postgres::PoolConfig {
120            max_size: self.pool_size,
121            ..Default::default()
122        });
123
124        cfg.create_pool(Some(Runtime::Tokio1), NoTls)
125            .map_err(|e| DaimonError::Other(format!("pgvector pool creation error: {e}")))
126    }
127
128    async fn run_migrations(&self, pool: &Pool) -> Result<()> {
129        let client = pool.get().await.map_err(|e| {
130            DaimonError::Other(format!("pgvector migration pool error: {e}"))
131        })?;
132
133        tracing::info!("pgvector: creating extension and table '{}'", self.table);
134
135        client
136            .execute(migrations::CREATE_EXTENSION, &[])
137            .await
138            .map_err(|e| DaimonError::Other(format!("pgvector CREATE EXTENSION error: {e}")))?;
139
140        let create_table = migrations::create_table_sql(&self.table, self.dimensions);
141        client
142            .execute(&create_table as &str, &[])
143            .await
144            .map_err(|e| DaimonError::Other(format!("pgvector CREATE TABLE error: {e}")))?;
145
146        let ops_class = self.distance_metric.ops_class();
147        let create_index = migrations::create_hnsw_index_sql(
148            &self.table,
149            ops_class,
150            self.hnsw_m,
151            self.hnsw_ef_construction,
152        );
153        client
154            .execute(&create_index as &str, &[])
155            .await
156            .map_err(|e| DaimonError::Other(format!("pgvector CREATE INDEX error: {e}")))?;
157
158        tracing::info!("pgvector: migration complete for '{}'", self.table);
159        Ok(())
160    }
161}