use crate::StateId;
use crate::error::{Result, ZiporaError};
use crate::fsa::traits::PrefixIterable;
use crate::fsa::{ZiporaTrie, Trie};
use rayon::prelude::*;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct ParallelTrieBuilder {
chunk_size: usize,
max_workers: usize,
}
impl ParallelTrieBuilder {
pub fn new() -> Self {
Self {
chunk_size: 10000,
max_workers: std::thread::available_parallelism().map(|n| n.get()).unwrap_or(1),
}
}
pub fn chunk_size(mut self, chunk_size: usize) -> Self {
self.chunk_size = chunk_size;
self
}
pub fn max_workers(mut self, max_workers: usize) -> Self {
self.max_workers = max_workers;
self
}
pub async fn build_louds_trie<I>(&self, keys: I) -> Result<ParallelLoudsTrie>
where
I: IntoIterator<Item = Vec<u8>>,
I::IntoIter: Send,
{
let keys: Vec<Vec<u8>> = keys.into_iter().collect();
if keys.is_empty() {
return Ok(ParallelLoudsTrie::new());
}
let chunks: Vec<Vec<Vec<u8>>> = keys
.chunks(self.chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let mut partial_tries = Vec::new();
for chunk in chunks {
let trie = tokio::task::spawn_blocking(move || -> Result<ZiporaTrie> {
let mut trie = ZiporaTrie::new();
for key in chunk {
trie.insert(&key)?;
}
Ok(trie)
})
.await
.map_err(|e| {
ZiporaError::configuration(&format!("parallel build failed: {}", e))
})??;
partial_tries.push(trie);
}
self.merge_tries(partial_tries).await
}
async fn merge_tries(&self, tries: Vec<ZiporaTrie>) -> Result<ParallelLoudsTrie> {
if tries.is_empty() {
return Ok(ParallelLoudsTrie::new());
}
if tries.len() == 1 {
return Ok(ParallelLoudsTrie::from_trie(
tries.into_iter().next().expect("non-empty tries"),
));
}
let mut all_keys = Vec::new();
for trie in tries {
let keys = self.extract_keys_from_trie(&trie).await?;
all_keys.extend(keys);
}
all_keys.sort();
all_keys.dedup();
let final_trie =
tokio::task::spawn_blocking(move || -> Result<ZiporaTrie> {
let mut trie = ZiporaTrie::new();
for key in all_keys {
trie.insert(&key)?;
}
Ok(trie)
})
.await
.map_err(|e| ZiporaError::configuration(&format!("final build failed: {}", e)))??;
Ok(ParallelLoudsTrie::from_trie(final_trie))
}
async fn extract_keys_from_trie(&self, trie: &ZiporaTrie) -> Result<Vec<Vec<u8>>> {
let keys: Vec<Vec<u8>> = trie.iter_all().collect();
Ok(keys)
}
}
impl Default for ParallelTrieBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ParallelLoudsTrie {
inner: Arc<Mutex<ZiporaTrie>>,
read_replicas: Arc<Mutex<Vec<Arc<ZiporaTrie>>>>,
replica_count: usize,
}
impl ParallelLoudsTrie {
pub fn new() -> Self {
let replica_count = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(1);
let empty_trie = ZiporaTrie::new();
let mut read_replicas = Vec::with_capacity(replica_count);
for _ in 0..replica_count {
read_replicas.push(Arc::new(empty_trie.clone()));
}
Self {
inner: Arc::new(Mutex::new(empty_trie)),
read_replicas: Arc::new(Mutex::new(read_replicas)),
replica_count,
}
}
pub fn from_trie(trie: ZiporaTrie) -> Self {
let replica_count = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(1);
let mut read_replicas = Vec::with_capacity(replica_count);
for _ in 0..replica_count {
read_replicas.push(Arc::new(trie.clone()));
}
Self {
inner: Arc::new(Mutex::new(trie)),
read_replicas: Arc::new(Mutex::new(read_replicas)),
replica_count,
}
}
pub async fn insert(&self, key: &[u8]) -> Result<StateId> {
let mut trie = self.inner.lock().await;
let result = trie.insert_and_get_node_id(key);
drop(trie);
if result.is_ok() {
self.refresh_replicas().await?;
}
result
}
pub async fn contains(&self, key: &[u8]) -> bool {
let replica_id = self.select_replica();
let replicas = self.read_replicas.lock().await;
let replica = &replicas[replica_id];
replica.contains(key)
}
pub async fn parallel_contains<I>(&self, keys: I) -> Vec<bool>
where
I: IntoIterator<Item = Vec<u8>>,
I::IntoIter: Send,
{
let keys: Vec<Vec<u8>> = keys.into_iter().collect();
let replicas = self.read_replicas.lock().await;
let results: Vec<bool> = keys
.par_iter()
.map(|key| {
let replica_id = self.select_replica();
replicas[replica_id].contains(key)
})
.collect();
results
}
pub async fn parallel_prefix_search<I>(&self, prefixes: I) -> Vec<Vec<Vec<u8>>>
where
I: IntoIterator<Item = Vec<u8>>,
I::IntoIter: Send,
{
let prefixes: Vec<Vec<u8>> = prefixes.into_iter().collect();
let replicas = self.read_replicas.lock().await;
let results: Vec<Vec<Vec<u8>>> = prefixes
.par_iter()
.map(|prefix| {
let replica_id = self.select_replica();
let replica = &replicas[replica_id];
replica.iter_prefix(prefix).collect()
})
.collect();
results
}
fn select_replica(&self) -> usize {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
COUNTER.fetch_add(1, Ordering::Relaxed) % self.replica_count
}
pub async fn refresh_replicas(&self) -> Result<()> {
let trie = self.inner.lock().await;
let trie_clone = trie.clone();
drop(trie);
let mut replicas = self.read_replicas.lock().await;
replicas.clear();
for _ in 0..self.replica_count {
replicas.push(Arc::new(trie_clone.clone()));
}
Ok(())
}
pub async fn len(&self) -> usize {
let trie = self.inner.lock().await;
trie.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
pub async fn bulk_insert<I>(&self, keys: I) -> Result<Vec<StateId>>
where
I: IntoIterator<Item = Vec<u8>>,
I::IntoIter: Send,
{
let keys: Vec<Vec<u8>> = keys.into_iter().collect();
let mut results = Vec::with_capacity(keys.len());
let mut trie = self.inner.lock().await;
for key in keys {
let state_id = trie.insert_and_get_node_id(&key)?;
results.push(state_id);
}
drop(trie);
self.refresh_replicas().await?;
Ok(results)
}
pub async fn parallel_process<F, T>(&self, operations: Vec<F>) -> Vec<Result<T>>
where
F: Fn(&ZiporaTrie) -> Result<T> + Send + Sync,
T: Send,
{
let replicas = self.read_replicas.lock().await;
let results: Vec<Result<T>> = operations
.par_iter()
.map(|op| {
let replica_id = self.select_replica();
let replica = &replicas[replica_id];
op(replica)
})
.collect();
results
}
}
impl Default for ParallelLoudsTrie {
fn default() -> Self {
Self::new()
}
}
pub struct ParallelTrieOps;
impl ParallelTrieOps {
pub async fn find_common_prefixes(
keys: Vec<Vec<u8>>,
min_support: usize,
) -> Result<Vec<Vec<u8>>> {
if keys.is_empty() {
return Ok(Vec::new());
}
use std::collections::HashMap;
use std::sync::Mutex;
let prefix_counts = Arc::new(Mutex::new(HashMap::<Vec<u8>, usize>::new()));
keys.par_iter().for_each(|key| {
for i in 1..=key.len() {
let prefix = key[..i].to_vec();
let mut counts = prefix_counts.lock().unwrap_or_else(|e| e.into_inner());
*counts.entry(prefix).or_insert(0) += 1;
}
});
let counts = prefix_counts.lock().unwrap_or_else(|e| e.into_inner());
let mut common_prefixes: Vec<Vec<u8>> = counts
.iter()
.filter(|&(_, &count)| count >= min_support)
.map(|(prefix, _)| prefix.clone())
.collect();
common_prefixes.sort();
Ok(common_prefixes)
}
pub async fn compute_similarity(
trie1: &ParallelLoudsTrie,
trie2: &ParallelLoudsTrie,
_sample_size: usize,
) -> Result<f64> {
let len1 = trie1.len().await;
let len2 = trie2.len().await;
if len1 == 0 && len2 == 0 {
return Ok(1.0);
}
if len1 == 0 || len2 == 0 {
return Ok(0.0);
}
let trie1_inner = trie1.inner.lock().await;
let keys1: Vec<Vec<u8>> = trie1_inner.iter_all().collect();
drop(trie1_inner);
let trie2_inner = trie2.inner.lock().await;
let keys2: Vec<Vec<u8>> = trie2_inner.iter_all().collect();
drop(trie2_inner);
use std::collections::HashSet;
let set1: HashSet<Vec<u8>> = keys1.into_iter().collect();
let set2: HashSet<Vec<u8>> = keys2.into_iter().collect();
let intersection_size = set1.intersection(&set2).count();
let union_size = set1.union(&set2).count();
if union_size == 0 {
Ok(0.0)
} else {
Ok(intersection_size as f64 / union_size as f64)
}
}
pub async fn merge_tries(tries: Vec<ParallelLoudsTrie>) -> Result<ParallelLoudsTrie> {
if tries.is_empty() {
return Ok(ParallelLoudsTrie::new());
}
if tries.len() == 1 {
return Ok(tries.into_iter().next().expect("non-empty tries"));
}
let mut all_keys = Vec::new();
for trie in tries {
let trie_inner = trie.inner.lock().await;
let keys: Vec<Vec<u8>> = trie_inner.iter_all().collect();
all_keys.extend(keys);
}
all_keys.sort();
all_keys.dedup();
let builder = ParallelTrieBuilder::new();
builder.build_louds_trie(all_keys).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio;
#[tokio::test]
async fn test_parallel_trie_builder() {
let builder = ParallelTrieBuilder::new().chunk_size(1000).max_workers(2);
let keys = vec![
b"cat".to_vec(),
b"car".to_vec(),
b"card".to_vec(),
b"dog".to_vec(),
];
let trie = builder.build_louds_trie(keys).await.unwrap();
assert!(!trie.is_empty().await);
}
#[tokio::test]
async fn test_parallel_louds_trie() {
let trie = ParallelLoudsTrie::new();
let _state_id = trie.insert(b"test").await.unwrap();
assert_eq!(trie.len().await, 1);
assert!(!trie.is_empty().await);
assert!(trie.contains(b"test").await);
assert!(!trie.contains(b"missing").await);
}
#[tokio::test]
async fn test_parallel_contains() {
let trie = ParallelLoudsTrie::new();
let _id1 = trie.insert(b"cat").await.unwrap();
let _id2 = trie.insert(b"car").await.unwrap();
let _id3 = trie.insert(b"dog").await.unwrap();
let test_keys = vec![
b"cat".to_vec(),
b"car".to_vec(),
b"bird".to_vec(),
b"dog".to_vec(),
];
let results = trie.parallel_contains(test_keys).await;
assert_eq!(results, vec![true, true, false, true]);
}
#[tokio::test]
async fn test_bulk_insert() {
let trie = ParallelLoudsTrie::new();
let keys = vec![b"apple".to_vec(), b"banana".to_vec(), b"cherry".to_vec()];
let state_ids = trie.bulk_insert(keys).await.unwrap();
assert_eq!(state_ids.len(), 3);
assert_eq!(trie.len().await, 3);
}
#[tokio::test]
async fn test_find_common_prefixes() {
let keys = vec![
b"prefix_one".to_vec(),
b"prefix_two".to_vec(),
b"prefix_three".to_vec(),
b"other".to_vec(),
];
let prefixes = ParallelTrieOps::find_common_prefixes(keys, 2)
.await
.unwrap();
assert!(prefixes.iter().any(|p| p.starts_with(b"prefix")));
}
#[tokio::test]
async fn test_replica_selection() {
let trie = ParallelLoudsTrie::new();
let replica1 = trie.select_replica();
let replica2 = trie.select_replica();
if trie.replica_count > 1 {
assert_ne!(replica1, replica2);
}
}
#[tokio::test]
async fn test_trie_similarity() {
let trie1 = ParallelLoudsTrie::new();
let trie2 = ParallelLoudsTrie::new();
let similarity = ParallelTrieOps::compute_similarity(&trie1, &trie2, 100)
.await
.unwrap();
assert_eq!(similarity, 1.0);
let _id1 = trie1.insert(b"cat").await.unwrap();
let _id2 = trie1.insert(b"car").await.unwrap();
let _id3 = trie2.insert(b"cat").await.unwrap();
let _id4 = trie2.insert(b"car").await.unwrap();
let similarity = ParallelTrieOps::compute_similarity(&trie1, &trie2, 100)
.await
.unwrap();
assert_eq!(similarity, 1.0);
let _id5 = trie2.insert(b"dog").await.unwrap();
let similarity = ParallelTrieOps::compute_similarity(&trie1, &trie2, 100)
.await
.unwrap();
assert!(similarity < 1.0);
assert!(similarity > 0.0);
}
#[tokio::test]
async fn test_basic_trie_operations() {
let trie = ParallelLoudsTrie::new();
let _id1 = trie.insert(b"a").await.unwrap();
let _id2 = trie.insert(b"b").await.unwrap();
assert_eq!(trie.len().await, 2);
assert!(!trie.is_empty().await);
}
#[tokio::test]
async fn test_merge_tries() {
let trie1 = ParallelLoudsTrie::new();
let trie2 = ParallelLoudsTrie::new();
let _id1 = trie1.insert(b"x").await.unwrap();
let _id2 = trie2.insert(b"y").await.unwrap();
let len1 = trie1.len().await;
let len2 = trie2.len().await;
let merged = ParallelTrieOps::merge_tries(vec![trie1, trie2])
.await
.unwrap();
let merged_len = merged.len().await;
assert!(merged_len > 0);
assert!(merged_len <= len1 + len2);
}
}