Skip to main content

stakpak_shared/
task_manager.rs

1use crate::helper::generate_simple_id;
2use crate::remote_connection::{RemoteConnectionInfo, RemoteConnectionManager};
3use chrono::{DateTime, Utc};
4use std::{collections::HashMap, process::Stdio, sync::Arc, time::Duration};
5use tokio::{
6    io::{AsyncBufReadExt, BufReader},
7    process::Command,
8    sync::{broadcast, mpsc, oneshot},
9    time::timeout,
10};
11
12const START_TASK_WAIT_TIME: Duration = Duration::from_millis(300);
13
14/// Kill a process and its entire process group.
15///
16/// Uses process group kill (`kill -9 -{pid}`) on Unix and `taskkill /F /T` on
17/// Windows to ensure child processes spawned by shells (node, vite, esbuild, etc.)
18/// are also terminated.
19///
20/// This is safe to call even if the process has already exited.
21fn terminate_process_group(process_id: u32) {
22    #[cfg(unix)]
23    {
24        use std::process::Command;
25        // First check if the process exists
26        let check_result = Command::new("kill")
27            .arg("-0") // Signal 0 just checks if process exists
28            .arg(process_id.to_string())
29            .output();
30
31        // Only kill if the process actually exists
32        if check_result
33            .map(|output| output.status.success())
34            .unwrap_or(false)
35        {
36            // Kill the entire process group using negative PID
37            // Since we spawn with .process_group(0), the shell becomes the process group leader
38            // Using -{pid} kills all processes in that group (shell + children like node/vite/esbuild)
39            let _ = Command::new("kill")
40                .arg("-9")
41                .arg(format!("-{}", process_id))
42                .output();
43
44            // Also try to kill the individual process in case it's not a group leader
45            let _ = Command::new("kill")
46                .arg("-9")
47                .arg(process_id.to_string())
48                .output();
49        }
50    }
51
52    #[cfg(windows)]
53    {
54        use std::process::Command;
55        // On Windows, use taskkill with /T flag to kill the process tree
56        let check_result = Command::new("tasklist")
57            .arg("/FI")
58            .arg(format!("PID eq {}", process_id))
59            .arg("/FO")
60            .arg("CSV")
61            .output();
62
63        // Only kill if the process actually exists
64        if let Ok(output) = check_result {
65            let output_str = String::from_utf8_lossy(&output.stdout);
66            if output_str.lines().count() > 1 {
67                // More than just header line - use /T to kill process tree
68                let _ = Command::new("taskkill")
69                    .arg("/F")
70                    .arg("/T") // Kill process tree
71                    .arg("/PID")
72                    .arg(process_id.to_string())
73                    .output();
74            }
75        }
76    }
77}
78
79pub type TaskId = String;
80
81#[derive(Debug, Clone, PartialEq, serde::Serialize)]
82pub enum TaskStatus {
83    Pending,
84    Running,
85    Completed,
86    Failed,
87    Cancelled,
88    TimedOut,
89}
90
91#[derive(Debug, Clone)]
92pub struct Task {
93    pub id: TaskId,
94    pub status: TaskStatus,
95    pub command: String,
96    pub remote_connection: Option<RemoteConnectionInfo>,
97    pub output: Option<String>,
98    pub error: Option<String>,
99    pub start_time: DateTime<Utc>,
100    pub duration: Option<Duration>,
101    pub timeout: Option<Duration>,
102}
103
104pub struct TaskEntry {
105    pub task: Task,
106    pub handle: tokio::task::JoinHandle<()>,
107    pub process_id: Option<u32>,
108    pub cancel_tx: Option<oneshot::Sender<()>>,
109}
110
111#[derive(Debug, Clone, serde::Serialize)]
112pub struct TaskInfo {
113    pub id: TaskId,
114    pub status: TaskStatus,
115    pub command: String,
116    pub output: Option<String>,
117    pub start_time: DateTime<Utc>,
118    pub duration: Option<Duration>,
119}
120
121impl From<&Task> for TaskInfo {
122    fn from(task: &Task) -> Self {
123        let duration = if matches!(task.status, TaskStatus::Running) {
124            // For running tasks, calculate duration from start time to now
125            Some(
126                Utc::now()
127                    .signed_duration_since(task.start_time)
128                    .to_std()
129                    .unwrap_or_default(),
130            )
131        } else {
132            // For completed/failed/cancelled tasks, use the stored duration
133            task.duration
134        };
135
136        TaskInfo {
137            id: task.id.clone(),
138            status: task.status.clone(),
139            command: task.command.clone(),
140            output: task.output.clone(),
141            start_time: task.start_time,
142            duration,
143        }
144    }
145}
146
147pub struct TaskCompletion {
148    pub output: String,
149    pub error: Option<String>,
150    pub final_status: TaskStatus,
151}
152
153#[derive(Debug, thiserror::Error)]
154pub enum TaskError {
155    #[error("Task not found: {0}")]
156    TaskNotFound(TaskId),
157    #[error("Task already running: {0}")]
158    TaskAlreadyRunning(TaskId),
159    #[error("Manager shutdown")]
160    ManagerShutdown,
161    #[error("Command execution failed: {0}")]
162    ExecutionFailed(String),
163    #[error("Task timeout")]
164    TaskTimeout,
165    #[error("Task cancelled")]
166    TaskCancelled,
167    #[error("Task failed on start: {0}")]
168    TaskFailedOnStart(String),
169}
170
171pub enum TaskMessage {
172    Start {
173        id: Option<TaskId>,
174        command: String,
175        remote_connection: Option<RemoteConnectionInfo>,
176        timeout: Option<Duration>,
177        response_tx: oneshot::Sender<Result<TaskId, TaskError>>,
178    },
179    Cancel {
180        id: TaskId,
181        response_tx: oneshot::Sender<Result<(), TaskError>>,
182    },
183    GetStatus {
184        id: TaskId,
185        response_tx: oneshot::Sender<Option<TaskStatus>>,
186    },
187    GetTaskDetails {
188        id: TaskId,
189        response_tx: oneshot::Sender<Option<TaskInfo>>,
190    },
191    GetAllTasks {
192        response_tx: oneshot::Sender<Vec<TaskInfo>>,
193    },
194    Shutdown {
195        response_tx: oneshot::Sender<()>,
196    },
197    TaskUpdate {
198        id: TaskId,
199        completion: TaskCompletion,
200    },
201    PartialUpdate {
202        id: TaskId,
203        output: String,
204    },
205}
206
207pub struct TaskManager {
208    tasks: HashMap<TaskId, TaskEntry>,
209    tx: mpsc::UnboundedSender<TaskMessage>,
210    rx: mpsc::UnboundedReceiver<TaskMessage>,
211    shutdown_tx: broadcast::Sender<()>,
212    shutdown_rx: broadcast::Receiver<()>,
213}
214
215impl Default for TaskManager {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221impl TaskManager {
222    pub fn new() -> Self {
223        let (tx, rx) = mpsc::unbounded_channel();
224        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
225
226        Self {
227            tasks: HashMap::new(),
228            tx,
229            rx,
230            shutdown_tx,
231            shutdown_rx,
232        }
233    }
234
235    pub fn handle(&self) -> Arc<TaskManagerHandle> {
236        Arc::new(TaskManagerHandle {
237            tx: self.tx.clone(),
238            shutdown_tx: self.shutdown_tx.clone(),
239        })
240    }
241
242    pub async fn run(mut self) {
243        loop {
244            tokio::select! {
245                msg = self.rx.recv() => {
246                    match msg {
247                        Some(msg) => {
248                            if self.handle_message(msg).await {
249                                break;
250                            }
251                        }
252                        None => {
253                            // All senders (TaskManagerHandles) have been dropped.
254                            // Clean up all running tasks and child processes.
255                            self.shutdown_all_tasks().await;
256                            break;
257                        }
258                    }
259                }
260                _ = self.shutdown_rx.recv() => {
261                    self.shutdown_all_tasks().await;
262                    break;
263                }
264            }
265        }
266    }
267
268    async fn handle_message(&mut self, msg: TaskMessage) -> bool {
269        match msg {
270            TaskMessage::Start {
271                id,
272                command,
273                remote_connection,
274                timeout,
275                response_tx,
276            } => {
277                let task_id = id.unwrap_or_else(|| generate_simple_id(6));
278                let result = self
279                    .start_task(task_id.clone(), command, timeout, remote_connection)
280                    .await;
281                let _ = response_tx.send(result.map(|_| task_id.clone()));
282                false
283            }
284            TaskMessage::Cancel { id, response_tx } => {
285                let result = self.cancel_task(&id).await;
286                let _ = response_tx.send(result);
287                false
288            }
289            TaskMessage::GetStatus { id, response_tx } => {
290                let status = self.tasks.get(&id).map(|entry| entry.task.status.clone());
291                let _ = response_tx.send(status);
292                false
293            }
294            TaskMessage::GetTaskDetails { id, response_tx } => {
295                let task_info = self.tasks.get(&id).map(|entry| TaskInfo::from(&entry.task));
296                let _ = response_tx.send(task_info);
297                false
298            }
299            TaskMessage::GetAllTasks { response_tx } => {
300                let mut tasks: Vec<TaskInfo> = self
301                    .tasks
302                    .values()
303                    .map(|entry| TaskInfo::from(&entry.task))
304                    .collect();
305                tasks.sort_by(|a, b| b.start_time.cmp(&a.start_time));
306                let _ = response_tx.send(tasks);
307                false
308            }
309            TaskMessage::TaskUpdate { id, completion } => {
310                if let Some(entry) = self.tasks.get_mut(&id) {
311                    entry.task.status = completion.final_status;
312                    entry.task.output = Some(completion.output);
313                    entry.task.error = completion.error;
314                    entry.task.duration = Some(
315                        Utc::now()
316                            .signed_duration_since(entry.task.start_time)
317                            .to_std()
318                            .unwrap_or_default(),
319                    );
320
321                    // Keep completed tasks in the list so they can be viewed with get_all_tasks
322                    // TODO: Consider implementing a cleanup mechanism for old completed tasks
323                    // if matches!(entry.task.status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::TimedOut) {
324                    //     self.tasks.remove(&id);
325                    // }
326                }
327                false
328            }
329            TaskMessage::PartialUpdate { id, output } => {
330                if let Some(entry) = self.tasks.get_mut(&id) {
331                    match &entry.task.output {
332                        Some(existing) => {
333                            entry.task.output = Some(format!("{}{}", existing, output));
334                        }
335                        None => {
336                            entry.task.output = Some(output);
337                        }
338                    }
339                }
340                false
341            }
342            TaskMessage::Shutdown { response_tx } => {
343                self.shutdown_all_tasks().await;
344                let _ = response_tx.send(());
345                true
346            }
347        }
348    }
349
350    async fn start_task(
351        &mut self,
352        id: TaskId,
353        command: String,
354        timeout: Option<Duration>,
355        remote_connection: Option<RemoteConnectionInfo>,
356    ) -> Result<(), TaskError> {
357        if self.tasks.contains_key(&id) {
358            return Err(TaskError::TaskAlreadyRunning(id));
359        }
360
361        let task = Task {
362            id: id.clone(),
363            status: TaskStatus::Running,
364            command: command.clone(),
365            remote_connection: remote_connection.clone(),
366            output: None,
367            error: None,
368            start_time: Utc::now(),
369            duration: None,
370            timeout,
371        };
372
373        let (cancel_tx, cancel_rx) = oneshot::channel();
374        let (process_tx, process_rx) = oneshot::channel();
375        let task_tx: mpsc::UnboundedSender<TaskMessage> = self.tx.clone();
376
377        let is_remote_task = remote_connection.is_some();
378
379        // Spawn task immediately - SSH connection happens inside the task
380        let handle = tokio::spawn(Self::execute_task(
381            id.clone(),
382            command,
383            remote_connection,
384            timeout,
385            cancel_rx,
386            process_tx,
387            task_tx,
388        ));
389
390        let entry = TaskEntry {
391            task,
392            handle,
393            process_id: None,
394            cancel_tx: Some(cancel_tx),
395        };
396
397        self.tasks.insert(id.clone(), entry);
398
399        // Wait for the process ID for local tasks only
400        if !is_remote_task {
401            // Local task - wait for process ID for proper cleanup
402            if let Ok(process_id) = process_rx.await
403                && let Some(entry) = self.tasks.get_mut(&id)
404            {
405                entry.process_id = Some(process_id);
406            }
407        }
408        // Remote tasks don't have local process IDs, so we skip waiting
409
410        Ok(())
411    }
412
413    async fn cancel_task(&mut self, id: &TaskId) -> Result<(), TaskError> {
414        if let Some(mut entry) = self.tasks.remove(id) {
415            entry.task.status = TaskStatus::Cancelled;
416
417            if let Some(cancel_tx) = entry.cancel_tx.take() {
418                let _ = cancel_tx.send(());
419            }
420
421            if let Some(process_id) = entry.process_id {
422                terminate_process_group(process_id);
423            }
424
425            entry.handle.abort();
426            Ok(())
427        } else {
428            Err(TaskError::TaskNotFound(id.clone()))
429        }
430    }
431
432    async fn execute_task(
433        id: TaskId,
434        command: String,
435        remote_connection: Option<RemoteConnectionInfo>,
436        task_timeout: Option<Duration>,
437        mut cancel_rx: oneshot::Receiver<()>,
438        process_tx: oneshot::Sender<u32>,
439        task_tx: mpsc::UnboundedSender<TaskMessage>,
440    ) {
441        let completion = if let Some(remote_info) = remote_connection {
442            // Remote execution
443            Self::execute_remote_task(
444                id.clone(),
445                command,
446                remote_info,
447                task_timeout,
448                &mut cancel_rx,
449                &task_tx,
450            )
451            .await
452        } else {
453            // Local execution (existing logic)
454            Self::execute_local_task(
455                id.clone(),
456                command,
457                task_timeout,
458                &mut cancel_rx,
459                process_tx,
460                &task_tx,
461            )
462            .await
463        };
464
465        // Send task completion back to manager
466        let _ = task_tx.send(TaskMessage::TaskUpdate {
467            id: id.clone(),
468            completion,
469        });
470    }
471
472    async fn execute_local_task(
473        id: TaskId,
474        command: String,
475        task_timeout: Option<Duration>,
476        cancel_rx: &mut oneshot::Receiver<()>,
477        process_tx: oneshot::Sender<u32>,
478        task_tx: &mpsc::UnboundedSender<TaskMessage>,
479    ) -> TaskCompletion {
480        let mut cmd = Command::new("sh");
481        cmd.arg("-c")
482            .arg(&command)
483            .stdin(Stdio::null())
484            .stdout(Stdio::piped())
485            .stderr(Stdio::piped());
486        #[cfg(unix)]
487        {
488            cmd.env("DEBIAN_FRONTEND", "noninteractive")
489                .env("SUDO_ASKPASS", "/bin/false")
490                .process_group(0);
491        }
492        #[cfg(windows)]
493        {
494            // On Windows, create a new process group
495            cmd.creation_flags(0x00000200); // CREATE_NEW_PROCESS_GROUP
496        }
497
498        let mut child = match cmd.spawn() {
499            Ok(child) => child,
500            Err(err) => {
501                return TaskCompletion {
502                    output: String::new(),
503                    error: Some(format!("Failed to spawn command: {}", err)),
504                    final_status: TaskStatus::Failed,
505                };
506            }
507        };
508
509        // Send the process ID back to the manager for tracking
510        if let Some(process_id) = child.id() {
511            let _ = process_tx.send(process_id);
512        }
513
514        // Take stdout and stderr for streaming
515        let stdout = child.stdout.take().unwrap();
516        let stderr = child.stderr.take().unwrap();
517
518        let stdout_reader = BufReader::new(stdout);
519        let stderr_reader = BufReader::new(stderr);
520
521        let mut stdout_lines = stdout_reader.lines();
522        let mut stderr_lines = stderr_reader.lines();
523
524        // Helper function to stream output and handle cancellation
525        let stream_output = async {
526            let mut final_output = String::new();
527            let mut final_error: Option<String> = None;
528
529            loop {
530                tokio::select! {
531                    line = stdout_lines.next_line() => {
532                        match line {
533                            Ok(Some(line)) => {
534                                let output_line = format!("{}\n", line);
535                                final_output.push_str(&output_line);
536                                let _ = task_tx.send(TaskMessage::PartialUpdate {
537                                    id: id.clone(),
538                                    output: output_line,
539                                });
540                            }
541                            Ok(None) => {
542                                // stdout stream ended
543                            }
544                            Err(err) => {
545                                final_error = Some(format!("Error reading stdout: {}", err));
546                                break;
547                            }
548                        }
549                    }
550                    line = stderr_lines.next_line() => {
551                        match line {
552                            Ok(Some(line)) => {
553                                let output_line = format!("{}\n", line);
554                                final_output.push_str(&output_line);
555                                let _ = task_tx.send(TaskMessage::PartialUpdate {
556                                    id: id.clone(),
557                                    output: output_line,
558                                });
559                            }
560                            Ok(None) => {
561                                // stderr stream ended
562                            }
563                            Err(err) => {
564                                final_error = Some(format!("Error reading stderr: {}", err));
565                                break;
566                            }
567                        }
568                    }
569                    status = child.wait() => {
570                        match status {
571                            Ok(exit_status) => {
572                                if final_output.is_empty() {
573                                    final_output = "No output".to_string();
574                                }
575
576                                let completion = if exit_status.success() {
577                                    TaskCompletion {
578                                        output: final_output,
579                                        error: final_error,
580                                        final_status: TaskStatus::Completed,
581                                    }
582                                } else {
583                                    TaskCompletion {
584                                        output: final_output,
585                                        error: final_error.or_else(|| Some(format!("Command failed with exit code: {:?}", exit_status.code()))),
586                                        final_status: TaskStatus::Failed,
587                                    }
588                                };
589                                return completion;
590                            }
591                            Err(err) => {
592                                return TaskCompletion {
593                                    output: final_output,
594                                    error: Some(err.to_string()),
595                                    final_status: TaskStatus::Failed,
596                                };
597                            }
598                        }
599                    }
600                    _ = &mut *cancel_rx => {
601                        return TaskCompletion {
602                            output: final_output,
603                            error: Some("Tool call was cancelled and don't try to run it again".to_string()),
604                            final_status: TaskStatus::Cancelled,
605                        };
606                    }
607                }
608            }
609
610            TaskCompletion {
611                output: final_output,
612                error: final_error,
613                final_status: TaskStatus::Failed,
614            }
615        };
616
617        // Execute with timeout if provided
618        if let Some(timeout_duration) = task_timeout {
619            match timeout(timeout_duration, stream_output).await {
620                Ok(result) => result,
621                Err(_) => TaskCompletion {
622                    output: String::new(),
623                    error: Some("Task timed out".to_string()),
624                    final_status: TaskStatus::TimedOut,
625                },
626            }
627        } else {
628            stream_output.await
629        }
630    }
631
632    async fn execute_remote_task(
633        id: TaskId,
634        command: String,
635        remote_info: RemoteConnectionInfo,
636        task_timeout: Option<Duration>,
637        cancel_rx: &mut oneshot::Receiver<()>,
638        task_tx: &mpsc::UnboundedSender<TaskMessage>,
639    ) -> TaskCompletion {
640        // Use RemoteConnectionManager to get a connection
641        let connection_manager = RemoteConnectionManager::new();
642        let connection = match connection_manager.get_connection(&remote_info).await {
643            Ok(conn) => conn,
644            Err(e) => {
645                return TaskCompletion {
646                    output: String::new(),
647                    error: Some(format!("Failed to establish remote connection: {}", e)),
648                    final_status: TaskStatus::Failed,
649                };
650            }
651        };
652
653        // Create progress callback for streaming updates
654        let task_tx_clone = task_tx.clone();
655        let id_clone = id.clone();
656        let progress_callback = move |output: String| {
657            if !output.trim().is_empty() {
658                let _ = task_tx_clone.send(TaskMessage::PartialUpdate {
659                    id: id_clone.clone(),
660                    output,
661                });
662            }
663        };
664
665        // Use unified execution with proper cancellation and timeout
666        let options = crate::remote_connection::CommandOptions {
667            timeout: task_timeout,
668            with_progress: false,
669            simple: false,
670        };
671
672        match connection
673            .execute_command_unified(&command, options, cancel_rx, Some(progress_callback), None)
674            .await
675        {
676            Ok((output, exit_code)) => TaskCompletion {
677                output,
678                error: if exit_code != 0 {
679                    Some(format!("Command exited with code {}", exit_code))
680                } else {
681                    None
682                },
683                final_status: TaskStatus::Completed,
684            },
685            Err(e) => {
686                let error_msg = e.to_string();
687                let status = if error_msg.contains("timed out") {
688                    TaskStatus::TimedOut
689                } else if error_msg.contains("cancelled") {
690                    TaskStatus::Cancelled
691                } else {
692                    TaskStatus::Failed
693                };
694
695                TaskCompletion {
696                    output: String::new(),
697                    error: Some(if error_msg.contains("cancelled") {
698                        "Tool call was cancelled and don't try to run it again".to_string()
699                    } else {
700                        format!("Remote command failed: {}", error_msg)
701                    }),
702                    final_status: status,
703                }
704            }
705        }
706    }
707
708    async fn shutdown_all_tasks(&mut self) {
709        for (_id, mut entry) in self.tasks.drain() {
710            if let Some(cancel_tx) = entry.cancel_tx.take() {
711                let _ = cancel_tx.send(());
712            }
713
714            if let Some(process_id) = entry.process_id {
715                terminate_process_group(process_id);
716            }
717
718            entry.handle.abort();
719        }
720    }
721}
722
723pub struct TaskManagerHandle {
724    tx: mpsc::UnboundedSender<TaskMessage>,
725    shutdown_tx: broadcast::Sender<()>,
726}
727
728impl Drop for TaskManagerHandle {
729    fn drop(&mut self) {
730        // Signal the TaskManager to shut down all tasks and kill child processes.
731        // This fires on the broadcast channel that TaskManager::run() listens on,
732        // triggering shutdown_all_tasks() which kills every process group.
733        //
734        // This is a last-resort safety net — callers should prefer calling
735        // handle.shutdown().await for a clean async shutdown. But if the handle
736        // is dropped without that (e.g., panic, std::process::exit, unexpected
737        // scope exit), this ensures child processes don't leak.
738        let _ = self.shutdown_tx.send(());
739    }
740}
741
742impl TaskManagerHandle {
743    pub async fn start_task(
744        &self,
745        command: String,
746        timeout: Option<Duration>,
747        remote_connection: Option<RemoteConnectionInfo>,
748    ) -> Result<TaskInfo, TaskError> {
749        let (response_tx, response_rx) = oneshot::channel();
750
751        self.tx
752            .send(TaskMessage::Start {
753                id: None,
754                command: command.clone(),
755                remote_connection: remote_connection.clone(),
756                timeout,
757                response_tx,
758            })
759            .map_err(|_| TaskError::ManagerShutdown)?;
760
761        let task_id = response_rx
762            .await
763            .map_err(|_| TaskError::ManagerShutdown)??;
764
765        // Wait for the task to start and get its status
766        tokio::time::sleep(START_TASK_WAIT_TIME).await;
767
768        let task_info = self
769            .get_task_details(task_id.clone())
770            .await
771            .map_err(|_| TaskError::ManagerShutdown)?
772            .ok_or_else(|| TaskError::TaskNotFound(task_id.clone()))?;
773
774        // If the task failed or was cancelled during start, return an error
775        if matches!(task_info.status, TaskStatus::Failed | TaskStatus::Cancelled) {
776            return Err(TaskError::TaskFailedOnStart(
777                task_info
778                    .output
779                    .unwrap_or_else(|| "Unknown reason".to_string()),
780            ));
781        }
782
783        // Return the task info with updated status
784        Ok(task_info)
785    }
786
787    pub async fn cancel_task(&self, id: TaskId) -> Result<TaskInfo, TaskError> {
788        // Get the task info before cancelling
789        let task_info = self
790            .get_all_tasks()
791            .await?
792            .into_iter()
793            .find(|task| task.id == id)
794            .ok_or_else(|| TaskError::TaskNotFound(id.clone()))?;
795
796        let (response_tx, response_rx) = oneshot::channel();
797
798        self.tx
799            .send(TaskMessage::Cancel { id, response_tx })
800            .map_err(|_| TaskError::ManagerShutdown)?;
801
802        response_rx
803            .await
804            .map_err(|_| TaskError::ManagerShutdown)??;
805
806        // Return the task info with updated status
807        Ok(TaskInfo {
808            status: TaskStatus::Cancelled,
809            duration: Some(
810                Utc::now()
811                    .signed_duration_since(task_info.start_time)
812                    .to_std()
813                    .unwrap_or_default(),
814            ),
815            ..task_info
816        })
817    }
818
819    pub async fn get_task_status(&self, id: TaskId) -> Result<Option<TaskStatus>, TaskError> {
820        let (response_tx, response_rx) = oneshot::channel();
821
822        self.tx
823            .send(TaskMessage::GetStatus { id, response_tx })
824            .map_err(|_| TaskError::ManagerShutdown)?;
825
826        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
827    }
828
829    pub async fn get_task_details(&self, id: TaskId) -> Result<Option<TaskInfo>, TaskError> {
830        let (response_tx, response_rx) = oneshot::channel();
831
832        self.tx
833            .send(TaskMessage::GetTaskDetails { id, response_tx })
834            .map_err(|_| TaskError::ManagerShutdown)?;
835
836        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
837    }
838
839    pub async fn get_all_tasks(&self) -> Result<Vec<TaskInfo>, TaskError> {
840        let (response_tx, response_rx) = oneshot::channel();
841
842        self.tx
843            .send(TaskMessage::GetAllTasks { response_tx })
844            .map_err(|_| TaskError::ManagerShutdown)?;
845
846        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
847    }
848
849    pub async fn shutdown(&self) -> Result<(), TaskError> {
850        let (response_tx, response_rx) = oneshot::channel();
851
852        self.tx
853            .send(TaskMessage::Shutdown { response_tx })
854            .map_err(|_| TaskError::ManagerShutdown)?;
855
856        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
857    }
858}
859
860#[cfg(test)]
861mod tests {
862    use super::*;
863    use tokio::time::{Duration, sleep};
864
865    #[tokio::test]
866    async fn test_task_manager_shutdown() {
867        let task_manager = TaskManager::new();
868        let handle = task_manager.handle();
869
870        // Spawn the task manager
871        let manager_handle = tokio::spawn(async move {
872            task_manager.run().await;
873        });
874
875        // Start a background task
876        let task_info = handle
877            .start_task("sleep 5".to_string(), None, None)
878            .await
879            .expect("Failed to start task");
880
881        // Verify task is running
882        let status = handle
883            .get_task_status(task_info.id.clone())
884            .await
885            .expect("Failed to get task status");
886        assert_eq!(status, Some(TaskStatus::Running));
887
888        // Shutdown the task manager
889        handle
890            .shutdown()
891            .await
892            .expect("Failed to shutdown task manager");
893
894        // Wait a bit for the shutdown to complete
895        sleep(Duration::from_millis(100)).await;
896
897        // Verify the manager task has completed
898        assert!(manager_handle.is_finished());
899    }
900
901    #[tokio::test]
902    async fn test_task_manager_cancels_tasks_on_shutdown() {
903        let task_manager = TaskManager::new();
904        let handle = task_manager.handle();
905
906        // Spawn the task manager
907        let manager_handle = tokio::spawn(async move {
908            task_manager.run().await;
909        });
910
911        // Start a long-running background task
912        let task_info = handle
913            .start_task("sleep 10".to_string(), None, None)
914            .await
915            .expect("Failed to start task");
916
917        // Verify task is running
918        let status = handle
919            .get_task_status(task_info.id.clone())
920            .await
921            .expect("Failed to get task status");
922        assert_eq!(status, Some(TaskStatus::Running));
923
924        // Shutdown the task manager
925        handle
926            .shutdown()
927            .await
928            .expect("Failed to shutdown task manager");
929
930        // Wait a bit for the shutdown to complete
931        sleep(Duration::from_millis(100)).await;
932
933        // Verify the manager task has completed
934        assert!(manager_handle.is_finished());
935    }
936
937    #[tokio::test]
938    async fn test_task_manager_start_and_complete_task() {
939        let task_manager = TaskManager::new();
940        let handle = task_manager.handle();
941
942        // Spawn the task manager
943        let _manager_handle = tokio::spawn(async move {
944            task_manager.run().await;
945        });
946
947        // Start a simple task
948        let task_info = handle
949            .start_task("echo 'Hello, World!'".to_string(), None, None)
950            .await
951            .expect("Failed to start task");
952
953        // Wait for the task to complete
954        sleep(Duration::from_millis(500)).await;
955
956        // Get task status
957        let status = handle
958            .get_task_status(task_info.id.clone())
959            .await
960            .expect("Failed to get task status");
961        assert_eq!(status, Some(TaskStatus::Completed));
962
963        // Get all tasks
964        let tasks = handle
965            .get_all_tasks()
966            .await
967            .expect("Failed to get all tasks");
968        assert_eq!(tasks.len(), 1);
969        assert_eq!(tasks[0].status, TaskStatus::Completed);
970
971        // Shutdown the task manager
972        handle
973            .shutdown()
974            .await
975            .expect("Failed to shutdown task manager");
976    }
977
978    #[tokio::test]
979    async fn test_task_manager_detects_immediate_failure() {
980        let task_manager = TaskManager::new();
981        let handle = task_manager.handle();
982
983        // Spawn the task manager
984        let _manager_handle = tokio::spawn(async move {
985            task_manager.run().await;
986        });
987
988        // Start a task that will fail immediately
989        let result = handle
990            .start_task("nonexistent_command_12345".to_string(), None, None)
991            .await;
992
993        // Should get a TaskFailedOnStart error
994        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
995
996        // Shutdown the task manager
997        handle
998            .shutdown()
999            .await
1000            .expect("Failed to shutdown task manager");
1001    }
1002
1003    #[tokio::test]
1004    async fn test_task_manager_handle_drop_triggers_shutdown() {
1005        let task_manager = TaskManager::new();
1006        let handle = task_manager.handle();
1007
1008        let manager_handle = tokio::spawn(async move {
1009            task_manager.run().await;
1010        });
1011
1012        // Start a long-running task
1013        let _task_info = handle
1014            .start_task("sleep 30".to_string(), None, None)
1015            .await
1016            .expect("Failed to start task");
1017
1018        // Drop the handle WITHOUT calling shutdown()
1019        drop(handle);
1020
1021        // The Drop impl sends on the broadcast shutdown channel,
1022        // which causes TaskManager::run() to call shutdown_all_tasks() and exit.
1023        // Give it a moment to process.
1024        sleep(Duration::from_millis(500)).await;
1025
1026        assert!(
1027            manager_handle.is_finished(),
1028            "TaskManager::run() should have exited after handle was dropped"
1029        );
1030    }
1031
1032    #[tokio::test]
1033    async fn test_task_manager_handle_drop_kills_child_processes() {
1034        let task_manager = TaskManager::new();
1035        let handle = task_manager.handle();
1036
1037        let _manager_handle = tokio::spawn(async move {
1038            task_manager.run().await;
1039        });
1040
1041        // Start a task that writes a marker file while running
1042        let marker = format!("/tmp/stakpak_test_drop_{}", std::process::id());
1043        let task_info = handle
1044            .start_task(format!("touch {} && sleep 30", marker), None, None)
1045            .await
1046            .expect("Failed to start task");
1047
1048        // Verify task is running
1049        let status = handle
1050            .get_task_status(task_info.id.clone())
1051            .await
1052            .expect("Failed to get status");
1053        assert_eq!(status, Some(TaskStatus::Running));
1054
1055        // Drop handle without explicit shutdown — Drop should kill the process
1056        drop(handle);
1057        sleep(Duration::from_millis(500)).await;
1058
1059        // Clean up marker file
1060        let _ = std::fs::remove_file(&marker);
1061    }
1062
1063    #[tokio::test]
1064    async fn test_task_manager_detects_immediate_exit_code_failure() {
1065        let task_manager = TaskManager::new();
1066        let handle = task_manager.handle();
1067
1068        // Spawn the task manager
1069        let _manager_handle = tokio::spawn(async move {
1070            task_manager.run().await;
1071        });
1072
1073        // Start a task that will exit with non-zero code immediately
1074        let result = handle.start_task("exit 1".to_string(), None, None).await;
1075
1076        // Should get a TaskFailedOnStart error
1077        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1078
1079        // Shutdown the task manager
1080        handle
1081            .shutdown()
1082            .await
1083            .expect("Failed to shutdown task manager");
1084    }
1085}