#![allow(missing_docs)]
pub mod bm25;
pub mod consolidation;
pub mod embedding;
pub mod hybrid;
pub mod in_memory;
pub mod namespaced;
pub mod pruning;
pub mod reflection;
pub mod scoring;
pub mod shared_tools;
pub mod tools;
use std::future::Future;
use std::pin::Pin;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::auth::TenantScope;
use crate::error::Error;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryType {
#[default]
Episodic,
Semantic,
Reflection,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
)]
#[serde(rename_all = "snake_case")]
pub enum Confidentiality {
#[default]
Public,
Internal,
Confidential,
Restricted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: String,
pub agent: String,
pub content: String,
pub category: String,
pub tags: Vec<String>,
pub created_at: DateTime<Utc>,
pub last_accessed: DateTime<Utc>,
pub access_count: u32,
#[serde(default = "default_importance")]
pub importance: u8,
#[serde(default)]
pub memory_type: MemoryType,
#[serde(default)]
pub keywords: Vec<String>,
#[serde(default)]
pub summary: Option<String>,
#[serde(default = "default_strength")]
pub strength: f64,
#[serde(default)]
pub related_ids: Vec<String>,
#[serde(default)]
pub source_ids: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub embedding: Option<Vec<f32>>,
#[serde(default)]
pub confidentiality: Confidentiality,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub author_user_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub author_tenant_id: Option<String>,
}
pub(crate) fn default_importance() -> u8 {
5
}
pub(crate) fn default_strength() -> f64 {
1.0
}
pub(crate) fn default_category() -> String {
"fact".into()
}
pub(crate) fn default_recall_limit() -> usize {
10
}
#[derive(Debug, Clone)]
pub struct MemoryQuery {
pub text: Option<String>,
pub category: Option<String>,
pub tags: Vec<String>,
pub agent: Option<String>,
pub agent_prefix: Option<String>,
pub limit: usize,
pub memory_type: Option<MemoryType>,
pub min_strength: Option<f64>,
pub query_embedding: Option<Vec<f32>>,
pub max_confidentiality: Option<Confidentiality>,
pub reinforce: bool,
pub exact_words: bool,
}
impl Default for MemoryQuery {
fn default() -> Self {
Self {
text: None,
category: None,
tags: Vec::new(),
agent: None,
agent_prefix: None,
limit: 0,
memory_type: None,
min_strength: None,
query_embedding: None,
max_confidentiality: None,
reinforce: true,
exact_words: false,
}
}
}
pub trait Memory: Send + Sync {
fn store(
&self,
scope: &TenantScope,
entry: MemoryEntry,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>>;
fn recall(
&self,
scope: &TenantScope,
query: MemoryQuery,
) -> Pin<Box<dyn Future<Output = Result<Vec<MemoryEntry>, Error>> + Send + '_>>;
fn update(
&self,
scope: &TenantScope,
id: &str,
content: String,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>>;
fn forget(
&self,
scope: &TenantScope,
id: &str,
) -> Pin<Box<dyn Future<Output = Result<bool, Error>> + Send + '_>>;
fn add_link(
&self,
_scope: &TenantScope,
_id: &str,
_related_id: &str,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn prune(
&self,
_scope: &TenantScope,
_min_strength: f64,
_min_age: chrono::Duration,
_agent_prefix: Option<&str>,
) -> Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + '_>> {
Box::pin(async { Ok(0) })
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(id: &str, content: &str) -> MemoryEntry {
MemoryEntry {
id: id.into(),
agent: "a".into(),
content: content.into(),
category: "fact".into(),
tags: vec![],
created_at: Utc::now(),
last_accessed: Utc::now(),
access_count: 0,
importance: 5,
memory_type: MemoryType::default(),
keywords: vec![],
summary: None,
strength: 1.0,
related_ids: vec![],
source_ids: vec![],
embedding: None,
confidentiality: Confidentiality::default(),
author_user_id: None,
author_tenant_id: None,
}
}
#[test]
fn memory_entry_serializes() {
let entry = MemoryEntry {
id: "m1".into(),
agent: "researcher".into(),
content: "Rust is fast".into(),
category: "fact".into(),
tags: vec!["rust".into()],
created_at: Utc::now(),
last_accessed: Utc::now(),
access_count: 0,
importance: 7,
memory_type: MemoryType::default(),
keywords: vec![],
summary: None,
strength: 1.0,
related_ids: vec![],
source_ids: vec![],
embedding: None,
confidentiality: Confidentiality::default(),
author_user_id: None,
author_tenant_id: None,
};
let json = serde_json::to_string(&entry).unwrap();
let parsed: MemoryEntry = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "m1");
assert_eq!(parsed.agent, "researcher");
assert_eq!(parsed.content, "Rust is fast");
assert_eq!(parsed.importance, 7);
}
#[test]
fn memory_entry_serializes_new_fields() {
let entry = MemoryEntry {
id: "m1".into(),
agent: "a".into(),
content: "test".into(),
category: "fact".into(),
tags: vec![],
created_at: Utc::now(),
last_accessed: Utc::now(),
access_count: 0,
importance: 7,
memory_type: MemoryType::Reflection,
keywords: vec!["rust".into(), "performance".into()],
summary: Some("Rust is fast for systems programming".into()),
strength: 0.85,
related_ids: vec!["m2".into(), "m3".into()],
source_ids: vec!["m0".into()],
embedding: None,
confidentiality: Confidentiality::default(),
author_user_id: None,
author_tenant_id: None,
};
let json = serde_json::to_string(&entry).unwrap();
let parsed: MemoryEntry = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.memory_type, MemoryType::Reflection);
assert_eq!(parsed.keywords, vec!["rust", "performance"]);
assert_eq!(
parsed.summary.as_deref(),
Some("Rust is fast for systems programming")
);
assert!((parsed.strength - 0.85).abs() < f64::EPSILON);
assert_eq!(parsed.related_ids, vec!["m2", "m3"]);
assert_eq!(parsed.source_ids, vec!["m0"]);
}
#[test]
fn memory_entry_deserialize_without_new_fields() {
let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":9}"#;
let entry: MemoryEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.importance, 9);
assert_eq!(entry.memory_type, MemoryType::Episodic);
assert!(entry.keywords.is_empty());
assert!(entry.summary.is_none());
assert!((entry.strength - 1.0).abs() < f64::EPSILON);
assert!(entry.related_ids.is_empty());
assert!(entry.source_ids.is_empty());
}
#[test]
fn memory_type_default_is_episodic() {
assert_eq!(MemoryType::default(), MemoryType::Episodic);
}
#[test]
fn strength_default_is_one() {
assert!((default_strength() - 1.0).abs() < f64::EPSILON);
}
#[test]
fn memory_type_serialization_roundtrip() {
for mt in [
MemoryType::Episodic,
MemoryType::Semantic,
MemoryType::Reflection,
] {
let json = serde_json::to_string(&mt).unwrap();
let parsed: MemoryType = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, mt);
}
}
#[test]
fn memory_type_serializes_as_snake_case() {
assert_eq!(
serde_json::to_string(&MemoryType::Episodic).unwrap(),
"\"episodic\""
);
assert_eq!(
serde_json::to_string(&MemoryType::Semantic).unwrap(),
"\"semantic\""
);
assert_eq!(
serde_json::to_string(&MemoryType::Reflection).unwrap(),
"\"reflection\""
);
}
#[test]
fn memory_entry_default_importance() {
let entry = make_entry("m1", "test");
assert_eq!(entry.importance, 5);
}
#[test]
fn memory_entry_deserialize_without_importance() {
let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0}"#;
let entry: MemoryEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.importance, 5); }
#[test]
fn memory_entry_deserialize_with_importance() {
let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":9}"#;
let entry: MemoryEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.importance, 9);
}
#[test]
fn memory_query_default() {
let q = MemoryQuery::default();
assert!(q.text.is_none());
assert!(q.category.is_none());
assert!(q.tags.is_empty());
assert!(q.agent.is_none());
assert_eq!(q.limit, 0);
assert!(q.memory_type.is_none());
assert!(q.min_strength.is_none());
assert!(q.query_embedding.is_none());
}
#[test]
fn memory_trait_is_object_safe() {
fn _accepts_dyn(_m: &dyn Memory) {}
}
#[test]
fn memory_entry_embedding_serde_roundtrip() {
let entry = MemoryEntry {
id: "m1".into(),
agent: "a".into(),
content: "test".into(),
category: "fact".into(),
tags: vec![],
created_at: Utc::now(),
last_accessed: Utc::now(),
access_count: 0,
importance: 5,
memory_type: MemoryType::default(),
keywords: vec![],
summary: None,
strength: 1.0,
related_ids: vec![],
source_ids: vec![],
embedding: Some(vec![0.1, 0.2, 0.3]),
confidentiality: Confidentiality::default(),
author_user_id: None,
author_tenant_id: None,
};
let json = serde_json::to_string(&entry).unwrap();
assert!(json.contains("\"embedding\""));
let parsed: MemoryEntry = serde_json::from_str(&json).unwrap();
let emb = parsed.embedding.unwrap();
assert_eq!(emb.len(), 3);
assert!((emb[0] - 0.1).abs() < f32::EPSILON);
}
#[test]
fn memory_entry_backward_compat_no_embedding() {
let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":5}"#;
let entry: MemoryEntry = serde_json::from_str(json).unwrap();
assert!(entry.embedding.is_none());
}
#[test]
fn memory_entry_none_embedding_not_serialized() {
let entry = make_entry("m1", "test");
let json = serde_json::to_string(&entry).unwrap();
assert!(!json.contains("embedding"));
}
#[test]
fn confidentiality_default_is_public() {
assert_eq!(Confidentiality::default(), Confidentiality::Public);
}
#[test]
fn confidentiality_ordering() {
assert!(Confidentiality::Public < Confidentiality::Internal);
assert!(Confidentiality::Internal < Confidentiality::Confidential);
assert!(Confidentiality::Confidential < Confidentiality::Restricted);
}
#[test]
fn confidentiality_serde_roundtrip() {
for c in [
Confidentiality::Public,
Confidentiality::Internal,
Confidentiality::Confidential,
Confidentiality::Restricted,
] {
let json = serde_json::to_string(&c).unwrap();
let parsed: Confidentiality = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, c);
}
}
#[test]
fn confidentiality_serializes_as_snake_case() {
assert_eq!(
serde_json::to_string(&Confidentiality::Public).unwrap(),
"\"public\""
);
assert_eq!(
serde_json::to_string(&Confidentiality::Confidential).unwrap(),
"\"confidential\""
);
assert_eq!(
serde_json::to_string(&Confidentiality::Restricted).unwrap(),
"\"restricted\""
);
}
#[test]
fn memory_entry_backward_compat_no_confidentiality() {
let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":5}"#;
let entry: MemoryEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.confidentiality, Confidentiality::Public);
}
#[test]
fn memory_query_max_confidentiality_default_is_none() {
let q = MemoryQuery::default();
assert!(q.max_confidentiality.is_none());
}
}