1use std::collections::HashMap;
4use async_trait::async_trait;
5use sqlx::{PgPool, Row, postgres::PgPoolOptions};
6use serde_json::Value as JsonValue;
7use tracing::{debug, instrument, warn};
8
9use lumosai_vector_core::prelude::*;
10use crate::{PostgresConfig, PostgresError, PostgresResult};
11
12pub struct PostgresVectorStorage {
14 pool: PgPool,
15 config: PostgresConfig,
16}
17
18impl PostgresVectorStorage {
19 pub async fn new(database_url: &str) -> Result<Self> {
21 let config = PostgresConfig::new(database_url);
22 Self::with_config(config).await
23 }
24
25 pub async fn with_config(config: PostgresConfig) -> Result<Self> {
27 let pool = PgPoolOptions::new()
28 .max_connections(config.pool.max_connections)
29 .min_connections(config.pool.min_connections)
30 .acquire_timeout(config.pool.connect_timeout)
31 .idle_timeout(config.pool.idle_timeout)
32 .max_lifetime(config.pool.max_lifetime)
33 .connect(&config.database_url)
34 .await
35 .map_err(PostgresError::from)?;
36
37 let storage = Self { pool, config };
38
39 storage.ensure_pgvector_extension().await?;
41
42 Ok(storage)
43 }
44
45 async fn ensure_pgvector_extension(&self) -> PostgresResult<()> {
47 let result = sqlx::query("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
48 .fetch_optional(&self.pool)
49 .await?;
50
51 if result.is_none() {
52 return Err(crate::error::pgvector_extension_error());
53 }
54
55 Ok(())
56 }
57
58 async fn ensure_table(&self, index_name: &str, dimension: usize) -> PostgresResult<()> {
60 let table_name = self.config.table_name(index_name);
61
62 let create_table_sql = format!(
63 r#"
64 CREATE TABLE IF NOT EXISTS {} (
65 id TEXT PRIMARY KEY,
66 content TEXT,
67 embedding vector({}),
68 metadata JSONB DEFAULT '{{}}',
69 created_at TIMESTAMPTZ DEFAULT NOW(),
70 updated_at TIMESTAMPTZ DEFAULT NOW()
71 )
72 "#,
73 table_name, dimension
74 );
75
76 sqlx::query(&create_table_sql)
77 .execute(&self.pool)
78 .await?;
79
80 let trigger_sql = format!(
82 r#"
83 CREATE OR REPLACE FUNCTION update_updated_at_column()
84 RETURNS TRIGGER AS $$
85 BEGIN
86 NEW.updated_at = NOW();
87 RETURN NEW;
88 END;
89 $$ language 'plpgsql';
90
91 DROP TRIGGER IF EXISTS update_{}_updated_at ON {};
92 CREATE TRIGGER update_{}_updated_at
93 BEFORE UPDATE ON {}
94 FOR EACH ROW
95 EXECUTE FUNCTION update_updated_at_column();
96 "#,
97 index_name, table_name, index_name, table_name
98 );
99
100 sqlx::query(&trigger_sql)
101 .execute(&self.pool)
102 .await?;
103
104 debug!("Ensured table exists: {}", table_name);
105 Ok(())
106 }
107
108 async fn ensure_vector_index(&self, index_name: &str) -> PostgresResult<()> {
110 if !self.config.table.auto_create_indexes {
111 return Ok(());
112 }
113
114 let table_name = self.config.table_name(index_name);
115 let idx_name = self.config.index_name(index_name, "embedding");
116
117 let exists = sqlx::query(
119 "SELECT 1 FROM pg_indexes WHERE tablename = $1 AND indexname = $2"
120 )
121 .bind(format!("{}{}", self.config.table.table_prefix.as_deref().unwrap_or(""), index_name))
122 .bind(&idx_name)
123 .fetch_optional(&self.pool)
124 .await?;
125
126 if exists.is_some() {
127 return Ok(());
128 }
129
130 let index_sql = self.config.performance.index_type
131 .create_index_sql(&table_name, &idx_name, &self.config.performance.index_params);
132
133 if !index_sql.is_empty() {
134 sqlx::query(&index_sql)
135 .execute(&self.pool)
136 .await
137 .map_err(|e| crate::error::index_creation_error(&idx_name, &e.to_string()))?;
138
139 debug!("Created vector index: {}", idx_name);
140 }
141
142 Ok(())
143 }
144
145 fn similarity_operator(metric: SimilarityMetric) -> &'static str {
147 match metric {
148 SimilarityMetric::Cosine => "<=>",
149 SimilarityMetric::Euclidean => "<->",
150 SimilarityMetric::DotProduct => "<#>",
151 _ => "<=>", }
153 }
154
155 fn metadata_to_jsonb(metadata: &Metadata) -> PostgresResult<JsonValue> {
157 let mut json_map = serde_json::Map::new();
158
159 for (key, value) in metadata {
160 let json_value = match value {
161 MetadataValue::String(s) => JsonValue::String(s.clone()),
162 MetadataValue::Integer(i) => JsonValue::Number((*i).into()),
163 MetadataValue::Float(f) => {
164 JsonValue::Number(serde_json::Number::from_f64(*f).unwrap_or_else(|| 0.into()))
165 },
166 MetadataValue::Boolean(b) => JsonValue::Bool(*b),
167 MetadataValue::Array(arr) => {
168 let json_arr: std::result::Result<Vec<_>, PostgresError> = arr.iter()
169 .map(|v| Self::metadata_value_to_json(v))
170 .collect();
171 JsonValue::Array(json_arr?)
172 },
173 MetadataValue::Object(obj) => {
174 let mut json_obj = serde_json::Map::new();
175 for (k, v) in obj {
176 json_obj.insert(k.clone(), Self::metadata_value_to_json(v)?);
177 }
178 JsonValue::Object(json_obj)
179 },
180 MetadataValue::Null => JsonValue::Null,
181 };
182 json_map.insert(key.clone(), json_value);
183 }
184
185 Ok(JsonValue::Object(json_map))
186 }
187
188 fn metadata_value_to_json(value: &MetadataValue) -> PostgresResult<JsonValue> {
190 match value {
191 MetadataValue::String(s) => Ok(JsonValue::String(s.clone())),
192 MetadataValue::Integer(i) => Ok(JsonValue::Number((*i).into())),
193 MetadataValue::Float(f) => {
194 Ok(JsonValue::Number(serde_json::Number::from_f64(*f).unwrap_or_else(|| 0.into())))
195 },
196 MetadataValue::Boolean(b) => Ok(JsonValue::Bool(*b)),
197 MetadataValue::Array(arr) => {
198 let json_arr: std::result::Result<Vec<_>, PostgresError> = arr.iter()
199 .map(Self::metadata_value_to_json)
200 .collect();
201 Ok(JsonValue::Array(json_arr?))
202 },
203 MetadataValue::Object(obj) => {
204 let mut json_obj = serde_json::Map::new();
205 for (k, v) in obj {
206 json_obj.insert(k.clone(), Self::metadata_value_to_json(v)?);
207 }
208 Ok(JsonValue::Object(json_obj))
209 },
210 MetadataValue::Null => Ok(JsonValue::Null),
211 }
212 }
213
214 fn jsonb_to_metadata(json: JsonValue) -> Metadata {
216 match json {
217 JsonValue::Object(map) => {
218 map.into_iter()
219 .filter_map(|(k, v)| {
220 Self::json_value_to_metadata_value(v).map(|mv| (k, mv))
221 })
222 .collect()
223 },
224 _ => HashMap::new(),
225 }
226 }
227
228 fn json_value_to_metadata_value(value: JsonValue) -> Option<MetadataValue> {
230 match value {
231 JsonValue::String(s) => Some(MetadataValue::String(s)),
232 JsonValue::Number(n) => {
233 if let Some(i) = n.as_i64() {
234 Some(MetadataValue::Integer(i))
235 } else if let Some(f) = n.as_f64() {
236 Some(MetadataValue::Float(f))
237 } else {
238 None
239 }
240 },
241 JsonValue::Bool(b) => Some(MetadataValue::Boolean(b)),
242 JsonValue::Array(arr) => {
243 let metadata_arr: Option<Vec<_>> = arr.into_iter()
244 .map(Self::json_value_to_metadata_value)
245 .collect();
246 metadata_arr.map(MetadataValue::Array)
247 },
248 JsonValue::Object(obj) => {
249 let metadata_obj: Option<HashMap<_, _>> = obj.into_iter()
250 .map(|(k, v)| Self::json_value_to_metadata_value(v).map(|mv| (k, mv)))
251 .collect();
252 metadata_obj.map(MetadataValue::Object)
253 },
254 JsonValue::Null => Some(MetadataValue::Null),
255 }
256 }
257
258 async fn set_search_params(&self) -> PostgresResult<()> {
260 let params = self.config.performance.index_type
261 .search_params_sql(&self.config.performance.index_params);
262
263 for param_sql in params {
264 sqlx::query(¶m_sql)
265 .execute(&self.pool)
266 .await?;
267 }
268
269 Ok(())
270 }
271}
272
273#[async_trait]
274impl VectorStorage for PostgresVectorStorage {
275 type Config = PostgresConfig;
276
277 #[instrument(skip(self))]
278 async fn create_index(&self, config: IndexConfig) -> Result<()> {
279 self.ensure_table(&config.name, config.dimension).await?;
280 self.ensure_vector_index(&config.name).await?;
281
282 debug!("Created PostgreSQL index: {}", config.name);
283 Ok(())
284 }
285
286 #[instrument(skip(self))]
287 async fn list_indexes(&self) -> Result<Vec<String>> {
288 let prefix = self.config.table.table_prefix.as_deref().unwrap_or("");
289 let schema = &self.config.table.schema;
290
291 let query = format!(
292 "SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name LIKE $2"
293 );
294
295 let rows = sqlx::query(&query)
296 .bind(schema)
297 .bind(format!("{}%", prefix))
298 .fetch_all(&self.pool)
299 .await
300 .map_err(PostgresError::from)?;
301
302 let mut indexes = Vec::new();
303 for row in rows {
304 let table_name: String = row.try_get("table_name").map_err(PostgresError::from)?;
305 if let Some(stripped) = table_name.strip_prefix(prefix) {
306 indexes.push(stripped.to_string());
307 } else {
308 indexes.push(table_name);
309 }
310 }
311
312 Ok(indexes)
313 }
314
315 #[instrument(skip(self))]
316 async fn describe_index(&self, index_name: &str) -> Result<IndexInfo> {
317 let table_name = self.config.table_name(index_name);
318
319 let table_info = sqlx::query(
321 r#"
322 SELECT
323 column_name,
324 data_type,
325 character_maximum_length
326 FROM information_schema.columns
327 WHERE table_schema = $1 AND table_name = $2 AND column_name = 'embedding'
328 "#
329 )
330 .bind(&self.config.table.schema)
331 .bind(format!("{}{}", self.config.table.table_prefix.as_deref().unwrap_or(""), index_name))
332 .fetch_optional(&self.pool)
333 .await
334 .map_err(PostgresError::from)?;
335
336 let dimension = if let Some(row) = table_info {
337 let data_type: String = row.try_get("data_type").map_err(PostgresError::from)?;
339 if data_type.contains("vector") {
340 384 } else {
343 return Err(VectorError::index_not_found(index_name));
344 }
345 } else {
346 return Err(VectorError::index_not_found(index_name));
347 };
348
349 let count_query = format!("SELECT COUNT(*) as count FROM {}", table_name);
351 let count_row = sqlx::query(&count_query)
352 .fetch_one(&self.pool)
353 .await
354 .map_err(PostgresError::from)?;
355 let vector_count: i64 = count_row.try_get("count").map_err(PostgresError::from)?;
356
357 let info = IndexInfo {
358 name: index_name.to_string(),
359 dimension,
360 metric: SimilarityMetric::Cosine, vector_count: vector_count as usize,
362 size_bytes: 0, created_at: None,
364 updated_at: None,
365 metadata: HashMap::new(),
366 };
367
368 Ok(info)
369 }
370
371 #[instrument(skip(self))]
372 async fn delete_index(&self, index_name: &str) -> Result<()> {
373 let table_name = self.config.table_name(index_name);
374
375 let drop_sql = format!("DROP TABLE IF EXISTS {} CASCADE", table_name);
376 sqlx::query(&drop_sql)
377 .execute(&self.pool)
378 .await
379 .map_err(PostgresError::from)?;
380
381 debug!("Deleted PostgreSQL table: {}", table_name);
382 Ok(())
383 }
384
385 async fn upsert_documents(&self, index_name: &str, documents: Vec<Document>) -> Result<Vec<DocumentId>> {
386 let table_name = self.config.table_name(index_name);
387 let mut ids = Vec::new();
388
389 for chunk in documents.chunks(self.config.performance.batch_size) {
391 let mut query_builder = sqlx::QueryBuilder::new(
392 format!("INSERT INTO {} (id, content, embedding, metadata) ", table_name)
393 );
394
395 query_builder.push_values(chunk, |mut b, doc| {
396 let embedding = doc.embedding.as_ref()
397 .ok_or_else(|| VectorError::InvalidVector("Document must have embedding".to_string()))
398 .unwrap();
399
400 let metadata_json = Self::metadata_to_jsonb(&doc.metadata).unwrap();
401
402 b.push_bind(&doc.id)
403 .push_bind(&doc.content)
404 .push_bind(embedding)
405 .push_bind(metadata_json);
406
407 ids.push(doc.id.clone());
408 });
409
410 query_builder.push(" ON CONFLICT (id) DO UPDATE SET content = EXCLUDED.content, embedding = EXCLUDED.embedding, metadata = EXCLUDED.metadata, updated_at = NOW()");
411
412 let query = query_builder.build();
413 query.execute(&self.pool).await.map_err(PostgresError::from)?;
414 }
415
416 debug!("Upserted {} documents to table: {}", ids.len(), table_name);
417 Ok(ids)
418 }
419
420 #[instrument(skip(self, request))]
421 async fn search(&self, request: SearchRequest) -> Result<SearchResponse> {
422 let table_name = self.config.table_name(&request.index_name);
423
424 self.set_search_params().await?;
426
427 let query_vector = match &request.query {
428 SearchQuery::Vector(vec) => vec.clone(),
429 SearchQuery::Text(_) => {
430 return Err(VectorError::NotSupported("Text search not implemented for PostgreSQL backend".to_string()));
431 },
432 };
433
434 let operator = Self::similarity_operator(SimilarityMetric::Cosine); let mut query = format!(
437 "SELECT id, content, embedding, metadata, (embedding {} $1) as distance FROM {} ",
438 operator, table_name
439 );
440
441 let mut bind_index = 2;
442
443 if let Some(_filter) = &request.filter {
445 warn!("Filters not yet implemented for PostgreSQL backend");
447 }
448
449 query.push_str(&format!(" ORDER BY distance LIMIT {}", request.top_k));
450
451 let rows = sqlx::query(&query)
452 .bind(&query_vector)
453 .fetch_all(&self.pool)
454 .await
455 .map_err(PostgresError::from)?;
456
457 let mut results = Vec::new();
458 for row in rows {
459 let id: String = row.try_get("id").map_err(PostgresError::from)?;
460 let content: String = row.try_get("content").map_err(PostgresError::from)?;
461 let distance: f32 = row.try_get("distance").map_err(PostgresError::from)?;
462 let metadata_json: JsonValue = row.try_get("metadata").map_err(PostgresError::from)?;
463
464 let embedding = if request.include_vectors {
465 let embedding_data: Vec<f32> = row.try_get("embedding").map_err(PostgresError::from)?;
466 Some(embedding_data)
467 } else {
468 None
469 };
470
471 let metadata = if request.include_metadata {
472 Self::jsonb_to_metadata(metadata_json)
473 } else {
474 HashMap::new()
475 };
476
477 let result = SearchResult {
478 id,
479 content: Some(content),
480 vector: embedding,
481 metadata: Some(metadata),
482 score: 1.0 - distance, };
484
485 results.push(result);
486 }
487
488 Ok(SearchResponse {
489 results,
490 total_count: None, execution_time_ms: None,
492 metadata: HashMap::new(),
493 })
494 }
495
496 #[instrument(skip(self))]
497 async fn update_document(&self, index_name: &str, document: Document) -> Result<()> {
498 self.upsert_documents(index_name, vec![document]).await?;
500 Ok(())
501 }
502
503 #[instrument(skip(self))]
504 async fn delete_documents(&self, index_name: &str, ids: Vec<DocumentId>) -> Result<()> {
505 let table_name = self.config.table_name(index_name);
506
507 if ids.is_empty() {
508 return Ok(());
509 }
510
511 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
512 let query = format!(
513 "DELETE FROM {} WHERE id IN ({})",
514 table_name,
515 placeholders.join(", ")
516 );
517
518 let mut sqlx_query = sqlx::query(&query);
519 for id in &ids {
520 sqlx_query = sqlx_query.bind(id);
521 }
522
523 let result = sqlx_query.execute(&self.pool).await.map_err(PostgresError::from)?;
524 let deleted_count = result.rows_affected() as usize;
525
526 debug!("Deleted {} documents from table: {}", deleted_count, table_name);
527 Ok(())
528 }
529
530 #[instrument(skip(self))]
531 async fn get_documents(&self, index_name: &str, ids: Vec<DocumentId>, include_vectors: bool) -> Result<Vec<Document>> {
532 let table_name = self.config.table_name(index_name);
533
534 if ids.is_empty() {
535 return Ok(vec![]);
536 }
537
538 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
539 let vector_select = if include_vectors { ", embedding" } else { "" };
540 let query = format!(
541 "SELECT id, content, metadata{} FROM {} WHERE id IN ({})",
542 vector_select,
543 table_name,
544 placeholders.join(", ")
545 );
546
547 let mut sqlx_query = sqlx::query(&query);
548 for id in &ids {
549 sqlx_query = sqlx_query.bind(id);
550 }
551
552 let rows = sqlx_query.fetch_all(&self.pool).await.map_err(PostgresError::from)?;
553
554 let mut documents = Vec::new();
555 for row in rows {
556 let id: String = row.try_get("id").map_err(PostgresError::from)?;
557 let content: String = row.try_get("content").map_err(PostgresError::from)?;
558 let metadata_json: JsonValue = row.try_get("metadata").map_err(PostgresError::from)?;
559
560 let embedding = if include_vectors {
561 let embedding_data: Vec<f32> = row.try_get("embedding").map_err(PostgresError::from)?;
562 Some(embedding_data)
563 } else {
564 None
565 };
566
567 let metadata = Self::jsonb_to_metadata(metadata_json);
568
569 let document = Document {
570 id,
571 content,
572 embedding,
573 metadata,
574 };
575
576 documents.push(document);
577 }
578
579 Ok(documents)
580 }
581
582 #[instrument(skip(self))]
583 async fn health_check(&self) -> Result<()> {
584 sqlx::query("SELECT 1")
585 .fetch_one(&self.pool)
586 .await
587 .map_err(PostgresError::from)?;
588
589 self.ensure_pgvector_extension().await?;
591
592 Ok(())
593 }
594
595 fn backend_info(&self) -> BackendInfo {
596 BackendInfo {
597 name: "PostgreSQL".to_string(),
598 version: "1.0.0".to_string(),
599 features: vec![
600 "persistent".to_string(),
601 "transactions".to_string(),
602 "sql_queries".to_string(),
603 "metadata_filtering".to_string(),
604 "vector_indexes".to_string(),
605 ],
606 metadata: HashMap::new(),
607 }
608 }
609}