1use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::RuntimeResult;
10use crate::runtime::KhiveRuntime;
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14 VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19#[derive(Clone, Debug)]
21pub struct SearchHit {
22 pub entity_id: Uuid,
23 pub score: DeterministicScore,
24 pub source: SearchSource,
25 pub title: Option<String>,
26 pub snippet: Option<String>,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32 Vector,
33 Text,
34 Both,
35}
36
37const RRF_K: usize = 60;
39
40const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44 pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49 let service = self.embedder().await?;
50 let model = self
51 .config()
52 .embedding_model
53 .expect("embedder() returns Unconfigured when model is None");
54 Ok(service.embed_one(text, model).await?)
55 }
56
57 pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65 if texts.is_empty() {
66 return Ok(vec![]);
67 }
68 let service = self.embedder().await?;
69 let model = self
70 .config()
71 .embedding_model
72 .expect("embedder() returns Unconfigured when model is None");
73 Ok(service.embed(texts, model).await?)
74 }
75
76 pub async fn hybrid_search(
84 &self,
85 namespace: Option<&str>,
86 query_text: &str,
87 query_vector: Option<Vec<f32>>,
88 limit: u32,
89 ) -> RuntimeResult<Vec<SearchHit>> {
90 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
91
92 let ns = self.ns(namespace).to_string();
93 let text_hits = self
94 .text(namespace)?
95 .search(TextSearchRequest {
96 query: query_text.to_string(),
97 mode: TextQueryMode::Plain,
98 filter: Some(TextFilter {
99 namespaces: vec![ns.clone()],
100 ..TextFilter::default()
101 }),
102 top_k: candidates,
103 snippet_chars: 200,
104 })
105 .await?;
106
107 let vector_hits = if let Some(vec) = query_vector {
108 self.vectors(namespace)?
109 .search(VectorSearchRequest {
110 query_embedding: vec,
111 top_k: candidates,
112 namespace: Some(ns.clone()),
113 kind: Some(SubstrateKind::Entity),
114 })
115 .await?
116 } else {
117 Vec::new()
118 };
119
120 let mut fused = rrf_fuse(text_hits, vector_hits, limit as usize);
121
122 if !fused.is_empty() {
125 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
126 let alive_page = self
127 .entities(namespace)?
128 .query_entities(
129 self.ns(namespace),
130 EntityFilter {
131 ids: candidate_ids,
132 ..EntityFilter::default()
133 },
134 PageRequest {
135 offset: 0,
136 limit: fused.len() as u32,
137 },
138 )
139 .await?;
140 let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
141 fused.retain(|h| alive.contains(&h.entity_id));
142 }
143
144 Ok(fused)
145 }
146
147 pub async fn knn(
153 &self,
154 namespace: Option<&str>,
155 query_vector: Vec<f32>,
156 top_k: u32,
157 ) -> RuntimeResult<Vec<VectorSearchHit>> {
158 let ns = self.ns(namespace).to_string();
159 Ok(self
160 .vectors(namespace)?
161 .search(VectorSearchRequest {
162 query_embedding: query_vector,
163 top_k,
164 namespace: Some(ns),
165 kind: Some(SubstrateKind::Entity),
166 })
167 .await?)
168 }
169
170 pub async fn rerank(
176 &self,
177 namespace: Option<&str>,
178 query_vector: &[f32],
179 candidate_ids: &[Uuid],
180 top_k: u32,
181 ) -> RuntimeResult<Vec<VectorSearchHit>> {
182 let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
183 let ns = self.ns(namespace).to_string();
184 let all_hits = self
185 .vectors(namespace)?
186 .search(VectorSearchRequest {
187 query_embedding: query_vector.to_vec(),
188 top_k: candidate_ids.len() as u32,
189 namespace: Some(ns),
190 kind: Some(SubstrateKind::Entity),
191 })
192 .await?;
193 let mut hits: Vec<VectorSearchHit> = all_hits
194 .into_iter()
195 .filter(|h| candidate_set.contains(&h.subject_id))
196 .collect();
197 hits.sort_by(|a, b| b.score.cmp(&a.score));
198 hits.truncate(top_k as usize);
199 Ok(hits)
200 }
201}
202
203fn rrf_fuse(
207 text_hits: Vec<TextSearchHit>,
208 vector_hits: Vec<VectorSearchHit>,
209 limit: usize,
210) -> Vec<SearchHit> {
211 #[derive(Default)]
212 struct Bucket {
213 score: DeterministicScore,
214 source: Option<SearchSource>,
215 title: Option<String>,
216 snippet: Option<String>,
217 }
218
219 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
220
221 for (i, hit) in text_hits.into_iter().enumerate() {
222 let rank = i + 1; let entry = buckets.entry(hit.subject_id).or_default();
224 entry.score = entry.score + rrf_score(rank, RRF_K);
225 entry.source = Some(match entry.source {
226 Some(SearchSource::Vector) => SearchSource::Both,
227 _ => SearchSource::Text,
228 });
229 if entry.title.is_none() {
230 entry.title = hit.title;
231 }
232 if entry.snippet.is_none() {
233 entry.snippet = hit.snippet;
234 }
235 }
236
237 for (i, hit) in vector_hits.into_iter().enumerate() {
238 let rank = i + 1;
239 let entry = buckets.entry(hit.subject_id).or_default();
240 entry.score = entry.score + rrf_score(rank, RRF_K);
241 entry.source = Some(match entry.source {
242 Some(SearchSource::Text) => SearchSource::Both,
243 _ => SearchSource::Vector,
244 });
245 }
246
247 let mut hits: Vec<SearchHit> = buckets
248 .into_iter()
249 .map(|(id, b)| SearchHit {
250 entity_id: id,
251 score: b.score,
252 source: b.source.expect("each bucket gets a source"),
253 title: b.title,
254 snippet: b.snippet,
255 })
256 .collect();
257
258 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
259 hits.truncate(limit);
260 hits
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::runtime::{KhiveRuntime, RuntimeConfig};
267 use khive_storage::types::{TextSearchHit, VectorSearchHit};
268 use lattice_embed::EmbeddingModel;
269
270 fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
271 TextSearchHit {
272 subject_id: id,
273 score: DeterministicScore::from_f64(1.0),
274 rank,
275 title: Some(title.to_string()),
276 snippet: Some("...".to_string()),
277 }
278 }
279
280 fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
281 VectorSearchHit {
282 subject_id: id,
283 score: DeterministicScore::from_f64(0.9),
284 rank,
285 }
286 }
287
288 #[test]
289 fn rrf_fuse_text_only() {
290 let a = Uuid::new_v4();
291 let b = Uuid::new_v4();
292 let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
293 let hits = rrf_fuse(text, vec![], 10);
294 assert_eq!(hits.len(), 2);
295 assert_eq!(hits[0].entity_id, a);
296 assert_eq!(hits[0].source, SearchSource::Text);
297 assert_eq!(hits[0].title.as_deref(), Some("A"));
298 }
299
300 #[test]
301 fn rrf_fuse_vector_only() {
302 let a = Uuid::new_v4();
303 let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
304 assert_eq!(hits.len(), 1);
305 assert_eq!(hits[0].source, SearchSource::Vector);
306 assert!(hits[0].title.is_none());
307 }
308
309 #[test]
310 fn rrf_fuse_marks_both_when_in_both_lists() {
311 let id = Uuid::new_v4();
312 let text = vec![text_hit(id, 1, "A")];
313 let vec = vec![vector_hit(id, 1)];
314 let hits = rrf_fuse(text, vec, 10);
315 assert_eq!(hits.len(), 1);
316 assert_eq!(hits[0].source, SearchSource::Both);
317 }
318
319 #[test]
320 fn rrf_fuse_respects_limit() {
321 let hits: Vec<TextSearchHit> = (0..20)
322 .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
323 .collect();
324 let fused = rrf_fuse(hits, vec![], 5);
325 assert_eq!(fused.len(), 5);
326 }
327
328 #[test]
329 fn rrf_fuse_orders_higher_score_first() {
330 let a = Uuid::new_v4();
332 let b = Uuid::new_v4();
333 let text = vec![text_hit(a, 1, "A")];
334 let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
335 let hits = rrf_fuse(text, vec, 10);
336 assert_eq!(hits[0].entity_id, a);
337 assert_eq!(hits[0].source, SearchSource::Both);
338 assert!(hits[0].score > hits[1].score);
339 }
340
341 #[test]
344 fn embed_batch_unconfigured_on_memory_runtime() {
345 let rt = KhiveRuntime::memory().unwrap();
347 let result = tokio::runtime::Runtime::new()
348 .unwrap()
349 .block_on(rt.embed_batch(&[]));
350 assert!(result.is_ok());
352 assert!(result.unwrap().is_empty());
353 }
354
355 #[test]
356 fn embed_batch_empty_input_returns_empty_vec() {
357 let rt = KhiveRuntime::memory().unwrap();
359 let result = tokio::runtime::Runtime::new()
360 .unwrap()
361 .block_on(rt.embed_batch(&[]));
362 assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
363 }
364
365 #[test]
366 fn embed_batch_no_model_non_empty_returns_unconfigured() {
367 let rt = KhiveRuntime::memory().unwrap();
368 let texts = vec!["hello".to_string()];
369 let result = tokio::runtime::Runtime::new()
370 .unwrap()
371 .block_on(rt.embed_batch(&texts));
372 match result {
373 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
374 Err(other) => panic!("expected Unconfigured, got {:?}", other),
375 Ok(_) => panic!("expected Err, got Ok"),
376 }
377 }
378
379 #[test]
380 #[ignore = "loads ~80 MB model; run with --include-ignored"]
381 fn embed_batch_count_matches_input() {
382 let config = RuntimeConfig {
383 db_path: None,
384 default_namespace: "test".to_string(),
385 embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
386 };
387 let rt = KhiveRuntime::new(config).unwrap();
388 let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
389 let result = tokio::runtime::Runtime::new()
390 .unwrap()
391 .block_on(rt.embed_batch(&texts));
392 let embeddings = result.unwrap();
393 assert_eq!(embeddings.len(), texts.len());
394 }
395
396 #[test]
397 #[ignore = "loads ~80 MB model; run with --include-ignored"]
398 fn embed_batch_vectors_have_expected_dimensions() {
399 let model = EmbeddingModel::AllMiniLmL6V2;
400 let config = RuntimeConfig {
401 db_path: None,
402 default_namespace: "test".to_string(),
403 embedding_model: Some(model),
404 };
405 let rt = KhiveRuntime::new(config).unwrap();
406 let texts = vec!["hello world".to_string()];
407 let result = tokio::runtime::Runtime::new()
408 .unwrap()
409 .block_on(rt.embed_batch(&texts));
410 let embeddings = result.unwrap();
411 assert_eq!(embeddings[0].len(), model.dimensions());
412 }
413}