#[cfg(feature = "tokenizer")]
use flodl::Variable;
use flodl::{Result, TensorError};
#[derive(Debug, Clone)]
pub struct Answer {
pub text: String,
pub start: usize,
pub end: usize,
pub score: f32,
}
#[cfg(feature = "tokenizer")]
#[derive(Debug, Clone)]
pub struct TokenPrediction {
pub token: String,
pub label: String,
pub score: f32,
pub attends: bool,
}
pub(crate) fn default_labels(n: i64) -> Vec<String> {
(0..n).map(|i| format!("LABEL_{i}")).collect()
}
pub(crate) fn check_num_labels(n: i64) -> Result<i64> {
if n <= 0 {
return Err(TensorError::new(&format!(
"num_labels must be > 0, got {n}",
)));
}
Ok(n)
}
pub(crate) fn logits_to_sorted_labels(
logits: &Variable,
id2label: &[String],
) -> Result<Vec<Vec<(String, f32)>>> {
let probs = logits.softmax(-1)?;
let shape = probs.shape();
assert_eq!(shape.len(), 2, "expected [batch, num_labels], got {shape:?}");
let batch = shape[0] as usize;
let n = shape[1] as usize;
assert_eq!(
n,
id2label.len(),
"classifier output width {n} != id2label count {}",
id2label.len(),
);
let flat = probs.data().to_f32_vec()?;
let mut out = Vec::with_capacity(batch);
for b in 0..batch {
let mut row: Vec<(String, f32)> = (0..n)
.map(|k| (id2label[k].clone(), flat[b * n + k]))
.collect();
row.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
out.push(row);
}
Ok(out)
}
#[cfg(feature = "tokenizer")]
pub(crate) fn extract_best_span(
logits: &Variable,
enc: &crate::tokenizer::EncodedBatch,
tokenizer: &crate::tokenizer::HfTokenizer,
) -> Result<Vec<Answer>> {
let shape = logits.shape();
assert_eq!(shape.len(), 3, "expected [B, S, 2], got {shape:?}");
let batch = shape[0] as usize;
let seq = shape[1] as usize;
assert_eq!(shape[2], 2, "QA head must be 2-wide, got {}", shape[2]);
let starts = logits.narrow(-1, 0, 1)?.softmax(1)?;
let ends = logits.narrow(-1, 1, 1)?.softmax(1)?;
let starts_flat = starts.data().to_f32_vec()?;
let ends_flat = ends.data().to_f32_vec()?;
let sequence_ids: Vec<i64> = enc.sequence_ids.data().to_i64_vec()?;
let input_ids: Vec<i64> = enc.input_ids.data().to_i64_vec()?;
let mut answers = Vec::with_capacity(batch);
for b in 0..batch {
let offset = b * seq;
let valid: Vec<usize> = (0..seq)
.filter(|&s| sequence_ids[offset + s] == 1)
.collect();
if valid.is_empty() {
return Err(TensorError::new(
"QA extract: no context tokens (sequence_id == 1) found; \
tokenizer did not produce a pair encoding",
));
}
let mut best = (valid[0], valid[0], f32::NEG_INFINITY);
for &i in &valid {
let sp = starts_flat[offset + i];
for &j in valid.iter().filter(|&&j| j >= i) {
let ep = ends_flat[offset + j];
let score = sp + ep;
if score > best.2 {
best = (i, j, score);
}
}
}
let (start, end, score) = best;
let span_ids: Vec<u32> = input_ids[offset + start..=offset + end]
.iter()
.map(|&x| x as u32)
.collect();
let text = tokenizer
.inner()
.decode(&span_ids, true)
.map_err(|e| TensorError::new(&format!("qa decode: {e}")))?;
answers.push(Answer { text, start, end, score });
}
Ok(answers)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_labels_generates_label_k_fallback() {
assert_eq!(default_labels(3), vec!["LABEL_0", "LABEL_1", "LABEL_2"]);
assert!(default_labels(0).is_empty());
}
#[test]
fn check_num_labels_rejects_nonpositive() {
assert_eq!(check_num_labels(3).unwrap(), 3);
assert!(check_num_labels(0).is_err());
assert!(check_num_labels(-1).is_err());
}
}