use candle_core::Tensor;
use candle_nn::{linear, Linear, Module, VarBuilder};
pub struct Classifier {
linear_0: Linear,
linear_2: Linear,
}
impl Classifier {
pub fn from_var_builder(vb: &VarBuilder) -> candle_core::Result<Self> {
let linear_0 = linear(768, 1536, vb.pp("0"))?;
let linear_2 = linear(1536, 1, vb.pp("2"))?;
Ok(Self { linear_0, linear_2 })
}
pub fn forward(&self, field_embs: &Tensor) -> candle_core::Result<Tensor> {
let h1 = self.linear_0.forward(field_embs)?.relu()?;
let logits = self.linear_2.forward(&h1)?;
logits.squeeze(1)
}
}