from candle.utils import load_safetensors, save_gguf, load_gguf
from candle.models.bert import BertModel, Config
import json
from candle import Tensor
from tqdm import tqdm
from dataclasses import fields
import os
import time
from huggingface_hub import hf_hub_download
from transformers import BertTokenizer, AutoModel
import torch
if __name__ == "__main__":
model_name = "intfloat/e5-small-v2"
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
config_file = hf_hub_download(repo_id=model_name, filename="config.json")
tensors = load_safetensors(model_file)
config = Config()
with open(config_file, "r") as f:
raw_config = json.load(f)
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])
model = BertModel(config)
model.load_state_dict(tensors)
hf_model = AutoModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
]
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
tokenized = tokenizer(sentences, padding=True)
tokens = Tensor(tokenized["input_ids"])
token_type_ids = Tensor(tokenized["token_type_ids"])
attention_mask = Tensor(tokenized["attention_mask"])
encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask)
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])
loss = torch.nn.L1Loss()
error = loss(hf_pooled, candle_pooled).mean().item()
print(f"Mean error between torch-reference and candle: {error}")
quantized_tensors = {}
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
if tensor.shape[-1] % 256 == 0:
new_tensor = tensor.quantize("q4k")
else:
new_tensor = tensor.quantize("q5_0")
quantized_tensors[name] = new_tensor
else:
quantized_tensors[name] = tensor.quantize("q8_0")
print(f"Saving quantized tensors")
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
quantized_model_file = "e5_small.gguf"
save_gguf(quantized_model_file, quantized_tensors, config_to_save)
file_size_mb = os.path.getsize(model_file) / 1024 / 1024
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
tensors, raw_config = load_gguf(quantized_model_file)
config = Config()
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])
model = BertModel(config)
model.load_state_dict(tensors, strict=False)
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")
candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
error = loss(hf_pooled, candle_pooled_2).mean().item()
print(f"Mean error between torch-reference and quantized-candle: {error}")