use std::sync::Arc;
use std::time::SystemTime;
use async_trait::async_trait;
use thiserror::Error;
use uuid7::uuid7;
use super::mutation::MemoryMutation;
#[async_trait]
pub trait Embedder: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>, MaterializeError>;
fn model_id(&self) -> String;
}
#[async_trait]
pub trait EntityExtractor: Send + Sync {
async fn extract(&self, text: &str) -> Result<Vec<String>, MaterializeError>;
}
#[derive(Debug, Clone, PartialEq, Error)]
pub enum MaterializeError {
#[error("embed failed: {message}")]
EmbedFailure { message: String },
#[error("entity extraction failed: {message}")]
NerFailure { message: String },
#[error("invalid request: {message}")]
InvalidRequest { message: String },
}
#[derive(Debug, Clone, PartialEq)]
pub struct RememberRequest {
pub text: String,
pub memory_type: String,
pub importance: f64,
pub valence: f64,
pub half_life: f64,
pub namespace: String,
pub certainty: f64,
pub domain: String,
pub source: String,
pub emotional_state: Option<String>,
pub metadata: serde_json::Value,
pub client_embedding: Option<Vec<f32>>,
}
#[async_trait]
pub trait Materializer: Send + Sync {
async fn materialize_remember(
&self,
req: RememberRequest,
) -> Result<MemoryMutation, MaterializeError>;
}
pub struct LocalMaterializer {
embedder: Arc<dyn Embedder>,
extractor: Arc<dyn EntityExtractor>,
}
impl LocalMaterializer {
pub fn new(embedder: Arc<dyn Embedder>, extractor: Arc<dyn EntityExtractor>) -> Self {
Self {
embedder,
extractor,
}
}
}
#[async_trait]
impl Materializer for LocalMaterializer {
async fn materialize_remember(
&self,
req: RememberRequest,
) -> Result<MemoryMutation, MaterializeError> {
if req.text.is_empty() {
return Err(MaterializeError::InvalidRequest {
message: "text is empty".into(),
});
}
let embedding = match req.client_embedding {
Some(emb) => emb,
None => self.embedder.embed(&req.text).await?,
};
let extracted_entities = self.extractor.extract(&req.text).await?;
let model_id = self.embedder.model_id();
let now_micros = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_micros() as i64)
.unwrap_or(0);
let rid = uuid7().to_string();
Ok(MemoryMutation::UpsertMemory {
rid,
text: req.text,
memory_type: req.memory_type,
importance: req.importance,
valence: req.valence,
half_life: req.half_life,
namespace: req.namespace,
certainty: req.certainty,
domain: req.domain,
source: req.source,
emotional_state: req.emotional_state,
embedding: Some(embedding),
metadata: req.metadata,
extracted_entities,
created_at_unix_micros: Some(now_micros),
embedding_model: Some(model_id),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use parking_lot::Mutex;
struct FakeEmbedder {
model: String,
call_count: Mutex<usize>,
}
impl FakeEmbedder {
fn new(model: &str) -> Arc<Self> {
Arc::new(Self {
model: model.into(),
call_count: Mutex::new(0),
})
}
fn calls(&self) -> usize {
*self.call_count.lock()
}
}
#[async_trait]
impl Embedder for FakeEmbedder {
async fn embed(&self, text: &str) -> Result<Vec<f32>, MaterializeError> {
*self.call_count.lock() += 1;
let sum: u32 = text.bytes().map(|b| b as u32).sum();
Ok(vec![
sum as f32,
text.len() as f32,
(text.bytes().next().unwrap_or(0)) as f32,
(text.bytes().last().unwrap_or(0)) as f32,
])
}
fn model_id(&self) -> String {
self.model.clone()
}
}
struct FakeExtractor;
#[async_trait]
impl EntityExtractor for FakeExtractor {
async fn extract(&self, text: &str) -> Result<Vec<String>, MaterializeError> {
Ok(text
.split_whitespace()
.filter(|w| w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false))
.map(|w| {
w.trim_end_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.collect())
}
}
fn build_materializer() -> (LocalMaterializer, Arc<FakeEmbedder>) {
let emb = FakeEmbedder::new("test-model.v1");
let mat = LocalMaterializer::new(emb.clone(), Arc::new(FakeExtractor));
(mat, emb)
}
fn req(text: &str) -> RememberRequest {
RememberRequest {
text: text.into(),
memory_type: "semantic".into(),
importance: 0.7,
valence: 0.0,
half_life: 86400.0,
namespace: "test".into(),
certainty: 1.0,
domain: "general".into(),
source: "test".into(),
emotional_state: None,
metadata: serde_json::json!({}),
client_embedding: None,
}
}
#[tokio::test]
async fn materialize_populates_all_v1_1_fields() {
let (mat, _) = build_materializer();
let m = mat
.materialize_remember(req("Alice met Bob in Paris"))
.await
.unwrap();
match m {
MemoryMutation::UpsertMemory {
rid,
embedding,
extracted_entities,
created_at_unix_micros,
embedding_model,
..
} => {
assert!(!rid.is_empty(), "rid stamped");
assert!(embedding.is_some(), "embedding stamped");
assert_eq!(extracted_entities, vec!["Alice", "Bob", "Paris"], "NER ran");
assert!(created_at_unix_micros.is_some(), "timestamp stamped");
assert_eq!(embedding_model.as_deref(), Some("test-model.v1"));
}
_ => panic!("expected UpsertMemory"),
}
}
#[tokio::test]
async fn client_supplied_embedding_skips_embedder() {
let (mat, emb) = build_materializer();
let mut r = req("hello");
r.client_embedding = Some(vec![9.0, 9.0, 9.0]);
let m = mat.materialize_remember(r).await.unwrap();
assert_eq!(emb.calls(), 0, "client embedding bypasses embedder");
match m {
MemoryMutation::UpsertMemory { embedding, .. } => {
assert_eq!(embedding, Some(vec![9.0, 9.0, 9.0]));
}
_ => panic!("expected UpsertMemory"),
}
}
#[tokio::test]
async fn empty_text_is_invalid() {
let (mat, _) = build_materializer();
let mut r = req("");
r.client_embedding = Some(vec![]);
let err = mat.materialize_remember(r).await.unwrap_err();
assert!(matches!(err, MaterializeError::InvalidRequest { .. }));
}
#[tokio::test]
async fn rid_is_fresh_each_call() {
let (mat, _) = build_materializer();
let m1 = mat.materialize_remember(req("hello")).await.unwrap();
let m2 = mat.materialize_remember(req("hello")).await.unwrap();
match (&m1, &m2) {
(
MemoryMutation::UpsertMemory { rid: r1, .. },
MemoryMutation::UpsertMemory { rid: r2, .. },
) => assert_ne!(r1, r2),
_ => panic!("expected UpsertMemory"),
}
}
#[tokio::test]
async fn embedder_failure_surfaces_as_embed_failure() {
struct AngryEmbedder;
#[async_trait]
impl Embedder for AngryEmbedder {
async fn embed(&self, _: &str) -> Result<Vec<f32>, MaterializeError> {
Err(MaterializeError::EmbedFailure {
message: "ONNX refused".into(),
})
}
fn model_id(&self) -> String {
"angry.v1".into()
}
}
let mat = LocalMaterializer::new(Arc::new(AngryEmbedder), Arc::new(FakeExtractor));
let err = mat.materialize_remember(req("hello")).await.unwrap_err();
assert!(matches!(err, MaterializeError::EmbedFailure { .. }));
}
#[tokio::test]
async fn ner_failure_surfaces_as_ner_failure() {
struct AngryExtractor;
#[async_trait]
impl EntityExtractor for AngryExtractor {
async fn extract(&self, _: &str) -> Result<Vec<String>, MaterializeError> {
Err(MaterializeError::NerFailure {
message: "NER pipeline crashed".into(),
})
}
}
let mat = LocalMaterializer::new(FakeEmbedder::new("test.v1"), Arc::new(AngryExtractor));
let err = mat.materialize_remember(req("hello")).await.unwrap_err();
assert!(matches!(err, MaterializeError::NerFailure { .. }));
}
#[allow(dead_code)]
fn _dyn_materializer_compile_check(_m: Arc<dyn Materializer>) {}
}