use candle_core::{DType, Device as CandleDevice, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use dashmap::DashMap;
use serde::Deserialize;
use std::path::Path;
use std::sync::Arc;
use tokenizers::Tokenizer;
use crate::neural::{
Device, EmbeddingConfig, ModernBertConfig, ModernBertEmbedder, ModernBertModel,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelFormat {
Onnx,
SafeTensors,
}
impl ModelFormat {
pub fn detect(path: &Path) -> Option<Self> {
let has_onnx = path.join("model.onnx").exists();
let has_safetensors = path.join("model.safetensors").exists();
let has_tokenizer = path.join("tokenizer.json").exists();
let has_config = path.join("config.json").exists();
if has_onnx && has_tokenizer {
Some(ModelFormat::Onnx)
} else if has_safetensors && has_config && has_tokenizer {
Some(ModelFormat::SafeTensors)
} else {
None
}
}
pub fn expected_files(&self) -> &'static [&'static str] {
match self {
ModelFormat::Onnx => &["model.onnx", "tokenizer.json"],
ModelFormat::SafeTensors => &["model.safetensors", "config.json", "tokenizer.json"],
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelArchitecture {
ModernBert,
Roberta,
Bert,
Unknown,
}
impl ModelArchitecture {
pub fn from_config(config_json: &str) -> Self {
#[derive(Deserialize)]
struct MinimalConfig {
model_type: Option<String>,
}
let config: MinimalConfig =
serde_json::from_str(config_json).unwrap_or(MinimalConfig { model_type: None });
match config.model_type.as_deref() {
Some("modernbert") => ModelArchitecture::ModernBert,
Some("roberta") => ModelArchitecture::Roberta,
Some("bert") => ModelArchitecture::Bert,
_ => ModelArchitecture::Unknown,
}
}
pub fn name(&self) -> &'static str {
match self {
ModelArchitecture::ModernBert => "ModernBERT",
ModelArchitecture::Roberta => "RoBERTa",
ModelArchitecture::Bert => "BERT",
ModelArchitecture::Unknown => "Unknown",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EmbeddingModel {
UniXcoder,
GraphCodeBERT,
CodeBERT,
Custom,
}
impl EmbeddingModel {
pub fn hf_model_id(&self) -> &str {
match self {
EmbeddingModel::UniXcoder => "microsoft/unixcoder-base",
EmbeddingModel::GraphCodeBERT => "microsoft/graphcodebert-base",
EmbeddingModel::CodeBERT => "microsoft/codebert-base",
EmbeddingModel::Custom => "",
}
}
pub fn embedding_dim(&self) -> usize {
match self {
EmbeddingModel::UniXcoder => 768,
EmbeddingModel::GraphCodeBERT => 768,
EmbeddingModel::CodeBERT => 768,
EmbeddingModel::Custom => 768, }
}
pub fn max_length(&self) -> usize {
match self {
EmbeddingModel::UniXcoder => 512,
EmbeddingModel::GraphCodeBERT => 512,
EmbeddingModel::CodeBERT => 512,
EmbeddingModel::Custom => 512,
}
}
}
#[derive(Debug, Clone)]
pub struct CodeEmbedderConfig {
pub model: EmbeddingModel,
pub device: Device,
pub use_cache: bool,
pub cache_size: usize,
pub normalize: bool,
pub batch_size: usize,
}
impl Default for CodeEmbedderConfig {
fn default() -> Self {
Self {
model: EmbeddingModel::UniXcoder,
device: Device::Cpu,
use_cache: true,
cache_size: 10000,
normalize: true,
batch_size: 32,
}
}
}
enum EmbedderBackend {
ModernBert(ModernBertEmbedder),
Bert {
model: BertModel,
tokenizer: Tokenizer,
device: CandleDevice,
hidden_size: usize,
},
}
impl EmbedderBackend {
fn embed(&self, text: &str) -> Result<Vec<f32>, CodeEmbedderError> {
match self {
EmbedderBackend::ModernBert(embedder) => embedder
.embed(text)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string())),
EmbedderBackend::Bert {
model,
tokenizer,
device,
hidden_size,
} => {
let encoding = tokenizer.encode(text, true).map_err(|e| {
CodeEmbedderError::Embedding(format!("Tokenization error: {}", e))
})?;
let ids = encoding.get_ids();
let token_type_ids: Vec<u32> = vec![0; ids.len()];
let input_ids = Tensor::new(ids, device)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?
.unsqueeze(0)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let token_type_tensor = Tensor::new(&token_type_ids[..], device)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?
.unsqueeze(0)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let hidden_states = model
.forward(&input_ids, &token_type_tensor, None)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let cls_embedding = hidden_states
.i((0, 0))
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let embedding = cls_embedding
.to_vec1()
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
if embedding.len() != *hidden_size {
return Err(CodeEmbedderError::Embedding(format!(
"BERT embedding dimension mismatch: expected {}, got {}",
hidden_size,
embedding.len()
)));
}
Ok(embedding)
}
}
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodeEmbedderError> {
match self {
EmbedderBackend::ModernBert(embedder) => embedder
.embed_batch(texts)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string())),
EmbedderBackend::Bert {
model,
tokenizer,
device,
hidden_size,
} => {
if texts.is_empty() {
return Ok(vec![]);
}
let encodings = tokenizer.encode_batch(texts.to_vec(), true).map_err(|e| {
CodeEmbedderError::Embedding(format!("Tokenization error: {}", e))
})?;
let max_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);
let batch_size = encodings.len();
let mut padded_ids: Vec<u32> = Vec::with_capacity(batch_size * max_len);
let mut padded_types: Vec<u32> = Vec::with_capacity(batch_size * max_len);
for enc in &encodings {
let ids = enc.get_ids();
let len = ids.len();
padded_ids.extend(ids.iter().copied());
padded_ids.extend(std::iter::repeat(0).take(max_len - len));
padded_types.extend(std::iter::repeat(0u32).take(max_len));
}
let input_tensor = Tensor::from_vec(padded_ids, (batch_size, max_len), device)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let type_tensor = Tensor::from_vec(padded_types, (batch_size, max_len), device)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let hidden_states = model
.forward(&input_tensor, &type_tensor, None)
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let mut embeddings = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let cls = hidden_states
.i((i, 0))
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
let vec: Vec<f32> = cls
.to_vec1()
.map_err(|e| CodeEmbedderError::Embedding(e.to_string()))?;
if vec.len() != *hidden_size {
return Err(CodeEmbedderError::Embedding(format!(
"BERT embedding dimension mismatch: expected {}, got {}",
hidden_size,
vec.len()
)));
}
embeddings.push(vec);
}
Ok(embeddings)
}
}
}
}
pub struct CodeEmbedder {
config: CodeEmbedderConfig,
backend: EmbedderBackend,
cache: Option<DashMap<String, Vec<f32>>>,
architecture: ModelArchitecture,
}
impl CodeEmbedder {
pub fn new() -> Result<Self, CodeEmbedderError> {
Self::with_config(CodeEmbedderConfig::default())
}
pub fn with_config(config: CodeEmbedderConfig) -> Result<Self, CodeEmbedderError> {
let embed_config = EmbeddingConfig {
normalize: config.normalize,
batch_size: config.batch_size,
..Default::default()
};
let embedder = ModernBertEmbedder::new(embed_config)
.map_err(|e| CodeEmbedderError::ModelLoad(e.to_string()))?;
let cache = if config.use_cache {
Some(DashMap::with_capacity(config.cache_size))
} else {
None
};
Ok(Self {
config,
backend: EmbedderBackend::ModernBert(embedder),
cache,
architecture: ModelArchitecture::ModernBert,
})
}
pub fn from_path(
path: impl AsRef<Path>,
config: CodeEmbedderConfig,
) -> Result<Self, CodeEmbedderError> {
let path = path.as_ref();
if !path.exists() {
return Err(CodeEmbedderError::ModelLoad(format!(
"Model path does not exist: {}",
path.display()
)));
}
let format = ModelFormat::detect(path).ok_or_else(|| {
CodeEmbedderError::ModelLoad(format!(
"Could not detect model format in {}. Expected either:\n\
- ONNX: model.onnx + tokenizer.json\n\
- SafeTensors: model.safetensors + config.json + tokenizer.json",
path.display()
))
})?;
match format {
ModelFormat::Onnx => Self::load_onnx_model(path, config),
ModelFormat::SafeTensors => Self::load_safetensors_model(path, config),
}
}
fn load_onnx_model(
path: &Path,
_config: CodeEmbedderConfig,
) -> Result<Self, CodeEmbedderError> {
Err(CodeEmbedderError::ModelLoad(format!(
"ONNX model detected at {}. For ONNX code models (UniXcoder, GraphCodeBERT), \
please use the specialized embedders from `crate::neural::code`:\n\
- `UniXcoderEmbedder::from_directory(path)` for UniXcoder models\n\
- `GraphCodeBertEmbedder::from_directory(path)` for GraphCodeBERT models",
path.display()
)))
}
fn load_safetensors_model(
path: &Path,
config: CodeEmbedderConfig,
) -> Result<Self, CodeEmbedderError> {
let config_path = path.join("config.json");
let config_json = std::fs::read_to_string(&config_path).map_err(|e| {
CodeEmbedderError::ModelLoad(format!(
"Failed to read config.json from {}: {}",
path.display(),
e
))
})?;
let architecture = ModelArchitecture::from_config(&config_json);
match architecture {
ModelArchitecture::ModernBert => Self::load_modernbert(path, config),
ModelArchitecture::Roberta | ModelArchitecture::Bert => {
Self::load_bert_family(path, config, architecture)
}
ModelArchitecture::Unknown => Err(CodeEmbedderError::ModelLoad(format!(
"Unknown model architecture in {}. Supported architectures:\n\
- modernbert (ModernBERT)\n\
- roberta (RoBERTa, UniXcoder, GraphCodeBERT, CodeBERT)\n\
- bert (BERT)",
path.display()
))),
}
}
fn load_modernbert(path: &Path, config: CodeEmbedderConfig) -> Result<Self, CodeEmbedderError> {
let model_path = path.join("model.safetensors");
let config_path = path.join("config.json");
let tokenizer_path = path.join("tokenizer.json");
let bert_config = ModernBertConfig {
model_id: path.display().to_string(),
device: config.device,
..Default::default()
};
let device = config
.device
.to_candle()
.map_err(|e| CodeEmbedderError::ModelLoad(format!("Device error: {}", e)))?;
let model = ModernBertModel::load_from_files(
&model_path,
&config_path,
&tokenizer_path,
bert_config,
device,
)
.map_err(|e| {
CodeEmbedderError::ModelLoad(format!(
"Failed to load ModernBERT model from {}: {}",
path.display(),
e
))
})?;
let embed_config = EmbeddingConfig {
normalize: config.normalize,
batch_size: config.batch_size,
..Default::default()
};
let embedder = ModernBertEmbedder::from_model(Arc::new(model), embed_config);
let cache = if config.use_cache {
Some(DashMap::with_capacity(config.cache_size))
} else {
None
};
Ok(Self {
config,
backend: EmbedderBackend::ModernBert(embedder),
cache,
architecture: ModelArchitecture::ModernBert,
})
}
fn load_bert_family(
path: &Path,
config: CodeEmbedderConfig,
architecture: ModelArchitecture,
) -> Result<Self, CodeEmbedderError> {
let model_path = path.join("model.safetensors");
let config_path = path.join("config.json");
let tokenizer_path = path.join("tokenizer.json");
let device = config
.device
.to_candle()
.map_err(|e| CodeEmbedderError::ModelLoad(format!("Device error: {}", e)))?;
let config_json = std::fs::read_to_string(&config_path)
.map_err(|e| CodeEmbedderError::ModelLoad(format!("Failed to read config: {}", e)))?;
let bert_config: BertConfig = serde_json::from_str(&config_json).map_err(|e| {
CodeEmbedderError::ModelLoad(format!(
"Failed to parse {} config from {}: {}",
architecture.name(),
path.display(),
e
))
})?;
let hidden_size = bert_config.hidden_size;
let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
CodeEmbedderError::ModelLoad(format!(
"Failed to load tokenizer from {}: {}",
tokenizer_path.display(),
e
))
})?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device).map_err(
|e| {
CodeEmbedderError::ModelLoad(format!(
"Failed to load model weights from {}: {}",
model_path.display(),
e
))
},
)?
};
let model = BertModel::load(vb, &bert_config).map_err(|e| {
CodeEmbedderError::ModelLoad(format!(
"Failed to initialize {} model: {}",
architecture.name(),
e
))
})?;
let cache = if config.use_cache {
Some(DashMap::with_capacity(config.cache_size))
} else {
None
};
Ok(Self {
config,
backend: EmbedderBackend::Bert {
model,
tokenizer,
device,
hidden_size,
},
cache,
architecture,
})
}
pub fn embedding_dim(&self) -> usize {
self.config.model.embedding_dim()
}
pub fn embed(&self, code: &str) -> Result<Vec<f32>, CodeEmbedderError> {
if let Some(ref cache) = self.cache {
if let Some(embedding) = cache.get(code) {
return Ok(embedding.clone());
}
}
let embedding = self.backend.embed(code)?;
if let Some(ref cache) = self.cache {
if cache.len() >= self.config.cache_size {
let to_remove: Vec<String> = cache
.iter()
.take(self.config.cache_size / 10)
.map(|e| e.key().clone())
.collect();
for key in to_remove {
cache.remove(&key);
}
}
cache.insert(code.to_string(), embedding.clone());
}
Ok(embedding)
}
pub fn embed_batch(&self, codes: &[&str]) -> Result<Vec<Vec<f32>>, CodeEmbedderError> {
let mut results = Vec::with_capacity(codes.len());
let mut uncached_indices = Vec::new();
let mut uncached_codes = Vec::new();
for (i, code) in codes.iter().enumerate() {
if let Some(ref cache) = self.cache {
if let Some(embedding) = cache.get(*code) {
results.push(Some(embedding.clone()));
continue;
}
}
results.push(None);
uncached_indices.push(i);
uncached_codes.push(*code);
}
if !uncached_codes.is_empty() {
let batch_embeddings = self.backend.embed_batch(&uncached_codes)?;
for (idx, embedding) in uncached_indices.into_iter().zip(batch_embeddings) {
if let Some(ref cache) = self.cache {
cache.insert(codes[idx].to_string(), embedding.clone());
}
results[idx] = Some(embedding);
}
}
Ok(results.into_iter().map(|e| e.unwrap()).collect())
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn score_similarity(&self, code_a: &str, code_b: &str) -> Result<f32, CodeEmbedderError> {
let embed_a = self.embed(code_a)?;
let embed_b = self.embed(code_b)?;
Ok(Self::cosine_similarity(&embed_a, &embed_b))
}
pub fn score_completion(
&self,
context: &str,
candidate: &str,
) -> Result<f64, CodeEmbedderError> {
let combined = format!("{}{}", context, candidate);
let context_embed = self.embed(context)?;
let combined_embed = self.embed(&combined)?;
let similarity = Self::cosine_similarity(&context_embed, &combined_embed);
Ok((similarity as f64 + 1.0) / 2.0)
}
pub fn clear_cache(&self) {
if let Some(ref cache) = self.cache {
cache.clear();
}
}
pub fn cache_size(&self) -> usize {
self.cache.as_ref().map(|c| c.len()).unwrap_or(0)
}
pub fn model(&self) -> EmbeddingModel {
self.config.model
}
pub fn architecture(&self) -> ModelArchitecture {
self.architecture
}
}
impl Default for CodeEmbedder {
fn default() -> Self {
Self::new().expect("Failed to create default CodeEmbedder")
}
}
#[derive(Debug, Clone)]
pub enum CodeEmbedderError {
ModelLoad(String),
Embedding(String),
InvalidInput(String),
Cache(String),
}
impl std::fmt::Display for CodeEmbedderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CodeEmbedderError::ModelLoad(msg) => write!(f, "Model load error: {}", msg),
CodeEmbedderError::Embedding(msg) => write!(f, "Embedding error: {}", msg),
CodeEmbedderError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
CodeEmbedderError::Cache(msg) => write!(f, "Cache error: {}", msg),
}
}
}
impl std::error::Error for CodeEmbedderError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_model_config() {
assert_eq!(EmbeddingModel::UniXcoder.embedding_dim(), 768);
assert_eq!(EmbeddingModel::UniXcoder.max_length(), 512);
assert_eq!(
EmbeddingModel::UniXcoder.hf_model_id(),
"microsoft/unixcoder-base"
);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((CodeEmbedder::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!((CodeEmbedder::cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
}
#[test]
fn test_code_embedder_config_default() {
let config = CodeEmbedderConfig::default();
assert_eq!(config.model, EmbeddingModel::UniXcoder);
assert!(config.use_cache);
assert!(config.normalize);
}
#[test]
fn test_architecture_detection() {
let modernbert = r#"{"model_type": "modernbert"}"#;
assert_eq!(
ModelArchitecture::from_config(modernbert),
ModelArchitecture::ModernBert
);
let roberta = r#"{"model_type": "roberta"}"#;
assert_eq!(
ModelArchitecture::from_config(roberta),
ModelArchitecture::Roberta
);
let bert = r#"{"model_type": "bert"}"#;
assert_eq!(
ModelArchitecture::from_config(bert),
ModelArchitecture::Bert
);
let unknown = r#"{"model_type": "gpt2"}"#;
assert_eq!(
ModelArchitecture::from_config(unknown),
ModelArchitecture::Unknown
);
let missing = r#"{"hidden_size": 768}"#;
assert_eq!(
ModelArchitecture::from_config(missing),
ModelArchitecture::Unknown
);
let invalid_json = "not valid json";
assert_eq!(
ModelArchitecture::from_config(invalid_json),
ModelArchitecture::Unknown
);
}
#[test]
fn test_model_format_detection() {
assert_eq!(
ModelFormat::Onnx.expected_files(),
&["model.onnx", "tokenizer.json"]
);
assert_eq!(
ModelFormat::SafeTensors.expected_files(),
&["model.safetensors", "config.json", "tokenizer.json"]
);
}
}