1use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ModelRole {
17 Embedding,
18 Reranker,
19}
20
21impl ModelRole {
22 pub fn as_str(&self) -> &'static str {
23 match self {
24 ModelRole::Embedding => "embedding",
25 ModelRole::Reranker => "reranker",
26 }
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum ModelStatus {
34 Available,
35 Missing,
36 Invalid,
37 Installing,
38 Disabled,
39}
40
41impl ModelStatus {
42 pub fn as_str(&self) -> &'static str {
43 match self {
44 ModelStatus::Available => "available",
45 ModelStatus::Missing => "missing",
46 ModelStatus::Invalid => "invalid",
47 ModelStatus::Installing => "installing",
48 ModelStatus::Disabled => "disabled",
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum SearchCapability {
57 KeywordOnly,
59 Hybrid,
61 HybridWithRerank,
63}
64
65pub fn search_capability(
67 embedding: Option<ModelStatus>,
68 reranker: Option<ModelStatus>,
69) -> SearchCapability {
70 match (embedding, reranker) {
71 (Some(ModelStatus::Available), Some(ModelStatus::Available)) => {
72 SearchCapability::HybridWithRerank
73 }
74 (Some(ModelStatus::Available), _) => SearchCapability::Hybrid,
75 _ => SearchCapability::KeywordOnly,
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
85 fn capability_degrades_gracefully() {
86 assert_eq!(search_capability(None, None), SearchCapability::KeywordOnly);
87 assert_eq!(
88 search_capability(Some(ModelStatus::Missing), None),
89 SearchCapability::KeywordOnly
90 );
91 assert_eq!(
92 search_capability(Some(ModelStatus::Available), None),
93 SearchCapability::Hybrid
94 );
95 assert_eq!(
96 search_capability(Some(ModelStatus::Available), Some(ModelStatus::Missing)),
97 SearchCapability::Hybrid
98 );
99 assert_eq!(
100 search_capability(Some(ModelStatus::Available), Some(ModelStatus::Available)),
101 SearchCapability::HybridWithRerank
102 );
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct VectorCandidate {
109 pub chunk_id: orbok_core::ChunkId,
110 pub file_id: orbok_core::FileId,
111 pub rank: u32,
112 pub score: f32,
113}
114
115pub trait EmbeddingModel: Send + Sync {
119 fn name(&self) -> &str;
121 fn version(&self) -> &str;
123 fn dimension(&self) -> u32;
125 fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>>;
128}
129
130pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
132 a.iter().zip(b).map(|(x, y)| x * y).sum()
133}
134
135pub fn l2_normalize(v: &mut Vec<f32>) {
137 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
138 if norm > 1e-10 {
139 for x in v.iter_mut() {
140 *x /= norm;
141 }
142 }
143}
144
145pub fn vec_to_blob(v: &[f32]) -> Vec<u8> {
148 v.iter().flat_map(|x| x.to_le_bytes()).collect()
149}
150
151pub fn blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<f32>> {
153 let dim = expected_dim as usize;
154 if blob.len() != dim * 4 {
155 return None;
156 }
157 Some(
158 blob.chunks_exact(4)
159 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
160 .collect(),
161 )
162}
163
164pub struct MockEmbeddingModel;
173
174impl EmbeddingModel for MockEmbeddingModel {
175 fn name(&self) -> &str {
176 "mock"
177 }
178 fn version(&self) -> &str {
179 "v1"
180 }
181 fn dimension(&self) -> u32 {
182 8
183 }
184 fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>> {
185 use sha2::{Digest, Sha256};
186 texts
187 .iter()
188 .map(|text| {
189 let digest = Sha256::digest(text.as_bytes());
190 let mut v: Vec<f32> = digest[..8].iter().map(|&b| b as f32 / 255.0).collect();
191 l2_normalize(&mut v);
192 Ok(v)
193 })
194 .collect()
195 }
196}
197
198#[cfg(test)]
199mod embedding_tests {
200 use super::*;
201
202 #[test]
204 fn mock_embed_batch() {
205 let model = MockEmbeddingModel;
206 let vecs = model.embed_batch(&["hello world", "foo bar"]).unwrap();
207 assert_eq!(vecs.len(), 2);
208 for v in &vecs {
209 assert_eq!(v.len(), model.dimension() as usize);
210 }
211 }
212
213 #[test]
215 fn blob_roundtrip_and_dim_mismatch() {
216 let v = vec![0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
217 let blob = vec_to_blob(&v);
218 assert_eq!(blob.len(), 32);
219 let back = blob_to_vec(&blob, 8).unwrap();
220 for (a, b) in v.iter().zip(&back) {
221 assert!((a - b).abs() < 1e-6);
222 }
223 assert!(
224 blob_to_vec(&blob, 16).is_none(),
225 "dim mismatch must return None"
226 );
227 }
228
229 #[test]
231 fn normalize_produces_unit_vector() {
232 let mut v = vec![3.0_f32, 4.0];
233 l2_normalize(&mut v);
234 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
235 assert!((norm - 1.0).abs() < 1e-6);
236 }
237
238 #[test]
240 fn cosine_sim_identical_vectors() {
241 let mut v = vec![1.0_f32, 2.0, 3.0];
242 l2_normalize(&mut v);
243 let sim = cosine_similarity(&v, &v);
244 assert!((sim - 1.0).abs() < 1e-6);
245 }
246}
247
248#[derive(Debug, Clone)]
252pub struct RerankCandidate {
253 pub chunk_id: orbok_core::ChunkId,
254 pub passage_text: String,
256}
257
258#[derive(Debug, Clone)]
260pub struct RerankScore {
261 pub chunk_id: orbok_core::ChunkId,
262 pub score: f32,
263}
264
265pub trait CrossEncoderReranker: Send + Sync {
270 fn name(&self) -> &str;
271 fn version(&self) -> &str;
272 fn max_candidates(&self) -> u32;
274 fn rerank(
275 &self,
276 query: &str,
277 candidates: &[RerankCandidate],
278 ) -> orbok_core::OrbokResult<Vec<RerankScore>>;
279}
280
281pub struct MockReranker;
284
285impl CrossEncoderReranker for MockReranker {
286 fn name(&self) -> &str {
287 "mock-reranker"
288 }
289 fn version(&self) -> &str {
290 "v1"
291 }
292 fn max_candidates(&self) -> u32 {
293 20
294 }
295 fn rerank(
296 &self,
297 _query: &str,
298 candidates: &[RerankCandidate],
299 ) -> orbok_core::OrbokResult<Vec<RerankScore>> {
300 let mut scores: Vec<RerankScore> = candidates
301 .iter()
302 .map(|c| RerankScore {
303 chunk_id: c.chunk_id.clone(),
304 score: c.passage_text.len() as f32,
305 })
306 .collect();
307 scores.sort_by(|a, b| {
308 b.score
309 .partial_cmp(&a.score)
310 .unwrap_or(std::cmp::Ordering::Equal)
311 });
312 Ok(scores)
313 }
314}
315
316#[cfg(test)]
317mod reranker_tests {
318 use super::*;
319 use orbok_core::ChunkId;
320
321 #[test]
323 fn mock_reranker_orders_by_length() {
324 let r = MockReranker;
325 let candidates = vec![
326 RerankCandidate {
327 chunk_id: ChunkId::from_string("c1".to_string()),
328 passage_text: "short".into(),
329 },
330 RerankCandidate {
331 chunk_id: ChunkId::from_string("c2".to_string()),
332 passage_text: "a much longer passage".into(),
333 },
334 ];
335 let scores = r.rerank("query", &candidates).unwrap();
336 assert_eq!(
337 scores[0].chunk_id.as_str(),
338 "c2",
339 "longer passage should rank first"
340 );
341 }
342
343 #[test]
345 fn rerank_max_candidates_limit() {
346 assert!(MockReranker.max_candidates() > 0);
347 }
348}
349
350#[derive(Debug, Clone, PartialEq, Eq)]
354pub enum InferenceBackend {
355 CandleCpu,
357 CandleCuda,
359 OnnxRuntime,
361 Mock,
363}
364
365impl InferenceBackend {
366 pub fn as_str(&self) -> &'static str {
367 match self {
368 InferenceBackend::CandleCpu => "candle-cpu",
369 InferenceBackend::CandleCuda => "candle-cuda",
370 InferenceBackend::OnnxRuntime => "onnx-runtime",
371 InferenceBackend::Mock => "mock",
372 }
373 }
374}
375
376#[derive(Debug, Clone)]
387pub struct EmbeddingModelConfig {
388 pub weights_path: String,
390 pub tokenizer_path: Option<String>,
392 pub dimension: u32,
394 pub max_seq_len: u32,
396 pub backend: InferenceBackend,
398 pub model_name: String,
400 pub model_version: String,
402}
403
404impl EmbeddingModelConfig {
405 pub fn weights_exist(&self) -> bool {
407 std::path::Path::new(&self.weights_path).exists()
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct RerankerConfig {
414 pub weights_path: String,
415 pub tokenizer_path: Option<String>,
416 pub max_seq_len: u32,
417 pub backend: InferenceBackend,
418 pub model_name: String,
419 pub model_version: String,
420}
421
422pub fn quantize_to_i8(v: &[f32]) -> Vec<i8> {
430 v.iter()
431 .map(|&x| (x * 127.0).round().clamp(-127.0, 127.0) as i8)
432 .collect()
433}
434
435pub fn dequantize_from_i8(v: &[i8]) -> Vec<f32> {
437 v.iter().map(|&x| x as f32 / 127.0).collect()
438}
439
440pub fn i8_vec_to_blob(v: &[i8]) -> Vec<u8> {
442 v.iter().map(|&x| x as u8).collect()
444}
445
446pub fn i8_blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<i8>> {
448 if blob.len() != expected_dim as usize {
449 return None;
450 }
451 Some(blob.iter().map(|&b| b as i8).collect())
452}
453
454pub fn cosine_similarity_i8(a: &[i8], b: &[i8]) -> f32 {
458 cosine_similarity(&dequantize_from_i8(a), &dequantize_from_i8(b))
459}
460
461#[cfg(test)]
462mod quantization_tests {
463 use super::*;
464
465 #[test]
467 fn fp32_and_i8_both_available() {
468 let v = vec![0.6f32, 0.8, 0.0, -0.5];
469 let blob_fp32 = vec_to_blob(&v);
470 let i8_vec = quantize_to_i8(&v);
471 let blob_i8 = i8_vec_to_blob(&i8_vec);
472 assert_eq!(blob_i8.len() * 4, blob_fp32.len());
474 }
475
476 #[test]
478 fn int8_is_4x_smaller_than_fp32() {
479 let v: Vec<f32> = (0..384).map(|i| (i as f32 / 384.0) - 0.5).collect();
480 let mut vn = v.clone();
481 l2_normalize(&mut vn);
482 let fp32_bytes = vec_to_blob(&vn).len();
483 let int8_bytes = i8_vec_to_blob(&quantize_to_i8(&vn)).len();
484 assert_eq!(fp32_bytes, 384 * 4);
485 assert_eq!(int8_bytes, 384);
486 assert_eq!(fp32_bytes / int8_bytes, 4);
487 }
488
489 #[test]
491 fn quantization_quality_loss_is_small() {
492 let mut v: Vec<f32> = (0..384).map(|i| ((i as f32 * 0.017).sin())).collect();
493 l2_normalize(&mut v);
494 let q = quantize_to_i8(&v);
495 let original_self_sim = cosine_similarity(&v, &v);
496 let quantized_self_sim = cosine_similarity_i8(&q, &q);
497 assert!(
499 (quantized_self_sim - original_self_sim).abs() < 0.02,
500 "quantization quality loss too high: {:.4}",
501 (quantized_self_sim - original_self_sim).abs()
502 );
503 }
504
505 #[test]
507 fn fp32_int8_roundtrip_within_tolerance() {
508 let mut v: Vec<f32> = vec![0.3, -0.7, 0.5, 0.1, -0.2, 0.8, -0.4, 0.6];
509 l2_normalize(&mut v);
510 let quantized = quantize_to_i8(&v);
511 let dequantized = dequantize_from_i8(&quantized);
512 for (orig, deq) in v.iter().zip(&dequantized) {
513 assert!(
514 (orig - deq).abs() < 0.01,
515 "round-trip error too large: {orig:.4} → {deq:.4}"
516 );
517 }
518 }
519}