use serde_json::Value;
pub trait MemoryProvider: Send + Sync {
fn name(&self) -> &str;
fn format_for_system_prompt(&self) -> String {
String::new()
}
fn view(&self, target: &str) -> Value;
fn add(&self, target: &str, content: &str, kind: Option<&str>) -> Result<Value, String>;
fn replace(
&self,
target: &str,
old_text: &str,
content: &str,
kind: Option<&str>,
) -> Result<Value, String>;
fn remove(&self, target: &str, old_text: &str) -> Result<Value, String>;
fn restore(&self, _target: &str, _old_text: &str) -> Result<Value, String> {
Err("This memory backend does not support restoring removed entries".to_string())
}
fn expand(&self, _old_text: &str) -> Result<Value, String> {
Err("This memory backend does not support expanding entries".to_string())
}
fn search(&self, _query: &str) -> Result<Value, String> {
Err("This memory backend does not support searching entries".to_string())
}
fn on_memory_write(&self, _action: &str, _target: &str, _payload: &str) {}
fn on_session_end(&self, _transcript: &str) {}
fn on_session_switch(&self, _new_session_id: &str, _parent_session_id: &str, _reset: bool) {}
fn on_pre_compress(&self, _transcript: &str) -> String {
String::new()
}
}
impl MemoryProvider for super::memory_db::SqliteMemoryStore {
fn name(&self) -> &str {
"builtin"
}
fn format_for_system_prompt(&self) -> String {
super::memory_db::SqliteMemoryStore::format_for_system_prompt(self)
}
fn view(&self, target: &str) -> Value {
super::memory_db::SqliteMemoryStore::view(self, target)
}
fn add(&self, target: &str, content: &str, kind: Option<&str>) -> Result<Value, String> {
let mkind = kind.and_then(super::memory_db::parse_kind);
super::memory_db::SqliteMemoryStore::add(self, target, content, mkind)
}
fn replace(
&self,
target: &str,
old_text: &str,
content: &str,
kind: Option<&str>,
) -> Result<Value, String> {
let mkind = kind.and_then(super::memory_db::parse_kind);
super::memory_db::SqliteMemoryStore::replace(self, target, old_text, content, mkind)
}
fn remove(&self, target: &str, old_text: &str) -> Result<Value, String> {
super::memory_db::SqliteMemoryStore::remove(self, target, old_text)
}
fn restore(&self, target: &str, old_text: &str) -> Result<Value, String> {
super::memory_db::SqliteMemoryStore::restore(self, target, old_text)
}
fn expand(&self, old_text: &str) -> Result<Value, String> {
super::memory_db::SqliteMemoryStore::expand(self, old_text)
}
fn search(&self, query: &str) -> Result<Value, String> {
super::memory_db::SqliteMemoryStore::search(self, query)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[derive(Default)]
struct RecordingProvider {
writes: Mutex<Vec<(String, String, String)>>,
}
impl MemoryProvider for RecordingProvider {
fn name(&self) -> &str {
"recording-test"
}
fn view(&self, _target: &str) -> Value {
Value::Null
}
fn add(&self, _: &str, _: &str, _kind: Option<&str>) -> Result<Value, String> {
Ok(Value::Null)
}
fn replace(&self, _: &str, _: &str, _: &str, _kind: Option<&str>) -> Result<Value, String> {
Ok(Value::Null)
}
fn remove(&self, _: &str, _: &str) -> Result<Value, String> {
Ok(Value::Null)
}
fn on_memory_write(&self, action: &str, target: &str, content: &str) {
self.writes
.lock()
.unwrap()
.push((action.into(), target.into(), content.into()));
}
}
#[test]
fn provider_crud_does_not_self_fire_on_memory_write() {
let p = RecordingProvider::default();
let _ = p.add("memory", "hello", None);
let _ = p.replace("memory", "old", "hello", None);
let _ = p.remove("pitfalls", "old");
let writes = p.writes.lock().unwrap();
assert!(
writes.is_empty(),
"providers must NOT self-fire on_memory_write \
(the tool layer does); got: {:?}",
*writes
);
}
#[test]
fn on_memory_write_contract_is_documented() {
let src = include_str!("memory_provider.rs");
let anchor = "fn on_memory_write";
let pos = src
.find(anchor)
.expect("trait method on_memory_write must exist");
let preamble = &src[pos.saturating_sub(2000)..pos];
assert!(
preamble.contains("payload"),
"doc must rename the third param meaning to 'payload'"
);
assert!(
preamble.contains("`\"remove\"`"),
"doc must describe the remove case"
);
assert!(
preamble.contains("old_text"),
"doc must say payload is old_text on remove"
);
assert!(
preamble.contains("NOT a generic") || preamble.contains("not a generic"),
"doc must warn that payload is NOT a generic new-value field"
);
}
#[test]
fn external_on_memory_write_call_records() {
let p = RecordingProvider::default();
p.on_memory_write("add", "memory", "hello");
p.on_memory_write("remove", "pitfalls", "hello");
let writes = p.writes.lock().unwrap();
assert_eq!(writes.len(), 2);
assert_eq!(writes[0], ("add".into(), "memory".into(), "hello".into()));
assert_eq!(
writes[1],
("remove".into(), "pitfalls".into(), "hello".into())
);
}
#[test]
fn on_pre_compress_output_threads_into_instructions() {
#[derive(Default)]
struct InsightProvider {
saw_transcript: Mutex<Option<String>>,
}
impl MemoryProvider for InsightProvider {
fn name(&self) -> &str {
"insight"
}
fn view(&self, _: &str) -> Value {
Value::Null
}
fn add(&self, _: &str, _: &str, _kind: Option<&str>) -> Result<Value, String> {
Ok(Value::Null)
}
fn replace(
&self,
_: &str,
_: &str,
_: &str,
_kind: Option<&str>,
) -> Result<Value, String> {
Ok(Value::Null)
}
fn remove(&self, _: &str, _: &str) -> Result<Value, String> {
Ok(Value::Null)
}
fn on_pre_compress(&self, transcript: &str) -> String {
*self.saw_transcript.lock().unwrap() = Some(transcript.to_string());
"REMEMBER: project uses cargo not bazel".into()
}
}
let p = InsightProvider::default();
let extra = p.on_pre_compress("turn 1 transcript");
assert_eq!(extra, "REMEMBER: project uses cargo not bazel");
assert_eq!(
p.saw_transcript.lock().unwrap().as_deref(),
Some("turn 1 transcript"),
"hook must receive the pre-compress transcript verbatim"
);
}
#[test]
fn on_session_end_fires_with_transcript() {
#[derive(Default)]
struct EndProvider {
ends: Mutex<Vec<String>>,
}
impl MemoryProvider for EndProvider {
fn name(&self) -> &str {
"end"
}
fn view(&self, _: &str) -> Value {
Value::Null
}
fn add(&self, _: &str, _: &str, _kind: Option<&str>) -> Result<Value, String> {
Ok(Value::Null)
}
fn replace(
&self,
_: &str,
_: &str,
_: &str,
_kind: Option<&str>,
) -> Result<Value, String> {
Ok(Value::Null)
}
fn remove(&self, _: &str, _: &str) -> Result<Value, String> {
Ok(Value::Null)
}
fn on_session_end(&self, transcript: &str) {
self.ends.lock().unwrap().push(transcript.to_string());
}
}
let p = EndProvider::default();
p.on_session_end("User: hi\n\nAssistant: hello\n");
let ends = p.ends.lock().unwrap();
assert_eq!(ends.len(), 1, "exactly one end-of-session fire");
assert!(
ends[0].contains("User: hi") && ends[0].contains("Assistant: hello"),
"transcript must contain user + assistant turns: {:?}",
ends[0]
);
}
#[test]
fn alternative_provider_default_hooks_are_no_ops() {
struct MinimalProvider;
impl MemoryProvider for MinimalProvider {
fn name(&self) -> &str {
"minimal"
}
fn view(&self, _: &str) -> Value {
Value::Null
}
fn add(&self, _: &str, _: &str, _kind: Option<&str>) -> Result<Value, String> {
Ok(Value::Null)
}
fn replace(
&self,
_: &str,
_: &str,
_: &str,
_kind: Option<&str>,
) -> Result<Value, String> {
Ok(Value::Null)
}
fn remove(&self, _: &str, _: &str) -> Result<Value, String> {
Ok(Value::Null)
}
}
let p = MinimalProvider;
p.on_session_end("transcript");
assert_eq!(p.on_pre_compress("anything"), "");
p.on_memory_write("add", "memory", "x");
}
#[test]
fn builtin_store_implements_trait_and_routes_through_on_write() {
use crate::extras::dirge_paths::ProjectPaths;
let dir = std::env::temp_dir().join(format!(
"dirge-memprovider-test-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
std::fs::create_dir_all(dir.join(".git")).unwrap();
let paths = ProjectPaths::new(&dir);
let store = super::super::memory_db::SqliteMemoryStore::load(&paths).unwrap();
let provider: &dyn MemoryProvider = &store;
assert_eq!(provider.name(), "builtin");
let resp = provider.add("memory", "trait-routed entry", None).unwrap();
assert_eq!(resp["success"], true);
let view = provider.view("memory");
let entries = view["entries"].as_array().unwrap();
assert!(entries.iter().any(|e| {
e.as_str()
.map(|s| s.contains("trait-routed"))
.unwrap_or(false)
}));
std::fs::remove_dir_all(&dir).ok();
}
}