ceres_db/
repository.rs

1//! Dataset repository for PostgreSQL with pgvector support.
2//!
3//! # Testing
4//!
5//! TODO(#12): Improve test coverage for repository methods
6//! Current tests only cover struct/serialization. Integration tests needed for:
7//! - `upsert()` - insert and update paths
8//! - `search()` - vector similarity queries
9//! - `get_hashes_for_portal()` - delta detection queries
10//! - `update_timestamp_only()` - timestamp-only updates
11//!
12//! Consider using testcontainers-rs for isolated PostgreSQL instances:
13//! <https://github.com/testcontainers/testcontainers-rs>
14//!
15//! See: <https://github.com/AndreaBozzo/Ceres/issues/12>
16
17use ceres_core::error::AppError;
18use ceres_core::models::{DatabaseStats, Dataset, NewDataset, SearchResult};
19use chrono::{DateTime, Utc};
20use pgvector::Vector;
21use sqlx::types::Json;
22use sqlx::{PgPool, Pool, Postgres};
23use std::collections::HashMap;
24use uuid::Uuid;
25
26/// Column list for SELECT queries. Must remain a const literal to ensure SQL safety
27/// since format!() bypasses sqlx compile-time validation.
28const DATASET_COLUMNS: &str = "id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash";
29
30/// Repository for dataset persistence in PostgreSQL with pgvector.
31///
32/// # Examples
33///
34/// ```no_run
35/// use sqlx::postgres::PgPoolOptions;
36/// use ceres_db::DatasetRepository;
37///
38/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
39/// let pool = PgPoolOptions::new()
40///     .max_connections(5)
41///     .connect("postgresql://localhost/ceres")
42///     .await?;
43///
44/// let repo = DatasetRepository::new(pool);
45/// # Ok(())
46/// # }
47/// ```
48#[derive(Clone)]
49pub struct DatasetRepository {
50    pool: Pool<Postgres>,
51}
52
53impl DatasetRepository {
54    pub fn new(pool: PgPool) -> Self {
55        Self { pool }
56    }
57
58    /// Inserts or updates a dataset. Returns the UUID of the affected row.
59    ///
60    /// TODO(robustness): Return UpsertOutcome to distinguish insert vs update
61    /// Currently returns only UUID without indicating operation type.
62    /// Consider: `pub enum UpsertOutcome { Created(Uuid), Updated(Uuid) }`
63    /// This enables accurate progress reporting in sync statistics.
64    pub async fn upsert(&self, new_data: &NewDataset) -> Result<Uuid, AppError> {
65        let embedding_vector = new_data.embedding.as_ref().cloned();
66
67        let rec: (Uuid,) = sqlx::query_as(
68            r#"
69            INSERT INTO datasets (
70                original_id,
71                source_portal,
72                url,
73                title,
74                description,
75                embedding,
76                metadata,
77                content_hash,
78                last_updated_at
79            )
80            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
81            ON CONFLICT (source_portal, original_id)
82            DO UPDATE SET
83                title = EXCLUDED.title,
84                description = EXCLUDED.description,
85                url = EXCLUDED.url,
86                embedding = COALESCE(EXCLUDED.embedding, datasets.embedding),
87                metadata = EXCLUDED.metadata,
88                content_hash = EXCLUDED.content_hash,
89                last_updated_at = NOW()
90            RETURNING id
91            "#,
92        )
93        .bind(&new_data.original_id)
94        .bind(&new_data.source_portal)
95        .bind(&new_data.url)
96        .bind(&new_data.title)
97        .bind(&new_data.description)
98        .bind(embedding_vector)
99        .bind(serde_json::to_value(&new_data.metadata).unwrap_or(serde_json::json!({})))
100        .bind(&new_data.content_hash)
101        .fetch_one(&self.pool)
102        .await
103        .map_err(AppError::DatabaseError)?;
104
105        Ok(rec.0)
106    }
107
108    /// Returns a map of original_id → content_hash for all datasets from a portal.
109    ///
110    /// TODO(performance): Optimize for large portals (100k+ datasets)
111    /// Currently loads entire HashMap into memory. Consider:
112    /// (1) Streaming hash comparison during sync, or
113    /// (2) Database-side hash check with WHERE clause, or
114    /// (3) Bloom filter for approximate membership testing
115    pub async fn get_hashes_for_portal(
116        &self,
117        portal_url: &str,
118    ) -> Result<HashMap<String, Option<String>>, AppError> {
119        let rows: Vec<HashRow> = sqlx::query_as(
120            r#"
121            SELECT original_id, content_hash
122            FROM datasets
123            WHERE source_portal = $1
124            "#,
125        )
126        .bind(portal_url)
127        .fetch_all(&self.pool)
128        .await
129        .map_err(AppError::DatabaseError)?;
130
131        let hash_map: HashMap<String, Option<String>> = rows
132            .into_iter()
133            .map(|row| (row.original_id, row.content_hash))
134            .collect();
135
136        Ok(hash_map)
137    }
138
139    /// Updates only the timestamp for unchanged datasets. Returns true if a row was updated.
140    pub async fn update_timestamp_only(
141        &self,
142        portal_url: &str,
143        original_id: &str,
144    ) -> Result<bool, AppError> {
145        let result = sqlx::query(
146            r#"
147            UPDATE datasets
148            SET last_updated_at = NOW()
149            WHERE source_portal = $1 AND original_id = $2
150            "#,
151        )
152        .bind(portal_url)
153        .bind(original_id)
154        .execute(&self.pool)
155        .await
156        .map_err(AppError::DatabaseError)?;
157
158        Ok(result.rows_affected() > 0)
159    }
160
161    /// Retrieves a dataset by UUID.
162    pub async fn get(&self, id: Uuid) -> Result<Option<Dataset>, AppError> {
163        let query = format!("SELECT {} FROM datasets WHERE id = $1", DATASET_COLUMNS);
164        let result = sqlx::query_as::<_, Dataset>(&query)
165            .bind(id)
166            .fetch_optional(&self.pool)
167            .await
168            .map_err(AppError::DatabaseError)?;
169
170        Ok(result)
171    }
172
173    /// Semantic search using cosine similarity. Returns results ordered by similarity.
174    pub async fn search(
175        &self,
176        query_vector: Vector,
177        limit: usize,
178    ) -> Result<Vec<SearchResult>, AppError> {
179        let query = format!(
180            "SELECT {}, 1 - (embedding <=> $1) as similarity_score FROM datasets WHERE embedding IS NOT NULL ORDER BY embedding <=> $1 LIMIT $2",
181            DATASET_COLUMNS
182        );
183        let results = sqlx::query_as::<_, SearchResultRow>(&query)
184            .bind(query_vector)
185            .bind(limit as i64)
186            .fetch_all(&self.pool)
187            .await
188            .map_err(AppError::DatabaseError)?;
189
190        Ok(results
191            .into_iter()
192            .map(|row| SearchResult {
193                dataset: Dataset {
194                    id: row.id,
195                    original_id: row.original_id,
196                    source_portal: row.source_portal,
197                    url: row.url,
198                    title: row.title,
199                    description: row.description,
200                    embedding: row.embedding,
201                    metadata: row.metadata,
202                    first_seen_at: row.first_seen_at,
203                    last_updated_at: row.last_updated_at,
204                    content_hash: row.content_hash,
205                },
206                similarity_score: row.similarity_score as f32,
207            })
208            .collect())
209    }
210
211    /// Lists datasets with optional portal filter and limit.
212    ///
213    /// TODO(config): Make default limit configurable via DEFAULT_EXPORT_LIMIT env var
214    /// Currently hardcoded to 10000. For large exports, consider streaming instead.
215    ///
216    /// TODO(performance): Implement streaming/pagination for memory efficiency
217    /// Loading all datasets into memory doesn't scale. Consider returning
218    /// `impl Stream<Item = Result<Dataset, AppError>>` or cursor-based pagination.
219    pub async fn list_all(
220        &self,
221        portal_filter: Option<&str>,
222        limit: Option<usize>,
223    ) -> Result<Vec<Dataset>, AppError> {
224        // TODO(config): Read default from DEFAULT_EXPORT_LIMIT env var
225        let limit_val = limit.unwrap_or(10000) as i64;
226
227        let datasets = if let Some(portal) = portal_filter {
228            let query = format!(
229                "SELECT {} FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC LIMIT $2",
230                DATASET_COLUMNS
231            );
232            sqlx::query_as::<_, Dataset>(&query)
233                .bind(portal)
234                .bind(limit_val)
235                .fetch_all(&self.pool)
236                .await
237                .map_err(AppError::DatabaseError)?
238        } else {
239            let query = format!(
240                "SELECT {} FROM datasets ORDER BY last_updated_at DESC LIMIT $1",
241                DATASET_COLUMNS
242            );
243            sqlx::query_as::<_, Dataset>(&query)
244                .bind(limit_val)
245                .fetch_all(&self.pool)
246                .await
247                .map_err(AppError::DatabaseError)?
248        };
249
250        Ok(datasets)
251    }
252
253    /// Returns aggregated database statistics.
254    pub async fn get_stats(&self) -> Result<DatabaseStats, AppError> {
255        let row: StatsRow = sqlx::query_as(
256            r#"
257            SELECT
258                COUNT(*) as total,
259                COUNT(embedding) as with_embeddings,
260                COUNT(DISTINCT source_portal) as portals,
261                MAX(last_updated_at) as last_update
262            FROM datasets
263            "#,
264        )
265        .fetch_one(&self.pool)
266        .await
267        .map_err(AppError::DatabaseError)?;
268
269        Ok(DatabaseStats {
270            total_datasets: row.total.unwrap_or(0),
271            datasets_with_embeddings: row.with_embeddings.unwrap_or(0),
272            total_portals: row.portals.unwrap_or(0),
273            last_update: row.last_update,
274        })
275    }
276}
277
278/// Helper struct for deserializing stats query results
279#[derive(sqlx::FromRow)]
280struct StatsRow {
281    total: Option<i64>,
282    with_embeddings: Option<i64>,
283    portals: Option<i64>,
284    last_update: Option<DateTime<Utc>>,
285}
286
287/// Helper struct for deserializing search query results
288#[derive(sqlx::FromRow)]
289struct SearchResultRow {
290    id: Uuid,
291    original_id: String,
292    source_portal: String,
293    url: String,
294    title: String,
295    description: Option<String>,
296    embedding: Option<Vector>,
297    metadata: Json<serde_json::Value>,
298    first_seen_at: DateTime<Utc>,
299    last_updated_at: DateTime<Utc>,
300    content_hash: Option<String>,
301    similarity_score: f64,
302}
303
304/// Helper struct for deserializing hash lookup query results
305#[derive(sqlx::FromRow)]
306struct HashRow {
307    original_id: String,
308    content_hash: Option<String>,
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use serde_json::json;
315
316    #[test]
317    fn test_new_dataset_structure() {
318        let title = "Test Dataset";
319        let description = Some("Test description".to_string());
320        let content_hash = NewDataset::compute_content_hash(title, description.as_deref());
321
322        let new_dataset = NewDataset {
323            original_id: "test-id".to_string(),
324            source_portal: "https://example.com".to_string(),
325            url: "https://example.com/dataset/test".to_string(),
326            title: title.to_string(),
327            description,
328            embedding: Some(Vector::from(vec![0.1, 0.2, 0.3])),
329            metadata: json!({"key": "value"}),
330            content_hash,
331        };
332
333        assert_eq!(new_dataset.original_id, "test-id");
334        assert_eq!(new_dataset.title, "Test Dataset");
335        assert!(new_dataset.embedding.is_some());
336        assert_eq!(new_dataset.content_hash.len(), 64);
337    }
338
339    #[test]
340    fn test_embedding_vector_conversion() {
341        let vec_f32 = vec![0.1_f32, 0.2, 0.3, 0.4];
342        let vector = Vector::from(vec_f32.clone());
343        assert_eq!(vector.as_slice().len(), vec_f32.len());
344    }
345
346    #[test]
347    fn test_metadata_serialization() {
348        let metadata = json!({
349            "organization": "test-org",
350            "tags": ["tag1", "tag2"]
351        });
352
353        let serialized = serde_json::to_value(&metadata).unwrap();
354        assert!(serialized.is_object());
355        assert_eq!(serialized["organization"], "test-org");
356    }
357}