1use crate::error::Result;
2use crate::models::{Memory, StorageStats};
3use sqlx::{PgPool, Row};
4use uuid::Uuid;
5
6pub struct Storage {
8 pool: PgPool,
9}
10
11impl Storage {
12 pub fn new(pool: PgPool) -> Self {
14 Self { pool }
15 }
16
17 pub async fn store(
19 &self,
20 content: &str,
21 context: String,
22 summary: String,
23 tags: Option<Vec<String>>,
24 ) -> Result<Uuid> {
25 let memory = Memory::new(content.to_string(), context, summary, tags);
26
27 let result: Uuid = sqlx::query_scalar(
29 r#"
30 INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
31 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
32 ON CONFLICT (content_hash) DO UPDATE SET
33 updated_at = EXCLUDED.updated_at
34 RETURNING id
35 "#
36 )
37 .bind(memory.id)
38 .bind(memory.content)
39 .bind(memory.content_hash)
40 .bind(&memory.tags)
41 .bind(&memory.context)
42 .bind(&memory.summary)
43 .bind(memory.chunk_index)
44 .bind(memory.total_chunks)
45 .bind(memory.parent_id)
46 .bind(memory.created_at)
47 .bind(memory.updated_at)
48 .fetch_one(&self.pool)
49 .await?;
50
51 Ok(result)
52 }
53
54 pub async fn store_chunk(
56 &self,
57 content: &str,
58 context: String,
59 summary: String,
60 tags: Option<Vec<String>>,
61 chunk_index: i32,
62 total_chunks: i32,
63 parent_id: Uuid,
64 ) -> Result<Uuid> {
65 let memory = Memory::new_chunk(
66 content.to_string(),
67 context,
68 summary,
69 tags,
70 chunk_index,
71 total_chunks,
72 parent_id,
73 );
74
75 let result: Uuid = sqlx::query_scalar(
77 r#"
78 INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
79 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
80 RETURNING id
81 "#
82 )
83 .bind(memory.id)
84 .bind(memory.content)
85 .bind(memory.content_hash)
86 .bind(&memory.tags)
87 .bind(&memory.context)
88 .bind(&memory.summary)
89 .bind(memory.chunk_index)
90 .bind(memory.total_chunks)
91 .bind(memory.parent_id)
92 .bind(memory.created_at)
93 .bind(memory.updated_at)
94 .fetch_one(&self.pool)
95 .await?;
96
97 Ok(result)
98 }
99
100 pub async fn get(&self, id: Uuid) -> Result<Option<Memory>> {
102 let row = sqlx::query(
103 r#"
104 SELECT
105 id,
106 content,
107 content_hash,
108 tags,
109 context,
110 summary,
111 chunk_index,
112 total_chunks,
113 parent_id,
114 created_at,
115 updated_at
116 FROM memories
117 WHERE id = $1
118 "#,
119 )
120 .bind(id)
121 .fetch_optional(&self.pool)
122 .await?;
123
124 match row {
125 Some(row) => {
126 let memory = Memory {
127 id: row.get("id"),
128 content: row.get("content"),
129 content_hash: row.get("content_hash"),
130 tags: row.get("tags"),
131 context: row.get("context"),
132 summary: row.get("summary"),
133 chunk_index: row.get("chunk_index"),
134 total_chunks: row.get("total_chunks"),
135 parent_id: row.get("parent_id"),
136 created_at: row.get("created_at"),
137 updated_at: row.get("updated_at"),
138 };
139 Ok(Some(memory))
140 }
141 None => Ok(None),
142 }
143 }
144
145 pub async fn get_chunks(&self, parent_id: Uuid) -> Result<Vec<Memory>> {
147 let rows = sqlx::query(
148 r#"
149 SELECT
150 id,
151 content,
152 content_hash,
153 tags,
154 context,
155 summary,
156 chunk_index,
157 total_chunks,
158 parent_id,
159 created_at,
160 updated_at
161 FROM memories
162 WHERE parent_id = $1
163 ORDER BY chunk_index ASC
164 "#,
165 )
166 .bind(parent_id)
167 .fetch_all(&self.pool)
168 .await?;
169
170 let memories = rows
171 .into_iter()
172 .map(|row| Memory {
173 id: row.get("id"),
174 content: row.get("content"),
175 content_hash: row.get("content_hash"),
176 tags: row.get("tags"),
177 context: row.get("context"),
178 summary: row.get("summary"),
179 chunk_index: row.get("chunk_index"),
180 total_chunks: row.get("total_chunks"),
181 parent_id: row.get("parent_id"),
182 created_at: row.get("created_at"),
183 updated_at: row.get("updated_at"),
184 })
185 .collect();
186
187 Ok(memories)
188 }
189
190 pub async fn delete(&self, id: Uuid) -> Result<bool> {
192 let result = sqlx::query("DELETE FROM memories WHERE id = $1")
193 .bind(id)
194 .execute(&self.pool)
195 .await?;
196
197 Ok(result.rows_affected() > 0)
198 }
199
200 pub async fn stats(&self) -> Result<StorageStats> {
202 let row = sqlx::query(
203 r#"
204 SELECT
205 COUNT(*) as total_memories,
206 pg_size_pretty(pg_total_relation_size('memories')) as table_size,
207 MAX(created_at) as last_memory_created
208 FROM memories
209 "#,
210 )
211 .fetch_one(&self.pool)
212 .await?;
213
214 let stats = StorageStats {
215 total_memories: row.get("total_memories"),
216 table_size: row.get("table_size"),
217 last_memory_created: row.get("last_memory_created"),
218 };
219
220 Ok(stats)
221 }
222
223 pub async fn list_recent(&self, limit: i64) -> Result<Vec<Memory>> {
225 let rows = sqlx::query(
226 r#"
227 SELECT
228 id,
229 content,
230 content_hash,
231 tags,
232 context,
233 summary,
234 chunk_index,
235 total_chunks,
236 parent_id,
237 created_at,
238 updated_at
239 FROM memories
240 ORDER BY created_at DESC
241 LIMIT $1
242 "#,
243 )
244 .bind(limit)
245 .fetch_all(&self.pool)
246 .await?;
247
248 let memories = rows
249 .into_iter()
250 .map(|row| Memory {
251 id: row.get("id"),
252 content: row.get("content"),
253 content_hash: row.get("content_hash"),
254 tags: row.get("tags"),
255 context: row.get("context"),
256 summary: row.get("summary"),
257 chunk_index: row.get("chunk_index"),
258 total_chunks: row.get("total_chunks"),
259 parent_id: row.get("parent_id"),
260 created_at: row.get("created_at"),
261 updated_at: row.get("updated_at"),
262 })
263 .collect();
264
265 Ok(memories)
266 }
267
268 pub async fn find_similar_content(
271 &self,
272 content_hash: &str,
273 limit: i64,
274 ) -> Result<Vec<Memory>> {
275 let rows = sqlx::query(
276 r#"
277 SELECT
278 id,
279 content,
280 content_hash,
281 tags,
282 context,
283 summary,
284 chunk_index,
285 total_chunks,
286 parent_id,
287 created_at,
288 updated_at
289 FROM memories
290 WHERE content_hash = $1
291 ORDER BY created_at DESC
292 LIMIT $2
293 "#,
294 )
295 .bind(content_hash)
296 .bind(limit)
297 .fetch_all(&self.pool)
298 .await?;
299
300 let memories = rows
301 .into_iter()
302 .map(|row| Memory {
303 id: row.get("id"),
304 content: row.get("content"),
305 content_hash: row.get("content_hash"),
306 tags: row.get("tags"),
307 context: row.get("context"),
308 summary: row.get("summary"),
309 chunk_index: row.get("chunk_index"),
310 total_chunks: row.get("total_chunks"),
311 parent_id: row.get("parent_id"),
312 created_at: row.get("created_at"),
313 updated_at: row.get("updated_at"),
314 })
315 .collect();
316
317 Ok(memories)
318 }
319
320 pub async fn exists_with_content(&self, content_hash: &str) -> Result<bool> {
323 let count: i64 =
324 sqlx::query_scalar("SELECT COUNT(*) FROM memories WHERE content_hash = $1")
325 .bind(content_hash)
326 .fetch_one(&self.pool)
327 .await?;
328
329 Ok(count > 0)
330 }
331
332 pub async fn get_content_stats(&self) -> Result<Vec<(String, i64)>> {
335 let rows = sqlx::query(
336 r#"
337 SELECT
338 content_hash,
339 COUNT(*) as total_count
340 FROM memories
341 GROUP BY content_hash
342 HAVING COUNT(*) > 1
343 ORDER BY total_count DESC
344 LIMIT 50
345 "#,
346 )
347 .fetch_all(&self.pool)
348 .await?;
349
350 let stats = rows
351 .into_iter()
352 .map(|row| {
353 (
354 row.get::<String, _>("content_hash"),
355 row.get::<i64, _>("total_count"),
356 )
357 })
358 .collect();
359
360 Ok(stats)
361 }
362}