use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use hnsw_rs::prelude::{DistL2, Hnsw};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::CorpFinanceError;
use crate::memory::types::RunSummary;
use crate::CorpFinanceResult;
pub const DEFAULT_M: usize = 16;
pub const DEFAULT_EF_CONSTRUCTION: usize = 200;
pub const DEFAULT_MAX_ELEMENTS: usize = 100_000;
pub const DEFAULT_MAX_LAYER: usize = 16;
const ENVELOPE_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HnswParams {
m: usize,
ef_construction: usize,
embedding_dim: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct HnswEnvelope {
version: u32,
params: HnswParams,
summaries: Vec<RunSummary>,
}
pub struct HnswMemoryIndex {
hnsw: Hnsw<'static, f32, DistL2>,
id_to_uuid: Vec<Uuid>,
summaries: HashMap<Uuid, RunSummary>,
embedding_dim: usize,
m: usize,
ef_construction: usize,
}
impl HnswMemoryIndex {
pub fn new(embedding_dim: usize) -> Self {
Self::with_params(
embedding_dim,
DEFAULT_M,
DEFAULT_EF_CONSTRUCTION,
DEFAULT_MAX_ELEMENTS,
)
}
pub fn with_params(
embedding_dim: usize,
m: usize,
ef_construction: usize,
max_elements: usize,
) -> Self {
let hnsw = Hnsw::<f32, DistL2>::new(
m,
max_elements,
DEFAULT_MAX_LAYER,
ef_construction,
DistL2 {},
);
Self {
hnsw,
id_to_uuid: Vec::new(),
summaries: HashMap::new(),
embedding_dim,
m,
ef_construction,
}
}
pub fn len(&self) -> usize {
self.summaries.len()
}
pub fn is_empty(&self) -> bool {
self.summaries.is_empty()
}
pub fn embedding_dim(&self) -> usize {
self.embedding_dim
}
pub fn summaries_iter(&self) -> impl Iterator<Item = &RunSummary> + '_ {
self.summaries.values()
}
pub fn summaries_count(&self) -> usize {
self.summaries.len()
}
pub fn ingest(&mut self, summary: &RunSummary) -> CorpFinanceResult<()> {
if summary.embedding.len() != self.embedding_dim {
return Err(CorpFinanceError::InvalidInput {
field: "embedding".into(),
reason: format!(
"expected dim {}, got {}",
self.embedding_dim,
summary.embedding.len()
),
});
}
if self.summaries.contains_key(&summary.run_id) {
return Err(CorpFinanceError::InvalidInput {
field: "run_id".into(),
reason: format!("duplicate run_id {}", summary.run_id),
});
}
let data_id = self.id_to_uuid.len();
self.hnsw.insert((&summary.embedding, data_id));
self.id_to_uuid.push(summary.run_id);
self.summaries.insert(summary.run_id, summary.clone());
Ok(())
}
pub fn query<F>(&self, embedding: &[f32], limit: usize, filter: F) -> Vec<(RunSummary, f32)>
where
F: Fn(&RunSummary) -> bool,
{
if embedding.len() != self.embedding_dim || self.summaries.is_empty() {
return Vec::new();
}
let knbn = limit.saturating_mul(4).max(limit).max(1);
let ef_search = self.ef_construction.max(knbn);
let neighbours = self.hnsw.search(embedding, knbn, ef_search);
let mut out: Vec<(RunSummary, f32)> = Vec::with_capacity(limit);
for n in neighbours {
if let Some(uuid) = self.id_to_uuid.get(n.d_id) {
if let Some(summary) = self.summaries.get(uuid) {
if filter(summary) {
out.push((summary.clone(), n.distance));
if out.len() >= limit {
break;
}
}
}
}
}
out
}
pub fn save_to(&self, path: &Path) -> CorpFinanceResult<()> {
let envelope = HnswEnvelope {
version: ENVELOPE_VERSION,
params: HnswParams {
m: self.m,
ef_construction: self.ef_construction,
embedding_dim: self.embedding_dim,
},
summaries: self
.id_to_uuid
.iter()
.filter_map(|u| self.summaries.get(u).cloned())
.collect(),
};
let json = serde_json::to_vec(&envelope)?;
let file = File::create(path).map_err(io_to_cf)?;
let mut encoder = GzEncoder::new(BufWriter::new(file), Compression::default());
encoder.write_all(&json).map_err(io_to_cf)?;
encoder.finish().map_err(io_to_cf)?;
Ok(())
}
pub fn load_from(path: &Path) -> CorpFinanceResult<Self> {
let file = File::open(path).map_err(io_to_cf)?;
let mut decoder = GzDecoder::new(BufReader::new(file));
let mut buf = Vec::new();
decoder.read_to_end(&mut buf).map_err(io_to_cf)?;
let envelope: HnswEnvelope = serde_json::from_slice(&buf)?;
if envelope.version != ENVELOPE_VERSION {
return Err(CorpFinanceError::InvalidInput {
field: "envelope.version".into(),
reason: format!(
"unsupported envelope version {} (expected {})",
envelope.version, ENVELOPE_VERSION
),
});
}
let mut idx = Self::with_params(
envelope.params.embedding_dim,
envelope.params.m,
envelope.params.ef_construction,
DEFAULT_MAX_ELEMENTS.max(envelope.summaries.len()),
);
for s in envelope.summaries {
idx.ingest(&s)?;
}
Ok(idx)
}
}
fn io_to_cf(e: std::io::Error) -> CorpFinanceError {
CorpFinanceError::SerializationError(format!("io: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::types::Surface;
fn mk_summary(text: &str, emb: Vec<f32>) -> RunSummary {
RunSummary::new(Surface::Mcp, "dcf_calc", "djb2:0xaaaa", text, emb)
}
#[test]
fn ingest_then_query_returns_inserted() {
let mut idx = HnswMemoryIndex::new(3);
let s = mk_summary("hello", vec![1.0, 0.0, 0.0]);
idx.ingest(&s).unwrap();
let res = idx.query(&[1.0, 0.0, 0.0], 1, |_| true);
assert_eq!(res.len(), 1);
assert_eq!(res[0].0.run_id, s.run_id);
}
#[test]
fn dim_mismatch_rejected() {
let mut idx = HnswMemoryIndex::new(3);
let s = mk_summary("x", vec![1.0, 0.0]);
assert!(idx.ingest(&s).is_err());
}
#[test]
fn summaries_iter_yields_all_ingested() {
let mut idx = HnswMemoryIndex::new(3);
idx.ingest(&mk_summary("first", vec![1.0, 0.0, 0.0]))
.unwrap();
idx.ingest(&mk_summary("second", vec![0.0, 1.0, 0.0]))
.unwrap();
assert_eq!(idx.summaries_count(), 2);
assert_eq!(idx.summaries_iter().count(), 2);
let texts: std::collections::HashSet<&str> = idx
.summaries_iter()
.map(|s| s.summary_text.as_str())
.collect();
assert!(texts.contains("first"));
assert!(texts.contains("second"));
}
}