1use crate::event::{Event, EventHub, LongOperationEvent, Origin};
113use anyhow::Result;
114use std::collections::HashMap;
115use std::sync::{
116 Arc, Mutex,
117 atomic::{AtomicBool, Ordering},
118};
119use std::thread;
120
121#[derive(Debug, Clone, PartialEq)]
123pub enum OperationStatus {
124 Running,
125 Completed,
126 Cancelled,
127 Failed(String),
128}
129
130#[derive(Debug, Clone)]
132pub struct OperationProgress {
133 pub percentage: f32, pub message: Option<String>,
135}
136
137impl OperationProgress {
138 pub fn new(percentage: f32, message: Option<String>) -> Self {
139 Self {
140 percentage: percentage.clamp(0.0, 100.0),
141 message,
142 }
143 }
144}
145
146pub trait LongOperation: Send + 'static {
148 type Output: Send + Sync + 'static + serde::Serialize;
149
150 fn execute(
151 &self,
152 progress_callback: Box<dyn Fn(OperationProgress) + Send>,
153 cancel_flag: Arc<AtomicBool>,
154 ) -> Result<Self::Output>;
155}
156
157trait OperationHandleTrait: Send {
159 fn get_status(&self) -> OperationStatus;
160 fn get_progress(&self) -> OperationProgress;
161 fn cancel(&self);
162 fn is_finished(&self) -> bool;
163}
164
165fn lock_or_recover<T>(mutex: &Mutex<T>) -> std::sync::MutexGuard<'_, T> {
171 mutex
172 .lock()
173 .unwrap_or_else(|poisoned| poisoned.into_inner())
174}
175
176struct OperationHandle {
178 status: Arc<Mutex<OperationStatus>>,
179 progress: Arc<Mutex<OperationProgress>>,
180 cancel_flag: Arc<AtomicBool>,
181 _join_handle: thread::JoinHandle<()>,
182}
183
184impl OperationHandleTrait for OperationHandle {
185 fn get_status(&self) -> OperationStatus {
186 lock_or_recover(&self.status).clone()
187 }
188
189 fn get_progress(&self) -> OperationProgress {
190 lock_or_recover(&self.progress).clone()
191 }
192
193 fn cancel(&self) {
194 self.cancel_flag.store(true, Ordering::Relaxed);
195 let mut status = lock_or_recover(&self.status);
196 if matches!(*status, OperationStatus::Running) {
197 *status = OperationStatus::Cancelled;
198 }
199 }
200
201 fn is_finished(&self) -> bool {
202 matches!(
203 self.get_status(),
204 OperationStatus::Completed | OperationStatus::Cancelled | OperationStatus::Failed(_)
205 )
206 }
207}
208
209pub struct LongOperationManager {
211 operations: Arc<Mutex<HashMap<String, Box<dyn OperationHandleTrait>>>>,
212 next_id: Arc<Mutex<u64>>,
213 results: Arc<Mutex<HashMap<String, String>>>, event_hub: Option<Arc<EventHub>>,
215}
216
217impl LongOperationManager {
218 pub fn new() -> Self {
219 Self {
220 operations: Arc::new(Mutex::new(HashMap::new())),
221 next_id: Arc::new(Mutex::new(0)),
222 results: Arc::new(Mutex::new(HashMap::new())),
223 event_hub: None,
224 }
225 }
226
227 pub fn set_event_hub(&mut self, event_hub: &Arc<EventHub>) {
229 self.event_hub = Some(Arc::clone(event_hub));
230 }
231
232 pub fn start_operation<Op: LongOperation>(&self, operation: Op) -> String {
234 let id = {
235 let mut next_id = lock_or_recover(&self.next_id);
236 *next_id += 1;
237 format!("op_{}", *next_id)
238 };
239
240 if let Some(event_hub) = &self.event_hub {
242 event_hub.send_event(Event {
243 origin: Origin::LongOperation(LongOperationEvent::Started),
244 ids: vec![],
245 data: Some(id.clone()),
246 });
247 }
248
249 let status = Arc::new(Mutex::new(OperationStatus::Running));
250 let progress = Arc::new(Mutex::new(OperationProgress::new(0.0, None)));
251 let cancel_flag = Arc::new(AtomicBool::new(false));
252
253 let status_clone = status.clone();
254 let progress_clone = progress.clone();
255 let cancel_flag_clone = cancel_flag.clone();
256 let results_clone = self.results.clone();
257 let id_clone = id.clone();
258 let event_hub_opt = self.event_hub.clone();
259
260 let join_handle = thread::spawn(move || {
261 let progress_callback = {
262 let progress = progress_clone.clone();
263 let event_hub_opt = event_hub_opt.clone();
264 let id_for_cb = id_clone.clone();
265 Box::new(move |prog: OperationProgress| {
266 *lock_or_recover(&progress) = prog.clone();
267 if let Some(event_hub) = &event_hub_opt {
268 let payload = serde_json::json!({
269 "id": id_for_cb,
270 "percentage": prog.percentage,
271 "message": prog.message,
272 })
273 .to_string();
274 event_hub.send_event(Event {
275 origin: Origin::LongOperation(LongOperationEvent::Progress),
276 ids: vec![],
277 data: Some(payload),
278 });
279 }
280 }) as Box<dyn Fn(OperationProgress) + Send>
281 };
282
283 let operation_result = operation.execute(progress_callback, cancel_flag_clone.clone());
284
285 let final_status = if cancel_flag_clone.load(Ordering::Relaxed) {
286 OperationStatus::Cancelled
287 } else {
288 match &operation_result {
289 Ok(result) => {
290 if let Ok(serialized) = serde_json::to_string(result) {
292 let mut results = lock_or_recover(&results_clone);
293 results.insert(id_clone.clone(), serialized);
294 }
295 OperationStatus::Completed
296 }
297 Err(e) => OperationStatus::Failed(e.to_string()),
298 }
299 };
300
301 if let Some(event_hub) = &event_hub_opt {
303 let (event, data) = match &final_status {
304 OperationStatus::Completed => (
305 LongOperationEvent::Completed,
306 serde_json::json!({"id": id_clone}).to_string(),
307 ),
308 OperationStatus::Cancelled => (
309 LongOperationEvent::Cancelled,
310 serde_json::json!({"id": id_clone}).to_string(),
311 ),
312 OperationStatus::Failed(err) => (
313 LongOperationEvent::Failed,
314 serde_json::json!({"id": id_clone, "error": err}).to_string(),
315 ),
316 OperationStatus::Running => (
317 LongOperationEvent::Progress,
318 serde_json::json!({"id": id_clone}).to_string(),
319 ),
320 };
321 event_hub.send_event(Event {
322 origin: Origin::LongOperation(event),
323 ids: vec![],
324 data: Some(data),
325 });
326 }
327
328 *lock_or_recover(&status_clone) = final_status;
329 });
330
331 let handle = OperationHandle {
332 status,
333 progress,
334 cancel_flag,
335 _join_handle: join_handle,
336 };
337
338 lock_or_recover(&self.operations).insert(id.clone(), Box::new(handle));
339
340 id
341 }
342
343 pub fn get_operation_status(&self, id: &str) -> Option<OperationStatus> {
345 let operations = lock_or_recover(&self.operations);
346 operations.get(id).map(|handle| handle.get_status())
347 }
348
349 pub fn get_operation_progress(&self, id: &str) -> Option<OperationProgress> {
351 let operations = lock_or_recover(&self.operations);
352 operations.get(id).map(|handle| handle.get_progress())
353 }
354
355 pub fn cancel_operation(&self, id: &str) -> bool {
357 let operations = lock_or_recover(&self.operations);
358 if let Some(handle) = operations.get(id) {
359 handle.cancel();
360 if let Some(event_hub) = &self.event_hub {
362 let payload = serde_json::json!({"id": id}).to_string();
363 event_hub.send_event(Event {
364 origin: Origin::LongOperation(LongOperationEvent::Cancelled),
365 ids: vec![],
366 data: Some(payload),
367 });
368 }
369 true
370 } else {
371 false
372 }
373 }
374
375 pub fn is_operation_finished(&self, id: &str) -> Option<bool> {
377 let operations = lock_or_recover(&self.operations);
378 operations.get(id).map(|handle| handle.is_finished())
379 }
380
381 pub fn cleanup_finished_operations(&self) {
383 let mut operations = lock_or_recover(&self.operations);
384 operations.retain(|_, handle| !handle.is_finished());
385 }
386
387 pub fn list_operations(&self) -> Vec<String> {
389 let operations = lock_or_recover(&self.operations);
390 operations.keys().cloned().collect()
391 }
392
393 pub fn get_operations_summary(&self) -> Vec<(String, OperationStatus, OperationProgress)> {
395 let operations = lock_or_recover(&self.operations);
396 operations
397 .iter()
398 .map(|(id, handle)| (id.clone(), handle.get_status(), handle.get_progress()))
399 .collect()
400 }
401
402 pub fn store_operation_result<T: serde::Serialize>(&self, id: &str, result: T) -> Result<()> {
404 let serialized = serde_json::to_string(&result)?;
405 let mut results = lock_or_recover(&self.results);
406 results.insert(id.to_string(), serialized);
407 Ok(())
408 }
409
410 pub fn get_operation_result(&self, id: &str) -> Option<String> {
412 let results = lock_or_recover(&self.results);
413 results.get(id).cloned()
414 }
415}
416
417impl Default for LongOperationManager {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use anyhow::anyhow;
427 use std::time::Duration;
428
429 pub struct FileProcessingOperation {
431 pub _file_path: String,
432 pub total_files: usize,
433 }
434
435 impl LongOperation for FileProcessingOperation {
436 type Output = ();
437
438 fn execute(
439 &self,
440 progress_callback: Box<dyn Fn(OperationProgress) + Send>,
441 cancel_flag: Arc<AtomicBool>,
442 ) -> Result<Self::Output> {
443 for i in 0..self.total_files {
444 if cancel_flag.load(Ordering::Relaxed) {
446 return Err(anyhow!("Operation was cancelled".to_string()));
447 }
448
449 thread::sleep(Duration::from_millis(500));
451
452 let percentage = (i as f32 / self.total_files as f32) * 100.0;
454 progress_callback(OperationProgress::new(
455 percentage,
456 Some(format!("Processing file {} of {}", i + 1, self.total_files)),
457 ));
458 }
459
460 progress_callback(OperationProgress::new(100.0, Some("Completed".to_string())));
462 Ok(())
463 }
464 }
465
466 #[test]
467 fn test_operation_manager() {
468 let manager = LongOperationManager::new();
469
470 let operation = FileProcessingOperation {
471 _file_path: "/tmp/test".to_string(),
472 total_files: 5,
473 };
474
475 let op_id = manager.start_operation(operation);
476
477 assert_eq!(
479 manager.get_operation_status(&op_id),
480 Some(OperationStatus::Running)
481 );
482
483 thread::sleep(Duration::from_millis(100));
485 let progress = manager.get_operation_progress(&op_id);
486 assert!(progress.is_some());
487
488 assert!(manager.cancel_operation(&op_id));
490 thread::sleep(Duration::from_millis(100));
491 assert_eq!(
492 manager.get_operation_status(&op_id),
493 Some(OperationStatus::Cancelled)
494 );
495 }
496}