use std::path::Path;
use candle_core::{DType, Device};
use candle_nn::VarBuilder;
pub mod classifier;
pub mod count_lstm;
pub mod count_pred;
pub mod schema_gather;
pub mod scorer;
pub mod span_rep;
pub mod token_gather;
pub struct AllHeads {
pub span_rep: span_rep::SpanRep,
pub count_lstm: count_lstm::CountLstmFixed,
pub count_pred: count_pred::CountPred,
pub classifier: classifier::Classifier,
}
impl AllHeads {
pub fn from_safetensors(weights_path: &Path, device: &Device) -> crate::Result<Self> {
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, device) }
.map_err(|e| {
crate::Error::Backend(format!("gliner2_fastino_candle: heads safetensors: {e}"))
})?;
let span_rep = span_rep::SpanRep::from_var_builder(&vb.pp("span_rep").pp("span_rep_layer"))
.map_err(|e| {
crate::Error::Backend(format!("gliner2_fastino_candle: span_rep load: {e}"))
})?;
let count_lstm =
count_lstm::CountLstmFixed::from_var_builder(&vb.pp("count_embed"), device).map_err(
|e| crate::Error::Backend(format!("gliner2_fastino_candle: count_embed load: {e}")),
)?;
let count_pred =
count_pred::CountPred::from_var_builder(&vb.pp("count_pred")).map_err(|e| {
crate::Error::Backend(format!("gliner2_fastino_candle: count_pred load: {e}"))
})?;
let classifier =
classifier::Classifier::from_var_builder(&vb.pp("classifier")).map_err(|e| {
crate::Error::Backend(format!("gliner2_fastino_candle: classifier load: {e}"))
})?;
Ok(Self {
span_rep,
count_lstm,
count_pred,
classifier,
})
}
pub fn from_var_builder(vb: VarBuilder<'_>, device: &Device) -> crate::Result<Self> {
let span_rep = span_rep::SpanRep::from_var_builder(&vb.pp("span_rep").pp("span_rep_layer"))
.map_err(|e| {
crate::Error::Backend(format!("gliner2_fastino_candle: span_rep load (vb): {e}"))
})?;
let count_lstm =
count_lstm::CountLstmFixed::from_var_builder(&vb.pp("count_embed"), device).map_err(
|e| {
crate::Error::Backend(format!(
"gliner2_fastino_candle: count_embed load (vb): {e}"
))
},
)?;
let count_pred =
count_pred::CountPred::from_var_builder(&vb.pp("count_pred")).map_err(|e| {
crate::Error::Backend(format!("gliner2_fastino_candle: count_pred load (vb): {e}"))
})?;
let classifier =
classifier::Classifier::from_var_builder(&vb.pp("classifier")).map_err(|e| {
crate::Error::Backend(format!("gliner2_fastino_candle: classifier load (vb): {e}"))
})?;
Ok(Self {
span_rep,
count_lstm,
count_pred,
classifier,
})
}
}