1use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ModelRole {
17 Embedding,
18 Reranker,
19}
20
21impl ModelRole {
22 pub fn as_str(&self) -> &'static str {
23 match self {
24 ModelRole::Embedding => "embedding",
25 ModelRole::Reranker => "reranker",
26 }
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum ModelStatus {
34 Available,
35 Missing,
36 Invalid,
37 Installing,
38 Disabled,
39}
40
41impl ModelStatus {
42 pub fn as_str(&self) -> &'static str {
43 match self {
44 ModelStatus::Available => "available",
45 ModelStatus::Missing => "missing",
46 ModelStatus::Invalid => "invalid",
47 ModelStatus::Installing => "installing",
48 ModelStatus::Disabled => "disabled",
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum SearchCapability {
57 KeywordOnly,
59 Hybrid,
61 HybridWithRerank,
63}
64
65pub fn search_capability(
67 embedding: Option<ModelStatus>,
68 reranker: Option<ModelStatus>,
69) -> SearchCapability {
70 match (embedding, reranker) {
71 (Some(ModelStatus::Available), Some(ModelStatus::Available)) => {
72 SearchCapability::HybridWithRerank
73 }
74 (Some(ModelStatus::Available), _) => SearchCapability::Hybrid,
75 _ => SearchCapability::KeywordOnly,
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
85 fn capability_degrades_gracefully() {
86 assert_eq!(search_capability(None, None), SearchCapability::KeywordOnly);
87 assert_eq!(
88 search_capability(Some(ModelStatus::Missing), None),
89 SearchCapability::KeywordOnly
90 );
91 assert_eq!(
92 search_capability(Some(ModelStatus::Available), None),
93 SearchCapability::Hybrid
94 );
95 assert_eq!(
96 search_capability(Some(ModelStatus::Available), Some(ModelStatus::Missing)),
97 SearchCapability::Hybrid
98 );
99 assert_eq!(
100 search_capability(Some(ModelStatus::Available), Some(ModelStatus::Available)),
101 SearchCapability::HybridWithRerank
102 );
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct VectorCandidate {
109 pub chunk_id: orbok_core::ChunkId,
110 pub file_id: orbok_core::FileId,
111 pub rank: u32,
112 pub score: f32,
113}
114
115pub trait EmbeddingModel: Send + Sync {
119 fn name(&self) -> &str;
121 fn version(&self) -> &str;
123 fn dimension(&self) -> u32;
125 fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>>;
128}
129
130pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
132 a.iter().zip(b).map(|(x, y)| x * y).sum()
133}
134
135pub fn l2_normalize(v: &mut Vec<f32>) {
137 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
138 if norm > 1e-10 {
139 for x in v.iter_mut() {
140 *x /= norm;
141 }
142 }
143}
144
145pub fn vec_to_blob(v: &[f32]) -> Vec<u8> {
148 v.iter().flat_map(|x| x.to_le_bytes()).collect()
149}
150
151pub fn blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<f32>> {
153 let dim = expected_dim as usize;
154 if blob.len() != dim * 4 {
155 return None;
156 }
157 Some(
158 blob.chunks_exact(4)
159 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
160 .collect(),
161 )
162}
163
164pub struct MockEmbeddingModel;
173
174impl EmbeddingModel for MockEmbeddingModel {
175 fn name(&self) -> &str {
176 "mock"
177 }
178 fn version(&self) -> &str {
179 "v1"
180 }
181 fn dimension(&self) -> u32 {
182 8
183 }
184 fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>> {
185 use sha2::{Digest, Sha256};
186 texts
187 .iter()
188 .map(|text| {
189 let digest = Sha256::digest(text.as_bytes());
190 let mut v: Vec<f32> = digest[..8]
191 .iter()
192 .map(|&b| b as f32 / 255.0)
193 .collect();
194 l2_normalize(&mut v);
195 Ok(v)
196 })
197 .collect()
198 }
199}
200
201#[cfg(test)]
202mod embedding_tests {
203 use super::*;
204
205 #[test]
207 fn mock_embed_batch() {
208 let model = MockEmbeddingModel;
209 let vecs = model.embed_batch(&["hello world", "foo bar"]).unwrap();
210 assert_eq!(vecs.len(), 2);
211 for v in &vecs {
212 assert_eq!(v.len(), model.dimension() as usize);
213 }
214 }
215
216 #[test]
218 fn blob_roundtrip_and_dim_mismatch() {
219 let v = vec![0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
220 let blob = vec_to_blob(&v);
221 assert_eq!(blob.len(), 32);
222 let back = blob_to_vec(&blob, 8).unwrap();
223 for (a, b) in v.iter().zip(&back) {
224 assert!((a - b).abs() < 1e-6);
225 }
226 assert!(blob_to_vec(&blob, 16).is_none(), "dim mismatch must return None");
227 }
228
229 #[test]
231 fn normalize_produces_unit_vector() {
232 let mut v = vec![3.0_f32, 4.0];
233 l2_normalize(&mut v);
234 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
235 assert!((norm - 1.0).abs() < 1e-6);
236 }
237
238 #[test]
240 fn cosine_sim_identical_vectors() {
241 let mut v = vec![1.0_f32, 2.0, 3.0];
242 l2_normalize(&mut v);
243 let sim = cosine_similarity(&v, &v);
244 assert!((sim - 1.0).abs() < 1e-6);
245 }
246}
247
248#[derive(Debug, Clone)]
252pub struct RerankCandidate {
253 pub chunk_id: orbok_core::ChunkId,
254 pub passage_text: String,
256}
257
258#[derive(Debug, Clone)]
260pub struct RerankScore {
261 pub chunk_id: orbok_core::ChunkId,
262 pub score: f32,
263}
264
265pub trait CrossEncoderReranker: Send + Sync {
270 fn name(&self) -> &str;
271 fn version(&self) -> &str;
272 fn max_candidates(&self) -> u32;
274 fn rerank(&self, query: &str, candidates: &[RerankCandidate])
275 -> orbok_core::OrbokResult<Vec<RerankScore>>;
276}
277
278pub struct MockReranker;
281
282impl CrossEncoderReranker for MockReranker {
283 fn name(&self) -> &str { "mock-reranker" }
284 fn version(&self) -> &str { "v1" }
285 fn max_candidates(&self) -> u32 { 20 }
286 fn rerank(
287 &self,
288 _query: &str,
289 candidates: &[RerankCandidate],
290 ) -> orbok_core::OrbokResult<Vec<RerankScore>> {
291 let mut scores: Vec<RerankScore> = candidates
292 .iter()
293 .map(|c| RerankScore {
294 chunk_id: c.chunk_id.clone(),
295 score: c.passage_text.len() as f32,
296 })
297 .collect();
298 scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
299 Ok(scores)
300 }
301}
302
303#[cfg(test)]
304mod reranker_tests {
305 use super::*;
306 use orbok_core::ChunkId;
307
308 #[test]
310 fn mock_reranker_orders_by_length() {
311 let r = MockReranker;
312 let candidates = vec![
313 RerankCandidate { chunk_id: ChunkId::from_string("c1".to_string()), passage_text: "short".into() },
314 RerankCandidate { chunk_id: ChunkId::from_string("c2".to_string()), passage_text: "a much longer passage".into() },
315 ];
316 let scores = r.rerank("query", &candidates).unwrap();
317 assert_eq!(scores[0].chunk_id.as_str(), "c2", "longer passage should rank first");
318 }
319
320 #[test]
322 fn rerank_max_candidates_limit() {
323 assert!(MockReranker.max_candidates() > 0);
324 }
325}
326
327#[derive(Debug, Clone, PartialEq, Eq)]
331pub enum InferenceBackend {
332 CandleCpu,
334 CandleCuda,
336 OnnxRuntime,
338 Mock,
340}
341
342impl InferenceBackend {
343 pub fn as_str(&self) -> &'static str {
344 match self {
345 InferenceBackend::CandleCpu => "candle-cpu",
346 InferenceBackend::CandleCuda => "candle-cuda",
347 InferenceBackend::OnnxRuntime => "onnx-runtime",
348 InferenceBackend::Mock => "mock",
349 }
350 }
351}
352
353#[derive(Debug, Clone)]
364pub struct EmbeddingModelConfig {
365 pub weights_path: String,
367 pub tokenizer_path: Option<String>,
369 pub dimension: u32,
371 pub max_seq_len: u32,
373 pub backend: InferenceBackend,
375 pub model_name: String,
377 pub model_version: String,
379}
380
381impl EmbeddingModelConfig {
382 pub fn weights_exist(&self) -> bool {
384 std::path::Path::new(&self.weights_path).exists()
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct RerankerConfig {
391 pub weights_path: String,
392 pub tokenizer_path: Option<String>,
393 pub max_seq_len: u32,
394 pub backend: InferenceBackend,
395 pub model_name: String,
396 pub model_version: String,
397}
398
399pub fn quantize_to_i8(v: &[f32]) -> Vec<i8> {
407 v.iter()
408 .map(|&x| (x * 127.0).round().clamp(-127.0, 127.0) as i8)
409 .collect()
410}
411
412pub fn dequantize_from_i8(v: &[i8]) -> Vec<f32> {
414 v.iter().map(|&x| x as f32 / 127.0).collect()
415}
416
417pub fn i8_vec_to_blob(v: &[i8]) -> Vec<u8> {
419 v.iter().map(|&x| x as u8).collect()
421}
422
423pub fn i8_blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<i8>> {
425 if blob.len() != expected_dim as usize {
426 return None;
427 }
428 Some(blob.iter().map(|&b| b as i8).collect())
429}
430
431pub fn cosine_similarity_i8(a: &[i8], b: &[i8]) -> f32 {
435 cosine_similarity(&dequantize_from_i8(a), &dequantize_from_i8(b))
436}
437
438#[cfg(test)]
439mod quantization_tests {
440 use super::*;
441
442 #[test]
444 fn fp32_and_i8_both_available() {
445 let v = vec![0.6f32, 0.8, 0.0, -0.5];
446 let blob_fp32 = vec_to_blob(&v);
447 let i8_vec = quantize_to_i8(&v);
448 let blob_i8 = i8_vec_to_blob(&i8_vec);
449 assert_eq!(blob_i8.len() * 4, blob_fp32.len());
451 }
452
453 #[test]
455 fn int8_is_4x_smaller_than_fp32() {
456 let v: Vec<f32> = (0..384).map(|i| (i as f32 / 384.0) - 0.5).collect();
457 let mut vn = v.clone();
458 l2_normalize(&mut vn);
459 let fp32_bytes = vec_to_blob(&vn).len();
460 let int8_bytes = i8_vec_to_blob(&quantize_to_i8(&vn)).len();
461 assert_eq!(fp32_bytes, 384 * 4);
462 assert_eq!(int8_bytes, 384);
463 assert_eq!(fp32_bytes / int8_bytes, 4);
464 }
465
466 #[test]
468 fn quantization_quality_loss_is_small() {
469 let mut v: Vec<f32> = (0..384)
470 .map(|i| ((i as f32 * 0.017).sin()))
471 .collect();
472 l2_normalize(&mut v);
473 let q = quantize_to_i8(&v);
474 let original_self_sim = cosine_similarity(&v, &v);
475 let quantized_self_sim = cosine_similarity_i8(&q, &q);
476 assert!((quantized_self_sim - original_self_sim).abs() < 0.02,
478 "quantization quality loss too high: {:.4}", (quantized_self_sim - original_self_sim).abs());
479 }
480
481 #[test]
483 fn fp32_int8_roundtrip_within_tolerance() {
484 let mut v: Vec<f32> = vec![0.3, -0.7, 0.5, 0.1, -0.2, 0.8, -0.4, 0.6];
485 l2_normalize(&mut v);
486 let quantized = quantize_to_i8(&v);
487 let dequantized = dequantize_from_i8(&quantized);
488 for (orig, deq) in v.iter().zip(&dequantized) {
489 assert!((orig - deq).abs() < 0.01,
490 "round-trip error too large: {orig:.4} → {deq:.4}");
491 }
492 }
493}