codex_memory/memory/
repository.rs

1use super::error::{MemoryError, Result};
2use super::models::*;
3use chrono::Utc;
4use pgvector::Vector;
5use sqlx::{PgPool, Postgres, Row, Transaction};
6use std::collections::HashMap;
7use std::time::Instant;
8use tracing::{debug, info};
9use uuid::Uuid;
10
11pub struct MemoryRepository {
12    pool: PgPool,
13}
14
15impl MemoryRepository {
16    pub fn new(pool: PgPool) -> Self {
17        Self { pool }
18    }
19
20    pub fn pool(&self) -> &PgPool {
21        &self.pool
22    }
23
24    pub async fn create_memory(&self, request: CreateMemoryRequest) -> Result<Memory> {
25        let id = Uuid::new_v4();
26        let content_hash = Memory::calculate_content_hash(&request.content);
27        let tier = request.tier.unwrap_or(MemoryTier::Working);
28
29        // Check for duplicates (skip in test mode)
30        let skip_duplicate_check =
31            std::env::var("SKIP_DUPLICATE_CHECK").unwrap_or_else(|_| "false".to_string()) == "true";
32
33        if !skip_duplicate_check {
34            let duplicate_exists = sqlx::query_scalar::<_, bool>(
35                "SELECT EXISTS(SELECT 1 FROM memories WHERE content_hash = $1 AND tier = $2 AND status = 'active')"
36            )
37            .bind(&content_hash)
38            .bind(tier)
39            .fetch_one(&self.pool)
40            .await?;
41
42            if duplicate_exists {
43                return Err(MemoryError::DuplicateContent {
44                    tier: format!("{tier:?}"),
45                });
46            }
47        }
48
49        let embedding = request.embedding.map(Vector::from);
50
51        let memory = sqlx::query_as::<_, Memory>(
52            r#"
53            INSERT INTO memories (
54                id, content, content_hash, embedding, tier, status, 
55                importance_score, metadata, parent_id, expires_at
56            )
57            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
58            RETURNING *
59            "#,
60        )
61        .bind(id)
62        .bind(&request.content)
63        .bind(&content_hash)
64        .bind(embedding)
65        .bind(tier)
66        .bind(MemoryStatus::Active)
67        .bind(request.importance_score.unwrap_or(0.5))
68        .bind(request.metadata.unwrap_or(serde_json::json!({})))
69        .bind(request.parent_id)
70        .bind(request.expires_at)
71        .fetch_one(&self.pool)
72        .await?;
73
74        info!("Created memory {} in tier {:?}", memory.id, memory.tier);
75        Ok(memory)
76    }
77
78    pub async fn get_memory(&self, id: Uuid) -> Result<Memory> {
79        let memory = sqlx::query_as::<_, Memory>(
80            r#"
81            UPDATE memories 
82            SET access_count = access_count + 1, 
83                last_accessed_at = NOW()
84            WHERE id = $1 AND status = 'active'
85            RETURNING *
86            "#,
87        )
88        .bind(id)
89        .fetch_optional(&self.pool)
90        .await?
91        .ok_or_else(|| MemoryError::NotFound { id: id.to_string() })?;
92
93        debug!("Retrieved memory {} from tier {:?}", id, memory.tier);
94        Ok(memory)
95    }
96
97    pub async fn update_memory(&self, id: Uuid, request: UpdateMemoryRequest) -> Result<Memory> {
98        let mut tx = self.pool.begin().await?;
99
100        // Get current memory
101        let current = sqlx::query_as::<_, Memory>(
102            "SELECT * FROM memories WHERE id = $1 AND status = 'active' FOR UPDATE",
103        )
104        .bind(id)
105        .fetch_optional(&mut *tx)
106        .await?
107        .ok_or_else(|| MemoryError::NotFound { id: id.to_string() })?;
108
109        // Update fields
110        let content = request.content.as_ref().unwrap_or(&current.content);
111        let content_hash = if request.content.is_some() {
112            Memory::calculate_content_hash(content)
113        } else {
114            current.content_hash.clone()
115        };
116
117        let embedding = request.embedding.map(Vector::from).or(current.embedding);
118        let tier = request.tier.unwrap_or(current.tier);
119        let importance_score = request.importance_score.unwrap_or(current.importance_score);
120        let metadata = request.metadata.as_ref().unwrap_or(&current.metadata);
121        let expires_at = request.expires_at.or(current.expires_at);
122
123        let updated = sqlx::query_as::<_, Memory>(
124            r#"
125            UPDATE memories 
126            SET content = $2, content_hash = $3, embedding = $4, tier = $5,
127                importance_score = $6, metadata = $7, expires_at = $8,
128                updated_at = NOW()
129            WHERE id = $1
130            RETURNING *
131            "#,
132        )
133        .bind(id)
134        .bind(content)
135        .bind(&content_hash)
136        .bind(embedding)
137        .bind(tier)
138        .bind(importance_score)
139        .bind(metadata)
140        .bind(expires_at)
141        .fetch_one(&mut *tx)
142        .await?;
143
144        // Record tier migration if changed
145        if current.tier != tier {
146            self.record_migration(
147                &mut tx,
148                id,
149                current.tier,
150                tier,
151                Some("Manual update".to_string()),
152            )
153            .await?;
154        }
155
156        tx.commit().await?;
157        info!("Updated memory {}", id);
158        Ok(updated)
159    }
160
161    pub async fn delete_memory(&self, id: Uuid) -> Result<()> {
162        let result = sqlx::query(
163            "UPDATE memories SET status = 'deleted' WHERE id = $1 AND status = 'active'",
164        )
165        .bind(id)
166        .execute(&self.pool)
167        .await?;
168
169        if result.rows_affected() == 0 {
170            return Err(MemoryError::NotFound { id: id.to_string() });
171        }
172
173        info!("Soft deleted memory {}", id);
174        Ok(())
175    }
176
177    pub async fn search_memories(&self, request: SearchRequest) -> Result<SearchResponse> {
178        let start_time = Instant::now();
179
180        let search_type = request
181            .search_type
182            .as_ref()
183            .unwrap_or(&SearchType::Semantic)
184            .clone();
185        let limit = request.limit.unwrap_or(10);
186        let offset = request.offset.unwrap_or(0);
187
188        let results = match search_type {
189            SearchType::Semantic => self.semantic_search(&request).await?,
190            SearchType::Temporal => self.temporal_search(&request).await?,
191            SearchType::Hybrid => self.hybrid_search(&request).await?,
192            SearchType::FullText => self.fulltext_search(&request).await?,
193        };
194
195        let total_count = if request.include_facets.unwrap_or(false) {
196            Some(self.count_search_results(&request).await?)
197        } else {
198            None
199        };
200
201        let facets = if request.include_facets.unwrap_or(false) {
202            Some(self.generate_search_facets(&request).await?)
203        } else {
204            None
205        };
206
207        let suggestions = if request.query_text.is_some() {
208            Some(self.generate_query_suggestions(&request).await?)
209        } else {
210            None
211        };
212
213        let next_cursor = if results.len() as i32 >= limit {
214            Some(self.generate_cursor(offset + limit as i64, &request))
215        } else {
216            None
217        };
218
219        let execution_time_ms = start_time.elapsed().as_millis() as u64;
220
221        Ok(SearchResponse {
222            results,
223            total_count,
224            facets,
225            suggestions,
226            next_cursor,
227            execution_time_ms,
228        })
229    }
230
231    async fn semantic_search(&self, request: &SearchRequest) -> Result<Vec<SearchResult>> {
232        let query_embedding = if let Some(ref embedding) = request.query_embedding {
233            Vector::from(embedding.clone())
234        } else {
235            return Err(MemoryError::InvalidRequest {
236                message: "Query embedding is required for semantic search".to_string(),
237            });
238        };
239
240        let limit = request.limit.unwrap_or(10);
241        let offset = request.offset.unwrap_or(0);
242        let threshold = request.similarity_threshold.unwrap_or(0.7);
243
244        let mut query_parts = vec![
245            "SELECT m.*, 1 - (m.embedding <=> $1) as similarity_score".to_string(),
246            "FROM memories m".to_string(),
247            "WHERE m.status = 'active' AND m.embedding IS NOT NULL".to_string(),
248        ];
249
250        // Add filters
251        self.add_filters(request, &mut query_parts)?;
252
253        query_parts.push(format!("AND 1 - (m.embedding <=> $1) >= {threshold}"));
254        query_parts.push("ORDER BY similarity_score DESC".to_string());
255        query_parts.push(format!("LIMIT {limit} OFFSET {offset}"));
256
257        let query = query_parts.join(" ");
258        let rows = sqlx::query(&query)
259            .bind(&query_embedding)
260            .fetch_all(&self.pool)
261            .await?;
262
263        self.build_search_results(rows, request).await
264    }
265
266    async fn temporal_search(&self, request: &SearchRequest) -> Result<Vec<SearchResult>> {
267        let limit = request.limit.unwrap_or(10);
268        let offset = request.offset.unwrap_or(0);
269
270        let mut query_parts = vec![
271            "SELECT m.*, 0.0 as similarity_score".to_string(),
272            "FROM memories m".to_string(),
273            "WHERE m.status = 'active'".to_string(),
274        ];
275
276        // Add filters
277        self.add_filters(request, &mut query_parts)?;
278
279        query_parts.push("ORDER BY m.created_at DESC, m.updated_at DESC".to_string());
280        query_parts.push(format!("LIMIT {limit} OFFSET {offset}"));
281
282        let query = query_parts.join(" ");
283        let rows = sqlx::query(&query).fetch_all(&self.pool).await?;
284
285        self.build_search_results(rows, request).await
286    }
287
288    async fn hybrid_search(&self, request: &SearchRequest) -> Result<Vec<SearchResult>> {
289        let weights = request.hybrid_weights.as_ref().unwrap_or(&HybridWeights {
290            semantic_weight: 0.4,
291            temporal_weight: 0.3,
292            importance_weight: 0.2,
293            access_frequency_weight: 0.1,
294        });
295
296        let query_embedding = if let Some(ref embedding) = request.query_embedding {
297            Vector::from(embedding.clone())
298        } else {
299            return Err(MemoryError::InvalidRequest {
300                message: "Query embedding is required for hybrid search".to_string(),
301            });
302        };
303
304        let limit = request.limit.unwrap_or(10);
305        let offset = request.offset.unwrap_or(0);
306        let threshold = request.similarity_threshold.unwrap_or(0.5);
307
308        let query = format!(
309            r#"
310            SELECT m.*,
311                1 - (m.embedding <=> $1) as similarity_score,
312                EXTRACT(EPOCH FROM (NOW() - m.created_at))::float / 86400 as days_old,
313                m.importance_score,
314                COALESCE(m.access_count, 0) as access_count,
315                (
316                    {} * (1 - (m.embedding <=> $1)) +
317                    {} * GREATEST(0, 1 - (EXTRACT(EPOCH FROM (NOW() - m.created_at))::float / 2592000)) + -- 30 days
318                    {} * m.importance_score +
319                    {} * LEAST(1.0, COALESCE(m.access_count, 0)::float / 100.0)
320                ) as combined_score
321            FROM memories m
322            WHERE m.status = 'active'
323                AND m.embedding IS NOT NULL
324                AND 1 - (m.embedding <=> $1) >= {}
325            ORDER BY combined_score DESC
326            LIMIT {} OFFSET {}
327            "#,
328            weights.semantic_weight,
329            weights.temporal_weight,
330            weights.importance_weight,
331            weights.access_frequency_weight,
332            threshold,
333            limit,
334            offset
335        );
336
337        let rows = sqlx::query(&query)
338            .bind(&query_embedding)
339            .fetch_all(&self.pool)
340            .await?;
341
342        self.build_search_results(rows, request).await
343    }
344
345    async fn fulltext_search(&self, request: &SearchRequest) -> Result<Vec<SearchResult>> {
346        let query_text =
347            request
348                .query_text
349                .as_ref()
350                .ok_or_else(|| MemoryError::InvalidRequest {
351                    message: "Query text is required for full-text search".to_string(),
352                })?;
353
354        let limit = request.limit.unwrap_or(10);
355        let offset = request.offset.unwrap_or(0);
356
357        let query = format!(
358            r#"
359            SELECT m.*,
360                ts_rank_cd(to_tsvector('english', m.content), plainto_tsquery('english', $1)) as similarity_score
361            FROM memories m
362            WHERE m.status = 'active'
363                AND to_tsvector('english', m.content) @@ plainto_tsquery('english', $1)
364            ORDER BY similarity_score DESC
365            LIMIT {limit} OFFSET {offset}
366            "#
367        );
368
369        let rows = sqlx::query(&query)
370            .bind(query_text)
371            .fetch_all(&self.pool)
372            .await?;
373
374        self.build_search_results(rows, request).await
375    }
376
377    fn add_filters(&self, request: &SearchRequest, query_parts: &mut Vec<String>) -> Result<()> {
378        if let Some(tier) = &request.tier {
379            query_parts.push(format!("AND m.tier = '{tier:?}'"));
380        }
381
382        if let Some(date_range) = &request.date_range {
383            if let Some(start) = &date_range.start {
384                query_parts.push(format!(
385                    "AND m.created_at >= '{}'",
386                    start.format("%Y-%m-%d %H:%M:%S")
387                ));
388            }
389            if let Some(end) = &date_range.end {
390                query_parts.push(format!(
391                    "AND m.created_at <= '{}'",
392                    end.format("%Y-%m-%d %H:%M:%S")
393                ));
394            }
395        }
396
397        if let Some(importance_range) = &request.importance_range {
398            if let Some(min) = importance_range.min {
399                query_parts.push(format!("AND m.importance_score >= {min}"));
400            }
401            if let Some(max) = importance_range.max {
402                query_parts.push(format!("AND m.importance_score <= {max}"));
403            }
404        }
405
406        Ok(())
407    }
408
409    async fn build_search_results(
410        &self,
411        rows: Vec<sqlx::postgres::PgRow>,
412        request: &SearchRequest,
413    ) -> Result<Vec<SearchResult>> {
414        let mut results = Vec::new();
415        let explain_score = request.explain_score.unwrap_or(false);
416
417        for row in rows {
418            let memory = Memory {
419                id: row.try_get("id")?,
420                content: row.try_get("content")?,
421                content_hash: row.try_get("content_hash")?,
422                embedding: row.try_get("embedding")?,
423                tier: row.try_get("tier")?,
424                status: row.try_get("status")?,
425                importance_score: row.try_get("importance_score")?,
426                access_count: row.try_get("access_count")?,
427                last_accessed_at: row.try_get("last_accessed_at")?,
428                metadata: row.try_get("metadata")?,
429                parent_id: row.try_get("parent_id")?,
430                created_at: row.try_get("created_at")?,
431                updated_at: row.try_get("updated_at")?,
432                expires_at: row.try_get("expires_at")?,
433            };
434
435            let similarity_score: f32 = row.try_get("similarity_score").unwrap_or(0.0);
436            let combined_score: f32 = row.try_get("combined_score").unwrap_or(similarity_score);
437            let temporal_score: Option<f32> = row.try_get("temporal_score").ok();
438            let access_frequency_score: Option<f32> = row.try_get("access_frequency_score").ok();
439            let importance_score = memory.importance_score; // Extract before move
440
441            let score_explanation = if explain_score {
442                Some(ScoreExplanation {
443                    semantic_contribution: similarity_score * 0.4,
444                    temporal_contribution: temporal_score.unwrap_or(0.0) * 0.3,
445                    importance_contribution: (importance_score * 0.2) as f32,
446                    access_frequency_contribution: access_frequency_score.unwrap_or(0.0) * 0.1,
447                    total_score: combined_score,
448                    factors: vec![
449                        "semantic similarity".to_string(),
450                        "recency".to_string(),
451                        "importance".to_string(),
452                    ],
453                })
454            } else {
455                None
456            };
457
458            results.push(SearchResult {
459                memory,
460                similarity_score,
461                temporal_score,
462                importance_score,
463                access_frequency_score,
464                combined_score,
465                score_explanation,
466            });
467        }
468
469        debug!("Built {} search results", results.len());
470        Ok(results)
471    }
472
473    async fn count_search_results(&self, _request: &SearchRequest) -> Result<i64> {
474        // Simplified count - would implement filtering logic similar to search
475        let count: i64 =
476            sqlx::query_scalar("SELECT COUNT(*) FROM memories WHERE status = 'active'")
477                .fetch_one(&self.pool)
478                .await?;
479        Ok(count)
480    }
481
482    async fn generate_search_facets(&self, _request: &SearchRequest) -> Result<SearchFacets> {
483        // Generate tier facets
484        let tier_rows: Vec<(String, i64)> = sqlx::query_as(
485            "SELECT tier, COUNT(*) FROM memories WHERE status = 'active' GROUP BY tier",
486        )
487        .fetch_all(&self.pool)
488        .await?;
489
490        let mut tiers = HashMap::new();
491        for (tier_str, count) in tier_rows {
492            if let Ok(tier) = tier_str.parse::<MemoryTier>() {
493                tiers.insert(tier, count);
494            }
495        }
496
497        // Generate date histogram (simplified)
498        let date_histogram = vec![DateBucket {
499            date: Utc::now(),
500            count: 10,
501            interval: "day".to_string(),
502        }];
503
504        // Generate importance ranges
505        let importance_ranges = vec![
506            ImportanceRange {
507                min: 0.0,
508                max: 0.3,
509                count: 5,
510                label: "Low".to_string(),
511            },
512            ImportanceRange {
513                min: 0.3,
514                max: 0.7,
515                count: 15,
516                label: "Medium".to_string(),
517            },
518            ImportanceRange {
519                min: 0.7,
520                max: 1.0,
521                count: 8,
522                label: "High".to_string(),
523            },
524        ];
525
526        Ok(SearchFacets {
527            tiers,
528            date_histogram,
529            importance_ranges,
530            tags: HashMap::new(), // Would extract from metadata
531        })
532    }
533
534    async fn generate_query_suggestions(&self, _request: &SearchRequest) -> Result<Vec<String>> {
535        // Simplified implementation - would use ML model or query history
536        Ok(vec![
537            "recent code changes".to_string(),
538            "function definitions".to_string(),
539            "error handling patterns".to_string(),
540        ])
541    }
542
543    fn generate_cursor(&self, offset: i64, _request: &SearchRequest) -> String {
544        // Simple cursor implementation - would encode more search state in production
545        use base64::{engine::general_purpose::STANDARD, Engine};
546        STANDARD.encode(format!("offset:{offset}"))
547    }
548
549    // Legacy method for backward compatibility
550    pub async fn search_memories_simple(
551        &self,
552        request: SearchRequest,
553    ) -> Result<Vec<SearchResult>> {
554        let response = self.search_memories(request).await?;
555        Ok(response.results)
556    }
557
558    pub async fn get_memories_by_tier(
559        &self,
560        tier: MemoryTier,
561        limit: Option<i64>,
562    ) -> Result<Vec<Memory>> {
563        let limit = limit.unwrap_or(100);
564
565        let memories = sqlx::query_as::<_, Memory>(
566            r#"
567            SELECT * FROM memories 
568            WHERE tier = $1 AND status = 'active'
569            ORDER BY importance_score DESC, updated_at DESC
570            LIMIT $2
571            "#,
572        )
573        .bind(tier)
574        .bind(limit)
575        .fetch_all(&self.pool)
576        .await?;
577
578        Ok(memories)
579    }
580
581    pub async fn migrate_memory(
582        &self,
583        id: Uuid,
584        to_tier: MemoryTier,
585        reason: Option<String>,
586    ) -> Result<Memory> {
587        let mut tx = self.pool.begin().await?;
588
589        // Get current memory with lock
590        let current = sqlx::query_as::<_, Memory>(
591            "SELECT * FROM memories WHERE id = $1 AND status = 'active' FOR UPDATE",
592        )
593        .bind(id)
594        .fetch_optional(&mut *tx)
595        .await?
596        .ok_or_else(|| MemoryError::NotFound { id: id.to_string() })?;
597
598        if current.tier == to_tier {
599            return Ok(current);
600        }
601
602        // Validate tier transition
603        let valid_transition = match (current.tier, to_tier) {
604            (MemoryTier::Working, MemoryTier::Warm)
605            | (MemoryTier::Working, MemoryTier::Cold)
606            | (MemoryTier::Warm, MemoryTier::Cold)
607            | (MemoryTier::Warm, MemoryTier::Working)
608            | (MemoryTier::Cold, MemoryTier::Warm) => true,
609            _ => false,
610        };
611
612        if !valid_transition {
613            return Err(MemoryError::InvalidTierTransition {
614                from: format!("{:?}", current.tier),
615                to: format!("{to_tier:?}"),
616            });
617        }
618
619        let start = std::time::Instant::now();
620
621        // Update tier
622        let updated = sqlx::query_as::<_, Memory>(
623            r#"
624            UPDATE memories 
625            SET tier = $2, status = 'active', updated_at = NOW()
626            WHERE id = $1
627            RETURNING *
628            "#,
629        )
630        .bind(id)
631        .bind(to_tier)
632        .fetch_one(&mut *tx)
633        .await?;
634
635        let duration_ms = start.elapsed().as_millis() as i32;
636
637        // Record migration
638        self.record_migration(&mut tx, id, current.tier, to_tier, reason)
639            .await?;
640
641        tx.commit().await?;
642
643        info!(
644            "Migrated memory {} from {:?} to {:?} in {}ms",
645            id, current.tier, to_tier, duration_ms
646        );
647
648        Ok(updated)
649    }
650
651    async fn record_migration(
652        &self,
653        tx: &mut Transaction<'_, Postgres>,
654        memory_id: Uuid,
655        from_tier: MemoryTier,
656        to_tier: MemoryTier,
657        reason: Option<String>,
658    ) -> Result<()> {
659        sqlx::query(
660            r#"
661            INSERT INTO migration_history (memory_id, from_tier, to_tier, migration_reason, success)
662            VALUES ($1, $2, $3, $4, true)
663            "#,
664        )
665        .bind(memory_id)
666        .bind(from_tier)
667        .bind(to_tier)
668        .bind(reason)
669        .execute(&mut **tx)
670        .await?;
671
672        Ok(())
673    }
674
675    pub async fn get_expired_memories(&self) -> Result<Vec<Memory>> {
676        let memories = sqlx::query_as::<_, Memory>(
677            r#"
678            SELECT * FROM memories 
679            WHERE status = 'active' 
680                AND expires_at IS NOT NULL 
681                AND expires_at < NOW()
682            "#,
683        )
684        .fetch_all(&self.pool)
685        .await?;
686
687        Ok(memories)
688    }
689
690    pub async fn cleanup_expired_memories(&self) -> Result<u64> {
691        let result = sqlx::query(
692            r#"
693            UPDATE memories 
694            SET status = 'deleted' 
695            WHERE status = 'active' 
696                AND expires_at IS NOT NULL 
697                AND expires_at < NOW()
698            "#,
699        )
700        .execute(&self.pool)
701        .await?;
702
703        let count = result.rows_affected();
704        if count > 0 {
705            info!("Cleaned up {} expired memories", count);
706        }
707
708        Ok(count)
709    }
710
711    pub async fn get_migration_candidates(
712        &self,
713        tier: MemoryTier,
714        limit: i64,
715    ) -> Result<Vec<Memory>> {
716        let query = match tier {
717            MemoryTier::Working => {
718                r#"
719                SELECT * FROM memories 
720                WHERE tier = 'working' 
721                    AND status = 'active'
722                    AND (
723                        importance_score < 0.3 
724                        OR (last_accessed_at IS NOT NULL 
725                            AND last_accessed_at < NOW() - INTERVAL '24 hours')
726                    )
727                ORDER BY importance_score ASC, last_accessed_at ASC NULLS FIRST
728                LIMIT $1
729                "#
730            }
731            MemoryTier::Warm => {
732                r#"
733                SELECT * FROM memories 
734                WHERE tier = 'warm' 
735                    AND status = 'active'
736                    AND importance_score < 0.1 
737                    AND updated_at < NOW() - INTERVAL '7 days'
738                ORDER BY importance_score ASC, updated_at ASC
739                LIMIT $1
740                "#
741            }
742            MemoryTier::Cold => {
743                return Ok(Vec::new());
744            }
745        };
746
747        let memories = sqlx::query_as::<_, Memory>(query)
748            .bind(limit)
749            .fetch_all(&self.pool)
750            .await?;
751
752        Ok(memories)
753    }
754
755    pub async fn get_statistics(&self) -> Result<MemoryStatistics> {
756        let stats = sqlx::query_as::<_, MemoryStatistics>(
757            r#"
758            SELECT 
759                COUNT(*) FILTER (WHERE tier = 'working' AND status = 'active') as working_count,
760                COUNT(*) FILTER (WHERE tier = 'warm' AND status = 'active') as warm_count,
761                COUNT(*) FILTER (WHERE tier = 'cold' AND status = 'active') as cold_count,
762                COUNT(*) FILTER (WHERE status = 'active') as total_active,
763                COUNT(*) FILTER (WHERE status = 'deleted') as total_deleted,
764                AVG(importance_score) FILTER (WHERE status = 'active') as avg_importance,
765                MAX(access_count) FILTER (WHERE status = 'active') as max_access_count,
766                CAST(AVG(access_count) FILTER (WHERE status = 'active') AS FLOAT8) as avg_access_count
767            FROM memories
768            "#,
769        )
770        .fetch_one(&self.pool)
771        .await?;
772
773        Ok(stats)
774    }
775}
776
777#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize, serde::Deserialize)]
778pub struct MemoryStatistics {
779    pub working_count: Option<i64>,
780    pub warm_count: Option<i64>,
781    pub cold_count: Option<i64>,
782    pub total_active: Option<i64>,
783    pub total_deleted: Option<i64>,
784    pub avg_importance: Option<f64>,
785    pub max_access_count: Option<i32>,
786    pub avg_access_count: Option<f64>,
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792
793    #[test]
794    fn test_content_hash_generation() {
795        let content = "This is a test memory content";
796        let hash1 = Memory::calculate_content_hash(content);
797        let hash2 = Memory::calculate_content_hash(content);
798
799        assert_eq!(hash1, hash2);
800        assert_eq!(hash1.len(), 64); // SHA-256 produces 64 hex characters
801    }
802
803    #[test]
804    fn test_should_migrate() {
805        let mut memory = Memory::default();
806
807        // Working tier with low importance should migrate
808        memory.tier = MemoryTier::Working;
809        memory.importance_score = 0.2;
810        assert!(memory.should_migrate());
811
812        // Working tier with high importance should not migrate
813        memory.importance_score = 0.8;
814        assert!(!memory.should_migrate());
815
816        // Cold tier should never migrate
817        memory.tier = MemoryTier::Cold;
818        memory.importance_score = 0.0;
819        assert!(!memory.should_migrate());
820    }
821
822    #[test]
823    fn test_next_tier() {
824        let mut memory = Memory::default();
825
826        memory.tier = MemoryTier::Working;
827        assert_eq!(memory.next_tier(), Some(MemoryTier::Warm));
828
829        memory.tier = MemoryTier::Warm;
830        assert_eq!(memory.next_tier(), Some(MemoryTier::Cold));
831
832        memory.tier = MemoryTier::Cold;
833        assert_eq!(memory.next_tier(), None);
834    }
835}