use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use serde_json::Value;
use super::memory_db::SqliteMemoryStore;
use super::memory_provider::MemoryProvider;
const HYBRID_RESULT_LIMIT: usize = 8;
const FUSION_BM25_POOL: usize = 50;
const RRF_K: f64 = 60.0;
pub trait Embedder: Send + Sync {
fn embed(&self, texts: &[String]) -> Vec<Option<Vec<f32>>>;
}
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut na = 0.0f32;
let mut nb = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot / (na.sqrt() * nb.sqrt())
}
pub fn rrf_fuse(rankings: &[Vec<String>], k: f64) -> Vec<String> {
let mut score: HashMap<&str, f64> = HashMap::new();
let mut order: Vec<&str> = Vec::new();
for ranking in rankings {
for (rank, key) in ranking.iter().enumerate() {
let e = score.entry(key.as_str()).or_insert_with(|| {
order.push(key.as_str());
0.0
});
*e += 1.0 / (k + rank as f64);
}
}
order.sort_by(|a, b| {
score[b]
.partial_cmp(&score[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
order.into_iter().map(str::to_string).collect()
}
fn content_key(content: &str) -> u64 {
crate::hash::fnv64(content.as_bytes())
}
const MAX_CACHE_ENTRIES: usize = 4096;
pub struct HybridMemoryProvider {
inner: Arc<SqliteMemoryStore>,
embedder: Arc<dyn Embedder>,
cache: Mutex<HashMap<u64, Vec<f32>>>,
}
impl HybridMemoryProvider {
pub fn new(inner: Arc<SqliteMemoryStore>, embedder: Arc<dyn Embedder>) -> Self {
Self {
inner,
embedder,
cache: Mutex::new(HashMap::new()),
}
}
fn embed_cached(&self, contents: &[String]) -> Vec<Option<Vec<f32>>> {
let mut out: Vec<Option<Vec<f32>>> = vec![None; contents.len()];
let mut miss_idx: Vec<usize> = Vec::new();
{
let cache = self.cache.lock().unwrap_or_else(|p| p.into_inner());
for (i, c) in contents.iter().enumerate() {
if let Some(v) = cache.get(&content_key(c)) {
out[i] = Some(v.clone());
} else {
miss_idx.push(i);
}
}
}
if miss_idx.is_empty() {
return out;
}
let miss_texts: Vec<String> = miss_idx.iter().map(|&i| contents[i].clone()).collect();
let fresh = self.embedder.embed(&miss_texts);
let mut cache = self.cache.lock().unwrap_or_else(|p| p.into_inner());
if cache.len() >= MAX_CACHE_ENTRIES {
cache.clear();
}
for (slot, vec) in miss_idx.into_iter().zip(fresh) {
if let Some(v) = vec {
cache.insert(content_key(&contents[slot]), v.clone());
out[slot] = Some(v);
}
}
out
}
#[cfg(test)]
fn cache_len(&self) -> usize {
self.cache.lock().unwrap_or_else(|p| p.into_inner()).len()
}
fn dense_ranking(&self, query: &str, rows: &[Value]) -> Vec<String> {
let qvec = match self.embedder.embed(&[query.to_string()]).into_iter().next() {
Some(Some(v)) => v,
_ => return Vec::new(),
};
let contents: Vec<String> = rows
.iter()
.map(|r| r["content"].as_str().unwrap_or_default().to_string())
.collect();
let embs = self.embed_cached(&contents);
let mut scored: Vec<(f32, &str)> = rows
.iter()
.zip(embs.iter())
.filter_map(|(row, emb)| {
let id = row["id"].as_str()?;
let v = emb.as_ref()?;
Some((cosine(&qvec, v), id))
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().map(|(_, id)| id.to_string()).collect()
}
}
impl MemoryProvider for HybridMemoryProvider {
fn name(&self) -> &str {
"hybrid"
}
fn format_for_system_prompt(&self) -> String {
self.inner.format_for_system_prompt()
}
fn view(&self, target: &str) -> Value {
MemoryProvider::view(self.inner.as_ref(), target)
}
fn add(&self, target: &str, content: &str, kind: Option<&str>) -> Result<Value, String> {
MemoryProvider::add(self.inner.as_ref(), target, content, kind)
}
fn replace(
&self,
target: &str,
old_text: &str,
content: &str,
kind: Option<&str>,
) -> Result<Value, String> {
MemoryProvider::replace(self.inner.as_ref(), target, old_text, content, kind)
}
fn supersede(
&self,
target: &str,
old_text: &str,
content: &str,
kind: Option<&str>,
harsh: bool,
) -> Result<Value, String> {
MemoryProvider::supersede(self.inner.as_ref(), target, old_text, content, kind, harsh)
}
fn remove(&self, target: &str, old_text: &str) -> Result<Value, String> {
MemoryProvider::remove(self.inner.as_ref(), target, old_text)
}
fn restore(&self, target: &str, old_text: &str) -> Result<Value, String> {
MemoryProvider::restore(self.inner.as_ref(), target, old_text)
}
fn expand(&self, old_text: &str) -> Result<Value, String> {
MemoryProvider::expand(self.inner.as_ref(), old_text)
}
fn record_outcome(&self, target: &str, old_text: &str, success: bool) -> Result<Value, String> {
MemoryProvider::record_outcome(self.inner.as_ref(), target, old_text, success)
}
fn search(&self, query: &str) -> Result<Value, String> {
let bm25 = self.inner.search_entries_limited(query, FUSION_BM25_POOL)?;
let bm25_ranked: Vec<String> = bm25["results"]
.as_array()
.map(|rs| {
rs.iter()
.filter_map(|r| r["id"].as_str().map(str::to_string))
.collect()
})
.unwrap_or_default();
let rows = self.inner.active_search_rows()?;
let dense_ranked = self.dense_ranking(query, &rows);
let by_id: HashMap<&str, &Value> = rows
.iter()
.filter_map(|r| r["id"].as_str().map(|id| (id, r)))
.collect();
let ordered: Vec<String> = if dense_ranked.is_empty() {
bm25_ranked
} else {
rrf_fuse(&[bm25_ranked, dense_ranked], RRF_K)
};
let results: Vec<Value> = ordered
.iter()
.filter_map(|id| by_id.get(id.as_str()).map(|v| (*v).clone()))
.take(HYBRID_RESULT_LIMIT)
.collect();
Ok(serde_json::json!({
"success": true,
"query": query,
"count": results.len(),
"results": results,
}))
}
fn on_memory_write(&self, action: &str, target: &str, payload: &str) {
self.inner.on_memory_write(action, target, payload);
}
fn on_session_end(&self, transcript: &str) {
self.inner.on_session_end(transcript);
}
fn on_session_switch(&self, new_session_id: &str, parent_session_id: &str, reset: bool) {
if reset {
self.cache.lock().unwrap_or_else(|p| p.into_inner()).clear();
}
self.inner
.on_session_switch(new_session_id, parent_session_id, reset);
}
fn on_pre_compress(&self, transcript: &str) -> String {
self.inner.on_pre_compress(transcript)
}
}
pub const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
const EMBED_TIMEOUT_SECS: u64 = 10;
pub fn api_embedder(
url: String,
model: String,
api_key: Option<String>,
) -> Option<Arc<dyn Embedder>> {
match ApiEmbedder::new(url, model, api_key) {
Ok(e) => Some(Arc::new(e)),
Err(err) => {
tracing::warn!(target: "dirge::memory_hybrid", error = %err, "embedder unavailable — staying BM25-only");
None
}
}
}
fn parse_embeddings(body: &Value, n: usize) -> Vec<Option<Vec<f32>>> {
let mut out = vec![None; n];
let Some(data) = body["data"].as_array() else {
return out;
};
for (pos, item) in data.iter().enumerate() {
let idx = item["index"].as_u64().map(|i| i as usize).unwrap_or(pos);
let emb: Option<Vec<f32>> = item["embedding"].as_array().map(|a| {
a.iter()
.filter_map(|x| x.as_f64().map(|f| f as f32))
.collect()
});
if idx < n
&& let Some(e) = emb
&& !e.is_empty()
{
out[idx] = Some(e);
}
}
out
}
async fn fetch_embeddings(
client: &reqwest::Client,
url: &str,
model: &str,
api_key: &Option<String>,
texts: &[String],
) -> Vec<Option<Vec<f32>>> {
let mut req = client
.post(url)
.json(&serde_json::json!({ "model": model, "input": texts }));
if let Some(k) = api_key {
req = req.bearer_auth(k);
}
match req.send().await {
Ok(resp) => match resp.json::<Value>().await {
Ok(body) => parse_embeddings(&body, texts.len()),
Err(e) => {
tracing::warn!(target: "dirge::memory_hybrid", error = %e, "embeddings response parse failed");
vec![None; texts.len()]
}
},
Err(e) => {
tracing::warn!(target: "dirge::memory_hybrid", error = %e, "embeddings request failed");
vec![None; texts.len()]
}
}
}
struct ApiEmbedder {
tx: std::sync::mpsc::Sender<EmbedJob>,
}
struct EmbedJob {
texts: Vec<String>,
reply: std::sync::mpsc::Sender<Vec<Option<Vec<f32>>>>,
}
impl ApiEmbedder {
fn new(url: String, model: String, api_key: Option<String>) -> std::io::Result<Self> {
let (tx, rx) = std::sync::mpsc::channel::<EmbedJob>();
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(EMBED_TIMEOUT_SECS))
.build()
.map_err(std::io::Error::other)?;
std::thread::Builder::new()
.name("dirge-embedder".into())
.spawn(move || {
let rt = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
tracing::error!(target: "dirge::memory_hybrid", error = %e, "embedder runtime build failed");
return;
}
};
while let Ok(job) = rx.recv() {
let result =
rt.block_on(fetch_embeddings(&client, &url, &model, &api_key, &job.texts));
let _ = job.reply.send(result);
}
})?;
Ok(Self { tx })
}
}
impl Embedder for ApiEmbedder {
fn embed(&self, texts: &[String]) -> Vec<Option<Vec<f32>>> {
let (reply, rx) = std::sync::mpsc::channel();
let job = EmbedJob {
texts: texts.to_vec(),
reply,
};
if self.tx.send(job).is_err() {
return vec![None; texts.len()];
}
let wait = std::time::Duration::from_secs(EMBED_TIMEOUT_SECS + 5);
rx.recv_timeout(wait)
.unwrap_or_else(|_| vec![None; texts.len()])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extras::dirge_paths::ProjectPaths;
use crate::extras::memory_db::MemoryKind;
#[test]
fn parse_embeddings_maps_by_index_and_tolerates_gaps() {
let body = serde_json::json!({
"data": [
{"index": 1, "embedding": [0.1, 0.2]},
{"index": 0, "embedding": [0.3, 0.4]},
]
});
let out = parse_embeddings(&body, 3);
assert_eq!(out[0], Some(vec![0.3, 0.4]), "index 0 placed correctly");
assert_eq!(out[1], Some(vec![0.1, 0.2]), "out-of-order index respected");
assert_eq!(out[2], None, "missing index stays None");
assert_eq!(
parse_embeddings(&serde_json::json!({}), 2),
vec![None, None]
);
}
#[test]
fn cosine_basics() {
assert!((cosine(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-6);
assert!(cosine(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-6);
assert!((cosine(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-6);
assert_eq!(cosine(&[1.0], &[1.0, 2.0]), 0.0);
assert_eq!(cosine(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
}
#[test]
fn rrf_rewards_agreement_and_merges_disjoint() {
let l1 = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let l2 = vec!["B".to_string(), "D".to_string(), "A".to_string()];
let fused = rrf_fuse(&[l1, l2], RRF_K);
assert_eq!(fused[0], "B", "agreed-upon key ranks first: {fused:?}");
for k in ["A", "B", "C", "D"] {
assert!(fused.contains(&k.to_string()), "{k} missing: {fused:?}");
}
}
struct ConceptEmbedder;
impl ConceptEmbedder {
fn concept(word: &str) -> Option<usize> {
let w = word.to_lowercase();
let dim = match w.as_str() {
"build" | "compile" | "compiling" => 0,
"project" | "binary" | "executable" => 1,
"test" | "tests" | "testing" | "suite" => 2,
"format" | "formatting" | "tidy" | "whitespace" | "indentation" => 3,
"memory" | "recall" | "recollections" | "remember" => 4,
"store" | "stored" | "persist" | "persists" | "saved" | "sqlite" => 5,
_ => return None,
};
Some(dim)
}
}
impl Embedder for ConceptEmbedder {
fn embed(&self, texts: &[String]) -> Vec<Option<Vec<f32>>> {
texts
.iter()
.map(|t| {
let mut v = vec![0.0f32; 6];
let mut any = false;
for word in t.split(|c: char| !c.is_alphanumeric()) {
if let Some(d) = Self::concept(word) {
v[d] += 1.0;
any = true;
}
}
any.then_some(v)
})
.collect()
}
}
fn temp_store() -> (Arc<SqliteMemoryStore>, std::path::PathBuf) {
let dir = std::env::temp_dir().join(format!(
"dirge-hybrid-test-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(dir.join(".git")).unwrap();
let store = SqliteMemoryStore::load(&ProjectPaths::new(&dir)).unwrap();
(Arc::new(store), dir)
}
fn ids(resp: &Value) -> Vec<String> {
resp["results"]
.as_array()
.unwrap()
.iter()
.map(|r| r["content"].as_str().unwrap().to_string())
.collect()
}
#[test]
fn hybrid_recovers_a_paraphrase_bm25_misses() {
let (store, dir) = temp_store();
store
.add_entry("memory", "build the project", Some(MemoryKind::Procedural))
.unwrap();
store
.add_entry("memory", "run the test suite", Some(MemoryKind::Procedural))
.unwrap();
let bm25 = store.search_entries("compile the binary").unwrap();
assert!(
!ids(&bm25).iter().any(|c| c == "build the project"),
"precondition: BM25 misses the paraphrase",
);
let hybrid = HybridMemoryProvider::new(store.clone(), Arc::new(ConceptEmbedder));
let resp = hybrid.search("compile the binary").unwrap();
assert!(
ids(&resp).iter().any(|c| c == "build the project"),
"hybrid must recover the paraphrase: {:?}",
ids(&resp),
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn hybrid_preserves_lexical_hits() {
let (store, dir) = temp_store();
store
.add_entry("memory", "build the project", Some(MemoryKind::Procedural))
.unwrap();
store
.add_entry("memory", "run the test suite", Some(MemoryKind::Procedural))
.unwrap();
let hybrid = HybridMemoryProvider::new(store.clone(), Arc::new(ConceptEmbedder));
let resp = hybrid.search("test suite").unwrap();
assert!(
ids(&resp).iter().any(|c| c == "run the test suite"),
"lexical hit preserved under fusion: {:?}",
ids(&resp),
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn falls_back_to_bm25_without_embedder_signal() {
struct NullEmbedder;
impl Embedder for NullEmbedder {
fn embed(&self, texts: &[String]) -> Vec<Option<Vec<f32>>> {
vec![None; texts.len()]
}
}
let (store, dir) = temp_store();
store
.add_entry("memory", "build the project", Some(MemoryKind::Procedural))
.unwrap();
let hybrid = HybridMemoryProvider::new(store.clone(), Arc::new(NullEmbedder));
let bm25 = store.search_entries("build the project").unwrap();
let hybrid_resp = hybrid.search("build the project").unwrap();
assert_eq!(
ids(&bm25),
ids(&hybrid_resp),
"null embedder → BM25 verbatim"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn delegates_non_search_to_inner() {
let (store, dir) = temp_store();
let hybrid = HybridMemoryProvider::new(store.clone(), Arc::new(ConceptEmbedder));
assert_eq!(hybrid.name(), "hybrid");
hybrid.add("memory", "a delegated fact", None).unwrap();
let view = MemoryProvider::view(&hybrid, "memory");
assert!(
view["entries"]
.as_array()
.unwrap()
.iter()
.any(|e| e.as_str().unwrap().contains("delegated fact")),
"add routed to inner and is visible via view",
);
let _ = std::fs::remove_dir_all(&dir);
}
struct OnesEmbedder;
impl Embedder for OnesEmbedder {
fn embed(&self, texts: &[String]) -> Vec<Option<Vec<f32>>> {
texts.iter().map(|_| Some(vec![1.0f32])).collect()
}
}
#[test]
fn embedding_cache_is_bounded() {
let (store, dir) = temp_store();
let hybrid = HybridMemoryProvider::new(store, Arc::new(OnesEmbedder));
let full: Vec<String> = (0..MAX_CACHE_ENTRIES).map(|i| format!("c{i}")).collect();
hybrid.embed_cached(&full);
assert_eq!(
hybrid.cache_len(),
MAX_CACHE_ENTRIES,
"cache filled to the cap"
);
let more: Vec<String> = (0..10).map(|i| format!("d{i}")).collect();
hybrid.embed_cached(&more);
assert_eq!(
hybrid.cache_len(),
10,
"cap cleared the cache before the new batch"
);
assert!(hybrid.cache_len() <= MAX_CACHE_ENTRIES);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn reset_clears_cache_continuation_keeps_it() {
let (store, dir) = temp_store();
let hybrid = HybridMemoryProvider::new(store, Arc::new(OnesEmbedder));
hybrid.embed_cached(&["a".to_string(), "b".to_string()]);
assert_eq!(hybrid.cache_len(), 2);
hybrid.on_session_switch("s2", "s1", false);
assert_eq!(hybrid.cache_len(), 2, "continuation keeps the cache");
hybrid.on_session_switch("s3", "", true);
assert_eq!(hybrid.cache_len(), 0, "reset clears the cache");
let _ = std::fs::remove_dir_all(&dir);
}
}