1use chrono::{DateTime, Utc};
2use pgvector::Vector;
3use serde::{Deserialize, Serialize};
4use sqlx::FromRow;
5use std::str::FromStr;
6use uuid::Uuid;
7
8#[derive(Debug, Clone)]
9pub struct SerializableVector(pub Option<Vector>);
10
11impl Serialize for SerializableVector {
12 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13 where
14 S: serde::Serializer,
15 {
16 match &self.0 {
17 Some(v) => v.as_slice().serialize(serializer),
18 None => serializer.serialize_none(),
19 }
20 }
21}
22
23impl<'de> Deserialize<'de> for SerializableVector {
24 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25 where
26 D: serde::Deserializer<'de>,
27 {
28 let opt_vec: Option<Vec<f32>> = Option::deserialize(deserializer)?;
29 Ok(SerializableVector(opt_vec.map(Vector::from)))
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
34#[sqlx(type_name = "varchar", rename_all = "lowercase")]
35pub enum MemoryTier {
36 Working,
37 Warm,
38 Cold,
39}
40
41impl FromStr for MemoryTier {
42 type Err = String;
43
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 match s.to_lowercase().as_str() {
46 "working" => Ok(MemoryTier::Working),
47 "warm" => Ok(MemoryTier::Warm),
48 "cold" => Ok(MemoryTier::Cold),
49 _ => Err(format!("Invalid memory tier: {s}")),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
55#[sqlx(type_name = "varchar", rename_all = "lowercase")]
56pub enum MemoryStatus {
57 Active,
58 Migrating,
59 Archived,
60 Deleted,
61}
62
63#[derive(Debug, Clone, FromRow)]
64pub struct Memory {
65 pub id: Uuid,
66 pub content: String,
67 pub content_hash: String,
68 pub embedding: Option<Vector>,
69 pub tier: MemoryTier,
70 pub status: MemoryStatus,
71 pub importance_score: f64,
72 pub access_count: i32,
73 pub last_accessed_at: Option<DateTime<Utc>>,
74 pub metadata: serde_json::Value,
75 pub parent_id: Option<Uuid>,
76 pub created_at: DateTime<Utc>,
77 pub updated_at: DateTime<Utc>,
78 pub expires_at: Option<DateTime<Utc>>,
79}
80
81impl Serialize for Memory {
82 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83 where
84 S: serde::Serializer,
85 {
86 use serde::ser::SerializeStruct;
87 let mut state = serializer.serialize_struct("Memory", 15)?;
88 state.serialize_field("id", &self.id)?;
89 state.serialize_field("content", &self.content)?;
90 state.serialize_field("content_hash", &self.content_hash)?;
91 state.serialize_field("embedding", &self.embedding.as_ref().map(|v| v.as_slice()))?;
92 state.serialize_field("tier", &self.tier)?;
93 state.serialize_field("status", &self.status)?;
94 state.serialize_field("importance_score", &self.importance_score)?;
95 state.serialize_field("access_count", &self.access_count)?;
96 state.serialize_field("last_accessed_at", &self.last_accessed_at)?;
97 state.serialize_field("metadata", &self.metadata)?;
98 state.serialize_field("parent_id", &self.parent_id)?;
99 state.serialize_field("created_at", &self.created_at)?;
100 state.serialize_field("updated_at", &self.updated_at)?;
101 state.serialize_field("expires_at", &self.expires_at)?;
102 state.end()
103 }
104}
105
106impl<'de> Deserialize<'de> for Memory {
107 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 Ok(Memory::default())
113 }
114}
115
116#[derive(Debug, Clone, FromRow)]
117pub struct MemorySummary {
118 pub id: Uuid,
119 pub summary_level: String,
120 pub summary_content: String,
121 pub summary_embedding: Option<Vector>,
122 pub start_time: DateTime<Utc>,
123 pub end_time: DateTime<Utc>,
124 pub memory_count: i32,
125 pub metadata: serde_json::Value,
126 pub created_at: DateTime<Utc>,
127 pub updated_at: DateTime<Utc>,
128}
129
130impl Serialize for MemorySummary {
131 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
132 where
133 S: serde::Serializer,
134 {
135 use serde::ser::SerializeStruct;
136 let mut state = serializer.serialize_struct("MemorySummary", 10)?;
137 state.serialize_field("id", &self.id)?;
138 state.serialize_field("summary_level", &self.summary_level)?;
139 state.serialize_field("summary_content", &self.summary_content)?;
140 state.serialize_field(
141 "summary_embedding",
142 &self.summary_embedding.as_ref().map(|v| v.as_slice()),
143 )?;
144 state.serialize_field("start_time", &self.start_time)?;
145 state.serialize_field("end_time", &self.end_time)?;
146 state.serialize_field("memory_count", &self.memory_count)?;
147 state.serialize_field("metadata", &self.metadata)?;
148 state.serialize_field("created_at", &self.created_at)?;
149 state.serialize_field("updated_at", &self.updated_at)?;
150 state.end()
151 }
152}
153
154impl<'de> Deserialize<'de> for MemorySummary {
155 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
156 where
157 D: serde::Deserializer<'de>,
158 {
159 unimplemented!("MemorySummary deserialization not needed")
160 }
161}
162
163#[derive(Debug, Clone, FromRow)]
164pub struct MemoryCluster {
165 pub id: Uuid,
166 pub cluster_name: String,
167 pub centroid_embedding: Vector,
168 pub concept_tags: Vec<String>,
169 pub member_count: i32,
170 pub tier: MemoryTier,
171 pub metadata: serde_json::Value,
172 pub created_at: DateTime<Utc>,
173 pub updated_at: DateTime<Utc>,
174}
175
176impl Serialize for MemoryCluster {
177 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178 where
179 S: serde::Serializer,
180 {
181 use serde::ser::SerializeStruct;
182 let mut state = serializer.serialize_struct("MemoryCluster", 9)?;
183 state.serialize_field("id", &self.id)?;
184 state.serialize_field("cluster_name", &self.cluster_name)?;
185 state.serialize_field("centroid_embedding", &self.centroid_embedding.as_slice())?;
186 state.serialize_field("concept_tags", &self.concept_tags)?;
187 state.serialize_field("member_count", &self.member_count)?;
188 state.serialize_field("tier", &self.tier)?;
189 state.serialize_field("metadata", &self.metadata)?;
190 state.serialize_field("created_at", &self.created_at)?;
191 state.serialize_field("updated_at", &self.updated_at)?;
192 state.end()
193 }
194}
195
196impl<'de> Deserialize<'de> for MemoryCluster {
197 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
198 where
199 D: serde::Deserializer<'de>,
200 {
201 unimplemented!("MemoryCluster deserialization not needed")
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
206pub struct MigrationHistoryEntry {
207 pub id: Uuid,
208 pub memory_id: Uuid,
209 pub from_tier: MemoryTier,
210 pub to_tier: MemoryTier,
211 pub migration_reason: Option<String>,
212 pub migrated_at: DateTime<Utc>,
213 pub migration_duration_ms: Option<i32>,
214 pub success: bool,
215 pub error_message: Option<String>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct CreateMemoryRequest {
220 pub content: String,
221 pub embedding: Option<Vec<f32>>,
222 pub tier: Option<MemoryTier>,
223 pub importance_score: Option<f64>,
224 pub metadata: Option<serde_json::Value>,
225 pub parent_id: Option<Uuid>,
226 pub expires_at: Option<DateTime<Utc>>,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct UpdateMemoryRequest {
231 pub content: Option<String>,
232 pub embedding: Option<Vec<f32>>,
233 pub tier: Option<MemoryTier>,
234 pub importance_score: Option<f64>,
235 pub metadata: Option<serde_json::Value>,
236 pub expires_at: Option<DateTime<Utc>>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct SearchRequest {
241 pub query_text: Option<String>,
243 pub query_embedding: Option<Vec<f32>>,
244
245 pub search_type: Option<SearchType>,
247 pub hybrid_weights: Option<HybridWeights>,
248
249 pub tier: Option<MemoryTier>,
251 pub date_range: Option<DateRange>,
252 pub importance_range: Option<RangeFilter<f32>>,
253 pub metadata_filters: Option<serde_json::Value>,
254 pub tags: Option<Vec<String>>,
255
256 pub limit: Option<i32>,
258 pub offset: Option<i64>, pub cursor: Option<String>, pub similarity_threshold: Option<f32>,
261 pub include_metadata: Option<bool>,
262 pub include_facets: Option<bool>,
263
264 pub ranking_boost: Option<RankingBoost>,
266 pub explain_score: Option<bool>,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub enum SearchType {
271 Semantic,
272 Temporal,
273 Hybrid,
274 FullText,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct HybridWeights {
279 pub semantic_weight: f32,
280 pub temporal_weight: f32,
281 pub importance_weight: f32,
282 pub access_frequency_weight: f32,
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct DateRange {
287 pub start: Option<DateTime<Utc>>,
288 pub end: Option<DateTime<Utc>>,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct RangeFilter<T> {
293 pub min: Option<T>,
294 pub max: Option<T>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct RankingBoost {
299 pub recency_boost: Option<f32>,
300 pub importance_boost: Option<f32>,
301 pub access_frequency_boost: Option<f32>,
302 pub tier_boost: Option<std::collections::HashMap<MemoryTier, f32>>,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct SearchResult {
307 pub memory: Memory,
308 pub similarity_score: f32,
309 pub temporal_score: Option<f32>,
310 pub importance_score: f64,
311 pub access_frequency_score: Option<f32>,
312 pub combined_score: f32,
313 pub score_explanation: Option<ScoreExplanation>,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ScoreExplanation {
318 pub semantic_contribution: f32,
319 pub temporal_contribution: f32,
320 pub importance_contribution: f32,
321 pub access_frequency_contribution: f32,
322 pub total_score: f32,
323 pub factors: Vec<String>,
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct SearchResponse {
328 pub results: Vec<SearchResult>,
329 pub total_count: Option<i64>,
330 pub facets: Option<SearchFacets>,
331 pub suggestions: Option<Vec<String>>,
332 pub next_cursor: Option<String>,
333 pub execution_time_ms: u64,
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct SearchFacets {
338 pub tiers: std::collections::HashMap<MemoryTier, i64>,
339 pub date_histogram: Vec<DateBucket>,
340 pub importance_ranges: Vec<ImportanceRange>,
341 pub tags: std::collections::HashMap<String, i64>,
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct DateBucket {
346 pub date: DateTime<Utc>,
347 pub count: i64,
348 pub interval: String, }
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
352pub struct ImportanceRange {
353 pub min: f32,
354 pub max: f32,
355 pub count: i64,
356 pub label: String,
357}
358
359impl Default for Memory {
360 fn default() -> Self {
361 Self {
362 id: Uuid::new_v4(),
363 content: String::new(),
364 content_hash: String::new(),
365 embedding: None,
366 tier: MemoryTier::Working,
367 status: MemoryStatus::Active,
368 importance_score: 0.5,
369 access_count: 0,
370 last_accessed_at: None,
371 metadata: serde_json::json!({}),
372 parent_id: None,
373 created_at: Utc::now(),
374 updated_at: Utc::now(),
375 expires_at: None,
376 }
377 }
378}
379
380impl Memory {
381 pub fn calculate_content_hash(content: &str) -> String {
382 use sha2::{Digest, Sha256};
383 let mut hasher = Sha256::new();
384 hasher.update(content.as_bytes());
385 hex::encode(hasher.finalize())
386 }
387
388 pub fn should_migrate(&self) -> bool {
389 match self.tier {
390 MemoryTier::Working => {
391 self.importance_score < 0.3
393 || (self.last_accessed_at.is_some()
394 && Utc::now()
395 .signed_duration_since(self.last_accessed_at.unwrap())
396 .num_hours()
397 > 24)
398 }
399 MemoryTier::Warm => {
400 self.importance_score < 0.1
402 && Utc::now().signed_duration_since(self.updated_at).num_days() > 7
403 }
404 MemoryTier::Cold => false,
405 }
406 }
407
408 pub fn next_tier(&self) -> Option<MemoryTier> {
409 match self.tier {
410 MemoryTier::Working => Some(MemoryTier::Warm),
411 MemoryTier::Warm => Some(MemoryTier::Cold),
412 MemoryTier::Cold => None,
413 }
414 }
415}