1use ceres_core::error::AppError;
18use ceres_core::models::{DatabaseStats, Dataset, NewDataset, SearchResult};
19use chrono::{DateTime, Utc};
20use futures::StreamExt;
21use futures::stream::BoxStream;
22use pgvector::Vector;
23use sqlx::types::Json;
24use sqlx::{PgPool, Pool, Postgres};
25use std::collections::HashMap;
26use uuid::Uuid;
27
28const DATASET_COLUMNS: &str = "id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash";
31
32const LIST_ALL_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets ORDER BY last_updated_at DESC";
34const LIST_ALL_LIMIT_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets ORDER BY last_updated_at DESC LIMIT $1";
35const LIST_ALL_PORTAL_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC";
36const LIST_ALL_PORTAL_LIMIT_QUERY: &str = "SELECT id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC LIMIT $2";
37
38#[derive(Clone)]
57pub struct DatasetRepository {
58 pool: Pool<Postgres>,
59}
60
61impl DatasetRepository {
62 pub fn new(pool: PgPool) -> Self {
63 Self { pool }
64 }
65
66 pub async fn upsert(&self, new_data: &NewDataset) -> Result<Uuid, AppError> {
73 let embedding_vector = new_data.embedding.as_ref().cloned();
74
75 let rec: (Uuid,) = sqlx::query_as(
76 r#"
77 INSERT INTO datasets (
78 original_id,
79 source_portal,
80 url,
81 title,
82 description,
83 embedding,
84 metadata,
85 content_hash,
86 last_updated_at
87 )
88 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
89 ON CONFLICT (source_portal, original_id)
90 DO UPDATE SET
91 title = EXCLUDED.title,
92 description = EXCLUDED.description,
93 url = EXCLUDED.url,
94 embedding = COALESCE(EXCLUDED.embedding, datasets.embedding),
95 metadata = EXCLUDED.metadata,
96 content_hash = EXCLUDED.content_hash,
97 last_updated_at = NOW()
98 RETURNING id
99 "#,
100 )
101 .bind(&new_data.original_id)
102 .bind(&new_data.source_portal)
103 .bind(&new_data.url)
104 .bind(&new_data.title)
105 .bind(&new_data.description)
106 .bind(embedding_vector)
107 .bind(serde_json::to_value(&new_data.metadata).unwrap_or(serde_json::json!({})))
108 .bind(&new_data.content_hash)
109 .fetch_one(&self.pool)
110 .await
111 .map_err(AppError::DatabaseError)?;
112
113 Ok(rec.0)
114 }
115
116 pub async fn get_hashes_for_portal(
124 &self,
125 portal_url: &str,
126 ) -> Result<HashMap<String, Option<String>>, AppError> {
127 let rows: Vec<HashRow> = sqlx::query_as(
128 r#"
129 SELECT original_id, content_hash
130 FROM datasets
131 WHERE source_portal = $1
132 "#,
133 )
134 .bind(portal_url)
135 .fetch_all(&self.pool)
136 .await
137 .map_err(AppError::DatabaseError)?;
138
139 let hash_map: HashMap<String, Option<String>> = rows
140 .into_iter()
141 .map(|row| (row.original_id, row.content_hash))
142 .collect();
143
144 Ok(hash_map)
145 }
146
147 pub async fn update_timestamp_only(
149 &self,
150 portal_url: &str,
151 original_id: &str,
152 ) -> Result<bool, AppError> {
153 let result = sqlx::query(
154 r#"
155 UPDATE datasets
156 SET last_updated_at = NOW()
157 WHERE source_portal = $1 AND original_id = $2
158 "#,
159 )
160 .bind(portal_url)
161 .bind(original_id)
162 .execute(&self.pool)
163 .await
164 .map_err(AppError::DatabaseError)?;
165
166 Ok(result.rows_affected() > 0)
167 }
168
169 pub async fn batch_update_timestamps(
174 &self,
175 portal_url: &str,
176 original_ids: &[String],
177 ) -> Result<u64, AppError> {
178 if original_ids.is_empty() {
179 return Ok(0);
180 }
181
182 let result = sqlx::query(
183 r#"
184 UPDATE datasets
185 SET last_updated_at = NOW()
186 WHERE source_portal = $1 AND original_id = ANY($2)
187 "#,
188 )
189 .bind(portal_url)
190 .bind(original_ids)
191 .execute(&self.pool)
192 .await
193 .map_err(AppError::DatabaseError)?;
194
195 Ok(result.rows_affected())
196 }
197
198 pub async fn get(&self, id: Uuid) -> Result<Option<Dataset>, AppError> {
200 let query = format!("SELECT {} FROM datasets WHERE id = $1", DATASET_COLUMNS);
201 let result = sqlx::query_as::<_, Dataset>(&query)
202 .bind(id)
203 .fetch_optional(&self.pool)
204 .await
205 .map_err(AppError::DatabaseError)?;
206
207 Ok(result)
208 }
209
210 pub async fn search(
212 &self,
213 query_vector: Vector,
214 limit: usize,
215 ) -> Result<Vec<SearchResult>, AppError> {
216 let query = format!(
217 "SELECT {}, 1 - (embedding <=> $1) as similarity_score FROM datasets WHERE embedding IS NOT NULL ORDER BY embedding <=> $1 LIMIT $2",
218 DATASET_COLUMNS
219 );
220 let results = sqlx::query_as::<_, SearchResultRow>(&query)
221 .bind(query_vector)
222 .bind(limit as i64)
223 .fetch_all(&self.pool)
224 .await
225 .map_err(AppError::DatabaseError)?;
226
227 Ok(results
228 .into_iter()
229 .map(|row| SearchResult {
230 dataset: Dataset {
231 id: row.id,
232 original_id: row.original_id,
233 source_portal: row.source_portal,
234 url: row.url,
235 title: row.title,
236 description: row.description,
237 embedding: row.embedding,
238 metadata: row.metadata,
239 first_seen_at: row.first_seen_at,
240 last_updated_at: row.last_updated_at,
241 content_hash: row.content_hash,
242 },
243 similarity_score: row.similarity_score as f32,
244 })
245 .collect())
246 }
247
248 pub async fn list_all(
255 &self,
256 portal_filter: Option<&str>,
257 limit: Option<usize>,
258 ) -> Result<Vec<Dataset>, AppError> {
259 let limit_val = limit.unwrap_or(10000) as i64;
261
262 let datasets = if let Some(portal) = portal_filter {
263 let query = format!(
264 "SELECT {} FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC LIMIT $2",
265 DATASET_COLUMNS
266 );
267 sqlx::query_as::<_, Dataset>(&query)
268 .bind(portal)
269 .bind(limit_val)
270 .fetch_all(&self.pool)
271 .await
272 .map_err(AppError::DatabaseError)?
273 } else {
274 let query = format!(
275 "SELECT {} FROM datasets ORDER BY last_updated_at DESC LIMIT $1",
276 DATASET_COLUMNS
277 );
278 sqlx::query_as::<_, Dataset>(&query)
279 .bind(limit_val)
280 .fetch_all(&self.pool)
281 .await
282 .map_err(AppError::DatabaseError)?
283 };
284
285 Ok(datasets)
286 }
287
288 pub fn list_all_stream<'a>(
298 &'a self,
299 portal_filter: Option<&'a str>,
300 limit: Option<usize>,
301 ) -> BoxStream<'a, Result<Dataset, AppError>> {
302 match (portal_filter, limit) {
303 (Some(portal), Some(lim)) => Box::pin(
304 sqlx::query_as::<_, Dataset>(LIST_ALL_PORTAL_LIMIT_QUERY)
305 .bind(portal)
306 .bind(lim as i64)
307 .fetch(&self.pool)
308 .map(|r| r.map_err(AppError::DatabaseError)),
309 ),
310 (Some(portal), None) => Box::pin(
311 sqlx::query_as::<_, Dataset>(LIST_ALL_PORTAL_QUERY)
312 .bind(portal)
313 .fetch(&self.pool)
314 .map(|r| r.map_err(AppError::DatabaseError)),
315 ),
316 (None, Some(lim)) => Box::pin(
317 sqlx::query_as::<_, Dataset>(LIST_ALL_LIMIT_QUERY)
318 .bind(lim as i64)
319 .fetch(&self.pool)
320 .map(|r| r.map_err(AppError::DatabaseError)),
321 ),
322 (None, None) => Box::pin(
323 sqlx::query_as::<_, Dataset>(LIST_ALL_QUERY)
324 .fetch(&self.pool)
325 .map(|r| r.map_err(AppError::DatabaseError)),
326 ),
327 }
328 }
329
330 pub async fn get_sync_status(
337 &self,
338 portal_url: &str,
339 ) -> Result<Option<PortalSyncStatus>, AppError> {
340 let result = sqlx::query_as::<_, PortalSyncStatus>(
341 r#"
342 SELECT portal_url, last_successful_sync, last_sync_mode, sync_status, datasets_synced, created_at, updated_at
343 FROM portal_sync_status
344 WHERE portal_url = $1
345 "#,
346 )
347 .bind(portal_url)
348 .fetch_optional(&self.pool)
349 .await
350 .map_err(AppError::DatabaseError)?;
351
352 Ok(result)
353 }
354
355 pub async fn upsert_sync_status(
361 &self,
362 portal_url: &str,
363 last_sync: DateTime<Utc>,
364 sync_mode: &str,
365 sync_status: &str,
366 datasets_synced: i32,
367 ) -> Result<(), AppError> {
368 sqlx::query(
369 r#"
370 INSERT INTO portal_sync_status (portal_url, last_successful_sync, last_sync_mode, sync_status, datasets_synced, updated_at)
371 VALUES (
372 $1,
373 CASE WHEN $4 = 'completed' THEN $2 ELSE NULL END,
374 $3,
375 $4,
376 $5,
377 NOW()
378 )
379 ON CONFLICT (portal_url)
380 DO UPDATE SET
381 last_successful_sync = CASE
382 WHEN EXCLUDED.sync_status = 'completed' THEN $2
383 ELSE portal_sync_status.last_successful_sync
384 END,
385 last_sync_mode = EXCLUDED.last_sync_mode,
386 sync_status = EXCLUDED.sync_status,
387 datasets_synced = EXCLUDED.datasets_synced,
388 updated_at = NOW()
389 "#,
390 )
391 .bind(portal_url)
392 .bind(last_sync)
393 .bind(sync_mode)
394 .bind(sync_status)
395 .bind(datasets_synced)
396 .execute(&self.pool)
397 .await
398 .map_err(AppError::DatabaseError)?;
399
400 Ok(())
401 }
402
403 pub async fn health_check(&self) -> Result<(), AppError> {
405 sqlx::query("SELECT 1")
406 .execute(&self.pool)
407 .await
408 .map_err(AppError::DatabaseError)?;
409 Ok(())
410 }
411
412 pub async fn get_stats(&self) -> Result<DatabaseStats, AppError> {
414 let row: StatsRow = sqlx::query_as(
415 r#"
416 SELECT
417 COUNT(*) as total,
418 COUNT(embedding) as with_embeddings,
419 COUNT(DISTINCT source_portal) as portals,
420 MAX(last_updated_at) as last_update
421 FROM datasets
422 "#,
423 )
424 .fetch_one(&self.pool)
425 .await
426 .map_err(AppError::DatabaseError)?;
427
428 Ok(DatabaseStats {
429 total_datasets: row.total.unwrap_or(0),
430 datasets_with_embeddings: row.with_embeddings.unwrap_or(0),
431 total_portals: row.portals.unwrap_or(0),
432 last_update: row.last_update,
433 })
434 }
435}
436
437#[derive(sqlx::FromRow)]
439struct StatsRow {
440 total: Option<i64>,
441 with_embeddings: Option<i64>,
442 portals: Option<i64>,
443 last_update: Option<DateTime<Utc>>,
444}
445
446#[derive(sqlx::FromRow)]
448struct SearchResultRow {
449 id: Uuid,
450 original_id: String,
451 source_portal: String,
452 url: String,
453 title: String,
454 description: Option<String>,
455 embedding: Option<Vector>,
456 metadata: Json<serde_json::Value>,
457 first_seen_at: DateTime<Utc>,
458 last_updated_at: DateTime<Utc>,
459 content_hash: Option<String>,
460 similarity_score: f64,
461}
462
463#[derive(sqlx::FromRow)]
465struct HashRow {
466 original_id: String,
467 content_hash: Option<String>,
468}
469
470#[derive(Debug, Clone, sqlx::FromRow)]
472pub struct PortalSyncStatus {
473 pub portal_url: String,
474 pub last_successful_sync: Option<DateTime<Utc>>,
475 pub last_sync_mode: Option<String>,
476 pub sync_status: Option<String>,
477 pub datasets_synced: i32,
478 pub created_at: DateTime<Utc>,
479 pub updated_at: DateTime<Utc>,
480}
481
482impl ceres_core::traits::DatasetStore for DatasetRepository {
487 async fn get_by_id(&self, id: Uuid) -> Result<Option<Dataset>, AppError> {
488 DatasetRepository::get(self, id).await
489 }
490
491 async fn get_hashes_for_portal(
492 &self,
493 portal_url: &str,
494 ) -> Result<HashMap<String, Option<String>>, AppError> {
495 DatasetRepository::get_hashes_for_portal(self, portal_url).await
496 }
497
498 async fn update_timestamp_only(
499 &self,
500 portal_url: &str,
501 original_id: &str,
502 ) -> Result<(), AppError> {
503 DatasetRepository::update_timestamp_only(self, portal_url, original_id).await?;
504 Ok(())
505 }
506
507 async fn batch_update_timestamps(
508 &self,
509 portal_url: &str,
510 original_ids: &[String],
511 ) -> Result<u64, AppError> {
512 DatasetRepository::batch_update_timestamps(self, portal_url, original_ids).await
513 }
514
515 async fn upsert(&self, dataset: &NewDataset) -> Result<Uuid, AppError> {
516 DatasetRepository::upsert(self, dataset).await
517 }
518
519 async fn search(
520 &self,
521 query_vector: Vector,
522 limit: usize,
523 ) -> Result<Vec<SearchResult>, AppError> {
524 DatasetRepository::search(self, query_vector, limit).await
525 }
526
527 fn list_stream<'a>(
528 &'a self,
529 portal_filter: Option<&'a str>,
530 limit: Option<usize>,
531 ) -> BoxStream<'a, Result<Dataset, AppError>> {
532 DatasetRepository::list_all_stream(self, portal_filter, limit)
533 }
534
535 async fn get_last_sync_time(
536 &self,
537 portal_url: &str,
538 ) -> Result<Option<DateTime<Utc>>, AppError> {
539 let status = DatasetRepository::get_sync_status(self, portal_url).await?;
540 Ok(status.and_then(|s| s.last_successful_sync))
541 }
542
543 async fn record_sync_status(
544 &self,
545 portal_url: &str,
546 sync_time: DateTime<Utc>,
547 sync_mode: &str,
548 sync_status: &str,
549 datasets_synced: i32,
550 ) -> Result<(), AppError> {
551 DatasetRepository::upsert_sync_status(
552 self,
553 portal_url,
554 sync_time,
555 sync_mode,
556 sync_status,
557 datasets_synced,
558 )
559 .await
560 }
561
562 async fn health_check(&self) -> Result<(), AppError> {
563 DatasetRepository::health_check(self).await
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use serde_json::json;
571
572 #[test]
573 fn test_new_dataset_structure() {
574 let title = "Test Dataset";
575 let description = Some("Test description".to_string());
576 let content_hash = NewDataset::compute_content_hash(title, description.as_deref());
577
578 let new_dataset = NewDataset {
579 original_id: "test-id".to_string(),
580 source_portal: "https://example.com".to_string(),
581 url: "https://example.com/dataset/test".to_string(),
582 title: title.to_string(),
583 description,
584 embedding: Some(Vector::from(vec![0.1, 0.2, 0.3])),
585 metadata: json!({"key": "value"}),
586 content_hash,
587 };
588
589 assert_eq!(new_dataset.original_id, "test-id");
590 assert_eq!(new_dataset.title, "Test Dataset");
591 assert!(new_dataset.embedding.is_some());
592 assert_eq!(new_dataset.content_hash.len(), 64);
593 }
594
595 #[test]
596 fn test_embedding_vector_conversion() {
597 let vec_f32 = vec![0.1_f32, 0.2, 0.3, 0.4];
598 let vector = Vector::from(vec_f32.clone());
599 assert_eq!(vector.as_slice().len(), vec_f32.len());
600 }
601
602 #[test]
603 fn test_metadata_serialization() {
604 let metadata = json!({
605 "organization": "test-org",
606 "tags": ["tag1", "tag2"]
607 });
608
609 let serialized = serde_json::to_value(&metadata).unwrap();
610 assert!(serialized.is_object());
611 assert_eq!(serialized["organization"], "test-org");
612 }
613}