use crate::layers::create_projection_layer;
use candle_core::{D, Result, Tensor};
use candle_nn::{Activation, Module, Sequential, VarBuilder};
pub struct SpanMarkerV0 {
project_start: Sequential,
project_end: Sequential,
out_project: Sequential,
max_width: usize,
}
impl SpanMarkerV0 {
pub fn load(hidden_size: usize, max_width: usize, vb: VarBuilder) -> Result<Self> {
let project_start =
create_projection_layer(hidden_size, hidden_size, vb.pp("project_start"))?;
let project_end = create_projection_layer(hidden_size, hidden_size, vb.pp("project_end"))?;
let out_project =
create_projection_layer(hidden_size * 2, hidden_size, vb.pp("out_project"))?;
Ok(Self {
project_start,
project_end,
out_project,
max_width,
})
}
pub fn forward(&self, h: &Tensor, span_idx: &Tensor) -> Result<Tensor> {
let (b, l, d) = h.dims3()?;
let start_rep = self.project_start.forward(h)?; let end_rep = self.project_end.forward(h)?;
let starts = span_idx.get_on_dim(D::Minus1, 0)?; let ends = span_idx.get_on_dim(D::Minus1, 1)?;
let start_span_rep = self.extract_elements(&start_rep, &starts)?; let end_span_rep = self.extract_elements(&end_rep, &ends)?;
let cat =
Tensor::cat(&[&start_span_rep, &end_span_rep], D::Minus1)?.apply(&Activation::Relu)?;
let out = self.out_project.forward(&cat)?;
out.reshape((b, l, self.max_width, d))
}
fn extract_elements(&self, h: &Tensor, idx: &Tensor) -> Result<Tensor> {
let (b, _l, d) = h.dims3()?;
let s = idx.dim(1)?;
let expanded_idx = idx.unsqueeze(2)?.expand(&[b, s, d])?.contiguous()?;
h.contiguous()?.gather(&expanded_idx, 1)
}
}