use anyhow::Result;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use std::thread::JoinHandle;
use super::IbdRetireWork;
pub(crate) struct RetireDispatcher {
shards: Vec<DispatcherShard>,
global_last_retired: Arc<AtomicU64>,
}
struct DispatcherShard {
tx: Option<mpsc::Sender<IbdRetireWork>>,
handle: Option<JoinHandle<()>>,
}
impl RetireDispatcher {
pub fn spawn<F>(num_shards: usize, start_height_minus_one: u64, mut spawn_thread: F) -> Self
where
F: FnMut(
usize,
mpsc::Receiver<IbdRetireWork>,
Arc<AtomicU64>,
Arc<GlobalProgressPublisher>,
) -> JoinHandle<()>,
{
let n = num_shards.max(1);
let global_last_retired = Arc::new(AtomicU64::new(start_height_minus_one));
let local_cursors: Vec<Arc<AtomicU64>> = (0..n)
.map(|_| Arc::new(AtomicU64::new(start_height_minus_one)))
.collect();
let publisher = Arc::new(GlobalProgressPublisher {
locals: local_cursors.clone(),
global: Arc::clone(&global_last_retired),
recompute_lock: Mutex::new(()),
});
let mut shards = Vec::with_capacity(n);
for (i, cursor) in local_cursors.iter().enumerate().take(n) {
let (tx, rx) = mpsc::channel::<IbdRetireWork>();
let handle = spawn_thread(i, rx, Arc::clone(cursor), Arc::clone(&publisher));
shards.push(DispatcherShard {
tx: Some(tx),
handle: Some(handle),
});
}
Self {
shards,
global_last_retired,
}
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
pub fn send(
&self,
work: IbdRetireWork,
) -> std::result::Result<(), mpsc::SendError<IbdRetireWork>> {
let i = (work.height as usize) % self.shards.len();
match &self.shards[i].tx {
Some(tx) => tx.send(work),
None => Err(mpsc::SendError(work)),
}
}
pub fn global_last_retired(&self) -> &Arc<AtomicU64> {
&self.global_last_retired
}
pub fn shutdown_and_join(&mut self) -> Result<()> {
for s in self.shards.iter_mut() {
s.tx.take();
}
for s in self.shards.iter_mut() {
if let Some(h) = s.handle.take() {
let _ = h.join();
}
}
Ok(())
}
}
impl Drop for RetireDispatcher {
fn drop(&mut self) {
for s in self.shards.iter_mut() {
s.tx.take();
}
for s in self.shards.iter_mut() {
if let Some(h) = s.handle.take() {
let _ = h.join();
}
}
}
}
pub(crate) struct GlobalProgressPublisher {
locals: Vec<Arc<AtomicU64>>,
global: Arc<AtomicU64>,
recompute_lock: Mutex<()>,
}
impl GlobalProgressPublisher {
pub fn publish(&self, local: &AtomicU64, h: u64) {
local.store(h, Ordering::Release);
let _g = self.recompute_lock.lock();
let mut m = u64::MAX;
for l in &self.locals {
let v = l.load(Ordering::Acquire);
if v < m {
m = v;
}
}
if m != u64::MAX {
self.global.store(m, Ordering::Release);
}
}
}
pub(crate) fn configured_retire_shards() -> usize {
let raw: usize = std::env::var("BLVM_IBD_RETIRE_SHARDS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1);
if raw <= 1 {
return 1;
}
let cap = std::thread::available_parallelism()
.map(|p| (p.get() / 2).max(1))
.unwrap_or(1);
raw.min(cap).max(1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn publisher_global_tracks_min_not_max() {
let local0 = Arc::new(AtomicU64::new(0));
let local1 = Arc::new(AtomicU64::new(0));
let global = Arc::new(AtomicU64::new(0));
let publisher = GlobalProgressPublisher {
locals: vec![Arc::clone(&local0), Arc::clone(&local1)],
global: Arc::clone(&global),
recompute_lock: Mutex::new(()),
};
publisher.publish(&local0, 100);
assert_eq!(global.load(Ordering::Acquire), 0);
assert_eq!(local0.load(Ordering::Acquire), 100);
publisher.publish(&local1, 50);
assert_eq!(global.load(Ordering::Acquire), 50);
publisher.publish(&local1, 200);
assert_eq!(global.load(Ordering::Acquire), 100);
publisher.publish(&local0, 300);
assert_eq!(global.load(Ordering::Acquire), 200);
}
#[test]
fn publisher_n1_global_equals_local() {
let local = Arc::new(AtomicU64::new(0));
let global = Arc::new(AtomicU64::new(0));
let publisher = GlobalProgressPublisher {
locals: vec![Arc::clone(&local)],
global: Arc::clone(&global),
recompute_lock: Mutex::new(()),
};
for h in [1u64, 5, 17, 100, 1_000_000].iter().copied() {
publisher.publish(&local, h);
assert_eq!(global.load(Ordering::Acquire), h);
assert_eq!(local.load(Ordering::Acquire), h);
}
}
#[test]
fn configured_retire_shards_defaults_and_clamps() {
let _guard = ENV_LOCK.lock();
std::env::remove_var("BLVM_IBD_RETIRE_SHARDS");
assert_eq!(configured_retire_shards(), 1, "default must be 1");
std::env::set_var("BLVM_IBD_RETIRE_SHARDS", "0");
assert_eq!(configured_retire_shards(), 1, "0 must clamp to 1");
std::env::set_var("BLVM_IBD_RETIRE_SHARDS", "1");
assert_eq!(configured_retire_shards(), 1);
std::env::set_var("BLVM_IBD_RETIRE_SHARDS", "garbage");
assert_eq!(configured_retire_shards(), 1, "unparseable must clamp to 1");
std::env::set_var("BLVM_IBD_RETIRE_SHARDS", "999");
let n = configured_retire_shards();
assert!(n >= 1);
let cap = std::thread::available_parallelism()
.map(|p| (p.get() / 2).max(1))
.unwrap_or(1);
assert_eq!(n, cap, "999 must clamp to available_parallelism / 2");
std::env::remove_var("BLVM_IBD_RETIRE_SHARDS");
}
static ENV_LOCK: Mutex<()> = Mutex::new(());
}