use crate::error::{PiiError, PiiResult};
use crate::nlp::NlpEngine;
use crate::types::{Language, NerSpan, NlpArtifacts};
use candle::Device;
pub trait CandleNerModel: Send + Sync {
fn model_name(&self) -> &str;
fn infer(&self, device: &Device, text: &str, language: &Language) -> PiiResult<Vec<NerSpan>>;
}
pub struct CandleNerEngine {
base: Box<dyn NlpEngine>,
model: Box<dyn CandleNerModel>,
device: Device,
}
impl CandleNerEngine {
pub fn new(base: Box<dyn NlpEngine>, model: Box<dyn CandleNerModel>) -> PiiResult<Self> {
let device = Device::Cpu;
Ok(Self { base, model, device })
}
}
impl NlpEngine for CandleNerEngine {
fn analyze(&self, text: &str, language: &Language) -> PiiResult<NlpArtifacts> {
let mut artifacts = self.base.analyze(text, language)?;
let spans = self
.model
.infer(&self.device, text, language)
.map_err(|err| PiiError::NlpEngine(err.to_string()))?;
artifacts.ner = spans;
artifacts.capabilities.ner = !artifacts.ner.is_empty();
Ok(artifacts)
}
}