use std::time::Duration;
use anyhow::{Context, Result};
use rusqlite::params;
use serde::{Deserialize, Serialize};
use crate::inspect::now_unix;
use crate::store::{Store, VEC_MIRROR_TABLE};
pub const VEC_MIRROR_DIM: usize = 768;
pub trait EmbeddingBackend: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn batch_embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
}
pub trait EmbeddingBackendFactory: Send + Sync {
fn build(&self, model: &str, ollama_url: &str) -> Result<Box<dyn EmbeddingBackend>>;
}
pub struct OllamaBackendFactory;
impl EmbeddingBackendFactory for OllamaBackendFactory {
fn build(&self, model: &str, ollama_url: &str) -> Result<Box<dyn EmbeddingBackend>> {
Ok(Box::new(OllamaClient::new(ollama_url, model)?))
}
}
pub struct MockEmbeddingBackend {
dim: usize,
}
impl MockEmbeddingBackend {
pub fn new(dim: usize) -> Self {
assert!(dim > 0, "MockEmbeddingBackend dim must be > 0");
Self { dim }
}
}
impl Default for MockEmbeddingBackend {
fn default() -> Self {
Self::new(64)
}
}
impl EmbeddingBackend for MockEmbeddingBackend {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut v = vec![0.0f32; self.dim];
for tok in text.split_whitespace() {
let cleaned: String = tok
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.flat_map(|c| c.to_lowercase())
.collect();
if cleaned.is_empty() {
continue;
}
let idx = (fnv1a(cleaned.as_bytes()) as usize) % self.dim;
v[idx] += 1.0;
}
Ok(v)
}
}
pub struct MockBackendFactory {
dim: usize,
}
impl MockBackendFactory {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl Default for MockBackendFactory {
fn default() -> Self {
Self::new(64)
}
}
impl EmbeddingBackendFactory for MockBackendFactory {
fn build(&self, _model: &str, _ollama_url: &str) -> Result<Box<dyn EmbeddingBackend>> {
Ok(Box::new(MockEmbeddingBackend::new(self.dim)))
}
}
fn fnv1a(bytes: &[u8]) -> u64 {
let mut h: u64 = 0xcbf29ce484222325;
for b in bytes {
h ^= *b as u64;
h = h.wrapping_mul(0x100000001b3);
}
h
}
pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
pub const DEFAULT_EMBED_MODEL: &str = "nomic-embed-text";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbedRole {
Document,
Query,
}
const DEFAULT_QWEN3_QUERY_INSTRUCTION: &str =
"Represent the query for retrieving relevant documents.";
#[derive(Debug, Clone, Copy)]
enum EmbedProfile {
Nomic,
Qwen3,
Raw,
}
impl EmbedProfile {
fn for_model(model: &str) -> Self {
if model.starts_with("qwen3-embedding") {
Self::Qwen3
} else if model.starts_with("nomic-embed-text") {
Self::Nomic
} else {
Self::Raw
}
}
fn format(&self, role: EmbedRole, text: &str, instruction: Option<&str>) -> String {
match self {
Self::Nomic => match role {
EmbedRole::Document => format!("search_document: {text}"),
EmbedRole::Query => format!("search_query: {text}"),
},
Self::Qwen3 => match role {
EmbedRole::Document => text.to_string(),
EmbedRole::Query => {
let instruct = instruction.unwrap_or(DEFAULT_QWEN3_QUERY_INSTRUCTION);
format!("Instruct: {instruct}\nQuery: {text}")
}
},
Self::Raw => text.to_string(),
}
}
}
pub fn prepare_embedding_text(
model: &str,
role: EmbedRole,
text: &str,
query_instruction: Option<&str>,
) -> String {
EmbedProfile::for_model(model).format(role, text, query_instruction)
}
#[derive(Debug, Clone)]
pub struct EmbedOptions {
pub model: String,
pub ollama_url: String,
pub limit: Option<usize>,
}
impl Default for EmbedOptions {
fn default() -> Self {
Self {
model: DEFAULT_EMBED_MODEL.to_string(),
ollama_url: DEFAULT_OLLAMA_URL.to_string(),
limit: None,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbedReport {
pub model: String,
pub dim: Option<usize>,
pub embedded: usize,
pub already_had: usize,
pub failed: usize,
}
pub struct OllamaClient {
base_url: String,
model: String,
http: reqwest::blocking::Client,
}
impl OllamaClient {
pub fn new(base_url: &str, model: &str) -> Result<Self> {
let http = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(120))
.build()
.context("building reqwest client")?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
http,
})
}
}
impl EmbeddingBackend for OllamaClient {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
#[derive(Serialize)]
struct Req<'a> {
model: &'a str,
prompt: &'a str,
}
#[derive(Deserialize)]
struct Resp {
embedding: Vec<f32>,
}
let url = format!("{}/api/embeddings", self.base_url);
let resp = self
.http
.post(&url)
.json(&Req {
model: &self.model,
prompt: text,
})
.send()
.map_err(|e| humanize_request_error(&url, &self.model, e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().unwrap_or_default();
anyhow::bail!(
"ollama returned {status} from {url}: {}",
body.trim().chars().take(200).collect::<String>()
);
}
let parsed: Resp = resp.json().context("parsing ollama response body")?;
if parsed.embedding.is_empty() {
anyhow::bail!(
"ollama returned an empty embedding (is model '{}' pulled? try `ollama pull {}`)",
self.model,
self.model,
);
}
Ok(parsed.embedding)
}
fn batch_embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
#[derive(Serialize)]
struct Req<'a> {
model: &'a str,
input: &'a [&'a str],
}
#[derive(Deserialize)]
struct Resp {
embeddings: Vec<Vec<f32>>,
}
let url = format!("{}/api/embed", self.base_url);
let resp = self
.http
.post(&url)
.json(&Req {
model: &self.model,
input: texts,
})
.send()
.map_err(|e| humanize_request_error(&url, &self.model, e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().unwrap_or_default();
anyhow::bail!(
"ollama returned {status} from {url}: {}",
body.trim().chars().take(200).collect::<String>()
);
}
let parsed: Resp = resp.json().context("parsing ollama response body")?;
if parsed.embeddings.len() != texts.len() {
anyhow::bail!(
"ollama returned {} embeddings for {} inputs",
parsed.embeddings.len(),
texts.len()
);
}
if parsed.embeddings.iter().any(|e| e.is_empty()) {
anyhow::bail!(
"ollama returned an empty embedding (is model '{}' pulled? try `ollama pull {}`)",
self.model,
self.model,
);
}
Ok(parsed.embeddings)
}
}
fn humanize_request_error(url: &str, model: &str, err: reqwest::Error) -> anyhow::Error {
if err.is_connect() || err.is_timeout() {
anyhow::anyhow!(
"could not reach ollama at {url}: {err}. \
Is ollama running? Start it with `ollama serve`, \
then pull the embedding model with `ollama pull {model}`.",
)
} else {
anyhow::anyhow!("request to {url} failed: {err}")
}
}
pub fn embed_missing(store: &mut Store, opts: &EmbedOptions) -> Result<EmbedReport> {
let client = OllamaClient::new(&opts.ollama_url, &opts.model)?;
embed_missing_with(store, opts, &client)
}
pub fn embed_missing_with(
store: &mut Store,
opts: &EmbedOptions,
backend: &dyn EmbeddingBackend,
) -> Result<EmbedReport> {
let already_had: i64 = store.conn().query_row(
"SELECT COUNT(*) FROM embeddings WHERE model = ?1",
params![opts.model],
|row| row.get(0),
)?;
let pending = pending_chunks(store, &opts.model, opts.limit)?;
let mut embedded = 0usize;
let mut failed = 0usize;
let mut dim: Option<usize> = None;
const BATCH_SIZE: usize = 32;
let mirror_model = opts.model == DEFAULT_EMBED_MODEL;
let tx = store.conn_mut().transaction()?;
for batch in pending.chunks(BATCH_SIZE) {
let prepared_texts: Vec<String> = batch
.iter()
.map(|(_, t)| prepare_embedding_text(&opts.model, EmbedRole::Document, t, None))
.collect();
let texts: Vec<&str> = prepared_texts.iter().map(|s| s.as_str()).collect();
match backend.batch_embed(&texts) {
Ok(vectors) => {
for ((chunk_id, _), vec) in batch.iter().zip(vectors) {
let this_dim = vec.len();
if let Some(existing) = dim {
if existing != this_dim {
anyhow::bail!(
"ollama returned inconsistent embedding dimensions: {existing} then {this_dim}"
);
}
} else {
dim = Some(this_dim);
}
let blob = f32s_to_blob(&vec);
tx.execute(
"INSERT OR REPLACE INTO embeddings (chunk_id, model, dim, embedding, created_at)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![chunk_id, opts.model, this_dim as i64, blob, now_unix()],
)?;
if mirror_model && this_dim == VEC_MIRROR_DIM {
mirror_write(&tx, chunk_id, &blob)?;
}
embedded += 1;
}
}
Err(err) => {
if embedded == 0 {
return Err(err);
}
eprintln!(
"warning: embedding batch of {} chunks failed: {err:#}",
batch.len()
);
failed += batch.len();
}
}
}
tx.commit()?;
Ok(EmbedReport {
model: opts.model.clone(),
dim,
embedded,
already_had: already_had as usize,
failed,
})
}
fn mirror_write(tx: &rusqlite::Transaction, chunk_id: &str, blob: &[u8]) -> Result<()> {
let rowid: i64 = tx.query_row(
"SELECT rowid FROM chunks WHERE id = ?1",
params![chunk_id],
|row| row.get(0),
)?;
tx.execute(
&format!("DELETE FROM {VEC_MIRROR_TABLE} WHERE rowid = ?1"),
params![rowid],
)?;
tx.execute(
&format!("INSERT INTO {VEC_MIRROR_TABLE} (rowid, embedding) VALUES (?1, ?2)"),
params![rowid, blob],
)?;
Ok(())
}
fn pending_chunks(
store: &Store,
model: &str,
limit: Option<usize>,
) -> Result<Vec<(String, String)>> {
let (sql, use_limit) = match limit {
Some(_) => (
"SELECT c.id, c.text FROM chunks c
LEFT JOIN embeddings e ON e.chunk_id = c.id AND e.model = ?1
WHERE e.chunk_id IS NULL
ORDER BY c.rowid
LIMIT ?2",
true,
),
None => (
"SELECT c.id, c.text FROM chunks c
LEFT JOIN embeddings e ON e.chunk_id = c.id AND e.model = ?1
WHERE e.chunk_id IS NULL
ORDER BY c.rowid",
false,
),
};
let conn = store.conn();
let mut stmt = conn.prepare(sql)?;
let rows = if use_limit {
stmt.query_map(params![model, limit.unwrap() as i64], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
})?
.collect::<Result<Vec<_>, _>>()?
} else {
stmt.query_map(params![model], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
})?
.collect::<Result<Vec<_>, _>>()?
};
Ok(rows)
}
pub fn f32s_to_blob(v: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(v.len() * 4);
for x in v {
out.extend_from_slice(&x.to_le_bytes());
}
out
}
pub fn blob_to_f32s(bytes: &[u8]) -> Result<Vec<f32>> {
if !bytes.len().is_multiple_of(4) {
anyhow::bail!(
"embedding blob length {} is not a multiple of 4",
bytes.len()
);
}
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Ok(out)
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut na = 0.0f32;
let mut nb = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot / (na.sqrt() * nb.sqrt())
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingStats {
pub model: String,
pub dim: i64,
pub count: i64,
}
pub fn embedding_stats(store: &Store) -> Result<Vec<EmbeddingStats>> {
let conn = store.conn();
let mut stmt = conn.prepare(
"SELECT model, dim, COUNT(*) FROM embeddings GROUP BY model, dim ORDER BY COUNT(*) DESC",
)?;
let rows = stmt.query_map([], |row| {
Ok(EmbeddingStats {
model: row.get(0)?,
dim: row.get(1)?,
count: row.get(2)?,
})
})?;
Ok(rows.collect::<Result<Vec<_>, _>>()?)
}
pub fn print_text(report: &EmbedReport) {
println!(
"embedded {} chunks (already had {}, failed {}) model={} dim={}",
report.embedded,
report.already_had,
report.failed,
report.model,
report
.dim
.map(|d| d.to_string())
.unwrap_or_else(|| "-".into()),
);
}
pub fn print_json(report: &EmbedReport) -> Result<()> {
println!("{}", serde_json::to_string_pretty(report)?);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn nomic_profile_prefixes_both_roles() {
assert_eq!(
prepare_embedding_text("nomic-embed-text", EmbedRole::Document, "hello", None),
"search_document: hello"
);
assert_eq!(
prepare_embedding_text("nomic-embed-text", EmbedRole::Query, "hello", None),
"search_query: hello"
);
}
#[test]
fn qwen3_query_uses_default_instruction_and_override() {
assert_eq!(
prepare_embedding_text("qwen3-embedding-4b", EmbedRole::Document, "hello", None),
"hello"
);
assert_eq!(
prepare_embedding_text("qwen3-embedding-4b", EmbedRole::Query, "hello", None),
format!(
"Instruct: {}\nQuery: hello",
DEFAULT_QWEN3_QUERY_INSTRUCTION
)
);
assert_eq!(
prepare_embedding_text(
"qwen3-embedding-4b",
EmbedRole::Query,
"hello",
Some("find the relevant thing"),
),
"Instruct: find the relevant thing\nQuery: hello"
);
}
#[test]
fn unknown_models_stay_raw() {
assert_eq!(
prepare_embedding_text("other-model", EmbedRole::Document, "hello", None),
"hello"
);
assert_eq!(
prepare_embedding_text("other-model", EmbedRole::Query, "hello", None),
"hello"
);
}
#[test]
fn blob_roundtrip_preserves_values() {
let v = vec![0.0, 1.0, -1.5, std::f32::consts::PI, f32::MIN, f32::MAX];
let round = blob_to_f32s(&f32s_to_blob(&v)).unwrap();
assert_eq!(round, v);
}
#[test]
fn cosine_identical_is_one() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_is_zero() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_handles_zero_and_mismatched() {
assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
assert_eq!(cosine_similarity(&[1.0, 1.0], &[1.0]), 0.0);
}
#[test]
fn blob_rejects_odd_length() {
assert!(blob_to_f32s(&[0, 1, 2]).is_err());
}
#[test]
fn mock_backend_is_deterministic() {
let b = MockEmbeddingBackend::new(32);
let a = b.embed("hello world").unwrap();
let c = b.embed("hello world").unwrap();
assert_eq!(a, c);
assert_eq!(a.len(), 32);
}
#[test]
fn mock_backend_shared_tokens_have_higher_cosine() {
let b = MockEmbeddingBackend::new(128);
let q = b.embed("rust programming").unwrap();
let near = b.embed("rust is a systems programming language").unwrap();
let far = b.embed("apples fall from trees in autumn").unwrap();
assert!(cosine_similarity(&q, &near) > cosine_similarity(&q, &far));
}
}