use std::collections::HashMap;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct AbbreviatorConfig {
pub min_occurrences: usize,
pub min_phrase_tokens: usize,
pub max_phrase_tokens: usize,
pub max_abbreviations: usize,
}
impl Default for AbbreviatorConfig {
fn default() -> Self {
Self {
min_occurrences: 3,
min_phrase_tokens: 3,
max_phrase_tokens: 8,
max_abbreviations: 20,
}
}
}
#[derive(Debug, Clone)]
pub struct Abbreviation {
pub symbol: String,
pub phrase: String,
pub occurrences: usize,
pub tokens_saved: usize,
}
#[derive(Debug, Clone)]
pub struct AbbreviationResult {
pub text: String,
pub abbreviations: Vec<Abbreviation>,
pub total_tokens_saved: usize,
}
pub struct NgramAbbreviator {
config: AbbreviatorConfig,
phrase_counts: HashMap<String, usize>,
active_abbreviations: Vec<Abbreviation>,
next_index: usize,
}
impl NgramAbbreviator {
pub fn new() -> Self {
Self::with_config(AbbreviatorConfig::default())
}
pub fn with_config(config: AbbreviatorConfig) -> Self {
Self {
config,
phrase_counts: HashMap::new(),
active_abbreviations: Vec::new(),
next_index: 1,
}
}
pub fn observe(&mut self, text: &str) {
let words: Vec<&str> = text.split_whitespace().collect();
if words.len() < self.config.min_phrase_tokens {
return;
}
for n in self.config.min_phrase_tokens..=self.config.max_phrase_tokens {
if words.len() < n {
break;
}
for window in words.windows(n) {
let phrase = window.join(" ");
if phrase.len() < 10 {
continue;
}
*self.phrase_counts.entry(phrase).or_insert(0) += 1;
}
}
}
pub fn abbreviate(&mut self, text: &str) -> Result<AbbreviationResult> {
let mut candidates: Vec<(String, usize)> = self
.phrase_counts
.iter()
.filter(|(_, &count)| count >= self.config.min_occurrences)
.map(|(phrase, &count)| (phrase.clone(), count))
.collect();
candidates.sort_by(|a, b| {
let savings_a = estimate_savings(&a.0, a.1);
let savings_b = estimate_savings(&b.0, b.1);
savings_b.cmp(&savings_a)
});
candidates.truncate(self.config.max_abbreviations);
if candidates.is_empty() {
return Ok(AbbreviationResult {
text: text.to_string(),
abbreviations: Vec::new(),
total_tokens_saved: 0,
});
}
let mut abbreviations = Vec::new();
let mut result_text = text.to_string();
let mut total_saved = 0usize;
for (phrase, occurrences) in &candidates {
if abbreviations.len() >= self.config.max_abbreviations {
break;
}
let symbol = format!("Β«A{}Β»", self.next_index);
self.next_index += 1;
let replaced = replace_after_first(&result_text, phrase, &symbol);
let replacements_made = count_occurrences(&result_text, phrase).saturating_sub(1);
if replacements_made == 0 {
continue;
}
result_text = replaced;
let phrase_tokens = phrase.split_whitespace().count();
let saved = replacements_made * phrase_tokens.saturating_sub(1);
total_saved += saved;
let abbrev = Abbreviation {
symbol: symbol.clone(),
phrase: phrase.clone(),
occurrences: *occurrences,
tokens_saved: saved,
};
abbreviations.push(abbrev.clone());
self.active_abbreviations.push(abbrev);
}
if !abbreviations.is_empty() {
let legend = format_legend(&abbreviations);
result_text = format!("{legend}\n\n{result_text}");
}
Ok(AbbreviationResult {
text: result_text,
abbreviations,
total_tokens_saved: total_saved,
})
}
pub fn reset(&mut self) {
self.phrase_counts.clear();
self.active_abbreviations.clear();
self.next_index = 1;
}
pub fn phrase_counts(&self) -> &HashMap<String, usize> {
&self.phrase_counts
}
pub fn active_abbreviations(&self) -> &[Abbreviation] {
&self.active_abbreviations
}
}
impl Default for NgramAbbreviator {
fn default() -> Self {
Self::new()
}
}
fn estimate_savings(phrase: &str, occurrences: usize) -> usize {
let phrase_tokens = phrase.split_whitespace().count();
let abbrev_cost = 1; let legend_cost = phrase_tokens + 2;
if occurrences <= 1 || phrase_tokens <= abbrev_cost {
return 0;
}
let replacements = occurrences - 1; let gross_savings = replacements * phrase_tokens;
let total_cost = legend_cost + occurrences * abbrev_cost;
gross_savings.saturating_sub(total_cost)
}
fn count_occurrences(haystack: &str, needle: &str) -> usize {
if needle.is_empty() {
return 0;
}
haystack.matches(needle).count()
}
fn replace_after_first(text: &str, needle: &str, replacement: &str) -> String {
if needle.is_empty() {
return text.to_string();
}
let mut result = String::with_capacity(text.len());
let mut found_first = false;
let mut remaining = text;
while let Some(pos) = remaining.find(needle) {
result.push_str(&remaining[..pos]);
if !found_first {
result.push_str(needle);
found_first = true;
} else {
result.push_str(replacement);
}
remaining = &remaining[pos + needle.len()..];
}
result.push_str(remaining);
result
}
fn format_legend(abbreviations: &[Abbreviation]) -> String {
let mut lines = vec!["[Abbreviations]".to_string()];
for abbrev in abbreviations {
lines.push(format!("{}={}", abbrev.symbol, abbrev.phrase));
}
lines.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_abbreviations_below_threshold() {
let mut abbr = NgramAbbreviator::new();
abbr.observe("hello world foo bar");
let result = abbr.abbreviate("hello world foo bar").unwrap();
assert!(result.abbreviations.is_empty());
assert_eq!(result.total_tokens_saved, 0);
}
#[test]
fn test_abbreviation_after_threshold() {
let mut abbr = NgramAbbreviator::with_config(AbbreviatorConfig {
min_occurrences: 2,
min_phrase_tokens: 3,
max_phrase_tokens: 8,
max_abbreviations: 20,
});
let text = "error mismatched types found. Then error mismatched types found again. And error mismatched types found once more.";
abbr.observe(text);
let result = abbr.abbreviate(text).unwrap();
if !result.abbreviations.is_empty() {
assert!(result.text.contains("[Abbreviations]"));
assert!(result.text.contains("Β«A"));
}
}
#[test]
fn test_replace_after_first() {
let text = "abc def abc def abc";
let result = replace_after_first(text, "abc", "X");
assert_eq!(result, "abc def X def X");
}
#[test]
fn test_replace_after_first_single_occurrence() {
let text = "abc def ghi";
let result = replace_after_first(text, "abc", "X");
assert_eq!(result, "abc def ghi");
}
#[test]
fn test_replace_after_first_empty_needle() {
let text = "abc def";
let result = replace_after_first(text, "", "X");
assert_eq!(result, "abc def");
}
#[test]
fn test_count_occurrences() {
assert_eq!(count_occurrences("abc abc abc", "abc"), 3);
assert_eq!(count_occurrences("abc def ghi", "xyz"), 0);
assert_eq!(count_occurrences("", "abc"), 0);
assert_eq!(count_occurrences("abc", ""), 0);
}
#[test]
fn test_estimate_savings() {
assert_eq!(estimate_savings("error mismatched types", 5), 2);
assert_eq!(estimate_savings("error mismatched types", 1), 0);
}
#[test]
fn test_format_legend() {
let abbrevs = vec![Abbreviation {
symbol: "Β«A1Β»".to_string(),
phrase: "error mismatched types".to_string(),
occurrences: 5,
tokens_saved: 8,
}];
let legend = format_legend(&abbrevs);
assert!(legend.contains("[Abbreviations]"));
assert!(legend.contains("Β«A1Β»=error mismatched types"));
}
#[test]
fn test_reset_clears_state() {
let mut abbr = NgramAbbreviator::new();
abbr.observe("some text with words and more words");
assert!(!abbr.phrase_counts.is_empty());
abbr.reset();
assert!(abbr.phrase_counts.is_empty());
assert!(abbr.active_abbreviations.is_empty());
assert_eq!(abbr.next_index, 1);
}
#[test]
fn test_observe_short_text_noop() {
let mut abbr = NgramAbbreviator::new();
abbr.observe("hi");
assert!(abbr.phrase_counts.is_empty());
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_first_occurrence_preserved(
phrase in "[a-z]{3,6} [a-z]{3,6} [a-z]{3,6}",
repeat in 3usize..=6usize,
) {
let mut abbr = NgramAbbreviator::with_config(AbbreviatorConfig {
min_occurrences: 2,
min_phrase_tokens: 3,
max_phrase_tokens: 8,
max_abbreviations: 20,
});
let text = std::iter::repeat(phrase.as_str())
.take(repeat)
.collect::<Vec<_>>()
.join(". ");
abbr.observe(&text);
let result = abbr.abbreviate(&text).unwrap();
prop_assert!(
result.text.contains(&phrase),
"first occurrence of '{}' should be preserved in:\n{}",
phrase, result.text
);
}
#[test]
fn prop_savings_non_negative(
text in "[a-z ]{20,200}"
) {
let mut abbr = NgramAbbreviator::new();
abbr.observe(&text);
let result = abbr.abbreviate(&text).unwrap();
let _ = result.total_tokens_saved;
}
}
}