use crate::{duration::DurationMs, effect::Scope, error::StateError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryTier {
#[default]
Hot,
Warm,
Cold,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Lifetime {
Transient,
Session,
Durable,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContentKind {
Episodic,
Semantic,
Procedural,
Structural,
Custom(String),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StoreOptions {
pub tier: Option<MemoryTier>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lifetime: Option<Lifetime>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content_kind: Option<ContentKind>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub salience: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ttl: Option<DurationMs>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SearchOptions {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min_score: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content_kind: Option<ContentKind>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tier: Option<MemoryTier>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_depth: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryLink {
pub from_key: String,
pub to_key: String,
pub relation: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
impl MemoryLink {
pub fn new(
from_key: impl Into<String>,
to_key: impl Into<String>,
relation: impl Into<String>,
) -> Self {
Self {
from_key: from_key.into(),
to_key: to_key.into(),
relation: relation.into(),
metadata: None,
}
}
}
#[async_trait]
pub trait StateStore: Send + Sync {
async fn read(&self, scope: &Scope, key: &str)
-> Result<Option<serde_json::Value>, StateError>;
async fn write(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
) -> Result<(), StateError>;
async fn delete(&self, scope: &Scope, key: &str) -> Result<(), StateError>;
async fn list(&self, scope: &Scope, prefix: &str) -> Result<Vec<String>, StateError>;
async fn search(
&self,
scope: &Scope,
query: &str,
limit: usize,
) -> Result<Vec<SearchResult>, StateError>;
async fn read_hinted(
&self,
scope: &Scope,
key: &str,
_options: &StoreOptions,
) -> Result<Option<serde_json::Value>, StateError> {
self.read(scope, key).await
}
async fn write_hinted(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
_options: &StoreOptions,
) -> Result<(), StateError> {
self.write(scope, key, value).await
}
fn clear_transient(&self) {}
async fn link(&self, _scope: &Scope, _link: &MemoryLink) -> Result<(), StateError> {
Ok(())
}
async fn unlink(
&self,
_scope: &Scope,
_from_key: &str,
_to_key: &str,
_relation: &str,
) -> Result<(), StateError> {
Ok(())
}
async fn traverse(
&self,
_scope: &Scope,
_from_key: &str,
_relation: Option<&str>,
_max_depth: u32,
) -> Result<Vec<String>, StateError> {
Ok(vec![])
}
async fn search_hinted(
&self,
scope: &Scope,
query: &str,
limit: usize,
_options: &SearchOptions,
) -> Result<Vec<SearchResult>, StateError> {
self.search(scope, query, limit).await
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub key: String,
pub score: f64,
pub snippet: Option<String>,
}
impl SearchResult {
pub fn new(key: impl Into<String>, score: f64) -> Self {
Self {
key: key.into(),
score,
snippet: None,
}
}
}
#[async_trait]
pub trait StateReader: Send + Sync {
async fn read(&self, scope: &Scope, key: &str)
-> Result<Option<serde_json::Value>, StateError>;
async fn list(&self, scope: &Scope, prefix: &str) -> Result<Vec<String>, StateError>;
async fn search(
&self,
scope: &Scope,
query: &str,
limit: usize,
) -> Result<Vec<SearchResult>, StateError>;
async fn read_hinted(
&self,
scope: &Scope,
key: &str,
_options: &StoreOptions,
) -> Result<Option<serde_json::Value>, StateError> {
self.read(scope, key).await
}
fn clear_transient(&self) {}
async fn traverse(
&self,
_scope: &Scope,
_from_key: &str,
_relation: Option<&str>,
_max_depth: u32,
) -> Result<Vec<String>, StateError> {
Ok(vec![])
}
async fn search_hinted(
&self,
scope: &Scope,
query: &str,
limit: usize,
_options: &SearchOptions,
) -> Result<Vec<SearchResult>, StateError> {
self.search(scope, query, limit).await
}
}
#[async_trait]
impl<T: StateStore> StateReader for T {
async fn read(
&self,
scope: &Scope,
key: &str,
) -> Result<Option<serde_json::Value>, StateError> {
StateStore::read(self, scope, key).await
}
async fn list(&self, scope: &Scope, prefix: &str) -> Result<Vec<String>, StateError> {
StateStore::list(self, scope, prefix).await
}
async fn search(
&self,
scope: &Scope,
query: &str,
limit: usize,
) -> Result<Vec<SearchResult>, StateError> {
StateStore::search(self, scope, query, limit).await
}
async fn read_hinted(
&self,
scope: &Scope,
key: &str,
options: &StoreOptions,
) -> Result<Option<serde_json::Value>, StateError> {
StateStore::read_hinted(self, scope, key, options).await
}
fn clear_transient(&self) {
StateStore::clear_transient(self);
}
async fn traverse(
&self,
scope: &Scope,
from_key: &str,
relation: Option<&str>,
max_depth: u32,
) -> Result<Vec<String>, StateError> {
StateStore::traverse(self, scope, from_key, relation, max_depth).await
}
async fn search_hinted(
&self,
scope: &Scope,
query: &str,
limit: usize,
options: &SearchOptions,
) -> Result<Vec<SearchResult>, StateError> {
StateStore::search_hinted(self, scope, query, limit, options).await
}
}