1use std::fmt::Write as _;
9
10use rusqlite::{OptionalExtension, params};
11use sha2::{Digest, Sha256};
12
13use super::{Db, StorageError};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct Page {
23 pub url_hash: String,
24 pub url: String,
25 pub canonical_url: String,
26 pub title: Option<String>,
27 pub fetched_at: i64,
28 pub expires_at: Option<i64>,
29 pub etag: Option<String>,
30 pub last_modified: Option<String>,
31 pub content_hash: String,
32 pub extracted_md: String,
33 pub metadata_json: Option<String>,
34 pub raw_html: Option<Vec<u8>>,
39 pub render_reason: Option<String>,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct CacheStats {
50 pub entry_count: u64,
51 pub total_extracted_bytes: u64,
52 pub expired_count: u64,
53}
54
55#[derive(Debug, Clone)]
57pub struct CacheListEntry {
58 pub url: String,
59 pub canonical_url: String,
60 pub title: Option<String>,
61 pub fetched_at: i64,
62 pub expires_at: Option<i64>,
63 pub size_bytes: i64,
64}
65
66pub fn url_hash(url: &str) -> String {
71 let mut h = Sha256::new();
72 h.update(url.as_bytes());
73 let out = h.finalize();
74 let mut s = String::with_capacity(out.len() * 2);
75 for b in out {
76 write!(s, "{b:02x}").expect("write to String never fails");
78 }
79 s
80}
81
82const SELECT_COLUMNS: &str = "url_hash, url, canonical_url, title, fetched_at, expires_at, \
85 etag, last_modified, content_hash, extracted_md, metadata_json, render_reason";
86
87fn row_to_page(row: &rusqlite::Row<'_>) -> rusqlite::Result<Page> {
88 Ok(Page {
89 url_hash: row.get(0)?,
90 url: row.get(1)?,
91 canonical_url: row.get(2)?,
92 title: row.get(3)?,
93 fetched_at: row.get(4)?,
94 expires_at: row.get(5)?,
95 etag: row.get(6)?,
96 last_modified: row.get(7)?,
97 content_hash: row.get(8)?,
98 extracted_md: row.get(9)?,
99 metadata_json: row.get(10)?,
100 render_reason: row.get(11)?,
101 raw_html: None,
104 })
105}
106
107pub async fn get_by_url_hash(db: &Db, hash: &str) -> Result<Option<Page>, StorageError> {
109 let hash = hash.to_owned();
110 let page = db
111 .conn
112 .call(move |c| {
113 c.query_row(
114 &format!("SELECT {SELECT_COLUMNS} FROM pages WHERE url_hash = ?1"),
115 params![hash],
116 row_to_page,
117 )
118 .optional()
119 })
120 .await?;
121 Ok(page)
122}
123
124pub async fn get_by_url(db: &Db, url: &str) -> Result<Option<Page>, StorageError> {
130 let url = url.to_owned();
131 let page = db
132 .conn
133 .call(move |c| {
134 c.query_row(
135 &format!("SELECT {SELECT_COLUMNS} FROM pages WHERE url = ?1 LIMIT 1"),
136 params![url],
137 row_to_page,
138 )
139 .optional()
140 })
141 .await?;
142 Ok(page)
143}
144
145pub async fn upsert(db: &Db, page: Page) -> Result<(), StorageError> {
147 let raw_zstd: Option<Vec<u8>> = match page.raw_html.as_ref() {
151 Some(bytes) => Some(zstd::stream::encode_all(bytes.as_slice(), 3).map_err(|e| {
152 StorageError::from(rusqlite::Error::ToSqlConversionFailure(Box::new(e)))
157 })?),
158 None => None,
159 };
160 db.conn
161 .call(move |c| {
162 c.execute(
163 "INSERT INTO pages (url_hash, url, canonical_url, title, fetched_at, \
164 expires_at, etag, last_modified, content_hash, \
165 extracted_md, metadata_json, raw_html_zstd, render_reason) \
166 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13) \
167 ON CONFLICT(url_hash) DO UPDATE SET \
168 url = excluded.url, \
169 canonical_url = excluded.canonical_url, \
170 title = excluded.title, \
171 fetched_at = excluded.fetched_at, \
172 expires_at = excluded.expires_at, \
173 etag = excluded.etag, \
174 last_modified = excluded.last_modified, \
175 content_hash = excluded.content_hash, \
176 extracted_md = excluded.extracted_md, \
177 metadata_json = excluded.metadata_json, \
178 raw_html_zstd = excluded.raw_html_zstd, \
179 render_reason = excluded.render_reason",
180 params![
181 page.url_hash,
182 page.url,
183 page.canonical_url,
184 page.title,
185 page.fetched_at,
186 page.expires_at,
187 page.etag,
188 page.last_modified,
189 page.content_hash,
190 page.extracted_md,
191 page.metadata_json,
192 raw_zstd,
193 page.render_reason,
194 ],
195 )?;
196 Ok(())
197 })
198 .await?;
199 Ok(())
200}
201
202pub async fn raw_html_bytes(db: &Db, url_hash: &str) -> Result<Option<Vec<u8>>, StorageError> {
209 let uh = url_hash.to_string();
210 let blob = db
211 .conn
212 .call(move |c| {
213 let r: rusqlite::Result<Option<Vec<u8>>> = c.query_row(
214 "SELECT raw_html_zstd FROM pages WHERE url_hash = ?1",
215 rusqlite::params![uh],
216 |row| row.get::<_, Option<Vec<u8>>>(0),
217 );
218 match r {
219 Ok(v) => Ok(v),
220 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
221 Err(e) => Err(e),
222 }
223 })
224 .await?;
225 Ok(blob)
226}
227
228pub async fn touch(
233 db: &Db,
234 url_hash: &str,
235 fetched_at: i64,
236 expires_at: Option<i64>,
237) -> Result<(), StorageError> {
238 let url_hash = url_hash.to_owned();
239 db.conn
240 .call(move |c| {
241 c.execute(
242 "UPDATE pages SET fetched_at = ?2, expires_at = ?3 WHERE url_hash = ?1",
243 params![url_hash, fetched_at, expires_at],
244 )?;
245 Ok(())
246 })
247 .await?;
248 Ok(())
249}
250
251pub async fn delete_by_url_like(db: &Db, like: &str) -> Result<u64, StorageError> {
255 let like = like.to_owned();
256 let n = db
257 .conn
258 .call(move |c| {
259 Ok(c.execute(
260 "DELETE FROM pages WHERE url LIKE ?1 ESCAPE '\\'",
261 params![like],
262 )? as u64)
263 })
264 .await?;
265 Ok(n)
266}
267
268pub async fn list_paginated(
270 db: &Db,
271 offset: u64,
272 limit: u64,
273) -> Result<Vec<CacheListEntry>, StorageError> {
274 let entries = db
275 .conn
276 .call(move |c| {
277 let mut stmt = c.prepare(
278 "SELECT url, canonical_url, title, fetched_at, expires_at, length(extracted_md) \
279 FROM pages \
280 ORDER BY fetched_at DESC \
281 LIMIT ?1 OFFSET ?2",
282 )?;
283 let rows = stmt
284 .query_map(params![limit as i64, offset as i64], |r| {
285 Ok(CacheListEntry {
286 url: r.get(0)?,
287 canonical_url: r.get(1)?,
288 title: r.get(2)?,
289 fetched_at: r.get(3)?,
290 expires_at: r.get(4)?,
291 size_bytes: r.get(5)?,
292 })
293 })?
294 .collect::<rusqlite::Result<Vec<_>>>()?;
295 Ok(rows)
296 })
297 .await?;
298 Ok(entries)
299}
300
301pub async fn stats(db: &Db, now: i64) -> Result<CacheStats, StorageError> {
306 let stats = db
307 .conn
308 .call(move |c| {
309 let entry_count: i64 = c.query_row("SELECT COUNT(*) FROM pages", [], |r| r.get(0))?;
310 let total_bytes: i64 = c.query_row(
311 "SELECT COALESCE(SUM(length(extracted_md)), 0) FROM pages",
312 [],
313 |r| r.get(0),
314 )?;
315 let expired_count: i64 = c.query_row(
316 "SELECT COUNT(*) FROM pages WHERE expires_at IS NOT NULL AND expires_at <= ?1",
317 params![now],
318 |r| r.get(0),
319 )?;
320 Ok(CacheStats {
321 entry_count: entry_count as u64,
322 total_extracted_bytes: total_bytes as u64,
323 expired_count: expired_count as u64,
324 })
325 })
326 .await?;
327 Ok(stats)
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 fn sample(hash: &str, url: &str) -> Page {
335 Page {
336 url_hash: hash.to_owned(),
337 url: url.to_owned(),
338 canonical_url: url.to_owned(),
339 title: Some("Sample".to_owned()),
340 fetched_at: 1_700_000_000,
341 expires_at: Some(1_700_003_600),
342 etag: Some("\"abc\"".to_owned()),
343 last_modified: None,
344 content_hash: "sha256:deadbeef".to_owned(),
345 extracted_md: "# Hello\n\nbody".to_owned(),
346 metadata_json: None,
347 raw_html: None,
348 render_reason: None,
349 }
350 }
351
352 async fn fresh_db() -> Db {
353 let tmp = tempfile::tempdir().unwrap();
354 Db::open(tmp.path().join("rover.db")).await.unwrap()
355 }
356
357 #[test]
358 fn url_hash_is_hex_64() {
359 let h = url_hash("https://example.com/");
360 assert_eq!(h.len(), 64);
361 assert!(h.chars().all(|c| c.is_ascii_hexdigit()));
362 }
363
364 #[tokio::test]
365 async fn upsert_then_get() {
366 let db = fresh_db().await;
367 let page = sample("hash1", "https://example.com/page");
368 upsert(&db, page.clone()).await.unwrap();
369 let got = get_by_url_hash(&db, "hash1").await.unwrap().unwrap();
370 assert_eq!(got, page);
371 }
372
373 #[tokio::test]
374 async fn render_reason_round_trips() {
375 let db = fresh_db().await;
376 let mut page = sample("hash1", "https://example.com/spa");
377 page.render_reason = Some("bot_challenge".to_owned());
378 upsert(&db, page.clone()).await.unwrap();
379 let got = get_by_url_hash(&db, "hash1").await.unwrap().unwrap();
380 assert_eq!(got.render_reason.as_deref(), Some("bot_challenge"));
381 assert_eq!(got, page);
382 }
383
384 #[tokio::test]
385 async fn upsert_replaces_existing() {
386 let db = fresh_db().await;
387 let p1 = sample("hash1", "https://example.com/v1");
388 let mut p2 = p1.clone();
389 p2.url = "https://example.com/v2".to_owned();
390 p2.fetched_at = 1_700_000_999;
391 upsert(&db, p1).await.unwrap();
392 upsert(&db, p2.clone()).await.unwrap();
393 let got = get_by_url_hash(&db, "hash1").await.unwrap().unwrap();
394 assert_eq!(got, p2);
395 }
396
397 #[tokio::test]
398 async fn get_by_url_finds_secondary_lookup() {
399 let db = fresh_db().await;
400 upsert(&db, sample("hash1", "https://example.com/article"))
401 .await
402 .unwrap();
403 let got = get_by_url(&db, "https://example.com/article")
404 .await
405 .unwrap();
406 assert!(got.is_some());
407 }
408
409 #[tokio::test]
410 async fn touch_updates_timestamps() {
411 let db = fresh_db().await;
412 upsert(&db, sample("hash1", "https://example.com/"))
413 .await
414 .unwrap();
415 touch(&db, "hash1", 1_700_999_999, Some(1_700_999_999 + 3600))
416 .await
417 .unwrap();
418 let got = get_by_url_hash(&db, "hash1").await.unwrap().unwrap();
419 assert_eq!(got.fetched_at, 1_700_999_999);
420 assert_eq!(got.expires_at, Some(1_700_999_999 + 3600));
421 }
422
423 #[tokio::test]
424 async fn delete_by_url_like() {
425 let db = fresh_db().await;
426 upsert(&db, sample("h1", "https://docs.example.com/a"))
427 .await
428 .unwrap();
429 upsert(&db, sample("h2", "https://docs.example.com/b"))
430 .await
431 .unwrap();
432 upsert(&db, sample("h3", "https://other.com/c"))
433 .await
434 .unwrap();
435 let n = super::delete_by_url_like(&db, "https://docs.example.com/%")
436 .await
437 .unwrap();
438 assert_eq!(n, 2);
439 assert!(get_by_url_hash(&db, "h1").await.unwrap().is_none());
440 assert!(get_by_url_hash(&db, "h3").await.unwrap().is_some());
441 }
442
443 #[tokio::test]
444 async fn list_paginated_orders_by_recency() {
445 let db = fresh_db().await;
446 let mut a = sample("h_a", "https://a/");
447 a.fetched_at = 100;
448 let mut b = sample("h_b", "https://b/");
449 b.fetched_at = 200;
450 upsert(&db, a).await.unwrap();
451 upsert(&db, b).await.unwrap();
452 let rows = list_paginated(&db, 0, 10).await.unwrap();
453 assert_eq!(rows.len(), 2);
454 assert_eq!(rows[0].url, "https://b/");
455 assert_eq!(rows[1].url, "https://a/");
456 }
457
458 #[tokio::test]
459 async fn upsert_writes_raw_html_when_provided() {
460 let db = fresh_db().await;
461 let raw = b"<html>body</html>".to_vec();
462 let mut page = sample("uhash", "https://example.com/p");
463 page.raw_html = Some(raw.clone());
464 upsert(&db, page).await.unwrap();
465
466 let blob = raw_html_bytes(&db, "uhash")
467 .await
468 .unwrap()
469 .expect("blob written");
470 assert!(!blob.is_empty());
471 let decoded = zstd::stream::decode_all(blob.as_slice()).unwrap();
472 assert_eq!(decoded, raw);
473 }
474
475 #[tokio::test]
476 async fn upsert_leaves_raw_html_null_when_none() {
477 let db = fresh_db().await;
478 let page = sample("uhash", "https://example.com/p");
479 upsert(&db, page).await.unwrap();
480
481 assert!(raw_html_bytes(&db, "uhash").await.unwrap().is_none());
482 }
483
484 #[tokio::test]
485 async fn stats_counts_expired() {
486 let db = fresh_db().await;
487 let mut fresh = sample("h_fresh", "https://a/");
488 fresh.expires_at = Some(2_000_000_000);
489 let mut stale = sample("h_stale", "https://b/");
490 stale.expires_at = Some(1);
491 upsert(&db, fresh).await.unwrap();
492 upsert(&db, stale).await.unwrap();
493 let s = stats(&db, 1_700_000_000).await.unwrap();
494 assert_eq!(s.entry_count, 2);
495 assert!(s.total_extracted_bytes > 0);
496 assert_eq!(s.expired_count, 1);
497 }
498}