use std::collections::{HashMap, HashSet};
use anyhow::{anyhow, bail, Context, Result};
use arrow::array::{
Array, BooleanArray, Float32Array, Float32Builder, Int32Array, ListArray, ListBuilder,
RecordBatch, StringArray, TimestampMicrosecondArray,
};
use arrow::datatypes::{DataType, FieldRef, Schema as ArrowSchema};
use chrono::Utc;
use futures::TryStreamExt;
use iceberg::arrow::schema_to_arrow_schema;
use iceberg::expr::Reference;
use iceberg::spec::Datum;
use iceberg::Catalog;
use sha2::{Digest, Sha256};
use std::sync::Arc;
use uuid::Uuid;
use super::chunk::Chunk;
use super::VectorIndex;
use crate::warehouse::iceberg::{
append_batch, IcebergWarehouse, TABLE_EMBEDDINGS, TABLE_EMBEDDING_MANIFEST,
TABLE_EMBEDDING_SNAPSHOTS,
};
const CHUNKER_VERSION: &str = "v1";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModelProfile {
pub model_name: String,
pub weights_sha: String,
pub tokenizer_sha: String,
pub pooling: String,
pub normalize: bool,
pub dim: usize,
pub dtype: String,
}
impl ModelProfile {
pub fn id(&self) -> String {
let canonical = format!(
"model={}\0weights={}\0tokenizer={}\0pooling={}\0normalize={}\0dim={}\0dtype={}",
self.model_name,
self.weights_sha,
self.tokenizer_sha,
self.pooling,
self.normalize,
self.dim,
self.dtype,
);
hex_sha256(canonical.as_bytes())
}
}
pub fn chunker_hash(opts: &super::chunk::ChunkOptions) -> String {
hex_sha256(
format!(
"{CHUNKER_VERSION}\0max_lines={}\0overlap={}",
opts.max_lines, opts.overlap
)
.as_bytes(),
)
}
pub trait Embedder: Send + Sync {
fn profile(&self) -> ModelProfile;
fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
}
#[derive(Debug, Clone)]
pub struct SnapshotRef {
pub snapshot_id: Uuid,
pub repo: String,
pub git_sha: String,
pub model_profile: String,
pub chunker_hash: String,
pub fileset_hash: String,
pub occurrences: usize,
pub new_vectors: usize,
}
pub struct IndexParams<'a> {
pub workspace: &'a str,
pub repo: &'a str,
pub git_sha: &'a str,
pub branch: &'a str,
pub model: &'a ModelProfile,
pub chunker_hash: &'a str,
pub complete: bool,
}
pub struct RepoRef<'a> {
pub workspace: &'a str,
pub repo: &'a str,
pub git_sha: &'a str,
pub branch: &'a str,
pub complete: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Occurrence {
pub content_hash: String,
pub file: String,
pub start_line: usize,
pub end_line: usize,
}
pub struct Reconstructed {
pub index: VectorIndex,
pub snapshot_id: Uuid,
pub git_sha: String,
pub by_id: HashMap<u64, Occurrence>,
}
pub fn collect_rust_sources(root: &std::path::Path) -> Vec<(String, String)> {
let mut out = Vec::new();
for entry in walkdir::WalkDir::new(root).into_iter().filter_map(|e| e.ok()) {
let p = entry.path();
if !p.is_file() || p.extension().and_then(|e| e.to_str()) != Some("rs") {
continue;
}
if p.components().any(|c| c.as_os_str() == "target") {
continue;
}
if let Ok(s) = std::fs::read_to_string(p) {
let rel = p
.strip_prefix(root)
.unwrap_or(p)
.to_string_lossy()
.into_owned();
out.push((rel, s));
}
}
out.sort();
out
}
pub fn index_repo<E: Embedder + ?Sized>(
wh: &IcebergWarehouse,
r: &RepoRef,
files: &[(String, String)],
opts: &super::chunk::ChunkOptions,
embedder: &E,
) -> Result<SnapshotRef> {
let chunks: Vec<Chunk> = if files.is_empty() {
Vec::new()
} else {
let nthreads = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
.min(files.len());
let part = files.len().div_ceil(nthreads).max(1);
std::thread::scope(|s| {
let handles: Vec<_> = files
.chunks(part)
.map(|fs| {
s.spawn(move || {
let mut v = Vec::new();
for (path, content) in fs {
v.extend(super::chunk::chunk_file(path, content, opts));
}
v
})
})
.collect();
handles.into_iter().flat_map(|h| h.join().unwrap()).collect()
})
};
let model = embedder.profile();
let ch = chunker_hash(opts);
let p = IndexParams {
workspace: r.workspace,
repo: r.repo,
git_sha: r.git_sha,
branch: r.branch,
model: &model,
chunker_hash: &ch,
complete: r.complete,
};
index_snapshot(wh, &p, &chunks, |missing| {
let texts: Vec<String> = missing.iter().map(|c| c.text.clone()).collect();
embedder.embed(&texts)
})
}
pub fn index_snapshot<F>(
wh: &IcebergWarehouse,
p: &IndexParams,
chunks: &[Chunk],
embed_missing: F,
) -> Result<SnapshotRef>
where
F: FnOnce(&[Chunk]) -> Result<Vec<Vec<f32>>>,
{
let dim = p.model.dim;
let model_profile = p.model.id();
let fileset_hash = fileset_hash_of(chunks);
wh.block_on(async {
if let Some(existing) =
find_snapshot(wh, p.repo, p.git_sha, &model_profile, p.chunker_hash, &fileset_hash)
.await?
{
return Ok(SnapshotRef {
snapshot_id: existing,
repo: p.repo.into(),
git_sha: p.git_sha.into(),
model_profile,
chunker_hash: p.chunker_hash.into(),
fileset_hash,
occurrences: chunks.len(),
new_vectors: 0,
});
}
let existing = existing_content_hashes(wh, &model_profile).await?;
let mut missing: Vec<Chunk> = Vec::new();
let mut seen: HashSet<&str> = HashSet::new();
for c in chunks {
let h = c.content_hash.as_str();
if !existing.contains(h) && seen.insert(h) {
missing.push(c.clone());
}
}
let new_vectors = missing.len();
if !missing.is_empty() {
let vectors = embed_missing(&missing)?;
if vectors.len() != missing.len() {
bail!(
"embedder returned {} vectors for {} chunks",
vectors.len(),
missing.len()
);
}
let mut to_write: Vec<(&str, &[f32])> = Vec::with_capacity(missing.len());
for (c, v) in missing.iter().zip(&vectors) {
ensure_dim(v, dim, c)?;
to_write.push((c.content_hash.as_str(), v.as_slice()));
}
append_embeddings(wh, &model_profile, dim, p.model.normalize, &to_write).await?;
}
let snapshot_id = Uuid::new_v4();
append_manifest(wh, snapshot_id, &model_profile, chunks).await?;
append_snapshot_row(wh, snapshot_id, p, &model_profile, &fileset_hash, chunks.len())
.await?;
Ok(SnapshotRef {
snapshot_id,
repo: p.repo.into(),
git_sha: p.git_sha.into(),
model_profile,
chunker_hash: p.chunker_hash.into(),
fileset_hash,
occurrences: chunks.len(),
new_vectors,
})
})
}
#[derive(Debug, Default, Clone)]
pub struct WarehouseStats {
pub embeddings: usize,
pub snapshots: usize,
pub occurrences: usize,
}
pub fn warehouse_stats(wh: &IcebergWarehouse) -> Result<WarehouseStats> {
wh.block_on(async {
Ok(WarehouseStats {
embeddings: count_rows(wh, TABLE_EMBEDDINGS).await?,
snapshots: count_rows(wh, TABLE_EMBEDDING_SNAPSHOTS).await?,
occurrences: count_rows(wh, TABLE_EMBEDDING_MANIFEST).await?,
})
})
}
async fn count_rows(wh: &IcebergWarehouse, table: &str) -> Result<usize> {
let t = wh.catalog().load_table(&wh.table_ident(table)).await?;
let scan = t.scan().select(["model_profile"]).build()?;
let batches: Vec<RecordBatch> = scan.to_arrow().await?.try_collect().await?;
Ok(batches.iter().map(|b| b.num_rows()).sum())
}
pub fn reconstruct(
wh: &IcebergWarehouse,
repo: &str,
git_sha: Option<&str>,
model_profile: &str,
) -> Result<Reconstructed> {
wh.block_on(async {
let (snapshot_id, sha) = resolve_snapshot(wh, repo, git_sha, model_profile).await?;
let manifest = read_manifest(wh, snapshot_id).await?;
if manifest.is_empty() {
bail!("snapshot {snapshot_id} has an empty manifest");
}
let needed: HashSet<&str> = manifest.iter().map(|m| m.content_hash.as_str()).collect();
let (vectors, dim) = read_vectors(wh, model_profile, &needed).await?;
let mut index = VectorIndex::new(dim)?;
let mut by_id = HashMap::with_capacity(manifest.len());
let mut flat: Vec<f32> = Vec::with_capacity(manifest.len() * dim);
let mut ids: Vec<u64> = Vec::with_capacity(manifest.len());
for occ in &manifest {
let v = vectors
.get(occ.content_hash.as_str())
.ok_or_else(|| anyhow!("manifest references missing vector {}", occ.content_hash))?;
flat.extend_from_slice(v);
ids.push(occ.ordinal as u64);
by_id.insert(
occ.ordinal as u64,
Occurrence {
content_hash: occ.content_hash.clone(),
file: occ.file.clone(),
start_line: occ.start_line,
end_line: occ.end_line,
},
);
}
index.add(&flat, &ids)?;
Ok(Reconstructed {
index,
snapshot_id,
git_sha: sha,
by_id,
})
})
}
pub fn search(
wh: &IcebergWarehouse,
repo: &str,
git_sha: Option<&str>,
model_profile: &str,
query: &[f32],
k: usize,
) -> Result<Vec<(f32, Occurrence)>> {
let r = reconstruct(wh, repo, git_sha, model_profile)?;
let hits = r.index.search(query, k);
let mut out = Vec::with_capacity(hits.len());
for (id, score) in hits {
let occ = r
.by_id
.get(&id)
.ok_or_else(|| anyhow!("hit id {id} has no occurrence"))?
.clone();
out.push((score, occ));
}
Ok(out)
}
struct ManifestRow {
ordinal: i32,
content_hash: String,
file: String,
start_line: usize,
end_line: usize,
}
async fn find_snapshot(
wh: &IcebergWarehouse,
repo: &str,
git_sha: &str,
model_profile: &str,
chunker_hash: &str,
fileset_hash: &str,
) -> Result<Option<Uuid>> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDING_SNAPSHOTS))
.await?;
let predicate = Reference::new("repo")
.equal_to(Datum::string(repo))
.and(Reference::new("git_sha").equal_to(Datum::string(git_sha)))
.and(Reference::new("model_profile").equal_to(Datum::string(model_profile)));
let scan = table
.scan()
.with_filter(predicate)
.select([
"snapshot_id",
"repo",
"git_sha",
"model_profile",
"chunker_hash",
"fileset_hash",
])
.build()?;
let batches: Vec<RecordBatch> = scan.to_arrow().await?.try_collect().await?;
for b in &batches {
let ids = col::<StringArray>(b, "snapshot_id")?;
let repos = col::<StringArray>(b, "repo")?;
let shas = col::<StringArray>(b, "git_sha")?;
let mps = col::<StringArray>(b, "model_profile")?;
let chs = col::<StringArray>(b, "chunker_hash")?;
let fhs = col::<StringArray>(b, "fileset_hash")?;
for i in 0..b.num_rows() {
if repos.value(i) == repo
&& shas.value(i) == git_sha
&& mps.value(i) == model_profile
&& chs.value(i) == chunker_hash
&& fhs.value(i) == fileset_hash
{
return Ok(Some(Uuid::parse_str(ids.value(i))?));
}
}
}
Ok(None)
}
async fn existing_content_hashes(
wh: &IcebergWarehouse,
model_profile: &str,
) -> Result<HashSet<String>> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDINGS))
.await?;
let scan = table
.scan()
.with_filter(Reference::new("model_profile").equal_to(Datum::string(model_profile)))
.select(["content_hash", "model_profile"])
.build()?;
let batches: Vec<RecordBatch> = scan.to_arrow().await?.try_collect().await?;
let mut out = HashSet::new();
for b in &batches {
let hs = col::<StringArray>(b, "content_hash")?;
let mps = col::<StringArray>(b, "model_profile")?;
for i in 0..b.num_rows() {
if mps.value(i) == model_profile {
out.insert(hs.value(i).to_string());
}
}
}
Ok(out)
}
async fn append_embeddings(
wh: &IcebergWarehouse,
model_profile: &str,
dim: usize,
normalized: bool,
rows: &[(&str, &[f32])],
) -> Result<()> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDINGS))
.await?;
let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?);
let elem = list_element_field(&schema, "vector")?;
let content: Vec<String> = rows.iter().map(|(h, _)| h.to_string()).collect();
let mps: Vec<String> = rows.iter().map(|_| model_profile.to_string()).collect();
let dims: Vec<i32> = rows.iter().map(|_| dim as i32).collect();
let norms: Vec<bool> = rows.iter().map(|_| normalized).collect();
let vhashes: Vec<String> = rows.iter().map(|(_, v)| hex_sha256(f32_le_bytes(v).as_slice())).collect();
let mut vb = ListBuilder::new(Float32Builder::new()).with_field(elem);
for (_, v) in rows {
vb.values().append_slice(v);
vb.append(true);
}
let cols: Vec<Arc<dyn Array>> = vec![
Arc::new(StringArray::from(content)),
Arc::new(StringArray::from(mps)),
Arc::new(Int32Array::from(dims)),
Arc::new(vb.finish()),
Arc::new(BooleanArray::from(norms)),
Arc::new(StringArray::from(vhashes)),
];
let batch = RecordBatch::try_new(schema, cols)?;
append_batch(wh.catalog(), table, batch).await
}
async fn append_manifest(
wh: &IcebergWarehouse,
snapshot_id: Uuid,
model_profile: &str,
chunks: &[Chunk],
) -> Result<()> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDING_MANIFEST))
.await?;
let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?);
let sid = snapshot_id.to_string();
let mut sids = Vec::with_capacity(chunks.len());
let mut ordinals = Vec::with_capacity(chunks.len());
let mut hashes = Vec::with_capacity(chunks.len());
let mut mps = Vec::with_capacity(chunks.len());
let mut files = Vec::with_capacity(chunks.len());
let mut starts = Vec::with_capacity(chunks.len());
let mut ends = Vec::with_capacity(chunks.len());
for (i, c) in chunks.iter().enumerate() {
sids.push(sid.clone());
ordinals.push(i as i32);
hashes.push(c.content_hash.clone());
mps.push(model_profile.to_string());
files.push(c.file.clone());
starts.push(c.start_line as i32);
ends.push(c.end_line as i32);
}
let cols: Vec<Arc<dyn Array>> = vec![
Arc::new(StringArray::from(sids)),
Arc::new(Int32Array::from(ordinals)),
Arc::new(StringArray::from(hashes)),
Arc::new(StringArray::from(mps)),
Arc::new(StringArray::from(files)),
Arc::new(Int32Array::from(starts)),
Arc::new(Int32Array::from(ends)),
];
let batch = RecordBatch::try_new(schema, cols)?;
append_batch(wh.catalog(), table, batch).await
}
async fn append_snapshot_row(
wh: &IcebergWarehouse,
snapshot_id: Uuid,
p: &IndexParams<'_>,
model_profile: &str,
fileset_hash: &str,
row_count: usize,
) -> Result<()> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDING_SNAPSHOTS))
.await?;
let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?);
let cols: Vec<Arc<dyn Array>> = vec![
Arc::new(StringArray::from(vec![snapshot_id.to_string()])),
Arc::new(StringArray::from(vec![p.workspace.to_string()])),
Arc::new(StringArray::from(vec![p.repo.to_string()])),
Arc::new(StringArray::from(vec![p.git_sha.to_string()])),
Arc::new(StringArray::from(vec![p.branch.to_string()])),
Arc::new(StringArray::from(vec![model_profile.to_string()])),
Arc::new(StringArray::from(vec![p.chunker_hash.to_string()])),
Arc::new(StringArray::from(vec![fileset_hash.to_string()])),
Arc::new(BooleanArray::from(vec![p.complete])),
Arc::new(Int32Array::from(vec![row_count as i32])),
Arc::new(
TimestampMicrosecondArray::from(vec![Utc::now().timestamp_micros()])
.with_timezone("+00:00"),
),
];
let batch = RecordBatch::try_new(schema, cols)?;
append_batch(wh.catalog(), table, batch).await
}
async fn resolve_snapshot(
wh: &IcebergWarehouse,
repo: &str,
git_sha: Option<&str>,
model_profile: &str,
) -> Result<(Uuid, String)> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDING_SNAPSHOTS))
.await?;
let predicate = Reference::new("repo")
.equal_to(Datum::string(repo))
.and(Reference::new("model_profile").equal_to(Datum::string(model_profile)));
let scan = table
.scan()
.with_filter(predicate)
.select(["snapshot_id", "repo", "git_sha", "model_profile", "ts_micros"])
.build()?;
let batches: Vec<RecordBatch> = scan.to_arrow().await?.try_collect().await?;
let mut candidates: Vec<(Uuid, String, i64)> = Vec::new();
for b in &batches {
let ids = col::<StringArray>(b, "snapshot_id")?;
let repos = col::<StringArray>(b, "repo")?;
let shas = col::<StringArray>(b, "git_sha")?;
let mps = col::<StringArray>(b, "model_profile")?;
let ts = col::<TimestampMicrosecondArray>(b, "ts_micros")?;
for i in 0..b.num_rows() {
if repos.value(i) != repo || mps.value(i) != model_profile {
continue;
}
if let Some(sha) = git_sha {
if !shas.value(i).starts_with(sha) {
continue;
}
}
candidates.push((Uuid::parse_str(ids.value(i))?, shas.value(i).to_string(), ts.value(i)));
}
}
if candidates.is_empty() {
bail!(
"no embedding snapshot for repo `{repo}`{}",
git_sha.map(|s| format!(" at sha {s}")).unwrap_or_default()
);
}
candidates.sort_by_key(|c| c.2);
let chosen = candidates.pop().unwrap();
Ok((chosen.0, chosen.1))
}
async fn read_manifest(wh: &IcebergWarehouse, snapshot_id: Uuid) -> Result<Vec<ManifestRow>> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDING_MANIFEST))
.await?;
let sid = snapshot_id.to_string();
let scan = table
.scan()
.with_filter(Reference::new("snapshot_id").equal_to(Datum::string(sid.clone())))
.select(["snapshot_id", "ordinal", "content_hash", "file", "start_line", "end_line"])
.build()?;
let batches: Vec<RecordBatch> = scan.to_arrow().await?.try_collect().await?;
let mut out = Vec::new();
for b in &batches {
let sids = col::<StringArray>(b, "snapshot_id")?;
let ords = col::<Int32Array>(b, "ordinal")?;
let hs = col::<StringArray>(b, "content_hash")?;
let files = col::<StringArray>(b, "file")?;
let starts = col::<Int32Array>(b, "start_line")?;
let ends = col::<Int32Array>(b, "end_line")?;
for i in 0..b.num_rows() {
if sids.value(i) != sid {
continue;
}
out.push(ManifestRow {
ordinal: ords.value(i),
content_hash: hs.value(i).to_string(),
file: files.value(i).to_string(),
start_line: starts.value(i) as usize,
end_line: ends.value(i) as usize,
});
}
}
out.sort_by_key(|m| m.ordinal);
Ok(out)
}
async fn read_vectors(
wh: &IcebergWarehouse,
model_profile: &str,
needed: &HashSet<&str>,
) -> Result<(HashMap<String, Vec<f32>>, usize)> {
let table = wh
.catalog()
.load_table(&wh.table_ident(TABLE_EMBEDDINGS))
.await?;
let scan = table
.scan()
.with_filter(Reference::new("model_profile").equal_to(Datum::string(model_profile)))
.select(["content_hash", "model_profile", "dim", "vector"])
.build()?;
let batches: Vec<RecordBatch> = scan.to_arrow().await?.try_collect().await?;
let mut out: HashMap<String, Vec<f32>> = HashMap::with_capacity(needed.len());
let mut dim = 0usize;
for b in &batches {
let hs = col::<StringArray>(b, "content_hash")?;
let mps = col::<StringArray>(b, "model_profile")?;
let dims = col::<Int32Array>(b, "dim")?;
let vecs = col::<ListArray>(b, "vector")?;
for i in 0..b.num_rows() {
if mps.value(i) != model_profile {
continue;
}
let h = hs.value(i);
if !needed.contains(h) || out.contains_key(h) {
continue;
}
let row = vecs.value(i);
let fa = row
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| anyhow!("vector element not Float32Array"))?;
let v: Vec<f32> = fa.values().to_vec();
dim = dims.value(i) as usize;
if v.len() != dim {
bail!("stored vector len {} != dim {dim} for {h}", v.len());
}
out.insert(h.to_string(), v);
}
}
if out.is_empty() {
bail!("no embeddings found for the snapshot's content hashes");
}
Ok((out, dim))
}
fn ensure_dim(v: &[f32], dim: usize, chunk: &Chunk) -> Result<()> {
if v.len() != dim {
bail!(
"embedding for {}:{}-{} has dim {} != model dim {dim}",
chunk.file,
chunk.start_line,
chunk.end_line,
v.len()
);
}
Ok(())
}
fn fileset_hash_of(chunks: &[Chunk]) -> String {
let mut keys: Vec<String> = chunks
.iter()
.map(|c| format!("{}:{}:{}:{}", c.file, c.start_line, c.end_line, c.content_hash))
.collect();
keys.sort();
let mut h = Sha256::new();
for k in keys {
h.update(k.as_bytes());
h.update(b"\n");
}
hex_encode(&h.finalize())
}
fn f32_le_bytes(v: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(v.len() * 4);
for &f in v {
out.extend_from_slice(&f.to_le_bytes());
}
out
}
fn hex_sha256(bytes: &[u8]) -> String {
let mut h = Sha256::new();
h.update(bytes);
hex_encode(&h.finalize())
}
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut s = String::with_capacity(bytes.len() * 2);
for &b in bytes {
s.push(HEX[(b >> 4) as usize] as char);
s.push(HEX[(b & 0x0f) as usize] as char);
}
s
}
fn col<'a, T: 'static>(batch: &'a RecordBatch, name: &str) -> Result<&'a T> {
batch
.column_by_name(name)
.ok_or_else(|| anyhow!("projected batch missing column `{name}`"))?
.as_any()
.downcast_ref::<T>()
.ok_or_else(|| anyhow!("column `{name}` has unexpected arrow type"))
}
fn list_element_field(schema: &ArrowSchema, name: &str) -> Result<FieldRef> {
let field = schema
.field_with_name(name)
.with_context(|| format!("schema missing list column `{name}`"))?;
match field.data_type() {
DataType::List(elem) | DataType::LargeList(elem) => Ok(elem.clone()),
other => bail!("column `{name}` expected List, got {other:?}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector::chunk::{hash_text, ChunkOptions};
use crate::warehouse::iceberg::IcebergWarehouse;
fn profile() -> ModelProfile {
ModelProfile {
model_name: "test-embed".into(),
weights_sha: "w0".into(),
tokenizer_sha: "t0".into(),
pooling: "mean".into(),
normalize: true,
dim: 8,
dtype: "f32".into(),
}
}
fn registry_profile(m: &crate::vector::embed_registry::EmbedModel) -> ModelProfile {
ModelProfile {
model_name: m.model_name.into(),
weights_sha: format!("weights-of-{}", m.id),
tokenizer_sha: format!("tok-of-{}", m.id),
pooling: "mean".into(),
normalize: true,
dim: m.dim,
dtype: "f32".into(),
}
}
#[test]
fn registry_models_produce_distinct_profiles_and_keys() {
use crate::vector::embed_registry as reg;
let jina = reg::by_id("jina-v2-base-code").expect("default registered");
let mini = reg::by_id("minilm-l6-v2").expect("alt registered");
let pj = registry_profile(jina);
let pm = registry_profile(mini);
assert_eq!(pj.dim, 768, "jina is 768-dim");
assert_eq!(pm.dim, 384, "minilm is 384-dim");
assert_ne!(pj.dim, pm.dim, "different models, different dims");
assert_ne!(
pj.id(),
pm.id(),
"different models MUST hash to different model_profile keys"
);
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let ch = chunker_hash(&ChunkOptions::default());
let chunks = vec![mk("a.rs", 1, "alpha")];
index_snapshot(&wh, ¶ms("repo", "sha1", &pj, &ch), &chunks, |missing| {
Ok(missing.iter().map(|_| vec![0.0f32; pj.dim]).collect())
})
.unwrap();
let q = vec![0.0f32; pm.dim];
let err = search(&wh, "repo", Some("sha1"), &pm.id(), &q, 1).unwrap_err();
assert!(
err.to_string().contains("no embedding snapshot"),
"minilm profile must not see jina's vectors: {err}"
);
let qj = {
let mut v = vec![0.0f32; pj.dim];
v[0] = 1.0;
v
};
let hits = search(&wh, "repo", Some("sha1"), &pj.id(), &qj, 1).unwrap();
assert_eq!(hits.len(), 1, "jina profile still finds its own vector");
}
fn unit(axis: usize) -> Vec<f32> {
let mut v = vec![0.0f32; 8];
v[axis] = 1.0;
v
}
fn axis_of(text: &str) -> usize {
match text {
"alpha" | "same" => 0,
"bravo" => 1,
"charlie" => 2,
"bravo-2" => 3,
"delta" => 4,
_ => 7,
}
}
fn embed(missing: &[Chunk]) -> Result<Vec<Vec<f32>>> {
Ok(missing.iter().map(|c| unit(axis_of(&c.text))).collect())
}
fn mk(file: &str, line: usize, text: &str) -> Chunk {
Chunk {
file: file.into(),
start_line: line,
end_line: line,
content_hash: hash_text(text),
text: text.into(),
}
}
fn params<'a>(repo: &'a str, sha: &'a str, model: &'a ModelProfile, ch: &'a str) -> IndexParams<'a> {
IndexParams {
workspace: "ws",
repo,
git_sha: sha,
branch: "main",
model,
chunker_hash: ch,
complete: true,
}
}
struct TestEmbedder;
impl Embedder for TestEmbedder {
fn profile(&self) -> ModelProfile {
profile()
}
fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|t| unit(axis_of(t))).collect())
}
}
#[test]
fn index_then_search_roundtrips() {
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let m = profile();
let ch = chunker_hash(&ChunkOptions::default());
let chunks = vec![mk("a.rs", 1, "alpha"), mk("b.rs", 1, "bravo"), mk("c.rs", 1, "charlie")];
let snap = index_snapshot(&wh, ¶ms("repo", "sha1", &m, &ch), &chunks, embed).unwrap();
assert_eq!(snap.occurrences, 3);
assert_eq!(snap.new_vectors, 3);
let mut q = unit(1);
q[2] = 0.1;
let hits = search(&wh, "repo", Some("sha1"), &m.id(), &q, 1).unwrap();
assert_eq!(hits[0].1.file, "b.rs");
}
#[test]
fn reindex_same_snapshot_is_idempotent() {
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let m = profile();
let ch = chunker_hash(&ChunkOptions::default());
let chunks = vec![mk("a.rs", 1, "alpha"), mk("b.rs", 1, "bravo")];
let s1 = index_snapshot(&wh, ¶ms("repo", "sha1", &m, &ch), &chunks, embed).unwrap();
let s2 = index_snapshot(&wh, ¶ms("repo", "sha1", &m, &ch), &chunks, embed).unwrap();
assert_eq!(s1.snapshot_id, s2.snapshot_id, "same identity → same snapshot");
assert_eq!(s2.new_vectors, 0, "nothing re-embedded");
}
#[test]
fn unchanged_content_dedups_across_commits() {
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let m = profile();
let ch = chunker_hash(&ChunkOptions::default());
let b1 = vec![mk("a.rs", 1, "alpha"), mk("b.rs", 1, "bravo")];
index_snapshot(&wh, ¶ms("repo", "sha1", &m, &ch), &b1, embed).unwrap();
let b2 = vec![
mk("a.rs", 1, "alpha"), mk("b.rs", 1, "bravo-2"), mk("d.rs", 1, "delta"), ];
let s2 = index_snapshot(&wh, ¶ms("repo", "sha2", &m, &ch), &b2, |missing| {
assert_eq!(missing.len(), 2, "only changed+new chunks embedded");
embed(missing)
})
.unwrap();
assert_eq!(s2.new_vectors, 2, "only the changed + new chunk are embedded");
let hits = search(&wh, "repo", Some("sha1"), &m.id(), &unit(0), 1).unwrap();
assert_eq!(hits[0].1.file, "a.rs");
assert_eq!(hits[0].1.content_hash, hash_text("alpha"));
}
#[test]
fn duplicate_content_in_two_files_returns_both() {
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let m = profile();
let ch = chunker_hash(&ChunkOptions::default());
let chunks = vec![mk("x.rs", 1, "same"), mk("y.rs", 9, "same")];
let snap = index_snapshot(&wh, ¶ms("repo", "sha1", &m, &ch), &chunks, embed).unwrap();
assert_eq!(snap.new_vectors, 1, "deduped to one vector");
assert_eq!(snap.occurrences, 2);
let hits = search(&wh, "repo", Some("sha1"), &m.id(), &unit(0), 2).unwrap();
let files: HashSet<&str> = hits.iter().map(|(_, o)| o.file.as_str()).collect();
assert!(files.contains("x.rs") && files.contains("y.rs"), "both occurrences");
}
#[test]
fn index_repo_chunks_embeds_and_searches() {
let dir = tempfile::tempdir().unwrap();
let wh = IcebergWarehouse::open(dir.path()).unwrap();
let files = vec![
("a.rs".to_string(), "alpha".to_string()),
("b.rs".to_string(), "bravo".to_string()),
];
let snap = index_repo(
&wh,
&RepoRef {
workspace: "ws",
repo: "repo",
git_sha: "sha1",
branch: "main",
complete: true,
},
&files,
&ChunkOptions::default(),
&TestEmbedder,
)
.unwrap();
assert_eq!(snap.occurrences, 2);
assert_eq!(snap.new_vectors, 2);
let hits = search(&wh, "repo", Some("sha1"), &profile().id(), &unit(1), 1).unwrap();
assert_eq!(hits[0].1.file, "b.rs");
}
}