Skip to main content

sqlite_graphrag/
embedder.rs

1//! Embedding generation for the GraphRAG memory.
2//!
3//! v1.0.76: the default build is **LLM-only** — the binary does NOT bundle
4//! fastembed / ort / ndarray / tokenizers. All embeddings are produced
5//! by a headless invocation of `claude code` or `codex` (OAuth, no MCP,
6//! no hooks) and stored as a BLOB in `memory_embeddings(memory_id, embedding,
7//! source)`. Vector similarity is computed in pure Rust at query time.
8//!
9//! The legacy fastembed pipeline is still available behind the opt-in
10//! `embedding-legacy` feature for the transition window. It is removed
11//! in v1.1.0. New code MUST use the LLM path (`embed_passage` /
12//! `embed_query` here, which always call the LLM).
13
14use crate::constants::EMBEDDING_DIM;
15use crate::errors::AppError;
16use crate::extract::llm_embedding::LlmEmbedding;
17use parking_lot::Mutex;
18use std::path::Path;
19use std::sync::OnceLock;
20
21/// Process-wide LLM-embedding client behind a `Mutex`.
22///
23/// The client is a thin wrapper around a single in-flight `claude code` or
24/// `codex` subprocess. Each call blocks until the LLM returns the
25/// embedding; the daemon was removed in v1.0.76 to make the CLI one-shot.
26static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
27
28/// Initialises the LLM-embedding client on first use and returns it.
29pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
30    if let Some(e) = EMBEDDER.get() {
31        return Ok(e);
32    }
33    let backend = LlmEmbedding::detect_available()?;
34    let _ = EMBEDDER.set(Mutex::new(backend));
35    Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
36}
37
38/// Embeds a single passage for storage. Delegates to the configured LLM
39/// headless (claude code / codex). Returns a 384-dim f32 vector.
40pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
41    let mut guard = embedder.lock();
42    let result = guard.embed_passage(text)?;
43    Ok(normalise_dim(result))
44}
45
46/// Embeds a single query for similarity search. Same model and dim as
47/// `embed_passage`; the only difference is the LLM-side prompt prefix
48/// that the headless invocation uses to disambiguate.
49pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
50    let mut guard = embedder.lock();
51    let result = guard.embed_query(text)?;
52    Ok(normalise_dim(result))
53}
54
55/// Embeds a batch of passages with token-count-aware batching. The
56/// `token_counts` are still used to keep the LLM invocation under
57/// the per-call context budget, but the count is now an approximation
58/// (whitespace-split words) since the `tokenizers` crate was removed.
59pub fn embed_passages_controlled(
60    embedder: &Mutex<LlmEmbedding>,
61    texts: &[&str],
62    token_counts: &[usize],
63) -> Result<Vec<Vec<f32>>, AppError> {
64    if texts.is_empty() {
65        return Ok(Vec::new());
66    }
67    let mut output: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
68    let mut group: Vec<&str> = Vec::new();
69    let mut current_padded = 0usize;
70    for (text, &tokens) in texts.iter().zip(token_counts.iter()) {
71        let padded = tokens.saturating_add(8);
72        if (current_padded + padded > crate::constants::REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS
73            || group.len() >= crate::constants::REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS)
74            && !group.is_empty()
75        {
76            flush_group(&mut output, &mut group, embedder)?;
77            current_padded = 0;
78        }
79        group.push(text);
80        current_padded += padded;
81    }
82    if !group.is_empty() {
83        flush_group(&mut output, &mut group, embedder)?;
84    }
85    Ok(output)
86}
87
88fn flush_group(
89    output: &mut Vec<Vec<f32>>,
90    group: &mut Vec<&str>,
91    embedder: &Mutex<LlmEmbedding>,
92) -> Result<(), AppError> {
93    let mut guard = embedder.lock();
94    for text in group.iter() {
95        let v = guard.embed_passage(text)?;
96        output.push(normalise_dim(v));
97    }
98    group.clear();
99    Ok(())
100}
101
102pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
103    let embedder = get_embedder(models_dir)?;
104    embed_passage(embedder, text)
105}
106
107pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
108    let embedder = get_embedder(models_dir)?;
109    embed_query(embedder, text)
110}
111
112pub fn embed_passages_controlled_local(
113    models_dir: &Path,
114    texts: &[&str],
115    token_counts: &[usize],
116) -> Result<Vec<Vec<f32>>, AppError> {
117    let embedder = get_embedder(models_dir)?;
118    embed_passages_controlled(embedder, texts, token_counts)
119}
120
121pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
122    let mut out = Vec::with_capacity(v.len() * 4);
123    for f in v {
124        out.extend_from_slice(&f.to_le_bytes());
125    }
126    out
127}
128
129pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
130    let mut out = Vec::with_capacity(bytes.len() / 4);
131    for chunk in bytes.chunks_exact(4) {
132        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
133    }
134    out
135}
136
137/// Returns the dimensionality of the embedding space. Used to
138/// validate LLM responses and to size the in-memory cache.
139pub fn embedding_dim() -> usize {
140    EMBEDDING_DIM
141}
142
143fn normalise_dim(mut v: Vec<f32>) -> Vec<f32> {
144    if v.len() == EMBEDDING_DIM {
145        return v;
146    }
147    if v.len() > EMBEDDING_DIM {
148        v.truncate(EMBEDDING_DIM);
149    } else {
150        v.resize(EMBEDDING_DIM, 0.0);
151    }
152    v
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn f32_to_bytes_roundtrip() {
161        let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
162        let bytes = f32_to_bytes(&input);
163        assert_eq!(bytes.len(), input.len() * 4);
164        let out = bytes_to_f32(&bytes);
165        assert_eq!(out, input);
166    }
167
168    #[test]
169    fn normalise_dim_truncates_and_pads() {
170        let long = vec![0.0; EMBEDDING_DIM + 10];
171        assert_eq!(normalise_dim(long.clone()).len(), EMBEDDING_DIM);
172        let short = vec![0.0; 10];
173        assert_eq!(normalise_dim(short).len(), EMBEDDING_DIM);
174        let exact = vec![0.0; EMBEDDING_DIM];
175        assert_eq!(normalise_dim(exact.clone()).len(), EMBEDDING_DIM);
176    }
177
178    #[test]
179    fn embedding_dim_matches_constant() {
180        assert_eq!(embedding_dim(), crate::constants::EMBEDDING_DIM);
181    }
182}