1use crate::error::Result;
2use crate::models::{Memory, StorageStats};
3use sqlx::{PgPool, Row};
4use uuid::Uuid;
5
6pub struct Storage {
8 pool: PgPool,
9}
10
11impl Storage {
12 pub fn new(pool: PgPool) -> Self {
14 Self { pool }
15 }
16
17 pub async fn store(
19 &self,
20 content: &str,
21 context: String,
22 summary: String,
23 tags: Option<Vec<String>>,
24 ) -> Result<Uuid> {
25 let memory = Memory::new(content.to_string(), context, summary, tags);
26
27 let result: Uuid = sqlx::query_scalar(
29 r#"
30 INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
31 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
32 ON CONFLICT (content_hash) DO UPDATE SET
33 context = EXCLUDED.context,
34 summary = EXCLUDED.summary,
35 tags = EXCLUDED.tags,
36 updated_at = EXCLUDED.updated_at
37 RETURNING id
38 "#
39 )
40 .bind(memory.id)
41 .bind(memory.content)
42 .bind(memory.content_hash)
43 .bind(&memory.tags)
44 .bind(&memory.context)
45 .bind(&memory.summary)
46 .bind(memory.chunk_index)
47 .bind(memory.total_chunks)
48 .bind(memory.parent_id)
49 .bind(memory.created_at)
50 .bind(memory.updated_at)
51 .fetch_one(&self.pool)
52 .await?;
53
54 Ok(result)
55 }
56
57 pub async fn store_chunk(
59 &self,
60 content: &str,
61 context: String,
62 summary: String,
63 tags: Option<Vec<String>>,
64 chunk_index: i32,
65 total_chunks: i32,
66 parent_id: Uuid,
67 ) -> Result<Uuid> {
68 let memory = Memory::new_chunk(
69 content.to_string(),
70 context,
71 summary,
72 tags,
73 chunk_index,
74 total_chunks,
75 parent_id,
76 );
77
78 let result: Uuid = sqlx::query_scalar(
80 r#"
81 INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
82 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
83 RETURNING id
84 "#
85 )
86 .bind(memory.id)
87 .bind(memory.content)
88 .bind(memory.content_hash)
89 .bind(&memory.tags)
90 .bind(&memory.context)
91 .bind(&memory.summary)
92 .bind(memory.chunk_index)
93 .bind(memory.total_chunks)
94 .bind(memory.parent_id)
95 .bind(memory.created_at)
96 .bind(memory.updated_at)
97 .fetch_one(&self.pool)
98 .await?;
99
100 Ok(result)
101 }
102
103 pub async fn get(&self, id: Uuid) -> Result<Option<Memory>> {
105 let row = sqlx::query(
106 r#"
107 SELECT
108 id,
109 content,
110 content_hash,
111 tags,
112 context,
113 summary,
114 chunk_index,
115 total_chunks,
116 parent_id,
117 created_at,
118 updated_at
119 FROM memories
120 WHERE id = $1
121 "#,
122 )
123 .bind(id)
124 .fetch_optional(&self.pool)
125 .await?;
126
127 match row {
128 Some(row) => {
129 let memory = Memory {
130 id: row.get("id"),
131 content: row.get("content"),
132 content_hash: row.get("content_hash"),
133 tags: row.get("tags"),
134 context: row.get("context"),
135 summary: row.get("summary"),
136 chunk_index: row.get("chunk_index"),
137 total_chunks: row.get("total_chunks"),
138 parent_id: row.get("parent_id"),
139 created_at: row.get("created_at"),
140 updated_at: row.get("updated_at"),
141 };
142 Ok(Some(memory))
143 }
144 None => Ok(None),
145 }
146 }
147
148 pub async fn get_chunks(&self, parent_id: Uuid) -> Result<Vec<Memory>> {
150 let rows = sqlx::query(
151 r#"
152 SELECT
153 id,
154 content,
155 content_hash,
156 tags,
157 context,
158 summary,
159 chunk_index,
160 total_chunks,
161 parent_id,
162 created_at,
163 updated_at
164 FROM memories
165 WHERE parent_id = $1
166 ORDER BY chunk_index ASC
167 "#,
168 )
169 .bind(parent_id)
170 .fetch_all(&self.pool)
171 .await?;
172
173 let memories = rows
174 .into_iter()
175 .map(|row| Memory {
176 id: row.get("id"),
177 content: row.get("content"),
178 content_hash: row.get("content_hash"),
179 tags: row.get("tags"),
180 context: row.get("context"),
181 summary: row.get("summary"),
182 chunk_index: row.get("chunk_index"),
183 total_chunks: row.get("total_chunks"),
184 parent_id: row.get("parent_id"),
185 created_at: row.get("created_at"),
186 updated_at: row.get("updated_at"),
187 })
188 .collect();
189
190 Ok(memories)
191 }
192
193 pub async fn delete(&self, id: Uuid) -> Result<bool> {
195 let result = sqlx::query("DELETE FROM memories WHERE id = $1")
196 .bind(id)
197 .execute(&self.pool)
198 .await?;
199
200 Ok(result.rows_affected() > 0)
201 }
202
203 pub async fn stats(&self) -> Result<StorageStats> {
205 let row = sqlx::query(
206 r#"
207 SELECT
208 COUNT(*) as total_memories,
209 pg_size_pretty(pg_total_relation_size('memories')) as table_size,
210 MAX(created_at) as last_memory_created
211 FROM memories
212 "#,
213 )
214 .fetch_one(&self.pool)
215 .await?;
216
217 let stats = StorageStats {
218 total_memories: row.get("total_memories"),
219 table_size: row.get("table_size"),
220 last_memory_created: row.get("last_memory_created"),
221 };
222
223 Ok(stats)
224 }
225
226 pub async fn list_recent(&self, limit: i64) -> Result<Vec<Memory>> {
228 let rows = sqlx::query(
229 r#"
230 SELECT
231 id,
232 content,
233 content_hash,
234 tags,
235 context,
236 summary,
237 chunk_index,
238 total_chunks,
239 parent_id,
240 created_at,
241 updated_at
242 FROM memories
243 ORDER BY created_at DESC
244 LIMIT $1
245 "#,
246 )
247 .bind(limit)
248 .fetch_all(&self.pool)
249 .await?;
250
251 let memories = rows
252 .into_iter()
253 .map(|row| Memory {
254 id: row.get("id"),
255 content: row.get("content"),
256 content_hash: row.get("content_hash"),
257 tags: row.get("tags"),
258 context: row.get("context"),
259 summary: row.get("summary"),
260 chunk_index: row.get("chunk_index"),
261 total_chunks: row.get("total_chunks"),
262 parent_id: row.get("parent_id"),
263 created_at: row.get("created_at"),
264 updated_at: row.get("updated_at"),
265 })
266 .collect();
267
268 Ok(memories)
269 }
270
271 pub async fn find_similar_content(
274 &self,
275 content_hash: &str,
276 limit: i64,
277 ) -> Result<Vec<Memory>> {
278 let rows = sqlx::query(
279 r#"
280 SELECT
281 id,
282 content,
283 content_hash,
284 tags,
285 context,
286 summary,
287 chunk_index,
288 total_chunks,
289 parent_id,
290 created_at,
291 updated_at
292 FROM memories
293 WHERE content_hash = $1
294 ORDER BY created_at DESC
295 LIMIT $2
296 "#,
297 )
298 .bind(content_hash)
299 .bind(limit)
300 .fetch_all(&self.pool)
301 .await?;
302
303 let memories = rows
304 .into_iter()
305 .map(|row| Memory {
306 id: row.get("id"),
307 content: row.get("content"),
308 content_hash: row.get("content_hash"),
309 tags: row.get("tags"),
310 context: row.get("context"),
311 summary: row.get("summary"),
312 chunk_index: row.get("chunk_index"),
313 total_chunks: row.get("total_chunks"),
314 parent_id: row.get("parent_id"),
315 created_at: row.get("created_at"),
316 updated_at: row.get("updated_at"),
317 })
318 .collect();
319
320 Ok(memories)
321 }
322
323 pub async fn exists_with_content(&self, content_hash: &str) -> Result<bool> {
326 let count: i64 =
327 sqlx::query_scalar("SELECT COUNT(*) FROM memories WHERE content_hash = $1")
328 .bind(content_hash)
329 .fetch_one(&self.pool)
330 .await?;
331
332 Ok(count > 0)
333 }
334
335 pub async fn get_content_stats(&self) -> Result<Vec<(String, i64)>> {
338 let rows = sqlx::query(
339 r#"
340 SELECT
341 content_hash,
342 COUNT(*) as total_count
343 FROM memories
344 GROUP BY content_hash
345 HAVING COUNT(*) > 1
346 ORDER BY total_count DESC
347 LIMIT 50
348 "#,
349 )
350 .fetch_all(&self.pool)
351 .await?;
352
353 let stats = rows
354 .into_iter()
355 .map(|row| {
356 (
357 row.get::<String, _>("content_hash"),
358 row.get::<i64, _>("total_count"),
359 )
360 })
361 .collect();
362
363 Ok(stats)
364 }
365}