codex_memory/
storage.rs

1use crate::error::Result;
2use crate::models::{Memory, StorageStats};
3use sqlx::{PgPool, Row};
4use uuid::Uuid;
5
6/// Simple storage repository for text data
7pub struct Storage {
8    pool: PgPool,
9}
10
11impl Storage {
12    /// Create a new storage instance
13    pub fn new(pool: PgPool) -> Self {
14        Self { pool }
15    }
16
17    /// Store text with context and summary (deduplication by hash)
18    pub async fn store(
19        &self,
20        content: &str,
21        context: String,
22        summary: String,
23        tags: Option<Vec<String>>,
24    ) -> Result<Uuid> {
25        let memory = Memory::new(content.to_string(), context, summary, tags);
26
27        // Simple content deduplication based on content hash
28        let result: Uuid = sqlx::query_scalar(
29            r#"
30            INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
31            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
32            ON CONFLICT (content_hash) DO UPDATE SET
33                updated_at = EXCLUDED.updated_at
34            RETURNING id
35            "#
36        )
37        .bind(memory.id)
38        .bind(memory.content)
39        .bind(memory.content_hash)
40        .bind(&memory.tags)
41        .bind(&memory.context)
42        .bind(&memory.summary)
43        .bind(memory.chunk_index)
44        .bind(memory.total_chunks)
45        .bind(memory.parent_id)
46        .bind(memory.created_at)
47        .bind(memory.updated_at)
48        .fetch_one(&self.pool)
49        .await?;
50
51        Ok(result)
52    }
53
54    /// Store a chunk with parent reference
55    pub async fn store_chunk(
56        &self,
57        content: &str,
58        context: String,
59        summary: String,
60        tags: Option<Vec<String>>,
61        chunk_index: i32,
62        total_chunks: i32,
63        parent_id: Uuid,
64    ) -> Result<Uuid> {
65        let memory = Memory::new_chunk(
66            content.to_string(),
67            context,
68            summary,
69            tags,
70            chunk_index,
71            total_chunks,
72            parent_id,
73        );
74
75        // Insert chunk (no deduplication for chunks to preserve order)
76        let result: Uuid = sqlx::query_scalar(
77            r#"
78            INSERT INTO memories (id, content, content_hash, context_fingerprint, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
79            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
80            RETURNING id
81            "#
82        )
83        .bind(memory.id)
84        .bind(memory.content)
85        .bind(memory.content_hash)
86        .bind(memory.context_fingerprint)
87        .bind(&memory.tags)
88        .bind(&memory.context)
89        .bind(&memory.summary)
90        .bind(memory.chunk_index)
91        .bind(memory.total_chunks)
92        .bind(memory.parent_id)
93        .bind(memory.created_at)
94        .bind(memory.updated_at)
95        .fetch_one(&self.pool)
96        .await?;
97
98        Ok(result)
99    }
100
101    /// Get memory by ID
102    pub async fn get(&self, id: Uuid) -> Result<Option<Memory>> {
103        let row = sqlx::query(
104            r#"
105            SELECT 
106                id,
107                content,
108                content_hash,
109                context_fingerprint,
110                tags,
111                context,
112                summary,
113                chunk_index,
114                total_chunks,
115                parent_id,
116                created_at,
117                updated_at
118            FROM memories
119            WHERE id = $1
120            "#,
121        )
122        .bind(id)
123        .fetch_optional(&self.pool)
124        .await?;
125
126        match row {
127            Some(row) => {
128                let memory = Memory {
129                    id: row.get("id"),
130                    content: row.get("content"),
131                    content_hash: row.get("content_hash"),
132                    context_fingerprint: row.get("context_fingerprint"),
133                    tags: row.get("tags"),
134                    context: row.get("context"),
135                    summary: row.get("summary"),
136                    chunk_index: row.get("chunk_index"),
137                    total_chunks: row.get("total_chunks"),
138                    parent_id: row.get("parent_id"),
139                    created_at: row.get("created_at"),
140                    updated_at: row.get("updated_at"),
141                };
142                Ok(Some(memory))
143            }
144            None => Ok(None),
145        }
146    }
147
148    /// Get all chunks for a parent document, ordered by chunk index
149    pub async fn get_chunks(&self, parent_id: Uuid) -> Result<Vec<Memory>> {
150        let rows = sqlx::query(
151            r#"
152            SELECT 
153                id,
154                content,
155                content_hash,
156                context_fingerprint,
157                tags,
158                context,
159                summary,
160                chunk_index,
161                total_chunks,
162                parent_id,
163                created_at,
164                updated_at
165            FROM memories
166            WHERE parent_id = $1
167            ORDER BY chunk_index ASC
168            "#,
169        )
170        .bind(parent_id)
171        .fetch_all(&self.pool)
172        .await?;
173
174        let memories = rows
175            .into_iter()
176            .map(|row| Memory {
177                id: row.get("id"),
178                content: row.get("content"),
179                content_hash: row.get("content_hash"),
180                context_fingerprint: row.get("context_fingerprint"),
181                tags: row.get("tags"),
182                context: row.get("context"),
183                summary: row.get("summary"),
184                chunk_index: row.get("chunk_index"),
185                total_chunks: row.get("total_chunks"),
186                parent_id: row.get("parent_id"),
187                created_at: row.get("created_at"),
188                updated_at: row.get("updated_at"),
189            })
190            .collect();
191
192        Ok(memories)
193    }
194
195    /// Delete memory by ID
196    pub async fn delete(&self, id: Uuid) -> Result<bool> {
197        let result = sqlx::query("DELETE FROM memories WHERE id = $1")
198            .bind(id)
199            .execute(&self.pool)
200            .await?;
201
202        Ok(result.rows_affected() > 0)
203    }
204
205    /// Get basic storage statistics
206    pub async fn stats(&self) -> Result<StorageStats> {
207        let row = sqlx::query(
208            r#"
209            SELECT 
210                COUNT(*) as total_memories,
211                pg_size_pretty(pg_total_relation_size('memories')) as table_size,
212                MAX(created_at) as last_memory_created
213            FROM memories
214            "#,
215        )
216        .fetch_one(&self.pool)
217        .await?;
218
219        let stats = StorageStats {
220            total_memories: row.get("total_memories"),
221            table_size: row.get("table_size"),
222            last_memory_created: row.get("last_memory_created"),
223        };
224
225        Ok(stats)
226    }
227
228    /// List recent memories (for basic browsing)
229    pub async fn list_recent(&self, limit: i64) -> Result<Vec<Memory>> {
230        let rows = sqlx::query(
231            r#"
232            SELECT 
233                id,
234                content,
235                content_hash,
236                context_fingerprint,
237                tags,
238                context,
239                summary,
240                chunk_index,
241                total_chunks,
242                parent_id,
243                created_at,
244                updated_at
245            FROM memories
246            ORDER BY created_at DESC
247            LIMIT $1
248            "#,
249        )
250        .bind(limit)
251        .fetch_all(&self.pool)
252        .await?;
253
254        let memories = rows
255            .into_iter()
256            .map(|row| Memory {
257                id: row.get("id"),
258                content: row.get("content"),
259                content_hash: row.get("content_hash"),
260                context_fingerprint: row.get("context_fingerprint"),
261                tags: row.get("tags"),
262                context: row.get("context"),
263                summary: row.get("summary"),
264                chunk_index: row.get("chunk_index"),
265                total_chunks: row.get("total_chunks"),
266                parent_id: row.get("parent_id"),
267                created_at: row.get("created_at"),
268                updated_at: row.get("updated_at"),
269            })
270            .collect();
271
272        Ok(memories)
273    }
274
275    /// Find memories with similar content but different contexts
276    /// Implements transfer appropriate processing - matching content with varying contexts
277    pub async fn find_similar_content(
278        &self,
279        content_hash: &str,
280        limit: i64,
281    ) -> Result<Vec<Memory>> {
282        let rows = sqlx::query(
283            r#"
284            SELECT 
285                id,
286                content,
287                content_hash,
288                context_fingerprint,
289                tags,
290                context,
291                summary,
292                chunk_index,
293                total_chunks,
294                parent_id,
295                created_at,
296                updated_at
297            FROM memories
298            WHERE content_hash = $1
299            ORDER BY created_at DESC
300            LIMIT $2
301            "#,
302        )
303        .bind(content_hash)
304        .bind(limit)
305        .fetch_all(&self.pool)
306        .await?;
307
308        let memories = rows
309            .into_iter()
310            .map(|row| Memory {
311                id: row.get("id"),
312                content: row.get("content"),
313                content_hash: row.get("content_hash"),
314                context_fingerprint: row.get("context_fingerprint"),
315                tags: row.get("tags"),
316                context: row.get("context"),
317                summary: row.get("summary"),
318                chunk_index: row.get("chunk_index"),
319                total_chunks: row.get("total_chunks"),
320                parent_id: row.get("parent_id"),
321                created_at: row.get("created_at"),
322                updated_at: row.get("updated_at"),
323            })
324            .collect();
325
326        Ok(memories)
327    }
328
329    /// Check if a specific content+context combination already exists
330    /// Used for precise deduplication while preserving context variations
331    pub async fn exists_with_context(
332        &self,
333        content_hash: &str,
334        context_fingerprint: &str,
335    ) -> Result<bool> {
336        let count: i64 = sqlx::query_scalar(
337            "SELECT COUNT(*) FROM memories WHERE content_hash = $1 AND context_fingerprint = $2",
338        )
339        .bind(content_hash)
340        .bind(context_fingerprint)
341        .fetch_one(&self.pool)
342        .await?;
343
344        Ok(count > 0)
345    }
346
347    /// Get context statistics showing how many different contexts exist for the same content
348    /// Useful for understanding encoding specificity utilization
349    pub async fn get_context_stats(&self) -> Result<Vec<(String, i64, i64)>> {
350        let rows = sqlx::query(
351            r#"
352            SELECT 
353                content_hash,
354                COUNT(*) as total_variations,
355                COUNT(DISTINCT context_fingerprint) as unique_contexts
356            FROM memories 
357            GROUP BY content_hash
358            HAVING COUNT(*) > 1
359            ORDER BY total_variations DESC
360            LIMIT 50
361            "#,
362        )
363        .fetch_all(&self.pool)
364        .await?;
365
366        let stats = rows
367            .into_iter()
368            .map(|row| {
369                (
370                    row.get::<String, _>("content_hash"),
371                    row.get::<i64, _>("total_variations"),
372                    row.get::<i64, _>("unique_contexts"),
373                )
374            })
375            .collect();
376
377        Ok(stats)
378    }
379}