use flodl::{cross_entropy_loss, Graph, HasGraph, Module, Result, TensorError, Variable};
#[derive(Debug, Clone)]
pub struct Answer {
pub text: String,
pub start: usize,
pub end: usize,
pub score: f32,
}
#[cfg(feature = "tokenizer")]
#[derive(Debug, Clone)]
pub struct TokenPrediction {
pub token: String,
pub label: String,
pub score: f32,
pub attends: bool,
}
pub(crate) fn default_labels(n: i64) -> Vec<String> {
(0..n).map(|i| format!("LABEL_{i}")).collect()
}
pub(crate) fn check_num_labels(n: i64) -> Result<i64> {
if n <= 0 {
return Err(TensorError::new(&format!(
"num_labels must be > 0, got {n}",
)));
}
Ok(n)
}
#[cfg(feature = "tokenizer")]
pub(crate) fn require_tokenizer<'a>(
tokenizer: Option<&'a crate::tokenizer::HfTokenizer>,
method: &str,
) -> Result<&'a crate::tokenizer::HfTokenizer> {
tokenizer.ok_or_else(|| {
TensorError::new(&format!(
"{method} requires a tokenizer; \
use from_pretrained or .with_tokenizer(...) first",
))
})
}
pub(crate) fn argmax_f32(slice: &[f32]) -> (usize, f32) {
let (idx, &val) = slice
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.expect("argmax_f32 called on empty slice");
(idx, val)
}
pub fn sequence_classification_loss(
logits: &Variable,
labels: &Variable,
) -> Result<Variable> {
let shape = logits.shape();
if shape.len() != 2 {
return Err(TensorError::new(&format!(
"sequence_classification_loss: logits must be [batch, num_labels], got {shape:?}",
)));
}
cross_entropy_loss(logits, labels)
}
pub fn token_classification_loss(
logits: &Variable,
labels: &Variable,
) -> Result<Variable> {
let shape = logits.shape();
if shape.len() != 3 {
return Err(TensorError::new(&format!(
"token_classification_loss: logits must be [batch, seq_len, num_labels], got {shape:?}",
)));
}
let num_labels = shape[2];
let flat_logits = logits.reshape(&[-1, num_labels])?;
let flat_labels = labels.reshape(&[-1])?;
cross_entropy_loss(&flat_logits, &flat_labels)
}
pub fn masked_lm_loss(
logits: &Variable,
labels: &Variable,
) -> Result<Variable> {
let shape = logits.shape();
if shape.len() != 3 {
return Err(TensorError::new(&format!(
"masked_lm_loss: logits must be [batch, seq_len, vocab_size], got {shape:?}",
)));
}
let vocab_size = shape[2];
let flat_logits = logits.reshape(&[-1, vocab_size])?;
let flat_labels = labels.reshape(&[-1])?;
cross_entropy_loss(&flat_logits, &flat_labels)
}
pub fn question_answering_loss(
logits: &Variable,
start_positions: &Variable,
end_positions: &Variable,
) -> Result<Variable> {
let shape = logits.shape();
if shape.len() != 3 || shape[2] != 2 {
return Err(TensorError::new(&format!(
"question_answering_loss: logits must be [batch, seq_len, 2], got {shape:?}",
)));
}
let start_logits = logits.narrow(-1, 0, 1)?.squeeze(-1)?;
let end_logits = logits.narrow(-1, 1, 1)?.squeeze(-1)?;
let start_loss = cross_entropy_loss(&start_logits, start_positions)?;
let end_loss = cross_entropy_loss(&end_logits, end_positions)?;
start_loss.add(&end_loss)?.mul_scalar(0.5)
}
pub(crate) fn logits_to_sorted_labels(
logits: &Variable,
id2label: &[String],
) -> Result<Vec<Vec<(String, f32)>>> {
let probs = logits.softmax(-1)?;
let shape = probs.shape();
assert_eq!(shape.len(), 2, "expected [batch, num_labels], got {shape:?}");
let batch = shape[0] as usize;
let n = shape[1] as usize;
assert_eq!(
n,
id2label.len(),
"classifier output width {n} != id2label count {}",
id2label.len(),
);
let flat = probs.data().to_f32_vec()?;
let mut out = Vec::with_capacity(batch);
for b in 0..batch {
let mut row: Vec<(String, f32)> = (0..n)
.map(|k| (id2label[k].clone(), flat[b * n + k]))
.collect();
row.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
out.push(row);
}
Ok(out)
}
#[cfg(feature = "tokenizer")]
pub(crate) fn extract_best_span(
logits: &Variable,
enc: &crate::tokenizer::EncodedBatch,
tokenizer: &crate::tokenizer::HfTokenizer,
) -> Result<Vec<Answer>> {
let shape = logits.shape();
assert_eq!(shape.len(), 3, "expected [B, S, 2], got {shape:?}");
let batch = shape[0] as usize;
let seq = shape[1] as usize;
assert_eq!(shape[2], 2, "QA head must be 2-wide, got {}", shape[2]);
let starts = logits.narrow(-1, 0, 1)?.softmax(1)?;
let ends = logits.narrow(-1, 1, 1)?.softmax(1)?;
let starts_flat = starts.data().to_f32_vec()?;
let ends_flat = ends.data().to_f32_vec()?;
let sequence_ids: Vec<i64> = enc.sequence_ids.data().to_i64_vec()?;
let input_ids: Vec<i64> = enc.input_ids.data().to_i64_vec()?;
let mut answers = Vec::with_capacity(batch);
for b in 0..batch {
let offset = b * seq;
let valid: Vec<usize> = (0..seq)
.filter(|&s| sequence_ids[offset + s] == 1)
.collect();
if valid.is_empty() {
return Err(TensorError::new(
"QA extract: no context tokens (sequence_id == 1) found; \
tokenizer did not produce a pair encoding",
));
}
let mut best = (valid[0], valid[0], f32::NEG_INFINITY);
for &i in &valid {
let sp = starts_flat[offset + i];
for &j in valid.iter().filter(|&&j| j >= i) {
let ep = ends_flat[offset + j];
let score = sp + ep;
if score > best.2 {
best = (i, j, score);
}
}
}
let (start, end, score) = best;
let span_ids: Vec<u32> = input_ids[offset + start..=offset + end]
.iter()
.map(|&x| x as u32)
.collect();
let text = tokenizer
.inner()
.decode(&span_ids, true)
.map_err(|e| TensorError::new(&format!("qa decode: {e}")))?;
answers.push(Answer { text, start, end, score });
}
Ok(answers)
}
#[cfg(feature = "tokenizer")]
pub trait EncoderInputs {
const FAMILY_NAME: &'static str;
const MASK_TOKEN: &'static str;
fn encoder_inputs(enc: &crate::tokenizer::EncodedBatch) -> Result<Vec<Variable>>;
}
pub struct ClassificationHead<C: Clone> {
pub(crate) graph: Graph,
pub(crate) config: C,
pub(crate) id2label: Vec<String>,
#[cfg(feature = "tokenizer")]
pub(crate) tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl<C: Clone> ClassificationHead<C> {
pub(crate) fn from_graph(
graph: Graph,
config: &C,
num_labels: i64,
id2label: Option<Vec<String>>,
) -> Self {
let id2label = id2label.unwrap_or_else(|| default_labels(num_labels));
Self {
graph,
config: config.clone(),
id2label,
#[cfg(feature = "tokenizer")]
tokenizer: None,
}
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn into_graph(self) -> Graph { self.graph }
pub fn config(&self) -> &C { &self.config }
pub fn labels(&self) -> &[String] { &self.id2label }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
}
#[cfg(feature = "tokenizer")]
impl<C: Clone + EncoderInputs> ClassificationHead<C> {
pub fn forward_encoded(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Variable> {
let inputs = C::encoder_inputs(enc)?;
self.graph.forward_multi(&inputs)
}
pub fn classify(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Vec<(String, f32)>>> {
self.graph.eval();
let logits = self.forward_encoded(enc)?;
logits_to_sorted_labels(&logits, &self.id2label)
}
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<(String, f32)>>> {
let name = format!("{}ForSequenceClassification::predict", C::FAMILY_NAME);
let tok = require_tokenizer(self.tokenizer.as_ref(), &name)?;
let enc = tok.encode(texts)?;
self.classify(&enc)
}
pub fn compute_loss(
&self,
enc: &crate::tokenizer::EncodedBatch,
labels: &Variable,
) -> Result<Variable> {
let logits = self.forward_encoded(enc)?;
sequence_classification_loss(&logits, labels)
}
}
impl<C: Clone> HasGraph for ClassificationHead<C> {
fn graph(&self) -> &Graph { &self.graph }
}
pub struct TaggingHead<C: Clone> {
pub(crate) graph: Graph,
pub(crate) config: C,
pub(crate) id2label: Vec<String>,
#[cfg(feature = "tokenizer")]
pub(crate) tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl<C: Clone> TaggingHead<C> {
pub(crate) fn from_graph(
graph: Graph,
config: &C,
num_labels: i64,
id2label: Option<Vec<String>>,
) -> Self {
let id2label = id2label.unwrap_or_else(|| default_labels(num_labels));
Self {
graph,
config: config.clone(),
id2label,
#[cfg(feature = "tokenizer")]
tokenizer: None,
}
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn into_graph(self) -> Graph { self.graph }
pub fn config(&self) -> &C { &self.config }
pub fn labels(&self) -> &[String] { &self.id2label }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
}
#[cfg(feature = "tokenizer")]
impl<C: Clone + EncoderInputs> TaggingHead<C> {
pub fn forward_encoded(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Variable> {
let inputs = C::encoder_inputs(enc)?;
self.graph.forward_multi(&inputs)
}
pub fn tag(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Vec<TokenPrediction>>> {
let name = format!("{}ForTokenClassification::tag", C::FAMILY_NAME);
let tok = require_tokenizer(self.tokenizer.as_ref(), &name)?;
self.graph.eval();
let logits = self.forward_encoded(enc)?;
let probs = logits.softmax(-1)?;
let shape = probs.shape();
assert_eq!(shape.len(), 3, "expected [B, S, num_labels], got {shape:?}");
let batch = shape[0] as usize;
let seq = shape[1] as usize;
let n = shape[2] as usize;
let flat = probs.data().to_f32_vec()?;
let input_ids: Vec<i64> = enc.input_ids.data().to_i64_vec()?;
let attn_ids: Vec<i64> = enc.attention_mask.data().to_i64_vec()?;
let mut out = Vec::with_capacity(batch);
for b in 0..batch {
let mut row = Vec::with_capacity(seq);
for s in 0..seq {
let base = (b * seq + s) * n;
let (best_k, best_p) = argmax_f32(&flat[base..base + n]);
let id = input_ids[b * seq + s] as u32;
let token = tok
.inner()
.id_to_token(id)
.unwrap_or_else(|| format!("<unk_id={id}>"));
row.push(TokenPrediction {
token,
label: self.id2label[best_k].clone(),
score: best_p,
attends: attn_ids[b * seq + s] != 0,
});
}
out.push(row);
}
Ok(out)
}
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<TokenPrediction>>> {
let name = format!("{}ForTokenClassification::predict", C::FAMILY_NAME);
let tok = require_tokenizer(self.tokenizer.as_ref(), &name)?;
let enc = tok.encode(texts)?;
self.tag(&enc)
}
pub fn compute_loss(
&self,
enc: &crate::tokenizer::EncodedBatch,
labels: &Variable,
) -> Result<Variable> {
let logits = self.forward_encoded(enc)?;
token_classification_loss(&logits, labels)
}
}
impl<C: Clone> HasGraph for TaggingHead<C> {
fn graph(&self) -> &Graph { &self.graph }
}
pub struct QaHead<C: Clone> {
pub(crate) graph: Graph,
pub(crate) config: C,
#[cfg(feature = "tokenizer")]
pub(crate) tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl<C: Clone> QaHead<C> {
pub(crate) fn from_graph(graph: Graph, config: &C) -> Self {
Self {
graph,
config: config.clone(),
#[cfg(feature = "tokenizer")]
tokenizer: None,
}
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn into_graph(self) -> Graph { self.graph }
pub fn config(&self) -> &C { &self.config }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
}
#[cfg(feature = "tokenizer")]
impl<C: Clone + EncoderInputs> QaHead<C> {
pub fn forward_encoded(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Variable> {
let inputs = C::encoder_inputs(enc)?;
self.graph.forward_multi(&inputs)
}
pub fn answer(&self, question: &str, context: &str) -> Result<Answer> {
let mut out = self.answer_batch(&[(question, context)])?;
Ok(out.pop().expect("answer_batch returns one per input"))
}
pub fn answer_batch(&self, pairs: &[(&str, &str)]) -> Result<Vec<Answer>> {
let name = format!("{}ForQuestionAnswering::answer", C::FAMILY_NAME);
let tok = require_tokenizer(self.tokenizer.as_ref(), &name)?;
let enc = tok.encode_pairs(pairs)?;
self.extract(&enc)
}
pub fn extract(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Answer>> {
let name = format!("{}ForQuestionAnswering::extract", C::FAMILY_NAME);
let tok = require_tokenizer(self.tokenizer.as_ref(), &name)?;
self.graph.eval();
let logits = self.forward_encoded(enc)?;
extract_best_span(&logits, enc, tok)
}
pub fn compute_loss(
&self,
enc: &crate::tokenizer::EncodedBatch,
start_positions: &Variable,
end_positions: &Variable,
) -> Result<Variable> {
let logits = self.forward_encoded(enc)?;
question_answering_loss(&logits, start_positions, end_positions)
}
}
impl<C: Clone> HasGraph for QaHead<C> {
fn graph(&self) -> &Graph { &self.graph }
}
pub struct MaskedLmHead<C: Clone> {
pub(crate) graph: Graph,
pub(crate) config: C,
#[cfg(feature = "tokenizer")]
pub(crate) tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl<C: Clone> MaskedLmHead<C> {
pub(crate) fn from_graph(graph: Graph, config: &C) -> Self {
Self {
graph,
config: config.clone(),
#[cfg(feature = "tokenizer")]
tokenizer: None,
}
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn into_graph(self) -> Graph { self.graph }
pub fn config(&self) -> &C { &self.config }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
}
#[cfg(feature = "tokenizer")]
impl<C: Clone + EncoderInputs> MaskedLmHead<C> {
pub fn forward_encoded(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Variable> {
let inputs = C::encoder_inputs(enc)?;
self.graph.forward_multi(&inputs)
}
pub fn compute_loss(
&self,
enc: &crate::tokenizer::EncodedBatch,
labels: &Variable,
) -> Result<Variable> {
let logits = self.forward_encoded(enc)?;
masked_lm_loss(&logits, labels)
}
pub fn fill_mask(
&self,
text: &str,
top_k: usize,
) -> Result<Vec<Vec<(String, f32)>>> {
if top_k == 0 {
return Err(TensorError::new("fill_mask: top_k must be > 0"));
}
let name = format!("{}ForMaskedLM::fill_mask", C::FAMILY_NAME);
let tok = require_tokenizer(self.tokenizer.as_ref(), &name)?;
let mask_tok = C::MASK_TOKEN;
let mask_id = tok.inner().token_to_id(mask_tok).ok_or_else(|| {
TensorError::new(&format!(
"fill_mask: tokenizer has no {mask_tok} token",
))
})? as i64;
self.graph.eval();
let enc = tok.encode(&[text])?;
let logits = self.forward_encoded(&enc)?;
let probs = logits.data().softmax(-1)?;
let ids_row = enc.input_ids.data().select(0, 0)?.to_i64_vec()?;
let mut out = Vec::new();
for (pos, id) in ids_row.iter().enumerate() {
if *id != mask_id {
continue;
}
let row = probs.select(0, 0)?.select(0, pos as i64)?;
let (vals, idxs) = row.topk(top_k as i64, 0, true, true)?;
let score_vec = vals.to_f32_vec()?;
let id_vec = idxs.to_i64_vec()?;
let picks: Vec<(String, f32)> = id_vec
.iter()
.zip(score_vec.iter())
.map(|(i, s)| {
let tok_str = tok
.inner()
.id_to_token(*i as u32)
.unwrap_or_else(|| format!("[UNK_{i}]"));
(tok_str, *s)
})
.collect();
out.push(picks);
}
if out.is_empty() {
return Err(TensorError::new(&format!(
"fill_mask: input contains no {mask_tok} token",
)));
}
Ok(out)
}
}
impl<C: Clone> HasGraph for MaskedLmHead<C> {
fn graph(&self) -> &Graph { &self.graph }
}
#[cfg(test)]
mod tests {
use super::*;
use flodl::{DType, Device, Tensor, TensorOptions};
fn cpu() -> Device { Device::CPU }
#[test]
fn default_labels_generates_label_k_fallback() {
assert_eq!(default_labels(3), vec!["LABEL_0", "LABEL_1", "LABEL_2"]);
assert!(default_labels(0).is_empty());
}
#[test]
fn check_num_labels_rejects_nonpositive() {
assert_eq!(check_num_labels(3).unwrap(), 3);
assert!(check_num_labels(0).is_err());
assert!(check_num_labels(-1).is_err());
}
fn logits_2d(data: &[f32], rows: i64, cols: i64) -> Variable {
Variable::new(
Tensor::from_f32(data, &[rows, cols], cpu()).unwrap(),
true,
)
}
fn labels_1d(data: &[i64], n: i64) -> Variable {
Variable::new(
Tensor::from_i64(data, &[n], cpu()).unwrap(),
false,
)
}
#[test]
fn sequence_classification_loss_rejects_wrong_rank() {
let logits = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4], cpu()).unwrap(),
true,
);
let labels = labels_1d(&[0, 1, 0, 1], 4);
let err = sequence_classification_loss(&logits, &labels).unwrap_err();
assert!(err.to_string().contains("must be [batch, num_labels]"));
}
#[test]
fn sequence_classification_loss_backward_flows() {
let logits = logits_2d(&[5.0, 0.1, 0.1, 0.1, 5.0, 0.1], 2, 3);
let labels = labels_1d(&[0, 1], 2);
let loss = sequence_classification_loss(&logits, &labels).unwrap();
loss.backward().unwrap();
assert!(logits.grad().is_some(), "logits must receive grad");
let loss_val = loss.data().to_f32_vec().unwrap()[0];
assert!(loss_val < 0.1, "expected small loss, got {loss_val}");
}
#[test]
fn token_classification_loss_flattens_and_ignores_minus_100() {
let logits_data = [
5.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 5.0, 5.0, 0.0, 0.0, 0.0, ];
let logits = Variable::new(
Tensor::from_f32(&logits_data, &[2, 3, 2], cpu()).unwrap(),
true,
);
let labels = Variable::new(
Tensor::from_i64(&[0, -100, 1, 1, 0, -100], &[2, 3], cpu()).unwrap(),
false,
);
let loss = token_classification_loss(&logits, &labels).unwrap();
loss.backward().unwrap();
assert!(logits.grad().is_some(), "logits must receive grad");
let loss_val = loss.data().to_f32_vec().unwrap()[0];
assert!(loss_val < 0.1, "expected small loss (all correct), got {loss_val}");
}
#[test]
fn token_classification_loss_rejects_wrong_rank() {
let logits = logits_2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let labels = labels_1d(&[0, 1], 2);
let err = token_classification_loss(&logits, &labels).unwrap_err();
assert!(err.to_string().contains("[batch, seq_len, num_labels]"));
}
#[test]
fn question_answering_loss_averages_two_heads() {
let opts = TensorOptions { dtype: DType::Float32, device: cpu() };
let logits_flat = Tensor::zeros(&[2, 4, 2], opts).unwrap();
let raw: Vec<f32> = {
let mut v = vec![0.0_f32; 2 * 4 * 2];
let ix = |b: usize, s: usize, k: usize| (b * 4 + s) * 2 + k;
v[ix(0, 1, 0)] = 10.0;
v[ix(0, 2, 1)] = 10.0;
v[ix(1, 0, 0)] = 10.0;
v[ix(1, 3, 1)] = 10.0;
v
};
drop(logits_flat);
let logits = Variable::new(
Tensor::from_f32(&raw, &[2, 4, 2], cpu()).unwrap(),
true,
);
let starts = labels_1d(&[1, 0], 2);
let ends = labels_1d(&[2, 3], 2);
let loss = question_answering_loss(&logits, &starts, &ends).unwrap();
loss.backward().unwrap();
assert!(logits.grad().is_some(), "logits must receive grad");
let loss_val = loss.data().to_f32_vec().unwrap()[0];
assert!(loss_val < 0.01, "expected tiny loss at peaked logits, got {loss_val}");
}
#[test]
fn masked_lm_loss_flattens_and_ignores_minus_100() {
let logits_data = [
0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0,
];
let logits = Variable::new(
Tensor::from_f32(&logits_data, &[2, 3, 4], cpu()).unwrap(),
true,
);
let labels = Variable::new(
Tensor::from_i64(&[2, -100, 0, -100, 1, 3], &[2, 3], cpu()).unwrap(),
false,
);
let loss = masked_lm_loss(&logits, &labels).unwrap();
loss.backward().unwrap();
assert!(logits.grad().is_some(), "logits must receive grad");
let loss_val = loss.data().to_f32_vec().unwrap()[0];
assert!(loss_val < 0.1, "expected small loss (all targets peaked), got {loss_val}");
}
#[test]
fn masked_lm_loss_rejects_wrong_rank() {
let logits = logits_2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let labels = labels_1d(&[0, 1], 2);
let err = masked_lm_loss(&logits, &labels).unwrap_err();
assert!(err.to_string().contains("[batch, seq_len, vocab_size]"));
}
#[test]
fn question_answering_loss_rejects_wrong_last_dim() {
let logits = Variable::new(
Tensor::from_f32(&[0.0_f32; 12], &[2, 3, 2], cpu()).unwrap(),
true,
);
let starts = labels_1d(&[0, 1], 2);
let ends = labels_1d(&[2, 2], 2);
assert!(question_answering_loss(&logits, &starts, &ends).is_ok());
let bad = Variable::new(
Tensor::from_f32(&[0.0_f32; 18], &[2, 3, 3], cpu()).unwrap(),
true,
);
let err = question_answering_loss(&bad, &starts, &ends).unwrap_err();
assert!(err.to_string().contains("[batch, seq_len, 2]"));
}
}