embeddenator_interop/
kernel_interop.rs1use embeddenator_vsa::{ReversibleVSAConfig, SparseVec};
11use std::collections::HashMap;
12use std::fmt;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum KernelInteropError {
17 MissingVector { id: usize },
18}
19
20impl fmt::Display for KernelInteropError {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 match self {
23 KernelInteropError::MissingVector { id } => {
24 write!(f, "missing vector for id {id}")
25 }
26 }
27 }
28}
29
30impl std::error::Error for KernelInteropError {}
31
32pub trait VsaBackend {
38 type Vector: Clone + Send + Sync + 'static;
39
40 fn zero(&self) -> Self::Vector;
41
42 fn bundle(&self, a: &Self::Vector, b: &Self::Vector) -> Self::Vector;
43
44 fn bind(&self, a: &Self::Vector, b: &Self::Vector) -> Self::Vector;
45
46 fn cosine(&self, a: &Self::Vector, b: &Self::Vector) -> f64;
47
48 fn encode_data(
49 &self,
50 data: &[u8],
51 config: &ReversibleVSAConfig,
52 path: Option<&str>,
53 ) -> Self::Vector;
54
55 fn decode_data(
56 &self,
57 vec: &Self::Vector,
58 config: &ReversibleVSAConfig,
59 path: Option<&str>,
60 expected_size: usize,
61 ) -> Vec<u8>;
62}
63
64#[derive(Clone, Copy, Debug, Default)]
66pub struct SparseVecBackend;
67
68impl VsaBackend for SparseVecBackend {
69 type Vector = SparseVec;
70
71 fn zero(&self) -> Self::Vector {
72 SparseVec::new()
73 }
74
75 fn bundle(&self, a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
76 a.bundle(b)
77 }
78
79 fn bind(&self, a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
80 a.bind(b)
81 }
82
83 fn cosine(&self, a: &Self::Vector, b: &Self::Vector) -> f64 {
84 a.cosine(b)
85 }
86
87 fn encode_data(
88 &self,
89 data: &[u8],
90 config: &ReversibleVSAConfig,
91 path: Option<&str>,
92 ) -> Self::Vector {
93 SparseVec::encode_data(data, config, path)
94 }
95
96 fn decode_data(
97 &self,
98 vec: &Self::Vector,
99 config: &ReversibleVSAConfig,
100 path: Option<&str>,
101 expected_size: usize,
102 ) -> Vec<u8> {
103 vec.decode_data(config, path, expected_size)
104 }
105}
106
107pub trait VectorStore<V> {
111 fn get(&self, id: usize) -> Option<&V>;
112}
113
114impl VectorStore<SparseVec> for HashMap<usize, SparseVec> {
115 fn get(&self, id: usize) -> Option<&SparseVec> {
116 self.get(&id)
117 }
118}
119
120pub trait CandidateGenerator<V> {
122 type Candidate;
123
124 fn candidates(&self, query: &V, k: usize) -> Vec<Self::Candidate>;
125}
126
127pub fn rerank_top_k_by_cosine<B, S>(
134 backend: &B,
135 store: &S,
136 query: &B::Vector,
137 candidate_ids: impl IntoIterator<Item = usize>,
138 k: usize,
139) -> Result<Vec<(usize, f64)>, KernelInteropError>
140where
141 B: VsaBackend,
142 S: VectorStore<B::Vector>,
143{
144 if k == 0 {
145 return Ok(Vec::new());
146 }
147
148 let mut scored = Vec::new();
149 for id in candidate_ids {
150 let vec = store
151 .get(id)
152 .ok_or(KernelInteropError::MissingVector { id })?;
153 scored.push((id, backend.cosine(query, vec)));
154 }
155
156 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
157 scored.truncate(k);
158 Ok(scored)
159}