use crate::constants::EMBEDDING_DIM;
use crate::errors::AppError;
use crate::extract::llm_embedding::LlmEmbedding;
use parking_lot::Mutex;
use std::path::Path;
use std::sync::OnceLock;
static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
if let Some(e) = EMBEDDER.get() {
return Ok(e);
}
let backend = LlmEmbedding::detect_available()?;
let _ = EMBEDDER.set(Mutex::new(backend));
Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
}
pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
let mut guard = embedder.lock();
let result = guard.embed_passage(text)?;
Ok(normalise_dim(result))
}
pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
let mut guard = embedder.lock();
let result = guard.embed_query(text)?;
Ok(normalise_dim(result))
}
pub fn embed_passages_controlled(
embedder: &Mutex<LlmEmbedding>,
texts: &[&str],
token_counts: &[usize],
) -> Result<Vec<Vec<f32>>, AppError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut output: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
let mut group: Vec<&str> = Vec::new();
let mut current_padded = 0usize;
for (text, &tokens) in texts.iter().zip(token_counts.iter()) {
let padded = tokens.saturating_add(8);
if (current_padded + padded > crate::constants::REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS
|| group.len() >= crate::constants::REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS)
&& !group.is_empty()
{
flush_group(&mut output, &mut group, embedder)?;
current_padded = 0;
}
group.push(text);
current_padded += padded;
}
if !group.is_empty() {
flush_group(&mut output, &mut group, embedder)?;
}
Ok(output)
}
fn flush_group(
output: &mut Vec<Vec<f32>>,
group: &mut Vec<&str>,
embedder: &Mutex<LlmEmbedding>,
) -> Result<(), AppError> {
let mut guard = embedder.lock();
for text in group.iter() {
let v = guard.embed_passage(text)?;
output.push(normalise_dim(v));
}
group.clear();
Ok(())
}
pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
let embedder = get_embedder(models_dir)?;
embed_passage(embedder, text)
}
pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
let embedder = get_embedder(models_dir)?;
embed_query(embedder, text)
}
pub fn embed_passages_controlled_local(
models_dir: &Path,
texts: &[&str],
token_counts: &[usize],
) -> Result<Vec<Vec<f32>>, AppError> {
let embedder = get_embedder(models_dir)?;
embed_passages_controlled(embedder, texts, token_counts)
}
pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(v.len() * 4);
for f in v {
out.extend_from_slice(&f.to_le_bytes());
}
out
}
pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
out
}
pub fn embedding_dim() -> usize {
EMBEDDING_DIM
}
fn normalise_dim(mut v: Vec<f32>) -> Vec<f32> {
if v.len() == EMBEDDING_DIM {
return v;
}
if v.len() > EMBEDDING_DIM {
v.truncate(EMBEDDING_DIM);
} else {
v.resize(EMBEDDING_DIM, 0.0);
}
v
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn f32_to_bytes_roundtrip() {
let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
let bytes = f32_to_bytes(&input);
assert_eq!(bytes.len(), input.len() * 4);
let out = bytes_to_f32(&bytes);
assert_eq!(out, input);
}
#[test]
fn normalise_dim_truncates_and_pads() {
let long = vec![0.0; EMBEDDING_DIM + 10];
assert_eq!(normalise_dim(long.clone()).len(), EMBEDDING_DIM);
let short = vec![0.0; 10];
assert_eq!(normalise_dim(short).len(), EMBEDDING_DIM);
let exact = vec![0.0; EMBEDDING_DIM];
assert_eq!(normalise_dim(exact.clone()).len(), EMBEDDING_DIM);
}
#[test]
fn embedding_dim_matches_constant() {
assert_eq!(embedding_dim(), crate::constants::EMBEDDING_DIM);
}
}