use super::coordinator::ShardCoordinator;
use super::routing::{compute_shard_key, ngram_order, ShardKey};
use std::collections::BTreeMap;
pub struct ShardedTrieView<'a> {
coordinator: &'a ShardCoordinator,
}
impl<'a> ShardedTrieView<'a> {
pub fn new(coordinator: &'a ShardCoordinator) -> Self {
Self { coordinator }
}
pub fn get(&self, ngram: &str) -> Option<u64> {
self.coordinator.get(ngram)
}
pub fn contains(&self, ngram: &str) -> bool {
self.coordinator.contains(ngram)
}
pub fn len(&self) -> u64 {
self.coordinator.total_entry_count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn prefix_search(&self, prefix: &[u8]) -> Vec<(Vec<u8>, u64)> {
let config = self.coordinator.config();
let granularity = &config.granularity;
let prefix_str = std::str::from_utf8(prefix).unwrap_or("");
let first_word = prefix_str.split('|').next().unwrap_or("");
let prefix_chars: String = first_word
.chars()
.filter(|c| c.is_alphabetic())
.flat_map(|c| c.to_lowercase())
.collect();
let order = ngram_order(prefix_str).max(1);
let required_len = granularity.prefix_len_for_order(order);
if prefix_chars.len() >= required_len {
let shard_key = compute_shard_key(prefix_str, order, granularity);
self.prefix_search_shard(&shard_key, prefix)
} else {
self.prefix_search_all_shards(prefix)
}
}
fn prefix_search_shard(&self, key: &ShardKey, prefix: &[u8]) -> Vec<(Vec<u8>, u64)> {
if let Ok(shard) = self.coordinator.get_or_create_shard(key) {
let guard = shard.read();
match guard.iter_with_counts() {
Ok(entries) => entries
.into_iter()
.filter(|(ngram, _)| ngram.starts_with(prefix))
.collect(),
Err(e) => {
log::warn!("Failed to iterate shard {}: {}", key, e);
Vec::new()
}
}
} else {
Vec::new()
}
}
fn prefix_search_all_shards(&self, prefix: &[u8]) -> Vec<(Vec<u8>, u64)> {
let mut results: BTreeMap<Vec<u8>, u64> = BTreeMap::new();
for key in self.coordinator.open_shard_keys() {
if let Ok(shard) = self.coordinator.get_or_create_shard(&key) {
let guard = shard.read();
match guard.iter_with_counts() {
Ok(iter) => {
for (ngram, count) in iter {
if ngram.starts_with(prefix) {
results.insert(ngram, count);
}
}
}
Err(e) => {
log::warn!("Failed to iterate shard {}: {}", key, e);
}
}
}
}
results.into_iter().collect()
}
pub fn stats(&self) -> ViewStats {
let shard_count = self.coordinator.open_shard_count();
let coordinator_stats = self.coordinator.stats();
ViewStats {
shard_count,
total_ngrams: coordinator_stats
.total_ngrams
.load(std::sync::atomic::Ordering::Relaxed),
unique_ngrams: coordinator_stats
.unique_ngrams
.load(std::sync::atomic::Ordering::Relaxed),
total_entries: self.len(),
}
}
pub fn shard_key_for(&self, ngram: &str) -> ShardKey {
self.coordinator.route_ngram(ngram)
}
pub fn shard_distribution(&self) -> BTreeMap<String, u64> {
let mut distribution = BTreeMap::new();
for key in self.coordinator.open_shard_keys() {
if let Ok(shard) = self.coordinator.get_or_create_shard(&key) {
let guard = shard.read();
distribution.insert(key.to_string(), guard.len() as u64);
}
}
distribution
}
pub fn iter_all(&self) -> impl Iterator<Item = (Vec<u8>, u64)> + '_ {
let keys = self.coordinator.open_shard_keys();
keys.into_iter().flat_map(move |key| {
if let Ok(shard) = self.coordinator.get_or_create_shard(&key) {
let guard = shard.read();
match guard.iter_with_counts() {
Ok(entries) => entries,
Err(e) => {
log::warn!("Failed to iterate shard {}: {}", key, e);
Vec::new()
}
}
} else {
Vec::new()
}
})
}
pub fn top_n(&self, n: usize) -> Vec<(Vec<u8>, u64)> {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
let mut heap: BinaryHeap<Reverse<(u64, Vec<u8>)>> = BinaryHeap::new();
for (ngram, count) in self.iter_all() {
if heap.len() < n {
heap.push(Reverse((count, ngram)));
} else if let Some(Reverse((min_count, _))) = heap.peek() {
if count > *min_count {
heap.pop();
heap.push(Reverse((count, ngram)));
}
}
}
let mut result: Vec<_> = heap
.into_iter()
.map(|Reverse((count, ngram))| (ngram, count))
.collect();
result.sort_by(|a, b| b.1.cmp(&a.1));
result
}
}
#[derive(Clone, Debug)]
pub struct ViewStats {
pub shard_count: usize,
pub total_ngrams: u64,
pub unique_ngrams: u64,
pub total_entries: u64,
}
#[cfg(test)]
mod tests {
use super::super::config::{ShardConfig, ShardGranularity};
use super::*;
use tempfile::TempDir;
fn create_test_coordinator() -> (TempDir, ShardCoordinator) {
let dir = TempDir::new().expect("Failed to create temp dir");
let config =
ShardConfig::new(dir.path().join("shards")).with_granularity(ShardGranularity::TwoChar);
let coordinator = ShardCoordinator::new(config).expect("Failed to create coordinator");
coordinator.store_ngram("the|quick", 100).expect("store");
coordinator.store_ngram("the|slow", 50).expect("store");
coordinator.store_ngram("apple|pie", 30).expect("store");
coordinator.store_ngram("apple|cider", 20).expect("store");
coordinator
.store_ngram("zebra|crossing", 10)
.expect("store");
(dir, coordinator)
}
#[test]
fn test_view_basic_queries() {
let (_dir, coordinator) = create_test_coordinator();
let view = ShardedTrieView::new(&coordinator);
assert_eq!(view.get("the|quick"), Some(100));
assert_eq!(view.get("apple|pie"), Some(30));
assert_eq!(view.get("nonexistent"), None);
assert!(view.contains("the|quick"));
assert!(!view.contains("nonexistent"));
}
#[test]
fn test_view_prefix_search() {
let (_dir, coordinator) = create_test_coordinator();
let view = ShardedTrieView::new(&coordinator);
let results = view.prefix_search(b"the|");
assert_eq!(results.len(), 2);
let ngrams: Vec<_> = results.iter().map(|(n, _)| n.as_slice()).collect();
assert!(ngrams.contains(&b"the|quick".as_slice()));
assert!(ngrams.contains(&b"the|slow".as_slice()));
let results = view.prefix_search(b"apple|");
assert_eq!(results.len(), 2);
}
#[test]
fn test_view_stats() {
let (_dir, coordinator) = create_test_coordinator();
let view = ShardedTrieView::new(&coordinator);
let stats = view.stats();
assert_eq!(stats.total_entries, 5);
assert!(stats.shard_count > 0);
}
#[test]
fn test_view_distribution() {
let (_dir, coordinator) = create_test_coordinator();
let view = ShardedTrieView::new(&coordinator);
let dist = view.shard_distribution();
assert!(dist.len() >= 2);
assert_eq!(dist.get("th"), Some(&2));
assert_eq!(dist.get("ap"), Some(&2));
}
#[test]
fn test_view_top_n() {
let (_dir, coordinator) = create_test_coordinator();
let view = ShardedTrieView::new(&coordinator);
let top = view.top_n(3);
assert_eq!(top.len(), 3);
assert_eq!(top[0], (b"the|quick".to_vec(), 100));
assert_eq!(top[1], (b"the|slow".to_vec(), 50));
}
#[test]
fn test_view_shard_key() {
let (_dir, coordinator) = create_test_coordinator();
let view = ShardedTrieView::new(&coordinator);
let key = view.shard_key_for("the|quick");
assert_eq!(key.prefix, "th");
let key = view.shard_key_for("apple|pie");
assert_eq!(key.prefix, "ap");
}
}