use super::coordinator::ShardCoordinator;
use super::routing::ShardKey;
use libdictenstein::persistent_artrie::PersistentARTrie;
use rayon::prelude::*;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use thiserror::Error;
use xxhash_rust::xxh3::Xxh3DefaultBuilder;
type XxHashMap<K, V> = HashMap<K, V, Xxh3DefaultBuilder>;
#[derive(Error, Debug)]
pub enum MergeError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Shard error: {0}")]
Shard(#[from] super::shard::ShardError),
#[error("Coordinator error: {0}")]
Coordinator(#[from] super::coordinator::CoordinatorError),
#[error("Trie error: {0}")]
Trie(String),
#[error("No shards available for merge")]
NoShards,
#[error("Merge cancelled")]
Cancelled,
#[error("Failed to open shard '{shard_key}': {message}")]
ShardOpen {
shard_key: String,
message: String,
},
}
pub type MergeResult<T> = Result<T, MergeError>;
#[derive(Clone, Debug)]
pub struct MergeProgress {
pub phase: usize,
pub total_phases: usize,
pub shards_remaining: usize,
pub total_shards: usize,
pub ngrams_merged: u64,
pub percent_complete: f32,
}
#[derive(Clone, Debug)]
pub struct MergeStats {
pub phases: usize,
pub total_ngrams: u64,
pub bytes_written: u64,
pub duration_ms: u64,
pub shards_merged: usize,
}
pub struct MergeCoordinator<'a> {
coordinator: &'a ShardCoordinator,
work_dir: PathBuf,
parallelism: usize,
cleanup_intermediates: bool,
}
impl<'a> MergeCoordinator<'a> {
pub fn new(coordinator: &'a ShardCoordinator) -> Self {
let work_dir = coordinator.config().shard_dir.join("merge_work");
let parallelism = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4);
Self {
coordinator,
work_dir,
parallelism,
cleanup_intermediates: true,
}
}
pub fn with_work_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.work_dir = dir.into();
self
}
pub fn with_parallelism(mut self, parallelism: usize) -> Self {
self.parallelism = parallelism.max(1);
self
}
pub fn with_cleanup(mut self, cleanup: bool) -> Self {
self.cleanup_intermediates = cleanup;
self
}
pub fn merge_to_trie<F>(
&self,
output_path: impl AsRef<Path>,
mut progress_callback: F,
) -> MergeResult<MergeStats>
where
F: FnMut(MergeProgress) + Send,
{
let output_path = output_path.as_ref();
let start_time = std::time::Instant::now();
let shard_files = self
.coordinator
.discover_shard_files()
.map_err(|e| MergeError::Trie(format!("Failed to discover shard files: {}", e)))?;
if shard_files.is_empty() {
return Err(MergeError::NoShards);
}
let shard_keys: Vec<ShardKey> = shard_files.into_iter().map(|(key, _)| key).collect();
let total_shards = shard_keys.len();
log::info!(
"Starting merge of {} shards to {:?}",
total_shards,
output_path
);
let mut output_trie = PersistentARTrie::<u64>::create(output_path)
.map_err(|e| MergeError::Trie(format!("Failed to create output trie: {}", e)))?;
let mut ngrams_merged = 0u64;
let mut shards_processed = 0usize;
for key in &shard_keys {
progress_callback(MergeProgress {
phase: 1,
total_phases: 1,
shards_remaining: total_shards - shards_processed,
total_shards,
ngrams_merged,
percent_complete: (shards_processed as f32 / total_shards as f32) * 100.0,
});
log::trace!(
"Merging shard '{}' ({}/{})",
key,
shards_processed + 1,
total_shards
);
let shard =
self.coordinator
.get_or_create_shard(key)
.map_err(|e| MergeError::ShardOpen {
shard_key: key.to_string(),
message: e.to_string(),
})?;
let guard = shard.read();
let iter = guard
.iter_with_counts()
.map_err(|e| MergeError::Trie(format!("Shard {} iteration failed: {}", key, e)))?;
for (ngram, count) in iter {
output_trie
.increment_bytes(&ngram, count as i64)
.map_err(|e| MergeError::Trie(format!("Increment failed: {}", e)))?;
ngrams_merged += 1;
}
shards_processed += 1;
}
output_trie
.checkpoint()
.map_err(|e| MergeError::Trie(format!("Checkpoint failed: {}", e)))?;
progress_callback(MergeProgress {
phase: 1,
total_phases: 1,
shards_remaining: 0,
total_shards,
ngrams_merged,
percent_complete: 100.0,
});
let duration = start_time.elapsed();
let bytes_written = std::fs::metadata(output_path).map(|m| m.len()).unwrap_or(0);
Ok(MergeStats {
phases: 1,
total_ngrams: ngrams_merged,
bytes_written,
duration_ms: duration.as_millis() as u64,
shards_merged: total_shards,
})
}
pub fn merge_to_memory(&self) -> MergeResult<XxHashMap<Vec<u8>, u64>> {
let shard_files = self
.coordinator
.discover_shard_files()
.map_err(|e| MergeError::Trie(format!("Failed to discover shard files: {}", e)))?;
if shard_files.is_empty() {
return Err(MergeError::NoShards);
}
let shard_keys: Vec<ShardKey> = shard_files.into_iter().map(|(key, _)| key).collect();
log::info!("Merging {} shards to memory", shard_keys.len());
let results: Result<Vec<XxHashMap<Vec<u8>, u64>>, MergeError> = shard_keys
.par_iter()
.map(|key| {
log::trace!("Merging shard '{}' to memory", key);
let shard = self.coordinator.get_or_create_shard(key).map_err(|e| {
MergeError::ShardOpen {
shard_key: key.to_string(),
message: e.to_string(),
}
})?;
let guard = shard.read();
let iter = guard.iter_with_counts().map_err(|e| {
MergeError::Trie(format!("Shard {} iteration failed: {}", key, e))
})?;
Ok(iter.into_iter().collect::<XxHashMap<_, _>>())
})
.collect();
let results = results?;
let mut merged: XxHashMap<Vec<u8>, u64> = HashMap::with_hasher(Xxh3DefaultBuilder);
for partial in results {
for (ngram, count) in partial {
*merged.entry(ngram).or_default() += count;
}
}
log::info!("Merged {} unique n-grams", merged.len());
Ok(merged)
}
pub fn iter_all(&self) -> MergeResult<impl Iterator<Item = (Vec<u8>, u64)>> {
let shard_keys: Vec<ShardKey> = self
.coordinator
.discover_shard_files()
.map_err(|e| MergeError::Trie(format!("Failed to discover shard files: {}", e)))?
.into_iter()
.map(|(key, _)| key)
.collect();
let mut all_entries = Vec::new();
for key in shard_keys {
log::trace!("Iterating shard '{}'", key);
let shard =
self.coordinator
.get_or_create_shard(&key)
.map_err(|e| MergeError::ShardOpen {
shard_key: key.to_string(),
message: e.to_string(),
})?;
let guard = shard.read();
let iter = guard
.iter_with_counts()
.map_err(|e| MergeError::Trie(format!("Shard {} iteration failed: {}", key, e)))?;
all_entries.extend(iter);
}
Ok(all_entries.into_iter())
}
pub fn estimated_final_size(&self) -> u64 {
self.coordinator.total_entry_count()
}
}
pub struct MergeBuilder<'a> {
coordinator: &'a ShardCoordinator,
work_dir: Option<PathBuf>,
parallelism: Option<usize>,
cleanup: bool,
}
impl<'a> MergeBuilder<'a> {
pub fn new(coordinator: &'a ShardCoordinator) -> Self {
Self {
coordinator,
work_dir: None,
parallelism: None,
cleanup: true,
}
}
pub fn work_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.work_dir = Some(dir.into());
self
}
pub fn parallelism(mut self, n: usize) -> Self {
self.parallelism = Some(n);
self
}
pub fn cleanup(mut self, cleanup: bool) -> Self {
self.cleanup = cleanup;
self
}
pub fn build(self) -> MergeCoordinator<'a> {
let mut merger = MergeCoordinator::new(self.coordinator);
if let Some(dir) = self.work_dir {
merger = merger.with_work_dir(dir);
}
if let Some(p) = self.parallelism {
merger = merger.with_parallelism(p);
}
merger.with_cleanup(self.cleanup)
}
}
#[cfg(test)]
mod tests {
use super::super::config::{ShardConfig, ShardGranularity};
use super::*;
use tempfile::TempDir;
fn create_test_coordinator() -> (TempDir, ShardCoordinator) {
let dir = TempDir::new().expect("Failed to create temp dir");
let config =
ShardConfig::new(dir.path().join("shards")).with_granularity(ShardGranularity::TwoChar);
let coordinator = ShardCoordinator::new(config).expect("Failed to create coordinator");
coordinator.store_ngram("the|quick", 100).expect("store");
coordinator.store_ngram("the|slow", 50).expect("store");
coordinator.store_ngram("apple|pie", 30).expect("store");
coordinator.store_ngram("apple|cider", 20).expect("store");
coordinator.store_ngram("banana|split", 15).expect("store");
coordinator
.store_ngram("cherry|blossom", 10)
.expect("store");
coordinator.store_ngram("zebra|crossing", 5).expect("store");
(dir, coordinator)
}
#[test]
fn test_merge_to_memory() {
let (_dir, coordinator) = create_test_coordinator();
let merger = MergeCoordinator::new(&coordinator);
let merged = merger.merge_to_memory().expect("merge");
assert_eq!(merged.len(), 7);
assert_eq!(merged.get(b"the|quick".as_slice()), Some(&100));
assert_eq!(merged.get(b"apple|pie".as_slice()), Some(&30));
assert_eq!(merged.get(b"zebra|crossing".as_slice()), Some(&5));
}
#[test]
fn test_iter_all() {
let (_dir, coordinator) = create_test_coordinator();
let merger = MergeCoordinator::new(&coordinator);
let all: Vec<_> = merger.iter_all().expect("iter_all").collect();
assert_eq!(all.len(), 7);
}
#[test]
fn test_estimated_size() {
let (_dir, coordinator) = create_test_coordinator();
let merger = MergeCoordinator::new(&coordinator);
let size = merger.estimated_final_size();
assert_eq!(size, 7);
}
#[test]
fn test_merge_to_trie() {
let (dir, coordinator) = create_test_coordinator();
let merger =
MergeCoordinator::new(&coordinator).with_work_dir(dir.path().join("merge_work"));
let output_path = dir.path().join("merged.artrie");
let stats = merger
.merge_to_trie(&output_path, |_progress| {})
.expect("merge");
assert!(output_path.exists());
assert_eq!(stats.shards_merged, 5); assert!(stats.total_ngrams > 0);
let merged_trie = PersistentARTrie::<u64>::open(&output_path).expect("open");
assert_eq!(merged_trie.get_value_bytes(b"the|quick"), Some(100));
assert_eq!(merged_trie.get_value_bytes(b"apple|pie"), Some(30));
}
#[test]
fn test_merge_progress() {
let (dir, coordinator) = create_test_coordinator();
let merger =
MergeCoordinator::new(&coordinator).with_work_dir(dir.path().join("merge_work"));
let output_path = dir.path().join("merged.artrie");
let mut progress_updates = Vec::new();
let _stats = merger
.merge_to_trie(&output_path, |progress| {
progress_updates.push(progress.clone());
})
.expect("merge");
assert!(!progress_updates.is_empty());
let last = progress_updates.last().unwrap();
assert_eq!(last.percent_complete, 100.0);
}
#[test]
fn test_merge_builder() {
let (_dir, coordinator) = create_test_coordinator();
let merger = MergeBuilder::new(&coordinator)
.parallelism(4)
.cleanup(false)
.build();
assert_eq!(merger.parallelism, 4);
assert!(!merger.cleanup_intermediates);
}
#[test]
fn test_merge_empty_coordinator() {
let dir = TempDir::new().expect("Failed to create temp dir");
let config =
ShardConfig::new(dir.path().join("shards")).with_granularity(ShardGranularity::TwoChar);
let coordinator = ShardCoordinator::new(config).expect("create");
let merger = MergeCoordinator::new(&coordinator);
let result = merger.merge_to_memory();
assert!(matches!(result, Err(MergeError::NoShards)));
}
}