use crate::{Entity, EntityCategory, EntityType, Error, Language, Model, Result};
#[cfg(feature = "candle")]
use {
super::encoder_candle::{CandleEncoder, CandleTextEncoder},
candle_core::{DType, Device, Module, Tensor, D},
candle_nn::{linear, Linear, VarBuilder},
std::collections::HashMap,
tokenizers::Tokenizer,
};
const CONLL_LABELS: &[&str] = &[
"O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
];
#[cfg(feature = "candle")]
pub struct CandleNER {
encoder: CandleEncoder,
classifier: Linear,
id2label: Vec<String>,
model_name: String,
device: Device,
}
#[cfg(feature = "candle")]
impl CandleNER {
pub fn from_pretrained(model_id: &str) -> Result<Self> {
use crate::backends::hf_loader;
let device = super::encoder_candle::best_device()?;
let api = hf_loader::hf_api()?;
let repo = api.model(model_id.to_string());
let config_path = hf_loader::download_model_file(&repo, &["config.json"])?;
let weights_path = hf_loader::download_model_file(&repo, &["model.safetensors"])
.or_else(|_| {
let pytorch_path = hf_loader::download_model_file(&repo, &["pytorch_model.bin"])?;
crate::backends::gliner_candle::convert_pytorch_to_safetensors(&pytorch_path)
})
.map_err(|e| Error::Retrieval(format!(
"model.safetensors not found and conversion failed. CandleNER requires safetensors format. \
The model may only have pytorch_model.bin. Attempted automatic conversion but it failed. \
Consider using BertNEROnnx (ONNX version) instead. \
Original error: {}",
e
)))?;
let tokenizer_path =
hf_loader::download_model_file(&repo, &["tokenizer.json", "vocab.txt"])?;
let config_str = std::fs::read_to_string(&config_path)
.map_err(|e| Error::Retrieval(format!("read config: {}", e)))?;
let config_json: serde_json::Value = serde_json::from_str(&config_str)
.map_err(|e| Error::Parse(format!("config JSON: {}", e)))?;
let encoder_config = CandleEncoder::parse_config(&config_str)?;
let id2label = Self::parse_labels(&config_json)?;
let num_labels = id2label.len();
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
.map_err(|e| Error::Retrieval(format!("safetensors: {}", e)))?
};
let encoder_tokenizer = if tokenizer_path.ends_with("tokenizer.json") {
Tokenizer::from_file(&tokenizer_path)
.map_err(|e| Error::Retrieval(format!("tokenizer: {}", e)))?
} else if tokenizer_path.ends_with("vocab.txt") {
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::processors::bert::BertProcessing;
use tokenizers::Tokenizer as TokenizerImpl;
let vocab_str = tokenizer_path
.to_str()
.ok_or_else(|| Error::Retrieval("Invalid tokenizer path".to_string()))?;
let model = WordPiece::from_file(vocab_str).build().map_err(|e| {
Error::Retrieval(format!("Failed to create WordPiece from vocab.txt: {}", e))
})?;
let mut tokenizer_impl = TokenizerImpl::new(model);
tokenizer_impl.with_normalizer(Some(BertNormalizer::new(
false, true, None, false, )));
tokenizer_impl.with_pre_tokenizer(Some(BertPreTokenizer));
tokenizer_impl.with_post_processor(Some(BertProcessing::default()));
tokenizer_impl
} else {
return Err(Error::Retrieval("Unsupported tokenizer format".to_string()));
};
let encoder = CandleEncoder::from_vb(
encoder_config.clone(),
vb.pp("bert"),
encoder_tokenizer,
device.clone(),
)?;
let classifier = linear(encoder_config.hidden_size, num_labels, vb.pp("classifier"))
.map_err(|e| Error::Retrieval(format!("classifier: {}", e)))?;
log::info!(
"[CandleNER] Loaded {} with {} labels on {:?}",
model_id,
num_labels,
device
);
Ok(Self {
encoder,
classifier,
id2label,
model_name: model_id.to_string(),
device,
})
}
pub fn new(model_id: &str) -> Result<Self> {
Self::from_pretrained(model_id)
}
fn parse_labels(config: &serde_json::Value) -> Result<Vec<String>> {
if let Some(id2label) = config.get("id2label") {
let map: HashMap<String, String> = serde_json::from_value(id2label.clone())
.map_err(|e| Error::Parse(format!("id2label: {}", e)))?;
let max_id = map
.keys()
.filter_map(|k| k.parse::<usize>().ok())
.max()
.unwrap_or(0);
let mut labels = vec!["O".to_string(); max_id + 1];
for (id_str, label) in map {
if let Ok(id) = id_str.parse::<usize>() {
labels[id] = label;
}
}
Ok(labels)
} else {
Ok(CONLL_LABELS.iter().map(|s| s.to_string()).collect())
}
}
pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
if text.is_empty() {
return Ok(vec![]);
}
let (embeddings, seq_len, offsets) = self.encoder.encode_with_offsets(text)?;
let hidden_dim = self.encoder.hidden_dim();
let hidden = Tensor::from_vec(embeddings, (1, seq_len, hidden_dim), &self.device)
.map_err(|e| Error::Parse(format!("hidden tensor: {}", e)))?;
let logits = self
.classifier
.forward(&hidden)
.map_err(|e| Error::Parse(format!("classifier forward: {}", e)))?;
let predictions = logits
.argmax(D::Minus1)
.map_err(|e| Error::Parse(format!("argmax: {}", e)))?
.flatten_all()
.map_err(|e| Error::Parse(format!("flatten: {}", e)))?
.to_vec1::<u32>()
.map_err(|e| Error::Parse(format!("to_vec: {}", e)))?;
self.decode_with_offsets(text, &predictions, &offsets)
}
fn decode_with_offsets(
&self,
text: &str,
predictions: &[u32],
offsets: &[(usize, usize)],
) -> Result<Vec<Entity>> {
let mut entities = Vec::with_capacity(16);
let mut current_entity: Option<(usize, usize, String, f64)> = None;
let span_converter = crate::offset::SpanConverter::new(text);
for (token_idx, &pred) in predictions.iter().enumerate() {
if token_idx >= offsets.len() {
break;
}
let (byte_start, byte_end) = offsets[token_idx];
if byte_start == byte_end {
if let Some((start, end, etype, conf)) = current_entity.take() {
if let Some(e) = self.create_entity_from_offsets(
text,
&span_converter,
start,
end,
&etype,
conf,
) {
entities.push(e);
}
}
continue;
}
let label = self
.id2label
.get(pred as usize)
.map(|s| s.as_str())
.unwrap_or("O");
if label == "O" {
if let Some((start, end, etype, conf)) = current_entity.take() {
if let Some(e) = self.create_entity_from_offsets(
text,
&span_converter,
start,
end,
&etype,
conf,
) {
entities.push(e);
}
}
} else if label.starts_with("B-") {
if let Some((start, end, etype, conf)) = current_entity.take() {
if let Some(e) = self.create_entity_from_offsets(
text,
&span_converter,
start,
end,
&etype,
conf,
) {
entities.push(e);
}
}
let entity_type = label.strip_prefix("B-").unwrap_or("MISC");
current_entity = Some((byte_start, byte_end, entity_type.to_string(), 0.9));
} else if label.starts_with("I-") {
let entity_type = label.strip_prefix("I-").unwrap_or("MISC");
if let Some((start, _, ref etype, conf)) = current_entity {
if entity_type == etype {
current_entity = Some((start, byte_end, etype.clone(), conf));
} else {
if let Some((s, e, t, c)) = current_entity.take() {
if let Some(ent) =
self.create_entity_from_offsets(text, &span_converter, s, e, &t, c)
{
entities.push(ent);
}
}
current_entity = Some((byte_start, byte_end, entity_type.to_string(), 0.9));
}
} else {
current_entity = Some((byte_start, byte_end, entity_type.to_string(), 0.9));
}
}
}
if let Some((start, end, etype, conf)) = current_entity.take() {
if let Some(e) =
self.create_entity_from_offsets(text, &span_converter, start, end, &etype, conf)
{
entities.push(e);
}
}
Ok(entities)
}
fn create_entity_from_offsets(
&self,
text: &str,
span_converter: &crate::offset::SpanConverter,
byte_start: usize,
byte_end: usize,
entity_type: &str,
confidence: f64,
) -> Option<Entity> {
if byte_start >= byte_end || byte_end > text.len() {
return None;
}
let entity_text = text.get(byte_start..byte_end)?;
if entity_text.trim().is_empty() {
return None;
}
let char_start = span_converter.byte_to_char(byte_start);
let char_end = span_converter.byte_to_char(byte_end);
let etype = match entity_type.to_uppercase().as_str() {
"PER" | "PERSON" => EntityType::Person,
"ORG" | "ORGANIZATION" => EntityType::Organization,
"LOC" | "LOCATION" | "GPE" => EntityType::Location,
"DATE" => EntityType::Date,
"TIME" => EntityType::Time,
"MONEY" => EntityType::Money,
"PERCENT" => EntityType::Percent,
"MISC" => EntityType::custom("MISC", EntityCategory::Misc),
other => EntityType::custom(other, EntityCategory::Misc),
};
Some(Entity::new(
entity_text.trim().to_string(),
etype,
char_start,
char_end,
confidence,
))
}
pub fn model_name(&self) -> &str {
&self.model_name
}
pub fn device(&self) -> String {
match &self.device {
Device::Cpu => "cpu".to_string(),
Device::Metal(_) => "metal".to_string(),
Device::Cuda(_) => "cuda".to_string(),
}
}
}
#[cfg(feature = "candle")]
impl Model for CandleNER {
fn extract_entities(&self, text: &str, _language: Option<Language>) -> Result<Vec<Entity>> {
self.extract(text)
}
fn supported_types(&self) -> Vec<EntityType> {
self.id2label
.iter()
.filter(|l| l.starts_with("B-"))
.map(|l| {
let tag = l.strip_prefix("B-").unwrap_or("MISC");
match tag.to_uppercase().as_str() {
"PER" | "PERSON" => EntityType::Person,
"ORG" | "ORGANIZATION" => EntityType::Organization,
"LOC" | "LOCATION" | "GPE" => EntityType::Location,
other => EntityType::custom(other, EntityCategory::Misc),
}
})
.collect()
}
fn is_available(&self) -> bool {
true
}
fn name(&self) -> &'static str {
"CandleNER"
}
fn description(&self) -> &'static str {
"BERT token classification NER using Candle (pure Rust, GPU support)"
}
fn version(&self) -> String {
format!("candle-ner-{}-{}", self.model_name, self.device())
}
fn capabilities(&self) -> crate::ModelCapabilities {
crate::ModelCapabilities::default()
}
}
crate::backends::macros::define_feature_stub! {
struct CandleNER;
feature = "candle";
name = "CandleNER (unavailable)";
description = "BERT NER with Candle - requires 'candle' feature";
error_msg = "CandleNER requires the 'candle' feature";
methods {
pub fn from_pretrained(_model_id: &str) -> crate::Result<Self> {
Self::new("")
}
pub fn model_name(&self) -> &str {
"candle-disabled"
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stub_without_feature() {
#[cfg(not(feature = "candle"))]
{
let result = CandleNER::new("test");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("candle"));
}
}
#[test]
fn test_conll_labels() {
assert_eq!(CONLL_LABELS.len(), 9);
assert_eq!(CONLL_LABELS[0], "O");
assert!(CONLL_LABELS.contains(&"B-PER"));
}
#[test]
fn test_conll_labels_complete() {
assert!(CONLL_LABELS.contains(&"O"));
assert!(CONLL_LABELS.contains(&"B-PER"));
assert!(CONLL_LABELS.contains(&"I-PER"));
assert!(CONLL_LABELS.contains(&"B-ORG"));
assert!(CONLL_LABELS.contains(&"I-ORG"));
assert!(CONLL_LABELS.contains(&"B-LOC"));
assert!(CONLL_LABELS.contains(&"I-LOC"));
assert!(CONLL_LABELS.contains(&"B-MISC"));
assert!(CONLL_LABELS.contains(&"I-MISC"));
}
#[test]
fn test_conll_labels_bio_pairing() {
for label in CONLL_LABELS {
if label.starts_with("B-") {
let i_tag = label.replacen("B-", "I-", 1);
assert!(
CONLL_LABELS.contains(&i_tag.as_str()),
"B-tag {} has no matching I-tag",
label
);
}
}
}
#[cfg(feature = "candle")]
#[test]
fn test_parse_labels_with_config() {
let config: serde_json::Value =
serde_json::from_str(r#"{"id2label": {"0": "O", "1": "B-PER", "2": "I-PER"}}"#)
.unwrap();
let labels = CandleNER::parse_labels(&config).unwrap();
assert_eq!(labels.len(), 3);
assert_eq!(labels[0], "O");
assert_eq!(labels[1], "B-PER");
assert_eq!(labels[2], "I-PER");
}
#[cfg(feature = "candle")]
#[test]
fn test_parse_labels_fallback() {
let config: serde_json::Value = serde_json::from_str(r#"{}"#).unwrap();
let labels = CandleNER::parse_labels(&config).unwrap();
assert_eq!(labels.len(), CONLL_LABELS.len());
assert_eq!(labels[0], "O");
}
#[cfg(feature = "candle")]
#[test]
fn test_parse_labels_sparse_ids() {
let config: serde_json::Value =
serde_json::from_str(r#"{"id2label": {"0": "O", "5": "B-PER"}}"#).unwrap();
let labels = CandleNER::parse_labels(&config).unwrap();
assert_eq!(labels.len(), 6); assert_eq!(labels[0], "O");
assert_eq!(labels[5], "B-PER");
assert_eq!(labels[3], "O");
}
}