use std::path::{Path, PathBuf};
use std::sync::Mutex;
use crossbeam_channel::bounded;
use hf_hub::api::sync::Api;
use rayon::prelude::*;
use crate::chunk::CodeChunk;
use crate::embed::SearchConfig;
use crate::encoder::VectorEncoder;
use crate::encoder::ripvec::chunking::{DEFAULT_DESIRED_CHUNK_CHARS, chunk_source};
use crate::encoder::ripvec::static_model::StaticEmbedModel;
use crate::languages::config_for_extension;
use crate::profile::Profiler;
use crate::walk::collect_files_with_options;
const PIPELINE_BATCH_SIZE: usize = 1024;
const PIPELINE_RING_SIZE: usize = 4;
pub const DEFAULT_MODEL_REPO: &str = "minishlab/potion-code-16M";
pub const DEFAULT_HIDDEN_DIM: usize = 256;
const MAX_FILE_BYTES: u64 = 1_000_000;
pub struct StaticEncoder {
model: StaticEmbedModel,
model_repo: String,
hidden_dim: usize,
}
impl StaticEncoder {
#[must_use]
pub fn encode_query(&self, query: &str) -> Vec<f32> {
self.model.encode_query(query)
}
pub fn from_pretrained(model_repo: &str) -> crate::Result<Self> {
let resolved = Self::resolve_model_dir(model_repo)?;
let model = StaticEmbedModel::from_path(&resolved, Some(true))
.map_err(|e| crate::Error::Other(anyhow::anyhow!("static model load failed: {e}")))?;
let hidden_dim = model.hidden_dim();
Ok(Self {
model,
model_repo: model_repo.to_string(),
hidden_dim,
})
}
fn resolve_model_dir(model_repo: &str) -> crate::Result<PathBuf> {
let local = Path::new(model_repo);
if local.is_dir() {
return Ok(local.to_path_buf());
}
let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
let repo = api.model(model_repo.to_string());
let _ = repo
.get("config.json")
.map_err(|e| crate::Error::Download(e.to_string()))?;
let _ = repo
.get("tokenizer.json")
.map_err(|e| crate::Error::Download(e.to_string()))?;
let weights_path = repo
.get("model.safetensors")
.map_err(|e| crate::Error::Download(e.to_string()))?;
weights_path
.parent()
.map(std::path::Path::to_path_buf)
.ok_or_else(|| {
crate::Error::Other(anyhow::anyhow!(
"hf-hub returned root path for {model_repo}; cannot resolve snapshot dir"
))
})
}
}
impl VectorEncoder for StaticEncoder {
fn embed_root(
&self,
root: &Path,
cfg: &SearchConfig,
profiler: &Profiler,
) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
let walk_options = cfg.walk_options();
let file_paths = {
let _guard = profiler.phase("walk");
collect_files_with_options(root, &walk_options)
};
if file_paths.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let (chunk_tx, chunk_rx) = bounded::<(CodeChunk, String)>(PIPELINE_BATCH_SIZE * 8);
let (batch_tx, batch_rx) = bounded::<Vec<(CodeChunk, String)>>(PIPELINE_RING_SIZE);
let output: Mutex<Vec<(CodeChunk, Vec<f32>)>> = Mutex::new(Vec::new());
let model = &self.model;
let num_cores = rayon::current_num_threads().max(2);
let chunk_threads = (num_cores / 2).max(1);
let chunk_pool = rayon::ThreadPoolBuilder::new()
.num_threads(chunk_threads)
.thread_name(|i| format!("semble-chunk-{i}"))
.build()
.map_err(|e| crate::Error::Other(anyhow::anyhow!("chunk thread pool build: {e}")))?;
let _phase_guard = profiler.phase("pipeline");
std::thread::scope(|scope| {
let chunk_tx_owned = chunk_tx;
scope.spawn(move || {
chunk_pool.install(|| {
file_paths.par_iter().for_each(|full| {
let (chunks, contents) = chunk_one_file(root, full);
for (chunk, content) in chunks.into_iter().zip(contents) {
if chunk_tx_owned.send((chunk, content)).is_err() {
return;
}
}
});
});
});
let batch_tx_owned = batch_tx;
scope.spawn(move || {
let mut buf: Vec<(CodeChunk, String)> = Vec::with_capacity(PIPELINE_BATCH_SIZE);
for pair in chunk_rx {
buf.push(pair);
if buf.len() >= PIPELINE_BATCH_SIZE {
let batch =
std::mem::replace(&mut buf, Vec::with_capacity(PIPELINE_BATCH_SIZE));
if batch_tx_owned.send(batch).is_err() {
return;
}
}
}
if !buf.is_empty() {
let _ = batch_tx_owned.send(buf);
}
});
scope.spawn(|| {
for batch in batch_rx {
if batch.is_empty() {
continue;
}
let mut chunks = Vec::with_capacity(batch.len());
let mut texts: Vec<String> = Vec::with_capacity(batch.len());
for (chunk, text) in batch {
chunks.push(chunk);
texts.push(text);
}
let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
let embeddings = model.encode_batch(&text_refs);
debug_assert_eq!(embeddings.len(), chunks.len());
let mut out = output.lock().expect("output mutex poisoned");
for (chunk, emb) in chunks.into_iter().zip(embeddings) {
out.push((chunk, emb));
}
}
});
});
let collected = output.into_inner().expect("output mutex poisoned");
let mut chunks_out = Vec::with_capacity(collected.len());
let mut embs_out = Vec::with_capacity(collected.len());
for (chunk, emb) in collected {
chunks_out.push(chunk);
embs_out.push(emb);
}
Ok((chunks_out, embs_out))
}
fn hidden_dim(&self) -> usize {
self.hidden_dim
}
fn identity(&self) -> &str {
&self.model_repo
}
}
fn chunk_one_file(root: &Path, full: &Path) -> (Vec<CodeChunk>, Vec<String>) {
match std::fs::metadata(full) {
Ok(meta) if meta.len() > MAX_FILE_BYTES => return (Vec::new(), Vec::new()),
Err(_) => return (Vec::new(), Vec::new()),
_ => {}
}
let Ok(source) = std::fs::read_to_string(full) else {
return (Vec::new(), Vec::new());
};
let ext = full
.extension()
.and_then(|e| e.to_str())
.unwrap_or_default();
let lang_cfg = config_for_extension(ext);
let language = lang_cfg.as_ref().map(|c| &c.language);
let rel_path = full
.strip_prefix(root)
.unwrap_or(full)
.display()
.to_string();
let boundaries = chunk_source(&source, language, DEFAULT_DESIRED_CHUNK_CHARS);
let mut chunks = Vec::with_capacity(boundaries.len());
let mut contents = Vec::with_capacity(boundaries.len());
for b in boundaries {
let text = b.content(&source).to_string();
if text.trim().is_empty() {
continue;
}
contents.push(text.clone());
chunks.push(CodeChunk {
file_path: rel_path.clone(),
name: String::new(),
kind: String::new(),
start_line: b.start_line,
end_line: b.end_line,
content: text.clone(),
enriched_content: text,
});
}
(chunks, contents)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoder::VectorEncoder;
#[test]
fn static_encoder_implements_vector_encoder() {
fn assert_trait_object<T: VectorEncoder + Send + Sync>() {}
assert_trait_object::<StaticEncoder>();
}
#[test]
#[ignore = "requires local model files at RIPVEC_SEMBLE_MODEL_PATH"]
fn static_encoder_loads_potion_code_16m() {
let Ok(path) = std::env::var("RIPVEC_SEMBLE_MODEL_PATH") else {
eprintln!("RIPVEC_SEMBLE_MODEL_PATH not set; skipping");
return;
};
let enc = StaticEncoder::from_pretrained(&path).expect("model load should succeed");
assert_eq!(enc.hidden_dim(), DEFAULT_HIDDEN_DIM);
assert_eq!(enc.identity(), path);
let row = enc.encode_query("hello world");
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-3,
"expected L2-normalized output; got norm={norm}"
);
}
}