use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MemoryOwner {
session_id: crate::types::SessionId,
}
impl MemoryOwner {
pub fn canonical_session(session_id: crate::types::SessionId) -> Self {
Self { session_id }
}
pub fn session_id(&self) -> &crate::types::SessionId {
&self.session_id
}
fn includes(&self, metadata: &MemoryMetadata) -> bool {
metadata.session_id == self.session_id
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub struct MessageRange {
start: u64,
end: u64,
}
impl MessageRange {
pub fn new(start: u64, end: u64) -> Result<Self, MemoryStoreError> {
if start > end {
return Err(MemoryStoreError::SourceRange { start, end });
}
Ok(Self { start, end })
}
pub fn single(offset: u64) -> Self {
Self {
start: offset,
end: offset.saturating_add(1),
}
}
pub fn start(&self) -> u64 {
self.start
}
pub fn end(&self) -> u64 {
self.end
}
pub fn len(&self) -> u64 {
self.end - self.start
}
pub fn is_empty(&self) -> bool {
self.start == self.end
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum MemorySource {
Compaction {
source_range: MessageRange,
},
}
impl MemorySource {
pub fn source_range(&self) -> Option<MessageRange> {
match self {
MemorySource::Compaction { source_range } => Some(*source_range),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryMetadata {
pub session_id: crate::types::SessionId,
pub source: MemorySource,
pub indexed_at: crate::time_compat::SystemTime,
}
#[derive(Debug, Clone)]
pub struct MemoryResult {
pub content: String,
pub metadata: MemoryMetadata,
pub score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MemorySearchScope {
pub owner: MemoryOwner,
}
impl MemorySearchScope {
pub fn for_session(session_id: crate::types::SessionId) -> Self {
Self {
owner: MemoryOwner::canonical_session(session_id),
}
}
pub fn for_owner(owner: MemoryOwner) -> Self {
Self { owner }
}
pub fn session_id(&self) -> &crate::types::SessionId {
self.owner.session_id()
}
pub fn includes(&self, metadata: &MemoryMetadata) -> bool {
self.owner.includes(metadata)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MemoryIndexScope {
pub owner: MemoryOwner,
}
impl MemoryIndexScope {
pub fn for_session(session_id: crate::types::SessionId) -> Self {
Self {
owner: MemoryOwner::canonical_session(session_id),
}
}
pub fn for_owner(owner: MemoryOwner) -> Self {
Self { owner }
}
pub fn session_id(&self) -> &crate::types::SessionId {
self.owner.session_id()
}
pub fn includes(&self, metadata: &MemoryMetadata) -> bool {
self.owner.includes(metadata)
}
}
#[derive(Debug, Clone)]
pub struct MemoryIndexRequest {
scope: MemoryIndexScope,
content: crate::types::MemoryIndexableContent,
metadata: MemoryMetadata,
}
impl MemoryIndexRequest {
pub fn new(
scope: MemoryIndexScope,
content: crate::types::MemoryIndexableContent,
metadata: MemoryMetadata,
) -> Result<Self, MemoryStoreError> {
if !scope.includes(&metadata) {
return Err(MemoryStoreError::Scope(format!(
"memory metadata session {} is outside indexing scope {}",
metadata.session_id,
scope.session_id()
)));
}
Ok(Self {
scope,
content,
metadata,
})
}
pub fn scope(&self) -> &MemoryIndexScope {
&self.scope
}
pub fn content(&self) -> &crate::types::MemoryIndexableContent {
&self.content
}
pub fn indexable_text(&self) -> Option<&str> {
self.content.indexable_text()
}
pub fn metadata(&self) -> &MemoryMetadata {
&self.metadata
}
pub fn into_parts(
self,
) -> (
MemoryIndexScope,
crate::types::MemoryIndexableContent,
MemoryMetadata,
) {
(self.scope, self.content, self.metadata)
}
}
#[derive(Debug, Clone)]
pub struct MemoryIndexBatch {
scope: MemoryIndexScope,
requests: Vec<MemoryIndexRequest>,
}
impl MemoryIndexBatch {
pub fn new(
scope: MemoryIndexScope,
requests: Vec<MemoryIndexRequest>,
) -> Result<Self, MemoryStoreError> {
for request in &requests {
if request.scope() != &scope {
return Err(MemoryStoreError::Scope(format!(
"memory index request scope {} is outside batch scope {}",
request.scope().session_id(),
scope.session_id()
)));
}
}
Ok(Self { scope, requests })
}
pub fn single(request: MemoryIndexRequest) -> Self {
Self {
scope: request.scope.clone(),
requests: vec![request],
}
}
pub fn scope(&self) -> &MemoryIndexScope {
&self.scope
}
pub fn len(&self) -> usize {
self.requests.len()
}
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
pub fn into_parts(self) -> (MemoryIndexScope, Vec<MemoryIndexRequest>) {
(self.scope, self.requests)
}
}
#[derive(Debug, Clone)]
pub struct MemoryIndexReceipt {
pub scope: MemoryIndexScope,
pub indexed_entries: usize,
}
#[derive(Debug)]
pub enum MemoryIndexDelivery {
NoStore {
scope: MemoryIndexScope,
},
Delivered(MemoryIndexReceipt),
Rejected {
scope: MemoryIndexScope,
attempted_entries: usize,
error: MemoryStoreError,
},
}
pub trait EmbeddingModel: Send + Sync {
fn dimension(&self) -> usize;
fn embed(&self, text: &str) -> Vec<f32>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HnswParams {
pub max_nb_connection: usize,
pub max_layer: usize,
pub ef_construction: usize,
pub ef_search: usize,
}
impl Default for HnswParams {
fn default() -> Self {
Self {
max_nb_connection: 16,
max_layer: 16,
ef_construction: 200,
ef_search: 200,
}
}
}
#[derive(Clone)]
pub struct MemoryRankingPolicy {
embedding_model: std::sync::Arc<dyn EmbeddingModel>,
hnsw_params: HnswParams,
}
impl MemoryRankingPolicy {
pub fn new(
embedding_model: std::sync::Arc<dyn EmbeddingModel>,
hnsw_params: HnswParams,
) -> Self {
Self {
embedding_model,
hnsw_params,
}
}
pub fn embedding_model(&self) -> &std::sync::Arc<dyn EmbeddingModel> {
&self.embedding_model
}
pub fn hnsw_params(&self) -> HnswParams {
self.hnsw_params
}
pub fn dimension(&self) -> usize {
self.embedding_model.dimension()
}
pub fn embed(&self, text: &str) -> Vec<f32> {
self.embedding_model.embed(text)
}
}
impl std::fmt::Debug for MemoryRankingPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryRankingPolicy")
.field("dimension", &self.embedding_model.dimension())
.field("hnsw_params", &self.hnsw_params)
.finish()
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait MemoryStore: Send + Sync {
async fn index_scoped(
&self,
request: MemoryIndexRequest,
) -> Result<MemoryIndexReceipt, MemoryStoreError> {
self.index_scoped_batch(MemoryIndexBatch::single(request))
.await
}
async fn index_scoped_batch(
&self,
batch: MemoryIndexBatch,
) -> Result<MemoryIndexReceipt, MemoryStoreError>;
async fn search(
&self,
scope: &MemorySearchScope,
query: &str,
limit: usize,
) -> Result<Vec<MemoryResult>, MemoryStoreError>;
}
#[derive(Debug, thiserror::Error)]
pub enum MemoryStoreError {
#[error("Scope error: {0}")]
Scope(String),
#[error("invalid memory source range: start {start} > end {end}")]
SourceRange { start: u64, end: u64 },
#[error("Embedding error: {0}")]
Embedding(String),
#[error("Storage error: {0}")]
Storage(String),
#[error("memory index lock poisoned")]
LockPoisoned,
#[error("memory point ID out of range")]
PointIdOutOfRange,
#[error("memory point ID overflow")]
PointIdOverflow,
#[error("memory store task join failed: {0}")]
TaskJoin(String),
#[error("memory text corruption at point {point_id}: stored bytes are not valid UTF-8")]
TextCorruption { point_id: i64 },
#[error(
"memory index/store divergence at point {point_id}: live index references a missing durable row"
)]
IndexDivergence { point_id: i64 },
#[error("memory scope index is poisoned pending rebuild from durable state")]
ScopePoisoned,
#[error(
"memory scope repair failed after partial index failure: {repair} (original failure: {original})"
)]
ScopeRepairFailed {
original: Box<MemoryStoreError>,
repair: Box<MemoryStoreError>,
},
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
impl MemoryStoreError {
pub fn error_code(&self) -> &'static str {
match self {
Self::Scope(_) => "memory_scope",
Self::SourceRange { .. } => "memory_source_range",
Self::Embedding(_) => "memory_embedding",
Self::Storage(_) => "memory_storage",
Self::LockPoisoned => "memory_lock_poisoned",
Self::PointIdOutOfRange => "memory_point_id_out_of_range",
Self::PointIdOverflow => "memory_point_id_overflow",
Self::TaskJoin(_) => "memory_task_join",
Self::TextCorruption { .. } => "memory_text_corruption",
Self::IndexDivergence { .. } => "memory_index_divergence",
Self::ScopePoisoned => "memory_scope_poisoned",
Self::ScopeRepairFailed { .. } => "memory_scope_repair_failed",
Self::Io(_) => "memory_io",
}
}
}