codex_memory/memory/
models.rs

1use chrono::{DateTime, Utc};
2use pgvector::Vector;
3use serde::{Deserialize, Serialize};
4use sqlx::FromRow;
5use std::str::FromStr;
6use uuid::Uuid;
7
8#[derive(Debug, Clone)]
9pub struct SerializableVector(pub Option<Vector>);
10
11impl Serialize for SerializableVector {
12    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13    where
14        S: serde::Serializer,
15    {
16        match &self.0 {
17            Some(v) => v.as_slice().serialize(serializer),
18            None => serializer.serialize_none(),
19        }
20    }
21}
22
23impl<'de> Deserialize<'de> for SerializableVector {
24    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25    where
26        D: serde::Deserializer<'de>,
27    {
28        let opt_vec: Option<Vec<f32>> = Option::deserialize(deserializer)?;
29        Ok(SerializableVector(opt_vec.map(Vector::from)))
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
34#[sqlx(type_name = "memory_tier", rename_all = "lowercase")]
35pub enum MemoryTier {
36    Working,
37    Warm,
38    Cold,
39}
40
41impl FromStr for MemoryTier {
42    type Err = String;
43
44    fn from_str(s: &str) -> Result<Self, Self::Err> {
45        match s.to_lowercase().as_str() {
46            "working" => Ok(MemoryTier::Working),
47            "warm" => Ok(MemoryTier::Warm),
48            "cold" => Ok(MemoryTier::Cold),
49            _ => Err(format!("Invalid memory tier: {s}")),
50        }
51    }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
55#[sqlx(type_name = "memory_status", rename_all = "lowercase")]
56pub enum MemoryStatus {
57    Active,
58    Migrating,
59    Archived,
60    Deleted,
61}
62
63#[derive(Debug, Clone, FromRow)]
64pub struct Memory {
65    pub id: Uuid,
66    pub content: String,
67    pub content_hash: String,
68    pub embedding: Option<Vector>,
69    pub tier: MemoryTier,
70    pub status: MemoryStatus,
71    pub importance_score: f32,
72    pub access_count: i32,
73    pub last_accessed_at: Option<DateTime<Utc>>,
74    pub metadata: serde_json::Value,
75    pub parent_id: Option<Uuid>,
76    pub created_at: DateTime<Utc>,
77    pub updated_at: DateTime<Utc>,
78    pub expires_at: Option<DateTime<Utc>>,
79}
80
81impl Serialize for Memory {
82    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83    where
84        S: serde::Serializer,
85    {
86        use serde::ser::SerializeStruct;
87        let mut state = serializer.serialize_struct("Memory", 15)?;
88        state.serialize_field("id", &self.id)?;
89        state.serialize_field("content", &self.content)?;
90        state.serialize_field("content_hash", &self.content_hash)?;
91        state.serialize_field("embedding", &self.embedding.as_ref().map(|v| v.as_slice()))?;
92        state.serialize_field("tier", &self.tier)?;
93        state.serialize_field("status", &self.status)?;
94        state.serialize_field("importance_score", &self.importance_score)?;
95        state.serialize_field("access_count", &self.access_count)?;
96        state.serialize_field("last_accessed_at", &self.last_accessed_at)?;
97        state.serialize_field("metadata", &self.metadata)?;
98        state.serialize_field("parent_id", &self.parent_id)?;
99        state.serialize_field("created_at", &self.created_at)?;
100        state.serialize_field("updated_at", &self.updated_at)?;
101        state.serialize_field("expires_at", &self.expires_at)?;
102        state.end()
103    }
104}
105
106impl<'de> Deserialize<'de> for Memory {
107    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        // For now, we'll just return a default memory since we don't need to deserialize from JSON
112        Ok(Memory::default())
113    }
114}
115
116#[derive(Debug, Clone, FromRow)]
117pub struct MemorySummary {
118    pub id: Uuid,
119    pub summary_level: String,
120    pub summary_content: String,
121    pub summary_embedding: Option<Vector>,
122    pub start_time: DateTime<Utc>,
123    pub end_time: DateTime<Utc>,
124    pub memory_count: i32,
125    pub metadata: serde_json::Value,
126    pub created_at: DateTime<Utc>,
127    pub updated_at: DateTime<Utc>,
128}
129
130impl Serialize for MemorySummary {
131    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
132    where
133        S: serde::Serializer,
134    {
135        use serde::ser::SerializeStruct;
136        let mut state = serializer.serialize_struct("MemorySummary", 10)?;
137        state.serialize_field("id", &self.id)?;
138        state.serialize_field("summary_level", &self.summary_level)?;
139        state.serialize_field("summary_content", &self.summary_content)?;
140        state.serialize_field(
141            "summary_embedding",
142            &self.summary_embedding.as_ref().map(|v| v.as_slice()),
143        )?;
144        state.serialize_field("start_time", &self.start_time)?;
145        state.serialize_field("end_time", &self.end_time)?;
146        state.serialize_field("memory_count", &self.memory_count)?;
147        state.serialize_field("metadata", &self.metadata)?;
148        state.serialize_field("created_at", &self.created_at)?;
149        state.serialize_field("updated_at", &self.updated_at)?;
150        state.end()
151    }
152}
153
154impl<'de> Deserialize<'de> for MemorySummary {
155    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
156    where
157        D: serde::Deserializer<'de>,
158    {
159        unimplemented!("MemorySummary deserialization not needed")
160    }
161}
162
163#[derive(Debug, Clone, FromRow)]
164pub struct MemoryCluster {
165    pub id: Uuid,
166    pub cluster_name: String,
167    pub centroid_embedding: Vector,
168    pub concept_tags: Vec<String>,
169    pub member_count: i32,
170    pub tier: MemoryTier,
171    pub metadata: serde_json::Value,
172    pub created_at: DateTime<Utc>,
173    pub updated_at: DateTime<Utc>,
174}
175
176impl Serialize for MemoryCluster {
177    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178    where
179        S: serde::Serializer,
180    {
181        use serde::ser::SerializeStruct;
182        let mut state = serializer.serialize_struct("MemoryCluster", 9)?;
183        state.serialize_field("id", &self.id)?;
184        state.serialize_field("cluster_name", &self.cluster_name)?;
185        state.serialize_field("centroid_embedding", &self.centroid_embedding.as_slice())?;
186        state.serialize_field("concept_tags", &self.concept_tags)?;
187        state.serialize_field("member_count", &self.member_count)?;
188        state.serialize_field("tier", &self.tier)?;
189        state.serialize_field("metadata", &self.metadata)?;
190        state.serialize_field("created_at", &self.created_at)?;
191        state.serialize_field("updated_at", &self.updated_at)?;
192        state.end()
193    }
194}
195
196impl<'de> Deserialize<'de> for MemoryCluster {
197    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
198    where
199        D: serde::Deserializer<'de>,
200    {
201        unimplemented!("MemoryCluster deserialization not needed")
202    }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
206pub struct MigrationHistoryEntry {
207    pub id: Uuid,
208    pub memory_id: Uuid,
209    pub from_tier: MemoryTier,
210    pub to_tier: MemoryTier,
211    pub migration_reason: Option<String>,
212    pub migrated_at: DateTime<Utc>,
213    pub migration_duration_ms: Option<i32>,
214    pub success: bool,
215    pub error_message: Option<String>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct CreateMemoryRequest {
220    pub content: String,
221    pub embedding: Option<Vec<f32>>,
222    pub tier: Option<MemoryTier>,
223    pub importance_score: Option<f32>,
224    pub metadata: Option<serde_json::Value>,
225    pub parent_id: Option<Uuid>,
226    pub expires_at: Option<DateTime<Utc>>,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct UpdateMemoryRequest {
231    pub content: Option<String>,
232    pub embedding: Option<Vec<f32>>,
233    pub tier: Option<MemoryTier>,
234    pub importance_score: Option<f32>,
235    pub metadata: Option<serde_json::Value>,
236    pub expires_at: Option<DateTime<Utc>>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct SearchRequest {
241    // Query options
242    pub query_text: Option<String>,
243    pub query_embedding: Option<Vec<f32>>,
244
245    // Search type configuration
246    pub search_type: Option<SearchType>,
247    pub hybrid_weights: Option<HybridWeights>,
248
249    // Filtering options
250    pub tier: Option<MemoryTier>,
251    pub date_range: Option<DateRange>,
252    pub importance_range: Option<RangeFilter<f32>>,
253    pub metadata_filters: Option<serde_json::Value>,
254    pub tags: Option<Vec<String>>,
255
256    // Result configuration
257    pub limit: Option<i32>,
258    pub offset: Option<i64>,    // For traditional pagination
259    pub cursor: Option<String>, // For cursor-based pagination
260    pub similarity_threshold: Option<f32>,
261    pub include_metadata: Option<bool>,
262    pub include_facets: Option<bool>,
263
264    // Ranking configuration
265    pub ranking_boost: Option<RankingBoost>,
266    pub explain_score: Option<bool>,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub enum SearchType {
271    Semantic,
272    Temporal,
273    Hybrid,
274    FullText,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct HybridWeights {
279    pub semantic_weight: f32,
280    pub temporal_weight: f32,
281    pub importance_weight: f32,
282    pub access_frequency_weight: f32,
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct DateRange {
287    pub start: Option<DateTime<Utc>>,
288    pub end: Option<DateTime<Utc>>,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct RangeFilter<T> {
293    pub min: Option<T>,
294    pub max: Option<T>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct RankingBoost {
299    pub recency_boost: Option<f32>,
300    pub importance_boost: Option<f32>,
301    pub access_frequency_boost: Option<f32>,
302    pub tier_boost: Option<std::collections::HashMap<MemoryTier, f32>>,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct SearchResult {
307    pub memory: Memory,
308    pub similarity_score: f32,
309    pub temporal_score: Option<f32>,
310    pub importance_score: f32,
311    pub access_frequency_score: Option<f32>,
312    pub combined_score: f32,
313    pub score_explanation: Option<ScoreExplanation>,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ScoreExplanation {
318    pub semantic_contribution: f32,
319    pub temporal_contribution: f32,
320    pub importance_contribution: f32,
321    pub access_frequency_contribution: f32,
322    pub total_score: f32,
323    pub factors: Vec<String>,
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct SearchResponse {
328    pub results: Vec<SearchResult>,
329    pub total_count: Option<i64>,
330    pub facets: Option<SearchFacets>,
331    pub suggestions: Option<Vec<String>>,
332    pub next_cursor: Option<String>,
333    pub execution_time_ms: u64,
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct SearchFacets {
338    pub tiers: std::collections::HashMap<MemoryTier, i64>,
339    pub date_histogram: Vec<DateBucket>,
340    pub importance_ranges: Vec<ImportanceRange>,
341    pub tags: std::collections::HashMap<String, i64>,
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct DateBucket {
346    pub date: DateTime<Utc>,
347    pub count: i64,
348    pub interval: String, // "day", "week", "month"
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
352pub struct ImportanceRange {
353    pub min: f32,
354    pub max: f32,
355    pub count: i64,
356    pub label: String,
357}
358
359impl Default for Memory {
360    fn default() -> Self {
361        Self {
362            id: Uuid::new_v4(),
363            content: String::new(),
364            content_hash: String::new(),
365            embedding: None,
366            tier: MemoryTier::Working,
367            status: MemoryStatus::Active,
368            importance_score: 0.5,
369            access_count: 0,
370            last_accessed_at: None,
371            metadata: serde_json::json!({}),
372            parent_id: None,
373            created_at: Utc::now(),
374            updated_at: Utc::now(),
375            expires_at: None,
376        }
377    }
378}
379
380impl Memory {
381    pub fn calculate_content_hash(content: &str) -> String {
382        use sha2::{Digest, Sha256};
383        let mut hasher = Sha256::new();
384        hasher.update(content.as_bytes());
385        hex::encode(hasher.finalize())
386    }
387
388    pub fn should_migrate(&self) -> bool {
389        match self.tier {
390            MemoryTier::Working => {
391                // Migrate if importance is low and hasn't been accessed recently
392                self.importance_score < 0.3
393                    || (self.last_accessed_at.is_some()
394                        && Utc::now()
395                            .signed_duration_since(self.last_accessed_at.unwrap())
396                            .num_hours()
397                            > 24)
398            }
399            MemoryTier::Warm => {
400                // Migrate to cold if very low importance and old
401                self.importance_score < 0.1
402                    && Utc::now().signed_duration_since(self.updated_at).num_days() > 7
403            }
404            MemoryTier::Cold => false,
405        }
406    }
407
408    pub fn next_tier(&self) -> Option<MemoryTier> {
409        match self.tier {
410            MemoryTier::Working => Some(MemoryTier::Warm),
411            MemoryTier::Warm => Some(MemoryTier::Cold),
412            MemoryTier::Cold => None,
413        }
414    }
415}