use memmap2::Mmap;
use safetensors::{Dtype, SafeTensors};
use std::fs::File;
use std::io;
use std::path::Path;
use std::sync::Arc;
use tokenizers::Tokenizer;
pub const DIM: usize = 256;
const EMBEDDINGS_TENSOR: &str = "embeddings";
pub struct Embedder {
_mmap: Arc<Mmap>,
embeddings_ptr: *const f32,
vocab_size: usize,
tokenizer: Tokenizer,
}
unsafe impl Send for Embedder {}
unsafe impl Sync for Embedder {}
impl Embedder {
pub fn open(model_dir: &Path) -> io::Result<Self> {
let safetensors_path = model_dir.join("model.safetensors");
let tokenizer_path = model_dir.join("tokenizer.json");
let file = File::open(&safetensors_path)?;
let mmap = unsafe { Mmap::map(&file) }?;
let mmap = Arc::new(mmap);
let st = SafeTensors::deserialize(&mmap[..])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("safetensors: {e}")))?;
let names: Vec<&str> = st.names().into_iter().collect();
let tensor = st.tensor(EMBEDDINGS_TENSOR).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"safetensors missing '{EMBEDDINGS_TENSOR}' tensor (have: {names:?}): {e}"
),
)
})?;
if tensor.dtype() != Dtype::F32 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected F32 embeddings, got {:?}", tensor.dtype()),
));
}
let shape = tensor.shape();
if shape.len() != 2 || shape[1] != DIM {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected shape [V, {DIM}], got {shape:?}"),
));
}
let vocab_size = shape[0];
let data = tensor.data();
if data.len() != vocab_size * DIM * 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"tensor data length {} != {} = vocab_size * DIM * 4",
data.len(),
vocab_size * DIM * 4
),
));
}
let ptr = data.as_ptr() as *const f32;
debug_assert_eq!(
(ptr as usize) % std::mem::align_of::<f32>(),
0,
"embedding tensor must be f32-aligned"
);
let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("tokenizer.json: {e}"),
)
})?;
Ok(Self {
_mmap: mmap,
embeddings_ptr: ptr,
vocab_size,
tokenizer,
})
}
#[allow(dead_code)] pub fn vocab_size(&self) -> usize {
self.vocab_size
}
fn all_rows(&self) -> &[f32] {
unsafe {
std::slice::from_raw_parts(self.embeddings_ptr, self.vocab_size * DIM)
}
}
fn row(&self, token_id: u32) -> &[f32] {
let id = (token_id as usize).min(self.vocab_size.saturating_sub(1));
let start = id * DIM;
&self.all_rows()[start..start + DIM]
}
pub fn encode_one(&self, text: &str) -> [f32; DIM] {
let mut out = [0.0f32; DIM];
if text.is_empty() {
return out;
}
let encoding = match self.tokenizer.encode(text, false) {
Ok(e) => e,
Err(_) => return out,
};
let ids = encoding.get_ids();
if ids.is_empty() {
return out;
}
for &id in ids {
let row = self.row(id);
for i in 0..DIM {
out[i] += row[i];
}
}
let inv_n = 1.0 / ids.len() as f32;
for v in &mut out {
*v *= inv_n;
}
let norm: f32 = out.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
let inv = 1.0 / norm;
for v in &mut out {
*v *= inv;
}
}
out
}
}
const PAR_THRESHOLD: usize = 4096;
pub fn cosine_topk(
query: &[f32; DIM],
embeddings: &[f32],
mask: Option<&[bool]>,
k: usize,
) -> Vec<(u32, f32)> {
use rayon::prelude::*;
let n = embeddings.len() / DIM;
debug_assert_eq!(embeddings.len() % DIM, 0, "embeddings length not a multiple of DIM");
if let Some(m) = mask {
debug_assert_eq!(m.len(), n, "mask length must equal row count");
}
if n == 0 || k == 0 {
return Vec::new();
}
let q_lanes = load_query_lanes(query);
let scores: Vec<f32> = if n >= PAR_THRESHOLD {
(0..n)
.into_par_iter()
.with_min_len(256)
.map(|i| score_row(i, embeddings, &q_lanes, mask))
.collect()
} else {
(0..n)
.map(|i| score_row(i, embeddings, &q_lanes, mask))
.collect()
};
let mut idx: Vec<u32> = (0..n as u32).collect();
let take = k.min(n);
idx.select_nth_unstable_by(take - 1, |&a, &b| {
scores[b as usize]
.partial_cmp(&scores[a as usize])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut top: Vec<u32> = idx.into_iter().take(take).collect();
top.sort_by(|&a, &b| {
scores[b as usize]
.partial_cmp(&scores[a as usize])
.unwrap_or(std::cmp::Ordering::Equal)
});
top.into_iter()
.filter_map(|i| {
let s = scores[i as usize];
if s.is_finite() { Some((i, s)) } else { None }
})
.collect()
}
#[inline]
fn score_row(
i: usize,
embeddings: &[f32],
q_lanes: &[wide::f32x8; DIM / 8],
mask: Option<&[bool]>,
) -> f32 {
if let Some(m) = mask {
if !m[i] {
return f32::NEG_INFINITY;
}
}
let row = &embeddings[i * DIM..(i + 1) * DIM];
dot_simd(q_lanes, row)
}
#[inline]
fn load_query_lanes(query: &[f32; DIM]) -> [wide::f32x8; DIM / 8] {
let mut out = [wide::f32x8::splat(0.0); DIM / 8];
for (i, lane) in out.iter_mut().enumerate() {
let chunk: [f32; 8] = query[i * 8..(i + 1) * 8].try_into().unwrap();
*lane = wide::f32x8::from(chunk);
}
out
}
#[inline]
fn dot_simd(q_lanes: &[wide::f32x8; DIM / 8], row: &[f32]) -> f32 {
debug_assert_eq!(row.len(), DIM);
let mut acc = wide::f32x8::splat(0.0);
for (i, q) in q_lanes.iter().enumerate() {
let chunk: [f32; 8] = row[i * 8..(i + 1) * 8].try_into().unwrap();
let r = wide::f32x8::from(chunk);
acc += *q * r;
}
let arr: [f32; 8] = acc.into();
arr.iter().sum()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::download::{ensure_model, ModelInfo};
fn ensure_real_model() -> std::path::PathBuf {
let info = ModelInfo::potion_code_16m();
ensure_model(&info).expect("model download failed; see network-security wiki")
}
#[test]
#[ignore]
fn network_loads_potion_model() {
let dir = ensure_real_model();
let emb = Embedder::open(&dir).expect("Embedder::open failed");
assert!(emb.vocab_size() > 1000, "vocab implausibly small");
}
#[test]
#[ignore]
fn network_encodes_to_unit_vector() {
let emb = Embedder::open(&ensure_real_model()).unwrap();
let v = emb.encode_one("def parse_json(s): return json.loads(s)");
for x in v.iter() {
assert!(x.is_finite(), "non-finite component: {x}");
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "norm = {norm}, expected ≈ 1.0");
}
#[test]
#[ignore]
fn network_similar_strings_have_high_cosine() {
let emb = Embedder::open(&ensure_real_model()).unwrap();
let a = emb.encode_one("def add(a, b): return a + b");
let b = emb.encode_one("def sum(x, y): return x + y");
let c = emb.encode_one("class HttpServer: def listen(self, port): pass");
let cos = |u: &[f32], v: &[f32]| -> f32 {
u.iter().zip(v.iter()).map(|(x, y)| x * y).sum::<f32>()
};
let ab = cos(&a, &b);
let ac = cos(&a, &c);
assert!(
ab > ac,
"expected related code (cos {ab}) > unrelated code (cos {ac})"
);
}
fn unit(values: &[f32]) -> Vec<f32> {
let mut v = values.to_vec();
v.resize(DIM, 0.0);
let n: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if n > 0.0 {
for x in &mut v {
*x /= n;
}
}
v
}
#[test]
fn cosine_topk_empty_returns_empty() {
let q = [0.0f32; DIM];
assert!(cosine_topk(&q, &[], None, 5).is_empty());
}
#[test]
fn cosine_topk_zero_k_returns_empty() {
let q = [0.0f32; DIM];
let rows = vec![0.0f32; DIM];
assert!(cosine_topk(&q, &rows, None, 0).is_empty());
}
#[test]
fn cosine_topk_orders_by_similarity() {
let mut rows = Vec::new();
rows.extend(unit(&[1.0, 0.0, 0.0])); rows.extend(unit(&[0.0, 1.0, 0.0])); rows.extend(unit(&[-1.0, 0.0, 0.0]));
let q: [f32; DIM] = unit(&[1.0, 0.0, 0.0]).try_into().unwrap();
let top = cosine_topk(&q, &rows, None, 3);
assert_eq!(top.len(), 3);
assert_eq!(top[0].0, 0);
assert!((top[0].1 - 1.0).abs() < 1e-5);
assert_eq!(top[1].0, 1);
assert!(top[1].1.abs() < 1e-5); assert_eq!(top[2].0, 2);
assert!((top[2].1 + 1.0).abs() < 1e-5); }
#[test]
fn cosine_topk_respects_k() {
let mut rows = Vec::new();
for i in 0..10 {
let mag = 1.0 - (i as f32) * 0.1;
rows.extend(unit(&[mag, 0.1, 0.0]));
}
let q: [f32; DIM] = unit(&[1.0, 0.0, 0.0]).try_into().unwrap();
let top = cosine_topk(&q, &rows, None, 3);
assert_eq!(top.len(), 3);
assert_eq!(top[0].0, 0);
assert_eq!(top[1].0, 1);
assert_eq!(top[2].0, 2);
}
#[test]
fn cosine_topk_mask_excludes_filtered_rows() {
let mut rows = Vec::new();
rows.extend(unit(&[1.0, 0.0, 0.0])); rows.extend(unit(&[0.9, 0.1, 0.0])); rows.extend(unit(&[0.8, 0.2, 0.0]));
let q: [f32; DIM] = unit(&[1.0, 0.0, 0.0]).try_into().unwrap();
let mask = vec![false, true, true];
let top = cosine_topk(&q, &rows, Some(&mask), 5);
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, 1);
assert_eq!(top[1].0, 2);
}
#[test]
fn cosine_topk_handles_large_matrix() {
let n = 5000;
let mut rows = Vec::with_capacity(n * DIM);
for i in 0..n {
let mut v = vec![0.0f32; 3];
v[0] = 1.0 - (i as f32) / (n as f32);
v[1] = (i as f32) / (n as f32);
rows.extend(unit(&v));
}
let q: [f32; DIM] = unit(&[1.0, 0.0, 0.0]).try_into().unwrap();
let top = cosine_topk(&q, &rows, None, 5);
assert_eq!(top.len(), 5);
assert_eq!(top[0].0, 0);
for w in top.windows(2) {
assert!(w[0].1 >= w[1].1, "scores not monotone: {:?}", top);
}
}
#[test]
#[ignore]
fn network_empty_returns_zero_vector() {
let emb = Embedder::open(&ensure_real_model()).unwrap();
let v = emb.encode_one("");
assert!(v.iter().all(|&x| x == 0.0));
}
}