#[cfg(feature = "huggingface")]
use anyhow::{Context, Result};
#[cfg(feature = "huggingface")]
use candle_core::{DType, Device, Tensor};
#[cfg(feature = "huggingface")]
use candle_nn::VarBuilder;
#[cfg(feature = "huggingface")]
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use candle_transformers::models::jina_bert::Config as JinaBertConfig;
#[cfg(feature = "huggingface")]
use hf_hub::{api::tokio::Api, Repo, RepoType};
#[cfg(feature = "huggingface")]
use std::collections::HashMap;
#[cfg(feature = "huggingface")]
use std::sync::Arc;
#[cfg(feature = "huggingface")]
use tokenizers::Tokenizer;
#[cfg(feature = "huggingface")]
use tokio::sync::RwLock;
#[cfg(feature = "huggingface")]
pub struct HuggingFaceModel {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
}
#[cfg(feature = "huggingface")]
impl HuggingFaceModel {
pub async fn load(model_name: &str) -> Result<Self> {
let device = Device::Cpu;
let cache_dir = crate::storage::get_huggingface_cache_dir()
.context("Failed to get HuggingFace cache directory")?;
std::env::set_var("HF_HOME", &cache_dir);
let api = Api::new().context("Failed to initialize HuggingFace API")?;
let repo = api.repo(Repo::new(model_name.to_string(), RepoType::Model));
let config_path = repo
.get("config.json")
.await
.with_context(|| format!("Failed to download config.json for model: {}", model_name))?;
let tokenizer = if let Ok(tokenizer_json_path) = repo.get("tokenizer.json").await {
Tokenizer::from_file(tokenizer_json_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?
} else {
if let (Ok(vocab_path), Ok(merges_path)) =
(repo.get("vocab.json").await, repo.get("merges.txt").await)
{
use tokenizers::{
models::bpe::BPE, normalizers, pre_tokenizers::byte_level::ByteLevel,
processors::roberta::RobertaProcessing,
};
let bpe = BPE::from_file(
vocab_path
.to_str()
.ok_or_else(|| anyhow::anyhow!("Invalid vocab path"))?,
merges_path
.to_str()
.ok_or_else(|| anyhow::anyhow!("Invalid merges path"))?,
)
.unk_token("<unk>".to_string())
.build()
.map_err(|e| anyhow::anyhow!("Failed to build BPE tokenizer: {:?}", e))?;
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
let post_processor = RobertaProcessing::new(
("</s>".to_string(), 2), ("<s>".to_string(), 0), )
.trim_offsets(false)
.add_prefix_space(true);
tokenizer.with_post_processor(Some(post_processor));
let normalizer =
normalizers::Sequence::new(vec![normalizers::Strip::new(true, true).into()]);
tokenizer.with_normalizer(Some(normalizer));
tokenizer
} else {
return Err(anyhow::anyhow!(
"Could not find tokenizer files for model: {}. \
Expected either tokenizer.json or (vocab.json + merges.txt). \
This model may not be compatible.",
model_name
));
}
};
let weights_path = if let Ok(path) = repo.get("model.safetensors").await {
path
} else if let Ok(path) = repo.get("pytorch_model.bin").await {
path
} else {
return Err(anyhow::anyhow!(
"Could not find model weights in safetensors or pytorch format"
));
};
let config_content = std::fs::read_to_string(config_path)?;
let config: BertConfig = serde_json::from_str(&config_content)?;
let weights = if weights_path.to_string_lossy().ends_with(".safetensors") {
candle_core::safetensors::load(&weights_path, &device)?
} else {
return Err(anyhow::anyhow!("PyTorch .bin format not supported in this implementation. Please use a model with safetensors format."));
};
let var_builder = VarBuilder::from_tensors(weights, DType::F32, &device);
let model = BertModel::load(var_builder, &config)?;
Ok(Self {
model,
tokenizer,
device,
})
}
pub fn encode(&self, text: &str) -> Result<Vec<f32>> {
self.encode_batch(&[text.to_string()])
.map(|embeddings| embeddings.into_iter().next().unwrap_or_default())
}
pub fn encode_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let mut all_embeddings = Vec::new();
for text in texts {
let encoding = self
.tokenizer
.encode(text.as_str(), true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
let tokens = encoding.get_ids();
let token_ids = Tensor::new(tokens, &self.device)?.unsqueeze(0)?;
let attention_mask = Tensor::ones((1, tokens.len()), DType::U8, &self.device)?;
let output = self.model.forward(&token_ids, &attention_mask, None)?;
let embeddings = self.mean_pooling(&output, &attention_mask)?;
let normalized = self.normalize(&embeddings)?;
let embedding_vec = normalized.to_vec1::<f32>()?;
all_embeddings.push(embedding_vec);
}
Ok(all_embeddings)
}
fn mean_pooling(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let attention_mask = attention_mask.to_dtype(DType::F32)?;
let attention_mask = attention_mask.unsqueeze(2)?;
let masked_hidden_states = hidden_states.mul(&attention_mask)?;
let sum_hidden_states = masked_hidden_states.sum(1)?;
let sum_mask = attention_mask.sum(1)?;
let mean_pooled = sum_hidden_states.div(&sum_mask)?;
Ok(mean_pooled)
}
fn normalize(&self, embeddings: &Tensor) -> Result<Tensor> {
let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;
Ok(embeddings.div(&norm)?)
}
}
#[cfg(feature = "huggingface")]
lazy_static::lazy_static! {
static ref MODEL_CACHE: Arc<RwLock<HashMap<String, Arc<HuggingFaceModel>>>> =
Arc::new(RwLock::new(HashMap::new()));
}
#[cfg(feature = "huggingface")]
pub struct HuggingFaceProvider;
#[cfg(feature = "huggingface")]
impl HuggingFaceProvider {
async fn get_model(model_name: &str) -> Result<Arc<HuggingFaceModel>> {
{
let cache = MODEL_CACHE.read().await;
if let Some(model) = cache.get(model_name) {
return Ok(model.clone());
}
}
let model = HuggingFaceModel::load(model_name)
.await
.with_context(|| format!("Failed to load HuggingFace model: {}", model_name))?;
let model_arc = Arc::new(model);
{
let mut cache = MODEL_CACHE.write().await;
cache.insert(model_name.to_string(), model_arc.clone());
}
Ok(model_arc)
}
pub async fn generate_embeddings(contents: &str, model: &str) -> Result<Vec<f32>> {
let model_instance = Self::get_model(model).await?;
let contents = contents.to_string();
let result =
tokio::task::spawn_blocking(move || model_instance.encode(&contents)).await??;
Ok(result)
}
pub async fn generate_embeddings_batch(
texts: Vec<String>,
model: &str,
) -> Result<Vec<Vec<f32>>> {
let model_instance = Self::get_model(model).await?;
let result =
tokio::task::spawn_blocking(move || model_instance.encode_batch(&texts)).await??;
Ok(result)
}
}
#[cfg(not(feature = "huggingface"))]
use anyhow::Result;
#[cfg(not(feature = "huggingface"))]
pub struct HuggingFaceProvider;
#[cfg(not(feature = "huggingface"))]
impl HuggingFaceProvider {
pub async fn generate_embeddings(_contents: &str, _model: &str) -> Result<Vec<f32>> {
Err(anyhow::anyhow!(
"HuggingFace support is not compiled in. Please rebuild with --features huggingface"
))
}
pub async fn generate_embeddings_batch(
_texts: Vec<String>,
_model: &str,
) -> Result<Vec<Vec<f32>>> {
Err(anyhow::anyhow!(
"HuggingFace support is not compiled in. Please rebuild with --features huggingface"
))
}
}
use super::super::types::InputType;
use super::EmbeddingProvider;
#[cfg(feature = "huggingface")]
pub struct HuggingFaceProviderImpl {
model_name: String,
dimension: usize,
}
#[cfg(feature = "huggingface")]
impl HuggingFaceProviderImpl {
pub async fn new(model: &str) -> Result<Self> {
#[cfg(not(feature = "huggingface"))]
{
Err(anyhow::anyhow!("HuggingFace provider requires 'huggingface' feature to be enabled. Cannot validate model '{}' without Hub API access.", model))
}
#[cfg(feature = "huggingface")]
{
let dimension = Self::get_model_dimension(model).await?;
Ok(Self {
model_name: model.to_string(),
dimension,
})
}
}
#[cfg(feature = "huggingface")]
async fn get_model_dimension(model: &str) -> Result<usize> {
Self::get_dimension_from_config(model).await
}
#[cfg(feature = "huggingface")]
async fn get_dimension_from_config(model_name: &str) -> Result<usize> {
let config_json = Self::download_config_direct(model_name).await?;
if let Ok(config) = Self::parse_as_jina_bert_config(&config_json) {
return Ok(config.hidden_size);
}
if let Ok(config) = Self::parse_as_bert_config(&config_json) {
return Ok(config.hidden_size);
}
Self::parse_hidden_size_from_json(&config_json, model_name)
}
#[cfg(feature = "huggingface")]
fn parse_as_jina_bert_config(config_json: &str) -> Result<JinaBertConfig> {
serde_json::from_str::<JinaBertConfig>(config_json)
.map_err(|e| anyhow::anyhow!("Failed to parse as JinaBertConfig: {}", e))
}
#[cfg(feature = "huggingface")]
fn parse_as_bert_config(
config_json: &str,
) -> Result<candle_transformers::models::bert::Config> {
use candle_transformers::models::bert::Config as BertConfig;
serde_json::from_str::<BertConfig>(config_json)
.map_err(|e| anyhow::anyhow!("Failed to parse as BertConfig: {}", e))
}
#[cfg(feature = "huggingface")]
fn parse_hidden_size_from_json(config_json: &str, model_name: &str) -> Result<usize> {
use serde_json::Value;
let config: Value = serde_json::from_str(config_json).with_context(|| {
format!(
"Failed to parse config.json as JSON for model: {}",
model_name
)
})?;
let dimension_fields = ["hidden_size", "d_model", "embedding_size", "dim"];
for field in &dimension_fields {
if let Some(dim) = config.get(field).and_then(|v| v.as_u64()) {
tracing::debug!(
"Found dimension {} for model {} from config.json field '{}'",
dim,
model_name,
field
);
return Ok(dim as usize);
}
}
Err(anyhow::anyhow!(
"No dimension field found in config.json for model '{}'. \
Searched for fields: {:?}. Available fields: {:?}",
model_name,
dimension_fields,
config
.as_object()
.map(|obj| obj.keys().collect::<Vec<_>>())
.unwrap_or_default()
))
}
#[cfg(feature = "huggingface")]
async fn download_config_direct(model_name: &str) -> Result<String> {
use reqwest;
let config_url = format!("https://huggingface.co/{}/raw/main/config.json", model_name);
tracing::debug!("Downloading config from: {}", config_url);
let client = reqwest::Client::new();
let response = client
.get(&config_url)
.header("User-Agent", "octocode/0.7.1")
.send()
.await
.with_context(|| format!("Failed to download config.json from {}", config_url))?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to download config.json for model '{}'. HTTP status: {}. \
This could be due to:\n\
1. Model doesn't exist on HuggingFace Hub\n\
2. Network connectivity issues\n\
3. Model is private and requires authentication\n\
4. Model doesn't have a config.json file",
model_name,
response.status()
));
}
let config_text = response.text().await.with_context(|| {
format!(
"Failed to read config.json response for model: {}",
model_name
)
})?;
Ok(config_text)
}
}
#[cfg(feature = "huggingface")]
#[async_trait::async_trait]
impl EmbeddingProvider for HuggingFaceProviderImpl {
async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
HuggingFaceProvider::generate_embeddings(text, &self.model_name).await
}
async fn generate_embeddings_batch(
&self,
texts: Vec<String>,
input_type: InputType,
) -> Result<Vec<Vec<f32>>> {
let processed_texts: Vec<String> = texts
.into_iter()
.map(|text| input_type.apply_prefix(&text))
.collect();
HuggingFaceProvider::generate_embeddings_batch(processed_texts, &self.model_name).await
}
fn get_dimension(&self) -> usize {
self.dimension
}
fn is_model_supported(&self) -> bool {
true
}
}
#[cfg(all(test, feature = "huggingface"))]
mod tests {
#[test]
fn test_roberta_tokenizer_building() {
use tokenizers::{
models::bpe::BPE, pre_tokenizers::byte_level::ByteLevel,
processors::roberta::RobertaProcessing, Tokenizer,
};
let vocab_file = std::env::temp_dir().join("test_vocab.json");
let merges_file = std::env::temp_dir().join("test_merges.txt");
let vocab_content = r#"{"<s>":0,"<pad>":1,"</s>":2,"<unk>":3,"h":4,"e":5,"l":6,"o":7,"r":8,"he":9,"ll":10,"or":11,"hello":12,"world":13}"#;
std::fs::write(&vocab_file, vocab_content).expect("Failed to write vocab");
let merges_content = "#version: 0.2\nh e\nl l\no r";
std::fs::write(&merges_file, merges_content).expect("Failed to write merges");
let bpe = BPE::from_file(vocab_file.to_str().unwrap(), merges_file.to_str().unwrap())
.unk_token("<unk>".to_string())
.build()
.expect("Failed to build BPE tokenizer");
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
let post_processor = RobertaProcessing::new(
("</s>".to_string(), 2), ("<s>".to_string(), 0), )
.trim_offsets(false)
.add_prefix_space(true);
tokenizer.with_post_processor(Some(post_processor));
let test_text = "hello world";
let encoding = tokenizer
.encode(test_text, false)
.expect("Failed to encode");
assert!(
!encoding.get_ids().is_empty(),
"Encoding should produce tokens"
);
println!("✓ RoBERTa-style tokenizer built successfully using BPE::from_file");
let _ = std::fs::remove_file(vocab_file);
let _ = std::fs::remove_file(merges_file);
}
#[test]
fn test_merges_parsing() {
let merges_content = r#"#version: 0.2
Ġ t
Ġ a
h e
Ġt he
i n"#;
let merges: Vec<(String, String)> = merges_content
.lines()
.skip(1) .filter_map(|line| {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() == 2 {
Some((parts[0].to_string(), parts[1].to_string()))
} else {
None
}
})
.collect();
assert_eq!(merges.len(), 5);
assert_eq!(merges[0], ("Ġ".to_string(), "t".to_string()));
println!("✓ Merges parsing works correctly");
}
}