#![allow(dead_code)]
use std::collections::HashMap;
use std::f32;
pub const EMBEDDING_DIM: usize = 1536;
pub trait EmbeddingProvider: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
fn dimension(&self) -> usize;
}
#[derive(Debug)]
pub enum EmbeddingError {
ApiError(String),
InvalidDimension,
NetworkError,
}
pub struct EmbeddingCache {
embeddings: HashMap<String, Vec<f32>>,
dimension: usize,
}
impl EmbeddingCache {
pub fn new(dimension: usize) -> Self {
Self {
embeddings: HashMap::new(),
dimension,
}
}
pub fn insert(&mut self, text: String, embedding: Vec<f32>) {
if embedding.len() == self.dimension {
self.embeddings.insert(text, embedding);
}
}
pub fn get(&self, text: &str) -> Option<&Vec<f32>> {
self.embeddings.get(text)
}
pub fn contains(&self, text: &str) -> bool {
self.embeddings.contains_key(text)
}
}
#[cfg(target_arch = "x86_64")]
pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
if a.len() != b.len() {
return 0.0;
}
unsafe {
let mut dot_sum = _mm256_setzero_ps();
let mut norm_a = _mm256_setzero_ps();
let mut norm_b = _mm256_setzero_ps();
let chunks = a.len() / 8;
for i in 0..chunks {
let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
norm_a = _mm256_fmadd_ps(va, va, norm_a);
norm_b = _mm256_fmadd_ps(vb, vb, norm_b);
}
let dot = horizontal_sum_ps256(dot_sum);
let na = horizontal_sum_ps256(norm_a).sqrt();
let nb = horizontal_sum_ps256(norm_b).sqrt();
let mut dot_rem = 0.0;
let mut na_rem = 0.0;
let mut nb_rem = 0.0;
for i in (chunks * 8)..a.len() {
dot_rem += a[i] * b[i];
na_rem += a[i] * a[i];
nb_rem += b[i] * b[i];
}
(dot + dot_rem) / ((na + na_rem.sqrt()) * (nb + nb_rem.sqrt()))
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
cosine_similarity_scalar(a, b)
}
pub fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot = 0.0;
let mut norm_a = 0.0;
let mut norm_b = 0.0;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
dot / (norm_a.sqrt() * norm_b.sqrt())
}
#[cfg(target_arch = "x86_64")]
unsafe fn horizontal_sum_ps256(v: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::*;
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum = _mm_add_ps(high, low);
let shuf = _mm_shuffle_ps(sum, sum, 0x0E);
let sums = _mm_add_ps(sum, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let result = _mm_add_ss(sums, shuf2);
_mm_cvtss_f32(result)
}
pub struct MockEmbeddingProvider {
dimension: usize,
}
impl MockEmbeddingProvider {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl EmbeddingProvider for MockEmbeddingProvider {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let mut embedding = vec![0.0; self.dimension];
let hash = text
.bytes()
.fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
for (i, slot) in embedding.iter_mut().enumerate() {
let val = ((hash.wrapping_mul(i as u64 + 1)) % 1000) as f32 / 1000.0;
*slot = val * 2.0 - 1.0; }
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
texts.iter().map(|t| self.embed(t)).collect()
}
fn dimension(&self) -> usize {
self.dimension
}
}
#[repr(C, packed)]
pub struct EmbeddedBlockHeader {
pub x: f32,
pub y: f32,
pub z: f32,
pub zoom: f32,
pub depth: u8,
pub layer_id: u8,
pub data_offset: u32,
pub data_len: u16,
pub parent_idx: u32,
pub child_count: u16,
pub crc16: [u8; 2],
pub embedding_offset: u32, pub has_embedding: bool, }
#[cfg(feature = "embeddings")]
pub struct CandleEmbeddingProvider {
model: candle_transformers::models::bert::BertModel,
tokenizer: tokenizers::Tokenizer,
dim: usize,
device: candle_core::Device,
}
#[cfg(feature = "embeddings")]
impl CandleEmbeddingProvider {
pub fn new(model_id: &str) -> Result<Self, EmbeddingError> {
use candle_core::Device;
use hf_hub::api::sync::Api;
let device = Device::Cpu;
let api = Api::new().map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
let repo = api.model(model_id.to_string());
let tokenizer_path = repo
.get("tokenizer.json")
.map_err(|e| EmbeddingError::ApiError(format!("tokenizer download: {}", e)))?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
.map_err(|e| EmbeddingError::ApiError(format!("tokenizer load: {}", e)))?;
let weights_path = repo
.get("model.safetensors")
.map_err(|e| EmbeddingError::ApiError(format!("weights download: {}", e)))?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(
&[weights_path],
candle_core::DType::F32,
&device,
)
}
.map_err(|e| EmbeddingError::ApiError(format!("varbuilder: {}", e)))?;
let config = candle_transformers::models::bert::Config {
vocab_size: 30522,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: candle_transformers::models::bert::Activation::Gelu,
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
model_type: Some("bert".to_string()),
};
let dim = config.hidden_size;
let model = candle_transformers::models::bert::BertModel::load(vb, &config)
.map_err(|e| EmbeddingError::ApiError(format!("model load: {}", e)))?;
Ok(Self {
model,
tokenizer,
dim,
device,
})
}
fn embed_inner(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
use candle_core::Tensor;
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| EmbeddingError::ApiError(format!("tokenize: {}", e)))?;
let ids = encoding.get_ids();
let type_ids = encoding.get_type_ids();
let len = ids.len();
let input_ids = Tensor::new(ids, &self.device)
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?
.reshape((1, len))
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
let token_type_ids = Tensor::new(type_ids, &self.device)
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?
.reshape((1, len))
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
let output = self
.model
.forward(&input_ids, &token_type_ids, None)
.map_err(|e| EmbeddingError::ApiError(format!("forward: {}", e)))?;
let pooled = output
.mean(1)
.map_err(|e| EmbeddingError::ApiError(format!("mean pool: {}", e)))?
.squeeze(0)
.map_err(|e| EmbeddingError::ApiError(format!("squeeze: {}", e)))?;
let mut embedding: Vec<f32> = pooled
.to_vec1()
.map_err(|e| EmbeddingError::ApiError(format!("to_vec: {}", e)))?;
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
}
#[cfg(feature = "embeddings")]
impl EmbeddingProvider for CandleEmbeddingProvider {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
self.embed_inner(text)
}
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
texts.iter().map(|t| self.embed_inner(t)).collect()
}
fn dimension(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity_scalar(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity_scalar(&a, &c) - 0.0).abs() < 0.001);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity_scalar(&a, &d) - -1.0).abs() < 0.001);
}
#[test]
fn test_mock_embeddings() {
let provider = MockEmbeddingProvider::new(128);
let embedding = provider.embed("test text").unwrap();
assert_eq!(embedding.len(), 128);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001);
}
#[test]
fn test_embedding_cache() {
let mut cache = EmbeddingCache::new(3);
let embedding = vec![1.0, 0.0, 0.0];
cache.insert("test".to_string(), embedding.clone());
assert!(cache.contains("test"));
assert_eq!(cache.get("test"), Some(&embedding));
}
}