Skip to main content

ceres_db/
repository.rs

1//! Dataset repository for PostgreSQL with pgvector support.
2//!
3//! # Testing
4//!
5//! TODO(#12): Improve test coverage for repository methods
6//! Current tests only cover struct/serialization. Integration tests needed for:
7//! - `upsert()` - insert and update paths
8//! - `search()` - vector similarity queries
9//! - `get_hashes_for_portal()` - delta detection queries
10//! - `update_timestamp_only()` - timestamp-only updates
11//!
12//! Consider using testcontainers-rs for isolated PostgreSQL instances:
13//! <https://github.com/testcontainers/testcontainers-rs>
14//!
15//! See: <https://github.com/AndreaBozzo/Ceres/issues/12>
16
17use ceres_core::error::AppError;
18use ceres_core::models::{DatabaseStats, Dataset, NewDataset, SearchResult};
19use chrono::{DateTime, Utc};
20use futures::StreamExt;
21use futures::stream::BoxStream;
22use pgvector::Vector;
23use sqlx::types::Json;
24use sqlx::{PgPool, Pool, Postgres};
25use std::collections::HashMap;
26use uuid::Uuid;
27
28/// Column list for SELECT queries. Must remain a const literal to ensure SQL safety
29/// since format!() bypasses sqlx compile-time validation.
30const DATASET_COLUMNS: &str = "id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash";
31
32// Static queries for list_all_stream to avoid lifetime issues with BoxStream
33const LIST_ALL_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets ORDER BY last_updated_at DESC";
34const LIST_ALL_LIMIT_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets ORDER BY last_updated_at DESC LIMIT $1";
35const LIST_ALL_PORTAL_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC";
36const LIST_ALL_PORTAL_LIMIT_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC LIMIT $2";
37
38/// Repository for dataset persistence in PostgreSQL with pgvector.
39///
40/// # Examples
41///
42/// ```no_run
43/// use sqlx::postgres::PgPoolOptions;
44/// use ceres_db::DatasetRepository;
45///
46/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
47/// let pool = PgPoolOptions::new()
48///     .max_connections(5)
49///     .connect("postgresql://localhost/ceres")
50///     .await?;
51///
52/// let repo = DatasetRepository::new(pool);
53/// # Ok(())
54/// # }
55/// ```
56#[derive(Clone)]
57pub struct DatasetRepository {
58    pool: Pool<Postgres>,
59}
60
61impl DatasetRepository {
62    pub fn new(pool: PgPool) -> Self {
63        Self { pool }
64    }
65
66    /// Inserts or updates a dataset. Returns the UUID of the affected row.
67    ///
68    /// TODO(robustness): Return UpsertOutcome to distinguish insert vs update
69    /// Currently returns only UUID without indicating operation type.
70    /// Consider: `pub enum UpsertOutcome { Created(Uuid), Updated(Uuid) }`
71    /// This enables accurate progress reporting in sync statistics.
72    pub async fn upsert(&self, new_data: &NewDataset) -> Result<Uuid, AppError> {
73        let embedding_vector = new_data.embedding.as_ref().cloned();
74
75        let rec: (Uuid,) = sqlx::query_as(
76            r#"
77            INSERT INTO datasets (
78                original_id,
79                source_portal,
80                url,
81                title,
82                description,
83                embedding,
84                metadata,
85                content_hash,
86                last_updated_at
87            )
88            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
89            ON CONFLICT (source_portal, original_id)
90            DO UPDATE SET
91                title = EXCLUDED.title,
92                description = EXCLUDED.description,
93                url = EXCLUDED.url,
94                embedding = COALESCE(EXCLUDED.embedding, datasets.embedding),
95                metadata = EXCLUDED.metadata,
96                content_hash = EXCLUDED.content_hash,
97                last_updated_at = NOW()
98            RETURNING id
99            "#,
100        )
101        .bind(&new_data.original_id)
102        .bind(&new_data.source_portal)
103        .bind(&new_data.url)
104        .bind(&new_data.title)
105        .bind(&new_data.description)
106        .bind(embedding_vector)
107        .bind(serde_json::to_value(&new_data.metadata).unwrap_or(serde_json::json!({})))
108        .bind(&new_data.content_hash)
109        .fetch_one(&self.pool)
110        .await
111        .map_err(AppError::DatabaseError)?;
112
113        Ok(rec.0)
114    }
115
116    /// Returns a map of original_id → content_hash for all datasets from a portal.
117    ///
118    /// TODO(performance): Optimize for large portals (100k+ datasets)
119    /// Currently loads entire HashMap into memory. Consider:
120    /// (1) Streaming hash comparison during sync, or
121    /// (2) Database-side hash check with WHERE clause, or
122    /// (3) Bloom filter for approximate membership testing
123    pub async fn get_hashes_for_portal(
124        &self,
125        portal_url: &str,
126    ) -> Result<HashMap<String, Option<String>>, AppError> {
127        let rows: Vec<HashRow> = sqlx::query_as(
128            r#"
129            SELECT original_id, content_hash
130            FROM datasets
131            WHERE source_portal = $1
132            "#,
133        )
134        .bind(portal_url)
135        .fetch_all(&self.pool)
136        .await
137        .map_err(AppError::DatabaseError)?;
138
139        let hash_map: HashMap<String, Option<String>> = rows
140            .into_iter()
141            .map(|row| (row.original_id, row.content_hash))
142            .collect();
143
144        Ok(hash_map)
145    }
146
147    /// Updates only the timestamp for unchanged datasets. Returns true if a row was updated.
148    pub async fn update_timestamp_only(
149        &self,
150        portal_url: &str,
151        original_id: &str,
152    ) -> Result<bool, AppError> {
153        let result = sqlx::query(
154            r#"
155            UPDATE datasets
156            SET last_updated_at = NOW()
157            WHERE source_portal = $1 AND original_id = $2
158            "#,
159        )
160        .bind(portal_url)
161        .bind(original_id)
162        .execute(&self.pool)
163        .await
164        .map_err(AppError::DatabaseError)?;
165
166        Ok(result.rows_affected() > 0)
167    }
168
169    /// Batch updates timestamps for multiple unchanged datasets.
170    ///
171    /// Uses a single UPDATE with ANY() for efficiency instead of N individual updates.
172    /// Returns the number of rows actually updated.
173    pub async fn batch_update_timestamps(
174        &self,
175        portal_url: &str,
176        original_ids: &[String],
177    ) -> Result<u64, AppError> {
178        if original_ids.is_empty() {
179            return Ok(0);
180        }
181
182        let result = sqlx::query(
183            r#"
184            UPDATE datasets
185            SET last_updated_at = NOW()
186            WHERE source_portal = $1 AND original_id = ANY($2)
187            "#,
188        )
189        .bind(portal_url)
190        .bind(original_ids)
191        .execute(&self.pool)
192        .await
193        .map_err(AppError::DatabaseError)?;
194
195        Ok(result.rows_affected())
196    }
197
198    /// Retrieves a dataset by UUID.
199    pub async fn get(&self, id: Uuid) -> Result<Option<Dataset>, AppError> {
200        let query = format!("SELECT {} FROM datasets WHERE id = $1", DATASET_COLUMNS);
201        let result = sqlx::query_as::<_, Dataset>(&query)
202            .bind(id)
203            .fetch_optional(&self.pool)
204            .await
205            .map_err(AppError::DatabaseError)?;
206
207        Ok(result)
208    }
209
210    /// Semantic search using cosine similarity. Returns results ordered by similarity.
211    pub async fn search(
212        &self,
213        query_vector: Vector,
214        limit: usize,
215    ) -> Result<Vec<SearchResult>, AppError> {
216        let query = format!(
217            "SELECT {}, 1 - (embedding <=> $1) as similarity_score FROM datasets WHERE embedding IS NOT NULL ORDER BY embedding <=> $1 LIMIT $2",
218            DATASET_COLUMNS
219        );
220        let results = sqlx::query_as::<_, SearchResultRow>(&query)
221            .bind(query_vector)
222            .bind(limit as i64)
223            .fetch_all(&self.pool)
224            .await
225            .map_err(AppError::DatabaseError)?;
226
227        Ok(results
228            .into_iter()
229            .map(|row| SearchResult {
230                dataset: Dataset {
231                    id: row.id,
232                    original_id: row.original_id,
233                    source_portal: row.source_portal,
234                    url: row.url,
235                    title: row.title,
236                    description: row.description,
237                    embedding: row.embedding,
238                    metadata: row.metadata,
239                    first_seen_at: row.first_seen_at,
240                    last_updated_at: row.last_updated_at,
241                    content_hash: row.content_hash,
242                },
243                similarity_score: row.similarity_score as f32,
244            })
245            .collect())
246    }
247
248    /// Lists datasets with optional portal filter and limit.
249    ///
250    /// For memory-efficient exports of large datasets, use [`list_all_stream`] instead.
251    ///
252    /// TODO(config): Make default limit configurable via DEFAULT_EXPORT_LIMIT env var
253    /// Currently hardcoded to 10000.
254    pub async fn list_all(
255        &self,
256        portal_filter: Option<&str>,
257        limit: Option<usize>,
258    ) -> Result<Vec<Dataset>, AppError> {
259        // TODO(config): Read default from DEFAULT_EXPORT_LIMIT env var
260        let limit_val = limit.unwrap_or(10000) as i64;
261
262        let datasets = if let Some(portal) = portal_filter {
263            let query = format!(
264                "SELECT {} FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC LIMIT $2",
265                DATASET_COLUMNS
266            );
267            sqlx::query_as::<_, Dataset>(&query)
268                .bind(portal)
269                .bind(limit_val)
270                .fetch_all(&self.pool)
271                .await
272                .map_err(AppError::DatabaseError)?
273        } else {
274            let query = format!(
275                "SELECT {} FROM datasets ORDER BY last_updated_at DESC LIMIT $1",
276                DATASET_COLUMNS
277            );
278            sqlx::query_as::<_, Dataset>(&query)
279                .bind(limit_val)
280                .fetch_all(&self.pool)
281                .await
282                .map_err(AppError::DatabaseError)?
283        };
284
285        Ok(datasets)
286    }
287
288    /// Lists datasets as a stream with optional portal filter.
289    ///
290    /// Unlike [`list_all`], this method streams results directly from the database
291    /// without loading everything into memory. Suitable for large exports.
292    ///
293    /// # Arguments
294    ///
295    /// * `portal_filter` - Optional portal URL to filter by
296    /// * `limit` - Optional maximum number of records (no default limit for streaming)
297    pub fn list_all_stream<'a>(
298        &'a self,
299        portal_filter: Option<&'a str>,
300        limit: Option<usize>,
301    ) -> BoxStream<'a, Result<Dataset, AppError>> {
302        match (portal_filter, limit) {
303            (Some(portal), Some(lim)) => Box::pin(
304                sqlx::query_as::<_, Dataset>(LIST_ALL_PORTAL_LIMIT_QUERY)
305                    .bind(portal)
306                    .bind(lim as i64)
307                    .fetch(&self.pool)
308                    .map(|r| r.map_err(AppError::DatabaseError)),
309            ),
310            (Some(portal), None) => Box::pin(
311                sqlx::query_as::<_, Dataset>(LIST_ALL_PORTAL_QUERY)
312                    .bind(portal)
313                    .fetch(&self.pool)
314                    .map(|r| r.map_err(AppError::DatabaseError)),
315            ),
316            (None, Some(lim)) => Box::pin(
317                sqlx::query_as::<_, Dataset>(LIST_ALL_LIMIT_QUERY)
318                    .bind(lim as i64)
319                    .fetch(&self.pool)
320                    .map(|r| r.map_err(AppError::DatabaseError)),
321            ),
322            (None, None) => Box::pin(
323                sqlx::query_as::<_, Dataset>(LIST_ALL_QUERY)
324                    .fetch(&self.pool)
325                    .map(|r| r.map_err(AppError::DatabaseError)),
326            ),
327        }
328    }
329
330    // =========================================================================
331    // Portal Sync Status Methods (for incremental harvesting)
332    // =========================================================================
333
334    /// Retrieves the sync status for a portal.
335    /// Returns None if this portal has never been synced.
336    pub async fn get_sync_status(
337        &self,
338        portal_url: &str,
339    ) -> Result<Option<PortalSyncStatus>, AppError> {
340        let result = sqlx::query_as::<_, PortalSyncStatus>(
341            r#"
342            SELECT portal_url, last_successful_sync, last_sync_mode, sync_status, datasets_synced, created_at, updated_at
343            FROM portal_sync_status
344            WHERE portal_url = $1
345            "#,
346        )
347        .bind(portal_url)
348        .fetch_optional(&self.pool)
349        .await
350        .map_err(AppError::DatabaseError)?;
351
352        Ok(result)
353    }
354
355    /// Updates or inserts the sync status for a portal.
356    ///
357    /// The `sync_status` parameter indicates the outcome: "completed" or "cancelled".
358    /// Only updates `last_successful_sync` when status is "completed", preserving
359    /// the last successful sync time for incremental harvesting after cancellations.
360    pub async fn upsert_sync_status(
361        &self,
362        portal_url: &str,
363        last_sync: DateTime<Utc>,
364        sync_mode: &str,
365        sync_status: &str,
366        datasets_synced: i32,
367    ) -> Result<(), AppError> {
368        sqlx::query(
369            r#"
370            INSERT INTO portal_sync_status (portal_url, last_successful_sync, last_sync_mode, sync_status, datasets_synced, updated_at)
371            VALUES (
372                $1,
373                CASE WHEN $4 = 'completed' THEN $2 ELSE NULL END,
374                $3,
375                $4,
376                $5,
377                NOW()
378            )
379            ON CONFLICT (portal_url)
380            DO UPDATE SET
381                last_successful_sync = CASE
382                    WHEN EXCLUDED.sync_status = 'completed' THEN $2
383                    ELSE portal_sync_status.last_successful_sync
384                END,
385                last_sync_mode = EXCLUDED.last_sync_mode,
386                sync_status = EXCLUDED.sync_status,
387                datasets_synced = EXCLUDED.datasets_synced,
388                updated_at = NOW()
389            "#,
390        )
391        .bind(portal_url)
392        .bind(last_sync)
393        .bind(sync_mode)
394        .bind(sync_status)
395        .bind(datasets_synced)
396        .execute(&self.pool)
397        .await
398        .map_err(AppError::DatabaseError)?;
399
400        Ok(())
401    }
402
403    /// Checks database connectivity by executing a simple query.
404    pub async fn health_check(&self) -> Result<(), AppError> {
405        sqlx::query("SELECT 1")
406            .execute(&self.pool)
407            .await
408            .map_err(AppError::DatabaseError)?;
409        Ok(())
410    }
411
412    /// Returns aggregated database statistics.
413    pub async fn get_stats(&self) -> Result<DatabaseStats, AppError> {
414        let row: StatsRow = sqlx::query_as(
415            r#"
416            SELECT
417                COUNT(*) as total,
418                COUNT(embedding) as with_embeddings,
419                COUNT(DISTINCT source_portal) as portals,
420                MAX(last_updated_at) as last_update
421            FROM datasets
422            "#,
423        )
424        .fetch_one(&self.pool)
425        .await
426        .map_err(AppError::DatabaseError)?;
427
428        Ok(DatabaseStats {
429            total_datasets: row.total.unwrap_or(0),
430            datasets_with_embeddings: row.with_embeddings.unwrap_or(0),
431            total_portals: row.portals.unwrap_or(0),
432            last_update: row.last_update,
433        })
434    }
435}
436
437/// Helper struct for deserializing stats query results
438#[derive(sqlx::FromRow)]
439struct StatsRow {
440    total: Option<i64>,
441    with_embeddings: Option<i64>,
442    portals: Option<i64>,
443    last_update: Option<DateTime<Utc>>,
444}
445
446/// Helper struct for deserializing search query results
447#[derive(sqlx::FromRow)]
448struct SearchResultRow {
449    id: Uuid,
450    original_id: String,
451    source_portal: String,
452    url: String,
453    title: String,
454    description: Option<String>,
455    embedding: Option<Vector>,
456    metadata: Json<serde_json::Value>,
457    first_seen_at: DateTime<Utc>,
458    last_updated_at: DateTime<Utc>,
459    content_hash: Option<String>,
460    similarity_score: f64,
461}
462
463/// Helper struct for deserializing hash lookup query results
464#[derive(sqlx::FromRow)]
465struct HashRow {
466    original_id: String,
467    content_hash: Option<String>,
468}
469
470/// Represents the sync status for a portal, used for incremental harvesting.
471#[derive(Debug, Clone, sqlx::FromRow)]
472pub struct PortalSyncStatus {
473    pub portal_url: String,
474    pub last_successful_sync: Option<DateTime<Utc>>,
475    pub last_sync_mode: Option<String>,
476    pub sync_status: Option<String>,
477    pub datasets_synced: i32,
478    pub created_at: DateTime<Utc>,
479    pub updated_at: DateTime<Utc>,
480}
481
482// =============================================================================
483// Trait Implementation: DatasetStore
484// =============================================================================
485
486impl ceres_core::traits::DatasetStore for DatasetRepository {
487    async fn get_by_id(&self, id: Uuid) -> Result<Option<Dataset>, AppError> {
488        DatasetRepository::get(self, id).await
489    }
490
491    async fn get_hashes_for_portal(
492        &self,
493        portal_url: &str,
494    ) -> Result<HashMap<String, Option<String>>, AppError> {
495        DatasetRepository::get_hashes_for_portal(self, portal_url).await
496    }
497
498    async fn update_timestamp_only(
499        &self,
500        portal_url: &str,
501        original_id: &str,
502    ) -> Result<(), AppError> {
503        DatasetRepository::update_timestamp_only(self, portal_url, original_id).await?;
504        Ok(())
505    }
506
507    async fn batch_update_timestamps(
508        &self,
509        portal_url: &str,
510        original_ids: &[String],
511    ) -> Result<u64, AppError> {
512        DatasetRepository::batch_update_timestamps(self, portal_url, original_ids).await
513    }
514
515    async fn upsert(&self, dataset: &NewDataset) -> Result<Uuid, AppError> {
516        DatasetRepository::upsert(self, dataset).await
517    }
518
519    async fn search(
520        &self,
521        query_vector: Vector,
522        limit: usize,
523    ) -> Result<Vec<SearchResult>, AppError> {
524        DatasetRepository::search(self, query_vector, limit).await
525    }
526
527    fn list_stream<'a>(
528        &'a self,
529        portal_filter: Option<&'a str>,
530        limit: Option<usize>,
531    ) -> BoxStream<'a, Result<Dataset, AppError>> {
532        DatasetRepository::list_all_stream(self, portal_filter, limit)
533    }
534
535    async fn get_last_sync_time(
536        &self,
537        portal_url: &str,
538    ) -> Result<Option<DateTime<Utc>>, AppError> {
539        let status = DatasetRepository::get_sync_status(self, portal_url).await?;
540        Ok(status.and_then(|s| s.last_successful_sync))
541    }
542
543    async fn record_sync_status(
544        &self,
545        portal_url: &str,
546        sync_time: DateTime<Utc>,
547        sync_mode: &str,
548        sync_status: &str,
549        datasets_synced: i32,
550    ) -> Result<(), AppError> {
551        DatasetRepository::upsert_sync_status(
552            self,
553            portal_url,
554            sync_time,
555            sync_mode,
556            sync_status,
557            datasets_synced,
558        )
559        .await
560    }
561
562    async fn health_check(&self) -> Result<(), AppError> {
563        DatasetRepository::health_check(self).await
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use serde_json::json;
571
572    #[test]
573    fn test_new_dataset_structure() {
574        let title = "Test Dataset";
575        let description = Some("Test description".to_string());
576        let content_hash = NewDataset::compute_content_hash(title, description.as_deref());
577
578        let new_dataset = NewDataset {
579            original_id: "test-id".to_string(),
580            source_portal: "https://example.com".to_string(),
581            url: "https://example.com/dataset/test".to_string(),
582            title: title.to_string(),
583            description,
584            embedding: Some(Vector::from(vec![0.1, 0.2, 0.3])),
585            metadata: json!({"key": "value"}),
586            content_hash,
587        };
588
589        assert_eq!(new_dataset.original_id, "test-id");
590        assert_eq!(new_dataset.title, "Test Dataset");
591        assert!(new_dataset.embedding.is_some());
592        assert_eq!(new_dataset.content_hash.len(), 64);
593    }
594
595    #[test]
596    fn test_embedding_vector_conversion() {
597        let vec_f32 = vec![0.1_f32, 0.2, 0.3, 0.4];
598        let vector = Vector::from(vec_f32.clone());
599        assert_eq!(vector.as_slice().len(), vec_f32.len());
600    }
601
602    #[test]
603    fn test_metadata_serialization() {
604        let metadata = json!({
605            "organization": "test-org",
606            "tags": ["tag1", "tag2"]
607        });
608
609        let serialized = serde_json::to_value(&metadata).unwrap();
610        assert!(serialized.is_object());
611        assert_eq!(serialized["organization"], "test-org");
612    }
613}