Skip to main content

stakpak_shared/
task_manager.rs

1use crate::helper::generate_simple_id;
2use crate::remote_connection::{RemoteConnectionInfo, RemoteConnectionManager};
3use chrono::{DateTime, Utc};
4use std::{collections::HashMap, process::Stdio, sync::Arc, time::Duration};
5use tokio::{
6    io::{AsyncBufReadExt, BufReader},
7    process::Command,
8    sync::{broadcast, mpsc, oneshot},
9    time::timeout,
10};
11
12const START_TASK_WAIT_TIME: Duration = Duration::from_millis(300);
13
14/// Kill a process and its entire process group.
15///
16/// Uses process group kill (`kill -9 -{pid}`) on Unix and `taskkill /F /T` on
17/// Windows to ensure child processes spawned by shells (node, vite, esbuild, etc.)
18/// are also terminated.
19///
20/// This is safe to call even if the process has already exited.
21fn terminate_process_group(process_id: u32) {
22    #[cfg(unix)]
23    {
24        use std::process::Command;
25        // First check if the process exists
26        let check_result = Command::new("kill")
27            .arg("-0") // Signal 0 just checks if process exists
28            .arg(process_id.to_string())
29            .output();
30
31        // Only kill if the process actually exists
32        if check_result
33            .map(|output| output.status.success())
34            .unwrap_or(false)
35        {
36            // Kill the entire process group using negative PID
37            // Since we spawn with .process_group(0), the shell becomes the process group leader
38            // Using -{pid} kills all processes in that group (shell + children like node/vite/esbuild)
39            let _ = Command::new("kill")
40                .arg("-9")
41                .arg(format!("-{}", process_id))
42                .output();
43
44            // Also try to kill the individual process in case it's not a group leader
45            let _ = Command::new("kill")
46                .arg("-9")
47                .arg(process_id.to_string())
48                .output();
49        }
50    }
51
52    #[cfg(windows)]
53    {
54        use std::process::Command;
55        // On Windows, use taskkill with /T flag to kill the process tree
56        let check_result = Command::new("tasklist")
57            .arg("/FI")
58            .arg(format!("PID eq {}", process_id))
59            .arg("/FO")
60            .arg("CSV")
61            .output();
62
63        // Only kill if the process actually exists
64        if let Ok(output) = check_result {
65            let output_str = String::from_utf8_lossy(&output.stdout);
66            if output_str.lines().count() > 1 {
67                // More than just header line - use /T to kill process tree
68                let _ = Command::new("taskkill")
69                    .arg("/F")
70                    .arg("/T") // Kill process tree
71                    .arg("/PID")
72                    .arg(process_id.to_string())
73                    .output();
74            }
75        }
76    }
77}
78
79pub type TaskId = String;
80
81#[derive(Debug, Clone, PartialEq, serde::Serialize)]
82pub enum TaskStatus {
83    Pending,
84    Running,
85    Completed,
86    Failed,
87    Cancelled,
88    TimedOut,
89    Paused,
90}
91
92#[derive(Debug, Clone)]
93pub struct Task {
94    pub id: TaskId,
95    pub status: TaskStatus,
96    pub command: String,
97    pub description: Option<String>,
98    pub remote_connection: Option<RemoteConnectionInfo>,
99    pub output: Option<String>,
100    pub error: Option<String>,
101    pub start_time: DateTime<Utc>,
102    pub duration: Option<Duration>,
103    pub timeout: Option<Duration>,
104    pub pause_info: Option<PauseInfo>,
105    pub child_env: HashMap<String, String>,
106}
107
108pub struct TaskEntry {
109    pub task: Task,
110    pub handle: tokio::task::JoinHandle<()>,
111    pub process_id: Option<u32>,
112    pub cancel_tx: Option<oneshot::Sender<()>>,
113}
114
115#[derive(Debug, Clone, serde::Serialize)]
116pub struct TaskInfo {
117    pub id: TaskId,
118    pub status: TaskStatus,
119    pub command: String,
120    pub description: Option<String>,
121    pub output: Option<String>,
122    pub start_time: DateTime<Utc>,
123    pub duration: Option<Duration>,
124    pub pause_info: Option<PauseInfo>,
125}
126
127impl From<&Task> for TaskInfo {
128    fn from(task: &Task) -> Self {
129        let duration = if matches!(task.status, TaskStatus::Running) {
130            // For running tasks, calculate duration from start time to now
131            Some(
132                Utc::now()
133                    .signed_duration_since(task.start_time)
134                    .to_std()
135                    .unwrap_or_default(),
136            )
137        } else {
138            // For completed/failed/cancelled tasks, use the stored duration
139            task.duration
140        };
141
142        TaskInfo {
143            id: task.id.clone(),
144            status: task.status.clone(),
145            command: task.command.clone(),
146            description: task.description.clone(),
147            output: task.output.clone(),
148            start_time: task.start_time,
149            duration,
150            pause_info: task.pause_info.clone(),
151        }
152    }
153}
154
155pub struct TaskCompletion {
156    pub output: String,
157    pub error: Option<String>,
158    pub final_status: TaskStatus,
159}
160
161fn task_completion_from_exit_status(
162    mut output: String,
163    error: Option<String>,
164    exit_status: std::process::ExitStatus,
165) -> TaskCompletion {
166    if output.is_empty() {
167        output = "No output".to_string();
168    }
169
170    if exit_status.success() {
171        TaskCompletion {
172            output,
173            error,
174            final_status: TaskStatus::Completed,
175        }
176    } else if exit_status.code() == Some(10) {
177        TaskCompletion {
178            output,
179            error: None,
180            final_status: TaskStatus::Paused,
181        }
182    } else {
183        TaskCompletion {
184            output,
185            error: error.or_else(|| {
186                Some(format!(
187                    "Command failed with exit code: {:?}",
188                    exit_status.code()
189                ))
190            }),
191            final_status: TaskStatus::Failed,
192        }
193    }
194}
195
196struct TaskExecution {
197    id: TaskId,
198    command: String,
199    remote_connection: Option<RemoteConnectionInfo>,
200    task_timeout: Option<Duration>,
201    child_env: HashMap<String, String>,
202}
203
204#[derive(Debug, Clone, Default)]
205pub struct StartTaskOptions {
206    pub description: Option<String>,
207    pub timeout: Option<Duration>,
208    pub remote_connection: Option<RemoteConnectionInfo>,
209    pub child_env: HashMap<String, String>,
210}
211
212#[derive(Debug, Clone, serde::Serialize)]
213pub struct PauseInfo {
214    pub checkpoint_id: Option<String>,
215    pub raw_output: Option<String>,
216}
217
218#[derive(Debug, thiserror::Error)]
219pub enum TaskError {
220    #[error("Task not found: {0}")]
221    TaskNotFound(TaskId),
222    #[error("Task already running: {0}")]
223    TaskAlreadyRunning(TaskId),
224    #[error("Manager shutdown")]
225    ManagerShutdown,
226    #[error("Command execution failed: {0}")]
227    ExecutionFailed(String),
228    #[error("Task timeout")]
229    TaskTimeout,
230    #[error("Task cancelled")]
231    TaskCancelled,
232    #[error("Task failed on start: {0}")]
233    TaskFailedOnStart(String),
234    #[error("Task not paused: {0}")]
235    TaskNotPaused(TaskId),
236}
237
238pub enum TaskMessage {
239    Start {
240        id: Option<TaskId>,
241        command: String,
242        options: StartTaskOptions,
243        response_tx: oneshot::Sender<Result<TaskId, TaskError>>,
244    },
245    Cancel {
246        id: TaskId,
247        response_tx: oneshot::Sender<Result<(), TaskError>>,
248    },
249    GetStatus {
250        id: TaskId,
251        response_tx: oneshot::Sender<Option<TaskStatus>>,
252    },
253    GetTaskDetails {
254        id: TaskId,
255        response_tx: oneshot::Sender<Option<TaskInfo>>,
256    },
257    GetAllTasks {
258        response_tx: oneshot::Sender<Vec<TaskInfo>>,
259    },
260    Shutdown {
261        response_tx: oneshot::Sender<()>,
262    },
263    TaskUpdate {
264        id: TaskId,
265        completion: TaskCompletion,
266    },
267    PartialUpdate {
268        id: TaskId,
269        output: String,
270    },
271    Resume {
272        id: TaskId,
273        command: String,
274        response_tx: oneshot::Sender<Result<(), TaskError>>,
275    },
276}
277
278pub struct TaskManager {
279    tasks: HashMap<TaskId, TaskEntry>,
280    tx: mpsc::UnboundedSender<TaskMessage>,
281    rx: mpsc::UnboundedReceiver<TaskMessage>,
282    shutdown_tx: broadcast::Sender<()>,
283    shutdown_rx: broadcast::Receiver<()>,
284}
285
286impl Default for TaskManager {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292impl TaskManager {
293    pub fn new() -> Self {
294        let (tx, rx) = mpsc::unbounded_channel();
295        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
296
297        Self {
298            tasks: HashMap::new(),
299            tx,
300            rx,
301            shutdown_tx,
302            shutdown_rx,
303        }
304    }
305
306    pub fn handle(&self) -> Arc<TaskManagerHandle> {
307        Arc::new(TaskManagerHandle {
308            tx: self.tx.clone(),
309            shutdown_tx: self.shutdown_tx.clone(),
310        })
311    }
312
313    pub async fn run(mut self) {
314        loop {
315            tokio::select! {
316                msg = self.rx.recv() => {
317                    match msg {
318                        Some(msg) => {
319                            if self.handle_message(msg).await {
320                                break;
321                            }
322                        }
323                        None => {
324                            // All senders (TaskManagerHandles) have been dropped.
325                            // Clean up all running tasks and child processes.
326                            self.shutdown_all_tasks().await;
327                            break;
328                        }
329                    }
330                }
331                _ = self.shutdown_rx.recv() => {
332                    self.shutdown_all_tasks().await;
333                    break;
334                }
335            }
336        }
337    }
338
339    async fn handle_message(&mut self, msg: TaskMessage) -> bool {
340        match msg {
341            TaskMessage::Start {
342                id,
343                command,
344                options,
345                response_tx,
346            } => {
347                let task_id = id.unwrap_or_else(|| generate_simple_id(6));
348                let result = self.start_task(task_id.clone(), command, options).await;
349                let _ = response_tx.send(result.map(|_| task_id.clone()));
350                false
351            }
352            TaskMessage::Cancel { id, response_tx } => {
353                let result = self.cancel_task(&id).await;
354                let _ = response_tx.send(result);
355                false
356            }
357            TaskMessage::GetStatus { id, response_tx } => {
358                let status = self.tasks.get(&id).map(|entry| entry.task.status.clone());
359                let _ = response_tx.send(status);
360                false
361            }
362            TaskMessage::GetTaskDetails { id, response_tx } => {
363                let task_info = self.tasks.get(&id).map(|entry| TaskInfo::from(&entry.task));
364                let _ = response_tx.send(task_info);
365                false
366            }
367            TaskMessage::GetAllTasks { response_tx } => {
368                let mut tasks: Vec<TaskInfo> = self
369                    .tasks
370                    .values()
371                    .map(|entry| TaskInfo::from(&entry.task))
372                    .collect();
373                tasks.sort_by(|a, b| b.start_time.cmp(&a.start_time));
374                let _ = response_tx.send(tasks);
375                false
376            }
377            TaskMessage::TaskUpdate { id, completion } => {
378                if let Some(entry) = self.tasks.get_mut(&id) {
379                    entry.task.status = completion.final_status.clone();
380                    entry.task.output = Some(completion.output.clone());
381                    entry.task.error = completion.error;
382                    entry.task.duration = Some(
383                        Utc::now()
384                            .signed_duration_since(entry.task.start_time)
385                            .to_std()
386                            .unwrap_or_default(),
387                    );
388
389                    // Extract checkpoint info for paused and completed tasks
390                    if matches!(
391                        completion.final_status,
392                        TaskStatus::Paused | TaskStatus::Completed
393                    ) {
394                        let checkpoint_id =
395                            serde_json::from_str::<serde_json::Value>(&completion.output)
396                                .ok()
397                                .and_then(|v| {
398                                    v.get("checkpoint_id")
399                                        .and_then(|c| c.as_str())
400                                        .map(|s| s.to_string())
401                                });
402                        entry.task.pause_info = Some(PauseInfo {
403                            checkpoint_id,
404                            raw_output: Some(completion.output),
405                        });
406                    }
407
408                    // Keep completed tasks in the list so they can be viewed with get_all_tasks
409                    // TODO: Consider implementing a cleanup mechanism for old completed tasks
410                    // if matches!(entry.task.status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::TimedOut) {
411                    //     self.tasks.remove(&id);
412                    // }
413                }
414                false
415            }
416            TaskMessage::PartialUpdate { id, output } => {
417                if let Some(entry) = self.tasks.get_mut(&id) {
418                    match &entry.task.output {
419                        Some(existing) => {
420                            entry.task.output = Some(format!("{}{}", existing, output));
421                        }
422                        None => {
423                            entry.task.output = Some(output);
424                        }
425                    }
426                }
427                false
428            }
429            TaskMessage::Resume {
430                id,
431                command,
432                response_tx,
433            } => {
434                let result = self.resume_task(id, command).await;
435                let _ = response_tx.send(result);
436                false
437            }
438            TaskMessage::Shutdown { response_tx } => {
439                self.shutdown_all_tasks().await;
440                let _ = response_tx.send(());
441                true
442            }
443        }
444    }
445
446    async fn start_task(
447        &mut self,
448        id: TaskId,
449        command: String,
450        options: StartTaskOptions,
451    ) -> Result<(), TaskError> {
452        if self.tasks.contains_key(&id) {
453            return Err(TaskError::TaskAlreadyRunning(id));
454        }
455
456        let StartTaskOptions {
457            description,
458            timeout,
459            remote_connection,
460            child_env,
461        } = options;
462
463        let task = Task {
464            id: id.clone(),
465            status: TaskStatus::Running,
466            command: command.clone(),
467            description,
468            remote_connection: remote_connection.clone(),
469            output: None,
470            error: None,
471            start_time: Utc::now(),
472            duration: None,
473            timeout,
474            pause_info: None,
475            child_env: child_env.clone(),
476        };
477
478        let (cancel_tx, cancel_rx) = oneshot::channel();
479        let (process_tx, process_rx) = oneshot::channel();
480        let task_tx: mpsc::UnboundedSender<TaskMessage> = self.tx.clone();
481
482        let is_remote_task = remote_connection.is_some();
483
484        // Spawn task immediately - SSH connection happens inside the task
485        let execution = TaskExecution {
486            id: id.clone(),
487            command,
488            remote_connection,
489            task_timeout: timeout,
490            child_env,
491        };
492
493        let handle = tokio::spawn(Self::execute_task(
494            execution, cancel_rx, process_tx, task_tx,
495        ));
496
497        let entry = TaskEntry {
498            task,
499            handle,
500            process_id: None,
501            cancel_tx: Some(cancel_tx),
502        };
503
504        self.tasks.insert(id.clone(), entry);
505
506        // Wait for the process ID for local tasks only
507        if !is_remote_task {
508            // Local task - wait for process ID for proper cleanup
509            if let Ok(process_id) = process_rx.await
510                && let Some(entry) = self.tasks.get_mut(&id)
511            {
512                entry.process_id = Some(process_id);
513            }
514        }
515        // Remote tasks don't have local process IDs, so we skip waiting
516
517        Ok(())
518    }
519
520    async fn resume_task(&mut self, id: TaskId, command: String) -> Result<(), TaskError> {
521        // Verify the task exists and is in a resumable state
522        if let Some(entry) = self.tasks.get(&id) {
523            if !matches!(
524                entry.task.status,
525                TaskStatus::Paused | TaskStatus::Completed
526            ) {
527                return Err(TaskError::TaskNotPaused(id));
528            }
529        } else {
530            return Err(TaskError::TaskNotFound(id));
531        }
532
533        // Update the task to Running and start a new execution
534        let entry = self.tasks.get_mut(&id).unwrap();
535        entry.task.status = TaskStatus::Running;
536        entry.task.command = command.clone();
537        entry.task.pause_info = None;
538        entry.task.output = None;
539        entry.task.error = None;
540
541        let (cancel_tx, cancel_rx) = oneshot::channel();
542        let (process_tx, process_rx) = oneshot::channel();
543        let task_tx = self.tx.clone();
544
545        let remote_connection = entry.task.remote_connection.clone();
546        let timeout = entry.task.timeout;
547        let child_env = entry.task.child_env.clone();
548
549        let execution = TaskExecution {
550            id: id.clone(),
551            command,
552            remote_connection: remote_connection.clone(),
553            task_timeout: timeout,
554            child_env,
555        };
556
557        let handle = tokio::spawn(Self::execute_task(
558            execution, cancel_rx, process_tx, task_tx,
559        ));
560
561        entry.handle = handle;
562        entry.cancel_tx = Some(cancel_tx);
563        entry.process_id = None;
564
565        // Wait for process ID for local tasks
566        if remote_connection.is_none()
567            && let Ok(process_id) = process_rx.await
568            && let Some(entry) = self.tasks.get_mut(&id)
569        {
570            entry.process_id = Some(process_id);
571        }
572
573        Ok(())
574    }
575
576    async fn cancel_task(&mut self, id: &TaskId) -> Result<(), TaskError> {
577        if let Some(mut entry) = self.tasks.remove(id) {
578            entry.task.status = TaskStatus::Cancelled;
579
580            if let Some(cancel_tx) = entry.cancel_tx.take() {
581                let _ = cancel_tx.send(());
582            }
583
584            if let Some(process_id) = entry.process_id {
585                terminate_process_group(process_id);
586            }
587
588            entry.handle.abort();
589            Ok(())
590        } else {
591            Err(TaskError::TaskNotFound(id.clone()))
592        }
593    }
594
595    async fn execute_task(
596        execution: TaskExecution,
597        mut cancel_rx: oneshot::Receiver<()>,
598        process_tx: oneshot::Sender<u32>,
599        task_tx: mpsc::UnboundedSender<TaskMessage>,
600    ) {
601        let TaskExecution {
602            id,
603            command,
604            remote_connection,
605            task_timeout,
606            child_env,
607        } = execution;
608        let completion = if let Some(remote_info) = remote_connection {
609            // Remote execution
610            Self::execute_remote_task(
611                id.clone(),
612                command,
613                remote_info,
614                task_timeout,
615                &mut cancel_rx,
616                &task_tx,
617            )
618            .await
619        } else {
620            // Local execution (existing logic)
621            Self::execute_local_task(
622                id.clone(),
623                command,
624                task_timeout,
625                &mut cancel_rx,
626                process_tx,
627                &task_tx,
628                child_env,
629            )
630            .await
631        };
632
633        // Send task completion back to manager
634        let _ = task_tx.send(TaskMessage::TaskUpdate {
635            id: id.clone(),
636            completion,
637        });
638    }
639
640    async fn execute_local_task(
641        id: TaskId,
642        command: String,
643        task_timeout: Option<Duration>,
644        cancel_rx: &mut oneshot::Receiver<()>,
645        process_tx: oneshot::Sender<u32>,
646        task_tx: &mpsc::UnboundedSender<TaskMessage>,
647        child_env: HashMap<String, String>,
648    ) -> TaskCompletion {
649        let mut cmd = Command::new("sh");
650        cmd.arg("-c")
651            .arg(&command)
652            .stdin(Stdio::null())
653            .stdout(Stdio::piped())
654            .stderr(Stdio::piped());
655        for (key, value) in child_env {
656            cmd.env(key, value);
657        }
658        #[cfg(unix)]
659        {
660            cmd.env("DEBIAN_FRONTEND", "noninteractive")
661                .env("SUDO_ASKPASS", "/bin/false")
662                .process_group(0);
663        }
664        #[cfg(windows)]
665        {
666            // On Windows, create a new process group
667            cmd.creation_flags(0x00000200); // CREATE_NEW_PROCESS_GROUP
668        }
669
670        let mut child = match cmd.spawn() {
671            Ok(child) => child,
672            Err(err) => {
673                return TaskCompletion {
674                    output: String::new(),
675                    error: Some(format!("Failed to spawn command: {}", err)),
676                    final_status: TaskStatus::Failed,
677                };
678            }
679        };
680
681        // Send the process ID back to the manager for tracking
682        if let Some(process_id) = child.id() {
683            let _ = process_tx.send(process_id);
684        }
685
686        // Take stdout and stderr for streaming
687        let stdout = child.stdout.take().unwrap();
688        let stderr = child.stderr.take().unwrap();
689
690        let stdout_reader = BufReader::new(stdout);
691        let stderr_reader = BufReader::new(stderr);
692
693        let mut stdout_lines = stdout_reader.lines();
694        let mut stderr_lines = stderr_reader.lines();
695
696        // Helper function to stream output and handle cancellation
697        let stream_output = async {
698            let mut final_output = String::new();
699            let mut final_error: Option<String> = None;
700            let mut stdout_done = false;
701            let mut stderr_done = false;
702
703            loop {
704                tokio::select! {
705                    line = stdout_lines.next_line(), if !stdout_done => {
706                        match line {
707                            Ok(Some(line)) => {
708                                let output_line = format!("{}\n", line);
709                                final_output.push_str(&output_line);
710                                let _ = task_tx.send(TaskMessage::PartialUpdate {
711                                    id: id.clone(),
712                                    output: output_line,
713                                });
714                            }
715                            Ok(None) => {
716                                // stdout stream ended
717                                stdout_done = true;
718                            }
719                            Err(err) => {
720                                final_error = Some(format!("Error reading stdout: {}", err));
721                                break;
722                            }
723                        }
724                    }
725                    line = stderr_lines.next_line(), if !stderr_done => {
726                        match line {
727                            Ok(Some(line)) => {
728                                let output_line = format!("{}\n", line);
729                                final_output.push_str(&output_line);
730                                let _ = task_tx.send(TaskMessage::PartialUpdate {
731                                    id: id.clone(),
732                                    output: output_line,
733                                });
734                            }
735                            Ok(None) => {
736                                // stderr stream ended
737                                stderr_done = true;
738                            }
739                            Err(err) => {
740                                final_error = Some(format!("Error reading stderr: {}", err));
741                                break;
742                            }
743                        }
744                    }
745                    status = child.wait() => {
746                        match status {
747                            Ok(exit_status) => {
748                                while !stdout_done {
749                                    match stdout_lines.next_line().await {
750                                        Ok(Some(line)) => {
751                                            let output_line = format!("{}\n", line);
752                                            final_output.push_str(&output_line);
753                                            let _ = task_tx.send(TaskMessage::PartialUpdate {
754                                                id: id.clone(),
755                                                output: output_line,
756                                            });
757                                        }
758                                        Ok(None) => stdout_done = true,
759                                        Err(err) => {
760                                            final_error = Some(format!("Error reading stdout: {}", err));
761                                            break;
762                                        }
763                                    }
764                                }
765
766                                while !stderr_done {
767                                    match stderr_lines.next_line().await {
768                                        Ok(Some(line)) => {
769                                            let output_line = format!("{}\n", line);
770                                            final_output.push_str(&output_line);
771                                            let _ = task_tx.send(TaskMessage::PartialUpdate {
772                                                id: id.clone(),
773                                                output: output_line,
774                                            });
775                                        }
776                                        Ok(None) => stderr_done = true,
777                                        Err(err) => {
778                                            final_error = Some(format!("Error reading stderr: {}", err));
779                                            break;
780                                        }
781                                    }
782                                }
783
784                                return task_completion_from_exit_status(
785                                    final_output,
786                                    final_error,
787                                    exit_status,
788                                );
789                            }
790                            Err(err) => {
791                                return TaskCompletion {
792                                    output: final_output,
793                                    error: Some(err.to_string()),
794                                    final_status: TaskStatus::Failed,
795                                };
796                            }
797                        }
798                    }
799                    _ = &mut *cancel_rx => {
800                        return TaskCompletion {
801                            output: final_output,
802                            error: Some("Tool call was cancelled and don't try to run it again".to_string()),
803                            final_status: TaskStatus::Cancelled,
804                        };
805                    }
806                }
807            }
808
809            TaskCompletion {
810                output: final_output,
811                error: final_error,
812                final_status: TaskStatus::Failed,
813            }
814        };
815
816        // Execute with timeout if provided
817        if let Some(timeout_duration) = task_timeout {
818            match timeout(timeout_duration, stream_output).await {
819                Ok(result) => result,
820                Err(_) => TaskCompletion {
821                    output: String::new(),
822                    error: Some("Task timed out".to_string()),
823                    final_status: TaskStatus::TimedOut,
824                },
825            }
826        } else {
827            stream_output.await
828        }
829    }
830
831    async fn execute_remote_task(
832        id: TaskId,
833        command: String,
834        remote_info: RemoteConnectionInfo,
835        task_timeout: Option<Duration>,
836        cancel_rx: &mut oneshot::Receiver<()>,
837        task_tx: &mpsc::UnboundedSender<TaskMessage>,
838    ) -> TaskCompletion {
839        // Use RemoteConnectionManager to get a connection
840        let connection_manager = RemoteConnectionManager::new();
841        let connection = match connection_manager.get_connection(&remote_info).await {
842            Ok(conn) => conn,
843            Err(e) => {
844                return TaskCompletion {
845                    output: String::new(),
846                    error: Some(format!("Failed to establish remote connection: {}", e)),
847                    final_status: TaskStatus::Failed,
848                };
849            }
850        };
851
852        // Create progress callback for streaming updates
853        let task_tx_clone = task_tx.clone();
854        let id_clone = id.clone();
855        let progress_callback = move |output: String| {
856            if !output.trim().is_empty() {
857                let _ = task_tx_clone.send(TaskMessage::PartialUpdate {
858                    id: id_clone.clone(),
859                    output,
860                });
861            }
862        };
863
864        // Use unified execution with proper cancellation and timeout
865        let options = crate::remote_connection::CommandOptions {
866            timeout: task_timeout,
867            with_progress: false,
868            simple: false,
869        };
870
871        match connection
872            .execute_command_unified(&command, options, cancel_rx, Some(progress_callback), None)
873            .await
874        {
875            Ok((output, exit_code)) => TaskCompletion {
876                output,
877                error: if exit_code != 0 {
878                    Some(format!("Command exited with code {}", exit_code))
879                } else {
880                    None
881                },
882                final_status: TaskStatus::Completed,
883            },
884            Err(e) => {
885                let error_msg = e.to_string();
886                let status = if error_msg.contains("timed out") {
887                    TaskStatus::TimedOut
888                } else if error_msg.contains("cancelled") {
889                    TaskStatus::Cancelled
890                } else {
891                    TaskStatus::Failed
892                };
893
894                TaskCompletion {
895                    output: String::new(),
896                    error: Some(if error_msg.contains("cancelled") {
897                        "Tool call was cancelled and don't try to run it again".to_string()
898                    } else {
899                        format!("Remote command failed: {}", error_msg)
900                    }),
901                    final_status: status,
902                }
903            }
904        }
905    }
906
907    async fn shutdown_all_tasks(&mut self) {
908        for (_id, mut entry) in self.tasks.drain() {
909            if let Some(cancel_tx) = entry.cancel_tx.take() {
910                let _ = cancel_tx.send(());
911            }
912
913            if let Some(process_id) = entry.process_id {
914                terminate_process_group(process_id);
915            }
916
917            entry.handle.abort();
918        }
919    }
920}
921
922pub struct TaskManagerHandle {
923    tx: mpsc::UnboundedSender<TaskMessage>,
924    shutdown_tx: broadcast::Sender<()>,
925}
926
927impl Drop for TaskManagerHandle {
928    fn drop(&mut self) {
929        // Signal the TaskManager to shut down all tasks and kill child processes.
930        // This fires on the broadcast channel that TaskManager::run() listens on,
931        // triggering shutdown_all_tasks() which kills every process group.
932        //
933        // This is a last-resort safety net — callers should prefer calling
934        // handle.shutdown().await for a clean async shutdown. But if the handle
935        // is dropped without that (e.g., panic, std::process::exit, unexpected
936        // scope exit), this ensures child processes don't leak.
937        let _ = self.shutdown_tx.send(());
938    }
939}
940
941impl TaskManagerHandle {
942    pub async fn start_task(
943        &self,
944        command: String,
945        options: StartTaskOptions,
946    ) -> Result<TaskInfo, TaskError> {
947        let (response_tx, response_rx) = oneshot::channel();
948
949        self.tx
950            .send(TaskMessage::Start {
951                id: None,
952                command: command.clone(),
953                options,
954                response_tx,
955            })
956            .map_err(|_| TaskError::ManagerShutdown)?;
957
958        let task_id = response_rx
959            .await
960            .map_err(|_| TaskError::ManagerShutdown)??;
961
962        // Wait for the task to start and get its status
963        tokio::time::sleep(START_TASK_WAIT_TIME).await;
964
965        let task_info = self
966            .get_task_details(task_id.clone())
967            .await
968            .map_err(|_| TaskError::ManagerShutdown)?
969            .ok_or_else(|| TaskError::TaskNotFound(task_id.clone()))?;
970
971        // If the task failed or was cancelled during start, return an error
972        if matches!(task_info.status, TaskStatus::Failed | TaskStatus::Cancelled) {
973            return Err(TaskError::TaskFailedOnStart(
974                task_info
975                    .output
976                    .unwrap_or_else(|| "Unknown reason".to_string()),
977            ));
978        }
979
980        // Return the task info with updated status
981        Ok(task_info)
982    }
983
984    pub async fn cancel_task(&self, id: TaskId) -> Result<TaskInfo, TaskError> {
985        // Get the task info before cancelling
986        let task_info = self
987            .get_all_tasks()
988            .await?
989            .into_iter()
990            .find(|task| task.id == id)
991            .ok_or_else(|| TaskError::TaskNotFound(id.clone()))?;
992
993        let (response_tx, response_rx) = oneshot::channel();
994
995        self.tx
996            .send(TaskMessage::Cancel { id, response_tx })
997            .map_err(|_| TaskError::ManagerShutdown)?;
998
999        response_rx
1000            .await
1001            .map_err(|_| TaskError::ManagerShutdown)??;
1002
1003        // Return the task info with updated status
1004        Ok(TaskInfo {
1005            status: TaskStatus::Cancelled,
1006            duration: Some(
1007                Utc::now()
1008                    .signed_duration_since(task_info.start_time)
1009                    .to_std()
1010                    .unwrap_or_default(),
1011            ),
1012            ..task_info
1013        })
1014    }
1015
1016    pub async fn resume_task(&self, id: TaskId, command: String) -> Result<TaskInfo, TaskError> {
1017        let (response_tx, response_rx) = oneshot::channel();
1018
1019        self.tx
1020            .send(TaskMessage::Resume {
1021                id: id.clone(),
1022                command,
1023                response_tx,
1024            })
1025            .map_err(|_| TaskError::ManagerShutdown)?;
1026
1027        response_rx
1028            .await
1029            .map_err(|_| TaskError::ManagerShutdown)??;
1030
1031        // Wait for the task to start
1032        tokio::time::sleep(START_TASK_WAIT_TIME).await;
1033
1034        let task_info = self
1035            .get_task_details(id.clone())
1036            .await
1037            .map_err(|_| TaskError::ManagerShutdown)?
1038            .ok_or(TaskError::TaskNotFound(id))?;
1039
1040        Ok(task_info)
1041    }
1042
1043    pub async fn get_task_status(&self, id: TaskId) -> Result<Option<TaskStatus>, TaskError> {
1044        let (response_tx, response_rx) = oneshot::channel();
1045
1046        self.tx
1047            .send(TaskMessage::GetStatus { id, response_tx })
1048            .map_err(|_| TaskError::ManagerShutdown)?;
1049
1050        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1051    }
1052
1053    pub async fn get_task_details(&self, id: TaskId) -> Result<Option<TaskInfo>, TaskError> {
1054        let (response_tx, response_rx) = oneshot::channel();
1055
1056        self.tx
1057            .send(TaskMessage::GetTaskDetails { id, response_tx })
1058            .map_err(|_| TaskError::ManagerShutdown)?;
1059
1060        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1061    }
1062
1063    pub async fn get_all_tasks(&self) -> Result<Vec<TaskInfo>, TaskError> {
1064        let (response_tx, response_rx) = oneshot::channel();
1065
1066        self.tx
1067            .send(TaskMessage::GetAllTasks { response_tx })
1068            .map_err(|_| TaskError::ManagerShutdown)?;
1069
1070        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1071    }
1072
1073    pub async fn shutdown(&self) -> Result<(), TaskError> {
1074        let (response_tx, response_rx) = oneshot::channel();
1075
1076        self.tx
1077            .send(TaskMessage::Shutdown { response_tx })
1078            .map_err(|_| TaskError::ManagerShutdown)?;
1079
1080        response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1081    }
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086    use super::*;
1087    use tokio::time::{Duration, sleep};
1088
1089    #[tokio::test]
1090    async fn test_task_manager_shutdown() {
1091        let task_manager = TaskManager::new();
1092        let handle = task_manager.handle();
1093
1094        // Spawn the task manager
1095        let manager_handle = tokio::spawn(async move {
1096            task_manager.run().await;
1097        });
1098
1099        // Start a background task
1100        let task_info = handle
1101            .start_task("sleep 5".to_string(), StartTaskOptions::default())
1102            .await
1103            .expect("Failed to start task");
1104
1105        // Verify task is running
1106        let status = handle
1107            .get_task_status(task_info.id.clone())
1108            .await
1109            .expect("Failed to get task status");
1110        assert_eq!(status, Some(TaskStatus::Running));
1111
1112        // Shutdown the task manager
1113        handle
1114            .shutdown()
1115            .await
1116            .expect("Failed to shutdown task manager");
1117
1118        // Wait a bit for the shutdown to complete
1119        sleep(Duration::from_millis(100)).await;
1120
1121        // Verify the manager task has completed
1122        assert!(manager_handle.is_finished());
1123    }
1124
1125    #[tokio::test]
1126    async fn test_task_manager_cancels_tasks_on_shutdown() {
1127        let task_manager = TaskManager::new();
1128        let handle = task_manager.handle();
1129
1130        // Spawn the task manager
1131        let manager_handle = tokio::spawn(async move {
1132            task_manager.run().await;
1133        });
1134
1135        // Start a long-running background task
1136        let task_info = handle
1137            .start_task("sleep 10".to_string(), StartTaskOptions::default())
1138            .await
1139            .expect("Failed to start task");
1140
1141        // Verify task is running
1142        let status = handle
1143            .get_task_status(task_info.id.clone())
1144            .await
1145            .expect("Failed to get task status");
1146        assert_eq!(status, Some(TaskStatus::Running));
1147
1148        // Shutdown the task manager
1149        handle
1150            .shutdown()
1151            .await
1152            .expect("Failed to shutdown task manager");
1153
1154        // Wait a bit for the shutdown to complete
1155        sleep(Duration::from_millis(100)).await;
1156
1157        // Verify the manager task has completed
1158        assert!(manager_handle.is_finished());
1159    }
1160
1161    #[tokio::test]
1162    async fn test_task_manager_start_and_complete_task() {
1163        let task_manager = TaskManager::new();
1164        let handle = task_manager.handle();
1165
1166        // Spawn the task manager
1167        let _manager_handle = tokio::spawn(async move {
1168            task_manager.run().await;
1169        });
1170
1171        // Start a simple task
1172        let task_info = handle
1173            .start_task(
1174                "echo 'Hello, World!'".to_string(),
1175                StartTaskOptions::default(),
1176            )
1177            .await
1178            .expect("Failed to start task");
1179
1180        // Wait for the task to complete
1181        sleep(Duration::from_millis(500)).await;
1182
1183        // Get task status
1184        let status = handle
1185            .get_task_status(task_info.id.clone())
1186            .await
1187            .expect("Failed to get task status");
1188        assert_eq!(status, Some(TaskStatus::Completed));
1189
1190        // Get all tasks
1191        let tasks = handle
1192            .get_all_tasks()
1193            .await
1194            .expect("Failed to get all tasks");
1195        assert_eq!(tasks.len(), 1);
1196        assert_eq!(tasks[0].status, TaskStatus::Completed);
1197
1198        // Shutdown the task manager
1199        handle
1200            .shutdown()
1201            .await
1202            .expect("Failed to shutdown task manager");
1203    }
1204
1205    #[tokio::test]
1206    async fn task_manager_local_task_receives_child_env_defaults() {
1207        let task_manager = TaskManager::new();
1208        let handle = task_manager.handle();
1209
1210        let _manager_handle = tokio::spawn(async move {
1211            task_manager.run().await;
1212        });
1213
1214        let mut child_env = HashMap::new();
1215        child_env.insert("STAKPAK_PROFILE".to_string(), "ops".to_string());
1216
1217        let task_info = handle
1218            .start_task(
1219                "printf '%s\\n' \"$STAKPAK_PROFILE\"".to_string(),
1220                StartTaskOptions {
1221                    child_env,
1222                    ..StartTaskOptions::default()
1223                },
1224            )
1225            .await
1226            .expect("task should start with child env defaults");
1227
1228        sleep(Duration::from_millis(500)).await;
1229
1230        let details = handle
1231            .get_task_details(task_info.id.clone())
1232            .await
1233            .expect("task details request should succeed")
1234            .expect("task details should exist");
1235
1236        assert_eq!(details.status, TaskStatus::Completed);
1237        assert_eq!(details.output.as_deref(), Some("ops\n"));
1238
1239        handle
1240            .shutdown()
1241            .await
1242            .expect("Failed to shutdown task manager");
1243    }
1244
1245    #[tokio::test]
1246    async fn resumed_local_task_reuses_child_env_defaults() {
1247        let task_manager = TaskManager::new();
1248        let handle = task_manager.handle();
1249
1250        let _manager_handle = tokio::spawn(async move {
1251            task_manager.run().await;
1252        });
1253
1254        let mut child_env = HashMap::new();
1255        child_env.insert("STAKPAK_PROFILE".to_string(), "ops".to_string());
1256
1257        let task_info = handle
1258            .start_task(
1259                "exit 10".to_string(),
1260                StartTaskOptions {
1261                    child_env,
1262                    ..StartTaskOptions::default()
1263                },
1264            )
1265            .await
1266            .expect("pausing task should start");
1267
1268        sleep(Duration::from_millis(500)).await;
1269
1270        let paused = handle
1271            .get_task_details(task_info.id.clone())
1272            .await
1273            .expect("task details request should succeed")
1274            .expect("task details should exist");
1275        assert_eq!(paused.status, TaskStatus::Paused);
1276
1277        handle
1278            .resume_task(
1279                task_info.id.clone(),
1280                "printf '%s\\n' \"$STAKPAK_PROFILE\"".to_string(),
1281            )
1282            .await
1283            .expect("paused task should resume");
1284
1285        sleep(Duration::from_millis(500)).await;
1286
1287        let details = handle
1288            .get_task_details(task_info.id)
1289            .await
1290            .expect("task details request should succeed")
1291            .expect("task details should exist");
1292
1293        assert_eq!(details.status, TaskStatus::Completed);
1294        assert_eq!(details.output.as_deref(), Some("ops\n"));
1295
1296        handle
1297            .shutdown()
1298            .await
1299            .expect("Failed to shutdown task manager");
1300    }
1301
1302    #[tokio::test]
1303    async fn test_task_manager_detects_immediate_failure() {
1304        let task_manager = TaskManager::new();
1305        let handle = task_manager.handle();
1306
1307        // Spawn the task manager
1308        let _manager_handle = tokio::spawn(async move {
1309            task_manager.run().await;
1310        });
1311
1312        // Start a task that will fail immediately
1313        let result = handle
1314            .start_task(
1315                "nonexistent_command_12345".to_string(),
1316                StartTaskOptions::default(),
1317            )
1318            .await;
1319
1320        // Should get a TaskFailedOnStart error
1321        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1322
1323        // Shutdown the task manager
1324        handle
1325            .shutdown()
1326            .await
1327            .expect("Failed to shutdown task manager");
1328    }
1329
1330    #[tokio::test]
1331    async fn test_task_manager_handle_drop_triggers_shutdown() {
1332        let task_manager = TaskManager::new();
1333        let handle = task_manager.handle();
1334
1335        let manager_handle = tokio::spawn(async move {
1336            task_manager.run().await;
1337        });
1338
1339        // Start a long-running task
1340        let _task_info = handle
1341            .start_task("sleep 30".to_string(), StartTaskOptions::default())
1342            .await
1343            .expect("Failed to start task");
1344
1345        // Drop the handle WITHOUT calling shutdown()
1346        drop(handle);
1347
1348        // The Drop impl sends on the broadcast shutdown channel,
1349        // which causes TaskManager::run() to call shutdown_all_tasks() and exit.
1350        // Give it a moment to process.
1351        sleep(Duration::from_millis(500)).await;
1352
1353        assert!(
1354            manager_handle.is_finished(),
1355            "TaskManager::run() should have exited after handle was dropped"
1356        );
1357    }
1358
1359    #[tokio::test]
1360    async fn test_task_manager_handle_drop_kills_child_processes() {
1361        let task_manager = TaskManager::new();
1362        let handle = task_manager.handle();
1363
1364        let _manager_handle = tokio::spawn(async move {
1365            task_manager.run().await;
1366        });
1367
1368        // Start a task that writes a marker file while running
1369        let marker = format!("/tmp/stakpak_test_drop_{}", std::process::id());
1370        let task_info = handle
1371            .start_task(
1372                format!("touch {} && sleep 30", marker),
1373                StartTaskOptions::default(),
1374            )
1375            .await
1376            .expect("Failed to start task");
1377
1378        // Verify task is running
1379        let status = handle
1380            .get_task_status(task_info.id.clone())
1381            .await
1382            .expect("Failed to get status");
1383        assert_eq!(status, Some(TaskStatus::Running));
1384
1385        // Drop handle without explicit shutdown — Drop should kill the process
1386        drop(handle);
1387        sleep(Duration::from_millis(500)).await;
1388
1389        // Clean up marker file
1390        let _ = std::fs::remove_file(&marker);
1391    }
1392
1393    #[tokio::test]
1394    async fn test_task_manager_detects_immediate_exit_code_failure() {
1395        let task_manager = TaskManager::new();
1396        let handle = task_manager.handle();
1397
1398        // Spawn the task manager
1399        let _manager_handle = tokio::spawn(async move {
1400            task_manager.run().await;
1401        });
1402
1403        // Start a task that will exit with non-zero code immediately
1404        let result = handle
1405            .start_task("exit 1".to_string(), StartTaskOptions::default())
1406            .await;
1407
1408        // Should get a TaskFailedOnStart error
1409        assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1410
1411        // Shutdown the task manager
1412        handle
1413            .shutdown()
1414            .await
1415            .expect("Failed to shutdown task manager");
1416    }
1417}