use crate::collection::Collection;
use crate::{Database, DistanceMetric, Point};
use parking_lot::RwLock;
use std::collections::HashSet;
use super::error::AgentMemoryError;
pub(super) const EXPIRES_AT_KEY: &str = "_veles_expires_at";
pub(super) fn get_collection(db: &Database, name: &str) -> Result<Collection, AgentMemoryError> {
db.get_vector_collection(name)
.map(|vc| vc.inner)
.or_else(|| db.get_graph_collection(name).map(|gc| gc.inner))
.or_else(|| db.get_metadata_collection(name).map(|mc| mc.inner))
.ok_or_else(|| AgentMemoryError::CollectionError("Collection not found".to_string()))
}
pub(super) fn validate_dimension(expected: usize, actual: usize) -> Result<(), AgentMemoryError> {
if actual != expected {
return Err(AgentMemoryError::DimensionMismatch { expected, actual });
}
Ok(())
}
pub(super) fn open_or_create_collection(
db: &Database,
collection_name: &str,
dimension: usize,
) -> Result<usize, AgentMemoryError> {
let actual_dimension = if let Some(collection) = db.get_vector_collection(collection_name) {
let existing_dim = collection.config().dimension;
if existing_dim != dimension {
return Err(AgentMemoryError::DimensionMismatch {
expected: existing_dim,
actual: dimension,
});
}
existing_dim
} else {
db.create_collection(collection_name, dimension, DistanceMetric::Cosine)?;
dimension
};
Ok(actual_dimension)
}
pub(super) fn load_stored_ids(db: &Database, collection_name: &str) -> HashSet<u64> {
db.get_vector_collection(collection_name)
.map(|c| c.all_ids().into_iter().collect())
.unwrap_or_default()
}
pub(super) fn clear_collection(collection: &Collection) -> Result<(), AgentMemoryError> {
let existing_ids = collection.all_ids();
if !existing_ids.is_empty() {
collection
.delete(&existing_ids)
.map_err(|e| AgentMemoryError::CollectionError(e.to_string()))?;
}
Ok(())
}
pub(super) fn rebuild_stored_ids(stored_ids: &RwLock<HashSet<u64>>, points: &[Point]) {
let mut ids = stored_ids.write();
ids.clear();
for point in points {
ids.insert(point.id);
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct MemorySnapshot {
points: Vec<Point>,
#[serde(default)]
edges: Vec<crate::collection::graph::GraphEdge>,
}
pub(super) fn serialize_points(
collection: &Collection,
ids: &[u64],
) -> Result<Vec<u8>, AgentMemoryError> {
let points: Vec<_> = collection.get_raw(ids).into_iter().flatten().collect();
let id_set: HashSet<u64> = points.iter().map(|p| p.id).collect();
let edges: Vec<_> = collection
.get_all_edges()
.into_iter()
.filter(|e| id_set.contains(&e.source()) && id_set.contains(&e.target()))
.collect();
serde_json::to_vec(&MemorySnapshot { points, edges })
.map_err(|e| AgentMemoryError::IoError(e.to_string()))
}
pub(super) fn deserialize_into_collection(
data: &[u8],
collection: &Collection,
) -> Result<Option<Vec<Point>>, AgentMemoryError> {
if data.is_empty() {
return Ok(None);
}
let snapshot: MemorySnapshot = serde_json::from_slice(data)
.or_else(|_| {
serde_json::from_slice::<Vec<Point>>(data).map(|points| MemorySnapshot {
points,
edges: Vec::new(),
})
})
.map_err(|e| AgentMemoryError::IoError(e.to_string()))?;
clear_collection(collection)?;
upsert_points(collection, snapshot.points.clone())?;
if !snapshot.edges.is_empty() {
collection
.add_edges_batch(snapshot.edges)
.map_err(|e| AgentMemoryError::CollectionError(e.to_string()))?;
}
Ok(Some(snapshot.points))
}
pub(super) fn delete_from_collection(
collection: &Collection,
ids: &[u64],
) -> Result<(), AgentMemoryError> {
collection
.delete(ids)
.map_err(|e| AgentMemoryError::CollectionError(e.to_string()))
}
pub(super) fn upsert_points(
collection: &Collection,
points: Vec<Point>,
) -> Result<(), AgentMemoryError> {
collection
.upsert(points)
.map_err(|e| AgentMemoryError::CollectionError(e.to_string()))
}
pub(super) fn search_collection(
collection: &Collection,
query: &[f32],
k: usize,
) -> Result<Vec<crate::SearchResult>, AgentMemoryError> {
collection
.search(query, k)
.map_err(|e| AgentMemoryError::CollectionError(e.to_string()))
}
pub(super) fn delete_tracked_point(
db: &Database,
collection_name: &str,
id: u64,
stored_ids: &RwLock<HashSet<u64>>,
ttl: &super::ttl::MemoryTtl,
kind: super::ttl::MemoryKind,
) -> Result<(), AgentMemoryError> {
let collection = get_collection(db, collection_name)?;
delete_from_collection(&collection, &[id])?;
stored_ids.write().remove(&id);
ttl.remove(kind, id);
Ok(())
}
pub(super) fn serialize_tracked_points(
db: &Database,
collection_name: &str,
stored_ids: &RwLock<HashSet<u64>>,
) -> Result<Vec<u8>, AgentMemoryError> {
let collection = get_collection(db, collection_name)?;
let all_ids: Vec<u64> = stored_ids.read().iter().copied().collect();
serialize_points(&collection, &all_ids)
}
pub(super) fn deserialize_tracked_points(
db: &Database,
collection_name: &str,
data: &[u8],
stored_ids: &RwLock<HashSet<u64>>,
) -> Result<(), AgentMemoryError> {
let collection = get_collection(db, collection_name)?;
if let Some(points) = deserialize_into_collection(data, &collection)? {
rebuild_stored_ids(stored_ids, &points);
}
Ok(())
}
pub(super) fn search_filtered(
db: &Database,
collection_name: &str,
dimension: usize,
query_embedding: &[f32],
k: usize,
ttl: &super::ttl::MemoryTtl,
kind: super::ttl::MemoryKind,
) -> Result<Vec<crate::SearchResult>, AgentMemoryError> {
validate_dimension(dimension, query_embedding.len())?;
let collection = get_collection(db, collection_name)?;
let fetch_k = k.saturating_add(ttl.expired_count(kind));
let results = search_collection(&collection, query_embedding, fetch_k)?;
Ok(results
.into_iter()
.filter(|r| !ttl.is_expired(kind, r.point.id))
.take(k)
.collect())
}
#[allow(clippy::cast_possible_truncation)] pub(super) fn validate_binary_header(data: &[u8], entry_size: usize) -> Option<usize> {
if data.len() < 8 {
return None;
}
let count = u64::from_le_bytes(data[0..8].try_into().ok()?) as usize;
let payload = data.len() - 8;
if entry_size == 0 || count > payload / entry_size {
return None;
}
let total = 8usize.checked_add(count.checked_mul(entry_size)?)?;
if data.len() != total {
return None;
}
Some(count)
}
pub(super) fn init_tracked_memory(
db: &Database,
collection_name: &str,
dimension: usize,
) -> Result<(String, usize, RwLock<HashSet<u64>>), AgentMemoryError> {
let name = collection_name.to_string();
let actual_dimension = open_or_create_collection(db, &name, dimension)?;
let stored_ids = RwLock::new(load_stored_ids(db, &name));
Ok((name, actual_dimension, stored_ids))
}
pub(super) fn resolve_embedding(
dimension: usize,
embedding: Option<&[f32]>,
) -> Result<Vec<f32>, AgentMemoryError> {
if let Some(emb) = embedding {
validate_dimension(dimension, emb.len())?;
}
Ok(embedding.map_or_else(|| vec![0.0; dimension], <[f32]>::to_vec))
}
pub(super) fn attach_expiry(payload: &mut serde_json::Value, expires_at: Option<u64>) {
if let (Some(expiry), Some(obj)) = (expires_at, payload.as_object_mut()) {
obj.insert(EXPIRES_AT_KEY.to_string(), serde_json::Value::from(expiry));
}
}
pub(super) fn ensure_live(
collection: &Collection,
collection_name: &str,
ttl: &super::ttl::MemoryTtl,
kind: super::ttl::MemoryKind,
id: u64,
) -> Result<Point, AgentMemoryError> {
if ttl.is_expired(kind, id) {
return Err(AgentMemoryError::NotFound(format!(
"memory id {id} is expired in {collection_name}"
)));
}
get_point_or_not_found(collection, collection_name, id)
}
pub(super) fn seed_edge_counter(collection: &Collection) -> std::sync::atomic::AtomicU64 {
let next = collection
.max_edge_id()
.map_or(1, |max| max.saturating_add(1));
std::sync::atomic::AtomicU64::new(next)
}
pub(super) fn add_relation_edge(
collection: &Collection,
next_edge_id: &std::sync::atomic::AtomicU64,
endpoints: (u64, u64),
rel_type: &str,
properties: Option<&serde_json::Map<String, serde_json::Value>>,
) -> Result<u64, AgentMemoryError> {
let (from_id, to_id) = endpoints;
loop {
let edge_id = next_edge_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if collection.edge_exists(edge_id) {
continue; }
let mut edge = crate::collection::graph::GraphEdge::new(edge_id, from_id, to_id, rel_type)
.map_err(|e| AgentMemoryError::CollectionError(e.to_string()))?;
if let Some(props) = properties {
let map: std::collections::HashMap<String, serde_json::Value> =
props.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
edge = edge.with_properties(map);
}
match collection.add_edge(edge) {
Ok(()) => return Ok(edge_id),
Err(crate::error::Error::EdgeExists(_)) => {}
Err(e) => return Err(AgentMemoryError::CollectionError(e.to_string())),
}
}
}
pub(super) fn verify_relation_endpoints(
collection: &Collection,
edge_id: u64,
endpoints: (u64, u64),
) -> Result<(), AgentMemoryError> {
let (from_id, to_id) = endpoints;
let alive = collection.get(&[from_id, to_id]);
if alive.iter().flatten().count() == 2 {
return Ok(());
}
let _ = collection.remove_edge(edge_id);
Err(AgentMemoryError::NotFound(format!(
"a relation endpoint ({from_id} or {to_id}) was deleted concurrently"
)))
}
pub(super) fn get_point_or_not_found(
collection: &Collection,
collection_name: &str,
id: u64,
) -> Result<Point, AgentMemoryError> {
collection
.get(&[id])
.into_iter()
.flatten()
.next()
.ok_or_else(|| {
AgentMemoryError::NotFound(format!("memory id {id} not found in {collection_name}"))
})
}
pub(super) fn set_ttl_durable(
db: &Database,
collection_name: &str,
ttl: &super::ttl::MemoryTtl,
kind: super::ttl::MemoryKind,
id: u64,
ttl_seconds: u64,
) -> Result<(), AgentMemoryError> {
let collection = get_collection(db, collection_name)?;
let point = ensure_live(&collection, collection_name, ttl, kind, id)?;
let expires_at = super::ttl::MemoryTtl::now().saturating_add(ttl_seconds);
let mut payload = point
.payload
.unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
attach_expiry(&mut payload, Some(expires_at));
if payload.get(EXPIRES_AT_KEY).is_none() {
return Err(AgentMemoryError::CollectionError(format!(
"memory id {id} in {collection_name} has a non-object payload; cannot persist TTL"
)));
}
upsert_points(
&collection,
vec![Point {
id,
vector: point.vector,
payload: Some(payload),
sparse_vectors: point.sparse_vectors,
}],
)?;
ttl.set_expiry(kind, id, expires_at);
Ok(())
}
pub(super) fn rebuild_ttl_from_payloads(
db: &Database,
collection_name: &str,
ttl: &super::ttl::MemoryTtl,
kind: super::ttl::MemoryKind,
) -> Result<(), AgentMemoryError> {
let collection = get_collection(db, collection_name)?;
let all_ids = collection.all_ids();
for point in collection.get_raw(&all_ids).into_iter().flatten() {
let expiry = point
.payload
.as_ref()
.and_then(|p| p.get(EXPIRES_AT_KEY))
.and_then(serde_json::Value::as_u64);
if let Some(expires_at) = expiry {
ttl.set_expiry(kind, point.id, expires_at);
}
}
Ok(())
}
pub(super) fn execute_velesql(
db: &Database,
collection_name: &str,
sql: &str,
params: &std::collections::HashMap<String, serde_json::Value>,
ttl: &super::ttl::MemoryTtl,
kind: super::ttl::MemoryKind,
) -> Result<Vec<crate::SearchResult>, AgentMemoryError> {
let collection = get_collection(db, collection_name)?;
let results = collection
.execute_query_str(sql, params)
.map_err(AgentMemoryError::DatabaseError)?;
Ok(results
.into_iter()
.filter(|r| !ttl.is_expired(kind, r.point.id))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_dimension_matching_returns_ok() {
assert!(validate_dimension(128, 128).is_ok());
}
#[test]
fn validate_dimension_zero_matches_zero() {
assert!(validate_dimension(0, 0).is_ok());
}
#[test]
fn validate_dimension_mismatch_returns_error() {
let err = validate_dimension(128, 64).unwrap_err();
assert!(
matches!(
err,
AgentMemoryError::DimensionMismatch {
expected: 128,
actual: 64
}
),
"Expected DimensionMismatch, got: {err:?}"
);
}
#[test]
fn validate_dimension_swapped_values_are_distinct() {
let err = validate_dimension(64, 128).unwrap_err();
assert!(matches!(
err,
AgentMemoryError::DimensionMismatch {
expected: 64,
actual: 128
}
));
}
#[test]
fn rebuild_stored_ids_populates_from_points() {
let stored_ids = RwLock::new(HashSet::new());
let points = vec![
Point::without_payload(10, vec![0.0; 4]),
Point::without_payload(20, vec![0.0; 4]),
Point::without_payload(30, vec![0.0; 4]),
];
rebuild_stored_ids(&stored_ids, &points);
let ids = stored_ids.read();
assert_eq!(ids.len(), 3);
assert!(ids.contains(&10));
assert!(ids.contains(&20));
assert!(ids.contains(&30));
}
#[test]
fn rebuild_stored_ids_clears_previous_ids() {
let mut initial = HashSet::new();
initial.insert(1);
initial.insert(2);
let stored_ids = RwLock::new(initial);
let points = vec![Point::without_payload(99, vec![0.0; 4])];
rebuild_stored_ids(&stored_ids, &points);
let ids = stored_ids.read();
assert_eq!(ids.len(), 1);
assert!(ids.contains(&99));
assert!(!ids.contains(&1));
assert!(!ids.contains(&2));
}
#[test]
fn rebuild_stored_ids_empty_points_clears_all() {
let mut initial = HashSet::new();
initial.insert(5);
let stored_ids = RwLock::new(initial);
rebuild_stored_ids(&stored_ids, &[]);
assert!(stored_ids.read().is_empty());
}
#[test]
fn rebuild_stored_ids_deduplicates() {
let stored_ids = RwLock::new(HashSet::new());
let points = vec![
Point::without_payload(1, vec![0.0; 4]),
Point::without_payload(1, vec![1.0; 4]), ];
rebuild_stored_ids(&stored_ids, &points);
let ids = stored_ids.read();
assert_eq!(ids.len(), 1);
assert!(ids.contains(&1));
}
#[cfg(feature = "persistence")]
mod persistence_tests {
use super::*;
use tempfile::TempDir;
#[test]
fn open_or_create_creates_new_collection() {
let tmp = TempDir::new().unwrap();
let db = Database::open(tmp.path()).unwrap();
let dim = open_or_create_collection(&db, "test_coll", 64).unwrap();
assert_eq!(dim, 64);
assert!(db.get_vector_collection("test_coll").is_some());
}
#[test]
fn open_or_create_returns_existing_with_matching_dim() {
let tmp = TempDir::new().unwrap();
let db = Database::open(tmp.path()).unwrap();
open_or_create_collection(&db, "my_coll", 128).unwrap();
let dim = open_or_create_collection(&db, "my_coll", 128).unwrap();
assert_eq!(dim, 128);
}
#[test]
fn open_or_create_errors_on_dimension_mismatch() {
let tmp = TempDir::new().unwrap();
let db = Database::open(tmp.path()).unwrap();
open_or_create_collection(&db, "dim_coll", 64).unwrap();
let err = open_or_create_collection(&db, "dim_coll", 128).unwrap_err();
assert!(
matches!(
err,
AgentMemoryError::DimensionMismatch {
expected: 64,
actual: 128
}
),
"Expected DimensionMismatch, got: {err:?}"
);
}
}
}