1use crate::error::{EngramError, Result};
10use crate::types::{CrossReference, EdgeType, Memory, MemoryScope, MemoryTier, Visibility};
11use chrono::{DateTime, Utc};
12use rusqlite::{params, Connection, OptionalExtension};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct TemporalQueryOptions {
19 pub as_of: Option<DateTime<Utc>>,
21 pub created_after: Option<DateTime<Utc>>,
23 pub created_before: Option<DateTime<Utc>>,
25 pub updated_after: Option<DateTime<Utc>>,
27 pub updated_before: Option<DateTime<Utc>>,
29 #[serde(default)]
31 pub include_deleted: bool,
32}
33
34impl TemporalQueryOptions {
35 pub fn as_of(timestamp: DateTime<Utc>) -> Self {
37 Self {
38 as_of: Some(timestamp),
39 ..Default::default()
40 }
41 }
42
43 pub fn time_range(start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
45 Self {
46 created_after: Some(start),
47 created_before: Some(end),
48 ..Default::default()
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TemporalMemory {
56 pub memory: Memory,
58 pub version_at_time: i32,
60 pub is_current: bool,
62 pub queried_at: DateTime<Utc>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct MemorySnapshot {
69 pub memory_id: i64,
71 pub version: i32,
73 pub content: String,
75 pub tags: Vec<String>,
77 pub metadata: HashMap<String, serde_json::Value>,
79 pub created_at: DateTime<Utc>,
81 pub created_by: Option<String>,
83 pub change_summary: Option<String>,
85}
86
87pub struct TemporalQueryEngine<'a> {
89 conn: &'a Connection,
90}
91
92impl<'a> TemporalQueryEngine<'a> {
93 pub fn new(conn: &'a Connection) -> Self {
95 Self { conn }
96 }
97
98 pub fn get_memory_at(
100 &self,
101 memory_id: i64,
102 as_of: DateTime<Utc>,
103 ) -> Result<Option<TemporalMemory>> {
104 let memory_existed: Option<(String, String)> = self
106 .conn
107 .query_row(
108 r#"
109 SELECT created_at, content
110 FROM memories
111 WHERE id = ?1 AND created_at <= ?2
112 "#,
113 params![memory_id, as_of.to_rfc3339()],
114 |row| Ok((row.get(0)?, row.get(1)?)),
115 )
116 .optional()?;
117
118 if memory_existed.is_none() {
119 return Ok(None);
120 }
121
122 let version_result: Option<(i32, String, String)> = self
124 .conn
125 .query_row(
126 r#"
127 SELECT version, content, tags
128 FROM memory_versions
129 WHERE memory_id = ?1 AND created_at <= ?2
130 ORDER BY version DESC
131 LIMIT 1
132 "#,
133 params![memory_id, as_of.to_rfc3339()],
134 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
135 )
136 .optional()?;
137
138 let current: Option<Memory> = self.get_current_memory(memory_id)?;
140
141 if let Some((version, content, tags_json)) = version_result {
142 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
143
144 if let Some(mut memory) = current.clone() {
146 memory.content = content;
147 memory.tags = tags;
148 memory.version = version;
149
150 let is_current = current.map(|c| c.version == version).unwrap_or(false);
151
152 return Ok(Some(TemporalMemory {
153 memory,
154 version_at_time: version,
155 is_current,
156 queried_at: as_of,
157 }));
158 }
159 }
160
161 if let Some(memory) = current {
163 if memory.created_at <= as_of {
164 return Ok(Some(TemporalMemory {
165 memory: memory.clone(),
166 version_at_time: memory.version,
167 is_current: true,
168 queried_at: as_of,
169 }));
170 }
171 }
172
173 Ok(None)
174 }
175
176 fn get_current_memory(&self, memory_id: i64) -> Result<Option<Memory>> {
178 self.conn
179 .query_row(
180 r#"
181 SELECT id, content, type, importance, access_count, created_at, updated_at,
182 last_accessed_at, owner_id, visibility, version, has_embedding
183 FROM memories
184 WHERE id = ?1
185 "#,
186 params![memory_id],
187 |row| {
188 let memory_type_str: String = row.get(2)?;
189 let visibility_str: String = row.get(9)?;
190
191 Ok(Memory {
192 id: row.get(0)?,
193 content: row.get(1)?,
194 memory_type: memory_type_str.parse().unwrap_or_default(),
195 tags: vec![], metadata: HashMap::new(),
197 importance: row.get(3)?,
198 access_count: row.get(4)?,
199 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
200 .map(|dt| dt.with_timezone(&Utc))
201 .unwrap_or_else(|_| Utc::now()),
202 updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
203 .map(|dt| dt.with_timezone(&Utc))
204 .unwrap_or_else(|_| Utc::now()),
205 last_accessed_at: row
206 .get::<_, Option<String>>(7)?
207 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
208 .map(|dt| dt.with_timezone(&Utc)),
209 owner_id: row.get(8)?,
210 visibility: match visibility_str.as_str() {
211 "shared" => Visibility::Shared,
212 "public" => Visibility::Public,
213 _ => Visibility::Private,
214 },
215 scope: MemoryScope::Global,
216 workspace: "default".to_string(),
217 tier: MemoryTier::Permanent,
218 version: row.get(10)?,
219 has_embedding: row.get(11)?,
220 expires_at: None,
221 content_hash: None,
222 event_time: None,
223 event_duration_seconds: None,
224 trigger_pattern: None,
225 procedure_success_count: 0,
226 procedure_failure_count: 0,
227 summary_of_id: None,
228 lifecycle_state: crate::types::LifecycleState::Active,
229 media_url: None,
230 })
231 },
232 )
233 .optional()
234 .map_err(EngramError::from)
235 }
236
237 pub fn query_time_range(
239 &self,
240 options: &TemporalQueryOptions,
241 limit: i64,
242 ) -> Result<Vec<Memory>> {
243 let mut conditions = vec!["1=1".to_string()];
244 let mut params: Vec<Box<dyn rusqlite::ToSql>> = vec![];
245
246 if let Some(ref after) = options.created_after {
247 conditions.push(format!("created_at >= ?{}", params.len() + 1));
248 params.push(Box::new(after.to_rfc3339()));
249 }
250
251 if let Some(ref before) = options.created_before {
252 conditions.push(format!("created_at <= ?{}", params.len() + 1));
253 params.push(Box::new(before.to_rfc3339()));
254 }
255
256 if let Some(ref after) = options.updated_after {
257 conditions.push(format!("updated_at >= ?{}", params.len() + 1));
258 params.push(Box::new(after.to_rfc3339()));
259 }
260
261 if let Some(ref before) = options.updated_before {
262 conditions.push(format!("updated_at <= ?{}", params.len() + 1));
263 params.push(Box::new(before.to_rfc3339()));
264 }
265
266 let sql = format!(
267 r#"
268 SELECT id, content, type, importance, access_count, created_at, updated_at,
269 last_accessed_at, owner_id, visibility, version, has_embedding
270 FROM memories
271 WHERE {}
272 ORDER BY created_at DESC
273 LIMIT ?{}
274 "#,
275 conditions.join(" AND "),
276 params.len() + 1
277 );
278
279 params.push(Box::new(limit));
280
281 let params_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
282
283 let mut stmt = self.conn.prepare(&sql)?;
284 let memories = stmt
285 .query_map(params_refs.as_slice(), |row| {
286 let memory_type_str: String = row.get(2)?;
287 let visibility_str: String = row.get(9)?;
288
289 Ok(Memory {
290 id: row.get(0)?,
291 content: row.get(1)?,
292 memory_type: memory_type_str.parse().unwrap_or_default(),
293 tags: vec![],
294 metadata: HashMap::new(),
295 importance: row.get(3)?,
296 access_count: row.get(4)?,
297 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
298 .map(|dt| dt.with_timezone(&Utc))
299 .unwrap_or_else(|_| Utc::now()),
300 updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
301 .map(|dt| dt.with_timezone(&Utc))
302 .unwrap_or_else(|_| Utc::now()),
303 last_accessed_at: row
304 .get::<_, Option<String>>(7)?
305 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
306 .map(|dt| dt.with_timezone(&Utc)),
307 owner_id: row.get(8)?,
308 visibility: match visibility_str.as_str() {
309 "shared" => Visibility::Shared,
310 "public" => Visibility::Public,
311 _ => Visibility::Private,
312 },
313 scope: MemoryScope::Global, workspace: "default".to_string(),
315 tier: MemoryTier::Permanent,
316 version: row.get(10)?,
317 has_embedding: row.get(11)?,
318 expires_at: None, content_hash: None, event_time: None,
321 event_duration_seconds: None,
322 trigger_pattern: None,
323 procedure_success_count: 0,
324 procedure_failure_count: 0,
325 summary_of_id: None,
326 lifecycle_state: crate::types::LifecycleState::Active,
327 media_url: None,
328 })
329 })?
330 .collect::<std::result::Result<Vec<_>, _>>()?;
331
332 Ok(memories)
333 }
334
335 pub fn get_crossrefs_at(
337 &self,
338 memory_id: i64,
339 as_of: DateTime<Utc>,
340 ) -> Result<Vec<CrossReference>> {
341 let mut stmt = self.conn.prepare(
342 r#"
343 SELECT from_id, to_id, edge_type, score, confidence, strength, source,
344 source_context, created_at, valid_from, valid_to, pinned
345 FROM crossrefs
346 WHERE (from_id = ?1 OR to_id = ?1)
347 AND valid_from <= ?2
348 AND (valid_to IS NULL OR valid_to > ?2)
349 ORDER BY score DESC
350 "#,
351 )?;
352
353 let crossrefs = stmt
354 .query_map(params![memory_id, as_of.to_rfc3339()], |row| {
355 let edge_type_str: String = row.get(2)?;
356 let source_str: String = row.get(6)?;
357
358 Ok(CrossReference {
359 from_id: row.get(0)?,
360 to_id: row.get(1)?,
361 edge_type: edge_type_str.parse().unwrap_or_default(),
362 score: row.get(3)?,
363 confidence: row.get(4)?,
364 strength: row.get(5)?,
365 source: match source_str.as_str() {
366 "manual" => crate::types::RelationSource::Manual,
367 "llm" => crate::types::RelationSource::Llm,
368 _ => crate::types::RelationSource::Auto,
369 },
370 source_context: row.get(7)?,
371 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
372 .map(|dt| dt.with_timezone(&Utc))
373 .unwrap_or_else(|_| Utc::now()),
374 valid_from: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
375 .map(|dt| dt.with_timezone(&Utc))
376 .unwrap_or_else(|_| Utc::now()),
377 valid_to: row
378 .get::<_, Option<String>>(10)?
379 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
380 .map(|dt| dt.with_timezone(&Utc)),
381 pinned: row.get(11)?,
382 metadata: HashMap::new(),
383 })
384 })?
385 .collect::<std::result::Result<Vec<_>, _>>()?;
386
387 Ok(crossrefs)
388 }
389
390 pub fn get_version_history(&self, memory_id: i64) -> Result<Vec<MemorySnapshot>> {
392 let mut stmt = self.conn.prepare(
393 r#"
394 SELECT memory_id, version, content, tags, metadata, created_at, created_by, change_summary
395 FROM memory_versions
396 WHERE memory_id = ?1
397 ORDER BY version DESC
398 "#,
399 )?;
400
401 let snapshots = stmt
402 .query_map(params![memory_id], |row| {
403 let tags_json: String = row.get(3)?;
404 let metadata_json: String = row.get(4)?;
405
406 Ok(MemorySnapshot {
407 memory_id: row.get(0)?,
408 version: row.get(1)?,
409 content: row.get(2)?,
410 tags: serde_json::from_str(&tags_json).unwrap_or_default(),
411 metadata: serde_json::from_str(&metadata_json).unwrap_or_default(),
412 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
413 .map(|dt| dt.with_timezone(&Utc))
414 .unwrap_or_else(|_| Utc::now()),
415 created_by: row.get(6)?,
416 change_summary: row.get(7)?,
417 })
418 })?
419 .collect::<std::result::Result<Vec<_>, _>>()?;
420
421 Ok(snapshots)
422 }
423
424 pub fn get_memory_version(
426 &self,
427 memory_id: i64,
428 version: i32,
429 ) -> Result<Option<MemorySnapshot>> {
430 self.conn
431 .query_row(
432 r#"
433 SELECT memory_id, version, content, tags, metadata, created_at, created_by, change_summary
434 FROM memory_versions
435 WHERE memory_id = ?1 AND version = ?2
436 "#,
437 params![memory_id, version],
438 |row| {
439 let tags_json: String = row.get(3)?;
440 let metadata_json: String = row.get(4)?;
441
442 Ok(MemorySnapshot {
443 memory_id: row.get(0)?,
444 version: row.get(1)?,
445 content: row.get(2)?,
446 tags: serde_json::from_str(&tags_json).unwrap_or_default(),
447 metadata: serde_json::from_str(&metadata_json).unwrap_or_default(),
448 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
449 .map(|dt| dt.with_timezone(&Utc))
450 .unwrap_or_else(|_| Utc::now()),
451 created_by: row.get(6)?,
452 change_summary: row.get(7)?,
453 })
454 },
455 )
456 .optional()
457 .map_err(EngramError::from)
458 }
459
460 pub fn traverse_graph_at(
462 &self,
463 start_id: i64,
464 as_of: DateTime<Utc>,
465 depth: usize,
466 edge_types: Option<Vec<EdgeType>>,
467 ) -> Result<Vec<(Memory, CrossReference)>> {
468 let mut visited = std::collections::HashSet::new();
469 let mut results = Vec::new();
470 let mut to_visit = vec![(start_id, 0usize)];
471
472 while let Some((current_id, current_depth)) = to_visit.pop() {
473 if current_depth >= depth || visited.contains(¤t_id) {
474 continue;
475 }
476 visited.insert(current_id);
477
478 let crossrefs = self.get_crossrefs_at(current_id, as_of)?;
480
481 for crossref in crossrefs {
482 if let Some(ref types) = edge_types {
484 if !types.contains(&crossref.edge_type) {
485 continue;
486 }
487 }
488
489 let other_id = if crossref.from_id == current_id {
491 crossref.to_id
492 } else {
493 crossref.from_id
494 };
495
496 if let Some(temporal_memory) = self.get_memory_at(other_id, as_of)? {
497 results.push((temporal_memory.memory, crossref.clone()));
498 to_visit.push((other_id, current_depth + 1));
499 }
500 }
501 }
502
503 Ok(results)
504 }
505
506 pub fn compare_states(
508 &self,
509 memory_id: i64,
510 time1: DateTime<Utc>,
511 time2: DateTime<Utc>,
512 ) -> Result<StateDiff> {
513 let state1 = self.get_memory_at(memory_id, time1)?;
514 let state2 = self.get_memory_at(memory_id, time2)?;
515
516 let crossrefs1 = self.get_crossrefs_at(memory_id, time1)?;
517 let crossrefs2 = self.get_crossrefs_at(memory_id, time2)?;
518
519 Ok(StateDiff {
520 memory_id,
521 time1,
522 time2,
523 memory_state1: state1.map(|t| t.memory),
524 memory_state2: state2.map(|t| t.memory),
525 crossrefs_added: crossrefs2
526 .iter()
527 .filter(|c| {
528 !crossrefs1
529 .iter()
530 .any(|c1| c1.to_id == c.to_id && c1.from_id == c.from_id)
531 })
532 .cloned()
533 .collect(),
534 crossrefs_removed: crossrefs1
535 .iter()
536 .filter(|c| {
537 !crossrefs2
538 .iter()
539 .any(|c2| c2.to_id == c.to_id && c2.from_id == c.from_id)
540 })
541 .cloned()
542 .collect(),
543 })
544 }
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct StateDiff {
550 pub memory_id: i64,
551 pub time1: DateTime<Utc>,
552 pub time2: DateTime<Utc>,
553 pub memory_state1: Option<Memory>,
554 pub memory_state2: Option<Memory>,
555 pub crossrefs_added: Vec<CrossReference>,
556 pub crossrefs_removed: Vec<CrossReference>,
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn test_temporal_query_options_default() {
565 let options = TemporalQueryOptions::default();
566 assert!(options.as_of.is_none());
567 assert!(options.created_after.is_none());
568 assert!(!options.include_deleted);
569 }
570
571 #[test]
572 fn test_temporal_query_options_as_of() {
573 let now = Utc::now();
574 let options = TemporalQueryOptions::as_of(now);
575 assert_eq!(options.as_of, Some(now));
576 }
577
578 #[test]
579 fn test_temporal_query_options_time_range() {
580 let start = Utc::now() - chrono::Duration::days(7);
581 let end = Utc::now();
582 let options = TemporalQueryOptions::time_range(start, end);
583 assert_eq!(options.created_after, Some(start));
584 assert_eq!(options.created_before, Some(end));
585 }
586}