use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MemoryOwner {
session_id: crate::types::SessionId,
}
impl MemoryOwner {
pub fn canonical_session(session_id: crate::types::SessionId) -> Self {
Self { session_id }
}
pub fn session_id(&self) -> &crate::types::SessionId {
&self.session_id
}
fn includes(&self, metadata: &MemoryMetadata) -> bool {
metadata.session_id == self.session_id
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryMetadata {
pub session_id: crate::types::SessionId,
pub turn: Option<u32>,
pub indexed_at: crate::time_compat::SystemTime,
}
#[derive(Debug, Clone)]
pub struct MemoryResult {
pub content: String,
pub metadata: MemoryMetadata,
pub score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MemorySearchScope {
pub owner: MemoryOwner,
}
impl MemorySearchScope {
pub fn for_session(session_id: crate::types::SessionId) -> Self {
Self {
owner: MemoryOwner::canonical_session(session_id),
}
}
pub fn for_owner(owner: MemoryOwner) -> Self {
Self { owner }
}
pub fn session_id(&self) -> &crate::types::SessionId {
self.owner.session_id()
}
pub fn includes(&self, metadata: &MemoryMetadata) -> bool {
self.owner.includes(metadata)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MemoryIndexScope {
pub owner: MemoryOwner,
}
impl MemoryIndexScope {
pub fn for_session(session_id: crate::types::SessionId) -> Self {
Self {
owner: MemoryOwner::canonical_session(session_id),
}
}
pub fn for_owner(owner: MemoryOwner) -> Self {
Self { owner }
}
pub fn session_id(&self) -> &crate::types::SessionId {
self.owner.session_id()
}
pub fn includes(&self, metadata: &MemoryMetadata) -> bool {
self.owner.includes(metadata)
}
}
#[derive(Debug, Clone)]
pub struct MemoryIndexRequest {
scope: MemoryIndexScope,
content: String,
metadata: MemoryMetadata,
}
impl MemoryIndexRequest {
pub fn new(
scope: MemoryIndexScope,
content: String,
metadata: MemoryMetadata,
) -> Result<Self, MemoryStoreError> {
if !scope.includes(&metadata) {
return Err(MemoryStoreError::Scope(format!(
"memory metadata session {} is outside indexing scope {}",
metadata.session_id,
scope.session_id()
)));
}
Ok(Self {
scope,
content,
metadata,
})
}
pub fn scope(&self) -> &MemoryIndexScope {
&self.scope
}
pub fn content(&self) -> &str {
&self.content
}
pub fn metadata(&self) -> &MemoryMetadata {
&self.metadata
}
pub fn into_parts(self) -> (MemoryIndexScope, String, MemoryMetadata) {
(self.scope, self.content, self.metadata)
}
}
#[derive(Debug, Clone)]
pub struct MemoryIndexBatch {
scope: MemoryIndexScope,
requests: Vec<MemoryIndexRequest>,
}
impl MemoryIndexBatch {
pub fn new(
scope: MemoryIndexScope,
requests: Vec<MemoryIndexRequest>,
) -> Result<Self, MemoryStoreError> {
for request in &requests {
if request.scope() != &scope {
return Err(MemoryStoreError::Scope(format!(
"memory index request scope {} is outside batch scope {}",
request.scope().session_id(),
scope.session_id()
)));
}
}
Ok(Self { scope, requests })
}
pub fn single(request: MemoryIndexRequest) -> Self {
Self {
scope: request.scope.clone(),
requests: vec![request],
}
}
pub fn scope(&self) -> &MemoryIndexScope {
&self.scope
}
pub fn len(&self) -> usize {
self.requests.len()
}
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
pub fn into_parts(self) -> (MemoryIndexScope, Vec<MemoryIndexRequest>) {
(self.scope, self.requests)
}
}
#[derive(Debug, Clone)]
pub struct MemoryIndexReceipt {
pub scope: MemoryIndexScope,
pub indexed_entries: usize,
}
#[derive(Debug)]
pub enum MemoryIndexDelivery {
NoStore {
scope: MemoryIndexScope,
},
Delivered(MemoryIndexReceipt),
Rejected {
scope: MemoryIndexScope,
attempted_entries: usize,
error: MemoryStoreError,
},
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait MemoryStore: Send + Sync {
async fn index_scoped(
&self,
request: MemoryIndexRequest,
) -> Result<MemoryIndexReceipt, MemoryStoreError> {
self.index_scoped_batch(MemoryIndexBatch::single(request))
.await
}
async fn index_scoped_batch(
&self,
batch: MemoryIndexBatch,
) -> Result<MemoryIndexReceipt, MemoryStoreError>;
async fn search(
&self,
scope: &MemorySearchScope,
query: &str,
limit: usize,
) -> Result<Vec<MemoryResult>, MemoryStoreError>;
}
#[derive(Debug, thiserror::Error)]
pub enum MemoryStoreError {
#[error("Scope error: {0}")]
Scope(String),
#[error("Embedding error: {0}")]
Embedding(String),
#[error("Index error: {0}")]
Index(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}