1use crate::error::Result;
2use crate::models::{Memory, StorageStats};
3use sqlx::{PgPool, Row};
4use uuid::Uuid;
5
6pub struct Storage {
8 pool: PgPool,
9}
10
11impl Storage {
12 pub fn new(pool: PgPool) -> Self {
14 Self { pool }
15 }
16
17 pub async fn store(
19 &self,
20 content: &str,
21 context: String,
22 summary: String,
23 tags: Option<Vec<String>>,
24 ) -> Result<Uuid> {
25 let memory = Memory::new(content.to_string(), context, summary, tags);
26
27 let result: Uuid = sqlx::query_scalar(
29 r#"
30 INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
31 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
32 ON CONFLICT (content_hash) DO UPDATE SET
33 updated_at = EXCLUDED.updated_at
34 RETURNING id
35 "#
36 )
37 .bind(memory.id)
38 .bind(memory.content)
39 .bind(memory.content_hash)
40 .bind(&memory.tags)
41 .bind(&memory.context)
42 .bind(&memory.summary)
43 .bind(memory.chunk_index)
44 .bind(memory.total_chunks)
45 .bind(memory.parent_id)
46 .bind(memory.created_at)
47 .bind(memory.updated_at)
48 .fetch_one(&self.pool)
49 .await?;
50
51 Ok(result)
52 }
53
54 pub async fn store_chunk(
56 &self,
57 content: &str,
58 context: String,
59 summary: String,
60 tags: Option<Vec<String>>,
61 chunk_index: i32,
62 total_chunks: i32,
63 parent_id: Uuid,
64 ) -> Result<Uuid> {
65 let memory = Memory::new_chunk(
66 content.to_string(),
67 context,
68 summary,
69 tags,
70 chunk_index,
71 total_chunks,
72 parent_id,
73 );
74
75 let result: Uuid = sqlx::query_scalar(
77 r#"
78 INSERT INTO memories (id, content, content_hash, context_fingerprint, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
79 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
80 RETURNING id
81 "#
82 )
83 .bind(memory.id)
84 .bind(memory.content)
85 .bind(memory.content_hash)
86 .bind(memory.context_fingerprint)
87 .bind(&memory.tags)
88 .bind(&memory.context)
89 .bind(&memory.summary)
90 .bind(memory.chunk_index)
91 .bind(memory.total_chunks)
92 .bind(memory.parent_id)
93 .bind(memory.created_at)
94 .bind(memory.updated_at)
95 .fetch_one(&self.pool)
96 .await?;
97
98 Ok(result)
99 }
100
101 pub async fn get(&self, id: Uuid) -> Result<Option<Memory>> {
103 let row = sqlx::query(
104 r#"
105 SELECT
106 id,
107 content,
108 content_hash,
109 context_fingerprint,
110 tags,
111 context,
112 summary,
113 chunk_index,
114 total_chunks,
115 parent_id,
116 created_at,
117 updated_at
118 FROM memories
119 WHERE id = $1
120 "#,
121 )
122 .bind(id)
123 .fetch_optional(&self.pool)
124 .await?;
125
126 match row {
127 Some(row) => {
128 let memory = Memory {
129 id: row.get("id"),
130 content: row.get("content"),
131 content_hash: row.get("content_hash"),
132 context_fingerprint: row.get("context_fingerprint"),
133 tags: row.get("tags"),
134 context: row.get("context"),
135 summary: row.get("summary"),
136 chunk_index: row.get("chunk_index"),
137 total_chunks: row.get("total_chunks"),
138 parent_id: row.get("parent_id"),
139 created_at: row.get("created_at"),
140 updated_at: row.get("updated_at"),
141 };
142 Ok(Some(memory))
143 }
144 None => Ok(None),
145 }
146 }
147
148 pub async fn get_chunks(&self, parent_id: Uuid) -> Result<Vec<Memory>> {
150 let rows = sqlx::query(
151 r#"
152 SELECT
153 id,
154 content,
155 content_hash,
156 context_fingerprint,
157 tags,
158 context,
159 summary,
160 chunk_index,
161 total_chunks,
162 parent_id,
163 created_at,
164 updated_at
165 FROM memories
166 WHERE parent_id = $1
167 ORDER BY chunk_index ASC
168 "#,
169 )
170 .bind(parent_id)
171 .fetch_all(&self.pool)
172 .await?;
173
174 let memories = rows
175 .into_iter()
176 .map(|row| Memory {
177 id: row.get("id"),
178 content: row.get("content"),
179 content_hash: row.get("content_hash"),
180 context_fingerprint: row.get("context_fingerprint"),
181 tags: row.get("tags"),
182 context: row.get("context"),
183 summary: row.get("summary"),
184 chunk_index: row.get("chunk_index"),
185 total_chunks: row.get("total_chunks"),
186 parent_id: row.get("parent_id"),
187 created_at: row.get("created_at"),
188 updated_at: row.get("updated_at"),
189 })
190 .collect();
191
192 Ok(memories)
193 }
194
195 pub async fn delete(&self, id: Uuid) -> Result<bool> {
197 let result = sqlx::query("DELETE FROM memories WHERE id = $1")
198 .bind(id)
199 .execute(&self.pool)
200 .await?;
201
202 Ok(result.rows_affected() > 0)
203 }
204
205 pub async fn stats(&self) -> Result<StorageStats> {
207 let row = sqlx::query(
208 r#"
209 SELECT
210 COUNT(*) as total_memories,
211 pg_size_pretty(pg_total_relation_size('memories')) as table_size,
212 MAX(created_at) as last_memory_created
213 FROM memories
214 "#,
215 )
216 .fetch_one(&self.pool)
217 .await?;
218
219 let stats = StorageStats {
220 total_memories: row.get("total_memories"),
221 table_size: row.get("table_size"),
222 last_memory_created: row.get("last_memory_created"),
223 };
224
225 Ok(stats)
226 }
227
228 pub async fn list_recent(&self, limit: i64) -> Result<Vec<Memory>> {
230 let rows = sqlx::query(
231 r#"
232 SELECT
233 id,
234 content,
235 content_hash,
236 context_fingerprint,
237 tags,
238 context,
239 summary,
240 chunk_index,
241 total_chunks,
242 parent_id,
243 created_at,
244 updated_at
245 FROM memories
246 ORDER BY created_at DESC
247 LIMIT $1
248 "#,
249 )
250 .bind(limit)
251 .fetch_all(&self.pool)
252 .await?;
253
254 let memories = rows
255 .into_iter()
256 .map(|row| Memory {
257 id: row.get("id"),
258 content: row.get("content"),
259 content_hash: row.get("content_hash"),
260 context_fingerprint: row.get("context_fingerprint"),
261 tags: row.get("tags"),
262 context: row.get("context"),
263 summary: row.get("summary"),
264 chunk_index: row.get("chunk_index"),
265 total_chunks: row.get("total_chunks"),
266 parent_id: row.get("parent_id"),
267 created_at: row.get("created_at"),
268 updated_at: row.get("updated_at"),
269 })
270 .collect();
271
272 Ok(memories)
273 }
274
275 pub async fn find_similar_content(
278 &self,
279 content_hash: &str,
280 limit: i64,
281 ) -> Result<Vec<Memory>> {
282 let rows = sqlx::query(
283 r#"
284 SELECT
285 id,
286 content,
287 content_hash,
288 context_fingerprint,
289 tags,
290 context,
291 summary,
292 chunk_index,
293 total_chunks,
294 parent_id,
295 created_at,
296 updated_at
297 FROM memories
298 WHERE content_hash = $1
299 ORDER BY created_at DESC
300 LIMIT $2
301 "#,
302 )
303 .bind(content_hash)
304 .bind(limit)
305 .fetch_all(&self.pool)
306 .await?;
307
308 let memories = rows
309 .into_iter()
310 .map(|row| Memory {
311 id: row.get("id"),
312 content: row.get("content"),
313 content_hash: row.get("content_hash"),
314 context_fingerprint: row.get("context_fingerprint"),
315 tags: row.get("tags"),
316 context: row.get("context"),
317 summary: row.get("summary"),
318 chunk_index: row.get("chunk_index"),
319 total_chunks: row.get("total_chunks"),
320 parent_id: row.get("parent_id"),
321 created_at: row.get("created_at"),
322 updated_at: row.get("updated_at"),
323 })
324 .collect();
325
326 Ok(memories)
327 }
328
329 pub async fn exists_with_context(
332 &self,
333 content_hash: &str,
334 context_fingerprint: &str,
335 ) -> Result<bool> {
336 let count: i64 = sqlx::query_scalar(
337 "SELECT COUNT(*) FROM memories WHERE content_hash = $1 AND context_fingerprint = $2",
338 )
339 .bind(content_hash)
340 .bind(context_fingerprint)
341 .fetch_one(&self.pool)
342 .await?;
343
344 Ok(count > 0)
345 }
346
347 pub async fn get_context_stats(&self) -> Result<Vec<(String, i64, i64)>> {
350 let rows = sqlx::query(
351 r#"
352 SELECT
353 content_hash,
354 COUNT(*) as total_variations,
355 COUNT(DISTINCT context_fingerprint) as unique_contexts
356 FROM memories
357 GROUP BY content_hash
358 HAVING COUNT(*) > 1
359 ORDER BY total_variations DESC
360 LIMIT 50
361 "#,
362 )
363 .fetch_all(&self.pool)
364 .await?;
365
366 let stats = rows
367 .into_iter()
368 .map(|row| {
369 (
370 row.get::<String, _>("content_hash"),
371 row.get::<i64, _>("total_variations"),
372 row.get::<i64, _>("unique_contexts"),
373 )
374 })
375 .collect();
376
377 Ok(stats)
378 }
379}