use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::error::{ClientError, Result};
use crate::ConnectionTrait;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MemoryOp {
PutBlob {
key: Vec<u8>,
value: Vec<u8>,
},
PutEmbedding {
collection: String,
id: String,
embedding: Vec<f32>,
metadata: HashMap<String, String>,
},
CreateNode {
namespace: String,
node_id: String,
node_type: String,
properties: HashMap<String, serde_json::Value>,
},
CreateEdge {
namespace: String,
from_id: String,
edge_type: String,
to_id: String,
properties: HashMap<String, serde_json::Value>,
},
DeleteBlob {
key: Vec<u8>,
},
DeleteEdge {
namespace: String,
from_id: String,
edge_type: String,
to_id: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum IntentStatus {
Pending,
Applied,
Committed,
Aborted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryIntent {
pub intent_id: u64,
pub memory_id: String,
pub ops: Vec<MemoryOp>,
pub status: IntentStatus,
pub created_at: u64,
pub version: u64,
}
impl MemoryIntent {
pub fn new(intent_id: u64, memory_id: String, ops: Vec<MemoryOp>) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Self {
intent_id,
memory_id,
ops,
status: IntentStatus::Pending,
created_at: now,
version: now, }
}
}
const INTENT_PREFIX: &str = "_intents/";
pub struct AtomicMemoryWriter<C: ConnectionTrait> {
conn: C,
next_intent_id: AtomicU64,
}
impl<C: ConnectionTrait> AtomicMemoryWriter<C> {
pub fn new(conn: C) -> Self {
Self {
conn,
next_intent_id: AtomicU64::new(1),
}
}
fn next_id(&self) -> u64 {
self.next_intent_id.fetch_add(1, Ordering::SeqCst)
}
fn intent_key(intent_id: u64) -> Vec<u8> {
format!("{}{}", INTENT_PREFIX, intent_id).into_bytes()
}
pub fn write_atomic(
&self,
memory_id: impl Into<String>,
ops: Vec<MemoryOp>,
) -> Result<AtomicWriteResult> {
let memory_id = memory_id.into();
let intent_id = self.next_id();
let intent = MemoryIntent::new(intent_id, memory_id.clone(), ops);
self.write_intent(&intent)?;
let apply_result = self.apply_ops(&intent);
match apply_result {
Ok(applied_count) => {
self.mark_committed(intent_id)?;
Ok(AtomicWriteResult {
intent_id,
memory_id,
ops_applied: applied_count,
status: IntentStatus::Committed,
})
}
Err(e) => {
let _ = self.mark_aborted(intent_id);
Err(e)
}
}
}
fn write_intent(&self, intent: &MemoryIntent) -> Result<()> {
let key = Self::intent_key(intent.intent_id);
let value = serde_json::to_vec(intent)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
Ok(())
}
fn apply_ops(&self, intent: &MemoryIntent) -> Result<usize> {
let mut applied = 0;
for op in &intent.ops {
self.apply_op(op, &intent.memory_id, intent.version)?;
applied += 1;
}
Ok(applied)
}
fn apply_op(&self, op: &MemoryOp, memory_id: &str, version: u64) -> Result<()> {
match op {
MemoryOp::PutBlob { key, value } => {
let versioned_key = Self::versioned_key(key, version);
self.conn.put(&versioned_key, value)?;
self.conn.put(key, value)?;
}
MemoryOp::PutEmbedding { collection, id, embedding, metadata } => {
let key = format!("_vectors/{}/{}/meta", collection, id).into_bytes();
let meta = EmbeddingMeta {
memory_id: memory_id.to_string(),
version,
dimensions: embedding.len(),
metadata: metadata.clone(),
};
let value = serde_json::to_vec(&meta)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
let emb_key = format!("_vectors/{}/{}/data", collection, id).into_bytes();
let emb_bytes: Vec<u8> = embedding
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
self.conn.put(&emb_key, &emb_bytes)?;
}
MemoryOp::CreateNode { namespace, node_id, node_type, properties } => {
let key = format!("_graph/{}/nodes/{}", namespace, node_id).into_bytes();
let node = GraphNodeRecord {
id: node_id.clone(),
node_type: node_type.clone(),
properties: properties.clone(),
memory_id: memory_id.to_string(),
version,
};
let value = serde_json::to_vec(&node)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
MemoryOp::CreateEdge { namespace, from_id, edge_type, to_id, properties } => {
let edge_key = format!(
"_graph/{}/edges/{}/{}/{}",
namespace, from_id, edge_type, to_id
).into_bytes();
let edge = GraphEdgeRecord {
from_id: from_id.clone(),
edge_type: edge_type.clone(),
to_id: to_id.clone(),
properties: properties.clone(),
memory_id: memory_id.to_string(),
version,
};
let value = serde_json::to_vec(&edge)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&edge_key, &value)?;
let rev_key = format!(
"_graph/{}/index/{}/{}/{}",
namespace, edge_type, to_id, from_id
).into_bytes();
self.conn.put(&rev_key, from_id.as_bytes())?;
}
MemoryOp::DeleteBlob { key } => {
self.conn.delete(key)?;
}
MemoryOp::DeleteEdge { namespace, from_id, edge_type, to_id } => {
let edge_key = format!(
"_graph/{}/edges/{}/{}/{}",
namespace, from_id, edge_type, to_id
).into_bytes();
self.conn.delete(&edge_key)?;
let rev_key = format!(
"_graph/{}/index/{}/{}/{}",
namespace, edge_type, to_id, from_id
).into_bytes();
self.conn.delete(&rev_key)?;
}
}
Ok(())
}
fn versioned_key(key: &[u8], version: u64) -> Vec<u8> {
let mut versioned = key.to_vec();
versioned.extend_from_slice(b"@v");
versioned.extend_from_slice(&version.to_le_bytes());
versioned
}
fn mark_committed(&self, intent_id: u64) -> Result<()> {
self.update_intent_status(intent_id, IntentStatus::Committed)
}
fn mark_aborted(&self, intent_id: u64) -> Result<()> {
self.update_intent_status(intent_id, IntentStatus::Aborted)
}
fn update_intent_status(&self, intent_id: u64, status: IntentStatus) -> Result<()> {
let key = Self::intent_key(intent_id);
if let Some(data) = self.conn.get(&key)? {
let mut intent: MemoryIntent = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
intent.status = status;
let value = serde_json::to_vec(&intent)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
pub fn recover(&self) -> Result<RecoveryReport> {
let prefix = INTENT_PREFIX.as_bytes();
let intents = self.conn.scan(prefix)?;
let mut report = RecoveryReport::default();
for (_, value) in intents {
let intent: MemoryIntent = match serde_json::from_slice(&value) {
Ok(i) => i,
Err(_) => {
report.corrupted += 1;
continue;
}
};
match intent.status {
IntentStatus::Pending | IntentStatus::Applied => {
match self.apply_ops(&intent) {
Ok(_) => {
self.mark_committed(intent.intent_id)?;
report.replayed += 1;
}
Err(_) => {
self.mark_aborted(intent.intent_id)?;
report.failed += 1;
}
}
}
IntentStatus::Committed => {
report.already_committed += 1;
}
IntentStatus::Aborted => {
report.already_aborted += 1;
}
}
}
Ok(report)
}
pub fn cleanup(&self, max_age_secs: u64) -> Result<usize> {
let prefix = INTENT_PREFIX.as_bytes();
let intents = self.conn.scan(prefix)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let cutoff = now.saturating_sub(max_age_secs * 1000);
let mut cleaned = 0;
for (key, value) in intents {
let intent: MemoryIntent = match serde_json::from_slice(&value) {
Ok(i) => i,
Err(_) => continue,
};
if intent.created_at < cutoff {
if matches!(intent.status, IntentStatus::Committed | IntentStatus::Aborted) {
self.conn.delete(&key)?;
cleaned += 1;
}
}
}
Ok(cleaned)
}
}
#[derive(Debug)]
pub struct AtomicWriteResult {
pub intent_id: u64,
pub memory_id: String,
pub ops_applied: usize,
pub status: IntentStatus,
}
#[derive(Debug, Default)]
pub struct RecoveryReport {
pub replayed: usize,
pub failed: usize,
pub already_committed: usize,
pub already_aborted: usize,
pub corrupted: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct EmbeddingMeta {
memory_id: String,
version: u64,
dimensions: usize,
metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GraphNodeRecord {
id: String,
node_type: String,
properties: HashMap<String, serde_json::Value>,
memory_id: String,
version: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GraphEdgeRecord {
from_id: String,
edge_type: String,
to_id: String,
properties: HashMap<String, serde_json::Value>,
memory_id: String,
version: u64,
}
pub struct MemoryWriteBuilder {
memory_id: String,
ops: Vec<MemoryOp>,
}
impl MemoryWriteBuilder {
pub fn new(memory_id: impl Into<String>) -> Self {
Self {
memory_id: memory_id.into(),
ops: Vec::new(),
}
}
pub fn put_blob(mut self, key: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) -> Self {
self.ops.push(MemoryOp::PutBlob {
key: key.into(),
value: value.into(),
});
self
}
pub fn put_embedding(
mut self,
collection: impl Into<String>,
id: impl Into<String>,
embedding: Vec<f32>,
) -> Self {
self.ops.push(MemoryOp::PutEmbedding {
collection: collection.into(),
id: id.into(),
embedding,
metadata: HashMap::new(),
});
self
}
pub fn put_embedding_with_meta(
mut self,
collection: impl Into<String>,
id: impl Into<String>,
embedding: Vec<f32>,
metadata: HashMap<String, String>,
) -> Self {
self.ops.push(MemoryOp::PutEmbedding {
collection: collection.into(),
id: id.into(),
embedding,
metadata,
});
self
}
pub fn create_node(
mut self,
namespace: impl Into<String>,
node_id: impl Into<String>,
node_type: impl Into<String>,
) -> Self {
self.ops.push(MemoryOp::CreateNode {
namespace: namespace.into(),
node_id: node_id.into(),
node_type: node_type.into(),
properties: HashMap::new(),
});
self
}
pub fn create_node_with_props(
mut self,
namespace: impl Into<String>,
node_id: impl Into<String>,
node_type: impl Into<String>,
properties: HashMap<String, serde_json::Value>,
) -> Self {
self.ops.push(MemoryOp::CreateNode {
namespace: namespace.into(),
node_id: node_id.into(),
node_type: node_type.into(),
properties,
});
self
}
pub fn create_edge(
mut self,
namespace: impl Into<String>,
from_id: impl Into<String>,
edge_type: impl Into<String>,
to_id: impl Into<String>,
) -> Self {
self.ops.push(MemoryOp::CreateEdge {
namespace: namespace.into(),
from_id: from_id.into(),
edge_type: edge_type.into(),
to_id: to_id.into(),
properties: HashMap::new(),
});
self
}
pub fn create_edge_with_props(
mut self,
namespace: impl Into<String>,
from_id: impl Into<String>,
edge_type: impl Into<String>,
to_id: impl Into<String>,
properties: HashMap<String, serde_json::Value>,
) -> Self {
self.ops.push(MemoryOp::CreateEdge {
namespace: namespace.into(),
from_id: from_id.into(),
edge_type: edge_type.into(),
to_id: to_id.into(),
properties,
});
self
}
pub fn execute<C: ConnectionTrait>(self, writer: &AtomicMemoryWriter<C>) -> Result<AtomicWriteResult> {
writer.write_atomic(self.memory_id, self.ops)
}
pub fn ops(&self) -> &[MemoryOp] {
&self.ops
}
}
#[cfg(test)]
mod tests {
use super::*;
}