use std::collections::{BTreeMap, HashMap};
use std::sync::{Mutex, OnceLock};
use chrono::Utc;
use hnsw_rs::prelude::{DistL2, Hnsw};
use uuid::Uuid;
use crate::error::CorpFinanceError;
use crate::self_learning::types::{EvalGrade, SurfaceEventRef, Trajectory};
use crate::surface::Surface;
use crate::CorpFinanceResult;
#[derive(Debug, Clone, Default)]
pub struct TrajectoryFilter {
pub surface: Option<Surface>,
pub eval_grade_min: Option<EvalGrade>,
pub tenant_id: Option<String>,
}
impl TrajectoryFilter {
pub fn new() -> Self {
Self::default()
}
pub fn with_surface(mut self, s: Surface) -> Self {
self.surface = Some(s);
self
}
pub fn with_eval_grade_min(mut self, g: EvalGrade) -> Self {
self.eval_grade_min = Some(g);
self
}
pub fn with_tenant_id(mut self, id: impl Into<String>) -> Self {
self.tenant_id = Some(id.into());
self
}
}
struct PartialState {
surface: Surface,
surface_event_id: String,
steps: Vec<SurfaceEventRef>,
tenant_id: Option<String>,
started_at: chrono::DateTime<Utc>,
}
struct TrajectoryHnswIndex {
hnsw: Hnsw<'static, f32, DistL2>,
id_to_uuid: Vec<Uuid>,
embedding_dim: usize,
ef_construction: usize,
}
const HNSW_M: usize = 16;
const HNSW_EF_CONSTRUCTION: usize = 200;
const HNSW_MAX_LAYER: usize = 16;
const HNSW_MAX_ELEMENTS: usize = 100_000;
impl TrajectoryHnswIndex {
fn new(embedding_dim: usize) -> Self {
let hnsw = Hnsw::<f32, DistL2>::new(
HNSW_M,
HNSW_MAX_ELEMENTS,
HNSW_MAX_LAYER,
HNSW_EF_CONSTRUCTION,
DistL2 {},
);
Self {
hnsw,
id_to_uuid: Vec::new(),
embedding_dim,
ef_construction: HNSW_EF_CONSTRUCTION,
}
}
fn insert(&mut self, trajectory_id: Uuid, embedding: &[f32]) -> CorpFinanceResult<()> {
if embedding.len() != self.embedding_dim {
return Err(CorpFinanceError::InvalidInput {
field: "embedding".into(),
reason: format!(
"trajectory hnsw expects dim {}, got {}",
self.embedding_dim,
embedding.len()
),
});
}
let data_id = self.id_to_uuid.len();
self.hnsw.insert((embedding, data_id));
self.id_to_uuid.push(trajectory_id);
Ok(())
}
fn query_top_k(&self, query: &[f32], k: usize) -> Vec<Uuid> {
if query.len() != self.embedding_dim || self.id_to_uuid.is_empty() {
return Vec::new();
}
let knbn = k.max(1);
let ef_search = self.ef_construction.max(knbn);
let neighbours = self.hnsw.search(query, knbn, ef_search);
neighbours
.into_iter()
.filter_map(|n| self.id_to_uuid.get(n.d_id).copied())
.collect()
}
}
struct Store {
in_flight: HashMap<(String, String), PartialState>,
by_id: BTreeMap<Uuid, (Trajectory, Vec<f32>)>,
hnsw: Option<TrajectoryHnswIndex>,
}
fn store() -> &'static Mutex<Store> {
static STORE: OnceLock<Mutex<Store>> = OnceLock::new();
STORE.get_or_init(|| {
Mutex::new(Store {
in_flight: HashMap::new(),
by_id: BTreeMap::new(),
hnsw: None,
})
})
}
fn key(surface: Surface, surface_event_id: &str) -> (String, String) {
(surface.as_str().to_string(), surface_event_id.to_string())
}
pub const MAX_TRAJECTORY_STEPS: usize = 1024;
pub fn capture_trajectory_step(
surface: Surface,
surface_event_id: &str,
step: SurfaceEventRef,
) -> CorpFinanceResult<()> {
if surface_event_id.is_empty() {
return Err(CorpFinanceError::InvalidInput {
field: "surface_event_id".into(),
reason: "trajectory step requires non-empty surface_event_id".into(),
});
}
let mut s = store().lock().expect("self_learning store poisoned");
let entry = s
.in_flight
.entry(key(surface, surface_event_id))
.or_insert_with(|| PartialState {
surface,
surface_event_id: surface_event_id.to_string(),
steps: Vec::new(),
tenant_id: None,
started_at: Utc::now(),
});
if entry.steps.len() >= MAX_TRAJECTORY_STEPS {
return Err(CorpFinanceError::FinancialImpossibility(format!(
"trajectory step cap {MAX_TRAJECTORY_STEPS} exceeded for {}/{}",
surface.as_str(),
surface_event_id
)));
}
entry.steps.push(step);
Ok(())
}
pub fn attach_tenant(
surface: Surface,
surface_event_id: &str,
tenant_id: impl Into<String>,
) -> CorpFinanceResult<()> {
let mut s = store().lock().expect("self_learning store poisoned");
if let Some(entry) = s.in_flight.get_mut(&key(surface, surface_event_id)) {
entry.tenant_id = Some(tenant_id.into());
}
Ok(())
}
pub fn complete_trajectory(
surface: Surface,
surface_event_id: &str,
eval_grade: Option<EvalGrade>,
) -> CorpFinanceResult<Trajectory> {
let mut s = store().lock().expect("self_learning store poisoned");
let partial = s.in_flight.remove(&key(surface, surface_event_id)).ok_or(
CorpFinanceError::InsufficientData(format!(
"no in-flight trajectory for {}/{}",
surface.as_str(),
surface_event_id
)),
)?;
if partial.steps.is_empty() {
return Err(CorpFinanceError::InvalidInput {
field: "steps".into(),
reason: "trajectory must have at least one step".into(),
});
}
let trajectory = Trajectory {
trajectory_id: Uuid::now_v7(),
surface: partial.surface,
surface_event_id: partial.surface_event_id,
steps: partial.steps,
eval_grade,
tenant_id: partial.tenant_id,
ts: partial.started_at,
};
s.by_id
.insert(trajectory.trajectory_id, (trajectory.clone(), Vec::new()));
Ok(trajectory)
}
pub fn persist_with_embedding(
trajectory: Trajectory,
embedding: Vec<f32>,
) -> CorpFinanceResult<()> {
let mut s = store().lock().expect("self_learning store poisoned");
let trajectory_id = trajectory.trajectory_id;
s.by_id
.insert(trajectory_id, (trajectory, embedding.clone()));
if !embedding.is_empty() {
if s.hnsw.is_none() {
s.hnsw = Some(TrajectoryHnswIndex::new(embedding.len()));
}
let idx = s
.hnsw
.as_mut()
.expect("hnsw index just initialised in branch above");
idx.insert(trajectory_id, &embedding)?;
}
Ok(())
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut na = 0.0f32;
let mut nb = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot / (na.sqrt() * nb.sqrt())
}
const HNSW_OVERFETCH: usize = 4;
pub fn retrieve_similar(
query_embedding: &[f32],
filter: &TrajectoryFilter,
limit: usize,
) -> CorpFinanceResult<Vec<Trajectory>> {
let s = store().lock().expect("self_learning store poisoned");
let use_hnsw = !query_embedding.is_empty()
&& s.hnsw
.as_ref()
.map(|idx| idx.embedding_dim == query_embedding.len() && !idx.id_to_uuid.is_empty())
.unwrap_or(false);
let candidates: Vec<(&Trajectory, f32)> = if use_hnsw {
let idx = s.hnsw.as_ref().expect("hnsw existence checked above");
let knbn = limit.saturating_mul(HNSW_OVERFETCH).max(limit).max(1);
let ids = idx.query_top_k(query_embedding, knbn);
ids.into_iter()
.filter_map(|tid| s.by_id.get(&tid))
.filter(|(t, _)| filter_matches(t, filter))
.map(|(t, e)| (t, cosine_similarity(query_embedding, e)))
.collect()
} else {
s.by_id
.values()
.filter(|(t, _)| filter_matches(t, filter))
.map(|(t, e)| (t, cosine_similarity(query_embedding, e)))
.collect()
};
let mut scored: Vec<(f32, &Trajectory)> =
candidates.into_iter().map(|(t, sim)| (sim, t)).collect();
scored.sort_by(|a, b| {
let ag = a.1.eval_grade.unwrap_or(EvalGrade::Failed);
let bg = b.1.eval_grade.unwrap_or(EvalGrade::Failed);
bg.cmp(&ag)
.then(b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal))
});
Ok(scored
.into_iter()
.take(limit)
.map(|(_, t)| t.clone())
.collect())
}
fn filter_matches(t: &Trajectory, filter: &TrajectoryFilter) -> bool {
if let Some(surf) = filter.surface {
if t.surface != surf {
return false;
}
}
if let Some(min) = filter.eval_grade_min {
match t.eval_grade {
Some(g) if g >= min => {}
_ => return false,
}
}
if let Some(ref tid) = filter.tenant_id {
if t.tenant_id.as_deref() != Some(tid.as_str()) {
return false;
}
}
true
}
#[cfg(test)]
pub(crate) fn reset_store_for_tests() {
let mut s = store().lock().expect("self_learning store poisoned");
s.in_flight.clear();
s.by_id.clear();
s.hnsw = None;
}
#[cfg(test)]
pub(crate) fn lock_test_store() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
let l = LOCK.get_or_init(|| std::sync::Mutex::new(()));
match l.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::self_learning::types::SurfaceEventKind;
fn lock_for_test() -> std::sync::MutexGuard<'static, ()> {
super::lock_test_store()
}
fn step(name: &str) -> SurfaceEventRef {
SurfaceEventRef {
kind: SurfaceEventKind::McpTool,
name: name.into(),
input_hash: "h".into(),
output_hash: Some("o".into()),
duration_ms: 10,
}
}
#[test]
fn capture_then_complete_returns_trajectory_with_steps() {
let _guard = lock_for_test();
reset_store_for_tests();
capture_trajectory_step(Surface::Cli, "test1", step("a")).unwrap();
capture_trajectory_step(Surface::Cli, "test1", step("b")).unwrap();
let t = complete_trajectory(Surface::Cli, "test1", Some(EvalGrade::Good)).unwrap();
assert_eq!(t.steps.len(), 2);
assert_eq!(t.eval_grade, Some(EvalGrade::Good));
}
#[test]
fn capture_with_empty_event_id_errors() {
let _guard = lock_for_test();
reset_store_for_tests();
let err = capture_trajectory_step(Surface::Cli, "", step("a")).unwrap_err();
assert!(matches!(err, CorpFinanceError::InvalidInput { .. }));
}
#[test]
fn complete_without_capture_errors() {
let _guard = lock_for_test();
reset_store_for_tests();
let err = complete_trajectory(Surface::Cli, "missing", None).unwrap_err();
assert!(matches!(err, CorpFinanceError::InsufficientData(_)));
}
fn finalised(name: &str, grade: EvalGrade) -> Trajectory {
capture_trajectory_step(Surface::Cli, name, step("a")).unwrap();
complete_trajectory(Surface::Cli, name, Some(grade)).unwrap()
}
#[test]
fn persist_with_empty_embedding_does_not_build_hnsw() {
let _guard = lock_for_test();
reset_store_for_tests();
let t = finalised("empty-emb", EvalGrade::Good);
persist_with_embedding(t.clone(), Vec::new()).unwrap();
let hits = retrieve_similar(&[], &TrajectoryFilter::new(), 5).unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].trajectory_id, t.trajectory_id);
}
#[test]
fn persist_with_non_empty_embedding_builds_hnsw_lazily() {
let _guard = lock_for_test();
reset_store_for_tests();
let t1 = finalised("emb-1", EvalGrade::Good);
persist_with_embedding(t1.clone(), vec![1.0, 0.0, 0.0]).unwrap();
let t2 = finalised("emb-2", EvalGrade::Excellent);
persist_with_embedding(t2.clone(), vec![0.0, 1.0, 0.0]).unwrap();
let hits = retrieve_similar(&[1.0, 0.0, 0.0], &TrajectoryFilter::new(), 2).unwrap();
assert_eq!(hits.len(), 2);
assert_eq!(hits[0].trajectory_id, t2.trajectory_id);
assert_eq!(hits[1].trajectory_id, t1.trajectory_id);
}
#[test]
fn hnsw_dim_mismatch_on_insert_errors() {
let _guard = lock_for_test();
reset_store_for_tests();
let t1 = finalised("dim-1", EvalGrade::Good);
persist_with_embedding(t1, vec![1.0, 0.0, 0.0]).unwrap();
let t2 = finalised("dim-2", EvalGrade::Good);
let err = persist_with_embedding(t2, vec![1.0, 0.0]).unwrap_err();
assert!(matches!(err, CorpFinanceError::InvalidInput { .. }));
}
#[test]
fn retrieve_dim_mismatch_falls_back_to_linear() {
let _guard = lock_for_test();
reset_store_for_tests();
let t1 = finalised("falls-back-1", EvalGrade::Good);
persist_with_embedding(t1.clone(), vec![1.0, 0.0, 0.0]).unwrap();
let hits = retrieve_similar(&[1.0, 0.0, 0.0, 0.0], &TrajectoryFilter::new(), 5).unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].trajectory_id, t1.trajectory_id);
}
#[test]
fn retrieve_filter_by_grade_min_excludes_below() {
let _guard = lock_for_test();
reset_store_for_tests();
let bad = finalised("low-grade", EvalGrade::Poor);
persist_with_embedding(bad.clone(), vec![1.0, 0.0, 0.0]).unwrap();
let good = finalised("good-grade", EvalGrade::Good);
persist_with_embedding(good.clone(), vec![0.9, 0.1, 0.0]).unwrap();
let filter = TrajectoryFilter::new().with_eval_grade_min(EvalGrade::Acceptable);
let hits = retrieve_similar(&[1.0, 0.0, 0.0], &filter, 5).unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].trajectory_id, good.trajectory_id);
}
#[test]
fn retrieve_filter_by_tenant_excludes_other_tenants() {
let _guard = lock_for_test();
reset_store_for_tests();
capture_trajectory_step(Surface::Cli, "tenant-a-traj", step("a")).unwrap();
attach_tenant(Surface::Cli, "tenant-a-traj", "tenant-a").unwrap();
let a = complete_trajectory(Surface::Cli, "tenant-a-traj", Some(EvalGrade::Good)).unwrap();
persist_with_embedding(a.clone(), vec![1.0, 0.0, 0.0]).unwrap();
capture_trajectory_step(Surface::Cli, "tenant-b-traj", step("b")).unwrap();
attach_tenant(Surface::Cli, "tenant-b-traj", "tenant-b").unwrap();
let b = complete_trajectory(Surface::Cli, "tenant-b-traj", Some(EvalGrade::Good)).unwrap();
persist_with_embedding(b.clone(), vec![1.0, 0.0, 0.0]).unwrap();
let filter = TrajectoryFilter::new().with_tenant_id("tenant-a");
let hits = retrieve_similar(&[1.0, 0.0, 0.0], &filter, 5).unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].trajectory_id, a.trajectory_id);
}
}