use crate::event::{Event, EventHub, LongOperationEvent, Origin};
use anyhow::Result;
use std::collections::HashMap;
use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
};
use std::thread;
#[derive(Debug, Clone, PartialEq)]
pub enum OperationStatus {
Running,
Completed,
Cancelled,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct OperationProgress {
pub percentage: f32, pub message: Option<String>,
}
impl OperationProgress {
pub fn new(percentage: f32, message: Option<String>) -> Self {
Self {
percentage: percentage.clamp(0.0, 100.0),
message,
}
}
}
pub trait LongOperation: Send + 'static {
type Output: Send + Sync + 'static + serde::Serialize;
fn execute(
&self,
progress_callback: Box<dyn Fn(OperationProgress) + Send>,
cancel_flag: Arc<AtomicBool>,
) -> Result<Self::Output>;
}
trait OperationHandleTrait: Send {
fn get_status(&self) -> OperationStatus;
fn get_progress(&self) -> OperationProgress;
fn cancel(&self);
fn is_finished(&self) -> bool;
}
struct OperationHandle {
status: Arc<Mutex<OperationStatus>>,
progress: Arc<Mutex<OperationProgress>>,
cancel_flag: Arc<AtomicBool>,
_join_handle: thread::JoinHandle<()>,
}
impl OperationHandleTrait for OperationHandle {
fn get_status(&self) -> OperationStatus {
self.status.lock().unwrap().clone()
}
fn get_progress(&self) -> OperationProgress {
self.progress.lock().unwrap().clone()
}
fn cancel(&self) {
self.cancel_flag.store(true, Ordering::Relaxed);
let mut status = self.status.lock().unwrap();
if matches!(*status, OperationStatus::Running) {
*status = OperationStatus::Cancelled;
}
}
fn is_finished(&self) -> bool {
matches!(
self.get_status(),
OperationStatus::Completed | OperationStatus::Cancelled | OperationStatus::Failed(_)
)
}
}
pub struct LongOperationManager {
operations: Arc<Mutex<HashMap<String, Box<dyn OperationHandleTrait>>>>,
next_id: Arc<Mutex<u64>>,
results: Arc<Mutex<HashMap<String, String>>>, event_hub: Option<Arc<EventHub>>,
}
impl LongOperationManager {
pub fn new() -> Self {
Self {
operations: Arc::new(Mutex::new(HashMap::new())),
next_id: Arc::new(Mutex::new(0)),
results: Arc::new(Mutex::new(HashMap::new())),
event_hub: None,
}
}
pub fn set_event_hub(&mut self, event_hub: &Arc<EventHub>) {
self.event_hub = Some(Arc::clone(event_hub));
}
pub fn start_operation<Op: LongOperation>(&self, operation: Op) -> String {
let id = {
let mut next_id = self.next_id.lock().unwrap();
*next_id += 1;
format!("op_{}", *next_id)
};
if let Some(event_hub) = &self.event_hub {
event_hub.send_event(Event {
origin: Origin::LongOperation(LongOperationEvent::Started),
ids: vec![],
data: Some(id.clone()),
});
}
let status = Arc::new(Mutex::new(OperationStatus::Running));
let progress = Arc::new(Mutex::new(OperationProgress::new(0.0, None)));
let cancel_flag = Arc::new(AtomicBool::new(false));
let status_clone = status.clone();
let progress_clone = progress.clone();
let cancel_flag_clone = cancel_flag.clone();
let results_clone = self.results.clone();
let id_clone = id.clone();
let event_hub_opt = self.event_hub.clone();
let join_handle = thread::spawn(move || {
let progress_callback = {
let progress = progress_clone.clone();
let event_hub_opt = event_hub_opt.clone();
let id_for_cb = id_clone.clone();
Box::new(move |prog: OperationProgress| {
*progress.lock().unwrap() = prog.clone();
if let Some(event_hub) = &event_hub_opt {
let payload = serde_json::json!({
"id": id_for_cb,
"percentage": prog.percentage,
"message": prog.message,
})
.to_string();
event_hub.send_event(Event {
origin: Origin::LongOperation(LongOperationEvent::Progress),
ids: vec![],
data: Some(payload),
});
}
}) as Box<dyn Fn(OperationProgress) + Send>
};
let operation_result = operation.execute(progress_callback, cancel_flag_clone.clone());
let final_status = if cancel_flag_clone.load(Ordering::Relaxed) {
OperationStatus::Cancelled
} else {
match &operation_result {
Ok(result) => {
if let Ok(serialized) = serde_json::to_string(result) {
let mut results = results_clone.lock().unwrap();
results.insert(id_clone.clone(), serialized);
}
OperationStatus::Completed
}
Err(e) => OperationStatus::Failed(e.to_string()),
}
};
if let Some(event_hub) = &event_hub_opt {
let (event, data) = match &final_status {
OperationStatus::Completed => (
LongOperationEvent::Completed,
serde_json::json!({"id": id_clone}).to_string(),
),
OperationStatus::Cancelled => (
LongOperationEvent::Cancelled,
serde_json::json!({"id": id_clone}).to_string(),
),
OperationStatus::Failed(err) => (
LongOperationEvent::Failed,
serde_json::json!({"id": id_clone, "error": err}).to_string(),
),
OperationStatus::Running => (
LongOperationEvent::Progress,
serde_json::json!({"id": id_clone}).to_string(),
),
};
event_hub.send_event(Event {
origin: Origin::LongOperation(event),
ids: vec![],
data: Some(data),
});
}
*status_clone.lock().unwrap() = final_status;
});
let handle = OperationHandle {
status,
progress,
cancel_flag,
_join_handle: join_handle,
};
self.operations
.lock()
.unwrap()
.insert(id.clone(), Box::new(handle));
id
}
pub fn get_operation_status(&self, id: &str) -> Option<OperationStatus> {
let operations = self.operations.lock().unwrap();
operations.get(id).map(|handle| handle.get_status())
}
pub fn get_operation_progress(&self, id: &str) -> Option<OperationProgress> {
let operations = self.operations.lock().unwrap();
operations.get(id).map(|handle| handle.get_progress())
}
pub fn cancel_operation(&self, id: &str) -> bool {
let operations = self.operations.lock().unwrap();
if let Some(handle) = operations.get(id) {
handle.cancel();
if let Some(event_hub) = &self.event_hub {
let payload = serde_json::json!({"id": id}).to_string();
event_hub.send_event(Event {
origin: Origin::LongOperation(LongOperationEvent::Cancelled),
ids: vec![],
data: Some(payload),
});
}
true
} else {
false
}
}
pub fn is_operation_finished(&self, id: &str) -> Option<bool> {
let operations = self.operations.lock().unwrap();
operations.get(id).map(|handle| handle.is_finished())
}
pub fn cleanup_finished_operations(&self) {
let mut operations = self.operations.lock().unwrap();
operations.retain(|_, handle| !handle.is_finished());
}
pub fn list_operations(&self) -> Vec<String> {
let operations = self.operations.lock().unwrap();
operations.keys().cloned().collect()
}
pub fn get_operations_summary(&self) -> Vec<(String, OperationStatus, OperationProgress)> {
let operations = self.operations.lock().unwrap();
operations
.iter()
.map(|(id, handle)| (id.clone(), handle.get_status(), handle.get_progress()))
.collect()
}
pub fn store_operation_result<T: serde::Serialize>(&self, id: &str, result: T) -> Result<()> {
let serialized = serde_json::to_string(&result)?;
let mut results = self.results.lock().unwrap();
results.insert(id.to_string(), serialized);
Ok(())
}
pub fn get_operation_result(&self, id: &str) -> Option<String> {
let results = self.results.lock().unwrap();
results.get(id).cloned()
}
}
impl Default for LongOperationManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use std::time::Duration;
pub struct FileProcessingOperation {
pub _file_path: String,
pub total_files: usize,
}
impl LongOperation for FileProcessingOperation {
type Output = ();
fn execute(
&self,
progress_callback: Box<dyn Fn(OperationProgress) + Send>,
cancel_flag: Arc<AtomicBool>,
) -> Result<Self::Output> {
for i in 0..self.total_files {
if cancel_flag.load(Ordering::Relaxed) {
return Err(anyhow!("Operation was cancelled".to_string()));
}
thread::sleep(Duration::from_millis(500));
let percentage = (i as f32 / self.total_files as f32) * 100.0;
progress_callback(OperationProgress::new(
percentage,
Some(format!("Processing file {} of {}", i + 1, self.total_files)),
));
}
progress_callback(OperationProgress::new(100.0, Some("Completed".to_string())));
Ok(())
}
}
#[test]
fn test_operation_manager() {
let manager = LongOperationManager::new();
let operation = FileProcessingOperation {
_file_path: "/tmp/test".to_string(),
total_files: 5,
};
let op_id = manager.start_operation(operation);
assert_eq!(
manager.get_operation_status(&op_id),
Some(OperationStatus::Running)
);
thread::sleep(Duration::from_millis(100));
let progress = manager.get_operation_progress(&op_id);
assert!(progress.is_some());
assert!(manager.cancel_operation(&op_id));
thread::sleep(Duration::from_millis(100));
assert_eq!(
manager.get_operation_status(&op_id),
Some(OperationStatus::Cancelled)
);
}
}