1use std::collections::HashSet;
21
22use serde::{Deserialize, Serialize};
23use serde_json::{Value, json};
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct MemoryHandle {
33 pub id: String,
36 pub body: String,
38 #[serde(default)]
41 pub embedding: Option<Vec<f32>>,
42 #[serde(default)]
45 pub namespace: Option<String>,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum HelperKind {
52 JaccardOverlap,
55 CosinePreFilter,
58 FtsClassifier,
62}
63
64impl HelperKind {
65 #[must_use]
68 pub const fn as_str(self) -> &'static str {
69 match self {
70 Self::JaccardOverlap => "jaccard_overlap",
71 Self::CosinePreFilter => "cosine_pre_filter",
72 Self::FtsClassifier => "fts_classifier",
73 }
74 }
75}
76
77#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct HelperParams {
81 pub content: String,
87 #[serde(default)]
91 pub candidates: Vec<MemoryHandle>,
92 #[serde(default)]
95 pub cosine_threshold: Option<f32>,
96 #[serde(default)]
100 pub content_embedding: Option<Vec<f32>>,
101 #[serde(default)]
103 pub namespace: Option<String>,
104}
105
106#[derive(Debug, Clone, Copy)]
115pub struct HelperContext<'a> {
116 pub content: &'a str,
123 pub candidates: &'a [MemoryHandle],
125 pub content_embedding: Option<&'a [f32]>,
128 pub namespace: Option<&'a str>,
130}
131
132impl<'a> HelperContext<'a> {
133 #[must_use]
135 pub fn new(
136 content: &'a str,
137 candidates: &'a [MemoryHandle],
138 content_embedding: Option<&'a [f32]>,
139 namespace: Option<&'a str>,
140 ) -> Self {
141 Self {
142 content,
143 candidates,
144 content_embedding,
145 namespace,
146 }
147 }
148
149 #[must_use]
154 pub fn effective_content<'p>(&self, params: &'p HelperParams) -> &'p str
155 where
156 'a: 'p,
157 {
158 if params.content.is_empty() {
159 self.content
160 } else {
161 params.content.as_str()
162 }
163 }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct HelperOutput {
170 pub kind: HelperKind,
172 pub summary: String,
175 pub payload: Value,
179}
180
181#[must_use]
193pub fn jaccard_overlap(params: &HelperParams) -> HelperOutput {
194 let ctx = HelperContext::new(¶ms.content, ¶ms.candidates, None, None);
195 jaccard_overlap_with(params, &ctx)
196}
197
198#[must_use]
203pub fn jaccard_overlap_with(params: &HelperParams, ctx: &HelperContext<'_>) -> HelperOutput {
204 let content = ctx.effective_content(params);
205 let candidates: &[MemoryHandle] = if params.candidates.is_empty() {
206 ctx.candidates
207 } else {
208 params.candidates.as_slice()
209 };
210
211 let content_tokens = tokenise(content);
212 let mut scored: Vec<(&str, f32, &str)> = candidates
213 .iter()
214 .map(|c| {
215 let candidate_tokens = tokenise(&c.body);
216 let overlap = jaccard(&content_tokens, &candidate_tokens);
217 (c.id.as_str(), overlap, c.body.as_str())
218 })
219 .collect();
220 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
221 scored.truncate(10);
222
223 let over_threshold: usize = scored.iter().filter(|(_, score, _)| *score >= 0.40).count();
224
225 let summary = format!(
226 "jaccard: {}/{} candidates over 0.40 overlap",
227 over_threshold,
228 candidates.len()
229 );
230
231 let payload = json!({
232 "helper": "jaccard_overlap",
233 "candidates_scored": candidates.len(),
234 "top_candidates": scored
235 .iter()
236 .map(|(id, score, body)| json!({
237 "id": id,
238 "overlap": score,
239 "preview": preview(body, 120),
240 }))
241 .collect::<Vec<_>>(),
242 });
243
244 HelperOutput {
245 kind: HelperKind::JaccardOverlap,
246 summary,
247 payload,
248 }
249}
250
251#[must_use]
258pub fn cosine_pre_filter(params: &HelperParams) -> HelperOutput {
259 let ctx = HelperContext::new(
260 ¶ms.content,
261 ¶ms.candidates,
262 params.content_embedding.as_deref(),
263 None,
264 );
265 cosine_pre_filter_with(params, &ctx)
266}
267
268#[must_use]
271pub fn cosine_pre_filter_with(params: &HelperParams, ctx: &HelperContext<'_>) -> HelperOutput {
272 let threshold = params.cosine_threshold.unwrap_or(0.20);
273 let content_emb: Option<&[f32]> = if params.content_embedding.is_some() {
274 params.content_embedding.as_deref()
275 } else {
276 ctx.content_embedding
277 };
278 let candidates: &[MemoryHandle] = if params.candidates.is_empty() {
279 ctx.candidates
280 } else {
281 params.candidates.as_slice()
282 };
283
284 let scored: Vec<Value> = candidates
285 .iter()
286 .map(|c| {
287 let score = match (content_emb, c.embedding.as_deref()) {
288 (Some(a), Some(b)) => Some(cosine(a, b)),
289 _ => None,
290 };
291 json!({
292 "id": c.id,
293 "score": score,
294 "above_threshold": score.is_some_and(|s| s >= threshold),
295 "preview": preview(&c.body, 120),
296 })
297 })
298 .collect();
299
300 let kept = scored
301 .iter()
302 .filter(|v| v["above_threshold"].as_bool().unwrap_or(false))
303 .count();
304 let total = scored.len();
305
306 let summary = format!("cosine: {kept}/{total} candidates over {threshold:.2} threshold");
307
308 let payload = json!({
309 "helper": "cosine_pre_filter",
310 "threshold": threshold,
311 "candidates_scored": total,
312 "candidates_kept": kept,
313 "candidates": scored,
314 });
315
316 HelperOutput {
317 kind: HelperKind::CosinePreFilter,
318 summary,
319 payload,
320 }
321}
322
323#[must_use]
334pub fn fts_classifier(params: &HelperParams) -> HelperOutput {
335 let ctx = HelperContext::new(
336 ¶ms.content,
337 ¶ms.candidates,
338 None,
339 params.namespace.as_deref(),
340 );
341 fts_classifier_with(params, &ctx)
342}
343
344#[must_use]
347pub fn fts_classifier_with(params: &HelperParams, ctx: &HelperContext<'_>) -> HelperOutput {
348 let content = ctx.effective_content(params);
349 let namespace: &str = params
350 .namespace
351 .as_deref()
352 .or(ctx.namespace)
353 .unwrap_or(crate::DEFAULT_NAMESPACE);
354
355 let body_lower = content.to_lowercase();
356 let kind = if body_lower.contains("step ")
357 || body_lower.contains("first, ")
358 || body_lower.contains("then ")
359 {
360 "procedural"
361 } else if body_lower.contains("yesterday")
362 || body_lower.contains("today")
363 || body_lower.contains("happened")
364 || body_lower.contains("event")
365 {
366 "episodic"
367 } else {
368 "declarative"
369 };
370
371 let summary = format!("fts_classifier: kind={kind} (namespace={namespace})");
372
373 let payload = json!({
374 "helper": HelperKind::FtsClassifier.as_str(),
375 "fact_kind": kind,
376 "namespace": namespace,
377 "tokens": tokenise(content).len(),
378 });
379
380 HelperOutput {
381 kind: HelperKind::FtsClassifier,
382 summary,
383 payload,
384 }
385}
386
387#[must_use]
393pub fn run_helper(kind: HelperKind, params: &HelperParams) -> HelperOutput {
394 match kind {
395 HelperKind::JaccardOverlap => jaccard_overlap(params),
396 HelperKind::CosinePreFilter => cosine_pre_filter(params),
397 HelperKind::FtsClassifier => fts_classifier(params),
398 }
399}
400
401#[must_use]
404pub fn run_helper_with(
405 kind: HelperKind,
406 params: &HelperParams,
407 ctx: &HelperContext<'_>,
408) -> HelperOutput {
409 match kind {
410 HelperKind::JaccardOverlap => jaccard_overlap_with(params, ctx),
411 HelperKind::CosinePreFilter => cosine_pre_filter_with(params, ctx),
412 HelperKind::FtsClassifier => fts_classifier_with(params, ctx),
413 }
414}
415
416fn tokenise(body: &str) -> HashSet<String> {
421 body.split_whitespace()
422 .map(|t| {
423 t.trim_matches(|c: char| !c.is_alphanumeric())
424 .to_lowercase()
425 })
426 .filter(|t| !t.is_empty())
427 .collect()
428}
429
430fn jaccard(a: &HashSet<String>, b: &HashSet<String>) -> f32 {
431 if a.is_empty() && b.is_empty() {
432 return 0.0;
433 }
434 let intersect: usize = a.intersection(b).count();
435 let union: usize = a.union(b).count();
436 if union == 0 {
437 0.0
438 } else {
439 intersect as f32 / union as f32
440 }
441}
442
443fn cosine(a: &[f32], b: &[f32]) -> f32 {
444 if a.is_empty() || b.is_empty() || a.len() != b.len() {
445 return 0.0;
446 }
447 let mut dot = 0.0_f32;
448 let mut na = 0.0_f32;
449 let mut nb = 0.0_f32;
450 for i in 0..a.len() {
451 dot += a[i] * b[i];
452 na += a[i] * a[i];
453 nb += b[i] * b[i];
454 }
455 if na <= f32::EPSILON || nb <= f32::EPSILON {
456 return 0.0;
457 }
458 dot / (na.sqrt() * nb.sqrt())
459}
460
461fn preview(body: &str, max: usize) -> String {
462 if body.chars().count() <= max {
463 body.to_string()
464 } else {
465 let truncated: String = body.chars().take(max).collect();
466 format!("{truncated}…")
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 fn mh(id: &str, body: &str) -> MemoryHandle {
475 MemoryHandle {
476 id: id.to_string(),
477 body: body.to_string(),
478 embedding: None,
479 namespace: None,
480 }
481 }
482
483 fn mh_emb(id: &str, body: &str, embedding: Vec<f32>) -> MemoryHandle {
484 MemoryHandle {
485 id: id.to_string(),
486 body: body.to_string(),
487 embedding: Some(embedding),
488 namespace: None,
489 }
490 }
491
492 #[test]
493 fn jaccard_overlap_returns_non_empty_for_overlapping_text() {
494 let params = HelperParams {
495 content: "the quick brown fox jumps over the lazy dog".to_string(),
496 candidates: vec![
497 mh("a", "a quick brown dog"),
498 mh("b", "completely unrelated content here"),
499 ],
500 ..Default::default()
501 };
502 let out = jaccard_overlap(¶ms);
503 assert_eq!(out.kind, HelperKind::JaccardOverlap);
504 let top = out.payload["top_candidates"].as_array().unwrap();
505 assert_eq!(top.len(), 2);
506 assert_eq!(top[0]["id"].as_str(), Some("a"));
508 let top_score = top[0]["overlap"].as_f64().unwrap();
509 let bot_score = top[1]["overlap"].as_f64().unwrap();
510 assert!(top_score > bot_score);
511 }
512
513 #[test]
514 fn jaccard_overlap_handles_empty_candidates_cleanly() {
515 let params = HelperParams {
516 content: "hello world".to_string(),
517 candidates: vec![],
518 ..Default::default()
519 };
520 let out = jaccard_overlap(¶ms);
521 assert_eq!(out.payload["candidates_scored"], 0);
522 assert_eq!(out.payload["top_candidates"].as_array().unwrap().len(), 0);
523 }
524
525 #[test]
526 fn cosine_pre_filter_drops_below_threshold() {
527 let params = HelperParams {
528 content: "x".to_string(),
529 candidates: vec![
530 mh_emb("near", "near body", vec![1.0, 0.0, 0.0]),
531 mh_emb("far", "far body", vec![0.0, 1.0, 0.0]),
532 ],
533 content_embedding: Some(vec![1.0, 0.05, 0.0]),
534 cosine_threshold: Some(0.50),
535 ..Default::default()
536 };
537 let out = cosine_pre_filter(¶ms);
538 let kept = out.payload["candidates_kept"].as_u64().unwrap();
539 assert_eq!(kept, 1, "only the 'near' candidate should pass");
540 }
541
542 #[test]
543 fn cosine_pre_filter_no_embedding_degrades_to_null_scores() {
544 let params = HelperParams {
545 content: "x".to_string(),
546 candidates: vec![mh("a", "a")],
547 content_embedding: None,
548 ..Default::default()
549 };
550 let out = cosine_pre_filter(¶ms);
551 let candidates = out.payload["candidates"].as_array().unwrap();
552 assert!(candidates[0]["score"].is_null());
553 assert_eq!(candidates[0]["above_threshold"], false);
554 }
555
556 #[test]
557 fn fts_classifier_labels_procedural_text() {
558 let params = HelperParams {
559 content: "Step 1: open the door. Then walk through.".to_string(),
560 ..Default::default()
561 };
562 let out = fts_classifier(¶ms);
563 assert_eq!(out.payload["fact_kind"], "procedural");
564 }
565
566 #[test]
567 fn fts_classifier_labels_episodic_text() {
568 let params = HelperParams {
569 content: "Yesterday I went to the store.".to_string(),
570 ..Default::default()
571 };
572 let out = fts_classifier(¶ms);
573 assert_eq!(out.payload["fact_kind"], "episodic");
574 }
575
576 #[test]
577 fn fts_classifier_default_is_declarative() {
578 let params = HelperParams {
579 content: "The capital of France is Paris.".to_string(),
580 ..Default::default()
581 };
582 let out = fts_classifier(¶ms);
583 assert_eq!(out.payload["fact_kind"], "declarative");
584 }
585
586 #[test]
587 fn run_helper_dispatches_correctly() {
588 let params = HelperParams {
589 content: "anything".to_string(),
590 ..Default::default()
591 };
592 let out = run_helper(HelperKind::FtsClassifier, ¶ms);
593 assert_eq!(out.kind, HelperKind::FtsClassifier);
594 }
595
596 #[test]
597 fn helper_kind_serialisation_is_snake_case() {
598 assert_eq!(HelperKind::JaccardOverlap.as_str(), "jaccard_overlap");
599 assert_eq!(HelperKind::CosinePreFilter.as_str(), "cosine_pre_filter");
600 assert_eq!(HelperKind::FtsClassifier.as_str(), "fts_classifier");
601 }
602}