1use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::RuntimeResult;
10use crate::runtime::KhiveRuntime;
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14 VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19#[derive(Clone, Debug)]
21pub struct SearchHit {
22 pub entity_id: Uuid,
23 pub score: DeterministicScore,
24 pub source: SearchSource,
25 pub title: Option<String>,
26 pub snippet: Option<String>,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32 Vector,
33 Text,
34 Both,
35}
36
37const RRF_K: usize = 60;
39
40const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44 pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49 let service = self.embedder().await?;
50 let model = self
51 .config()
52 .embedding_model
53 .expect("embedder() returns Unconfigured when model is None");
54 Ok(service.embed_one(text, model).await?)
55 }
56
57 pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65 if texts.is_empty() {
66 return Ok(vec![]);
67 }
68 let service = self.embedder().await?;
69 let model = self
70 .config()
71 .embedding_model
72 .expect("embedder() returns Unconfigured when model is None");
73 Ok(service.embed(texts, model).await?)
74 }
75
76 pub async fn hybrid_search(
91 &self,
92 namespace: Option<&str>,
93 query_text: &str,
94 query_vector: Option<Vec<f32>>,
95 limit: u32,
96 entity_kind: Option<&str>,
97 ) -> RuntimeResult<Vec<SearchHit>> {
98 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
99
100 let ns = self.ns(namespace).to_string();
101 let text_hits = self
102 .text(namespace)?
103 .search(TextSearchRequest {
104 query: query_text.to_string(),
105 mode: TextQueryMode::Plain,
106 filter: Some(TextFilter {
107 namespaces: vec![ns.clone()],
108 ..TextFilter::default()
109 }),
110 top_k: candidates,
111 snippet_chars: 200,
112 })
113 .await?;
114
115 let vector_hits = if let Some(vec) = query_vector {
116 self.vectors(namespace)?
117 .search(VectorSearchRequest {
118 query_embedding: vec,
119 top_k: candidates,
120 namespace: Some(ns.clone()),
121 kind: Some(SubstrateKind::Entity),
122 })
123 .await?
124 } else {
125 Vec::new()
126 };
127
128 let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize);
131
132 if !fused.is_empty() {
136 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
137 let alive_page = self
138 .entities(namespace)?
139 .query_entities(
140 self.ns(namespace),
141 EntityFilter {
142 ids: candidate_ids,
143 kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
144 ..EntityFilter::default()
145 },
146 PageRequest {
147 offset: 0,
148 limit: fused.len() as u32,
149 },
150 )
151 .await?;
152 let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
153 fused.retain(|h| alive.contains(&h.entity_id));
154 }
155
156 fused.truncate(limit as usize);
157 Ok(fused)
158 }
159
160 pub async fn knn(
166 &self,
167 namespace: Option<&str>,
168 query_vector: Vec<f32>,
169 top_k: u32,
170 ) -> RuntimeResult<Vec<VectorSearchHit>> {
171 let ns = self.ns(namespace).to_string();
172 Ok(self
173 .vectors(namespace)?
174 .search(VectorSearchRequest {
175 query_embedding: query_vector,
176 top_k,
177 namespace: Some(ns),
178 kind: Some(SubstrateKind::Entity),
179 })
180 .await?)
181 }
182
183 pub async fn rerank(
189 &self,
190 namespace: Option<&str>,
191 query_vector: &[f32],
192 candidate_ids: &[Uuid],
193 top_k: u32,
194 ) -> RuntimeResult<Vec<VectorSearchHit>> {
195 let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
196 let ns = self.ns(namespace).to_string();
197 let all_hits = self
198 .vectors(namespace)?
199 .search(VectorSearchRequest {
200 query_embedding: query_vector.to_vec(),
201 top_k: candidate_ids.len() as u32,
202 namespace: Some(ns),
203 kind: Some(SubstrateKind::Entity),
204 })
205 .await?;
206 let mut hits: Vec<VectorSearchHit> = all_hits
207 .into_iter()
208 .filter(|h| candidate_set.contains(&h.subject_id))
209 .collect();
210 hits.sort_by(|a, b| b.score.cmp(&a.score));
211 hits.truncate(top_k as usize);
212 Ok(hits)
213 }
214}
215
216fn rrf_fuse(
220 text_hits: Vec<TextSearchHit>,
221 vector_hits: Vec<VectorSearchHit>,
222 limit: usize,
223) -> Vec<SearchHit> {
224 #[derive(Default)]
225 struct Bucket {
226 score: DeterministicScore,
227 source: Option<SearchSource>,
228 title: Option<String>,
229 snippet: Option<String>,
230 }
231
232 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
233
234 for (i, hit) in text_hits.into_iter().enumerate() {
235 let rank = i + 1; let entry = buckets.entry(hit.subject_id).or_default();
237 entry.score = entry.score + rrf_score(rank, RRF_K);
238 entry.source = Some(match entry.source {
239 Some(SearchSource::Vector) => SearchSource::Both,
240 _ => SearchSource::Text,
241 });
242 if entry.title.is_none() {
243 entry.title = hit.title;
244 }
245 if entry.snippet.is_none() {
246 entry.snippet = hit.snippet;
247 }
248 }
249
250 for (i, hit) in vector_hits.into_iter().enumerate() {
251 let rank = i + 1;
252 let entry = buckets.entry(hit.subject_id).or_default();
253 entry.score = entry.score + rrf_score(rank, RRF_K);
254 entry.source = Some(match entry.source {
255 Some(SearchSource::Text) => SearchSource::Both,
256 _ => SearchSource::Vector,
257 });
258 }
259
260 let mut hits: Vec<SearchHit> = buckets
261 .into_iter()
262 .map(|(id, b)| SearchHit {
263 entity_id: id,
264 score: b.score,
265 source: b.source.expect("each bucket gets a source"),
266 title: b.title,
267 snippet: b.snippet,
268 })
269 .collect();
270
271 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
272 hits.truncate(limit);
273 hits
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::runtime::{KhiveRuntime, RuntimeConfig};
280 use khive_storage::types::{TextSearchHit, VectorSearchHit};
281 use lattice_embed::EmbeddingModel;
282
283 fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
284 TextSearchHit {
285 subject_id: id,
286 score: DeterministicScore::from_f64(1.0),
287 rank,
288 title: Some(title.to_string()),
289 snippet: Some("...".to_string()),
290 }
291 }
292
293 fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
294 VectorSearchHit {
295 subject_id: id,
296 score: DeterministicScore::from_f64(0.9),
297 rank,
298 }
299 }
300
301 #[test]
302 fn rrf_fuse_text_only() {
303 let a = Uuid::new_v4();
304 let b = Uuid::new_v4();
305 let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
306 let hits = rrf_fuse(text, vec![], 10);
307 assert_eq!(hits.len(), 2);
308 assert_eq!(hits[0].entity_id, a);
309 assert_eq!(hits[0].source, SearchSource::Text);
310 assert_eq!(hits[0].title.as_deref(), Some("A"));
311 }
312
313 #[test]
314 fn rrf_fuse_vector_only() {
315 let a = Uuid::new_v4();
316 let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
317 assert_eq!(hits.len(), 1);
318 assert_eq!(hits[0].source, SearchSource::Vector);
319 assert!(hits[0].title.is_none());
320 }
321
322 #[test]
323 fn rrf_fuse_marks_both_when_in_both_lists() {
324 let id = Uuid::new_v4();
325 let text = vec![text_hit(id, 1, "A")];
326 let vec = vec![vector_hit(id, 1)];
327 let hits = rrf_fuse(text, vec, 10);
328 assert_eq!(hits.len(), 1);
329 assert_eq!(hits[0].source, SearchSource::Both);
330 }
331
332 #[test]
333 fn rrf_fuse_respects_limit() {
334 let hits: Vec<TextSearchHit> = (0..20)
335 .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
336 .collect();
337 let fused = rrf_fuse(hits, vec![], 5);
338 assert_eq!(fused.len(), 5);
339 }
340
341 #[test]
342 fn rrf_fuse_orders_higher_score_first() {
343 let a = Uuid::new_v4();
345 let b = Uuid::new_v4();
346 let text = vec![text_hit(a, 1, "A")];
347 let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
348 let hits = rrf_fuse(text, vec, 10);
349 assert_eq!(hits[0].entity_id, a);
350 assert_eq!(hits[0].source, SearchSource::Both);
351 assert!(hits[0].score > hits[1].score);
352 }
353
354 #[test]
357 fn embed_batch_unconfigured_on_memory_runtime() {
358 let rt = KhiveRuntime::memory().unwrap();
360 let result = tokio::runtime::Runtime::new()
361 .unwrap()
362 .block_on(rt.embed_batch(&[]));
363 assert!(result.is_ok());
365 assert!(result.unwrap().is_empty());
366 }
367
368 #[test]
369 fn embed_batch_empty_input_returns_empty_vec() {
370 let rt = KhiveRuntime::memory().unwrap();
372 let result = tokio::runtime::Runtime::new()
373 .unwrap()
374 .block_on(rt.embed_batch(&[]));
375 assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
376 }
377
378 #[test]
379 fn embed_batch_no_model_non_empty_returns_unconfigured() {
380 let rt = KhiveRuntime::memory().unwrap();
381 let texts = vec!["hello".to_string()];
382 let result = tokio::runtime::Runtime::new()
383 .unwrap()
384 .block_on(rt.embed_batch(&texts));
385 match result {
386 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
387 Err(other) => panic!("expected Unconfigured, got {:?}", other),
388 Ok(_) => panic!("expected Err, got Ok"),
389 }
390 }
391
392 #[test]
393 #[ignore = "loads ~80 MB model; run with --include-ignored"]
394 fn embed_batch_count_matches_input() {
395 let config = RuntimeConfig {
396 db_path: None,
397 default_namespace: "test".to_string(),
398 embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
399 packs: vec!["kg".to_string()],
400 ..RuntimeConfig::default()
401 };
402 let rt = KhiveRuntime::new(config).unwrap();
403 let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
404 let result = tokio::runtime::Runtime::new()
405 .unwrap()
406 .block_on(rt.embed_batch(&texts));
407 let embeddings = result.unwrap();
408 assert_eq!(embeddings.len(), texts.len());
409 }
410
411 #[test]
412 #[ignore = "loads ~80 MB model; run with --include-ignored"]
413 fn embed_batch_vectors_have_expected_dimensions() {
414 let model = EmbeddingModel::AllMiniLmL6V2;
415 let config = RuntimeConfig {
416 db_path: None,
417 default_namespace: "test".to_string(),
418 embedding_model: Some(model),
419 packs: vec!["kg".to_string()],
420 ..RuntimeConfig::default()
421 };
422 let rt = KhiveRuntime::new(config).unwrap();
423 let texts = vec!["hello world".to_string()];
424 let result = tokio::runtime::Runtime::new()
425 .unwrap()
426 .block_on(rt.embed_batch(&texts));
427 let embeddings = result.unwrap();
428 assert_eq!(embeddings[0].len(), model.dimensions());
429 }
430}