stakpak_shared/
task_manager.rs

1use crate::helper::generate_simple_id;
2use crate::remote_connection::{RemoteConnectionInfo, RemoteConnectionManager};
3use chrono::{DateTime, Utc};
4use std::{collections::HashMap, process::Stdio, sync::Arc, time::Duration};
5use tokio::{
6    io::{AsyncBufReadExt, BufReader},
7    process::Command,
8    sync::{broadcast, mpsc, oneshot},
9    time::timeout,
10};
11
12const START_TASK_WAIT_TIME: Duration = Duration::from_millis(300);
13
14pub type TaskId = String;
15
16#[derive(Debug, Clone, PartialEq, serde::Serialize)]
17pub enum TaskStatus {
18    Pending,
19    Running,
20    Completed,
21    Failed,
22    Cancelled,
23    TimedOut,
24}
25
26#[derive(Debug, Clone)]
27pub struct Task {
28    pub id: TaskId,
29    pub status: TaskStatus,
30    pub command: String,
31    pub remote_connection: Option<RemoteConnectionInfo>,
32    pub output: Option<String>,
33    pub error: Option<String>,
34    pub start_time: DateTime<Utc>,
35    pub duration: Option<Duration>,
36    pub timeout: Option<Duration>,
37}
38
39pub struct TaskEntry {
40    pub task: Task,
41    pub handle: tokio::task::JoinHandle<()>,
42    pub process_id: Option<u32>,
43    pub cancel_tx: Option<oneshot::Sender<()>>,
44}
45
46#[derive(Debug, Clone, serde::Serialize)]
47pub struct TaskInfo {
48    pub id: TaskId,
49    pub status: TaskStatus,
50    pub command: String,
51    pub output: Option<String>,
52    pub start_time: DateTime<Utc>,
53    pub duration: Option<Duration>,
54}
55
56impl From<&Task> for TaskInfo {
57    fn from(task: &Task) -> Self {
58        let duration = if matches!(task.status, TaskStatus::Running) {
59            // For running tasks, calculate duration from start time to now
60            Some(
61                Utc::now()
62                    .signed_duration_since(task.start_time)
63                    .to_std()
64                    .unwrap_or_default(),
65            )
66        } else {
67            // For completed/failed/cancelled tasks, use the stored duration
68            task.duration
69        };
70
71        TaskInfo {
72            id: task.id.clone(),
73            status: task.status.clone(),
74            command: task.command.clone(),
75            output: task.output.clone(),
76            start_time: task.start_time,
77            duration,
78        }
79    }
80}
81
82pub struct TaskCompletion {
83    pub output: String,
84    pub error: Option<String>,
85    pub final_status: TaskStatus,
86}
87
88#[derive(Debug, thiserror::Error)]
89pub enum TaskError {
90    #[error("Task not found: {0}")]
91    TaskNotFound(TaskId),
92    #[error("Task already running: {0}")]
93    TaskAlreadyRunning(TaskId),
94    #[error("Manager shutdown")]
95    ManagerShutdown,
96    #[error("Command execution failed: {0}")]
97    ExecutionFailed(String),
98    #[error("Task timeout")]
99    TaskTimeout,
100    #[error("Task cancelled")]
101    TaskCancelled,
102    #[error("Task failed on start: {0}")]
103    TaskFailedOnStart(String),
104}
105
106pub enum TaskMessage {
107    Start {
108        id: Option<TaskId>,
109        command: String,
110        remote_connection: Option<RemoteConnectionInfo>,
111        timeout: Option<Duration>,
112        response_tx: oneshot::Sender<Result<TaskId, TaskError>>,
113    },
114    Cancel {
115        id: TaskId,
116        response_tx: oneshot::Sender<Result<(), TaskError>>,
117    },
118    GetStatus {
119        id: TaskId,
120        response_tx: oneshot::Sender<Option<TaskStatus>>,
121    },
122    GetTaskDetails {
123        id: TaskId,
124        response_tx: oneshot::Sender<Option<TaskInfo>>,
125    },
126    GetAllTasks {
127        response_tx: oneshot::Sender<Vec<TaskInfo>>,
128    },
129    Shutdown {
130        response_tx: oneshot::Sender<()>,
131    },
132    TaskUpdate {
133        id: TaskId,
134        completion: TaskCompletion,
135    },
136    PartialUpdate {
137        id: TaskId,
138        output: String,
139    },
140}
141
142pub struct TaskManager {
143    tasks: HashMap<TaskId, TaskEntry>,
144    tx: mpsc::UnboundedSender<TaskMessage>,
145    rx: mpsc::UnboundedReceiver<TaskMessage>,
146    #[allow(dead_code)]
147    shutdown_tx: broadcast::Sender<()>,
148    shutdown_rx: broadcast::Receiver<()>,
149}
150
151impl Default for TaskManager {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157impl TaskManager {
158    pub fn new() -> Self {
159        let (tx, rx) = mpsc::unbounded_channel();
160        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
161
162        Self {
163            tasks: HashMap::new(),
164            tx,
165            rx,
166            shutdown_tx,
167            shutdown_rx,
168        }
169    }
170
171    pub fn handle(&self) -> Arc<TaskManagerHandle> {
172        Arc::new(TaskManagerHandle {
173            tx: self.tx.clone(),
174        })
175    }
176
177    pub async fn run(mut self) {
178        loop {
179            tokio::select! {
180                msg = self.rx.recv() => {
181                    match msg {
182                        Some(msg) => {
183                            if self.handle_message(msg).await {
184                                break;
185                            }
186                        }
187                        None => {
188                            break;
189                        }
190                    }
191                }
192                _ = self.shutdown_rx.recv() => {
193                    self.shutdown_all_tasks().await;
194                    break;
195                }
196            }
197        }
198    }
199
200    async fn handle_message(&mut self, msg: TaskMessage) -> bool {
201        match msg {
202            TaskMessage::Start {
203                id,
204                command,
205                remote_connection,
206                timeout,
207                response_tx,
208            } => {
209                let task_id = id.unwrap_or_else(|| generate_simple_id(6));
210                let result = self
211                    .start_task(task_id.clone(), command, timeout, remote_connection)
212                    .await;
213                let _ = response_tx.send(result.map(|_| task_id.clone()));
214                false
215            }
216            TaskMessage::Cancel { id, response_tx } => {
217                let result = self.cancel_task(&id).await;
218                let _ = response_tx.send(result);
219                false
220            }
221            TaskMessage::GetStatus { id, response_tx } => {
222                let status = self.tasks.get(&id).map(|entry| entry.task.status.clone());
223                let _ = response_tx.send(status);
224                false
225            }
226            TaskMessage::GetTaskDetails { id, response_tx } => {
227                let task_info = self.tasks.get(&id).map(|entry| TaskInfo::from(&entry.task));
228                let _ = response_tx.send(task_info);
229                false
230            }
231            TaskMessage::GetAllTasks { response_tx } => {
232                let mut tasks: Vec<TaskInfo> = self
233                    .tasks
234                    .values()
235                    .map(|entry| TaskInfo::from(&entry.task))
236                    .collect();
237                tasks.sort_by(|a, b| b.start_time.cmp(&a.start_time));
238                let _ = response_tx.send(tasks);
239                false
240            }
241            TaskMessage::TaskUpdate { id, completion } => {
242                if let Some(entry) = self.tasks.get_mut(&id) {
243                    entry.task.status = completion.final_status;
244                    entry.task.output = Some(completion.output);
245                    entry.task.error = completion.error;
246                    entry.task.duration = Some(
247                        Utc::now()
248                            .signed_duration_since(entry.task.start_time)
249                            .to_std()
250                            .unwrap_or_default(),
251                    );
252
253                    // Keep completed tasks in the list so they can be viewed with get_all_tasks
254                    // TODO: Consider implementing a cleanup mechanism for old completed tasks
255                    // if matches!(entry.task.status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::TimedOut) {
256                    //     self.tasks.remove(&id);
257                    // }
258                }
259                false
260            }
261            TaskMessage::PartialUpdate { id, output } => {
262                if let Some(entry) = self.tasks.get_mut(&id) {
263                    match &entry.task.output {
264                        Some(existing) => {
265                            entry.task.output = Some(format!("{}{}", existing, output));
266                        }
267                        None => {
268                            entry.task.output = Some(output);
269                        }
270                    }
271                }
272                false
273            }
274            TaskMessage::Shutdown { response_tx } => {
275                self.shutdown_all_tasks().await;
276                let _ = response_tx.send(());
277                true
278            }
279        }
280    }
281
282    async fn start_task(
283        &mut self,
284        id: TaskId,
285        command: String,
286        timeout: Option<Duration>,
287        remote_connection: Option<RemoteConnectionInfo>,
288    ) -> Result<(), TaskError> {
289        if self.tasks.contains_key(&id) {
290            return Err(TaskError::TaskAlreadyRunning(id));
291        }
292
293        let task = Task {
294            id: id.clone(),
295            status: TaskStatus::Running,
296            command: command.clone(),
297            remote_connection: remote_connection.clone(),
298            output: None,
299            error: None,
300            start_time: Utc::now(),
301            duration: None,
302            timeout,
303        };
304
305        let (cancel_tx, cancel_rx) = oneshot::channel();
306        let (process_tx, process_rx) = oneshot::channel();
307        let task_tx: mpsc::UnboundedSender<TaskMessage> = self.tx.clone();
308
309        let is_remote_task = remote_connection.is_some();
310
311        // Spawn task immediately - SSH connection happens inside the task
312        let handle = tokio::spawn(Self::execute_task(
313            id.clone(),
314            command,
315            remote_connection,
316            timeout,
317            cancel_rx,
318            process_tx,
319            task_tx,
320        ));
321
322        let entry = TaskEntry {
323            task,
324            handle,
325            process_id: None,
326            cancel_tx: Some(cancel_tx),
327        };
328
329        self.tasks.insert(id.clone(), entry);
330
331        // Wait for the process ID for local tasks only
332        if !is_remote_task {
333            // Local task - wait for process ID for proper cleanup
334            if let Ok(process_id) = process_rx.await
335                && let Some(entry) = self.tasks.get_mut(&id)
336            {
337                entry.process_id = Some(process_id);
338            }
339        }
340        // Remote tasks don't have local process IDs, so we skip waiting
341
342        Ok(())
343    }
344
345    async fn cancel_task(&mut self, id: &TaskId) -> Result<(), TaskError> {
346        if let Some(mut entry) = self.tasks.remove(id) {
347            entry.task.status = TaskStatus::Cancelled;
348
349            if let Some(cancel_tx) = entry.cancel_tx.take() {
350                let _ = cancel_tx.send(());
351            }
352
353            if let Some(process_id) = entry.process_id {
354                // Kill the process by PID only if it still exists
355                #[cfg(unix)]
356                {
357                    use std::process::Command;
358                    // First check if the process exists
359                    let check_result = Command::new("kill")
360                        .arg("-0") // Signal 0 just checks if process exists
361                        .arg(process_id.to_string())
362                        .output();
363
364                    // Only kill if the process actually exists
365                    if check_result
366                        .map(|output| output.status.success())
367                        .unwrap_or(false)
368                    {
369                        let _ = Command::new("kill")
370                            .arg("-9")
371                            .arg(process_id.to_string())
372                            .output(); // Use output() to capture stderr and avoid printing
373                    }
374                }
375
376                #[cfg(windows)]
377                {
378                    use std::process::Command;
379                    // On Windows, tasklist can check if process exists
380                    let check_result = Command::new("tasklist")
381                        .arg("/FI")
382                        .arg(format!("PID eq {}", process_id))
383                        .arg("/FO")
384                        .arg("CSV")
385                        .output();
386
387                    // Only kill if the process actually exists
388                    if let Ok(output) = check_result {
389                        let output_str = String::from_utf8_lossy(&output.stdout);
390                        if output_str.lines().count() > 1 {
391                            // More than just header line
392                            let _ = Command::new("taskkill")
393                                .arg("/F")
394                                .arg("/PID")
395                                .arg(process_id.to_string())
396                                .output(); // Use output() to capture stderr and avoid printing
397                        }
398                    }
399                }
400            }
401
402            entry.handle.abort();
403            Ok(())
404        } else {
405            Err(TaskError::TaskNotFound(id.clone()))
406        }
407    }
408
409    async fn execute_task(
410        id: TaskId,
411        command: String,
412        remote_connection: Option<RemoteConnectionInfo>,
413        task_timeout: Option<Duration>,
414        mut cancel_rx: oneshot::Receiver<()>,
415        process_tx: oneshot::Sender<u32>,
416        task_tx: mpsc::UnboundedSender<TaskMessage>,
417    ) {
418        let completion = if let Some(remote_info) = remote_connection {
419            // Remote execution
420            Self::execute_remote_task(
421                id.clone(),
422                command,
423                remote_info,
424                task_timeout,
425                &mut cancel_rx,
426                &task_tx,
427            )
428            .await
429        } else {
430            // Local execution (existing logic)
431            Self::execute_local_task(
432                id.clone(),
433                command,
434                task_timeout,
435                &mut cancel_rx,
436                process_tx,
437                &task_tx,
438            )
439            .await
440        };
441
442        // Send task completion back to manager
443        let _ = task_tx.send(TaskMessage::TaskUpdate {
444            id: id.clone(),
445            completion,
446        });
447    }
448
449    async fn execute_local_task(
450        id: TaskId,
451        command: String,
452        task_timeout: Option<Duration>,
453        cancel_rx: &mut oneshot::Receiver<()>,
454        process_tx: oneshot::Sender<u32>,
455        task_tx: &mpsc::UnboundedSender<TaskMessage>,
456    ) -> TaskCompletion {
457        let mut cmd = Command::new("sh");
458        cmd.arg("-c")
459            .arg(&command)
460            .stdin(Stdio::null())
461            .stdout(Stdio::piped())
462            .stderr(Stdio::piped());
463        #[cfg(unix)]
464        {
465            cmd.env("DEBIAN_FRONTEND", "noninteractive")
466                .env("SUDO_ASKPASS", "/bin/false")
467                .process_group(0);
468        }
469        #[cfg(windows)]
470        {
471            // On Windows, create a new process group
472            cmd.creation_flags(0x00000200); // CREATE_NEW_PROCESS_GROUP
473        }
474
475        let mut child = match cmd.spawn() {
476            Ok(child) => child,
477            Err(err) => {
478                return TaskCompletion {
479                    output: String::new(),
480                    error: Some(format!("Failed to spawn command: {}", err)),
481                    final_status: TaskStatus::Failed,
482                };
483            }
484        };
485
486        // Send the process ID back to the manager for tracking
487        if let Some(process_id) = child.id() {
488            let _ = process_tx.send(process_id);
489        }
490
491        // Take stdout and stderr for streaming
492        let stdout = child.stdout.take().unwrap();
493        let stderr = child.stderr.take().unwrap();
494
495        let stdout_reader = BufReader::new(stdout);
496        let stderr_reader = BufReader::new(stderr);
497
498        let mut stdout_lines = stdout_reader.lines();
499        let mut stderr_lines = stderr_reader.lines();
500
501        // Helper function to stream output and handle cancellation
502        let stream_output = async {
503            let mut final_output = String::new();
504            let mut final_error: Option<String> = None;
505
506            loop {
507                tokio::select! {
508                    line = stdout_lines.next_line() => {
509                        match line {
510                            Ok(Some(line)) => {
511                                let output_line = format!("{}\n", line);
512                                final_output.push_str(&output_line);
513                                let _ = task_tx.send(TaskMessage::PartialUpdate {
514                                    id: id.clone(),
515                                    output: output_line,
516                                });
517                            }
518                            Ok(None) => {
519                                // stdout stream ended
520                            }
521                            Err(err) => {
522                                final_error = Some(format!("Error reading stdout: {}", err));
523                                break;
524                            }
525                        }
526                    }
527                    line = stderr_lines.next_line() => {
528                        match line {
529                            Ok(Some(line)) => {
530                                let output_line = format!("{}\n", line);
531                                final_output.push_str(&output_line);
532                                let _ = task_tx.send(TaskMessage::PartialUpdate {
533                                    id: id.clone(),
534                                    output: output_line,
535                                });
536                            }
537                            Ok(None) => {
538                                // stderr stream ended
539                            }
540                            Err(err) => {
541                                final_error = Some(format!("Error reading stderr: {}", err));
542                                break;
543                            }
544                        }
545                    }
546                    status = child.wait() => {
547                        match status {
548                            Ok(exit_status) => {
549                                if final_output.is_empty() {
550                                    final_output = "No output".to_string();
551                                }
552
553                                let completion = if exit_status.success() {
554                                    TaskCompletion {
555                                        output: final_output,
556                                        error: final_error,
557                                        final_status: TaskStatus::Completed,
558                                    }
559                                } else {
560                                    TaskCompletion {
561                                        output: final_output,
562                                        error: final_error.or_else(|| Some(format!("Command failed with exit code: {:?}", exit_status.code()))),
563                                        final_status: TaskStatus::Failed,
564                                    }
565                                };
566                                return completion;
567                            }
568                            Err(err) => {
569                                return TaskCompletion {
570                                    output: final_output,
571                                    error: Some(err.to_string()),
572                                    final_status: TaskStatus::Failed,
573                                };
574                            }
575                        }
576                    }
577                    _ = &mut *cancel_rx => {
578                        return TaskCompletion {
579                            output: final_output,
580                            error: Some("Tool call was cancelled and don't try to run it again".to_string()),
581                            final_status: TaskStatus::Cancelled,
582                        };
583                    }
584                }
585            }
586
587            TaskCompletion {
588                output: final_output,
589                error: final_error,
590                final_status: TaskStatus::Failed,
591            }
592        };
593
594        // Execute with timeout if provided
595        if let Some(timeout_duration) = task_timeout {
596            match timeout(timeout_duration, stream_output).await {
597                Ok(result) => result,
598                Err(_) => TaskCompletion {
599                    output: String::new(),
600                    error: Some("Task timed out".to_string()),
601                    final_status: TaskStatus::TimedOut,
602                },
603            }
604        } else {
605            stream_output.await
606        }
607    }
608
609    async fn execute_remote_task(
610        id: TaskId,
611        command: String,
612        remote_info: RemoteConnectionInfo,
613        task_timeout: Option<Duration>,
614        cancel_rx: &mut oneshot::Receiver<()>,
615        task_tx: &mpsc::UnboundedSender<TaskMessage>,
616    ) -> TaskCompletion {
617        // Use RemoteConnectionManager to get a connection
618        let connection_manager = RemoteConnectionManager::new();
619        let connection = match connection_manager.get_connection(&remote_info).await {
620            Ok(conn) => conn,
621            Err(e) => {
622                return TaskCompletion {
623                    output: String::new(),
624                    error: Some(format!("Failed to establish remote connection: {}", e)),
625                    final_status: TaskStatus::Failed,
626                };
627            }
628        };
629
630        // Create progress callback for streaming updates
631        let task_tx_clone = task_tx.clone();
632        let id_clone = id.clone();
633        let progress_callback = move |output: String| {
634            if !output.trim().is_empty() {
635                let _ = task_tx_clone.send(TaskMessage::PartialUpdate {
636                    id: id_clone.clone(),
637                    output,
638                });
639            }
640        };
641
642        // Use unified execution with proper cancellation and timeout
643        let options = crate::remote_connection::CommandOptions {
644            timeout: task_timeout,
645            with_progress: false,
646            simple: false,
647        };
648
649        match connection
650            .execute_command_unified(&command, options, cancel_rx, Some(progress_callback), None)
651            .await
652        {
653            Ok((output, exit_code)) => TaskCompletion {
654                output,
655                error: if exit_code != 0 {
656                    Some(format!("Command exited with code {}", exit_code))
657                } else {
658                    None
659                },
660                final_status: TaskStatus::Completed,
661            },
662            Err(e) => {
663                let error_msg = e.to_string();
664                let status = if error_msg.contains("timed out") {
665                    TaskStatus::TimedOut
666                } else if error_msg.contains("cancelled") {
667                    TaskStatus::Cancelled
668                } else {
669                    TaskStatus::Failed
670                };
671
672                TaskCompletion {
673                    output: String::new(),
674                    error: Some(if error_msg.contains("cancelled") {
675                        "Tool call was cancelled and don't try to run it again".to_string()
676                    } else {
677                        format!("Remote command failed: {}", error_msg)
678                    }),
679                    final_status: status,
680                }
681            }
682        }
683    }
684
685    async fn shutdown_all_tasks(&mut self) {
686        for (_id, mut entry) in self.tasks.drain() {
687            if let Some(cancel_tx) = entry.cancel_tx.take() {
688                let _ = cancel_tx.send(());
689            }
690
691            if let Some(process_id) = entry.process_id {
692                // Kill the process by PID only if it still exists
693                #[cfg(unix)]
694                {
695                    use std::process::Command;
696                    // First check if the process exists
697                    let check_result = Command::new("kill")
698                        .arg("-0") // Signal 0 just checks if process exists
699                        .arg(process_id.to_string())
700                        .output();
701
702                    // Only kill if the process actually exists
703                    if check_result
704                        .map(|output| output.status.success())
705                        .unwrap_or(false)
706                    {
707                        let _ = Command::new("kill")
708                            .arg("-9")
709                            .arg(process_id.to_string())
710                            .output(); // Use output() to capture stderr and avoid printing
711                    }
712                }
713
714                #[cfg(windows)]
715                {
716                    use std::process::Command;
717                    // On Windows, tasklist can check if process exists
718                    let check_result = Command::new("tasklist")
719                        .arg("/FI")
720                        .arg(format!("PID eq {}", process_id))
721                        .arg("/FO")
722                        .arg("CSV")
723                        .output();
724
725                    // Only kill if the process actually exists
726                    if let Ok(output) = check_result {
727                        let output_str = String::from_utf8_lossy(&output.stdout);
728                        if output_str.lines().count() > 1 {
729                            // More than just header line
730                            let _ = Command::new("taskkill")
731                                .arg("/F")
732                                .arg("/PID")
733                                .arg(process_id.to_string())
734                                .output(); // Use output() to capture stderr and avoid printing
735                        }
736                    }
737                }
738            }
739
740            entry.handle.abort();
741        }
742    }
743}
744
745pub struct TaskManagerHandle {
746    tx: mpsc::UnboundedSender<TaskMessage>,
747}
748
749impl TaskManagerHandle {
750    pub async fn start_task(
751        &self,
752        command: String,
753        timeout: Option<Duration>,
754        remote_connection: Option<RemoteConnectionInfo>,
755    ) -> Result<TaskInfo, TaskError> {
756        let (response_tx, response_rx) = oneshot::channel();
757
758        self.tx
759            .send(TaskMessage::Start {
760                id: None,
761                command: command.clone(),
762                remote_connection: remote_connection.clone(),
763                timeout,
764                response_tx,
765            })
766            .map_err(|_| TaskError::ManagerShutdown)?;
767
768        let task_id = response_rx
769            .await
770            .map_err(|_| TaskError::ManagerShutdown)??;
771
772        // Wait for the task to start and get its status
773        tokio::time::sleep(START_TASK_WAIT_TIME).await;
774
775        let task_info = self
776            .get_task_details(task_id.clone())
777            .await
778            .map_err(|_| TaskError::ManagerShutdown)?
779            .ok_or_else(|| TaskError::TaskNotFound(task_id.clone()))?;
780
781        // If the task failed or was cancelled during start, return an error
782        if matches!(task_info.status, TaskStatus::Failed | TaskStatus::Cancelled) {
783            return Err(TaskError::TaskFailedOnStart(
784                task_info
785                    .output
786                    .unwrap_or_else(|| "Unknown reason".to_string()),
787            ));
788        }
789
790        // Return the task info with updated status
791        Ok(task_info)
792    }
793
794    pub async fn cancel_task(&self, id: TaskId) -> Result<TaskInfo, TaskError> {
795        // Get the task info before cancelling
796        let task_info = self
797            .get_all_tasks()
798            .await?
799            .into_iter()
800            .find(|task| task.id == id)
801            .ok_or_else(|| TaskError::TaskNotFound(id.clone()))?;
802
803        let (response_tx, response_rx) = oneshot::channel();
804
805        self.tx
806            .send(TaskMessage::Cancel { id, response_tx })
807            .map_err(|_| TaskError::ManagerShutdown)?;
808
809        response_rx
810            .await
811            .map_err(|_| TaskError::ManagerShutdown)??;
812
813        // Return the task info with updated status
814        Ok(TaskInfo {
815            status: TaskStatus::Cancelled,
816            duration: Some(
817                Utc::now()
818                    .signed_duration_since(task_info.start_time)
819                    .to_std()
820                    .unwrap_or_default(),
821            ),
822            ..task_info
823        })
824    }
825
826    pub async fn get_task_status(&self, id: TaskId) -> Result<Option<TaskStatus>, TaskError> {
827        let (response_tx, response_rx) = oneshot::channel();
828
829        self.tx
830            .send(TaskMessage::GetStatus { id, response_tx })
831            .map_err(|_| TaskError::ManagerShutdown)?;
832
833        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
834    }
835
836    pub async fn get_task_details(&self, id: TaskId) -> Result<Option<TaskInfo>, TaskError> {
837        let (response_tx, response_rx) = oneshot::channel();
838
839        self.tx
840            .send(TaskMessage::GetTaskDetails { id, response_tx })
841            .map_err(|_| TaskError::ManagerShutdown)?;
842
843        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
844    }
845
846    pub async fn get_all_tasks(&self) -> Result<Vec<TaskInfo>, TaskError> {
847        let (response_tx, response_rx) = oneshot::channel();
848
849        self.tx
850            .send(TaskMessage::GetAllTasks { response_tx })
851            .map_err(|_| TaskError::ManagerShutdown)?;
852
853        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
854    }
855
856    pub async fn shutdown(&self) -> Result<(), TaskError> {
857        let (response_tx, response_rx) = oneshot::channel();
858
859        self.tx
860            .send(TaskMessage::Shutdown { response_tx })
861            .map_err(|_| TaskError::ManagerShutdown)?;
862
863        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
864    }
865}
866
867#[cfg(test)]
868mod tests {
869    use super::*;
870    use tokio::time::{Duration, sleep};
871
872    #[tokio::test]
873    async fn test_task_manager_shutdown() {
874        let task_manager = TaskManager::new();
875        let handle = task_manager.handle();
876
877        // Spawn the task manager
878        let manager_handle = tokio::spawn(async move {
879            task_manager.run().await;
880        });
881
882        // Start a background task
883        let task_info = handle
884            .start_task("sleep 5".to_string(), None, None)
885            .await
886            .expect("Failed to start task");
887
888        // Verify task is running
889        let status = handle
890            .get_task_status(task_info.id.clone())
891            .await
892            .expect("Failed to get task status");
893        assert_eq!(status, Some(TaskStatus::Running));
894
895        // Shutdown the task manager
896        handle
897            .shutdown()
898            .await
899            .expect("Failed to shutdown task manager");
900
901        // Wait a bit for the shutdown to complete
902        sleep(Duration::from_millis(100)).await;
903
904        // Verify the manager task has completed
905        assert!(manager_handle.is_finished());
906    }
907
908    #[tokio::test]
909    async fn test_task_manager_cancels_tasks_on_shutdown() {
910        let task_manager = TaskManager::new();
911        let handle = task_manager.handle();
912
913        // Spawn the task manager
914        let manager_handle = tokio::spawn(async move {
915            task_manager.run().await;
916        });
917
918        // Start a long-running background task
919        let task_info = handle
920            .start_task("sleep 10".to_string(), None, None)
921            .await
922            .expect("Failed to start task");
923
924        // Verify task is running
925        let status = handle
926            .get_task_status(task_info.id.clone())
927            .await
928            .expect("Failed to get task status");
929        assert_eq!(status, Some(TaskStatus::Running));
930
931        // Shutdown the task manager
932        handle
933            .shutdown()
934            .await
935            .expect("Failed to shutdown task manager");
936
937        // Wait a bit for the shutdown to complete
938        sleep(Duration::from_millis(100)).await;
939
940        // Verify the manager task has completed
941        assert!(manager_handle.is_finished());
942    }
943
944    #[tokio::test]
945    async fn test_task_manager_start_and_complete_task() {
946        let task_manager = TaskManager::new();
947        let handle = task_manager.handle();
948
949        // Spawn the task manager
950        let _manager_handle = tokio::spawn(async move {
951            task_manager.run().await;
952        });
953
954        // Start a simple task
955        let task_info = handle
956            .start_task("echo 'Hello, World!'".to_string(), None, None)
957            .await
958            .expect("Failed to start task");
959
960        // Wait for the task to complete
961        sleep(Duration::from_millis(500)).await;
962
963        // Get task status
964        let status = handle
965            .get_task_status(task_info.id.clone())
966            .await
967            .expect("Failed to get task status");
968        assert_eq!(status, Some(TaskStatus::Completed));
969
970        // Get all tasks
971        let tasks = handle
972            .get_all_tasks()
973            .await
974            .expect("Failed to get all tasks");
975        assert_eq!(tasks.len(), 1);
976        assert_eq!(tasks[0].status, TaskStatus::Completed);
977
978        // Shutdown the task manager
979        handle
980            .shutdown()
981            .await
982            .expect("Failed to shutdown task manager");
983    }
984
985    #[tokio::test]
986    async fn test_task_manager_detects_immediate_failure() {
987        let task_manager = TaskManager::new();
988        let handle = task_manager.handle();
989
990        // Spawn the task manager
991        let _manager_handle = tokio::spawn(async move {
992            task_manager.run().await;
993        });
994
995        // Start a task that will fail immediately
996        let result = handle
997            .start_task("nonexistent_command_12345".to_string(), None, None)
998            .await;
999
1000        // Should get a TaskFailedOnStart error
1001        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1002
1003        // Shutdown the task manager
1004        handle
1005            .shutdown()
1006            .await
1007            .expect("Failed to shutdown task manager");
1008    }
1009
1010    #[tokio::test]
1011    async fn test_task_manager_detects_immediate_exit_code_failure() {
1012        let task_manager = TaskManager::new();
1013        let handle = task_manager.handle();
1014
1015        // Spawn the task manager
1016        let _manager_handle = tokio::spawn(async move {
1017            task_manager.run().await;
1018        });
1019
1020        // Start a task that will exit with non-zero code immediately
1021        let result = handle.start_task("exit 1".to_string(), None, None).await;
1022
1023        // Should get a TaskFailedOnStart error
1024        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1025
1026        // Shutdown the task manager
1027        handle
1028            .shutdown()
1029            .await
1030            .expect("Failed to shutdown task manager");
1031    }
1032}