1use std::time::Duration;
5
6use sha2::{Digest, Sha256};
7use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
8use sqlx::SqlitePool;
9
10use crate::types::{WebSearchError, WebSearchResult};
11
12const SCHEMA_VERSION: u32 = 1;
15
16pub struct WebSearchCache {
17 pool: SqlitePool,
18 ttl: Duration,
19}
20
21impl WebSearchCache {
22 pub async fn open(path: &str, ttl: Duration) -> Result<Self, WebSearchError> {
23 let opts = SqliteConnectOptions::new()
24 .filename(path)
25 .create_if_missing(true);
26 let max_conns = if path == ":memory:" { 1 } else { 4 };
30 let pool = SqlitePoolOptions::new()
31 .max_connections(max_conns)
32 .connect_with(opts)
33 .await
34 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
35 sqlx::query(
36 "CREATE TABLE IF NOT EXISTS web_search_cache (\
37 key TEXT PRIMARY KEY,\
38 provider TEXT NOT NULL,\
39 query TEXT NOT NULL,\
40 result_json TEXT NOT NULL,\
41 inserted_at INTEGER NOT NULL\
42 )",
43 )
44 .execute(&pool)
45 .await
46 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
47 sqlx::query("CREATE INDEX IF NOT EXISTS idx_web_search_cache_inserted ON web_search_cache(inserted_at)")
48 .execute(&pool)
49 .await
50 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
51 Ok(Self { pool, ttl })
52 }
53
54 pub async fn open_memory(ttl: Duration) -> Result<Self, WebSearchError> {
56 Self::open(":memory:", ttl).await
57 }
58
59 pub fn ttl(&self) -> Duration {
60 self.ttl
61 }
62
63 pub fn key(provider: &str, query: &str, params_canonical: &str) -> String {
64 let mut h = Sha256::new();
65 h.update(SCHEMA_VERSION.to_le_bytes());
66 h.update(provider.as_bytes());
67 h.update(b"\0");
68 h.update(query.as_bytes());
69 h.update(b"\0");
70 h.update(params_canonical.as_bytes());
71 hex::encode(h.finalize())
72 }
73
74 pub async fn get(&self, key: &str) -> Result<Option<WebSearchResult>, WebSearchError> {
75 if self.ttl.as_secs() == 0 {
76 return Ok(None);
77 }
78 let now = chrono::Utc::now().timestamp();
79 let cutoff = now - self.ttl.as_secs() as i64;
80 let row: Option<(String, i64)> =
81 sqlx::query_as("SELECT result_json, inserted_at FROM web_search_cache WHERE key = ?")
82 .bind(key)
83 .fetch_optional(&self.pool)
84 .await
85 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
86 match row {
87 Some((json, inserted_at)) if inserted_at >= cutoff => {
88 let mut parsed: WebSearchResult = serde_json::from_str(&json)
89 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
90 parsed.from_cache = true;
91 Ok(Some(parsed))
92 }
93 _ => Ok(None),
94 }
95 }
96
97 pub async fn put(&self, key: &str, value: &WebSearchResult) -> Result<(), WebSearchError> {
98 let json =
99 serde_json::to_string(value).map_err(|e| WebSearchError::Cache(e.to_string()))?;
100 sqlx::query(
101 "INSERT OR REPLACE INTO web_search_cache(key, provider, query, result_json, inserted_at) VALUES(?, ?, ?, ?, ?)",
102 )
103 .bind(key)
104 .bind(&value.provider)
105 .bind(&value.query)
106 .bind(json)
107 .bind(chrono::Utc::now().timestamp())
108 .execute(&self.pool)
109 .await
110 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
111 Ok(())
112 }
113
114 pub async fn purge_expired(&self) -> Result<u64, WebSearchError> {
115 let cutoff = chrono::Utc::now().timestamp() - self.ttl.as_secs() as i64;
116 let res = sqlx::query("DELETE FROM web_search_cache WHERE inserted_at < ?")
117 .bind(cutoff)
118 .execute(&self.pool)
119 .await
120 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
121 Ok(res.rows_affected())
122 }
123
124 pub async fn clear(&self) -> Result<u64, WebSearchError> {
130 let res = sqlx::query("DELETE FROM web_search_cache")
131 .execute(&self.pool)
132 .await
133 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
134 Ok(res.rows_affected())
135 }
136
137 pub async fn stats(&self) -> Result<CacheStats, WebSearchError> {
142 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM web_search_cache")
143 .fetch_one(&self.pool)
144 .await
145 .map_err(|e| WebSearchError::Cache(e.to_string()))?;
146 Ok(CacheStats {
147 entries: row.0.max(0) as u64,
148 })
149 }
150}
151
152#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
154pub struct CacheStats {
155 pub entries: u64,
156}