use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelRole {
Embedding,
Reranker,
}
impl ModelRole {
pub fn as_str(&self) -> &'static str {
match self {
ModelRole::Embedding => "embedding",
ModelRole::Reranker => "reranker",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelStatus {
Available,
Missing,
Invalid,
Installing,
Disabled,
}
impl ModelStatus {
pub fn as_str(&self) -> &'static str {
match self {
ModelStatus::Available => "available",
ModelStatus::Missing => "missing",
ModelStatus::Invalid => "invalid",
ModelStatus::Installing => "installing",
ModelStatus::Disabled => "disabled",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchCapability {
KeywordOnly,
Hybrid,
HybridWithRerank,
}
pub fn search_capability(
embedding: Option<ModelStatus>,
reranker: Option<ModelStatus>,
) -> SearchCapability {
match (embedding, reranker) {
(Some(ModelStatus::Available), Some(ModelStatus::Available)) => {
SearchCapability::HybridWithRerank
}
(Some(ModelStatus::Available), _) => SearchCapability::Hybrid,
_ => SearchCapability::KeywordOnly,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn capability_degrades_gracefully() {
assert_eq!(search_capability(None, None), SearchCapability::KeywordOnly);
assert_eq!(
search_capability(Some(ModelStatus::Missing), None),
SearchCapability::KeywordOnly
);
assert_eq!(
search_capability(Some(ModelStatus::Available), None),
SearchCapability::Hybrid
);
assert_eq!(
search_capability(Some(ModelStatus::Available), Some(ModelStatus::Missing)),
SearchCapability::Hybrid
);
assert_eq!(
search_capability(Some(ModelStatus::Available), Some(ModelStatus::Available)),
SearchCapability::HybridWithRerank
);
}
}
#[derive(Debug, Clone)]
pub struct VectorCandidate {
pub chunk_id: orbok_core::ChunkId,
pub file_id: orbok_core::FileId,
pub rank: u32,
pub score: f32,
}
pub trait EmbeddingModel: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn dimension(&self) -> u32;
fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>>;
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
pub fn l2_normalize(v: &mut Vec<f32>) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
pub fn vec_to_blob(v: &[f32]) -> Vec<u8> {
v.iter().flat_map(|x| x.to_le_bytes()).collect()
}
pub fn blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<f32>> {
let dim = expected_dim as usize;
if blob.len() != dim * 4 {
return None;
}
Some(
blob.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
)
}
pub struct MockEmbeddingModel;
impl EmbeddingModel for MockEmbeddingModel {
fn name(&self) -> &str {
"mock"
}
fn version(&self) -> &str {
"v1"
}
fn dimension(&self) -> u32 {
8
}
fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>> {
use sha2::{Digest, Sha256};
texts
.iter()
.map(|text| {
let digest = Sha256::digest(text.as_bytes());
let mut v: Vec<f32> = digest[..8]
.iter()
.map(|&b| b as f32 / 255.0)
.collect();
l2_normalize(&mut v);
Ok(v)
})
.collect()
}
}
#[cfg(test)]
mod embedding_tests {
use super::*;
#[test]
fn mock_embed_batch() {
let model = MockEmbeddingModel;
let vecs = model.embed_batch(&["hello world", "foo bar"]).unwrap();
assert_eq!(vecs.len(), 2);
for v in &vecs {
assert_eq!(v.len(), model.dimension() as usize);
}
}
#[test]
fn blob_roundtrip_and_dim_mismatch() {
let v = vec![0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
let blob = vec_to_blob(&v);
assert_eq!(blob.len(), 32);
let back = blob_to_vec(&blob, 8).unwrap();
for (a, b) in v.iter().zip(&back) {
assert!((a - b).abs() < 1e-6);
}
assert!(blob_to_vec(&blob, 16).is_none(), "dim mismatch must return None");
}
#[test]
fn normalize_produces_unit_vector() {
let mut v = vec![3.0_f32, 4.0];
l2_normalize(&mut v);
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_sim_identical_vectors() {
let mut v = vec![1.0_f32, 2.0, 3.0];
l2_normalize(&mut v);
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
}
#[derive(Debug, Clone)]
pub struct RerankCandidate {
pub chunk_id: orbok_core::ChunkId,
pub passage_text: String,
}
#[derive(Debug, Clone)]
pub struct RerankScore {
pub chunk_id: orbok_core::ChunkId,
pub score: f32,
}
pub trait CrossEncoderReranker: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn max_candidates(&self) -> u32;
fn rerank(&self, query: &str, candidates: &[RerankCandidate])
-> orbok_core::OrbokResult<Vec<RerankScore>>;
}
pub struct MockReranker;
impl CrossEncoderReranker for MockReranker {
fn name(&self) -> &str { "mock-reranker" }
fn version(&self) -> &str { "v1" }
fn max_candidates(&self) -> u32 { 20 }
fn rerank(
&self,
_query: &str,
candidates: &[RerankCandidate],
) -> orbok_core::OrbokResult<Vec<RerankScore>> {
let mut scores: Vec<RerankScore> = candidates
.iter()
.map(|c| RerankScore {
chunk_id: c.chunk_id.clone(),
score: c.passage_text.len() as f32,
})
.collect();
scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
Ok(scores)
}
}
#[cfg(test)]
mod reranker_tests {
use super::*;
use orbok_core::ChunkId;
#[test]
fn mock_reranker_orders_by_length() {
let r = MockReranker;
let candidates = vec![
RerankCandidate { chunk_id: ChunkId::from_string("c1".to_string()), passage_text: "short".into() },
RerankCandidate { chunk_id: ChunkId::from_string("c2".to_string()), passage_text: "a much longer passage".into() },
];
let scores = r.rerank("query", &candidates).unwrap();
assert_eq!(scores[0].chunk_id.as_str(), "c2", "longer passage should rank first");
}
#[test]
fn rerank_max_candidates_limit() {
assert!(MockReranker.max_candidates() > 0);
}
}