use super::*;
pub struct SpanRepLayer {
project_start_0: Linear,
project_start_3: Linear,
project_end_0: Linear,
project_end_3: Linear,
out_project_0: Linear,
out_project_3: Linear,
hidden_size: usize,
}
#[cfg(feature = "candle")]
impl SpanRepLayer {
pub fn new(hidden_size: usize, _max_width: usize, vb: VarBuilder) -> Result<Self> {
let project_start_0 = linear(hidden_size, hidden_size * 4, vb.pp("project_start").pp("0"))
.map_err(|e| Error::Retrieval(format!("SpanRepLayer project_start.0: {}", e)))?;
let project_start_3 = linear(hidden_size * 4, hidden_size, vb.pp("project_start").pp("3"))
.map_err(|e| Error::Retrieval(format!("SpanRepLayer project_start.3: {}", e)))?;
let project_end_0 = linear(hidden_size, hidden_size * 4, vb.pp("project_end").pp("0"))
.map_err(|e| Error::Retrieval(format!("SpanRepLayer project_end.0: {}", e)))?;
let project_end_3 = linear(hidden_size * 4, hidden_size, vb.pp("project_end").pp("3"))
.map_err(|e| Error::Retrieval(format!("SpanRepLayer project_end.3: {}", e)))?;
let out_project_0 = linear(
hidden_size * 2,
hidden_size * 4,
vb.pp("out_project").pp("0"),
)
.map_err(|e| Error::Retrieval(format!("SpanRepLayer out_project.0: {}", e)))?;
let out_project_3 = linear(hidden_size * 4, hidden_size, vb.pp("out_project").pp("3"))
.map_err(|e| Error::Retrieval(format!("SpanRepLayer out_project.3: {}", e)))?;
Ok(Self {
project_start_0,
project_start_3,
project_end_0,
project_end_3,
out_project_0,
out_project_3,
hidden_size,
})
}
pub fn forward(&self, token_embeddings: &Tensor, span_indices: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, _hidden) = token_embeddings
.dims3()
.map_err(|e| Error::Parse(format!("token_embeddings dims: {}", e)))?;
let (_, _num_spans, _) = span_indices
.dims3()
.map_err(|e| Error::Parse(format!("span_indices dims: {}", e)))?;
let start_rep = self.project_start_0.forward(token_embeddings)?;
let start_rep = start_rep.relu()?;
let start_rep = self.project_start_3.forward(&start_rep)?;
let end_rep = self.project_end_0.forward(token_embeddings)?;
let end_rep = end_rep.relu()?;
let end_rep = self.project_end_3.forward(&end_rep)?;
let start_idx = span_indices.i((.., .., 0))?.to_dtype(DType::U32)?;
let end_idx = span_indices.i((.., .., 1))?.to_dtype(DType::U32)?;
let mut span_embs = Vec::new();
for b in 0..batch_size {
let batch_start_rep = start_rep.i(b)?;
let batch_end_rep = end_rep.i(b)?;
let batch_starts = start_idx.i(b)?;
let batch_ends = end_idx.i(b)?;
let max_idx = (seq_len - 1) as u32;
let batch_starts = batch_starts.clamp(0f64, max_idx as f64)?;
let batch_ends = batch_ends.clamp(0f64, max_idx as f64)?;
let start_span_rep = batch_start_rep
.index_select(&batch_starts.to_dtype(DType::U32)?, 0)
.map_err(|e| Error::Parse(format!("start index_select: {}", e)))?;
let end_span_rep = batch_end_rep
.index_select(&batch_ends.to_dtype(DType::U32)?, 0)
.map_err(|e| Error::Parse(format!("end index_select: {}", e)))?;
let cat = Tensor::cat(&[&start_span_rep, &end_span_rep], D::Minus1)?;
let cat = cat.relu()?;
let out = self.out_project_0.forward(&cat)?;
let out = out.relu()?;
let out = self.out_project_3.forward(&out)?;
span_embs.push(out);
}
Tensor::stack(&span_embs, 0).map_err(|e| Error::Parse(format!("stack span_embs: {}", e)))
}
}
#[cfg(feature = "candle")]
pub struct LabelEncoder {
linear_0: Linear,
linear_3: Linear,
}
#[cfg(feature = "candle")]
impl LabelEncoder {
pub fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
let linear_0 = linear(hidden_size, hidden_size * 4, vb.pp("0"))
.map_err(|e| Error::Retrieval(format!("LabelEncoder.0: {}", e)))?;
let linear_3 = linear(hidden_size * 4, hidden_size, vb.pp("3"))
.map_err(|e| Error::Retrieval(format!("LabelEncoder.3: {}", e)))?;
Ok(Self { linear_0, linear_3 })
}
pub fn forward(&self, label_embeddings: &Tensor) -> Result<Tensor> {
let out = self
.linear_0
.forward(label_embeddings)
.map_err(|e| Error::Parse(format!("label projection 0: {}", e)))?;
let out = out
.relu()
.map_err(|e| Error::Parse(format!("label relu: {}", e)))?;
self.linear_3
.forward(&out)
.map_err(|e| Error::Parse(format!("label projection 3: {}", e)))
}
}
#[cfg(feature = "candle")]
pub struct SpanLabelMatcher {
temperature: f64,
}
#[cfg(feature = "candle")]
impl SpanLabelMatcher {
pub fn new(temperature: f64) -> Self {
Self { temperature }
}
pub fn forward(&self, span_embeddings: &Tensor, label_embeddings: &Tensor) -> Result<Tensor> {
let span_norm = l2_normalize(span_embeddings, D::Minus1)?;
let label_norm = l2_normalize(label_embeddings, D::Minus1)?;
let batch_size = span_norm.dims()[0];
let label_t = label_norm.t()?;
let label_t = label_t.unsqueeze(0)?.broadcast_as((
batch_size,
label_t.dims()[0],
label_t.dims()[1],
))?;
let scores = span_norm.matmul(&label_t)?;
let scaled = (scores * self.temperature)?;
candle_nn::ops::sigmoid(&scaled).map_err(|e| Error::Parse(format!("sigmoid: {}", e)))
}
}
#[cfg(feature = "candle")]
pub(crate) fn l2_normalize(tensor: &Tensor, dim: D) -> Result<Tensor> {
let norm = tensor.sqr()?.sum(dim)?.sqrt()?;
let norm = norm.unsqueeze(D::Minus1)?;
let norm_clamped = norm
.clamp(1e-12, f32::MAX)
.map_err(|e| Error::Parse(format!("clamp: {}", e)))?;
tensor
.broadcast_div(&norm_clamped)
.map_err(|e| Error::Parse(format!("l2_normalize: {}", e)))
}