lumosai_vector_postgres/
storage.rs

1//! PostgreSQL vector storage implementation
2
3use std::collections::HashMap;
4use async_trait::async_trait;
5use sqlx::{PgPool, Row, postgres::PgPoolOptions};
6use serde_json::Value as JsonValue;
7use tracing::{debug, instrument, warn};
8
9use lumosai_vector_core::prelude::*;
10use crate::{PostgresConfig, PostgresError, PostgresResult};
11
12/// PostgreSQL vector storage implementation using pgvector
13pub struct PostgresVectorStorage {
14    pool: PgPool,
15    config: PostgresConfig,
16}
17
18impl PostgresVectorStorage {
19    /// Create a new PostgreSQL vector storage instance
20    pub async fn new(database_url: &str) -> Result<Self> {
21        let config = PostgresConfig::new(database_url);
22        Self::with_config(config).await
23    }
24    
25    /// Create a new PostgreSQL vector storage instance with configuration
26    pub async fn with_config(config: PostgresConfig) -> Result<Self> {
27        let pool = PgPoolOptions::new()
28            .max_connections(config.pool.max_connections)
29            .min_connections(config.pool.min_connections)
30            .acquire_timeout(config.pool.connect_timeout)
31            .idle_timeout(config.pool.idle_timeout)
32            .max_lifetime(config.pool.max_lifetime)
33            .connect(&config.database_url)
34            .await
35            .map_err(PostgresError::from)?;
36        
37        let storage = Self { pool, config };
38        
39        // Check pgvector extension
40        storage.ensure_pgvector_extension().await?;
41        
42        Ok(storage)
43    }
44    
45    /// Ensure pgvector extension is installed
46    async fn ensure_pgvector_extension(&self) -> PostgresResult<()> {
47        let result = sqlx::query("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
48            .fetch_optional(&self.pool)
49            .await?;
50        
51        if result.is_none() {
52            return Err(crate::error::pgvector_extension_error());
53        }
54        
55        Ok(())
56    }
57    
58    /// Create table for an index if it doesn't exist
59    async fn ensure_table(&self, index_name: &str, dimension: usize) -> PostgresResult<()> {
60        let table_name = self.config.table_name(index_name);
61        
62        let create_table_sql = format!(
63            r#"
64            CREATE TABLE IF NOT EXISTS {} (
65                id TEXT PRIMARY KEY,
66                content TEXT,
67                embedding vector({}),
68                metadata JSONB DEFAULT '{{}}',
69                created_at TIMESTAMPTZ DEFAULT NOW(),
70                updated_at TIMESTAMPTZ DEFAULT NOW()
71            )
72            "#,
73            table_name, dimension
74        );
75        
76        sqlx::query(&create_table_sql)
77            .execute(&self.pool)
78            .await?;
79        
80        // Create updated_at trigger
81        let trigger_sql = format!(
82            r#"
83            CREATE OR REPLACE FUNCTION update_updated_at_column()
84            RETURNS TRIGGER AS $$
85            BEGIN
86                NEW.updated_at = NOW();
87                RETURN NEW;
88            END;
89            $$ language 'plpgsql';
90            
91            DROP TRIGGER IF EXISTS update_{}_updated_at ON {};
92            CREATE TRIGGER update_{}_updated_at
93                BEFORE UPDATE ON {}
94                FOR EACH ROW
95                EXECUTE FUNCTION update_updated_at_column();
96            "#,
97            index_name, table_name, index_name, table_name
98        );
99        
100        sqlx::query(&trigger_sql)
101            .execute(&self.pool)
102            .await?;
103        
104        debug!("Ensured table exists: {}", table_name);
105        Ok(())
106    }
107    
108    /// Create vector index if configured
109    async fn ensure_vector_index(&self, index_name: &str) -> PostgresResult<()> {
110        if !self.config.table.auto_create_indexes {
111            return Ok(());
112        }
113        
114        let table_name = self.config.table_name(index_name);
115        let idx_name = self.config.index_name(index_name, "embedding");
116        
117        // Check if index already exists
118        let exists = sqlx::query(
119            "SELECT 1 FROM pg_indexes WHERE tablename = $1 AND indexname = $2"
120        )
121        .bind(format!("{}{}", self.config.table.table_prefix.as_deref().unwrap_or(""), index_name))
122        .bind(&idx_name)
123        .fetch_optional(&self.pool)
124        .await?;
125        
126        if exists.is_some() {
127            return Ok(());
128        }
129        
130        let index_sql = self.config.performance.index_type
131            .create_index_sql(&table_name, &idx_name, &self.config.performance.index_params);
132        
133        if !index_sql.is_empty() {
134            sqlx::query(&index_sql)
135                .execute(&self.pool)
136                .await
137                .map_err(|e| crate::error::index_creation_error(&idx_name, &e.to_string()))?;
138            
139            debug!("Created vector index: {}", idx_name);
140        }
141        
142        Ok(())
143    }
144    
145    /// Convert similarity metric to PostgreSQL operator
146    fn similarity_operator(metric: SimilarityMetric) -> &'static str {
147        match metric {
148            SimilarityMetric::Cosine => "<=>",
149            SimilarityMetric::Euclidean => "<->",
150            SimilarityMetric::DotProduct => "<#>",
151            _ => "<=>", // Default to cosine
152        }
153    }
154    
155    /// Convert metadata to JSONB
156    fn metadata_to_jsonb(metadata: &Metadata) -> PostgresResult<JsonValue> {
157        let mut json_map = serde_json::Map::new();
158        
159        for (key, value) in metadata {
160            let json_value = match value {
161                MetadataValue::String(s) => JsonValue::String(s.clone()),
162                MetadataValue::Integer(i) => JsonValue::Number((*i).into()),
163                MetadataValue::Float(f) => {
164                    JsonValue::Number(serde_json::Number::from_f64(*f).unwrap_or_else(|| 0.into()))
165                },
166                MetadataValue::Boolean(b) => JsonValue::Bool(*b),
167                MetadataValue::Array(arr) => {
168                    let json_arr: std::result::Result<Vec<_>, PostgresError> = arr.iter()
169                        .map(|v| Self::metadata_value_to_json(v))
170                        .collect();
171                    JsonValue::Array(json_arr?)
172                },
173                MetadataValue::Object(obj) => {
174                    let mut json_obj = serde_json::Map::new();
175                    for (k, v) in obj {
176                        json_obj.insert(k.clone(), Self::metadata_value_to_json(v)?);
177                    }
178                    JsonValue::Object(json_obj)
179                },
180                MetadataValue::Null => JsonValue::Null,
181            };
182            json_map.insert(key.clone(), json_value);
183        }
184        
185        Ok(JsonValue::Object(json_map))
186    }
187    
188    /// Convert single metadata value to JSON
189    fn metadata_value_to_json(value: &MetadataValue) -> PostgresResult<JsonValue> {
190        match value {
191            MetadataValue::String(s) => Ok(JsonValue::String(s.clone())),
192            MetadataValue::Integer(i) => Ok(JsonValue::Number((*i).into())),
193            MetadataValue::Float(f) => {
194                Ok(JsonValue::Number(serde_json::Number::from_f64(*f).unwrap_or_else(|| 0.into())))
195            },
196            MetadataValue::Boolean(b) => Ok(JsonValue::Bool(*b)),
197            MetadataValue::Array(arr) => {
198                let json_arr: std::result::Result<Vec<_>, PostgresError> = arr.iter()
199                    .map(Self::metadata_value_to_json)
200                    .collect();
201                Ok(JsonValue::Array(json_arr?))
202            },
203            MetadataValue::Object(obj) => {
204                let mut json_obj = serde_json::Map::new();
205                for (k, v) in obj {
206                    json_obj.insert(k.clone(), Self::metadata_value_to_json(v)?);
207                }
208                Ok(JsonValue::Object(json_obj))
209            },
210            MetadataValue::Null => Ok(JsonValue::Null),
211        }
212    }
213    
214    /// Convert JSONB to metadata
215    fn jsonb_to_metadata(json: JsonValue) -> Metadata {
216        match json {
217            JsonValue::Object(map) => {
218                map.into_iter()
219                    .filter_map(|(k, v)| {
220                        Self::json_value_to_metadata_value(v).map(|mv| (k, mv))
221                    })
222                    .collect()
223            },
224            _ => HashMap::new(),
225        }
226    }
227    
228    /// Convert JSON value to metadata value
229    fn json_value_to_metadata_value(value: JsonValue) -> Option<MetadataValue> {
230        match value {
231            JsonValue::String(s) => Some(MetadataValue::String(s)),
232            JsonValue::Number(n) => {
233                if let Some(i) = n.as_i64() {
234                    Some(MetadataValue::Integer(i))
235                } else if let Some(f) = n.as_f64() {
236                    Some(MetadataValue::Float(f))
237                } else {
238                    None
239                }
240            },
241            JsonValue::Bool(b) => Some(MetadataValue::Boolean(b)),
242            JsonValue::Array(arr) => {
243                let metadata_arr: Option<Vec<_>> = arr.into_iter()
244                    .map(Self::json_value_to_metadata_value)
245                    .collect();
246                metadata_arr.map(MetadataValue::Array)
247            },
248            JsonValue::Object(obj) => {
249                let metadata_obj: Option<HashMap<_, _>> = obj.into_iter()
250                    .map(|(k, v)| Self::json_value_to_metadata_value(v).map(|mv| (k, mv)))
251                    .collect();
252                metadata_obj.map(MetadataValue::Object)
253            },
254            JsonValue::Null => Some(MetadataValue::Null),
255        }
256    }
257    
258    /// Set search parameters for the current session
259    async fn set_search_params(&self) -> PostgresResult<()> {
260        let params = self.config.performance.index_type
261            .search_params_sql(&self.config.performance.index_params);
262
263        for param_sql in params {
264            sqlx::query(&param_sql)
265                .execute(&self.pool)
266                .await?;
267        }
268
269        Ok(())
270    }
271}
272
273#[async_trait]
274impl VectorStorage for PostgresVectorStorage {
275    type Config = PostgresConfig;
276
277    #[instrument(skip(self))]
278    async fn create_index(&self, config: IndexConfig) -> Result<()> {
279        self.ensure_table(&config.name, config.dimension).await?;
280        self.ensure_vector_index(&config.name).await?;
281
282        debug!("Created PostgreSQL index: {}", config.name);
283        Ok(())
284    }
285
286    #[instrument(skip(self))]
287    async fn list_indexes(&self) -> Result<Vec<String>> {
288        let prefix = self.config.table.table_prefix.as_deref().unwrap_or("");
289        let schema = &self.config.table.schema;
290
291        let query = format!(
292            "SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name LIKE $2"
293        );
294
295        let rows = sqlx::query(&query)
296            .bind(schema)
297            .bind(format!("{}%", prefix))
298            .fetch_all(&self.pool)
299            .await
300            .map_err(PostgresError::from)?;
301
302        let mut indexes = Vec::new();
303        for row in rows {
304            let table_name: String = row.try_get("table_name").map_err(PostgresError::from)?;
305            if let Some(stripped) = table_name.strip_prefix(prefix) {
306                indexes.push(stripped.to_string());
307            } else {
308                indexes.push(table_name);
309            }
310        }
311
312        Ok(indexes)
313    }
314
315    #[instrument(skip(self))]
316    async fn describe_index(&self, index_name: &str) -> Result<IndexInfo> {
317        let table_name = self.config.table_name(index_name);
318
319        // Get table info
320        let table_info = sqlx::query(
321            r#"
322            SELECT
323                column_name,
324                data_type,
325                character_maximum_length
326            FROM information_schema.columns
327            WHERE table_schema = $1 AND table_name = $2 AND column_name = 'embedding'
328            "#
329        )
330        .bind(&self.config.table.schema)
331        .bind(format!("{}{}", self.config.table.table_prefix.as_deref().unwrap_or(""), index_name))
332        .fetch_optional(&self.pool)
333        .await
334        .map_err(PostgresError::from)?;
335
336        let dimension = if let Some(row) = table_info {
337            // Extract dimension from vector type
338            let data_type: String = row.try_get("data_type").map_err(PostgresError::from)?;
339            if data_type.contains("vector") {
340                // Parse dimension from vector(n) format
341                384 // Default for now, would need to parse from type
342            } else {
343                return Err(VectorError::index_not_found(index_name));
344            }
345        } else {
346            return Err(VectorError::index_not_found(index_name));
347        };
348
349        // Get row count
350        let count_query = format!("SELECT COUNT(*) as count FROM {}", table_name);
351        let count_row = sqlx::query(&count_query)
352            .fetch_one(&self.pool)
353            .await
354            .map_err(PostgresError::from)?;
355        let vector_count: i64 = count_row.try_get("count").map_err(PostgresError::from)?;
356
357        let info = IndexInfo {
358            name: index_name.to_string(),
359            dimension,
360            metric: SimilarityMetric::Cosine, // Default, could be stored in metadata
361            vector_count: vector_count as usize,
362            size_bytes: 0, // Would need to calculate
363            created_at: None,
364            updated_at: None,
365            metadata: HashMap::new(),
366        };
367
368        Ok(info)
369    }
370
371    #[instrument(skip(self))]
372    async fn delete_index(&self, index_name: &str) -> Result<()> {
373        let table_name = self.config.table_name(index_name);
374
375        let drop_sql = format!("DROP TABLE IF EXISTS {} CASCADE", table_name);
376        sqlx::query(&drop_sql)
377            .execute(&self.pool)
378            .await
379            .map_err(PostgresError::from)?;
380
381        debug!("Deleted PostgreSQL table: {}", table_name);
382        Ok(())
383    }
384
385    async fn upsert_documents(&self, index_name: &str, documents: Vec<Document>) -> Result<Vec<DocumentId>> {
386        let table_name = self.config.table_name(index_name);
387        let mut ids = Vec::new();
388
389        // Process in batches
390        for chunk in documents.chunks(self.config.performance.batch_size) {
391            let mut query_builder = sqlx::QueryBuilder::new(
392                format!("INSERT INTO {} (id, content, embedding, metadata) ", table_name)
393            );
394
395            query_builder.push_values(chunk, |mut b, doc| {
396                let embedding = doc.embedding.as_ref()
397                    .ok_or_else(|| VectorError::InvalidVector("Document must have embedding".to_string()))
398                    .unwrap();
399
400                let metadata_json = Self::metadata_to_jsonb(&doc.metadata).unwrap();
401
402                b.push_bind(&doc.id)
403                    .push_bind(&doc.content)
404                    .push_bind(embedding)
405                    .push_bind(metadata_json);
406
407                ids.push(doc.id.clone());
408            });
409
410            query_builder.push(" ON CONFLICT (id) DO UPDATE SET content = EXCLUDED.content, embedding = EXCLUDED.embedding, metadata = EXCLUDED.metadata, updated_at = NOW()");
411
412            let query = query_builder.build();
413            query.execute(&self.pool).await.map_err(PostgresError::from)?;
414        }
415
416        debug!("Upserted {} documents to table: {}", ids.len(), table_name);
417        Ok(ids)
418    }
419
420    #[instrument(skip(self, request))]
421    async fn search(&self, request: SearchRequest) -> Result<SearchResponse> {
422        let table_name = self.config.table_name(&request.index_name);
423
424        // Set search parameters
425        self.set_search_params().await?;
426
427        let query_vector = match &request.query {
428            SearchQuery::Vector(vec) => vec.clone(),
429            SearchQuery::Text(_) => {
430                return Err(VectorError::NotSupported("Text search not implemented for PostgreSQL backend".to_string()));
431            },
432        };
433
434        // Build the search query
435        let operator = Self::similarity_operator(SimilarityMetric::Cosine); // TODO: Get from index config
436        let mut query = format!(
437            "SELECT id, content, embedding, metadata, (embedding {} $1) as distance FROM {} ",
438            operator, table_name
439        );
440
441        let mut bind_index = 2;
442
443        // Add filter conditions if present
444        if let Some(_filter) = &request.filter {
445            // TODO: Implement filter conversion to SQL WHERE clause
446            warn!("Filters not yet implemented for PostgreSQL backend");
447        }
448
449        query.push_str(&format!(" ORDER BY distance LIMIT {}", request.top_k));
450
451        let rows = sqlx::query(&query)
452            .bind(&query_vector)
453            .fetch_all(&self.pool)
454            .await
455            .map_err(PostgresError::from)?;
456
457        let mut results = Vec::new();
458        for row in rows {
459            let id: String = row.try_get("id").map_err(PostgresError::from)?;
460            let content: String = row.try_get("content").map_err(PostgresError::from)?;
461            let distance: f32 = row.try_get("distance").map_err(PostgresError::from)?;
462            let metadata_json: JsonValue = row.try_get("metadata").map_err(PostgresError::from)?;
463
464            let embedding = if request.include_vectors {
465                let embedding_data: Vec<f32> = row.try_get("embedding").map_err(PostgresError::from)?;
466                Some(embedding_data)
467            } else {
468                None
469            };
470
471            let metadata = if request.include_metadata {
472                Self::jsonb_to_metadata(metadata_json)
473            } else {
474                HashMap::new()
475            };
476
477            let result = SearchResult {
478                id,
479                content: Some(content),
480                vector: embedding,
481                metadata: Some(metadata),
482                score: 1.0 - distance, // Convert distance to similarity score
483            };
484
485            results.push(result);
486        }
487
488        Ok(SearchResponse {
489            results,
490            total_count: None, // Could implement with separate count query
491            execution_time_ms: None,
492            metadata: HashMap::new(),
493        })
494    }
495
496    #[instrument(skip(self))]
497    async fn update_document(&self, index_name: &str, document: Document) -> Result<()> {
498        // For PostgreSQL, update is the same as upsert
499        self.upsert_documents(index_name, vec![document]).await?;
500        Ok(())
501    }
502
503    #[instrument(skip(self))]
504    async fn delete_documents(&self, index_name: &str, ids: Vec<DocumentId>) -> Result<()> {
505        let table_name = self.config.table_name(index_name);
506
507        if ids.is_empty() {
508            return Ok(());
509        }
510
511        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
512        let query = format!(
513            "DELETE FROM {} WHERE id IN ({})",
514            table_name,
515            placeholders.join(", ")
516        );
517
518        let mut sqlx_query = sqlx::query(&query);
519        for id in &ids {
520            sqlx_query = sqlx_query.bind(id);
521        }
522
523        let result = sqlx_query.execute(&self.pool).await.map_err(PostgresError::from)?;
524        let deleted_count = result.rows_affected() as usize;
525
526        debug!("Deleted {} documents from table: {}", deleted_count, table_name);
527        Ok(())
528    }
529
530    #[instrument(skip(self))]
531    async fn get_documents(&self, index_name: &str, ids: Vec<DocumentId>, include_vectors: bool) -> Result<Vec<Document>> {
532        let table_name = self.config.table_name(index_name);
533
534        if ids.is_empty() {
535            return Ok(vec![]);
536        }
537
538        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
539        let vector_select = if include_vectors { ", embedding" } else { "" };
540        let query = format!(
541            "SELECT id, content, metadata{} FROM {} WHERE id IN ({})",
542            vector_select,
543            table_name,
544            placeholders.join(", ")
545        );
546
547        let mut sqlx_query = sqlx::query(&query);
548        for id in &ids {
549            sqlx_query = sqlx_query.bind(id);
550        }
551
552        let rows = sqlx_query.fetch_all(&self.pool).await.map_err(PostgresError::from)?;
553
554        let mut documents = Vec::new();
555        for row in rows {
556            let id: String = row.try_get("id").map_err(PostgresError::from)?;
557            let content: String = row.try_get("content").map_err(PostgresError::from)?;
558            let metadata_json: JsonValue = row.try_get("metadata").map_err(PostgresError::from)?;
559
560            let embedding = if include_vectors {
561                let embedding_data: Vec<f32> = row.try_get("embedding").map_err(PostgresError::from)?;
562                Some(embedding_data)
563            } else {
564                None
565            };
566
567            let metadata = Self::jsonb_to_metadata(metadata_json);
568
569            let document = Document {
570                id,
571                content,
572                embedding,
573                metadata,
574            };
575
576            documents.push(document);
577        }
578
579        Ok(documents)
580    }
581
582    #[instrument(skip(self))]
583    async fn health_check(&self) -> Result<()> {
584        sqlx::query("SELECT 1")
585            .fetch_one(&self.pool)
586            .await
587            .map_err(PostgresError::from)?;
588
589        // Check pgvector extension
590        self.ensure_pgvector_extension().await?;
591
592        Ok(())
593    }
594
595    fn backend_info(&self) -> BackendInfo {
596        BackendInfo {
597            name: "PostgreSQL".to_string(),
598            version: "1.0.0".to_string(),
599            features: vec![
600                "persistent".to_string(),
601                "transactions".to_string(),
602                "sql_queries".to_string(),
603                "metadata_filtering".to_string(),
604                "vector_indexes".to_string(),
605            ],
606            metadata: HashMap::new(),
607        }
608    }
609}