use crate::parser::bio_tags::BioTag;
use candle_core::{DType, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use candle_transformers::models::distilbert::{Config, DistilBertModel};
pub struct CrfModel {
pub distilbert: DistilBertModel,
pub emission: Linear,
pub transitions: Tensor,
}
impl CrfModel {
pub fn load(vb: VarBuilder, config: Config) -> Result<Self> {
let distilbert = DistilBertModel::load(vb.clone(), &config)?;
let hidden_size = 768; let num_labels = BioTag::NUM_TAGS;
let emission = candle_nn::linear(hidden_size, num_labels, vb.clone())?;
let dev = vb.device();
let transitions = match vb.get((num_labels, num_labels), "crf_transitions.weight") {
Ok(t) => t,
Err(_) => Tensor::zeros((num_labels, num_labels), DType::F32, dev)?,
};
Ok(Self {
distilbert,
emission,
transitions,
})
}
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let hidden_states = self.distilbert.forward(input_ids, attention_mask)?;
let emissions = self.emission.forward(&hidden_states)?;
Ok(emissions)
}
}