use crate::core::traits::{Model, Tokenizer};
use crate::error::{Result, TrustformersError};
use crate::pipeline::{BasePipeline, Pipeline, PipelineOutput, QuestionAnsweringOutput};
use crate::{AutoModel, AutoTokenizer};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QAConfig {
pub max_length: usize,
pub max_answer_length: usize,
pub handle_impossible_answer: bool,
pub doc_stride: usize,
}
impl Default for QAConfig {
fn default() -> Self {
Self {
max_length: 384,
max_answer_length: 15,
handle_impossible_answer: false,
doc_stride: 128,
}
}
}
#[derive(Debug, Clone)]
pub struct QAInput {
pub question: String,
pub context: String,
}
#[derive(Clone)]
pub struct QuestionAnsweringPipeline {
base: BasePipeline<AutoModel, AutoTokenizer>,
max_answer_length: usize,
handle_impossible_answer: bool,
}
impl QuestionAnsweringPipeline {
pub fn new(model: AutoModel, tokenizer: AutoTokenizer) -> Result<Self> {
Ok(Self {
base: BasePipeline::new(model, tokenizer),
max_answer_length: 15,
handle_impossible_answer: false,
})
}
pub fn with_max_answer_length(mut self, length: usize) -> Self {
self.max_answer_length = length;
self
}
pub fn with_handle_impossible_answer(mut self, handle: bool) -> Self {
self.handle_impossible_answer = handle;
self
}
fn answer_question(&self, question: &str, context: &str) -> Result<QuestionAnsweringOutput> {
match &self.base.model.model_type {
#[cfg(feature = "bert")]
crate::automodel::AutoModelType::BertForSequenceClassification(model) => {
self.extract_answer_with_bert(model, question, context)
},
#[cfg(feature = "bert")]
crate::automodel::AutoModelType::Bert(_model) => {
let answer_result = self.extract_answer_simple(question, context);
Ok(answer_result)
},
_ => Err(TrustformersError::runtime_error(
"Model does not support question answering",
)),
}
}
fn answer_questions_batch(&self, inputs: &[QAInput]) -> Result<Vec<QuestionAnsweringOutput>> {
inputs
.iter()
.map(|input| self.answer_question(&input.question, &input.context))
.collect()
}
fn extract_answer_simple(&self, question: &str, context: &str) -> QuestionAnsweringOutput {
let question_lower = question.to_lowercase();
let context_words: Vec<&str> = context.split_whitespace().collect();
let mut best_start = 0;
let mut best_score = 0.5;
let mut answer_length = 3;
if question_lower.contains("what") || question_lower.contains("who") {
for (i, window) in context_words.windows(3).enumerate() {
let window_text = window.join(" ").to_lowercase();
if window_text.contains("is")
|| window_text.contains("was")
|| window_text.contains("are")
{
best_start = i + 1; best_score = 0.8;
answer_length = 2;
break;
}
}
} else if question_lower.contains("when") {
for (i, word) in context_words.iter().enumerate() {
if word.chars().all(|c| c.is_ascii_digit()) && word.len() == 4 {
best_start = i;
best_score = 0.9;
answer_length = 1;
break;
} else if word.to_lowercase().contains("day")
|| word.to_lowercase().contains("month")
{
best_start = i.saturating_sub(1);
best_score = 0.7;
answer_length = 2;
break;
}
}
} else if question_lower.contains("where") {
for (i, window) in context_words.windows(2).enumerate() {
let window_text = window.join(" ").to_lowercase();
if window_text.contains("in ")
|| window_text.contains("at ")
|| window_text.contains("on ")
{
best_start = i + 1; best_score = 0.75;
answer_length = 2;
break;
}
}
}
let end_idx = (best_start + answer_length).min(context_words.len());
let answer = if best_start < context_words.len() {
context_words[best_start..end_idx].join(" ")
} else {
"Unable to find answer".to_string()
};
let char_start = context_words[..best_start]
.iter()
.map(|w| w.len() + 1)
.sum::<usize>()
.saturating_sub(1);
let char_end = char_start + answer.len();
QuestionAnsweringOutput {
answer,
score: best_score,
start: char_start,
end: char_end,
}
}
#[cfg(feature = "bert")]
fn extract_answer_with_bert(
&self,
model: &crate::models::bert::BertForSequenceClassification,
question: &str,
context: &str,
) -> Result<QuestionAnsweringOutput> {
let input_text = format!("[CLS] {} [SEP] {} [SEP]", question, context);
let tokenized = self.base.tokenizer.encode(&input_text)?;
let sep_token_id = self.base.tokenizer.token_to_id("[SEP]").unwrap_or(102);
let sep_positions: Vec<usize> = tokenized
.input_ids
.iter()
.enumerate()
.filter(|(_, &id)| id == sep_token_id)
.map(|(pos, _)| pos)
.collect();
if sep_positions.len() < 2 {
return Err(TrustformersError::invalid_input_simple(
"Could not find proper separator tokens in input".to_string(),
));
}
let context_start = sep_positions[0] + 1;
let context_end = sep_positions[1];
let output = model.forward(tokenized.clone())?;
let logits = &output.logits;
let logits_data = logits.data()?;
let class_probs = self.softmax(&logits_data);
let confidence = class_probs.iter().fold(0.0f32, |a, &b| a.max(b));
if confidence < 0.3 {
return Ok(QuestionAnsweringOutput {
answer: "".to_string(),
score: confidence,
start: 0,
end: 0,
});
}
let answer_result = self.extract_answer_contextual(question, context, confidence);
Ok(answer_result)
}
fn softmax(&self, logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
if sum_exp > 0.0 {
exp_logits.iter().map(|&x| x / sum_exp).collect()
} else {
vec![1.0 / logits.len() as f32; logits.len()] }
}
fn extract_answer_contextual(
&self,
question: &str,
context: &str,
confidence: f32,
) -> QuestionAnsweringOutput {
let question_lower = question.to_lowercase();
let context_words: Vec<&str> = context.split_whitespace().collect();
let mut best_start = 0;
let mut best_score = confidence * 0.7; let mut answer_length = 3;
if question_lower.contains("what") {
for (i, window) in context_words.windows(4).enumerate() {
let window_text = window.join(" ").to_lowercase();
if window_text.contains(" is ")
|| window_text.contains(" was ")
|| window_text.contains(" are ")
|| window_text.contains(" were ")
{
best_start = self.find_relevant_phrase_start(&context_words, i, 2);
best_score = confidence * 0.9;
answer_length = self.determine_answer_length(&window_text);
break;
}
}
} else if question_lower.contains("who") {
for (i, window) in context_words.windows(3).enumerate() {
if window
.iter()
.any(|w| w.chars().next().map(|c| c.is_uppercase()).unwrap_or(false))
{
best_start = i;
best_score = confidence * 0.85;
answer_length = 2;
break;
}
}
} else if question_lower.contains("when") {
for (i, word) in context_words.iter().enumerate() {
if self.is_time_expression(word) {
best_start = i.saturating_sub(1);
best_score = confidence * 0.95;
answer_length = self.get_time_expression_length(word, &context_words, i);
break;
}
}
} else if question_lower.contains("where") {
for (i, window) in context_words.windows(3).enumerate() {
let window_text = window.join(" ").to_lowercase();
if self.is_location_phrase(&window_text) {
best_start = i;
best_score = confidence * 0.8;
answer_length = 2;
break;
}
}
} else if question_lower.contains("how") {
for (i, window) in context_words.windows(5).enumerate() {
let window_text = window.join(" ").to_lowercase();
if window_text.contains("by")
|| window_text.contains("through")
|| window_text.contains("using")
|| window_text.contains("method")
{
best_start = i;
best_score = confidence * 0.75;
answer_length = 4;
break;
}
}
}
let end_idx = (best_start + answer_length).min(context_words.len());
let answer = if best_start < context_words.len() && end_idx > best_start {
context_words[best_start..end_idx].join(" ")
} else {
"Unable to find answer".to_string()
};
let char_start = self.calculate_char_position(&context_words, best_start, context);
let char_end = char_start + answer.len();
QuestionAnsweringOutput {
answer,
score: best_score,
start: char_start,
end: char_end.min(context.len()),
}
}
fn is_time_expression(&self, word: &str) -> bool {
if word.len() == 4 && word.chars().all(|c| c.is_ascii_digit()) {
if let Ok(year) = word.parse::<u32>() {
return (1000..=2100).contains(&year);
}
}
let time_words = [
"january",
"february",
"march",
"april",
"may",
"june",
"july",
"august",
"september",
"october",
"november",
"december",
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
"morning",
"afternoon",
"evening",
"night",
"today",
"yesterday",
"tomorrow",
];
time_words.contains(&word.to_lowercase().as_str())
}
fn is_location_phrase(&self, phrase: &str) -> bool {
phrase.contains("in ")
|| phrase.contains("at ")
|| phrase.contains("on ")
|| phrase.contains("near ")
|| phrase.contains("city")
|| phrase.contains("country")
|| phrase.contains("state")
|| phrase.contains("town")
|| phrase.contains("village")
}
fn find_relevant_phrase_start(
&self,
words: &[&str],
window_pos: usize,
offset: usize,
) -> usize {
(window_pos + offset).min(words.len().saturating_sub(1))
}
fn determine_answer_length(&self, content: &str) -> usize {
if content.contains("definition") || content.contains("means") {
5 } else if content.contains("name") || content.contains("called") {
2 } else {
3 }
}
fn get_time_expression_length(&self, word: &str, words: &[&str], pos: usize) -> usize {
if word.len() == 4 && word.chars().all(|c| c.is_ascii_digit()) {
1 } else if pos + 1 < words.len()
&& (words[pos + 1].contains(",") || words[pos + 1].parse::<u32>().is_ok())
{
2 } else {
1 }
}
fn calculate_char_position(
&self,
words: &[&str],
word_pos: usize,
original_context: &str,
) -> usize {
if word_pos == 0 {
return 0;
}
let target_words = &words[..word_pos.min(words.len())];
let partial_text = target_words.join(" ");
original_context.find(&partial_text)
.map(|pos| pos + partial_text.len() + 1) .unwrap_or(word_pos * 5) }
}
impl Pipeline for QuestionAnsweringPipeline {
type Input = String; type Output = PipelineOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let parts: Vec<&str> = input.split("\ncontext:").collect();
if parts.len() != 2 {
return Err(TrustformersError::invalid_input_simple(
"Expected format: 'question\ncontext:text'".to_string(),
));
}
let question = parts[0].trim();
let context = parts[1].trim();
let result = self.answer_question(question, context)?;
Ok(PipelineOutput::QuestionAnswering(result))
}
fn batch(&self, inputs: Vec<Self::Input>) -> Result<Vec<Self::Output>> {
let parsed_inputs: Result<Vec<QAInput>> = inputs
.iter()
.map(|input| {
let parts: Vec<&str> = input.split("\ncontext:").collect();
if parts.len() != 2 {
return Err(TrustformersError::invalid_input_simple(
"Expected format: 'question\ncontext:text'".to_string(),
));
}
Ok(QAInput {
question: parts[0].trim().to_string(),
context: parts[1].trim().to_string(),
})
})
.collect();
let parsed = parsed_inputs?;
let results = self.answer_questions_batch(&parsed)?;
Ok(results.into_iter().map(PipelineOutput::QuestionAnswering).collect())
}
}
#[cfg(feature = "async")]
#[async_trait::async_trait]
impl crate::pipeline::AsyncPipeline for QuestionAnsweringPipeline {
type Input = String;
type Output = PipelineOutput;
async fn __call_async__(&self, input: Self::Input) -> Result<Self::Output> {
let pipeline = self.clone();
tokio::task::spawn_blocking(move || pipeline.__call__(input))
.await
.map_err(|e| TrustformersError::pipeline(e.to_string(), "runtime"))?
}
}
#[cfg(test)]
mod tests {
use super::*;
struct QAHelpers;
impl QAHelpers {
fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
if sum_exp > 0.0 {
exp_logits.iter().map(|&x| x / sum_exp).collect()
} else {
vec![1.0 / logits.len() as f32; logits.len()]
}
}
fn is_time_expression(word: &str) -> bool {
if word.len() == 4 && word.chars().all(|c| c.is_ascii_digit()) {
if let Ok(year) = word.parse::<u32>() {
return (1000..=2100).contains(&year);
}
}
let time_words = [
"january",
"february",
"march",
"april",
"may",
"june",
"july",
"august",
"september",
"october",
"november",
"december",
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
"morning",
"afternoon",
"evening",
"night",
"today",
"yesterday",
"tomorrow",
];
time_words.contains(&word.to_lowercase().as_str())
}
fn is_location_phrase(phrase: &str) -> bool {
phrase.contains("in ")
|| phrase.contains("at ")
|| phrase.contains("on ")
|| phrase.contains("near ")
|| phrase.contains("city")
|| phrase.contains("country")
|| phrase.contains("state")
|| phrase.contains("town")
|| phrase.contains("village")
}
}
#[test]
fn test_qa_config_default_values() {
let cfg = QAConfig::default();
assert_eq!(cfg.max_length, 384);
assert_eq!(cfg.max_answer_length, 15);
assert!(!cfg.handle_impossible_answer);
assert_eq!(cfg.doc_stride, 128);
}
#[test]
fn test_qa_config_clone_and_debug() {
let cfg = QAConfig::default();
let cloned = cfg.clone();
assert_eq!(cfg.max_length, cloned.max_length);
let dbg = format!("{:?}", cloned);
assert!(dbg.contains("QAConfig"));
}
#[test]
fn test_qa_input_construction() {
let input = QAInput {
question: "What is the capital?".to_string(),
context: "The capital of France is Paris.".to_string(),
};
assert_eq!(input.question, "What is the capital?");
assert!(!input.context.is_empty());
}
#[test]
fn test_softmax_single_element() {
let result = QAHelpers::softmax(&[1.0]);
assert!((result[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_softmax_two_equal_elements() {
let result = QAHelpers::softmax(&[0.0, 0.0]);
assert_eq!(result.len(), 2);
assert!((result[0] - 0.5).abs() < 1e-6);
assert!((result[1] - 0.5).abs() < 1e-6);
}
#[test]
fn test_softmax_probabilities_sum_to_one() {
let logits: Vec<f32> = (0..8u32)
.map(|i| {
let lcg = 1664525u32.wrapping_mul(i).wrapping_add(1013904223);
(lcg % 1000) as f32 / 100.0 - 5.0
})
.collect();
let probs = QAHelpers::softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "softmax sum = {}", sum);
}
#[test]
fn test_softmax_empty_input() {
let result = QAHelpers::softmax(&[]);
assert!(result.is_empty());
}
#[test]
fn test_softmax_maximum_probability_dominates() {
let logits = vec![0.0, 0.0, 100.0];
let probs = QAHelpers::softmax(&logits);
assert!(probs[2] > 0.99, "dominant logit should win: {:?}", probs);
}
#[test]
fn test_is_time_expression_year() {
assert!(QAHelpers::is_time_expression("2024"));
assert!(QAHelpers::is_time_expression("1990"));
}
#[test]
fn test_is_time_expression_month_name() {
assert!(QAHelpers::is_time_expression("January"));
assert!(QAHelpers::is_time_expression("december"));
}
#[test]
fn test_is_time_expression_day_name() {
assert!(QAHelpers::is_time_expression("Monday"));
assert!(QAHelpers::is_time_expression("friday"));
}
#[test]
fn test_is_time_expression_time_of_day() {
assert!(QAHelpers::is_time_expression("morning"));
assert!(QAHelpers::is_time_expression("afternoon"));
assert!(QAHelpers::is_time_expression("evening"));
assert!(QAHelpers::is_time_expression("night"));
}
#[test]
fn test_is_time_expression_relative() {
assert!(QAHelpers::is_time_expression("today"));
assert!(QAHelpers::is_time_expression("yesterday"));
assert!(QAHelpers::is_time_expression("tomorrow"));
}
#[test]
fn test_is_not_time_expression() {
assert!(!QAHelpers::is_time_expression("Paris"));
assert!(!QAHelpers::is_time_expression("apple"));
assert!(!QAHelpers::is_time_expression("99")); assert!(!QAHelpers::is_time_expression("12345")); }
#[test]
fn test_is_location_phrase_prepositions() {
assert!(QAHelpers::is_location_phrase("in Berlin"));
assert!(QAHelpers::is_location_phrase("at the airport"));
assert!(QAHelpers::is_location_phrase("on the hill"));
assert!(QAHelpers::is_location_phrase("near the river"));
}
#[test]
fn test_is_location_phrase_place_words() {
assert!(QAHelpers::is_location_phrase("the city center"));
assert!(QAHelpers::is_location_phrase("the country side"));
assert!(QAHelpers::is_location_phrase("the state of Maine"));
assert!(QAHelpers::is_location_phrase("the small town"));
assert!(QAHelpers::is_location_phrase("a village green"));
}
#[test]
fn test_is_not_location_phrase() {
assert!(!QAHelpers::is_location_phrase("running fast"));
assert!(!QAHelpers::is_location_phrase("the quick brown fox"));
}
#[test]
fn test_qa_output_fields_present() {
let out = QuestionAnsweringOutput {
answer: "Paris".to_string(),
score: 0.92,
start: 10,
end: 15,
};
assert_eq!(out.answer, "Paris");
assert!((out.score - 0.92).abs() < 1e-6);
assert!(out.start < out.end);
}
#[test]
fn test_qa_output_score_range() {
let out = QuestionAnsweringOutput {
answer: "test".to_string(),
score: 0.75,
start: 0,
end: 4,
};
assert!(out.score >= 0.0 && out.score <= 1.0);
}
#[test]
fn test_qa_output_empty_answer_for_no_answer() {
let out = QuestionAnsweringOutput {
answer: String::new(),
score: 0.1,
start: 0,
end: 0,
};
assert!(out.answer.is_empty());
assert_eq!(out.start, out.end);
}
#[test]
fn test_span_extraction_start_before_end() {
let context = "The quick brown fox jumps over the lazy dog.";
let words: Vec<&str> = context.split_whitespace().collect();
let start_word = 1;
let end_word = 4;
let answer = words[start_word..end_word].join(" ");
let char_start: usize = words[..start_word].iter().map(|w| w.len() + 1).sum();
let char_end = char_start + answer.len();
assert!(char_start <= char_end);
assert!(!answer.is_empty());
}
#[test]
fn test_doc_stride_chunks() {
let cfg = QAConfig {
max_length: 384,
doc_stride: 128,
..QAConfig::default()
};
assert!(cfg.doc_stride < cfg.max_length);
}
#[test]
fn test_max_answer_length_config_roundtrip() {
let cfg = QAConfig {
max_answer_length: 30,
..QAConfig::default()
};
assert_eq!(cfg.max_answer_length, 30);
}
#[test]
fn test_answer_ranking_highest_score_wins() {
let mut candidates = [
QuestionAnsweringOutput {
answer: "London".to_string(),
score: 0.5,
start: 0,
end: 6,
},
QuestionAnsweringOutput {
answer: "Paris".to_string(),
score: 0.9,
start: 7,
end: 12,
},
QuestionAnsweringOutput {
answer: "Berlin".to_string(),
score: 0.3,
start: 13,
end: 19,
},
];
candidates
.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
assert_eq!(candidates[0].answer, "Paris");
}
#[test]
fn test_no_answer_detection_low_score() {
let threshold = 0.3_f32;
let out = QuestionAnsweringOutput {
answer: String::new(),
score: 0.15,
start: 0,
end: 0,
};
assert!(out.score < threshold, "should be flagged as no-answer");
}
fn compute_em(prediction: &str, ground_truth: &str) -> f32 {
if prediction.trim().to_lowercase() == ground_truth.trim().to_lowercase() {
1.0
} else {
0.0
}
}
fn compute_f1(prediction: &str, ground_truth: &str) -> f32 {
let pred_tokens: std::collections::HashSet<&str> = prediction.split_whitespace().collect();
let truth_tokens: std::collections::HashSet<&str> =
ground_truth.split_whitespace().collect();
let common: usize = pred_tokens.intersection(&truth_tokens).count();
if common == 0 {
return 0.0;
}
let precision = common as f32 / pred_tokens.len() as f32;
let recall = common as f32 / truth_tokens.len() as f32;
2.0 * precision * recall / (precision + recall)
}
#[test]
fn test_exact_match_identical_strings() {
assert!((compute_em("Paris", "Paris") - 1.0).abs() < 1e-6);
}
#[test]
fn test_exact_match_different_strings() {
assert!((compute_em("London", "Paris") - 0.0).abs() < 1e-6);
}
#[test]
fn test_exact_match_case_insensitive() {
assert!((compute_em("paris", "Paris") - 1.0).abs() < 1e-6);
}
#[test]
fn test_f1_perfect_overlap() {
let f1 = compute_f1("the quick brown fox", "the quick brown fox");
assert!((f1 - 1.0).abs() < 1e-6);
}
#[test]
fn test_f1_no_overlap() {
let f1 = compute_f1("cat", "dog");
assert!((f1 - 0.0).abs() < 1e-6);
}
#[test]
fn test_f1_partial_overlap() {
let f1 = compute_f1("quick brown fox", "the quick brown");
assert!(f1 > 0.0 && f1 < 1.0, "f1 should be partial: {}", f1);
}
#[test]
fn test_qa_config_impossible_answer_flag() {
let cfg = QAConfig {
handle_impossible_answer: true,
..QAConfig::default()
};
assert!(cfg.handle_impossible_answer);
}
}