1use ceres_core::error::AppError;
18use ceres_core::models::{DatabaseStats, Dataset, NewDataset, SearchResult};
19use chrono::{DateTime, Utc};
20use pgvector::Vector;
21use sqlx::types::Json;
22use sqlx::{PgPool, Pool, Postgres};
23use std::collections::HashMap;
24use uuid::Uuid;
25
26const DATASET_COLUMNS: &str = "id, original_id, source_portal, url, title, description, embedding, metadata, first_seen_at, last_updated_at, content_hash";
29
30#[derive(Clone)]
49pub struct DatasetRepository {
50 pool: Pool<Postgres>,
51}
52
53impl DatasetRepository {
54 pub fn new(pool: PgPool) -> Self {
55 Self { pool }
56 }
57
58 pub async fn upsert(&self, new_data: &NewDataset) -> Result<Uuid, AppError> {
65 let embedding_vector = new_data.embedding.as_ref().cloned();
66
67 let rec: (Uuid,) = sqlx::query_as(
68 r#"
69 INSERT INTO datasets (
70 original_id,
71 source_portal,
72 url,
73 title,
74 description,
75 embedding,
76 metadata,
77 content_hash,
78 last_updated_at
79 )
80 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
81 ON CONFLICT (source_portal, original_id)
82 DO UPDATE SET
83 title = EXCLUDED.title,
84 description = EXCLUDED.description,
85 url = EXCLUDED.url,
86 embedding = COALESCE(EXCLUDED.embedding, datasets.embedding),
87 metadata = EXCLUDED.metadata,
88 content_hash = EXCLUDED.content_hash,
89 last_updated_at = NOW()
90 RETURNING id
91 "#,
92 )
93 .bind(&new_data.original_id)
94 .bind(&new_data.source_portal)
95 .bind(&new_data.url)
96 .bind(&new_data.title)
97 .bind(&new_data.description)
98 .bind(embedding_vector)
99 .bind(serde_json::to_value(&new_data.metadata).unwrap_or(serde_json::json!({})))
100 .bind(&new_data.content_hash)
101 .fetch_one(&self.pool)
102 .await
103 .map_err(AppError::DatabaseError)?;
104
105 Ok(rec.0)
106 }
107
108 pub async fn get_hashes_for_portal(
116 &self,
117 portal_url: &str,
118 ) -> Result<HashMap<String, Option<String>>, AppError> {
119 let rows: Vec<HashRow> = sqlx::query_as(
120 r#"
121 SELECT original_id, content_hash
122 FROM datasets
123 WHERE source_portal = $1
124 "#,
125 )
126 .bind(portal_url)
127 .fetch_all(&self.pool)
128 .await
129 .map_err(AppError::DatabaseError)?;
130
131 let hash_map: HashMap<String, Option<String>> = rows
132 .into_iter()
133 .map(|row| (row.original_id, row.content_hash))
134 .collect();
135
136 Ok(hash_map)
137 }
138
139 pub async fn update_timestamp_only(
141 &self,
142 portal_url: &str,
143 original_id: &str,
144 ) -> Result<bool, AppError> {
145 let result = sqlx::query(
146 r#"
147 UPDATE datasets
148 SET last_updated_at = NOW()
149 WHERE source_portal = $1 AND original_id = $2
150 "#,
151 )
152 .bind(portal_url)
153 .bind(original_id)
154 .execute(&self.pool)
155 .await
156 .map_err(AppError::DatabaseError)?;
157
158 Ok(result.rows_affected() > 0)
159 }
160
161 pub async fn get(&self, id: Uuid) -> Result<Option<Dataset>, AppError> {
163 let query = format!("SELECT {} FROM datasets WHERE id = $1", DATASET_COLUMNS);
164 let result = sqlx::query_as::<_, Dataset>(&query)
165 .bind(id)
166 .fetch_optional(&self.pool)
167 .await
168 .map_err(AppError::DatabaseError)?;
169
170 Ok(result)
171 }
172
173 pub async fn search(
175 &self,
176 query_vector: Vector,
177 limit: usize,
178 ) -> Result<Vec<SearchResult>, AppError> {
179 let query = format!(
180 "SELECT {}, 1 - (embedding <=> $1) as similarity_score FROM datasets WHERE embedding IS NOT NULL ORDER BY embedding <=> $1 LIMIT $2",
181 DATASET_COLUMNS
182 );
183 let results = sqlx::query_as::<_, SearchResultRow>(&query)
184 .bind(query_vector)
185 .bind(limit as i64)
186 .fetch_all(&self.pool)
187 .await
188 .map_err(AppError::DatabaseError)?;
189
190 Ok(results
191 .into_iter()
192 .map(|row| SearchResult {
193 dataset: Dataset {
194 id: row.id,
195 original_id: row.original_id,
196 source_portal: row.source_portal,
197 url: row.url,
198 title: row.title,
199 description: row.description,
200 embedding: row.embedding,
201 metadata: row.metadata,
202 first_seen_at: row.first_seen_at,
203 last_updated_at: row.last_updated_at,
204 content_hash: row.content_hash,
205 },
206 similarity_score: row.similarity_score as f32,
207 })
208 .collect())
209 }
210
211 pub async fn list_all(
220 &self,
221 portal_filter: Option<&str>,
222 limit: Option<usize>,
223 ) -> Result<Vec<Dataset>, AppError> {
224 let limit_val = limit.unwrap_or(10000) as i64;
226
227 let datasets = if let Some(portal) = portal_filter {
228 let query = format!(
229 "SELECT {} FROM datasets WHERE source_portal = $1 ORDER BY last_updated_at DESC LIMIT $2",
230 DATASET_COLUMNS
231 );
232 sqlx::query_as::<_, Dataset>(&query)
233 .bind(portal)
234 .bind(limit_val)
235 .fetch_all(&self.pool)
236 .await
237 .map_err(AppError::DatabaseError)?
238 } else {
239 let query = format!(
240 "SELECT {} FROM datasets ORDER BY last_updated_at DESC LIMIT $1",
241 DATASET_COLUMNS
242 );
243 sqlx::query_as::<_, Dataset>(&query)
244 .bind(limit_val)
245 .fetch_all(&self.pool)
246 .await
247 .map_err(AppError::DatabaseError)?
248 };
249
250 Ok(datasets)
251 }
252
253 pub async fn get_stats(&self) -> Result<DatabaseStats, AppError> {
255 let row: StatsRow = sqlx::query_as(
256 r#"
257 SELECT
258 COUNT(*) as total,
259 COUNT(embedding) as with_embeddings,
260 COUNT(DISTINCT source_portal) as portals,
261 MAX(last_updated_at) as last_update
262 FROM datasets
263 "#,
264 )
265 .fetch_one(&self.pool)
266 .await
267 .map_err(AppError::DatabaseError)?;
268
269 Ok(DatabaseStats {
270 total_datasets: row.total.unwrap_or(0),
271 datasets_with_embeddings: row.with_embeddings.unwrap_or(0),
272 total_portals: row.portals.unwrap_or(0),
273 last_update: row.last_update,
274 })
275 }
276}
277
278#[derive(sqlx::FromRow)]
280struct StatsRow {
281 total: Option<i64>,
282 with_embeddings: Option<i64>,
283 portals: Option<i64>,
284 last_update: Option<DateTime<Utc>>,
285}
286
287#[derive(sqlx::FromRow)]
289struct SearchResultRow {
290 id: Uuid,
291 original_id: String,
292 source_portal: String,
293 url: String,
294 title: String,
295 description: Option<String>,
296 embedding: Option<Vector>,
297 metadata: Json<serde_json::Value>,
298 first_seen_at: DateTime<Utc>,
299 last_updated_at: DateTime<Utc>,
300 content_hash: Option<String>,
301 similarity_score: f64,
302}
303
304#[derive(sqlx::FromRow)]
306struct HashRow {
307 original_id: String,
308 content_hash: Option<String>,
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use serde_json::json;
315
316 #[test]
317 fn test_new_dataset_structure() {
318 let title = "Test Dataset";
319 let description = Some("Test description".to_string());
320 let content_hash = NewDataset::compute_content_hash(title, description.as_deref());
321
322 let new_dataset = NewDataset {
323 original_id: "test-id".to_string(),
324 source_portal: "https://example.com".to_string(),
325 url: "https://example.com/dataset/test".to_string(),
326 title: title.to_string(),
327 description,
328 embedding: Some(Vector::from(vec![0.1, 0.2, 0.3])),
329 metadata: json!({"key": "value"}),
330 content_hash,
331 };
332
333 assert_eq!(new_dataset.original_id, "test-id");
334 assert_eq!(new_dataset.title, "Test Dataset");
335 assert!(new_dataset.embedding.is_some());
336 assert_eq!(new_dataset.content_hash.len(), 64);
337 }
338
339 #[test]
340 fn test_embedding_vector_conversion() {
341 let vec_f32 = vec![0.1_f32, 0.2, 0.3, 0.4];
342 let vector = Vector::from(vec_f32.clone());
343 assert_eq!(vector.as_slice().len(), vec_f32.len());
344 }
345
346 #[test]
347 fn test_metadata_serialization() {
348 let metadata = json!({
349 "organization": "test-org",
350 "tags": ["tag1", "tag2"]
351 });
352
353 let serialized = serde_json::to_value(&metadata).unwrap();
354 assert!(serialized.is_object());
355 assert_eq!(serialized["organization"], "test-org");
356 }
357}