use std::path::{Path, PathBuf};
use std::sync::Mutex;
use crossbeam_channel::bounded;
use hf_hub::api::sync::Api;
use rayon::prelude::*;
use streaming_iterator::StreamingIterator;
use tree_sitter::{Parser, QueryCursor};
use crate::chunk::{CodeChunk, ContentKind};
use crate::embed::SearchConfig;
use crate::encoder::VectorEncoder;
use crate::encoder::ripvec::chunking::{DEFAULT_DESIRED_CHUNK_CHARS, chunk_source};
use crate::encoder::ripvec::static_model::StaticEmbedModel;
use crate::languages::{config_for_extension, lsp_symbol_kind_for_node_kind};
use crate::profile::Profiler;
use crate::walk::collect_files_with_options;
const PIPELINE_BATCH_SIZE: usize = 1024;
const PIPELINE_RING_SIZE: usize = 4;
pub const DEFAULT_MODEL_REPO: &str = "minishlab/potion-base-32M";
pub const DEFAULT_HIDDEN_DIM: usize = 256;
const MAX_FILE_BYTES: u64 = 1_000_000;
pub struct StaticEncoder {
model: StaticEmbedModel,
model_repo: String,
hidden_dim: usize,
}
impl StaticEncoder {
#[must_use]
pub fn encode_query(&self, query: &str) -> Vec<f32> {
self.model.encode_query(query)
}
pub fn from_pretrained(model_repo: &str) -> crate::Result<Self> {
let resolved = Self::resolve_model_dir(model_repo)?;
let model = StaticEmbedModel::from_path(&resolved, Some(true))
.map_err(|e| crate::Error::Other(anyhow::anyhow!("static model load failed: {e}")))?;
let hidden_dim = model.hidden_dim();
Ok(Self {
model,
model_repo: model_repo.to_string(),
hidden_dim,
})
}
fn resolve_model_dir(model_repo: &str) -> crate::Result<PathBuf> {
let local = Path::new(model_repo);
if local.is_dir() {
return Ok(local.to_path_buf());
}
let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
let repo = api.model(model_repo.to_string());
let _ = repo
.get("config.json")
.map_err(|e| crate::Error::Download(e.to_string()))?;
let _ = repo
.get("tokenizer.json")
.map_err(|e| crate::Error::Download(e.to_string()))?;
let weights_path = repo
.get("model.safetensors")
.map_err(|e| crate::Error::Download(e.to_string()))?;
weights_path
.parent()
.map(std::path::Path::to_path_buf)
.ok_or_else(|| {
crate::Error::Other(anyhow::anyhow!(
"hf-hub returned root path for {model_repo}; cannot resolve snapshot dir"
))
})
}
pub fn embed_paths(
&self,
root: &Path,
paths: &[std::path::PathBuf],
profiler: &Profiler,
) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
let _guard = profiler.phase("embed_paths");
let mut chunks_out: Vec<CodeChunk> = Vec::new();
let mut texts: Vec<String> = Vec::new();
for path in paths {
let (file_chunks, file_texts) = chunk_one_file(root, path);
chunks_out.extend(file_chunks);
texts.extend(file_texts);
}
if chunks_out.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
let embeddings = self.model.encode_batch(&text_refs);
debug_assert_eq!(embeddings.len(), chunks_out.len());
Ok((chunks_out, embeddings))
}
}
impl VectorEncoder for StaticEncoder {
fn embed_root(
&self,
root: &Path,
cfg: &SearchConfig,
profiler: &Profiler,
) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
let walk_options = cfg.walk_options();
let file_paths = {
let _guard = profiler.phase("walk");
collect_files_with_options(root, &walk_options)
};
if file_paths.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let (chunk_tx, chunk_rx) = bounded::<(CodeChunk, String)>(PIPELINE_BATCH_SIZE * 8);
let (batch_tx, batch_rx) = bounded::<Vec<(CodeChunk, String)>>(PIPELINE_RING_SIZE);
let output: Mutex<Vec<(CodeChunk, Vec<f32>)>> = Mutex::new(Vec::new());
let model = &self.model;
let num_cores = rayon::current_num_threads().max(2);
let chunk_threads = (num_cores / 2).max(1);
let chunk_pool = rayon::ThreadPoolBuilder::new()
.num_threads(chunk_threads)
.thread_name(|i| format!("semble-chunk-{i}"))
.build()
.map_err(|e| crate::Error::Other(anyhow::anyhow!("chunk thread pool build: {e}")))?;
let _phase_guard = profiler.phase("pipeline");
std::thread::scope(|scope| {
let chunk_tx_owned = chunk_tx;
scope.spawn(move || {
chunk_pool.install(|| {
file_paths.par_iter().for_each(|full| {
let (chunks, contents) = chunk_one_file(root, full);
for (chunk, content) in chunks.into_iter().zip(contents) {
if chunk_tx_owned.send((chunk, content)).is_err() {
return;
}
}
});
});
});
let batch_tx_owned = batch_tx;
scope.spawn(move || {
let mut buf: Vec<(CodeChunk, String)> = Vec::with_capacity(PIPELINE_BATCH_SIZE);
for pair in chunk_rx {
buf.push(pair);
if buf.len() >= PIPELINE_BATCH_SIZE {
let batch =
std::mem::replace(&mut buf, Vec::with_capacity(PIPELINE_BATCH_SIZE));
if batch_tx_owned.send(batch).is_err() {
return;
}
}
}
if !buf.is_empty() {
let _ = batch_tx_owned.send(buf);
}
});
scope.spawn(|| {
for batch in batch_rx {
if batch.is_empty() {
continue;
}
let mut chunks = Vec::with_capacity(batch.len());
let mut texts: Vec<String> = Vec::with_capacity(batch.len());
for (chunk, text) in batch {
chunks.push(chunk);
texts.push(text);
}
let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
let embeddings = model.encode_batch(&text_refs);
debug_assert_eq!(embeddings.len(), chunks.len());
let mut out = output.lock().expect("output mutex poisoned");
for (chunk, emb) in chunks.into_iter().zip(embeddings) {
out.push((chunk, emb));
}
}
});
});
let collected = output.into_inner().expect("output mutex poisoned");
let mut chunks_out = Vec::with_capacity(collected.len());
let mut embs_out = Vec::with_capacity(collected.len());
for (chunk, emb) in collected {
chunks_out.push(chunk);
embs_out.push(emb);
}
Ok((chunks_out, embs_out))
}
fn hidden_dim(&self) -> usize {
self.hidden_dim
}
fn identity(&self) -> &str {
&self.model_repo
}
}
struct NameCapture {
start_byte: usize,
end_byte: usize,
name: String,
lsp_kind: u32,
}
fn extract_name_captures(
source: &str,
lang_cfg: &crate::languages::LangConfig,
) -> Vec<NameCapture> {
let mut parser = Parser::new();
if parser.set_language(&lang_cfg.language).is_err() {
return Vec::new();
}
let Some(tree) = parser.parse(source, None) else {
return Vec::new();
};
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&lang_cfg.query, tree.root_node(), source.as_bytes());
let capture_names = lang_cfg.query.capture_names();
let mut result: Vec<NameCapture> = Vec::new();
while let Some(m) = matches.next() {
let mut name_start = 0usize;
let mut name_end = 0usize;
let mut name_text = String::new();
let mut def_kind = "";
let mut has_name = false;
let mut has_def = false;
for cap in m.captures {
let cap_name = &capture_names[cap.index as usize];
if *cap_name == "name" {
let start = cap.node.start_byte();
let end = cap.node.end_byte();
if end <= source.len() {
name_start = start;
name_end = end;
name_text = source[start..end].to_string();
has_name = true;
}
} else if *cap_name == "def" {
def_kind = cap.node.kind();
has_def = true;
}
}
if has_name {
result.push(NameCapture {
start_byte: name_start,
end_byte: name_end,
name: name_text,
lsp_kind: if has_def {
lsp_symbol_kind_for_node_kind(def_kind)
} else {
crate::languages::lsp_symbol_kind::VARIABLE
},
});
}
}
result.sort_unstable_by_key(|c| c.start_byte);
result
}
fn name_for_chunk(captures: &[NameCapture], chunk_start: usize, chunk_end: usize) -> (&str, u32) {
for cap in captures {
if cap.start_byte >= chunk_start && cap.end_byte <= chunk_end {
return (cap.name.as_str(), cap.lsp_kind);
}
if cap.start_byte >= chunk_end {
break;
}
}
("", crate::languages::lsp_symbol_kind::VARIABLE)
}
fn chunk_one_file(root: &Path, full: &Path) -> (Vec<CodeChunk>, Vec<String>) {
match std::fs::metadata(full) {
Ok(meta) if meta.len() > MAX_FILE_BYTES => return (Vec::new(), Vec::new()),
Err(_) => return (Vec::new(), Vec::new()),
_ => {}
}
let Ok(source) = std::fs::read_to_string(full) else {
return (Vec::new(), Vec::new());
};
let ext = full
.extension()
.and_then(|e| e.to_str())
.unwrap_or_default();
let lang_cfg = config_for_extension(ext);
let language = lang_cfg.as_ref().map(|c| &c.language);
let name_captures: Vec<NameCapture> = lang_cfg
.as_deref()
.map(|cfg| extract_name_captures(&source, cfg))
.unwrap_or_default();
let rel_path = full
.strip_prefix(root)
.unwrap_or(full)
.display()
.to_string();
let content_kind = ContentKind::from_extension(ext);
let boundaries = chunk_source(&source, language, DEFAULT_DESIRED_CHUNK_CHARS);
let mut chunks = Vec::with_capacity(boundaries.len());
let mut contents = Vec::with_capacity(boundaries.len());
for b in boundaries {
let text = b.content(&source).to_string();
if text.trim().is_empty() {
continue;
}
let (name, lsp_kind) = name_for_chunk(&name_captures, b.start_byte, b.end_byte);
let name = name.to_string();
let kind = if name.is_empty() {
String::new()
} else {
lsp_kind.to_string()
};
contents.push(text.clone());
chunks.push(CodeChunk {
file_path: rel_path.clone(),
name,
kind,
content_kind,
start_line: b.start_line,
end_line: b.end_line,
symbol_line: b.start_line,
content: text.clone(),
enriched_content: text,
qualified_name: None,
});
}
(chunks, contents)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoder::VectorEncoder;
use std::io::Write as _;
#[test]
fn chunk_one_file_populates_name_from_tree_sitter() {
let source = "pub fn add(a: i32, b: i32) -> i32 { a + b }\n";
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("add.rs");
{
let mut f = std::fs::File::create(&path).expect("create");
f.write_all(source.as_bytes()).expect("write");
}
let (chunks, _) = chunk_one_file(dir.path(), &path);
assert!(
!chunks.is_empty(),
"expected at least one chunk from Rust source"
);
assert!(
chunks.iter().any(|c| c.name == "add"),
"expected at least one chunk with name 'add'; got names: {:?}",
chunks.iter().map(|c| c.name.as_str()).collect::<Vec<_>>()
);
}
#[test]
fn chunk_one_file_leaves_name_empty_when_no_identifier() {
let source = "// just a comment\n \n// another comment\n";
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("comments.rs");
{
let mut f = std::fs::File::create(&path).expect("create");
f.write_all(source.as_bytes()).expect("write");
}
let (chunks, _) = chunk_one_file(dir.path(), &path);
for c in &chunks {
assert!(
c.name.is_empty(),
"expected empty name for comment-only source; got {:?}",
c.name
);
}
}
#[test]
fn static_encoder_implements_vector_encoder() {
fn assert_trait_object<T: VectorEncoder + Send + Sync>() {}
assert_trait_object::<StaticEncoder>();
}
fn write_temp(source: &str, filename: &str) -> (tempfile::TempDir, std::path::PathBuf) {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join(filename);
std::fs::write(&path, source).expect("write");
(dir, path)
}
#[test]
fn chunk_one_file_populates_kind_for_rust_struct() {
let source = "pub struct Foo { x: i32 }\n";
let (dir, path) = write_temp(source, "foo.rs");
let (chunks, _) = chunk_one_file(dir.path(), &path);
let struct_chunk = chunks.iter().find(|c| c.name == "Foo");
assert!(
struct_chunk.is_some(),
"expected a chunk named 'Foo'; got: {:?}",
chunks.iter().map(|c| c.name.as_str()).collect::<Vec<_>>()
);
let kind = &struct_chunk.unwrap().kind;
assert_eq!(
kind.as_str(),
"23",
"struct_item must emit LSP SymbolKind::Struct (23); got: {kind:?}"
);
}
#[test]
fn chunk_one_file_populates_kind_for_rust_trait() {
let source = "pub trait MyTrait { fn method(&self); }\n";
let (dir, path) = write_temp(source, "trait.rs");
let (chunks, _) = chunk_one_file(dir.path(), &path);
let trait_chunk = chunks.iter().find(|c| c.name == "MyTrait");
assert!(
trait_chunk.is_some(),
"expected a chunk named 'MyTrait'; got: {:?}",
chunks.iter().map(|c| c.name.as_str()).collect::<Vec<_>>()
);
let kind = &trait_chunk.unwrap().kind;
assert_eq!(
kind.as_str(),
"11",
"trait_item must emit LSP SymbolKind::Interface (11); got: {kind:?}"
);
}
#[test]
fn chunk_one_file_kind_distinct_from_variable_default() {
let source = "pub struct Qux { x: i32, y: i32 }\n";
let (dir, path) = write_temp(source, "qux.rs");
let (chunks, _) = chunk_one_file(dir.path(), &path);
let named_chunks: Vec<_> = chunks.iter().filter(|c| !c.name.is_empty()).collect();
assert!(
!named_chunks.is_empty(),
"expected at least one named chunk from Rust source with struct definition"
);
for c in &named_chunks {
assert!(
!c.kind.is_empty(),
"named chunk '{}' must have non-empty kind (pre-B2 regression); got empty",
c.name
);
}
let qux = named_chunks.iter().find(|c| c.name == "Qux");
if let Some(c) = qux {
assert_eq!(
c.kind.as_str(),
"23",
"Qux (struct_item) must emit LSP SymbolKind::Struct (23); got: {:?}",
c.kind
);
}
}
#[test]
#[ignore = "requires local model files at RIPVEC_SEMBLE_MODEL_PATH"]
fn static_encoder_loads_potion_code_16m() {
let Ok(path) = std::env::var("RIPVEC_SEMBLE_MODEL_PATH") else {
eprintln!("RIPVEC_SEMBLE_MODEL_PATH not set; skipping");
return;
};
let enc = StaticEncoder::from_pretrained(&path).expect("model load should succeed");
assert_eq!(enc.hidden_dim(), DEFAULT_HIDDEN_DIM);
assert_eq!(enc.identity(), path);
let row = enc.encode_query("hello world");
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-3,
"expected L2-normalized output; got norm={norm}"
);
}
}