1use crate::event::{Event, EventHub, LongOperationEvent, Origin};
114use anyhow::Result;
115use std::collections::HashMap;
116use std::sync::{
117 Arc, Mutex,
118 atomic::{AtomicBool, Ordering},
119};
120use std::thread;
121
122#[derive(Debug, Clone, PartialEq)]
124pub enum OperationStatus {
125 Running,
126 Completed,
127 Cancelled,
128 Failed(String),
129}
130
131#[derive(Debug, Clone)]
133pub struct OperationProgress {
134 pub percentage: f32, pub message: Option<String>,
136}
137
138impl OperationProgress {
139 pub fn new(percentage: f32, message: Option<String>) -> Self {
140 Self {
141 percentage: percentage.clamp(0.0, 100.0),
142 message,
143 }
144 }
145}
146
147pub trait LongOperation: Send + 'static {
149 type Output: Send + Sync + 'static + serde::Serialize;
150
151 fn execute(
152 &self,
153 progress_callback: Box<dyn Fn(OperationProgress) + Send>,
154 cancel_flag: Arc<AtomicBool>,
155 ) -> Result<Self::Output>;
156}
157
158trait OperationHandleTrait: Send {
160 fn get_status(&self) -> OperationStatus;
161 fn get_progress(&self) -> OperationProgress;
162 fn cancel(&self);
163 fn is_finished(&self) -> bool;
164}
165
166fn lock_or_recover<T>(mutex: &Mutex<T>) -> std::sync::MutexGuard<'_, T> {
172 mutex
173 .lock()
174 .unwrap_or_else(|poisoned| poisoned.into_inner())
175}
176
177struct OperationHandle {
179 status: Arc<Mutex<OperationStatus>>,
180 progress: Arc<Mutex<OperationProgress>>,
181 cancel_flag: Arc<AtomicBool>,
182 _join_handle: thread::JoinHandle<()>,
183}
184
185impl OperationHandleTrait for OperationHandle {
186 fn get_status(&self) -> OperationStatus {
187 lock_or_recover(&self.status).clone()
188 }
189
190 fn get_progress(&self) -> OperationProgress {
191 lock_or_recover(&self.progress).clone()
192 }
193
194 fn cancel(&self) {
195 self.cancel_flag.store(true, Ordering::Relaxed);
196 let mut status = lock_or_recover(&self.status);
197 if matches!(*status, OperationStatus::Running) {
198 *status = OperationStatus::Cancelled;
199 }
200 }
201
202 fn is_finished(&self) -> bool {
203 matches!(
204 self.get_status(),
205 OperationStatus::Completed | OperationStatus::Cancelled | OperationStatus::Failed(_)
206 )
207 }
208}
209
210pub struct LongOperationManager {
212 operations: Arc<Mutex<HashMap<String, Box<dyn OperationHandleTrait>>>>,
213 next_id: Arc<Mutex<u64>>,
214 results: Arc<Mutex<HashMap<String, String>>>, event_hub: Option<Arc<EventHub>>,
216}
217
218impl std::fmt::Debug for LongOperationManager {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 let operations_len = lock_or_recover(&self.operations).len();
221 let next_id = *lock_or_recover(&self.next_id);
222 let results_len = lock_or_recover(&self.results).len();
223
224 f.debug_struct("LongOperationManager")
225 .field("operations_len", &operations_len)
226 .field("next_id", &next_id)
227 .field("results_len", &results_len)
228 .field("event_hub_set", &self.event_hub.is_some())
229 .finish()
230 }
231}
232
233impl LongOperationManager {
234 pub fn new() -> Self {
235 Self {
236 operations: Arc::new(Mutex::new(HashMap::new())),
237 next_id: Arc::new(Mutex::new(0)),
238 results: Arc::new(Mutex::new(HashMap::new())),
239 event_hub: None,
240 }
241 }
242
243 pub fn set_event_hub(&mut self, event_hub: &Arc<EventHub>) {
245 self.event_hub = Some(Arc::clone(event_hub));
246 }
247
248 pub fn start_operation<Op: LongOperation>(&self, operation: Op) -> String {
250 let id = {
251 let mut next_id = lock_or_recover(&self.next_id);
252 *next_id += 1;
253 format!("op_{}", *next_id)
254 };
255
256 if let Some(event_hub) = &self.event_hub {
258 event_hub.send_event(Event {
259 origin: Origin::LongOperation(LongOperationEvent::Started),
260 ids: vec![],
261 data: Some(id.clone()),
262 });
263 }
264
265 let status = Arc::new(Mutex::new(OperationStatus::Running));
266 let progress = Arc::new(Mutex::new(OperationProgress::new(0.0, None)));
267 let cancel_flag = Arc::new(AtomicBool::new(false));
268
269 let status_clone = status.clone();
270 let progress_clone = progress.clone();
271 let cancel_flag_clone = cancel_flag.clone();
272 let results_clone = self.results.clone();
273 let id_clone = id.clone();
274 let event_hub_opt = self.event_hub.clone();
275
276 let join_handle = thread::spawn(move || {
277 let progress_callback = {
278 let progress = progress_clone.clone();
279 let event_hub_opt = event_hub_opt.clone();
280 let id_for_cb = id_clone.clone();
281 Box::new(move |prog: OperationProgress| {
282 *lock_or_recover(&progress) = prog.clone();
283 if let Some(event_hub) = &event_hub_opt {
284 let payload = serde_json::json!({
285 "id": id_for_cb,
286 "percentage": prog.percentage,
287 "message": prog.message,
288 })
289 .to_string();
290 event_hub.send_event(Event {
291 origin: Origin::LongOperation(LongOperationEvent::Progress),
292 ids: vec![],
293 data: Some(payload),
294 });
295 }
296 }) as Box<dyn Fn(OperationProgress) + Send>
297 };
298
299 let operation_result = operation.execute(progress_callback, cancel_flag_clone.clone());
300
301 let final_status = if cancel_flag_clone.load(Ordering::Relaxed) {
302 OperationStatus::Cancelled
303 } else {
304 match &operation_result {
305 Ok(result) => {
306 if let Ok(serialized) = serde_json::to_string(result) {
308 let mut results = lock_or_recover(&results_clone);
309 results.insert(id_clone.clone(), serialized);
310 }
311 OperationStatus::Completed
312 }
313 Err(e) => OperationStatus::Failed(e.to_string()),
314 }
315 };
316
317 if let Some(event_hub) = &event_hub_opt {
319 let (event, data) = match &final_status {
320 OperationStatus::Completed => (
321 LongOperationEvent::Completed,
322 serde_json::json!({"id": id_clone}).to_string(),
323 ),
324 OperationStatus::Cancelled => (
325 LongOperationEvent::Cancelled,
326 serde_json::json!({"id": id_clone}).to_string(),
327 ),
328 OperationStatus::Failed(err) => (
329 LongOperationEvent::Failed,
330 serde_json::json!({"id": id_clone, "error": err}).to_string(),
331 ),
332 OperationStatus::Running => (
333 LongOperationEvent::Progress,
334 serde_json::json!({"id": id_clone}).to_string(),
335 ),
336 };
337 event_hub.send_event(Event {
338 origin: Origin::LongOperation(event),
339 ids: vec![],
340 data: Some(data),
341 });
342 }
343
344 *lock_or_recover(&status_clone) = final_status;
345 });
346
347 let handle = OperationHandle {
348 status,
349 progress,
350 cancel_flag,
351 _join_handle: join_handle,
352 };
353
354 lock_or_recover(&self.operations).insert(id.clone(), Box::new(handle));
355
356 id
357 }
358
359 pub fn get_operation_status(&self, id: &str) -> Option<OperationStatus> {
361 let operations = lock_or_recover(&self.operations);
362 operations.get(id).map(|handle| handle.get_status())
363 }
364
365 pub fn get_operation_progress(&self, id: &str) -> Option<OperationProgress> {
367 let operations = lock_or_recover(&self.operations);
368 operations.get(id).map(|handle| handle.get_progress())
369 }
370
371 pub fn cancel_operation(&self, id: &str) -> bool {
373 let operations = lock_or_recover(&self.operations);
374 if let Some(handle) = operations.get(id) {
375 handle.cancel();
376 if let Some(event_hub) = &self.event_hub {
378 let payload = serde_json::json!({"id": id}).to_string();
379 event_hub.send_event(Event {
380 origin: Origin::LongOperation(LongOperationEvent::Cancelled),
381 ids: vec![],
382 data: Some(payload),
383 });
384 }
385 true
386 } else {
387 false
388 }
389 }
390
391 pub fn is_operation_finished(&self, id: &str) -> Option<bool> {
393 let operations = lock_or_recover(&self.operations);
394 operations.get(id).map(|handle| handle.is_finished())
395 }
396
397 pub fn cleanup_finished_operations(&self) {
399 let mut operations = lock_or_recover(&self.operations);
400 operations.retain(|_, handle| !handle.is_finished());
401 }
402
403 pub fn list_operations(&self) -> Vec<String> {
405 let operations = lock_or_recover(&self.operations);
406 operations.keys().cloned().collect()
407 }
408
409 pub fn get_operations_summary(&self) -> Vec<(String, OperationStatus, OperationProgress)> {
411 let operations = lock_or_recover(&self.operations);
412 operations
413 .iter()
414 .map(|(id, handle)| (id.clone(), handle.get_status(), handle.get_progress()))
415 .collect()
416 }
417
418 pub fn store_operation_result<T: serde::Serialize>(&self, id: &str, result: T) -> Result<()> {
420 let serialized = serde_json::to_string(&result)?;
421 let mut results = lock_or_recover(&self.results);
422 results.insert(id.to_string(), serialized);
423 Ok(())
424 }
425
426 pub fn get_operation_result(&self, id: &str) -> Option<String> {
428 let results = lock_or_recover(&self.results);
429 results.get(id).cloned()
430 }
431}
432
433impl Default for LongOperationManager {
434 fn default() -> Self {
435 Self::new()
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use anyhow::anyhow;
443 use std::time::Duration;
444
445 pub struct FileProcessingOperation {
447 pub _file_path: String,
448 pub total_files: usize,
449 }
450
451 impl LongOperation for FileProcessingOperation {
452 type Output = ();
453
454 fn execute(
455 &self,
456 progress_callback: Box<dyn Fn(OperationProgress) + Send>,
457 cancel_flag: Arc<AtomicBool>,
458 ) -> Result<Self::Output> {
459 for i in 0..self.total_files {
460 if cancel_flag.load(Ordering::Relaxed) {
462 return Err(anyhow!("Operation was cancelled".to_string()));
463 }
464
465 thread::sleep(Duration::from_millis(500));
467
468 let percentage = (i as f32 / self.total_files as f32) * 100.0;
470 progress_callback(OperationProgress::new(
471 percentage,
472 Some(format!("Processing file {} of {}", i + 1, self.total_files)),
473 ));
474 }
475
476 progress_callback(OperationProgress::new(100.0, Some("Completed".to_string())));
478 Ok(())
479 }
480 }
481
482 #[test]
483 fn test_operation_manager() {
484 let manager = LongOperationManager::new();
485
486 let operation = FileProcessingOperation {
487 _file_path: "/tmp/test".to_string(),
488 total_files: 5,
489 };
490
491 let op_id = manager.start_operation(operation);
492
493 assert_eq!(
495 manager.get_operation_status(&op_id),
496 Some(OperationStatus::Running)
497 );
498
499 thread::sleep(Duration::from_millis(100));
501 let progress = manager.get_operation_progress(&op_id);
502 assert!(progress.is_some());
503
504 assert!(manager.cancel_operation(&op_id));
506 thread::sleep(Duration::from_millis(100));
507 assert_eq!(
508 manager.get_operation_status(&op_id),
509 Some(OperationStatus::Cancelled)
510 );
511 }
512}