use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MemoryId([u8; 32]);
impl MemoryId {
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn from_data(data: &[u8]) -> Self {
Self(*blake3::hash(data).as_bytes())
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
pub fn to_hex(&self) -> String {
self.0.iter().map(|b| format!("{:02x}", b)).collect()
}
pub fn from_hex(hex: &str) -> Option<Self> {
if hex.len() != 64 {
return None;
}
let mut bytes = [0u8; 32];
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
let s = std::str::from_utf8(chunk).ok()?;
bytes[i] = u8::from_str_radix(s, 16).ok()?;
}
Some(Self(bytes))
}
}
impl Serialize for MemoryId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_hex())
}
}
impl<'de> Deserialize<'de> for MemoryId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::from_hex(&s).ok_or_else(|| serde::de::Error::custom("invalid memory id hex"))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub struct Timestamp(pub u64);
impl Timestamp {
pub fn now() -> Self {
let now = chrono::Utc::now();
let micros = (now.timestamp() as u64) * 1_000_000 + (now.timestamp_subsec_micros() as u64);
Self(micros)
}
pub fn from_secs(secs: u64) -> Self {
Self(secs * 1_000_000)
}
pub fn age_secs(&self) -> u64 {
let now = Self::now();
(now.0.saturating_sub(self.0)) / 1_000_000
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: MemoryId,
pub entry_type: String,
pub data: serde_json::Value,
pub metadata: MemoryMetadata,
pub tags: Vec<SemanticTag>,
pub embedding: Option<Embedding>,
}
impl MemoryEntry {
pub fn new(entry_type: &str, data: serde_json::Value) -> Self {
let timestamp = chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0);
let mut to_hash = Vec::new();
to_hash.extend_from_slice(entry_type.as_bytes());
to_hash.extend_from_slice(×tamp.to_le_bytes());
to_hash.extend_from_slice(&serde_json::to_vec(&data).unwrap_or_default());
let id = MemoryId::from_data(&to_hash);
Self {
id,
entry_type: entry_type.to_string(),
data,
metadata: MemoryMetadata::default(),
tags: Vec::new(),
embedding: None,
}
}
pub fn with_tags(mut self, tags: &[&str]) -> Self {
self.tags = tags.iter().map(|t| SemanticTag::new(t)).collect();
self
}
pub fn with_importance(mut self, importance: f32) -> Self {
self.metadata.importance = importance;
self
}
pub fn with_embedding(mut self, embedding: Embedding) -> Self {
self.embedding = Some(embedding);
self
}
pub fn size_bytes(&self) -> usize {
std::mem::size_of::<Self>()
+ self.entry_type.len()
+ serde_json::to_vec(&self.data).map(|v| v.len()).unwrap_or(0)
+ self.tags.iter().map(|t| t.0.len()).sum::<usize>()
+ self.embedding.as_ref().map(|e| e.0.len() * 4).unwrap_or(0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryMetadata {
pub created_at: Timestamp,
pub last_accessed: Timestamp,
pub access_count: u32,
pub importance: f32,
pub attention: f32,
pub consolidated: bool,
pub source: String,
}
impl Default for MemoryMetadata {
fn default() -> Self {
let now = Timestamp::now();
Self {
created_at: now,
last_accessed: now,
access_count: 0,
importance: 0.5,
attention: 1.0,
consolidated: false,
source: "unknown".to_string(),
}
}
}
impl MemoryMetadata {
pub fn with_source(source: &str) -> Self {
Self {
source: source.to_string(),
..Default::default()
}
}
pub fn record_access(&mut self) {
self.last_accessed = Timestamp::now();
self.access_count += 1;
self.attention = (self.attention + 0.2).min(1.0);
}
pub fn decay(&mut self, factor: f32) {
self.attention *= factor;
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SemanticTag(pub String);
impl SemanticTag {
pub fn new(tag: &str) -> Self {
Self(tag.to_lowercase().trim().to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding(pub Vec<f32>);
impl Embedding {
pub fn new(values: Vec<f32>) -> Self {
Self(values)
}
pub fn cosine_similarity(&self, other: &Embedding) -> f32 {
if self.0.len() != other.0.len() || self.0.is_empty() {
return 0.0;
}
let dot: f32 = self.0.iter().zip(&other.0).map(|(a, b)| a * b).sum();
let mag_a: f32 = self.0.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = other.0.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
0.0
} else {
dot / (mag_a * mag_b)
}
}
pub fn from_text_simple(text: &str) -> Self {
const DIM: usize = 64;
let mut values = vec![0.0f32; DIM];
for word in text.to_lowercase().split_whitespace() {
let hash = blake3::hash(word.as_bytes());
let bytes = hash.as_bytes();
for (i, &b) in bytes.iter().take(DIM).enumerate() {
values[i] += (b as f32 / 255.0) - 0.5;
}
}
let magnitude: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for v in &mut values {
*v /= magnitude;
}
}
Self(values)
}
}
#[derive(Debug, Clone, Default)]
pub struct MemoryQuery {
pub text: Option<String>,
pub tags: Vec<SemanticTag>,
pub entry_type: Option<String>,
pub min_importance: Option<f32>,
pub after: Option<Timestamp>,
pub before: Option<Timestamp>,
pub limit: Option<usize>,
pub embedding: Option<Embedding>,
}
impl MemoryQuery {
pub fn text(query: &str) -> Self {
Self {
text: Some(query.to_string()),
embedding: Some(Embedding::from_text_simple(query)),
..Default::default()
}
}
pub fn tags(tags: &[&str]) -> Self {
Self {
tags: tags.iter().map(|t| SemanticTag::new(t)).collect(),
..Default::default()
}
}
pub fn entry_type(entry_type: &str) -> Self {
Self {
entry_type: Some(entry_type.to_string()),
..Default::default()
}
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn with_min_importance(mut self, importance: f32) -> Self {
self.min_importance = Some(importance);
self
}
}
#[derive(Debug, Clone)]
pub struct MemoryResult {
pub entry: MemoryEntry,
pub relevance: f32,
pub source: MemorySource,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemorySource {
ShortTerm,
LongTerm,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
pub id: EntityId,
pub entity_type: String,
pub name: String,
pub properties: HashMap<String, serde_json::Value>,
pub embedding: Option<Embedding>,
pub metadata: MemoryMetadata,
}
impl Entity {
pub fn new(entity_type: &str, name: &str) -> Self {
let id_data = format!("{}:{}", entity_type, name);
let id = EntityId::from_data(id_data.as_bytes());
Self {
id,
entity_type: entity_type.to_string(),
name: name.to_string(),
properties: HashMap::new(),
embedding: None,
metadata: MemoryMetadata::default(),
}
}
pub fn with_property(mut self, key: &str, value: serde_json::Value) -> Self {
self.properties.insert(key.to_string(), value);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EntityId([u8; 32]);
impl EntityId {
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn from_data(data: &[u8]) -> Self {
Self(*blake3::hash(data).as_bytes())
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
pub fn to_hex(&self) -> String {
self.0.iter().map(|b| format!("{:02x}", b)).collect()
}
}
impl Serialize for EntityId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_hex())
}
}
impl<'de> Deserialize<'de> for EntityId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let mut bytes = [0u8; 32];
for (i, chunk) in s.as_bytes().chunks(2).enumerate() {
if i >= 32 {
break;
}
let hex_str = std::str::from_utf8(chunk).map_err(serde::de::Error::custom)?;
bytes[i] = u8::from_str_radix(hex_str, 16).map_err(serde::de::Error::custom)?;
}
Ok(Self(bytes))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Link {
pub source: EntityId,
pub target: EntityId,
pub relation: Relation,
pub weight: f32,
pub properties: HashMap<String, serde_json::Value>,
pub created_at: Timestamp,
}
impl Link {
pub fn new(source: EntityId, relation: Relation, target: EntityId) -> Self {
Self {
source,
target,
relation,
weight: 1.0,
properties: HashMap::new(),
created_at: Timestamp::now(),
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Relation(pub String);
impl Relation {
pub fn new(name: &str) -> Self {
Self(name.to_uppercase())
}
pub fn is_a() -> Self {
Self::new("IS_A")
}
pub fn has() -> Self {
Self::new("HAS")
}
pub fn related_to() -> Self {
Self::new("RELATED_TO")
}
pub fn caused_by() -> Self {
Self::new("CAUSED_BY")
}
pub fn located_at() -> Self {
Self::new("LOCATED_AT")
}
pub fn observed() -> Self {
Self::new("OBSERVED")
}
}
pub type LinkType = Relation;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_id() {
let data = b"test data";
let id1 = MemoryId::from_data(data);
let id2 = MemoryId::from_data(data);
assert_eq!(id1, id2);
}
#[test]
fn test_embedding_similarity() {
let e1 = Embedding::new(vec![1.0, 0.0, 0.0]);
let e2 = Embedding::new(vec![1.0, 0.0, 0.0]);
let e3 = Embedding::new(vec![0.0, 1.0, 0.0]);
assert!((e1.cosine_similarity(&e2) - 1.0).abs() < 0.001);
assert!((e1.cosine_similarity(&e3)).abs() < 0.001);
}
#[test]
fn test_memory_entry() {
let entry = MemoryEntry::new("test", serde_json::json!({"key": "value"}))
.with_tags(&["tag1", "tag2"])
.with_importance(0.8);
assert_eq!(entry.entry_type, "test");
assert_eq!(entry.tags.len(), 2);
assert_eq!(entry.metadata.importance, 0.8);
}
#[test]
fn test_entity() {
let entity = Entity::new("sensor", "temp_001")
.with_property("location", serde_json::json!("room_a"));
assert_eq!(entity.entity_type, "sensor");
assert_eq!(entity.name, "temp_001");
assert!(entity.properties.contains_key("location"));
}
}