1use crate::helper::generate_simple_id;
2use crate::remote_connection::{RemoteConnectionInfo, RemoteConnectionManager};
3use chrono::{DateTime, Utc};
4use std::{collections::HashMap, process::Stdio, sync::Arc, time::Duration};
5use tokio::{
6 io::{AsyncBufReadExt, BufReader},
7 process::Command,
8 sync::{broadcast, mpsc, oneshot},
9 time::timeout,
10};
11
12const START_TASK_WAIT_TIME: Duration = Duration::from_millis(300);
13
14fn terminate_process_group(process_id: u32) {
22 #[cfg(unix)]
23 {
24 use std::process::Command;
25 let check_result = Command::new("kill")
27 .arg("-0") .arg(process_id.to_string())
29 .output();
30
31 if check_result
33 .map(|output| output.status.success())
34 .unwrap_or(false)
35 {
36 let _ = Command::new("kill")
40 .arg("-9")
41 .arg(format!("-{}", process_id))
42 .output();
43
44 let _ = Command::new("kill")
46 .arg("-9")
47 .arg(process_id.to_string())
48 .output();
49 }
50 }
51
52 #[cfg(windows)]
53 {
54 use std::process::Command;
55 let check_result = Command::new("tasklist")
57 .arg("/FI")
58 .arg(format!("PID eq {}", process_id))
59 .arg("/FO")
60 .arg("CSV")
61 .output();
62
63 if let Ok(output) = check_result {
65 let output_str = String::from_utf8_lossy(&output.stdout);
66 if output_str.lines().count() > 1 {
67 let _ = Command::new("taskkill")
69 .arg("/F")
70 .arg("/T") .arg("/PID")
72 .arg(process_id.to_string())
73 .output();
74 }
75 }
76 }
77}
78
79pub type TaskId = String;
80
81#[derive(Debug, Clone, PartialEq, serde::Serialize)]
82pub enum TaskStatus {
83 Pending,
84 Running,
85 Completed,
86 Failed,
87 Cancelled,
88 TimedOut,
89 Paused,
90}
91
92#[derive(Debug, Clone)]
93pub struct Task {
94 pub id: TaskId,
95 pub status: TaskStatus,
96 pub command: String,
97 pub description: Option<String>,
98 pub remote_connection: Option<RemoteConnectionInfo>,
99 pub output: Option<String>,
100 pub error: Option<String>,
101 pub start_time: DateTime<Utc>,
102 pub duration: Option<Duration>,
103 pub timeout: Option<Duration>,
104 pub pause_info: Option<PauseInfo>,
105 pub child_env: HashMap<String, String>,
106}
107
108pub struct TaskEntry {
109 pub task: Task,
110 pub handle: tokio::task::JoinHandle<()>,
111 pub process_id: Option<u32>,
112 pub cancel_tx: Option<oneshot::Sender<()>>,
113}
114
115#[derive(Debug, Clone, serde::Serialize)]
116pub struct TaskInfo {
117 pub id: TaskId,
118 pub status: TaskStatus,
119 pub command: String,
120 pub description: Option<String>,
121 pub output: Option<String>,
122 pub start_time: DateTime<Utc>,
123 pub duration: Option<Duration>,
124 pub pause_info: Option<PauseInfo>,
125}
126
127impl From<&Task> for TaskInfo {
128 fn from(task: &Task) -> Self {
129 let duration = if matches!(task.status, TaskStatus::Running) {
130 Some(
132 Utc::now()
133 .signed_duration_since(task.start_time)
134 .to_std()
135 .unwrap_or_default(),
136 )
137 } else {
138 task.duration
140 };
141
142 TaskInfo {
143 id: task.id.clone(),
144 status: task.status.clone(),
145 command: task.command.clone(),
146 description: task.description.clone(),
147 output: task.output.clone(),
148 start_time: task.start_time,
149 duration,
150 pause_info: task.pause_info.clone(),
151 }
152 }
153}
154
155pub struct TaskCompletion {
156 pub output: String,
157 pub error: Option<String>,
158 pub final_status: TaskStatus,
159}
160
161fn task_completion_from_exit_status(
162 mut output: String,
163 error: Option<String>,
164 exit_status: std::process::ExitStatus,
165) -> TaskCompletion {
166 if output.is_empty() {
167 output = "No output".to_string();
168 }
169
170 if exit_status.success() {
171 TaskCompletion {
172 output,
173 error,
174 final_status: TaskStatus::Completed,
175 }
176 } else if exit_status.code() == Some(10) {
177 TaskCompletion {
178 output,
179 error: None,
180 final_status: TaskStatus::Paused,
181 }
182 } else {
183 TaskCompletion {
184 output,
185 error: error.or_else(|| {
186 Some(format!(
187 "Command failed with exit code: {:?}",
188 exit_status.code()
189 ))
190 }),
191 final_status: TaskStatus::Failed,
192 }
193 }
194}
195
196struct TaskExecution {
197 id: TaskId,
198 command: String,
199 remote_connection: Option<RemoteConnectionInfo>,
200 task_timeout: Option<Duration>,
201 child_env: HashMap<String, String>,
202}
203
204#[derive(Debug, Clone, Default)]
205pub struct StartTaskOptions {
206 pub description: Option<String>,
207 pub timeout: Option<Duration>,
208 pub remote_connection: Option<RemoteConnectionInfo>,
209 pub child_env: HashMap<String, String>,
210}
211
212#[derive(Debug, Clone, serde::Serialize)]
213pub struct PauseInfo {
214 pub checkpoint_id: Option<String>,
215 pub raw_output: Option<String>,
216}
217
218#[derive(Debug, thiserror::Error)]
219pub enum TaskError {
220 #[error("Task not found: {0}")]
221 TaskNotFound(TaskId),
222 #[error("Task already running: {0}")]
223 TaskAlreadyRunning(TaskId),
224 #[error("Manager shutdown")]
225 ManagerShutdown,
226 #[error("Command execution failed: {0}")]
227 ExecutionFailed(String),
228 #[error("Task timeout")]
229 TaskTimeout,
230 #[error("Task cancelled")]
231 TaskCancelled,
232 #[error("Task failed on start: {0}")]
233 TaskFailedOnStart(String),
234 #[error("Task not paused: {0}")]
235 TaskNotPaused(TaskId),
236}
237
238pub enum TaskMessage {
239 Start {
240 id: Option<TaskId>,
241 command: String,
242 options: StartTaskOptions,
243 response_tx: oneshot::Sender<Result<TaskId, TaskError>>,
244 },
245 Cancel {
246 id: TaskId,
247 response_tx: oneshot::Sender<Result<(), TaskError>>,
248 },
249 GetStatus {
250 id: TaskId,
251 response_tx: oneshot::Sender<Option<TaskStatus>>,
252 },
253 GetTaskDetails {
254 id: TaskId,
255 response_tx: oneshot::Sender<Option<TaskInfo>>,
256 },
257 GetAllTasks {
258 response_tx: oneshot::Sender<Vec<TaskInfo>>,
259 },
260 Shutdown {
261 response_tx: oneshot::Sender<()>,
262 },
263 TaskUpdate {
264 id: TaskId,
265 completion: TaskCompletion,
266 },
267 PartialUpdate {
268 id: TaskId,
269 output: String,
270 },
271 Resume {
272 id: TaskId,
273 command: String,
274 response_tx: oneshot::Sender<Result<(), TaskError>>,
275 },
276}
277
278pub struct TaskManager {
279 tasks: HashMap<TaskId, TaskEntry>,
280 tx: mpsc::UnboundedSender<TaskMessage>,
281 rx: mpsc::UnboundedReceiver<TaskMessage>,
282 shutdown_tx: broadcast::Sender<()>,
283 shutdown_rx: broadcast::Receiver<()>,
284}
285
286impl Default for TaskManager {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292impl TaskManager {
293 pub fn new() -> Self {
294 let (tx, rx) = mpsc::unbounded_channel();
295 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
296
297 Self {
298 tasks: HashMap::new(),
299 tx,
300 rx,
301 shutdown_tx,
302 shutdown_rx,
303 }
304 }
305
306 pub fn handle(&self) -> Arc<TaskManagerHandle> {
307 Arc::new(TaskManagerHandle {
308 tx: self.tx.clone(),
309 shutdown_tx: self.shutdown_tx.clone(),
310 })
311 }
312
313 pub async fn run(mut self) {
314 loop {
315 tokio::select! {
316 msg = self.rx.recv() => {
317 match msg {
318 Some(msg) => {
319 if self.handle_message(msg).await {
320 break;
321 }
322 }
323 None => {
324 self.shutdown_all_tasks().await;
327 break;
328 }
329 }
330 }
331 _ = self.shutdown_rx.recv() => {
332 self.shutdown_all_tasks().await;
333 break;
334 }
335 }
336 }
337 }
338
339 async fn handle_message(&mut self, msg: TaskMessage) -> bool {
340 match msg {
341 TaskMessage::Start {
342 id,
343 command,
344 options,
345 response_tx,
346 } => {
347 let task_id = id.unwrap_or_else(|| generate_simple_id(6));
348 let result = self.start_task(task_id.clone(), command, options).await;
349 let _ = response_tx.send(result.map(|_| task_id.clone()));
350 false
351 }
352 TaskMessage::Cancel { id, response_tx } => {
353 let result = self.cancel_task(&id).await;
354 let _ = response_tx.send(result);
355 false
356 }
357 TaskMessage::GetStatus { id, response_tx } => {
358 let status = self.tasks.get(&id).map(|entry| entry.task.status.clone());
359 let _ = response_tx.send(status);
360 false
361 }
362 TaskMessage::GetTaskDetails { id, response_tx } => {
363 let task_info = self.tasks.get(&id).map(|entry| TaskInfo::from(&entry.task));
364 let _ = response_tx.send(task_info);
365 false
366 }
367 TaskMessage::GetAllTasks { response_tx } => {
368 let mut tasks: Vec<TaskInfo> = self
369 .tasks
370 .values()
371 .map(|entry| TaskInfo::from(&entry.task))
372 .collect();
373 tasks.sort_by(|a, b| b.start_time.cmp(&a.start_time));
374 let _ = response_tx.send(tasks);
375 false
376 }
377 TaskMessage::TaskUpdate { id, completion } => {
378 if let Some(entry) = self.tasks.get_mut(&id) {
379 entry.task.status = completion.final_status.clone();
380 entry.task.output = Some(completion.output.clone());
381 entry.task.error = completion.error;
382 entry.task.duration = Some(
383 Utc::now()
384 .signed_duration_since(entry.task.start_time)
385 .to_std()
386 .unwrap_or_default(),
387 );
388
389 if matches!(
391 completion.final_status,
392 TaskStatus::Paused | TaskStatus::Completed
393 ) {
394 let checkpoint_id =
395 serde_json::from_str::<serde_json::Value>(&completion.output)
396 .ok()
397 .and_then(|v| {
398 v.get("checkpoint_id")
399 .and_then(|c| c.as_str())
400 .map(|s| s.to_string())
401 });
402 entry.task.pause_info = Some(PauseInfo {
403 checkpoint_id,
404 raw_output: Some(completion.output),
405 });
406 }
407
408 }
414 false
415 }
416 TaskMessage::PartialUpdate { id, output } => {
417 if let Some(entry) = self.tasks.get_mut(&id) {
418 match &entry.task.output {
419 Some(existing) => {
420 entry.task.output = Some(format!("{}{}", existing, output));
421 }
422 None => {
423 entry.task.output = Some(output);
424 }
425 }
426 }
427 false
428 }
429 TaskMessage::Resume {
430 id,
431 command,
432 response_tx,
433 } => {
434 let result = self.resume_task(id, command).await;
435 let _ = response_tx.send(result);
436 false
437 }
438 TaskMessage::Shutdown { response_tx } => {
439 self.shutdown_all_tasks().await;
440 let _ = response_tx.send(());
441 true
442 }
443 }
444 }
445
446 async fn start_task(
447 &mut self,
448 id: TaskId,
449 command: String,
450 options: StartTaskOptions,
451 ) -> Result<(), TaskError> {
452 if self.tasks.contains_key(&id) {
453 return Err(TaskError::TaskAlreadyRunning(id));
454 }
455
456 let StartTaskOptions {
457 description,
458 timeout,
459 remote_connection,
460 child_env,
461 } = options;
462
463 let task = Task {
464 id: id.clone(),
465 status: TaskStatus::Running,
466 command: command.clone(),
467 description,
468 remote_connection: remote_connection.clone(),
469 output: None,
470 error: None,
471 start_time: Utc::now(),
472 duration: None,
473 timeout,
474 pause_info: None,
475 child_env: child_env.clone(),
476 };
477
478 let (cancel_tx, cancel_rx) = oneshot::channel();
479 let (process_tx, process_rx) = oneshot::channel();
480 let task_tx: mpsc::UnboundedSender<TaskMessage> = self.tx.clone();
481
482 let is_remote_task = remote_connection.is_some();
483
484 let execution = TaskExecution {
486 id: id.clone(),
487 command,
488 remote_connection,
489 task_timeout: timeout,
490 child_env,
491 };
492
493 let handle = tokio::spawn(Self::execute_task(
494 execution, cancel_rx, process_tx, task_tx,
495 ));
496
497 let entry = TaskEntry {
498 task,
499 handle,
500 process_id: None,
501 cancel_tx: Some(cancel_tx),
502 };
503
504 self.tasks.insert(id.clone(), entry);
505
506 if !is_remote_task {
508 if let Ok(process_id) = process_rx.await
510 && let Some(entry) = self.tasks.get_mut(&id)
511 {
512 entry.process_id = Some(process_id);
513 }
514 }
515 Ok(())
518 }
519
520 async fn resume_task(&mut self, id: TaskId, command: String) -> Result<(), TaskError> {
521 if let Some(entry) = self.tasks.get(&id) {
523 if !matches!(
524 entry.task.status,
525 TaskStatus::Paused | TaskStatus::Completed
526 ) {
527 return Err(TaskError::TaskNotPaused(id));
528 }
529 } else {
530 return Err(TaskError::TaskNotFound(id));
531 }
532
533 let entry = self.tasks.get_mut(&id).unwrap();
535 entry.task.status = TaskStatus::Running;
536 entry.task.command = command.clone();
537 entry.task.pause_info = None;
538 entry.task.output = None;
539 entry.task.error = None;
540
541 let (cancel_tx, cancel_rx) = oneshot::channel();
542 let (process_tx, process_rx) = oneshot::channel();
543 let task_tx = self.tx.clone();
544
545 let remote_connection = entry.task.remote_connection.clone();
546 let timeout = entry.task.timeout;
547 let child_env = entry.task.child_env.clone();
548
549 let execution = TaskExecution {
550 id: id.clone(),
551 command,
552 remote_connection: remote_connection.clone(),
553 task_timeout: timeout,
554 child_env,
555 };
556
557 let handle = tokio::spawn(Self::execute_task(
558 execution, cancel_rx, process_tx, task_tx,
559 ));
560
561 entry.handle = handle;
562 entry.cancel_tx = Some(cancel_tx);
563 entry.process_id = None;
564
565 if remote_connection.is_none()
567 && let Ok(process_id) = process_rx.await
568 && let Some(entry) = self.tasks.get_mut(&id)
569 {
570 entry.process_id = Some(process_id);
571 }
572
573 Ok(())
574 }
575
576 async fn cancel_task(&mut self, id: &TaskId) -> Result<(), TaskError> {
577 if let Some(mut entry) = self.tasks.remove(id) {
578 entry.task.status = TaskStatus::Cancelled;
579
580 if let Some(cancel_tx) = entry.cancel_tx.take() {
581 let _ = cancel_tx.send(());
582 }
583
584 if let Some(process_id) = entry.process_id {
585 terminate_process_group(process_id);
586 }
587
588 entry.handle.abort();
589 Ok(())
590 } else {
591 Err(TaskError::TaskNotFound(id.clone()))
592 }
593 }
594
595 async fn execute_task(
596 execution: TaskExecution,
597 mut cancel_rx: oneshot::Receiver<()>,
598 process_tx: oneshot::Sender<u32>,
599 task_tx: mpsc::UnboundedSender<TaskMessage>,
600 ) {
601 let TaskExecution {
602 id,
603 command,
604 remote_connection,
605 task_timeout,
606 child_env,
607 } = execution;
608 let completion = if let Some(remote_info) = remote_connection {
609 Self::execute_remote_task(
611 id.clone(),
612 command,
613 remote_info,
614 task_timeout,
615 &mut cancel_rx,
616 &task_tx,
617 )
618 .await
619 } else {
620 Self::execute_local_task(
622 id.clone(),
623 command,
624 task_timeout,
625 &mut cancel_rx,
626 process_tx,
627 &task_tx,
628 child_env,
629 )
630 .await
631 };
632
633 let _ = task_tx.send(TaskMessage::TaskUpdate {
635 id: id.clone(),
636 completion,
637 });
638 }
639
640 async fn execute_local_task(
641 id: TaskId,
642 command: String,
643 task_timeout: Option<Duration>,
644 cancel_rx: &mut oneshot::Receiver<()>,
645 process_tx: oneshot::Sender<u32>,
646 task_tx: &mpsc::UnboundedSender<TaskMessage>,
647 child_env: HashMap<String, String>,
648 ) -> TaskCompletion {
649 let mut cmd = Command::new("sh");
650 cmd.arg("-c")
651 .arg(&command)
652 .stdin(Stdio::null())
653 .stdout(Stdio::piped())
654 .stderr(Stdio::piped());
655 for (key, value) in child_env {
656 cmd.env(key, value);
657 }
658 #[cfg(unix)]
659 {
660 cmd.env("DEBIAN_FRONTEND", "noninteractive")
661 .env("SUDO_ASKPASS", "/bin/false")
662 .process_group(0);
663 }
664 #[cfg(windows)]
665 {
666 cmd.creation_flags(0x00000200); }
669
670 let mut child = match cmd.spawn() {
671 Ok(child) => child,
672 Err(err) => {
673 return TaskCompletion {
674 output: String::new(),
675 error: Some(format!("Failed to spawn command: {}", err)),
676 final_status: TaskStatus::Failed,
677 };
678 }
679 };
680
681 if let Some(process_id) = child.id() {
683 let _ = process_tx.send(process_id);
684 }
685
686 let stdout = child.stdout.take().unwrap();
688 let stderr = child.stderr.take().unwrap();
689
690 let stdout_reader = BufReader::new(stdout);
691 let stderr_reader = BufReader::new(stderr);
692
693 let mut stdout_lines = stdout_reader.lines();
694 let mut stderr_lines = stderr_reader.lines();
695
696 let stream_output = async {
698 let mut final_output = String::new();
699 let mut final_error: Option<String> = None;
700 let mut stdout_done = false;
701 let mut stderr_done = false;
702
703 loop {
704 tokio::select! {
705 line = stdout_lines.next_line(), if !stdout_done => {
706 match line {
707 Ok(Some(line)) => {
708 let output_line = format!("{}\n", line);
709 final_output.push_str(&output_line);
710 let _ = task_tx.send(TaskMessage::PartialUpdate {
711 id: id.clone(),
712 output: output_line,
713 });
714 }
715 Ok(None) => {
716 stdout_done = true;
718 }
719 Err(err) => {
720 final_error = Some(format!("Error reading stdout: {}", err));
721 break;
722 }
723 }
724 }
725 line = stderr_lines.next_line(), if !stderr_done => {
726 match line {
727 Ok(Some(line)) => {
728 let output_line = format!("{}\n", line);
729 final_output.push_str(&output_line);
730 let _ = task_tx.send(TaskMessage::PartialUpdate {
731 id: id.clone(),
732 output: output_line,
733 });
734 }
735 Ok(None) => {
736 stderr_done = true;
738 }
739 Err(err) => {
740 final_error = Some(format!("Error reading stderr: {}", err));
741 break;
742 }
743 }
744 }
745 status = child.wait() => {
746 match status {
747 Ok(exit_status) => {
748 while !stdout_done {
749 match stdout_lines.next_line().await {
750 Ok(Some(line)) => {
751 let output_line = format!("{}\n", line);
752 final_output.push_str(&output_line);
753 let _ = task_tx.send(TaskMessage::PartialUpdate {
754 id: id.clone(),
755 output: output_line,
756 });
757 }
758 Ok(None) => stdout_done = true,
759 Err(err) => {
760 final_error = Some(format!("Error reading stdout: {}", err));
761 break;
762 }
763 }
764 }
765
766 while !stderr_done {
767 match stderr_lines.next_line().await {
768 Ok(Some(line)) => {
769 let output_line = format!("{}\n", line);
770 final_output.push_str(&output_line);
771 let _ = task_tx.send(TaskMessage::PartialUpdate {
772 id: id.clone(),
773 output: output_line,
774 });
775 }
776 Ok(None) => stderr_done = true,
777 Err(err) => {
778 final_error = Some(format!("Error reading stderr: {}", err));
779 break;
780 }
781 }
782 }
783
784 return task_completion_from_exit_status(
785 final_output,
786 final_error,
787 exit_status,
788 );
789 }
790 Err(err) => {
791 return TaskCompletion {
792 output: final_output,
793 error: Some(err.to_string()),
794 final_status: TaskStatus::Failed,
795 };
796 }
797 }
798 }
799 _ = &mut *cancel_rx => {
800 return TaskCompletion {
801 output: final_output,
802 error: Some("Tool call was cancelled and don't try to run it again".to_string()),
803 final_status: TaskStatus::Cancelled,
804 };
805 }
806 }
807 }
808
809 TaskCompletion {
810 output: final_output,
811 error: final_error,
812 final_status: TaskStatus::Failed,
813 }
814 };
815
816 if let Some(timeout_duration) = task_timeout {
818 match timeout(timeout_duration, stream_output).await {
819 Ok(result) => result,
820 Err(_) => TaskCompletion {
821 output: String::new(),
822 error: Some("Task timed out".to_string()),
823 final_status: TaskStatus::TimedOut,
824 },
825 }
826 } else {
827 stream_output.await
828 }
829 }
830
831 async fn execute_remote_task(
832 id: TaskId,
833 command: String,
834 remote_info: RemoteConnectionInfo,
835 task_timeout: Option<Duration>,
836 cancel_rx: &mut oneshot::Receiver<()>,
837 task_tx: &mpsc::UnboundedSender<TaskMessage>,
838 ) -> TaskCompletion {
839 let connection_manager = RemoteConnectionManager::new();
841 let connection = match connection_manager.get_connection(&remote_info).await {
842 Ok(conn) => conn,
843 Err(e) => {
844 return TaskCompletion {
845 output: String::new(),
846 error: Some(format!("Failed to establish remote connection: {}", e)),
847 final_status: TaskStatus::Failed,
848 };
849 }
850 };
851
852 let task_tx_clone = task_tx.clone();
854 let id_clone = id.clone();
855 let progress_callback = move |output: String| {
856 if !output.trim().is_empty() {
857 let _ = task_tx_clone.send(TaskMessage::PartialUpdate {
858 id: id_clone.clone(),
859 output,
860 });
861 }
862 };
863
864 let options = crate::remote_connection::CommandOptions {
866 timeout: task_timeout,
867 with_progress: false,
868 simple: false,
869 };
870
871 match connection
872 .execute_command_unified(&command, options, cancel_rx, Some(progress_callback), None)
873 .await
874 {
875 Ok((output, exit_code)) => TaskCompletion {
876 output,
877 error: if exit_code != 0 {
878 Some(format!("Command exited with code {}", exit_code))
879 } else {
880 None
881 },
882 final_status: TaskStatus::Completed,
883 },
884 Err(e) => {
885 let error_msg = e.to_string();
886 let status = if error_msg.contains("timed out") {
887 TaskStatus::TimedOut
888 } else if error_msg.contains("cancelled") {
889 TaskStatus::Cancelled
890 } else {
891 TaskStatus::Failed
892 };
893
894 TaskCompletion {
895 output: String::new(),
896 error: Some(if error_msg.contains("cancelled") {
897 "Tool call was cancelled and don't try to run it again".to_string()
898 } else {
899 format!("Remote command failed: {}", error_msg)
900 }),
901 final_status: status,
902 }
903 }
904 }
905 }
906
907 async fn shutdown_all_tasks(&mut self) {
908 for (_id, mut entry) in self.tasks.drain() {
909 if let Some(cancel_tx) = entry.cancel_tx.take() {
910 let _ = cancel_tx.send(());
911 }
912
913 if let Some(process_id) = entry.process_id {
914 terminate_process_group(process_id);
915 }
916
917 entry.handle.abort();
918 }
919 }
920}
921
922pub struct TaskManagerHandle {
923 tx: mpsc::UnboundedSender<TaskMessage>,
924 shutdown_tx: broadcast::Sender<()>,
925}
926
927impl Drop for TaskManagerHandle {
928 fn drop(&mut self) {
929 let _ = self.shutdown_tx.send(());
938 }
939}
940
941impl TaskManagerHandle {
942 pub async fn start_task(
943 &self,
944 command: String,
945 options: StartTaskOptions,
946 ) -> Result<TaskInfo, TaskError> {
947 let (response_tx, response_rx) = oneshot::channel();
948
949 self.tx
950 .send(TaskMessage::Start {
951 id: None,
952 command: command.clone(),
953 options,
954 response_tx,
955 })
956 .map_err(|_| TaskError::ManagerShutdown)?;
957
958 let task_id = response_rx
959 .await
960 .map_err(|_| TaskError::ManagerShutdown)??;
961
962 tokio::time::sleep(START_TASK_WAIT_TIME).await;
964
965 let task_info = self
966 .get_task_details(task_id.clone())
967 .await
968 .map_err(|_| TaskError::ManagerShutdown)?
969 .ok_or_else(|| TaskError::TaskNotFound(task_id.clone()))?;
970
971 if matches!(task_info.status, TaskStatus::Failed | TaskStatus::Cancelled) {
973 return Err(TaskError::TaskFailedOnStart(
974 task_info
975 .output
976 .unwrap_or_else(|| "Unknown reason".to_string()),
977 ));
978 }
979
980 Ok(task_info)
982 }
983
984 pub async fn cancel_task(&self, id: TaskId) -> Result<TaskInfo, TaskError> {
985 let task_info = self
987 .get_all_tasks()
988 .await?
989 .into_iter()
990 .find(|task| task.id == id)
991 .ok_or_else(|| TaskError::TaskNotFound(id.clone()))?;
992
993 let (response_tx, response_rx) = oneshot::channel();
994
995 self.tx
996 .send(TaskMessage::Cancel { id, response_tx })
997 .map_err(|_| TaskError::ManagerShutdown)?;
998
999 response_rx
1000 .await
1001 .map_err(|_| TaskError::ManagerShutdown)??;
1002
1003 Ok(TaskInfo {
1005 status: TaskStatus::Cancelled,
1006 duration: Some(
1007 Utc::now()
1008 .signed_duration_since(task_info.start_time)
1009 .to_std()
1010 .unwrap_or_default(),
1011 ),
1012 ..task_info
1013 })
1014 }
1015
1016 pub async fn resume_task(&self, id: TaskId, command: String) -> Result<TaskInfo, TaskError> {
1017 let (response_tx, response_rx) = oneshot::channel();
1018
1019 self.tx
1020 .send(TaskMessage::Resume {
1021 id: id.clone(),
1022 command,
1023 response_tx,
1024 })
1025 .map_err(|_| TaskError::ManagerShutdown)?;
1026
1027 response_rx
1028 .await
1029 .map_err(|_| TaskError::ManagerShutdown)??;
1030
1031 tokio::time::sleep(START_TASK_WAIT_TIME).await;
1033
1034 let task_info = self
1035 .get_task_details(id.clone())
1036 .await
1037 .map_err(|_| TaskError::ManagerShutdown)?
1038 .ok_or(TaskError::TaskNotFound(id))?;
1039
1040 Ok(task_info)
1041 }
1042
1043 pub async fn get_task_status(&self, id: TaskId) -> Result<Option<TaskStatus>, TaskError> {
1044 let (response_tx, response_rx) = oneshot::channel();
1045
1046 self.tx
1047 .send(TaskMessage::GetStatus { id, response_tx })
1048 .map_err(|_| TaskError::ManagerShutdown)?;
1049
1050 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1051 }
1052
1053 pub async fn get_task_details(&self, id: TaskId) -> Result<Option<TaskInfo>, TaskError> {
1054 let (response_tx, response_rx) = oneshot::channel();
1055
1056 self.tx
1057 .send(TaskMessage::GetTaskDetails { id, response_tx })
1058 .map_err(|_| TaskError::ManagerShutdown)?;
1059
1060 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1061 }
1062
1063 pub async fn get_all_tasks(&self) -> Result<Vec<TaskInfo>, TaskError> {
1064 let (response_tx, response_rx) = oneshot::channel();
1065
1066 self.tx
1067 .send(TaskMessage::GetAllTasks { response_tx })
1068 .map_err(|_| TaskError::ManagerShutdown)?;
1069
1070 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1071 }
1072
1073 pub async fn shutdown(&self) -> Result<(), TaskError> {
1074 let (response_tx, response_rx) = oneshot::channel();
1075
1076 self.tx
1077 .send(TaskMessage::Shutdown { response_tx })
1078 .map_err(|_| TaskError::ManagerShutdown)?;
1079
1080 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
1081 }
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086 use super::*;
1087 use tokio::time::{Duration, sleep};
1088
1089 #[tokio::test]
1090 async fn test_task_manager_shutdown() {
1091 let task_manager = TaskManager::new();
1092 let handle = task_manager.handle();
1093
1094 let manager_handle = tokio::spawn(async move {
1096 task_manager.run().await;
1097 });
1098
1099 let task_info = handle
1101 .start_task("sleep 5".to_string(), StartTaskOptions::default())
1102 .await
1103 .expect("Failed to start task");
1104
1105 let status = handle
1107 .get_task_status(task_info.id.clone())
1108 .await
1109 .expect("Failed to get task status");
1110 assert_eq!(status, Some(TaskStatus::Running));
1111
1112 handle
1114 .shutdown()
1115 .await
1116 .expect("Failed to shutdown task manager");
1117
1118 sleep(Duration::from_millis(100)).await;
1120
1121 assert!(manager_handle.is_finished());
1123 }
1124
1125 #[tokio::test]
1126 async fn test_task_manager_cancels_tasks_on_shutdown() {
1127 let task_manager = TaskManager::new();
1128 let handle = task_manager.handle();
1129
1130 let manager_handle = tokio::spawn(async move {
1132 task_manager.run().await;
1133 });
1134
1135 let task_info = handle
1137 .start_task("sleep 10".to_string(), StartTaskOptions::default())
1138 .await
1139 .expect("Failed to start task");
1140
1141 let status = handle
1143 .get_task_status(task_info.id.clone())
1144 .await
1145 .expect("Failed to get task status");
1146 assert_eq!(status, Some(TaskStatus::Running));
1147
1148 handle
1150 .shutdown()
1151 .await
1152 .expect("Failed to shutdown task manager");
1153
1154 sleep(Duration::from_millis(100)).await;
1156
1157 assert!(manager_handle.is_finished());
1159 }
1160
1161 #[tokio::test]
1162 async fn test_task_manager_start_and_complete_task() {
1163 let task_manager = TaskManager::new();
1164 let handle = task_manager.handle();
1165
1166 let _manager_handle = tokio::spawn(async move {
1168 task_manager.run().await;
1169 });
1170
1171 let task_info = handle
1173 .start_task(
1174 "echo 'Hello, World!'".to_string(),
1175 StartTaskOptions::default(),
1176 )
1177 .await
1178 .expect("Failed to start task");
1179
1180 sleep(Duration::from_millis(500)).await;
1182
1183 let status = handle
1185 .get_task_status(task_info.id.clone())
1186 .await
1187 .expect("Failed to get task status");
1188 assert_eq!(status, Some(TaskStatus::Completed));
1189
1190 let tasks = handle
1192 .get_all_tasks()
1193 .await
1194 .expect("Failed to get all tasks");
1195 assert_eq!(tasks.len(), 1);
1196 assert_eq!(tasks[0].status, TaskStatus::Completed);
1197
1198 handle
1200 .shutdown()
1201 .await
1202 .expect("Failed to shutdown task manager");
1203 }
1204
1205 #[tokio::test]
1206 async fn task_manager_local_task_receives_child_env_defaults() {
1207 let task_manager = TaskManager::new();
1208 let handle = task_manager.handle();
1209
1210 let _manager_handle = tokio::spawn(async move {
1211 task_manager.run().await;
1212 });
1213
1214 let mut child_env = HashMap::new();
1215 child_env.insert("STAKPAK_PROFILE".to_string(), "ops".to_string());
1216
1217 let task_info = handle
1218 .start_task(
1219 "printf '%s\\n' \"$STAKPAK_PROFILE\"".to_string(),
1220 StartTaskOptions {
1221 child_env,
1222 ..StartTaskOptions::default()
1223 },
1224 )
1225 .await
1226 .expect("task should start with child env defaults");
1227
1228 sleep(Duration::from_millis(500)).await;
1229
1230 let details = handle
1231 .get_task_details(task_info.id.clone())
1232 .await
1233 .expect("task details request should succeed")
1234 .expect("task details should exist");
1235
1236 assert_eq!(details.status, TaskStatus::Completed);
1237 assert_eq!(details.output.as_deref(), Some("ops\n"));
1238
1239 handle
1240 .shutdown()
1241 .await
1242 .expect("Failed to shutdown task manager");
1243 }
1244
1245 #[tokio::test]
1246 async fn resumed_local_task_reuses_child_env_defaults() {
1247 let task_manager = TaskManager::new();
1248 let handle = task_manager.handle();
1249
1250 let _manager_handle = tokio::spawn(async move {
1251 task_manager.run().await;
1252 });
1253
1254 let mut child_env = HashMap::new();
1255 child_env.insert("STAKPAK_PROFILE".to_string(), "ops".to_string());
1256
1257 let task_info = handle
1258 .start_task(
1259 "exit 10".to_string(),
1260 StartTaskOptions {
1261 child_env,
1262 ..StartTaskOptions::default()
1263 },
1264 )
1265 .await
1266 .expect("pausing task should start");
1267
1268 sleep(Duration::from_millis(500)).await;
1269
1270 let paused = handle
1271 .get_task_details(task_info.id.clone())
1272 .await
1273 .expect("task details request should succeed")
1274 .expect("task details should exist");
1275 assert_eq!(paused.status, TaskStatus::Paused);
1276
1277 handle
1278 .resume_task(
1279 task_info.id.clone(),
1280 "printf '%s\\n' \"$STAKPAK_PROFILE\"".to_string(),
1281 )
1282 .await
1283 .expect("paused task should resume");
1284
1285 sleep(Duration::from_millis(500)).await;
1286
1287 let details = handle
1288 .get_task_details(task_info.id)
1289 .await
1290 .expect("task details request should succeed")
1291 .expect("task details should exist");
1292
1293 assert_eq!(details.status, TaskStatus::Completed);
1294 assert_eq!(details.output.as_deref(), Some("ops\n"));
1295
1296 handle
1297 .shutdown()
1298 .await
1299 .expect("Failed to shutdown task manager");
1300 }
1301
1302 #[tokio::test]
1303 async fn test_task_manager_detects_immediate_failure() {
1304 let task_manager = TaskManager::new();
1305 let handle = task_manager.handle();
1306
1307 let _manager_handle = tokio::spawn(async move {
1309 task_manager.run().await;
1310 });
1311
1312 let result = handle
1314 .start_task(
1315 "nonexistent_command_12345".to_string(),
1316 StartTaskOptions::default(),
1317 )
1318 .await;
1319
1320 assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1322
1323 handle
1325 .shutdown()
1326 .await
1327 .expect("Failed to shutdown task manager");
1328 }
1329
1330 #[tokio::test]
1331 async fn test_task_manager_handle_drop_triggers_shutdown() {
1332 let task_manager = TaskManager::new();
1333 let handle = task_manager.handle();
1334
1335 let manager_handle = tokio::spawn(async move {
1336 task_manager.run().await;
1337 });
1338
1339 let _task_info = handle
1341 .start_task("sleep 30".to_string(), StartTaskOptions::default())
1342 .await
1343 .expect("Failed to start task");
1344
1345 drop(handle);
1347
1348 sleep(Duration::from_millis(500)).await;
1352
1353 assert!(
1354 manager_handle.is_finished(),
1355 "TaskManager::run() should have exited after handle was dropped"
1356 );
1357 }
1358
1359 #[tokio::test]
1360 async fn test_task_manager_handle_drop_kills_child_processes() {
1361 let task_manager = TaskManager::new();
1362 let handle = task_manager.handle();
1363
1364 let _manager_handle = tokio::spawn(async move {
1365 task_manager.run().await;
1366 });
1367
1368 let marker = format!("/tmp/stakpak_test_drop_{}", std::process::id());
1370 let task_info = handle
1371 .start_task(
1372 format!("touch {} && sleep 30", marker),
1373 StartTaskOptions::default(),
1374 )
1375 .await
1376 .expect("Failed to start task");
1377
1378 let status = handle
1380 .get_task_status(task_info.id.clone())
1381 .await
1382 .expect("Failed to get status");
1383 assert_eq!(status, Some(TaskStatus::Running));
1384
1385 drop(handle);
1387 sleep(Duration::from_millis(500)).await;
1388
1389 let _ = std::fs::remove_file(&marker);
1391 }
1392
1393 #[tokio::test]
1394 async fn test_task_manager_detects_immediate_exit_code_failure() {
1395 let task_manager = TaskManager::new();
1396 let handle = task_manager.handle();
1397
1398 let _manager_handle = tokio::spawn(async move {
1400 task_manager.run().await;
1401 });
1402
1403 let result = handle
1405 .start_task("exit 1".to_string(), StartTaskOptions::default())
1406 .await;
1407
1408 assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1410
1411 handle
1413 .shutdown()
1414 .await
1415 .expect("Failed to shutdown task manager");
1416 }
1417}