use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use libdictenstein::persistent_artrie::PersistentARTrie;
use parking_lot::RwLock;
use super::storage::NgramStorage;
use crate::ngram::vocabulary::open_or_create_concurrent_vocabulary_lockfree_with_capacity;
use super::aggregator::YearAggregator;
use super::checkpoint::{CheckpointError, ImportCheckpoint, TrieCheckpointStorage};
use super::config::GoogleBooksConfig;
use super::languages::{get_prefixes, is_supported};
use super::reader::{FileNgramReader, ReaderError};
#[cfg(feature = "google-books")]
use super::task_manager::RetryAfter;
fn estimate_ngram_count(config: &GoogleBooksConfig) -> u64 {
let per_order: &[u64] = match config.language.as_str() {
"en" | "eng" => &[
0, 13_000_000, 314_000_000, 977_000_000, 1_313_000_000, 1_176_000_000, ],
_ => &[
0, 5_000_000, 100_000_000, 300_000_000, 500_000_000, 400_000_000, ],
};
let mut total = 0u64;
for order in config.orders.clone() {
if let Some(&count) = per_order.get(order as usize) {
let factor = match config.min_count {
0..=1 => 1.0,
2..=10 => 0.4,
11..=40 => 0.2,
41..=100 => 0.1,
_ => 0.05,
};
total += (count as f64 * factor) as u64;
}
}
total
}
fn estimate_vocabulary_size(config: &GoogleBooksConfig) -> usize {
let base_vocab = match config.language.as_str() {
"en" | "eng" => 13_000_000usize,
_ => 5_000_000usize,
};
let factor = match config.min_count {
0..=1 => 1.0,
2..=10 => 0.4,
11..=40 => 0.2,
41..=100 => 0.1,
_ => 0.05,
};
(base_vocab as f64 * factor) as usize
}
fn is_retryable_error(e: &ImportError) -> bool {
match e {
ImportError::Reader(reader_err) => {
let msg = reader_err.to_string().to_lowercase();
msg.contains("timeout")
|| msg.contains("timed out")
|| msg.contains("elapsed") || msg.contains("deadline") || msg.contains("connection")
|| msg.contains("connect") || msg.contains("network")
|| msg.contains("temporarily")
|| msg.contains("reset")
|| msg.contains("broken pipe")
|| msg.contains("refused") || msg.contains("unreachable") || msg.contains("error sending request") || msg.contains("request") || msg.contains("dns") || msg.contains("resolve") || msg.contains("decoding") || msg.contains("decode") }
ImportError::Io(io_err) => {
let msg = io_err.to_string().to_lowercase();
msg.contains("timeout")
|| msg.contains("timed out")
|| msg.contains("elapsed")
|| msg.contains("deadline")
|| msg.contains("connection")
|| msg.contains("connect") || msg.contains("network")
|| msg.contains("temporarily")
|| msg.contains("reset")
|| msg.contains("broken pipe")
|| msg.contains("refused")
|| msg.contains("unreachable")
|| msg.contains("error sending request") || msg.contains("request") || msg.contains("dns") || msg.contains("resolve") || msg.contains("decoding") || msg.contains("decode") || io_err.kind() == std::io::ErrorKind::TimedOut
|| io_err.kind() == std::io::ErrorKind::ConnectionReset
|| io_err.kind() == std::io::ErrorKind::ConnectionRefused
|| io_err.kind() == std::io::ErrorKind::ConnectionAborted
|| io_err.kind() == std::io::ErrorKind::NotConnected
}
_ => false,
}
}
#[cfg(feature = "google-books")]
fn extract_retry_after(error: &ImportError) -> Option<RetryAfter> {
match error {
ImportError::Reader(ReaderError::RateLimited { retry_after, .. }) => retry_after.clone(),
_ => None,
}
}
#[derive(Debug, Clone, Copy)]
pub struct NgramStorageResult {
pub is_new: bool,
}
pub const COUNTER_BATCH_SIZE: u64 = 10_000;
fn store_ngram_shared(
ngram: &str,
count: u64,
storage: &Arc<NgramStorage>,
) -> Result<NgramStorageResult, ImportError> {
let is_new = storage.store_ngram(ngram, count)?;
Ok(NgramStorageResult { is_new })
}
#[allow(dead_code)]
fn store_ngram_shared_legacy(
ngram: &str,
count: u64,
trie: &Arc<RwLock<PersistentARTrie<u64>>>,
) -> Result<NgramStorageResult, ImportError> {
let mut trie_guard = trie.write();
let is_new = trie_guard.get_value_bytes(ngram.as_bytes()).is_none();
trie_guard
.increment_bytes(ngram.as_bytes(), count as i64)
.map_err(|e| ImportError::Trie(format!("Failed to store ngram '{}': {}", ngram, e)))?;
Ok(NgramStorageResult { is_new })
}
#[derive(Debug, thiserror::Error)]
pub enum TrieCheckpointError {
#[error("Trie operation failed: {0}")]
TrieError(String),
}
impl TrieCheckpointStorage for PersistentARTrie<u64> {
type Error = TrieCheckpointError;
fn store_checkpoint_u64(&mut self, key: &str, value: u64) -> Result<(), Self::Error> {
self.upsert_bytes(key.as_bytes(), value)
.map_err(|e| TrieCheckpointError::TrieError(e.to_string()))?;
Ok(())
}
fn load_checkpoint_u64(&self, key: &str) -> Result<Option<u64>, Self::Error> {
Ok(self.get_value_bytes(key.as_bytes()))
}
fn delete_checkpoint_key(&mut self, key: &str) -> Result<bool, Self::Error> {
Ok(self.remove(key))
}
fn delete_checkpoint_prefix(&mut self, prefix: &str) -> Result<usize, Self::Error> {
Ok(self.remove_prefix(prefix.as_bytes()))
}
fn iter_checkpoint_prefix(&self, prefix: &str) -> Result<Vec<(String, u64)>, Self::Error> {
match self.iter_prefix_with_values(prefix.as_bytes()) {
Some(iter) => Ok(iter
.map(|(k, v)| (String::from_utf8_lossy(&k).into_owned(), v))
.collect()),
None => Ok(Vec::new()),
}
}
}
#[derive(Clone, Debug)]
pub struct ImportProgress {
pub current_order: u8,
pub current_prefix: String,
pub ngrams_in_file: u64,
pub total_ngrams: u64,
pub files_completed: u32,
pub total_files: u32,
pub bytes_downloaded: u64,
pub ngrams_per_second: f64,
pub eta_seconds: Option<u64>,
pub phase: ImportPhase,
}
#[derive(Clone, Debug, PartialEq)]
pub enum ImportPhase {
Importing,
MknPass1,
MknPass2,
Finalizing,
Complete,
}
#[derive(Clone, Debug)]
pub enum WorkerUpdate {
Started {
worker_id: usize,
order: u8,
prefix: Arc<str>,
attempt: u8,
},
Finished {
worker_id: usize,
order: u8,
prefix: Arc<str>,
ngram_count: u64,
duration: Duration,
},
NgramProgress {
worker_id: usize,
ngram_count: u64,
},
Retrying {
worker_id: usize,
order: u8,
prefix: Arc<str>,
attempt: u32,
error: Arc<str>,
},
Deferred {
worker_id: usize,
order: u8,
prefix: Arc<str>,
attempt: u32,
delay_seconds: u64,
error: Arc<str>,
},
Exited {
worker_id: usize,
},
}
#[derive(Clone, Debug, Default)]
pub struct ImportStats {
pub total_ngrams: u64,
pub ngrams_by_order: [u64; 5],
pub unique_ngrams: u64,
pub bytes_downloaded: u64,
pub files_processed: u32,
pub elapsed_seconds: u64,
pub ngrams_per_second: f64,
}
#[derive(Debug, thiserror::Error)]
pub enum ImportError {
#[error("Configuration error: {0}")]
Config(String),
#[error("Unsupported language: {0}")]
UnsupportedLanguage(String),
#[error("Reader error: {0}")]
Reader(#[from] ReaderError),
#[error("Checkpoint error: {0}")]
Checkpoint(#[from] CheckpointError),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Import interrupted (checkpoint saved)")]
Interrupted,
#[error("Trie error: {0}")]
Trie(String),
#[error("Storage error: {0}")]
Storage(#[from] super::storage::StorageError),
}
pub struct GoogleBooksImporter {
config: GoogleBooksConfig,
checkpoint: ImportCheckpoint,
checkpoint_path: PathBuf,
total_ngrams: AtomicU64,
unique_ngrams: AtomicU64,
interrupted: AtomicBool,
start_time: Instant,
storage: Arc<NgramStorage>,
lockfree_flush_threshold: u64,
}
impl GoogleBooksImporter {
pub fn new(config: GoogleBooksConfig) -> Result<Self, ImportError> {
if !is_supported(&config.language) {
return Err(ImportError::UnsupportedLanguage(config.language.clone()));
}
let checkpoint_path = config.output_path.with_extension("checkpoint.json");
let estimated_ngrams = estimate_ngram_count(&config);
log::info!("Estimated n-gram count: {}", estimated_ngrams);
let estimated_vocab = estimate_vocabulary_size(&config);
log::info!("Estimated vocabulary size: {}", estimated_vocab);
let vocabulary_path = config.vocabulary_path();
log::info!("Using vocabulary at {:?}", vocabulary_path);
let vocabulary = open_or_create_concurrent_vocabulary_lockfree_with_capacity(
&vocabulary_path,
estimated_vocab,
)
.map_err(|e| ImportError::Trie(format!("Failed to create/open vocabulary: {}", e)))?;
let storage = NgramStorage::resume_or_start_with_vocabulary(
&config,
estimated_ngrams,
Some(vocabulary),
)
.map_err(|e| ImportError::Trie(format!("Failed to create storage: {}", e)))?;
if storage.is_sharded() {
log::info!("Using sharded storage with vocabulary-indexed encoding");
} else {
log::info!("Using single-trie storage with vocabulary-indexed encoding");
}
let lockfree_flush_threshold = if config.parallel_downloads >= 8 {
50_000
} else {
100_000
};
Ok(Self {
config,
checkpoint: ImportCheckpoint::new(),
checkpoint_path,
total_ngrams: AtomicU64::new(0),
unique_ngrams: AtomicU64::new(0),
interrupted: AtomicBool::new(false),
start_time: Instant::now(),
storage: Arc::new(storage),
lockfree_flush_threshold,
})
}
pub fn set_lockfree_flush_threshold(&mut self, threshold: u64) {
self.lockfree_flush_threshold = threshold;
log::info!(
"Lock-free flush threshold set to {} entries per shard",
threshold
);
}
pub fn lockfree_flush_threshold(&self) -> u64 {
self.lockfree_flush_threshold
}
pub fn resume_or_start(config: GoogleBooksConfig) -> Result<Self, ImportError> {
let checkpoint_path = config.output_path.with_extension("checkpoint.json");
let vocabulary_path = config.vocabulary_path();
Self::check_vocabulary_wal_consistency(&vocabulary_path, &checkpoint_path);
let mut importer = Self::new(config)?;
let trie_checkpoint = importer.storage.load_import_checkpoint()?;
if let Some(checkpoint) = trie_checkpoint {
log::info!(
"Resuming from trie checkpoint: {} orders in progress, {} total prefixes completed",
checkpoint.orders_in_progress().len(),
checkpoint.total_completed_prefix_count()
);
importer.checkpoint = checkpoint;
for order in importer.config.orders.clone() {
let in_progress = importer.checkpoint.in_progress_prefixes(order);
if !in_progress.is_empty() {
log::warn!(
"Order {}: recovering {} in-progress prefixes as failed for retry: {:?}",
order,
in_progress.len(),
in_progress
);
importer.checkpoint.recover_in_progress_as_failed(order);
}
}
if let Some(coordinator) = importer.storage.as_sharded() {
let mut reconciled_count = 0usize;
for order in importer.config.orders.clone() {
let shard_completed = coordinator.completed_prefixes_for_order(order);
let importer_completed: Vec<String> = importer
.checkpoint
.order_progress
.get(&order)
.map(|p| p.completed_prefixes().cloned().collect())
.unwrap_or_default();
for prefix in importer_completed {
if !shard_completed.contains(&prefix) {
log::warn!(
"Order {}: prefix '{}' marked complete in importer checkpoint but \
not found in shard state - marking for retry",
order,
prefix
);
importer.checkpoint.fail_prefix(order, &prefix);
reconciled_count += 1;
}
}
}
if reconciled_count > 0 {
log::warn!(
"Reconciliation: {} prefixes marked for retry due to missing shard data",
reconciled_count
);
}
}
importer.total_ngrams.store(
importer.checkpoint.stats.ngrams_processed,
Ordering::Relaxed,
);
importer
.unique_ngrams
.store(importer.checkpoint.stats.unique_ngrams, Ordering::Relaxed);
if ImportCheckpoint::exists(&checkpoint_path) {
if let Err(e) = ImportCheckpoint::delete(&checkpoint_path) {
log::warn!("Failed to delete legacy JSON checkpoint: {}", e);
} else {
log::info!("Deleted legacy JSON checkpoint (migrated to trie)");
}
}
return Ok(importer);
}
if ImportCheckpoint::exists(&checkpoint_path) {
let checkpoint = ImportCheckpoint::load(&checkpoint_path)?;
log::info!(
"Resuming from JSON checkpoint: {} orders in progress, {} total prefixes completed",
checkpoint.orders_in_progress().len(),
checkpoint.total_completed_prefix_count()
);
importer.checkpoint = checkpoint;
for order in importer.config.orders.clone() {
let in_progress = importer.checkpoint.in_progress_prefixes(order);
if !in_progress.is_empty() {
log::warn!(
"Order {}: recovering {} in-progress prefixes as failed for retry: {:?}",
order,
in_progress.len(),
in_progress
);
importer.checkpoint.recover_in_progress_as_failed(order);
}
}
if let Some(coordinator) = importer.storage.as_sharded() {
let mut reconciled_count = 0usize;
for order in importer.config.orders.clone() {
let shard_completed = coordinator.completed_prefixes_for_order(order);
let importer_completed: Vec<String> = importer
.checkpoint
.order_progress
.get(&order)
.map(|p| p.completed_prefixes().cloned().collect())
.unwrap_or_default();
for prefix in importer_completed {
if !shard_completed.contains(&prefix) {
log::warn!(
"Order {}: prefix '{}' marked complete in importer checkpoint but \
not found in shard state - marking for retry",
order,
prefix
);
importer.checkpoint.fail_prefix(order, &prefix);
reconciled_count += 1;
}
}
}
if reconciled_count > 0 {
log::warn!(
"Reconciliation: {} prefixes marked for retry due to missing shard data",
reconciled_count
);
}
}
importer.total_ngrams.store(
importer.checkpoint.stats.ngrams_processed,
Ordering::Relaxed,
);
importer
.unique_ngrams
.store(importer.checkpoint.stats.unique_ngrams, Ordering::Relaxed);
log::info!("Migrating JSON checkpoint to trie-based storage...");
importer
.storage
.save_import_checkpoint_async(&importer.checkpoint)
.map_err(|e| {
ImportError::Trie(format!("Failed to migrate checkpoint to trie: {}", e))
})?;
return Ok(importer);
}
Ok(importer)
}
fn check_vocabulary_wal_consistency(vocabulary_path: &Path, checkpoint_path: &Path) {
let checkpoint_trie_path = checkpoint_path.with_extension("checkpoint.artrie");
let has_checkpoint = checkpoint_path.exists() || checkpoint_trie_path.exists();
if !has_checkpoint {
return; }
let vocab_wal_path = vocabulary_path.with_extension("vocab.wal");
let vocab_wal_path2 = {
let mut p = vocabulary_path.to_path_buf();
p.set_extension("wal");
p
};
for wal_path in [vocab_wal_path, vocab_wal_path2] {
if wal_path.exists() {
if let Ok(metadata) = std::fs::metadata(&wal_path) {
let size = metadata.len();
const WARNING_THRESHOLD: u64 = 1_000_000;
if size > WARNING_THRESHOLD {
log::warn!(
"VOCABULARY WAL INCONSISTENCY DETECTED: {} is {} bytes",
wal_path.display(),
size
);
log::warn!(
"This indicates a previous checkpoint did not properly flush the vocabulary."
);
log::warn!(
"Resume may result in index inconsistency and duplicated n-gram counts."
);
log::warn!(
"Consider starting a fresh import or manually checkpointing the vocabulary."
);
}
}
}
}
}
pub fn interrupt(&self) {
self.interrupted.store(true, Ordering::Release);
}
pub fn is_interrupted(&self) -> bool {
self.interrupted.load(Ordering::Acquire)
}
fn get_filtered_prefixes(&self, order: u8) -> Vec<String> {
let all_prefixes = get_prefixes(order);
match &self.config.prefix {
Some(p) => {
if all_prefixes.contains(p) {
vec![p.clone()]
} else {
vec![] }
}
None => all_prefixes,
}
}
fn process_file(&mut self, path: &Path) -> Result<u64, ImportError> {
if !self.storage.is_sharded() {
return self.process_file_with_transaction(path);
}
let reader = FileNgramReader::open_with_options(
path,
self.config.skip_pos_tags,
self.config.min_count,
)?;
let mut aggregator = YearAggregator::new(self.config.year_range);
let mut ngrams_in_file = 0u64;
for result in reader {
let record = result?;
if let Some(aggregated) = aggregator.push(record) {
self.store_ngram(&aggregated.ngram, aggregated.total_count)?;
ngrams_in_file += 1;
}
}
if let Some(aggregated) = aggregator.flush() {
self.store_ngram(&aggregated.ngram, aggregated.total_count)?;
ngrams_in_file += 1;
}
self.total_ngrams
.fetch_add(ngrams_in_file, Ordering::Relaxed);
Ok(ngrams_in_file)
}
fn process_file_with_transaction(&self, path: &Path) -> Result<u64, ImportError> {
let file_id = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown");
let mut tx = self
.storage
.begin_file_tx(file_id)
.map_err(|e| ImportError::Trie(format!("Failed to begin file tx: {}", e)))?;
let result = self.process_file_inner(&mut tx, path);
match result {
Ok(ngrams_in_file) => {
self.storage
.commit_file_tx(tx)
.map_err(|e| ImportError::Trie(format!("Failed to commit file tx: {}", e)))?;
self.total_ngrams
.fetch_add(ngrams_in_file, Ordering::Relaxed);
Ok(ngrams_in_file)
}
Err(e) => {
let _ = self.storage.abort_file_tx(tx);
Err(e)
}
}
}
fn process_file_inner(
&self,
tx: &mut super::storage::StorageFileTx,
path: &Path,
) -> Result<u64, ImportError> {
let reader = FileNgramReader::open_with_options(
path,
self.config.skip_pos_tags,
self.config.min_count,
)?;
let mut aggregator = YearAggregator::new(self.config.year_range);
let mut ngrams_in_file = 0u64;
for result in reader {
let record = result?;
if let Some(aggregated) = aggregator.push(record) {
self.storage
.tx_increment_ngram(tx, &aggregated.ngram, aggregated.total_count)
.map_err(|e| ImportError::Trie(format!("Failed to increment ngram: {}", e)))?;
ngrams_in_file += 1;
}
}
if let Some(aggregated) = aggregator.flush() {
self.storage
.tx_increment_ngram(tx, &aggregated.ngram, aggregated.total_count)
.map_err(|e| ImportError::Trie(format!("Failed to increment ngram: {}", e)))?;
ngrams_in_file += 1;
}
Ok(ngrams_in_file)
}
fn store_ngram(&self, ngram: &str, count: u64) -> Result<(), ImportError> {
let is_new = self
.storage
.store(ngram, count)
.map_err(|e| ImportError::Trie(format!("Failed to store ngram '{}': {}", ngram, e)))?;
if is_new {
self.unique_ngrams.fetch_add(1, Ordering::Relaxed);
}
Ok(())
}
fn calculate_rate(&self) -> f64 {
let elapsed = self.start_time.elapsed().as_secs_f64();
if elapsed > 0.0 {
self.total_ngrams.load(Ordering::Relaxed) as f64 / elapsed
} else {
0.0
}
}
fn estimate_eta(&self, completed: u32, total: u32) -> Option<u64> {
if completed == 0 || completed >= total {
return None;
}
let elapsed = self.start_time.elapsed().as_secs_f64();
let rate = completed as f64 / elapsed;
let remaining = (total - completed) as f64 / rate;
Some(remaining as u64)
}
fn build_stats(&self) -> Result<ImportStats, ImportError> {
let elapsed = self.start_time.elapsed().as_secs();
let total = self.total_ngrams.load(Ordering::Relaxed);
Ok(ImportStats {
total_ngrams: total,
ngrams_by_order: self.checkpoint.stats.ngrams_by_order,
unique_ngrams: self.unique_ngrams.load(Ordering::Relaxed),
bytes_downloaded: self.checkpoint.stats.bytes_downloaded,
files_processed: self.checkpoint.stats.files_processed,
elapsed_seconds: elapsed,
ngrams_per_second: if elapsed > 0 {
total as f64 / elapsed as f64
} else {
0.0
},
})
}
pub fn checkpoint(&self) -> &ImportCheckpoint {
&self.checkpoint
}
pub fn config(&self) -> &GoogleBooksConfig {
&self.config
}
}
impl Drop for GoogleBooksImporter {
fn drop(&mut self) {
if let Err(e) = self.storage.rotate_vocabulary_wal() {
log::error!("Failed to rotate vocabulary WAL on drop: {}", e);
}
}
}
#[cfg(feature = "google-books")]
pub async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}
#[cfg(feature = "google-books")]
pub async fn run_import_with_shutdown<F>(
importer: GoogleBooksImporter,
progress: F,
) -> Result<ImportStats, ImportError>
where
F: FnMut(ImportProgress) + Send + 'static,
{
let importer_ref = Arc::new(parking_lot::Mutex::new(importer));
let importer_clone = Arc::clone(&importer_ref);
let shutdown_handle = tokio::spawn(async move {
shutdown_signal().await;
log::warn!("Received shutdown signal, saving checkpoint...");
if let Some(importer) = importer_clone.try_lock() {
importer.interrupt();
}
});
let result = {
let mut importer = importer_ref.lock();
importer.import_http(progress).await
};
shutdown_handle.abort();
result
}
pub const DEFAULT_CHECKPOINT_INTERVAL_MS: u64 = 5000;
#[cfg(feature = "google-books")]
mod checkpoint_ops;
mod cron;
pub use cron::{run_import_with_periodic_checkpoints, CheckpointState};
mod finalize;
#[cfg(feature = "google-books")]
mod import_ops;
#[cfg(feature = "google-books")]
mod mkn;
#[cfg(feature = "google-books")]
mod worker_pool;
#[cfg(feature = "google-books")]
mod cache;
#[cfg(feature = "google-books")]
use cache::{cleanup_cache_file, download_to_cache};
#[cfg(test)]
mod tests;