use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc::Sender;
use tokio::sync::watch::Sender as WatchSender;
use tokio::task::JoinHandle;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PipelinePhase {
Initializing,
Downloading {
active_workers: usize,
results_received: u64,
total_pending: u64,
},
Paused {
results_received: u64,
total_pending: u64,
},
ComputingStats,
Merging {
shards_processed: usize,
total_shards: usize,
},
CleaningUp,
Completed {
total_ngrams: u64,
duration: Duration,
},
Cancelled,
ForceQuit,
Failed {
error: String,
},
}
impl PipelinePhase {
pub fn is_terminal(&self) -> bool {
matches!(
self,
PipelinePhase::Completed { .. }
| PipelinePhase::Cancelled
| PipelinePhase::ForceQuit
| PipelinePhase::Failed { .. }
)
}
pub fn name(&self) -> &'static str {
match self {
PipelinePhase::Initializing => "Initializing",
PipelinePhase::Downloading { .. } => "Downloading",
PipelinePhase::Paused { .. } => "Paused",
PipelinePhase::ComputingStats => "Computing MKN Statistics",
PipelinePhase::Merging { .. } => "Merging Shards",
PipelinePhase::CleaningUp => "Cleaning Up",
PipelinePhase::Completed { .. } => "Completed",
PipelinePhase::Cancelled => "Cancelled",
PipelinePhase::ForceQuit => "Force Quit",
PipelinePhase::Failed { .. } => "Failed",
}
}
}
#[derive(Debug)]
pub enum ImportTrigger {
InitComplete {
total_pending: u64,
},
JobResult {
order: u8,
prefix: Arc<str>,
ngrams: u64,
success: bool,
},
WorkerExited {
worker_id: usize,
},
AllResultsReceived,
StatsComplete,
MergeComplete,
CleanupComplete,
Pause,
Resume,
Cancel,
ForceQuit,
SetParallelism(usize),
Error(String),
}
pub type CleanupFn = Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
pub struct CleanupGuard {
resources: Vec<CleanupFn>,
}
impl CleanupGuard {
pub fn new() -> Self {
Self {
resources: Vec::new(),
}
}
pub fn register<F, Fut>(&mut self, f: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.resources.push(Box::new(move || Box::pin(f())));
}
pub async fn cleanup(mut self) {
while let Some(cleanup_fn) = self.resources.pop() {
cleanup_fn().await;
}
}
pub fn cleanup_blocking(mut self) {
while let Some(cleanup_fn) = self.resources.pop() {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(cleanup_fn())
});
}
}
pub fn len(&self) -> usize {
self.resources.len()
}
pub fn is_empty(&self) -> bool {
self.resources.is_empty()
}
}
impl Default for CleanupGuard {
fn default() -> Self {
Self::new()
}
}
pub struct CleanupResources<S> {
pub worker_handles: HashMap<usize, JoinHandle<()>>,
pub worker_shutdown_txs: HashMap<usize, WatchSender<bool>>,
pub shared_state: Option<Arc<S>>,
pub result_tx: Option<Box<dyn std::any::Any + Send>>,
pub worker_exit_tx: Option<Box<dyn std::any::Any + Send>>,
pub worker_converter: Option<JoinHandle<()>>,
pub stats_task: Option<JoinHandle<()>>,
pub command_handler: Option<JoinHandle<()>>,
}
impl<S: Send + Sync + 'static> CleanupResources<S> {
pub fn new() -> Self {
Self {
worker_handles: HashMap::new(),
worker_shutdown_txs: HashMap::new(),
shared_state: None,
result_tx: None,
worker_exit_tx: None,
worker_converter: None,
stats_task: None,
command_handler: None,
}
}
pub fn with_worker_handles(mut self, handles: HashMap<usize, JoinHandle<()>>) -> Self {
self.worker_handles = handles;
self
}
pub fn with_worker_shutdown_txs(mut self, txs: HashMap<usize, WatchSender<bool>>) -> Self {
self.worker_shutdown_txs = txs;
self
}
pub fn with_shared_state(mut self, state: Arc<S>) -> Self {
self.shared_state = Some(state);
self
}
pub fn with_result_tx<T: Send + 'static>(mut self, tx: Sender<T>) -> Self {
self.result_tx = Some(Box::new(tx));
self
}
pub fn with_worker_exit_tx<T: Send + 'static>(mut self, tx: Sender<T>) -> Self {
self.worker_exit_tx = Some(Box::new(tx));
self
}
pub fn with_worker_converter(mut self, handle: JoinHandle<()>) -> Self {
self.worker_converter = Some(handle);
self
}
pub fn with_stats_task(mut self, handle: JoinHandle<()>) -> Self {
self.stats_task = Some(handle);
self
}
pub fn with_command_handler(mut self, handle: JoinHandle<()>) -> Self {
self.command_handler = Some(handle);
self
}
pub fn into_cleanup_guard(self) -> CleanupGuard {
let mut guard = CleanupGuard::new();
if let Some(handle) = self.command_handler {
guard.register(move || async move {
handle.abort();
let _ = handle.await;
});
}
if let Some(handle) = self.worker_converter {
guard.register(move || async move {
let _ = handle.await;
});
}
if let Some(handle) = self.stats_task {
guard.register(move || async move {
handle.abort();
let _ = handle.await;
});
}
let result_tx = self.result_tx;
let worker_exit_tx = self.worker_exit_tx;
guard.register(move || async move {
drop(result_tx);
drop(worker_exit_tx);
});
if let Some(state) = self.shared_state {
guard.register(move || async move {
drop(state);
});
}
let handles = self.worker_handles;
guard.register(move || async move {
for (_, handle) in handles {
let _ = handle.await;
}
});
let shutdown_txs = self.worker_shutdown_txs;
guard.register(move || async move {
for tx in shutdown_txs.values() {
let _ = tx.send(true);
}
});
guard
}
}
impl<S: Send + Sync + 'static> Default for CleanupResources<S> {
fn default() -> Self {
Self::new()
}
}
pub struct ImportContext {
pub phase: PipelinePhase,
pub active_workers: usize,
pub results_received: u64,
pub total_pending: u64,
pub paused: Arc<AtomicBool>,
pub cancelled: Arc<AtomicBool>,
pub force_quit: Arc<AtomicBool>,
pub current_parallelism: Arc<AtomicUsize>,
pub total_ngrams: Arc<AtomicU64>,
pub unique_ngrams: Arc<AtomicU64>,
pub files_completed: Arc<AtomicU64>,
pub order_files_completed: HashMap<u8, u64>,
pub order_files_skipped: HashMap<u8, u64>,
pub order_total_files: HashMap<u8, u64>,
pub order_start_times: HashMap<u8, Instant>,
pub start_time: Instant,
}
impl ImportContext {
pub fn new() -> Self {
Self {
phase: PipelinePhase::Initializing,
active_workers: 0,
results_received: 0,
total_pending: 0,
paused: Arc::new(AtomicBool::new(false)),
cancelled: Arc::new(AtomicBool::new(false)),
force_quit: Arc::new(AtomicBool::new(false)),
current_parallelism: Arc::new(AtomicUsize::new(0)),
total_ngrams: Arc::new(AtomicU64::new(0)),
unique_ngrams: Arc::new(AtomicU64::new(0)),
files_completed: Arc::new(AtomicU64::new(0)),
order_files_completed: HashMap::new(),
order_files_skipped: HashMap::new(),
order_total_files: HashMap::new(),
order_start_times: HashMap::new(),
start_time: Instant::now(),
}
}
pub fn transition(&mut self, trigger: ImportTrigger) -> &PipelinePhase {
let new_phase = match (&self.phase, trigger) {
(PipelinePhase::Initializing, ImportTrigger::InitComplete { total_pending }) => {
self.total_pending = total_pending;
PipelinePhase::Downloading {
active_workers: self.active_workers,
results_received: 0,
total_pending,
}
}
(
PipelinePhase::Downloading { total_pending, .. },
ImportTrigger::JobResult { ngrams: _, .. },
) => {
self.results_received += 1;
PipelinePhase::Downloading {
active_workers: self.active_workers,
results_received: self.results_received,
total_pending: *total_pending,
}
}
(PipelinePhase::Downloading { .. }, ImportTrigger::WorkerExited { .. }) => {
self.active_workers = self.active_workers.saturating_sub(1);
PipelinePhase::Downloading {
active_workers: self.active_workers,
results_received: self.results_received,
total_pending: self.total_pending,
}
}
(PipelinePhase::Downloading { .. }, ImportTrigger::AllResultsReceived) => {
PipelinePhase::ComputingStats
}
(PipelinePhase::Downloading { total_pending, .. }, ImportTrigger::Pause) => {
self.paused.store(true, Ordering::SeqCst);
PipelinePhase::Paused {
results_received: self.results_received,
total_pending: *total_pending,
}
}
(PipelinePhase::Downloading { .. }, ImportTrigger::Cancel) => {
self.cancelled.store(true, Ordering::SeqCst);
PipelinePhase::Cancelled
}
(PipelinePhase::Downloading { .. }, ImportTrigger::ForceQuit) => {
self.force_quit.store(true, Ordering::SeqCst);
PipelinePhase::ForceQuit
}
(PipelinePhase::Downloading { .. }, ImportTrigger::Error(msg)) => {
PipelinePhase::Failed { error: msg }
}
(PipelinePhase::Paused { total_pending, .. }, ImportTrigger::Resume) => {
self.paused.store(false, Ordering::SeqCst);
PipelinePhase::Downloading {
active_workers: self.active_workers,
results_received: self.results_received,
total_pending: *total_pending,
}
}
(PipelinePhase::Paused { .. }, ImportTrigger::Cancel) => {
self.cancelled.store(true, Ordering::SeqCst);
PipelinePhase::Cancelled
}
(PipelinePhase::Paused { .. }, ImportTrigger::ForceQuit) => {
self.force_quit.store(true, Ordering::SeqCst);
PipelinePhase::ForceQuit
}
(PipelinePhase::ComputingStats, ImportTrigger::StatsComplete) => {
PipelinePhase::Merging {
shards_processed: 0,
total_shards: 0,
}
}
(PipelinePhase::ComputingStats, ImportTrigger::Error(msg)) => {
PipelinePhase::Failed { error: msg }
}
(PipelinePhase::Merging { .. }, ImportTrigger::MergeComplete) => {
PipelinePhase::CleaningUp
}
(PipelinePhase::Merging { .. }, ImportTrigger::Error(msg)) => {
PipelinePhase::Failed { error: msg }
}
(PipelinePhase::CleaningUp, ImportTrigger::CleanupComplete) => {
let duration = self.start_time.elapsed();
let total = self.total_ngrams.load(Ordering::Relaxed);
PipelinePhase::Completed {
total_ngrams: total,
duration,
}
}
(current, trigger) => {
log::warn!(
"Invalid state transition: {:?} + {:?}",
current.name(),
trigger
);
return &self.phase;
}
};
self.phase = new_phase;
&self.phase
}
pub fn is_terminal(&self) -> bool {
self.phase.is_terminal()
}
}
impl Default for ImportContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_phase_is_terminal() {
assert!(!PipelinePhase::Initializing.is_terminal());
assert!(!PipelinePhase::Downloading {
active_workers: 4,
results_received: 0,
total_pending: 100
}
.is_terminal());
assert!(!PipelinePhase::ComputingStats.is_terminal());
assert!(PipelinePhase::Completed {
total_ngrams: 1000,
duration: Duration::from_secs(60)
}
.is_terminal());
assert!(PipelinePhase::Cancelled.is_terminal());
assert!(PipelinePhase::ForceQuit.is_terminal());
assert!(PipelinePhase::Failed {
error: "test".to_string()
}
.is_terminal());
}
#[test]
fn test_phase_names() {
assert_eq!(PipelinePhase::Initializing.name(), "Initializing");
assert_eq!(
PipelinePhase::Downloading {
active_workers: 4,
results_received: 0,
total_pending: 100
}
.name(),
"Downloading"
);
assert_eq!(
PipelinePhase::ComputingStats.name(),
"Computing MKN Statistics"
);
}
#[test]
fn test_context_transitions() {
let mut ctx = ImportContext::new();
assert!(matches!(ctx.phase, PipelinePhase::Initializing));
ctx.transition(ImportTrigger::InitComplete { total_pending: 100 });
assert!(matches!(ctx.phase, PipelinePhase::Downloading { .. }));
ctx.transition(ImportTrigger::Pause);
assert!(matches!(ctx.phase, PipelinePhase::Paused { .. }));
ctx.transition(ImportTrigger::Resume);
assert!(matches!(ctx.phase, PipelinePhase::Downloading { .. }));
ctx.transition(ImportTrigger::Cancel);
assert!(matches!(ctx.phase, PipelinePhase::Cancelled));
assert!(ctx.is_terminal());
}
#[tokio::test]
async fn test_cleanup_guard_lifo_order() {
use std::sync::atomic::{AtomicUsize, Ordering};
let order = Arc::new(AtomicUsize::new(0));
let results = Arc::new(parking_lot::Mutex::new(Vec::new()));
let mut guard = CleanupGuard::new();
let order_clone = Arc::clone(&order);
let results_clone = Arc::clone(&results);
guard.register(move || {
let results = results_clone;
let order = order_clone;
async move {
let n = order.fetch_add(1, Ordering::SeqCst);
results.lock().push(n);
}
});
let order_clone = Arc::clone(&order);
let results_clone = Arc::clone(&results);
guard.register(move || {
let results = results_clone;
let order = order_clone;
async move {
let n = order.fetch_add(1, Ordering::SeqCst);
results.lock().push(n);
}
});
let order_clone = Arc::clone(&order);
let results_clone = Arc::clone(&results);
guard.register(move || {
let results = results_clone;
let order = order_clone;
async move {
let n = order.fetch_add(1, Ordering::SeqCst);
results.lock().push(n);
}
});
guard.cleanup().await;
let results = results.lock();
assert_eq!(*results, vec![0, 1, 2]);
}
#[tokio::test]
async fn test_cleanup_resources_builder() {
use std::sync::atomic::{AtomicUsize, Ordering};
let order = Arc::new(AtomicUsize::new(0));
let results = Arc::new(parking_lot::Mutex::new(Vec::new()));
struct MockSharedState {
order: Arc<AtomicUsize>,
results: Arc<parking_lot::Mutex<Vec<(usize, &'static str)>>>,
}
impl Drop for MockSharedState {
fn drop(&mut self) {
let n = self.order.fetch_add(1, Ordering::SeqCst);
self.results.lock().push((n, "shared_state"));
}
}
let shared_state = Arc::new(MockSharedState {
order: Arc::clone(&order),
results: Arc::clone(&results),
});
let (shutdown_tx, _shutdown_rx) = tokio::sync::watch::channel(false);
let mut shutdown_txs = HashMap::new();
shutdown_txs.insert(0, shutdown_tx);
let order_clone = Arc::clone(&order);
let results_clone = Arc::clone(&results);
let worker_handle = tokio::spawn(async move {
let n = order_clone.fetch_add(1, Ordering::SeqCst);
results_clone.lock().push((n, "worker"));
});
let mut handles = HashMap::new();
handles.insert(0, worker_handle);
let resources: CleanupResources<MockSharedState> = CleanupResources::new()
.with_worker_handles(handles)
.with_worker_shutdown_txs(shutdown_txs)
.with_shared_state(shared_state);
let guard = resources.into_cleanup_guard();
guard.cleanup().await;
let results = results.lock();
assert!(results.len() >= 2, "Expected at least 2 cleanup actions");
let worker_order = results
.iter()
.find(|(_, name)| *name == "worker")
.map(|(n, _)| *n);
let shared_order = results
.iter()
.find(|(_, name)| *name == "shared_state")
.map(|(n, _)| *n);
if let (Some(w), Some(s)) = (worker_order, shared_order) {
assert!(
w < s,
"Worker should complete before shared_state is dropped"
);
}
}
#[test]
fn test_cleanup_resources_default() {
struct DummyState;
let resources: CleanupResources<DummyState> = CleanupResources::default();
assert!(resources.worker_handles.is_empty());
assert!(resources.worker_shutdown_txs.is_empty());
assert!(resources.shared_state.is_none());
}
}