use std::io::Write;
use std::path::Path;
use libdictenstein::double_array_trie_char::DoubleArrayTrieChar;
use libdictenstein::persistent_artrie_char::PersistentARTrieChar;
#[inline]
fn is_unigram(term: &str) -> bool {
!term.starts_with('\x00') && !term.contains(char::is_whitespace)
}
#[derive(Clone, Debug, Default)]
pub struct ExtractionStats {
pub total_unigrams: u64,
pub filtered_unigrams: u64,
pub vocabulary_size: u64,
pub source_size_bytes: u64,
pub dictionary_size_bytes: u64,
pub elapsed_seconds: f64,
pub words_extracted: u64,
pub words_filtered: u64,
pub dict_size_bytes: u64,
}
#[derive(Debug, thiserror::Error)]
pub enum ExtractionError {
#[error("Model file not found: {0}")]
ModelNotFound(String),
#[error("No 1-grams found in model")]
NoUnigrams,
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Dictionary construction error: {0}")]
Construction(String),
}
pub struct DictionaryExtractor;
impl DictionaryExtractor {
pub fn extract_words<P: AsRef<Path>>(
model_path: P,
min_count: u64,
) -> Result<(Vec<String>, ExtractionStats), ExtractionError> {
use std::time::Instant;
let start = Instant::now();
let model_path = model_path.as_ref();
if !model_path.exists() {
return Err(ExtractionError::ModelNotFound(
model_path.display().to_string(),
));
}
let source_size = std::fs::metadata(model_path)?.len();
log::info!(
"Extracting dictionary from {:?} with min_count={}",
model_path,
min_count
);
let trie: PersistentARTrieChar<u64> =
PersistentARTrieChar::open(model_path).map_err(|e| {
ExtractionError::Io(std::io::Error::other(format!("Failed to open trie: {}", e)))
})?;
let mut words = Vec::new();
let mut total_unigrams = 0u64;
let mut filtered_unigrams = 0u64;
for (term, count) in trie.iter_with_values() {
if is_unigram(&term) {
total_unigrams += 1;
if count >= min_count {
words.push(term);
} else {
filtered_unigrams += 1;
}
}
}
words.sort();
log::info!(
"Extracted {} words from {} unigrams ({} filtered) in {:.2}s",
words.len(),
total_unigrams,
filtered_unigrams,
start.elapsed().as_secs_f64()
);
let stats = ExtractionStats {
total_unigrams,
filtered_unigrams,
vocabulary_size: words.len() as u64,
source_size_bytes: source_size,
dictionary_size_bytes: 0,
elapsed_seconds: start.elapsed().as_secs_f64(),
words_extracted: words.len() as u64,
words_filtered: filtered_unigrams,
dict_size_bytes: 0,
};
Ok((words, stats))
}
pub fn extract_to_file<P: AsRef<Path>>(
model_path: P,
output_path: P,
min_count: u64,
) -> Result<ExtractionStats, ExtractionError> {
use std::time::Instant;
let start = Instant::now();
let output_path = output_path.as_ref();
let (words, mut stats) = Self::extract_words(&model_path, min_count)?;
if words.is_empty() {
return Err(ExtractionError::NoUnigrams);
}
log::info!("Building DoubleArrayTrieChar from {} words", words.len());
let dict: DoubleArrayTrieChar<()> = DoubleArrayTrieChar::from_terms(&words);
log::info!("Serializing dictionary to {:?}", output_path);
let bytes = bincode::serialize(&dict).map_err(|e| {
ExtractionError::Serialization(format!("Failed to serialize dictionary: {}", e))
})?;
let mut file = std::fs::File::create(output_path)?;
file.write_all(&bytes)?;
file.flush()?;
let dictionary_size = bytes.len() as u64;
stats.dictionary_size_bytes = dictionary_size;
stats.dict_size_bytes = dictionary_size;
stats.elapsed_seconds = start.elapsed().as_secs_f64();
log::info!(
"Dictionary extraction complete: {} words, {} bytes in {:.2}s",
stats.vocabulary_size,
stats.dictionary_size_bytes,
stats.elapsed_seconds
);
Ok(stats)
}
pub fn extract_with_progress<P1, P2, F>(
model_path: P1,
output_path: P2,
min_count: u64,
mut progress: F,
) -> Result<ExtractionStats, ExtractionError>
where
P1: AsRef<Path>,
P2: AsRef<Path>,
F: FnMut(ExtractionProgress),
{
use std::time::Instant;
let start = Instant::now();
let model_path = model_path.as_ref();
let output_path = output_path.as_ref();
if !model_path.exists() {
return Err(ExtractionError::ModelNotFound(
model_path.display().to_string(),
));
}
let source_size = std::fs::metadata(model_path)?.len();
progress(ExtractionProgress {
phase: ExtractionPhase::Loading,
items_processed: 0,
items_total: None,
items_accepted: 0,
elapsed_seconds: start.elapsed().as_secs_f64(),
words_processed: 0,
});
let trie: PersistentARTrieChar<u64> =
PersistentARTrieChar::open(model_path).map_err(|e| {
ExtractionError::Io(std::io::Error::other(format!("Failed to open trie: {}", e)))
})?;
let mut words = Vec::new();
let mut total_unigrams = 0u64;
let mut filtered_unigrams = 0u64;
let mut last_progress = 0u64;
for (term, count) in trie.iter_with_values() {
if is_unigram(&term) {
total_unigrams += 1;
if count >= min_count {
words.push(term);
} else {
filtered_unigrams += 1;
}
if total_unigrams - last_progress >= 100_000 {
last_progress = total_unigrams;
progress(ExtractionProgress {
phase: ExtractionPhase::Filtering,
items_processed: total_unigrams,
items_total: None,
items_accepted: words.len() as u64,
elapsed_seconds: start.elapsed().as_secs_f64(),
words_processed: total_unigrams,
});
}
}
}
words.sort();
if words.is_empty() {
return Err(ExtractionError::NoUnigrams);
}
progress(ExtractionProgress {
phase: ExtractionPhase::Building,
items_processed: total_unigrams,
items_total: Some(total_unigrams),
items_accepted: words.len() as u64,
elapsed_seconds: start.elapsed().as_secs_f64(),
words_processed: total_unigrams,
});
let dict: DoubleArrayTrieChar<()> = DoubleArrayTrieChar::from_terms(&words);
progress(ExtractionProgress {
phase: ExtractionPhase::Saving,
items_processed: total_unigrams,
items_total: Some(total_unigrams),
items_accepted: words.len() as u64,
elapsed_seconds: start.elapsed().as_secs_f64(),
words_processed: total_unigrams,
});
let bytes = bincode::serialize(&dict).map_err(|e| {
ExtractionError::Serialization(format!("Failed to serialize dictionary: {}", e))
})?;
let mut file = std::fs::File::create(output_path)?;
file.write_all(&bytes)?;
file.flush()?;
let dictionary_size = bytes.len() as u64;
progress(ExtractionProgress {
phase: ExtractionPhase::Complete,
items_processed: total_unigrams,
items_total: Some(total_unigrams),
items_accepted: words.len() as u64,
elapsed_seconds: start.elapsed().as_secs_f64(),
words_processed: total_unigrams,
});
let stats = ExtractionStats {
total_unigrams,
filtered_unigrams,
vocabulary_size: words.len() as u64,
source_size_bytes: source_size,
dictionary_size_bytes: dictionary_size,
elapsed_seconds: start.elapsed().as_secs_f64(),
words_extracted: words.len() as u64,
words_filtered: filtered_unigrams,
dict_size_bytes: dictionary_size,
};
Ok(stats)
}
pub fn extract_to_file_with_progress<P1, P2, F>(
model_path: P1,
output_path: P2,
min_count: u64,
progress: F,
) -> Result<ExtractionStats, ExtractionError>
where
P1: AsRef<Path>,
P2: AsRef<Path>,
F: FnMut(ExtractionProgress),
{
Self::extract_with_progress(model_path, output_path, min_count, progress)
}
}
#[derive(Clone, Debug)]
pub struct ExtractionProgress {
pub phase: ExtractionPhase,
pub items_processed: u64,
pub items_total: Option<u64>,
pub items_accepted: u64,
pub elapsed_seconds: f64,
pub words_processed: u64,
}
#[derive(Clone, Debug, PartialEq)]
pub enum ExtractionPhase {
Loading,
Filtering,
Building,
Saving,
Complete,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extraction_stats_default() {
let stats = ExtractionStats::default();
assert_eq!(stats.total_unigrams, 0);
assert_eq!(stats.vocabulary_size, 0);
}
#[test]
fn test_extraction_model_not_found() {
let result = DictionaryExtractor::extract_words("/nonexistent/path.artrie", 100);
assert!(matches!(result, Err(ExtractionError::ModelNotFound(_))));
}
#[test]
fn test_extraction_progress() {
let progress = ExtractionProgress {
phase: ExtractionPhase::Filtering,
items_processed: 1000,
items_total: Some(10000),
items_accepted: 500,
elapsed_seconds: 1.5,
words_processed: 1000,
};
assert_eq!(progress.phase, ExtractionPhase::Filtering);
assert_eq!(progress.items_processed, 1000);
assert_eq!(progress.items_accepted, 500);
}
#[test]
fn test_is_unigram() {
assert!(is_unigram("hello"));
assert!(is_unigram("world123"));
assert!(is_unigram("café"));
assert!(!is_unigram("hello world"));
assert!(!is_unigram("the quick"));
assert!(!is_unigram("\x00metadata"));
assert!(!is_unigram("\x00__checkpoint__"));
}
}