use super::FillMaskResult;
use crate::error::{Result, TextError};
#[derive(Debug)]
pub struct FillMaskPipeline;
impl Default for FillMaskPipeline {
fn default() -> Self {
Self::new()
}
}
impl FillMaskPipeline {
pub fn new() -> Self {
Self
}
pub fn fill_mask(&self, text: &str) -> Result<Vec<FillMaskResult>> {
if !text.contains("[MASK]") {
return Err(TextError::InvalidInput("No [MASK] token found".to_string()));
}
let words: Vec<&str> = text.split_whitespace().collect();
let mask_index = words.iter().position(|&w| w == "[MASK]").unwrap_or(0);
let left_context: Vec<&str> = if mask_index > 0 {
words[..mask_index].iter().rev().take(3).copied().collect()
} else {
vec![]
};
let right_context: Vec<&str> = if mask_index < words.len() - 1 {
words[mask_index + 1..].iter().take(3).copied().collect()
} else {
vec![]
};
let mut candidates = Vec::new();
let common_words = vec![
("the", 0.85),
("a", 0.75),
("an", 0.65),
("is", 0.80),
("was", 0.75),
("are", 0.70),
("will", 0.68),
("can", 0.72),
("would", 0.70),
("should", 0.65),
("very", 0.60),
("more", 0.68),
("most", 0.65),
("good", 0.60),
("great", 0.58),
("important", 0.55),
("significant", 0.52),
("major", 0.50),
];
for (word, base_score) in common_words {
let mut score = base_score;
if !left_context.is_empty() {
let prev_word = left_context[0];
if (prev_word == "a" || prev_word == "an") && word.starts_with(char::is_alphabetic)
{
score *= 0.3; } else if prev_word.ends_with("ly") && (word == "good" || word == "important") {
score *= 1.2; }
}
if !right_context.is_empty() {
let next_word = right_context[0];
if word == "a" && next_word.starts_with(|c: char| "aeiou".contains(c)) {
score *= 0.2; } else if word == "an" && !next_word.starts_with(|c: char| "aeiou".contains(c)) {
score *= 0.2; }
}
candidates.push(FillMaskResult {
token_str: word.to_string(),
sequence: text.replace("[MASK]", word),
score,
token: candidates.len() + 1,
});
}
candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).expect("Operation failed"));
Ok(candidates.into_iter().take(5).collect())
}
}