use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MemoryHandle {
pub id: String,
pub body: String,
#[serde(default)]
pub embedding: Option<Vec<f32>>,
#[serde(default)]
pub namespace: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HelperKind {
JaccardOverlap,
CosinePreFilter,
FtsClassifier,
}
impl HelperKind {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::JaccardOverlap => "jaccard_overlap",
Self::CosinePreFilter => "cosine_pre_filter",
Self::FtsClassifier => "fts_classifier",
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HelperParams {
pub content: String,
#[serde(default)]
pub candidates: Vec<MemoryHandle>,
#[serde(default)]
pub cosine_threshold: Option<f32>,
#[serde(default)]
pub content_embedding: Option<Vec<f32>>,
#[serde(default)]
pub namespace: Option<String>,
}
#[derive(Debug, Clone, Copy)]
pub struct HelperContext<'a> {
pub content: &'a str,
pub candidates: &'a [MemoryHandle],
pub content_embedding: Option<&'a [f32]>,
pub namespace: Option<&'a str>,
}
impl<'a> HelperContext<'a> {
#[must_use]
pub fn new(
content: &'a str,
candidates: &'a [MemoryHandle],
content_embedding: Option<&'a [f32]>,
namespace: Option<&'a str>,
) -> Self {
Self {
content,
candidates,
content_embedding,
namespace,
}
}
#[must_use]
pub fn effective_content<'p>(&self, params: &'p HelperParams) -> &'p str
where
'a: 'p,
{
if params.content.is_empty() {
self.content
} else {
params.content.as_str()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HelperOutput {
pub kind: HelperKind,
pub summary: String,
pub payload: Value,
}
#[must_use]
pub fn jaccard_overlap(params: &HelperParams) -> HelperOutput {
let ctx = HelperContext::new(¶ms.content, ¶ms.candidates, None, None);
jaccard_overlap_with(params, &ctx)
}
#[must_use]
pub fn jaccard_overlap_with(params: &HelperParams, ctx: &HelperContext<'_>) -> HelperOutput {
let content = ctx.effective_content(params);
let candidates: &[MemoryHandle] = if params.candidates.is_empty() {
ctx.candidates
} else {
params.candidates.as_slice()
};
let content_tokens = tokenise(content);
let mut scored: Vec<(&str, f32, &str)> = candidates
.iter()
.map(|c| {
let candidate_tokens = tokenise(&c.body);
let overlap = jaccard(&content_tokens, &candidate_tokens);
(c.id.as_str(), overlap, c.body.as_str())
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(10);
let over_threshold: usize = scored.iter().filter(|(_, score, _)| *score >= 0.40).count();
let summary = format!(
"jaccard: {}/{} candidates over 0.40 overlap",
over_threshold,
candidates.len()
);
let payload = json!({
"helper": "jaccard_overlap",
"candidates_scored": candidates.len(),
"top_candidates": scored
.iter()
.map(|(id, score, body)| json!({
"id": id,
"overlap": score,
"preview": preview(body, 120),
}))
.collect::<Vec<_>>(),
});
HelperOutput {
kind: HelperKind::JaccardOverlap,
summary,
payload,
}
}
#[must_use]
pub fn cosine_pre_filter(params: &HelperParams) -> HelperOutput {
let ctx = HelperContext::new(
¶ms.content,
¶ms.candidates,
params.content_embedding.as_deref(),
None,
);
cosine_pre_filter_with(params, &ctx)
}
#[must_use]
pub fn cosine_pre_filter_with(params: &HelperParams, ctx: &HelperContext<'_>) -> HelperOutput {
let threshold = params.cosine_threshold.unwrap_or(0.20);
let content_emb: Option<&[f32]> = if params.content_embedding.is_some() {
params.content_embedding.as_deref()
} else {
ctx.content_embedding
};
let candidates: &[MemoryHandle] = if params.candidates.is_empty() {
ctx.candidates
} else {
params.candidates.as_slice()
};
let scored: Vec<Value> = candidates
.iter()
.map(|c| {
let score = match (content_emb, c.embedding.as_deref()) {
(Some(a), Some(b)) => Some(cosine(a, b)),
_ => None,
};
json!({
"id": c.id,
"score": score,
"above_threshold": score.is_some_and(|s| s >= threshold),
"preview": preview(&c.body, 120),
})
})
.collect();
let kept = scored
.iter()
.filter(|v| v["above_threshold"].as_bool().unwrap_or(false))
.count();
let total = scored.len();
let summary = format!("cosine: {kept}/{total} candidates over {threshold:.2} threshold");
let payload = json!({
"helper": "cosine_pre_filter",
"threshold": threshold,
"candidates_scored": total,
"candidates_kept": kept,
"candidates": scored,
});
HelperOutput {
kind: HelperKind::CosinePreFilter,
summary,
payload,
}
}
#[must_use]
pub fn fts_classifier(params: &HelperParams) -> HelperOutput {
let ctx = HelperContext::new(
¶ms.content,
¶ms.candidates,
None,
params.namespace.as_deref(),
);
fts_classifier_with(params, &ctx)
}
#[must_use]
pub fn fts_classifier_with(params: &HelperParams, ctx: &HelperContext<'_>) -> HelperOutput {
let content = ctx.effective_content(params);
let namespace: &str = params
.namespace
.as_deref()
.or(ctx.namespace)
.unwrap_or(crate::DEFAULT_NAMESPACE);
let body_lower = content.to_lowercase();
let kind = if body_lower.contains("step ")
|| body_lower.contains("first, ")
|| body_lower.contains("then ")
{
"procedural"
} else if body_lower.contains("yesterday")
|| body_lower.contains("today")
|| body_lower.contains("happened")
|| body_lower.contains("event")
{
"episodic"
} else {
"declarative"
};
let summary = format!("fts_classifier: kind={kind} (namespace={namespace})");
let payload = json!({
"helper": HelperKind::FtsClassifier.as_str(),
"fact_kind": kind,
"namespace": namespace,
"tokens": tokenise(content).len(),
});
HelperOutput {
kind: HelperKind::FtsClassifier,
summary,
payload,
}
}
#[must_use]
pub fn run_helper(kind: HelperKind, params: &HelperParams) -> HelperOutput {
match kind {
HelperKind::JaccardOverlap => jaccard_overlap(params),
HelperKind::CosinePreFilter => cosine_pre_filter(params),
HelperKind::FtsClassifier => fts_classifier(params),
}
}
#[must_use]
pub fn run_helper_with(
kind: HelperKind,
params: &HelperParams,
ctx: &HelperContext<'_>,
) -> HelperOutput {
match kind {
HelperKind::JaccardOverlap => jaccard_overlap_with(params, ctx),
HelperKind::CosinePreFilter => cosine_pre_filter_with(params, ctx),
HelperKind::FtsClassifier => fts_classifier_with(params, ctx),
}
}
fn tokenise(body: &str) -> HashSet<String> {
body.split_whitespace()
.map(|t| {
t.trim_matches(|c: char| !c.is_alphanumeric())
.to_lowercase()
})
.filter(|t| !t.is_empty())
.collect()
}
fn jaccard(a: &HashSet<String>, b: &HashSet<String>) -> f32 {
if a.is_empty() && b.is_empty() {
return 0.0;
}
let intersect: usize = a.intersection(b).count();
let union: usize = a.union(b).count();
if union == 0 {
0.0
} else {
intersect as f32 / union as f32
}
}
fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || b.is_empty() || a.len() != b.len() {
return 0.0;
}
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
if na <= f32::EPSILON || nb <= f32::EPSILON {
return 0.0;
}
dot / (na.sqrt() * nb.sqrt())
}
fn preview(body: &str, max: usize) -> String {
if body.chars().count() <= max {
body.to_string()
} else {
let truncated: String = body.chars().take(max).collect();
format!("{truncated}…")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mh(id: &str, body: &str) -> MemoryHandle {
MemoryHandle {
id: id.to_string(),
body: body.to_string(),
embedding: None,
namespace: None,
}
}
fn mh_emb(id: &str, body: &str, embedding: Vec<f32>) -> MemoryHandle {
MemoryHandle {
id: id.to_string(),
body: body.to_string(),
embedding: Some(embedding),
namespace: None,
}
}
#[test]
fn jaccard_overlap_returns_non_empty_for_overlapping_text() {
let params = HelperParams {
content: "the quick brown fox jumps over the lazy dog".to_string(),
candidates: vec![
mh("a", "a quick brown dog"),
mh("b", "completely unrelated content here"),
],
..Default::default()
};
let out = jaccard_overlap(¶ms);
assert_eq!(out.kind, HelperKind::JaccardOverlap);
let top = out.payload["top_candidates"].as_array().unwrap();
assert_eq!(top.len(), 2);
assert_eq!(top[0]["id"].as_str(), Some("a"));
let top_score = top[0]["overlap"].as_f64().unwrap();
let bot_score = top[1]["overlap"].as_f64().unwrap();
assert!(top_score > bot_score);
}
#[test]
fn jaccard_overlap_handles_empty_candidates_cleanly() {
let params = HelperParams {
content: "hello world".to_string(),
candidates: vec![],
..Default::default()
};
let out = jaccard_overlap(¶ms);
assert_eq!(out.payload["candidates_scored"], 0);
assert_eq!(out.payload["top_candidates"].as_array().unwrap().len(), 0);
}
#[test]
fn cosine_pre_filter_drops_below_threshold() {
let params = HelperParams {
content: "x".to_string(),
candidates: vec![
mh_emb("near", "near body", vec![1.0, 0.0, 0.0]),
mh_emb("far", "far body", vec![0.0, 1.0, 0.0]),
],
content_embedding: Some(vec![1.0, 0.05, 0.0]),
cosine_threshold: Some(0.50),
..Default::default()
};
let out = cosine_pre_filter(¶ms);
let kept = out.payload["candidates_kept"].as_u64().unwrap();
assert_eq!(kept, 1, "only the 'near' candidate should pass");
}
#[test]
fn cosine_pre_filter_no_embedding_degrades_to_null_scores() {
let params = HelperParams {
content: "x".to_string(),
candidates: vec![mh("a", "a")],
content_embedding: None,
..Default::default()
};
let out = cosine_pre_filter(¶ms);
let candidates = out.payload["candidates"].as_array().unwrap();
assert!(candidates[0]["score"].is_null());
assert_eq!(candidates[0]["above_threshold"], false);
}
#[test]
fn fts_classifier_labels_procedural_text() {
let params = HelperParams {
content: "Step 1: open the door. Then walk through.".to_string(),
..Default::default()
};
let out = fts_classifier(¶ms);
assert_eq!(out.payload["fact_kind"], "procedural");
}
#[test]
fn fts_classifier_labels_episodic_text() {
let params = HelperParams {
content: "Yesterday I went to the store.".to_string(),
..Default::default()
};
let out = fts_classifier(¶ms);
assert_eq!(out.payload["fact_kind"], "episodic");
}
#[test]
fn fts_classifier_default_is_declarative() {
let params = HelperParams {
content: "The capital of France is Paris.".to_string(),
..Default::default()
};
let out = fts_classifier(¶ms);
assert_eq!(out.payload["fact_kind"], "declarative");
}
#[test]
fn run_helper_dispatches_correctly() {
let params = HelperParams {
content: "anything".to_string(),
..Default::default()
};
let out = run_helper(HelperKind::FtsClassifier, ¶ms);
assert_eq!(out.kind, HelperKind::FtsClassifier);
}
#[test]
fn helper_kind_serialisation_is_snake_case() {
assert_eq!(HelperKind::JaccardOverlap.as_str(), "jaccard_overlap");
assert_eq!(HelperKind::CosinePreFilter.as_str(), "cosine_pre_filter");
assert_eq!(HelperKind::FtsClassifier.as_str(), "fts_classifier");
}
}