1use crate::error::{EngramError, Result};
10use crate::types::{CrossReference, EdgeType, Memory, MemoryScope, MemoryTier, Visibility};
11use chrono::{DateTime, Utc};
12use rusqlite::{params, Connection, OptionalExtension};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct TemporalQueryOptions {
19 pub as_of: Option<DateTime<Utc>>,
21 pub created_after: Option<DateTime<Utc>>,
23 pub created_before: Option<DateTime<Utc>>,
25 pub updated_after: Option<DateTime<Utc>>,
27 pub updated_before: Option<DateTime<Utc>>,
29 #[serde(default)]
31 pub include_deleted: bool,
32}
33
34impl TemporalQueryOptions {
35 pub fn as_of(timestamp: DateTime<Utc>) -> Self {
37 Self {
38 as_of: Some(timestamp),
39 ..Default::default()
40 }
41 }
42
43 pub fn time_range(start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
45 Self {
46 created_after: Some(start),
47 created_before: Some(end),
48 ..Default::default()
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TemporalMemory {
56 pub memory: Memory,
58 pub version_at_time: i32,
60 pub is_current: bool,
62 pub queried_at: DateTime<Utc>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct MemorySnapshot {
69 pub memory_id: i64,
71 pub version: i32,
73 pub content: String,
75 pub tags: Vec<String>,
77 pub metadata: HashMap<String, serde_json::Value>,
79 pub created_at: DateTime<Utc>,
81 pub created_by: Option<String>,
83 pub change_summary: Option<String>,
85}
86
87pub struct TemporalQueryEngine<'a> {
89 conn: &'a Connection,
90}
91
92impl<'a> TemporalQueryEngine<'a> {
93 pub fn new(conn: &'a Connection) -> Self {
95 Self { conn }
96 }
97
98 pub fn get_memory_at(
100 &self,
101 memory_id: i64,
102 as_of: DateTime<Utc>,
103 ) -> Result<Option<TemporalMemory>> {
104 let memory_existed: Option<(String, String)> = self
106 .conn
107 .query_row(
108 r#"
109 SELECT created_at, content
110 FROM memories
111 WHERE id = ?1 AND created_at <= ?2
112 "#,
113 params![memory_id, as_of.to_rfc3339()],
114 |row| Ok((row.get(0)?, row.get(1)?)),
115 )
116 .optional()?;
117
118 if memory_existed.is_none() {
119 return Ok(None);
120 }
121
122 let version_result: Option<(i32, String, String)> = self
124 .conn
125 .query_row(
126 r#"
127 SELECT version, content, tags
128 FROM memory_versions
129 WHERE memory_id = ?1 AND created_at <= ?2
130 ORDER BY version DESC
131 LIMIT 1
132 "#,
133 params![memory_id, as_of.to_rfc3339()],
134 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
135 )
136 .optional()?;
137
138 let current: Option<Memory> = self.get_current_memory(memory_id)?;
140
141 if let Some((version, content, tags_json)) = version_result {
142 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
143
144 if let Some(mut memory) = current.clone() {
146 memory.content = content;
147 memory.tags = tags;
148 memory.version = version;
149
150 let is_current = current.map(|c| c.version == version).unwrap_or(false);
151
152 return Ok(Some(TemporalMemory {
153 memory,
154 version_at_time: version,
155 is_current,
156 queried_at: as_of,
157 }));
158 }
159 }
160
161 if let Some(memory) = current {
163 if memory.created_at <= as_of {
164 return Ok(Some(TemporalMemory {
165 memory: memory.clone(),
166 version_at_time: memory.version,
167 is_current: true,
168 queried_at: as_of,
169 }));
170 }
171 }
172
173 Ok(None)
174 }
175
176 fn get_current_memory(&self, memory_id: i64) -> Result<Option<Memory>> {
178 self.conn
179 .query_row(
180 r#"
181 SELECT id, content, type, importance, access_count, created_at, updated_at,
182 last_accessed_at, owner_id, visibility, version, has_embedding
183 FROM memories
184 WHERE id = ?1
185 "#,
186 params![memory_id],
187 |row| {
188 let memory_type_str: String = row.get(2)?;
189 let visibility_str: String = row.get(9)?;
190
191 Ok(Memory {
192 id: row.get(0)?,
193 content: row.get(1)?,
194 memory_type: memory_type_str.parse().unwrap_or_default(),
195 tags: vec![], metadata: HashMap::new(),
197 importance: row.get(3)?,
198 access_count: row.get(4)?,
199 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
200 .map(|dt| dt.with_timezone(&Utc))
201 .unwrap_or_else(|_| Utc::now()),
202 updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
203 .map(|dt| dt.with_timezone(&Utc))
204 .unwrap_or_else(|_| Utc::now()),
205 last_accessed_at: row
206 .get::<_, Option<String>>(7)?
207 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
208 .map(|dt| dt.with_timezone(&Utc)),
209 owner_id: row.get(8)?,
210 visibility: match visibility_str.as_str() {
211 "shared" => Visibility::Shared,
212 "public" => Visibility::Public,
213 _ => Visibility::Private,
214 },
215 scope: MemoryScope::Global,
216 workspace: "default".to_string(),
217 tier: MemoryTier::Permanent,
218 version: row.get(10)?,
219 has_embedding: row.get(11)?,
220 expires_at: None,
221 content_hash: None,
222 event_time: None,
223 event_duration_seconds: None,
224 trigger_pattern: None,
225 procedure_success_count: 0,
226 procedure_failure_count: 0,
227 summary_of_id: None,
228 lifecycle_state: crate::types::LifecycleState::Active,
229 })
230 },
231 )
232 .optional()
233 .map_err(EngramError::from)
234 }
235
236 pub fn query_time_range(
238 &self,
239 options: &TemporalQueryOptions,
240 limit: i64,
241 ) -> Result<Vec<Memory>> {
242 let mut conditions = vec!["1=1".to_string()];
243 let mut params: Vec<Box<dyn rusqlite::ToSql>> = vec![];
244
245 if let Some(ref after) = options.created_after {
246 conditions.push(format!("created_at >= ?{}", params.len() + 1));
247 params.push(Box::new(after.to_rfc3339()));
248 }
249
250 if let Some(ref before) = options.created_before {
251 conditions.push(format!("created_at <= ?{}", params.len() + 1));
252 params.push(Box::new(before.to_rfc3339()));
253 }
254
255 if let Some(ref after) = options.updated_after {
256 conditions.push(format!("updated_at >= ?{}", params.len() + 1));
257 params.push(Box::new(after.to_rfc3339()));
258 }
259
260 if let Some(ref before) = options.updated_before {
261 conditions.push(format!("updated_at <= ?{}", params.len() + 1));
262 params.push(Box::new(before.to_rfc3339()));
263 }
264
265 let sql = format!(
266 r#"
267 SELECT id, content, type, importance, access_count, created_at, updated_at,
268 last_accessed_at, owner_id, visibility, version, has_embedding
269 FROM memories
270 WHERE {}
271 ORDER BY created_at DESC
272 LIMIT ?{}
273 "#,
274 conditions.join(" AND "),
275 params.len() + 1
276 );
277
278 params.push(Box::new(limit));
279
280 let params_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
281
282 let mut stmt = self.conn.prepare(&sql)?;
283 let memories = stmt
284 .query_map(params_refs.as_slice(), |row| {
285 let memory_type_str: String = row.get(2)?;
286 let visibility_str: String = row.get(9)?;
287
288 Ok(Memory {
289 id: row.get(0)?,
290 content: row.get(1)?,
291 memory_type: memory_type_str.parse().unwrap_or_default(),
292 tags: vec![],
293 metadata: HashMap::new(),
294 importance: row.get(3)?,
295 access_count: row.get(4)?,
296 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
297 .map(|dt| dt.with_timezone(&Utc))
298 .unwrap_or_else(|_| Utc::now()),
299 updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
300 .map(|dt| dt.with_timezone(&Utc))
301 .unwrap_or_else(|_| Utc::now()),
302 last_accessed_at: row
303 .get::<_, Option<String>>(7)?
304 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
305 .map(|dt| dt.with_timezone(&Utc)),
306 owner_id: row.get(8)?,
307 visibility: match visibility_str.as_str() {
308 "shared" => Visibility::Shared,
309 "public" => Visibility::Public,
310 _ => Visibility::Private,
311 },
312 scope: MemoryScope::Global, workspace: "default".to_string(),
314 tier: MemoryTier::Permanent,
315 version: row.get(10)?,
316 has_embedding: row.get(11)?,
317 expires_at: None, content_hash: None, event_time: None,
320 event_duration_seconds: None,
321 trigger_pattern: None,
322 procedure_success_count: 0,
323 procedure_failure_count: 0,
324 summary_of_id: None,
325 lifecycle_state: crate::types::LifecycleState::Active,
326 })
327 })?
328 .collect::<std::result::Result<Vec<_>, _>>()?;
329
330 Ok(memories)
331 }
332
333 pub fn get_crossrefs_at(
335 &self,
336 memory_id: i64,
337 as_of: DateTime<Utc>,
338 ) -> Result<Vec<CrossReference>> {
339 let mut stmt = self.conn.prepare(
340 r#"
341 SELECT from_id, to_id, edge_type, score, confidence, strength, source,
342 source_context, created_at, valid_from, valid_to, pinned
343 FROM crossrefs
344 WHERE (from_id = ?1 OR to_id = ?1)
345 AND valid_from <= ?2
346 AND (valid_to IS NULL OR valid_to > ?2)
347 ORDER BY score DESC
348 "#,
349 )?;
350
351 let crossrefs = stmt
352 .query_map(params![memory_id, as_of.to_rfc3339()], |row| {
353 let edge_type_str: String = row.get(2)?;
354 let source_str: String = row.get(6)?;
355
356 Ok(CrossReference {
357 from_id: row.get(0)?,
358 to_id: row.get(1)?,
359 edge_type: edge_type_str.parse().unwrap_or_default(),
360 score: row.get(3)?,
361 confidence: row.get(4)?,
362 strength: row.get(5)?,
363 source: match source_str.as_str() {
364 "manual" => crate::types::RelationSource::Manual,
365 "llm" => crate::types::RelationSource::Llm,
366 _ => crate::types::RelationSource::Auto,
367 },
368 source_context: row.get(7)?,
369 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
370 .map(|dt| dt.with_timezone(&Utc))
371 .unwrap_or_else(|_| Utc::now()),
372 valid_from: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
373 .map(|dt| dt.with_timezone(&Utc))
374 .unwrap_or_else(|_| Utc::now()),
375 valid_to: row
376 .get::<_, Option<String>>(10)?
377 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
378 .map(|dt| dt.with_timezone(&Utc)),
379 pinned: row.get(11)?,
380 metadata: HashMap::new(),
381 })
382 })?
383 .collect::<std::result::Result<Vec<_>, _>>()?;
384
385 Ok(crossrefs)
386 }
387
388 pub fn get_version_history(&self, memory_id: i64) -> Result<Vec<MemorySnapshot>> {
390 let mut stmt = self.conn.prepare(
391 r#"
392 SELECT memory_id, version, content, tags, metadata, created_at, created_by, change_summary
393 FROM memory_versions
394 WHERE memory_id = ?1
395 ORDER BY version DESC
396 "#,
397 )?;
398
399 let snapshots = stmt
400 .query_map(params![memory_id], |row| {
401 let tags_json: String = row.get(3)?;
402 let metadata_json: String = row.get(4)?;
403
404 Ok(MemorySnapshot {
405 memory_id: row.get(0)?,
406 version: row.get(1)?,
407 content: row.get(2)?,
408 tags: serde_json::from_str(&tags_json).unwrap_or_default(),
409 metadata: serde_json::from_str(&metadata_json).unwrap_or_default(),
410 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
411 .map(|dt| dt.with_timezone(&Utc))
412 .unwrap_or_else(|_| Utc::now()),
413 created_by: row.get(6)?,
414 change_summary: row.get(7)?,
415 })
416 })?
417 .collect::<std::result::Result<Vec<_>, _>>()?;
418
419 Ok(snapshots)
420 }
421
422 pub fn get_memory_version(
424 &self,
425 memory_id: i64,
426 version: i32,
427 ) -> Result<Option<MemorySnapshot>> {
428 self.conn
429 .query_row(
430 r#"
431 SELECT memory_id, version, content, tags, metadata, created_at, created_by, change_summary
432 FROM memory_versions
433 WHERE memory_id = ?1 AND version = ?2
434 "#,
435 params![memory_id, version],
436 |row| {
437 let tags_json: String = row.get(3)?;
438 let metadata_json: String = row.get(4)?;
439
440 Ok(MemorySnapshot {
441 memory_id: row.get(0)?,
442 version: row.get(1)?,
443 content: row.get(2)?,
444 tags: serde_json::from_str(&tags_json).unwrap_or_default(),
445 metadata: serde_json::from_str(&metadata_json).unwrap_or_default(),
446 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
447 .map(|dt| dt.with_timezone(&Utc))
448 .unwrap_or_else(|_| Utc::now()),
449 created_by: row.get(6)?,
450 change_summary: row.get(7)?,
451 })
452 },
453 )
454 .optional()
455 .map_err(EngramError::from)
456 }
457
458 pub fn traverse_graph_at(
460 &self,
461 start_id: i64,
462 as_of: DateTime<Utc>,
463 depth: usize,
464 edge_types: Option<Vec<EdgeType>>,
465 ) -> Result<Vec<(Memory, CrossReference)>> {
466 let mut visited = std::collections::HashSet::new();
467 let mut results = Vec::new();
468 let mut to_visit = vec![(start_id, 0usize)];
469
470 while let Some((current_id, current_depth)) = to_visit.pop() {
471 if current_depth >= depth || visited.contains(¤t_id) {
472 continue;
473 }
474 visited.insert(current_id);
475
476 let crossrefs = self.get_crossrefs_at(current_id, as_of)?;
478
479 for crossref in crossrefs {
480 if let Some(ref types) = edge_types {
482 if !types.contains(&crossref.edge_type) {
483 continue;
484 }
485 }
486
487 let other_id = if crossref.from_id == current_id {
489 crossref.to_id
490 } else {
491 crossref.from_id
492 };
493
494 if let Some(temporal_memory) = self.get_memory_at(other_id, as_of)? {
495 results.push((temporal_memory.memory, crossref.clone()));
496 to_visit.push((other_id, current_depth + 1));
497 }
498 }
499 }
500
501 Ok(results)
502 }
503
504 pub fn compare_states(
506 &self,
507 memory_id: i64,
508 time1: DateTime<Utc>,
509 time2: DateTime<Utc>,
510 ) -> Result<StateDiff> {
511 let state1 = self.get_memory_at(memory_id, time1)?;
512 let state2 = self.get_memory_at(memory_id, time2)?;
513
514 let crossrefs1 = self.get_crossrefs_at(memory_id, time1)?;
515 let crossrefs2 = self.get_crossrefs_at(memory_id, time2)?;
516
517 Ok(StateDiff {
518 memory_id,
519 time1,
520 time2,
521 memory_state1: state1.map(|t| t.memory),
522 memory_state2: state2.map(|t| t.memory),
523 crossrefs_added: crossrefs2
524 .iter()
525 .filter(|c| {
526 !crossrefs1
527 .iter()
528 .any(|c1| c1.to_id == c.to_id && c1.from_id == c.from_id)
529 })
530 .cloned()
531 .collect(),
532 crossrefs_removed: crossrefs1
533 .iter()
534 .filter(|c| {
535 !crossrefs2
536 .iter()
537 .any(|c2| c2.to_id == c.to_id && c2.from_id == c.from_id)
538 })
539 .cloned()
540 .collect(),
541 })
542 }
543}
544
545#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct StateDiff {
548 pub memory_id: i64,
549 pub time1: DateTime<Utc>,
550 pub time2: DateTime<Utc>,
551 pub memory_state1: Option<Memory>,
552 pub memory_state2: Option<Memory>,
553 pub crossrefs_added: Vec<CrossReference>,
554 pub crossrefs_removed: Vec<CrossReference>,
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_temporal_query_options_default() {
563 let options = TemporalQueryOptions::default();
564 assert!(options.as_of.is_none());
565 assert!(options.created_after.is_none());
566 assert!(!options.include_deleted);
567 }
568
569 #[test]
570 fn test_temporal_query_options_as_of() {
571 let now = Utc::now();
572 let options = TemporalQueryOptions::as_of(now);
573 assert_eq!(options.as_of, Some(now));
574 }
575
576 #[test]
577 fn test_temporal_query_options_time_range() {
578 let start = Utc::now() - chrono::Duration::days(7);
579 let end = Utc::now();
580 let options = TemporalQueryOptions::time_range(start, end);
581 assert_eq!(options.created_after, Some(start));
582 assert_eq!(options.created_before, Some(end));
583 }
584}