use std::path::Path;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::debertav2::{Config as DebertaV2Config, DebertaV2Model};
pub struct Encoder {
pub(crate) model: DebertaV2Model,
pub(crate) config: DebertaV2Config,
}
impl Encoder {
pub fn from_safetensors(
weights_path: &Path,
config_path: &Path,
device: &Device,
) -> crate::Result<Self> {
let cfg_str = std::fs::read_to_string(config_path).map_err(|e| {
crate::Error::Backend(format!(
"encoder config read {}: {e}",
config_path.display()
))
})?;
let config: DebertaV2Config = serde_json::from_str(&cfg_str).map_err(|e| {
crate::Error::Backend(format!(
"encoder config parse {}: {e}",
config_path.display()
))
})?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, device)
}
.map_err(|e| crate::Error::Backend(format!("encoder safetensors: {e}")))?;
let model = DebertaV2Model::load(vb.pp("encoder"), &config)
.map_err(|e| crate::Error::Backend(format!("encoder DebertaV2Model::load: {e}")))?;
Ok(Self { model, config })
}
pub fn from_var_builder(vb: VarBuilder<'_>, config: &DebertaV2Config) -> crate::Result<Self> {
let model = DebertaV2Model::load(vb, config).map_err(|e| {
crate::Error::Backend(format!("encoder DebertaV2Model::load (vb): {e}"))
})?;
Ok(Self {
model,
config: config.clone(),
})
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
token_type_ids: Option<&Tensor>,
) -> candle_core::Result<Tensor> {
self.model.forward(
input_ids,
token_type_ids.cloned(),
Some(attention_mask.clone()),
)
}
pub fn hidden_size(&self) -> usize {
self.config.hidden_size
}
}