use libdictenstein::persistent_artrie::eviction::EvictionConfig;
use std::path::PathBuf;
const MIN_PER_SHARD_OVERLAY_BUDGET_BYTES: usize = 64 * 1024 * 1024;
const OVERLAY_EVICTION_CAP_NODES: usize = 200_000;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ShardGranularity {
FirstChar,
TwoChar,
Adaptive,
Custom {
prefix_len: usize,
},
CpuProportional {
multiplier: usize,
minimum: usize,
},
}
impl Default for ShardGranularity {
fn default() -> Self {
Self::CpuProportional {
multiplier: 2,
minimum: 8,
}
}
}
impl ShardGranularity {
pub fn prefix_len_for_order(&self, order: u8) -> usize {
match self {
Self::FirstChar => 1,
Self::TwoChar => 2,
Self::Adaptive => {
if order == 1 {
1 } else {
2 }
}
Self::Custom { prefix_len } => *prefix_len,
Self::CpuProportional { .. } => 0, }
}
pub fn max_shards(&self) -> usize {
match self {
Self::FirstChar => 26,
Self::TwoChar => 676,
Self::Adaptive => 676, Self::Custom { prefix_len } => 26_usize.pow(*prefix_len as u32),
Self::CpuProportional {
multiplier,
minimum,
} => Self::compute_cpu_proportional_shards(*multiplier, *minimum),
}
}
pub fn num_shards(&self) -> usize {
self.max_shards()
}
pub fn is_hash_based(&self) -> bool {
matches!(self, Self::CpuProportional { .. })
}
pub fn is_prefix_based(&self) -> bool {
!self.is_hash_based()
}
fn compute_cpu_proportional_shards(multiplier: usize, minimum: usize) -> usize {
let cpus = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4);
(cpus * multiplier).max(minimum)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum MergeMode {
PostImport,
Periodic,
Hierarchical {
levels: usize,
},
}
impl Default for MergeMode {
fn default() -> Self {
Self::PostImport
}
}
#[derive(Clone, Debug)]
pub struct MergeConfig {
pub mode: MergeMode,
pub merge_trigger_shards: usize,
pub shard_full_threshold: u64,
pub merge_parallelism: usize,
pub cleanup_after_merge: bool,
}
impl Default for MergeConfig {
fn default() -> Self {
Self {
mode: MergeMode::PostImport,
merge_trigger_shards: 10,
shard_full_threshold: 1_000_000, merge_parallelism: 4,
cleanup_after_merge: true,
}
}
}
#[derive(Clone, Debug)]
pub struct ShardConfig {
pub granularity: ShardGranularity,
pub shard_dir: PathBuf,
pub max_writers: usize,
pub max_open_shards: usize,
pub shard_memory_budget: usize,
pub checkpoint_interval_ms: u64,
pub auto_shard_threshold: u64,
pub merge: MergeConfig,
pub overlay_budget_bytes: Option<usize>,
}
impl Default for ShardConfig {
fn default() -> Self {
Self {
granularity: ShardGranularity::default(),
shard_dir: PathBuf::from("shards"),
max_writers: 4,
max_open_shards: 32, shard_memory_budget: 64 * 1024 * 1024, checkpoint_interval_ms: 30_000, auto_shard_threshold: 10_000_000, merge: MergeConfig::default(),
overlay_budget_bytes: None,
}
}
}
impl ShardConfig {
pub fn new(shard_dir: impl Into<PathBuf>) -> Self {
Self {
shard_dir: shard_dir.into(),
..Default::default()
}
}
pub fn with_overlay_budget_bytes(mut self, budget: Option<usize>) -> Self {
self.overlay_budget_bytes = budget;
self
}
pub fn overlay_eviction_config(&self) -> Option<EvictionConfig> {
let global = self.overlay_budget_bytes?;
let num_shards = self.granularity.num_shards().max(1);
let resident = if self.granularity.is_hash_based() || self.max_open_shards == 0 {
num_shards
} else {
self.max_open_shards.min(num_shards)
};
let per_shard = (global / resident).max(MIN_PER_SHARD_OVERLAY_BUDGET_BYTES);
let mut config = EvictionConfig::without_memory_monitor();
config.resident_budget_bytes = Some(per_shard);
config.resident_budget_eviction_cap = Some(OVERLAY_EVICTION_CAP_NODES);
Some(config)
}
pub fn with_granularity(mut self, granularity: ShardGranularity) -> Self {
self.granularity = granularity;
self
}
pub fn with_max_writers(mut self, max_writers: usize) -> Self {
self.max_writers = max_writers;
self
}
pub fn with_max_open_shards(mut self, max_open_shards: usize) -> Self {
self.max_open_shards = max_open_shards;
self
}
pub fn with_auto_shard_threshold(mut self, threshold: u64) -> Self {
self.auto_shard_threshold = threshold;
self
}
pub fn with_merge_config(mut self, merge: MergeConfig) -> Self {
self.merge = merge;
self
}
pub fn shard_path(&self, prefix: &str) -> PathBuf {
self.shard_dir.join(format!("shard_{}.artrie", prefix))
}
pub fn global_checkpoint_path(&self) -> PathBuf {
self.shard_dir.join("global_checkpoint.json")
}
pub fn should_shard(&self, estimated_ngrams: u64) -> bool {
estimated_ngrams >= self.auto_shard_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_granularity_prefix_len() {
assert_eq!(ShardGranularity::FirstChar.prefix_len_for_order(1), 1);
assert_eq!(ShardGranularity::FirstChar.prefix_len_for_order(2), 1);
assert_eq!(ShardGranularity::TwoChar.prefix_len_for_order(1), 2);
assert_eq!(ShardGranularity::TwoChar.prefix_len_for_order(2), 2);
assert_eq!(ShardGranularity::Adaptive.prefix_len_for_order(1), 1);
assert_eq!(ShardGranularity::Adaptive.prefix_len_for_order(2), 2);
assert_eq!(ShardGranularity::Adaptive.prefix_len_for_order(5), 2);
assert_eq!(
ShardGranularity::Custom { prefix_len: 3 }.prefix_len_for_order(1),
3
);
let cpu_prop = ShardGranularity::CpuProportional {
multiplier: 2,
minimum: 8,
};
assert_eq!(cpu_prop.prefix_len_for_order(1), 0);
assert_eq!(cpu_prop.prefix_len_for_order(5), 0);
}
#[test]
fn test_granularity_max_shards() {
assert_eq!(ShardGranularity::FirstChar.max_shards(), 26);
assert_eq!(ShardGranularity::TwoChar.max_shards(), 676);
assert_eq!(ShardGranularity::Adaptive.max_shards(), 676);
assert_eq!(
ShardGranularity::Custom { prefix_len: 3 }.max_shards(),
17576
);
let cpu_prop = ShardGranularity::CpuProportional {
multiplier: 2,
minimum: 8,
};
let expected = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
* 2;
assert_eq!(cpu_prop.max_shards(), expected.max(8));
}
#[test]
fn test_cpu_proportional_minimum() {
let cpu_prop = ShardGranularity::CpuProportional {
multiplier: 1,
minimum: 16,
};
assert!(cpu_prop.max_shards() >= 16);
}
#[test]
fn test_is_hash_based() {
assert!(!ShardGranularity::FirstChar.is_hash_based());
assert!(!ShardGranularity::TwoChar.is_hash_based());
assert!(!ShardGranularity::Adaptive.is_hash_based());
assert!(!ShardGranularity::Custom { prefix_len: 2 }.is_hash_based());
let cpu_prop = ShardGranularity::CpuProportional {
multiplier: 2,
minimum: 8,
};
assert!(cpu_prop.is_hash_based());
assert!(!cpu_prop.is_prefix_based());
}
#[test]
fn test_default_is_cpu_proportional() {
let default = ShardGranularity::default();
assert!(default.is_hash_based());
assert!(matches!(
default,
ShardGranularity::CpuProportional {
multiplier: 2,
minimum: 8
}
));
}
#[test]
fn test_shard_config_paths() {
let config = ShardConfig::new("/tmp/shards");
assert_eq!(
config.shard_path("th"),
PathBuf::from("/tmp/shards/shard_th.artrie")
);
assert_eq!(
config.global_checkpoint_path(),
PathBuf::from("/tmp/shards/global_checkpoint.json")
);
}
#[test]
fn test_should_shard() {
let config = ShardConfig::default();
assert!(!config.should_shard(1_000_000)); assert!(config.should_shard(10_000_000)); assert!(config.should_shard(100_000_000)); }
}