use anyhow::{anyhow, Result};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::OnceCell;
use rusqlite::{params, Connection};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
static MODEL: OnceCell<TextEmbedding> = OnceCell::new();
static FORCE_CPU: AtomicBool = AtomicBool::new(false);
pub fn set_force_cpu(force: bool) {
FORCE_CPU.store(force, Ordering::Relaxed);
}
#[allow(dead_code)]
fn force_cpu() -> bool {
FORCE_CPU.load(Ordering::Relaxed)
}
pub fn model_cache_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")))
.join("tokenix")
.join("models")
}
static TOKENIZER: OnceCell<Option<tokenizers::Tokenizer>> = OnceCell::new();
pub fn count_tokens_accurate(text: &str) -> usize {
let tok = TOKENIZER.get_or_init(|| {
let dir = model_cache_dir();
let path = find_tokenizer_json(&dir)?;
tokenizers::Tokenizer::from_file(path).ok()
});
match tok {
Some(t) => t
.encode(text, false)
.map(|e| e.len())
.unwrap_or_else(|_| crate::chunker::count_tokens(text)),
None => crate::chunker::count_tokens(text),
}
}
fn find_tokenizer_json(dir: &std::path::Path) -> Option<PathBuf> {
let rd = std::fs::read_dir(dir).ok()?;
for entry in rd.flatten() {
let p = entry.path();
if p.is_dir() {
if let Some(found) = find_tokenizer_json(&p) {
return Some(found);
}
} else if p.file_name().is_some_and(|n| n == "tokenizer.json") {
return Some(p);
}
}
None
}
pub fn gpu_backend() -> Option<&'static str> {
#[cfg(feature = "cuda")]
{
return Some("CUDA");
}
#[cfg(all(feature = "directml", target_os = "windows"))]
{
return Some("DirectML");
}
#[allow(unreachable_code)]
None
}
fn open_query_cache_db() -> Option<Connection> {
let dir = dirs::home_dir()?.join(".tokenix");
std::fs::create_dir_all(&dir).ok()?;
let path = dir.join("query_cache.db");
let conn = Connection::open(&path).ok()?;
conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL;")
.ok()?;
conn.execute(
"CREATE TABLE IF NOT EXISTS query_cache (
query_text TEXT PRIMARY KEY,
embedding BLOB NOT NULL
)",
[],
)
.ok()?;
Some(conn)
}
fn serialize_vec(v: &[f32]) -> Vec<u8> {
v.iter().flat_map(|f| f.to_le_bytes()).collect()
}
fn deserialize_vec(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect()
}
fn model() -> Result<&'static TextEmbedding> {
MODEL
.get_or_try_init(|| {
#[allow(unused_unsafe)]
if std::env::var("OMP_NUM_THREADS").is_err() {
unsafe { std::env::set_var("OMP_NUM_THREADS", "1") };
}
let cache_dir = model_cache_dir();
std::fs::create_dir_all(&cache_dir).ok();
#[allow(unused_mut)]
let mut options =
InitOptions::new(EmbeddingModel::NomicEmbedTextV15Q).with_cache_dir(cache_dir);
#[cfg(feature = "cuda")]
if !force_cpu() {
options = options.with_execution_providers(vec![
ort::execution_providers::CUDAExecutionProvider::default().build(),
ort::execution_providers::CPUExecutionProvider::default().build(),
]);
}
#[cfg(feature = "directml")]
if !force_cpu() {
options = options.with_execution_providers(vec![
ort::execution_providers::DirectMLExecutionProvider::default().build(),
ort::execution_providers::CPUExecutionProvider::default().build(),
]);
}
TextEmbedding::try_new(options)
})
.map_err(|e| anyhow!("Embedding model init failed: {e}"))
}
pub fn embed_documents(texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let prefixed: Vec<String> = texts
.iter()
.map(|t| format!("search_document: {t}"))
.collect();
model()?.embed(prefixed, None).map_err(|e| anyhow!("{e}"))
}
pub fn embed_query(text: &str) -> Result<Vec<f32>> {
if let Some(conn) = open_query_cache_db() {
if let Ok(mut stmt) =
conn.prepare("SELECT embedding FROM query_cache WHERE query_text = ?1")
{
if let Ok(mut rows) = stmt.query(params![text]) {
if let Ok(Some(row)) = rows.next() {
if let Ok(blob) = row.get::<_, Vec<u8>>(0) {
return Ok(deserialize_vec(&blob));
}
}
}
}
}
let prefixed = format!("search_query: {text}");
let vec = model()?
.embed(vec![prefixed], None)
.map_err(|e| anyhow!("{e}"))?
.into_iter()
.next()
.ok_or_else(|| anyhow!("Empty embedding response"))?;
if let Some(conn) = open_query_cache_db() {
let blob = serialize_vec(&vec);
let _ = conn.execute(
"INSERT OR REPLACE INTO query_cache (query_text, embedding) VALUES (?1, ?2)",
params![text, blob],
);
}
Ok(vec)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg_attr(
not(feature = "model-tests"),
ignore = "needs model download; run with --features model-tests"
)]
fn embed_query_returns_768_dims() {
let vec = embed_query("hello world").expect("embed_query failed");
assert_eq!(
vec.len(),
768,
"nomic-embed-text-v1.5 produces 768-dim vectors"
);
}
#[test]
fn embed_documents_empty_returns_empty() {
let result = embed_documents(&[]).expect("empty embed should succeed");
assert!(result.is_empty());
}
#[test]
#[cfg_attr(
not(feature = "model-tests"),
ignore = "needs model download; run with --features model-tests"
)]
fn embed_documents_returns_correct_count() {
let texts = vec![
"fn main() {}".to_string(),
"struct Foo { x: i32 }".to_string(),
];
let vecs = embed_documents(&texts).expect("embed_documents failed");
assert_eq!(vecs.len(), 2);
for v in &vecs {
assert_eq!(v.len(), 768);
}
}
#[test]
#[cfg_attr(
not(feature = "model-tests"),
ignore = "needs model download; run with --features model-tests"
)]
fn similar_texts_have_higher_cosine_similarity() {
let q = embed_query("database connection pool").unwrap();
let doc_similar =
embed_documents(&["database connection pooling strategy".to_string()]).unwrap();
let doc_different =
embed_documents(&["sorting algorithms bubble sort".to_string()]).unwrap();
let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(x, y)| x * y).sum() };
let norm = |a: &[f32]| -> f32 { a.iter().map(|x| x * x).sum::<f32>().sqrt() };
let cosine = |a: &[f32], b: &[f32]| dot(a, b) / (norm(a) * norm(b));
let sim_similar = cosine(&q, &doc_similar[0]);
let sim_different = cosine(&q, &doc_different[0]);
assert!(
sim_similar > sim_different,
"similar={sim_similar:.3} should > different={sim_different:.3}"
);
}
#[test]
#[cfg_attr(
not(feature = "model-tests"),
ignore = "needs model download; run with --features model-tests"
)]
fn test_query_cache_persistence() {
let query = "test_persistent_cache_query_string_12345";
if let Some(conn) = open_query_cache_db() {
let _ = conn.execute("DELETE FROM query_cache WHERE query_text = ?1", [query]);
}
let vec1 = embed_query(query).expect("First embed failed");
let vec2 = embed_query(query).expect("Second embed failed");
assert_eq!(vec1.len(), 768);
assert_eq!(vec1, vec2);
}
}