1use crate::helper::generate_simple_id;
2use crate::remote_connection::{RemoteConnectionInfo, RemoteConnectionManager};
3use chrono::{DateTime, Utc};
4use std::{collections::HashMap, process::Stdio, sync::Arc, time::Duration};
5use tokio::{
6 io::{AsyncBufReadExt, BufReader},
7 process::Command,
8 sync::{broadcast, mpsc, oneshot},
9 time::timeout,
10};
11
12const START_TASK_WAIT_TIME: Duration = Duration::from_millis(300);
13
14fn terminate_process_group(process_id: u32) {
22 #[cfg(unix)]
23 {
24 use std::process::Command;
25 let check_result = Command::new("kill")
27 .arg("-0") .arg(process_id.to_string())
29 .output();
30
31 if check_result
33 .map(|output| output.status.success())
34 .unwrap_or(false)
35 {
36 let _ = Command::new("kill")
40 .arg("-9")
41 .arg(format!("-{}", process_id))
42 .output();
43
44 let _ = Command::new("kill")
46 .arg("-9")
47 .arg(process_id.to_string())
48 .output();
49 }
50 }
51
52 #[cfg(windows)]
53 {
54 use std::process::Command;
55 let check_result = Command::new("tasklist")
57 .arg("/FI")
58 .arg(format!("PID eq {}", process_id))
59 .arg("/FO")
60 .arg("CSV")
61 .output();
62
63 if let Ok(output) = check_result {
65 let output_str = String::from_utf8_lossy(&output.stdout);
66 if output_str.lines().count() > 1 {
67 let _ = Command::new("taskkill")
69 .arg("/F")
70 .arg("/T") .arg("/PID")
72 .arg(process_id.to_string())
73 .output();
74 }
75 }
76 }
77}
78
79pub type TaskId = String;
80
81#[derive(Debug, Clone, PartialEq, serde::Serialize)]
82pub enum TaskStatus {
83 Pending,
84 Running,
85 Completed,
86 Failed,
87 Cancelled,
88 TimedOut,
89}
90
91#[derive(Debug, Clone)]
92pub struct Task {
93 pub id: TaskId,
94 pub status: TaskStatus,
95 pub command: String,
96 pub remote_connection: Option<RemoteConnectionInfo>,
97 pub output: Option<String>,
98 pub error: Option<String>,
99 pub start_time: DateTime<Utc>,
100 pub duration: Option<Duration>,
101 pub timeout: Option<Duration>,
102}
103
104pub struct TaskEntry {
105 pub task: Task,
106 pub handle: tokio::task::JoinHandle<()>,
107 pub process_id: Option<u32>,
108 pub cancel_tx: Option<oneshot::Sender<()>>,
109}
110
111#[derive(Debug, Clone, serde::Serialize)]
112pub struct TaskInfo {
113 pub id: TaskId,
114 pub status: TaskStatus,
115 pub command: String,
116 pub output: Option<String>,
117 pub start_time: DateTime<Utc>,
118 pub duration: Option<Duration>,
119}
120
121impl From<&Task> for TaskInfo {
122 fn from(task: &Task) -> Self {
123 let duration = if matches!(task.status, TaskStatus::Running) {
124 Some(
126 Utc::now()
127 .signed_duration_since(task.start_time)
128 .to_std()
129 .unwrap_or_default(),
130 )
131 } else {
132 task.duration
134 };
135
136 TaskInfo {
137 id: task.id.clone(),
138 status: task.status.clone(),
139 command: task.command.clone(),
140 output: task.output.clone(),
141 start_time: task.start_time,
142 duration,
143 }
144 }
145}
146
147pub struct TaskCompletion {
148 pub output: String,
149 pub error: Option<String>,
150 pub final_status: TaskStatus,
151}
152
153#[derive(Debug, thiserror::Error)]
154pub enum TaskError {
155 #[error("Task not found: {0}")]
156 TaskNotFound(TaskId),
157 #[error("Task already running: {0}")]
158 TaskAlreadyRunning(TaskId),
159 #[error("Manager shutdown")]
160 ManagerShutdown,
161 #[error("Command execution failed: {0}")]
162 ExecutionFailed(String),
163 #[error("Task timeout")]
164 TaskTimeout,
165 #[error("Task cancelled")]
166 TaskCancelled,
167 #[error("Task failed on start: {0}")]
168 TaskFailedOnStart(String),
169}
170
171pub enum TaskMessage {
172 Start {
173 id: Option<TaskId>,
174 command: String,
175 remote_connection: Option<RemoteConnectionInfo>,
176 timeout: Option<Duration>,
177 response_tx: oneshot::Sender<Result<TaskId, TaskError>>,
178 },
179 Cancel {
180 id: TaskId,
181 response_tx: oneshot::Sender<Result<(), TaskError>>,
182 },
183 GetStatus {
184 id: TaskId,
185 response_tx: oneshot::Sender<Option<TaskStatus>>,
186 },
187 GetTaskDetails {
188 id: TaskId,
189 response_tx: oneshot::Sender<Option<TaskInfo>>,
190 },
191 GetAllTasks {
192 response_tx: oneshot::Sender<Vec<TaskInfo>>,
193 },
194 Shutdown {
195 response_tx: oneshot::Sender<()>,
196 },
197 TaskUpdate {
198 id: TaskId,
199 completion: TaskCompletion,
200 },
201 PartialUpdate {
202 id: TaskId,
203 output: String,
204 },
205}
206
207pub struct TaskManager {
208 tasks: HashMap<TaskId, TaskEntry>,
209 tx: mpsc::UnboundedSender<TaskMessage>,
210 rx: mpsc::UnboundedReceiver<TaskMessage>,
211 shutdown_tx: broadcast::Sender<()>,
212 shutdown_rx: broadcast::Receiver<()>,
213}
214
215impl Default for TaskManager {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221impl TaskManager {
222 pub fn new() -> Self {
223 let (tx, rx) = mpsc::unbounded_channel();
224 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
225
226 Self {
227 tasks: HashMap::new(),
228 tx,
229 rx,
230 shutdown_tx,
231 shutdown_rx,
232 }
233 }
234
235 pub fn handle(&self) -> Arc<TaskManagerHandle> {
236 Arc::new(TaskManagerHandle {
237 tx: self.tx.clone(),
238 shutdown_tx: self.shutdown_tx.clone(),
239 })
240 }
241
242 pub async fn run(mut self) {
243 loop {
244 tokio::select! {
245 msg = self.rx.recv() => {
246 match msg {
247 Some(msg) => {
248 if self.handle_message(msg).await {
249 break;
250 }
251 }
252 None => {
253 self.shutdown_all_tasks().await;
256 break;
257 }
258 }
259 }
260 _ = self.shutdown_rx.recv() => {
261 self.shutdown_all_tasks().await;
262 break;
263 }
264 }
265 }
266 }
267
268 async fn handle_message(&mut self, msg: TaskMessage) -> bool {
269 match msg {
270 TaskMessage::Start {
271 id,
272 command,
273 remote_connection,
274 timeout,
275 response_tx,
276 } => {
277 let task_id = id.unwrap_or_else(|| generate_simple_id(6));
278 let result = self
279 .start_task(task_id.clone(), command, timeout, remote_connection)
280 .await;
281 let _ = response_tx.send(result.map(|_| task_id.clone()));
282 false
283 }
284 TaskMessage::Cancel { id, response_tx } => {
285 let result = self.cancel_task(&id).await;
286 let _ = response_tx.send(result);
287 false
288 }
289 TaskMessage::GetStatus { id, response_tx } => {
290 let status = self.tasks.get(&id).map(|entry| entry.task.status.clone());
291 let _ = response_tx.send(status);
292 false
293 }
294 TaskMessage::GetTaskDetails { id, response_tx } => {
295 let task_info = self.tasks.get(&id).map(|entry| TaskInfo::from(&entry.task));
296 let _ = response_tx.send(task_info);
297 false
298 }
299 TaskMessage::GetAllTasks { response_tx } => {
300 let mut tasks: Vec<TaskInfo> = self
301 .tasks
302 .values()
303 .map(|entry| TaskInfo::from(&entry.task))
304 .collect();
305 tasks.sort_by(|a, b| b.start_time.cmp(&a.start_time));
306 let _ = response_tx.send(tasks);
307 false
308 }
309 TaskMessage::TaskUpdate { id, completion } => {
310 if let Some(entry) = self.tasks.get_mut(&id) {
311 entry.task.status = completion.final_status;
312 entry.task.output = Some(completion.output);
313 entry.task.error = completion.error;
314 entry.task.duration = Some(
315 Utc::now()
316 .signed_duration_since(entry.task.start_time)
317 .to_std()
318 .unwrap_or_default(),
319 );
320
321 }
327 false
328 }
329 TaskMessage::PartialUpdate { id, output } => {
330 if let Some(entry) = self.tasks.get_mut(&id) {
331 match &entry.task.output {
332 Some(existing) => {
333 entry.task.output = Some(format!("{}{}", existing, output));
334 }
335 None => {
336 entry.task.output = Some(output);
337 }
338 }
339 }
340 false
341 }
342 TaskMessage::Shutdown { response_tx } => {
343 self.shutdown_all_tasks().await;
344 let _ = response_tx.send(());
345 true
346 }
347 }
348 }
349
350 async fn start_task(
351 &mut self,
352 id: TaskId,
353 command: String,
354 timeout: Option<Duration>,
355 remote_connection: Option<RemoteConnectionInfo>,
356 ) -> Result<(), TaskError> {
357 if self.tasks.contains_key(&id) {
358 return Err(TaskError::TaskAlreadyRunning(id));
359 }
360
361 let task = Task {
362 id: id.clone(),
363 status: TaskStatus::Running,
364 command: command.clone(),
365 remote_connection: remote_connection.clone(),
366 output: None,
367 error: None,
368 start_time: Utc::now(),
369 duration: None,
370 timeout,
371 };
372
373 let (cancel_tx, cancel_rx) = oneshot::channel();
374 let (process_tx, process_rx) = oneshot::channel();
375 let task_tx: mpsc::UnboundedSender<TaskMessage> = self.tx.clone();
376
377 let is_remote_task = remote_connection.is_some();
378
379 let handle = tokio::spawn(Self::execute_task(
381 id.clone(),
382 command,
383 remote_connection,
384 timeout,
385 cancel_rx,
386 process_tx,
387 task_tx,
388 ));
389
390 let entry = TaskEntry {
391 task,
392 handle,
393 process_id: None,
394 cancel_tx: Some(cancel_tx),
395 };
396
397 self.tasks.insert(id.clone(), entry);
398
399 if !is_remote_task {
401 if let Ok(process_id) = process_rx.await
403 && let Some(entry) = self.tasks.get_mut(&id)
404 {
405 entry.process_id = Some(process_id);
406 }
407 }
408 Ok(())
411 }
412
413 async fn cancel_task(&mut self, id: &TaskId) -> Result<(), TaskError> {
414 if let Some(mut entry) = self.tasks.remove(id) {
415 entry.task.status = TaskStatus::Cancelled;
416
417 if let Some(cancel_tx) = entry.cancel_tx.take() {
418 let _ = cancel_tx.send(());
419 }
420
421 if let Some(process_id) = entry.process_id {
422 terminate_process_group(process_id);
423 }
424
425 entry.handle.abort();
426 Ok(())
427 } else {
428 Err(TaskError::TaskNotFound(id.clone()))
429 }
430 }
431
432 async fn execute_task(
433 id: TaskId,
434 command: String,
435 remote_connection: Option<RemoteConnectionInfo>,
436 task_timeout: Option<Duration>,
437 mut cancel_rx: oneshot::Receiver<()>,
438 process_tx: oneshot::Sender<u32>,
439 task_tx: mpsc::UnboundedSender<TaskMessage>,
440 ) {
441 let completion = if let Some(remote_info) = remote_connection {
442 Self::execute_remote_task(
444 id.clone(),
445 command,
446 remote_info,
447 task_timeout,
448 &mut cancel_rx,
449 &task_tx,
450 )
451 .await
452 } else {
453 Self::execute_local_task(
455 id.clone(),
456 command,
457 task_timeout,
458 &mut cancel_rx,
459 process_tx,
460 &task_tx,
461 )
462 .await
463 };
464
465 let _ = task_tx.send(TaskMessage::TaskUpdate {
467 id: id.clone(),
468 completion,
469 });
470 }
471
472 async fn execute_local_task(
473 id: TaskId,
474 command: String,
475 task_timeout: Option<Duration>,
476 cancel_rx: &mut oneshot::Receiver<()>,
477 process_tx: oneshot::Sender<u32>,
478 task_tx: &mpsc::UnboundedSender<TaskMessage>,
479 ) -> TaskCompletion {
480 let mut cmd = Command::new("sh");
481 cmd.arg("-c")
482 .arg(&command)
483 .stdin(Stdio::null())
484 .stdout(Stdio::piped())
485 .stderr(Stdio::piped());
486 #[cfg(unix)]
487 {
488 cmd.env("DEBIAN_FRONTEND", "noninteractive")
489 .env("SUDO_ASKPASS", "/bin/false")
490 .process_group(0);
491 }
492 #[cfg(windows)]
493 {
494 cmd.creation_flags(0x00000200); }
497
498 let mut child = match cmd.spawn() {
499 Ok(child) => child,
500 Err(err) => {
501 return TaskCompletion {
502 output: String::new(),
503 error: Some(format!("Failed to spawn command: {}", err)),
504 final_status: TaskStatus::Failed,
505 };
506 }
507 };
508
509 if let Some(process_id) = child.id() {
511 let _ = process_tx.send(process_id);
512 }
513
514 let stdout = child.stdout.take().unwrap();
516 let stderr = child.stderr.take().unwrap();
517
518 let stdout_reader = BufReader::new(stdout);
519 let stderr_reader = BufReader::new(stderr);
520
521 let mut stdout_lines = stdout_reader.lines();
522 let mut stderr_lines = stderr_reader.lines();
523
524 let stream_output = async {
526 let mut final_output = String::new();
527 let mut final_error: Option<String> = None;
528
529 loop {
530 tokio::select! {
531 line = stdout_lines.next_line() => {
532 match line {
533 Ok(Some(line)) => {
534 let output_line = format!("{}\n", line);
535 final_output.push_str(&output_line);
536 let _ = task_tx.send(TaskMessage::PartialUpdate {
537 id: id.clone(),
538 output: output_line,
539 });
540 }
541 Ok(None) => {
542 }
544 Err(err) => {
545 final_error = Some(format!("Error reading stdout: {}", err));
546 break;
547 }
548 }
549 }
550 line = stderr_lines.next_line() => {
551 match line {
552 Ok(Some(line)) => {
553 let output_line = format!("{}\n", line);
554 final_output.push_str(&output_line);
555 let _ = task_tx.send(TaskMessage::PartialUpdate {
556 id: id.clone(),
557 output: output_line,
558 });
559 }
560 Ok(None) => {
561 }
563 Err(err) => {
564 final_error = Some(format!("Error reading stderr: {}", err));
565 break;
566 }
567 }
568 }
569 status = child.wait() => {
570 match status {
571 Ok(exit_status) => {
572 if final_output.is_empty() {
573 final_output = "No output".to_string();
574 }
575
576 let completion = if exit_status.success() {
577 TaskCompletion {
578 output: final_output,
579 error: final_error,
580 final_status: TaskStatus::Completed,
581 }
582 } else {
583 TaskCompletion {
584 output: final_output,
585 error: final_error.or_else(|| Some(format!("Command failed with exit code: {:?}", exit_status.code()))),
586 final_status: TaskStatus::Failed,
587 }
588 };
589 return completion;
590 }
591 Err(err) => {
592 return TaskCompletion {
593 output: final_output,
594 error: Some(err.to_string()),
595 final_status: TaskStatus::Failed,
596 };
597 }
598 }
599 }
600 _ = &mut *cancel_rx => {
601 return TaskCompletion {
602 output: final_output,
603 error: Some("Tool call was cancelled and don't try to run it again".to_string()),
604 final_status: TaskStatus::Cancelled,
605 };
606 }
607 }
608 }
609
610 TaskCompletion {
611 output: final_output,
612 error: final_error,
613 final_status: TaskStatus::Failed,
614 }
615 };
616
617 if let Some(timeout_duration) = task_timeout {
619 match timeout(timeout_duration, stream_output).await {
620 Ok(result) => result,
621 Err(_) => TaskCompletion {
622 output: String::new(),
623 error: Some("Task timed out".to_string()),
624 final_status: TaskStatus::TimedOut,
625 },
626 }
627 } else {
628 stream_output.await
629 }
630 }
631
632 async fn execute_remote_task(
633 id: TaskId,
634 command: String,
635 remote_info: RemoteConnectionInfo,
636 task_timeout: Option<Duration>,
637 cancel_rx: &mut oneshot::Receiver<()>,
638 task_tx: &mpsc::UnboundedSender<TaskMessage>,
639 ) -> TaskCompletion {
640 let connection_manager = RemoteConnectionManager::new();
642 let connection = match connection_manager.get_connection(&remote_info).await {
643 Ok(conn) => conn,
644 Err(e) => {
645 return TaskCompletion {
646 output: String::new(),
647 error: Some(format!("Failed to establish remote connection: {}", e)),
648 final_status: TaskStatus::Failed,
649 };
650 }
651 };
652
653 let task_tx_clone = task_tx.clone();
655 let id_clone = id.clone();
656 let progress_callback = move |output: String| {
657 if !output.trim().is_empty() {
658 let _ = task_tx_clone.send(TaskMessage::PartialUpdate {
659 id: id_clone.clone(),
660 output,
661 });
662 }
663 };
664
665 let options = crate::remote_connection::CommandOptions {
667 timeout: task_timeout,
668 with_progress: false,
669 simple: false,
670 };
671
672 match connection
673 .execute_command_unified(&command, options, cancel_rx, Some(progress_callback), None)
674 .await
675 {
676 Ok((output, exit_code)) => TaskCompletion {
677 output,
678 error: if exit_code != 0 {
679 Some(format!("Command exited with code {}", exit_code))
680 } else {
681 None
682 },
683 final_status: TaskStatus::Completed,
684 },
685 Err(e) => {
686 let error_msg = e.to_string();
687 let status = if error_msg.contains("timed out") {
688 TaskStatus::TimedOut
689 } else if error_msg.contains("cancelled") {
690 TaskStatus::Cancelled
691 } else {
692 TaskStatus::Failed
693 };
694
695 TaskCompletion {
696 output: String::new(),
697 error: Some(if error_msg.contains("cancelled") {
698 "Tool call was cancelled and don't try to run it again".to_string()
699 } else {
700 format!("Remote command failed: {}", error_msg)
701 }),
702 final_status: status,
703 }
704 }
705 }
706 }
707
708 async fn shutdown_all_tasks(&mut self) {
709 for (_id, mut entry) in self.tasks.drain() {
710 if let Some(cancel_tx) = entry.cancel_tx.take() {
711 let _ = cancel_tx.send(());
712 }
713
714 if let Some(process_id) = entry.process_id {
715 terminate_process_group(process_id);
716 }
717
718 entry.handle.abort();
719 }
720 }
721}
722
723pub struct TaskManagerHandle {
724 tx: mpsc::UnboundedSender<TaskMessage>,
725 shutdown_tx: broadcast::Sender<()>,
726}
727
728impl Drop for TaskManagerHandle {
729 fn drop(&mut self) {
730 let _ = self.shutdown_tx.send(());
739 }
740}
741
742impl TaskManagerHandle {
743 pub async fn start_task(
744 &self,
745 command: String,
746 timeout: Option<Duration>,
747 remote_connection: Option<RemoteConnectionInfo>,
748 ) -> Result<TaskInfo, TaskError> {
749 let (response_tx, response_rx) = oneshot::channel();
750
751 self.tx
752 .send(TaskMessage::Start {
753 id: None,
754 command: command.clone(),
755 remote_connection: remote_connection.clone(),
756 timeout,
757 response_tx,
758 })
759 .map_err(|_| TaskError::ManagerShutdown)?;
760
761 let task_id = response_rx
762 .await
763 .map_err(|_| TaskError::ManagerShutdown)??;
764
765 tokio::time::sleep(START_TASK_WAIT_TIME).await;
767
768 let task_info = self
769 .get_task_details(task_id.clone())
770 .await
771 .map_err(|_| TaskError::ManagerShutdown)?
772 .ok_or_else(|| TaskError::TaskNotFound(task_id.clone()))?;
773
774 if matches!(task_info.status, TaskStatus::Failed | TaskStatus::Cancelled) {
776 return Err(TaskError::TaskFailedOnStart(
777 task_info
778 .output
779 .unwrap_or_else(|| "Unknown reason".to_string()),
780 ));
781 }
782
783 Ok(task_info)
785 }
786
787 pub async fn cancel_task(&self, id: TaskId) -> Result<TaskInfo, TaskError> {
788 let task_info = self
790 .get_all_tasks()
791 .await?
792 .into_iter()
793 .find(|task| task.id == id)
794 .ok_or_else(|| TaskError::TaskNotFound(id.clone()))?;
795
796 let (response_tx, response_rx) = oneshot::channel();
797
798 self.tx
799 .send(TaskMessage::Cancel { id, response_tx })
800 .map_err(|_| TaskError::ManagerShutdown)?;
801
802 response_rx
803 .await
804 .map_err(|_| TaskError::ManagerShutdown)??;
805
806 Ok(TaskInfo {
808 status: TaskStatus::Cancelled,
809 duration: Some(
810 Utc::now()
811 .signed_duration_since(task_info.start_time)
812 .to_std()
813 .unwrap_or_default(),
814 ),
815 ..task_info
816 })
817 }
818
819 pub async fn get_task_status(&self, id: TaskId) -> Result<Option<TaskStatus>, TaskError> {
820 let (response_tx, response_rx) = oneshot::channel();
821
822 self.tx
823 .send(TaskMessage::GetStatus { id, response_tx })
824 .map_err(|_| TaskError::ManagerShutdown)?;
825
826 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
827 }
828
829 pub async fn get_task_details(&self, id: TaskId) -> Result<Option<TaskInfo>, TaskError> {
830 let (response_tx, response_rx) = oneshot::channel();
831
832 self.tx
833 .send(TaskMessage::GetTaskDetails { id, response_tx })
834 .map_err(|_| TaskError::ManagerShutdown)?;
835
836 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
837 }
838
839 pub async fn get_all_tasks(&self) -> Result<Vec<TaskInfo>, TaskError> {
840 let (response_tx, response_rx) = oneshot::channel();
841
842 self.tx
843 .send(TaskMessage::GetAllTasks { response_tx })
844 .map_err(|_| TaskError::ManagerShutdown)?;
845
846 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
847 }
848
849 pub async fn shutdown(&self) -> Result<(), TaskError> {
850 let (response_tx, response_rx) = oneshot::channel();
851
852 self.tx
853 .send(TaskMessage::Shutdown { response_tx })
854 .map_err(|_| TaskError::ManagerShutdown)?;
855
856 response_rx.await.map_err(|_| TaskError::ManagerShutdown)
857 }
858}
859
860#[cfg(test)]
861mod tests {
862 use super::*;
863 use tokio::time::{Duration, sleep};
864
865 #[tokio::test]
866 async fn test_task_manager_shutdown() {
867 let task_manager = TaskManager::new();
868 let handle = task_manager.handle();
869
870 let manager_handle = tokio::spawn(async move {
872 task_manager.run().await;
873 });
874
875 let task_info = handle
877 .start_task("sleep 5".to_string(), None, None)
878 .await
879 .expect("Failed to start task");
880
881 let status = handle
883 .get_task_status(task_info.id.clone())
884 .await
885 .expect("Failed to get task status");
886 assert_eq!(status, Some(TaskStatus::Running));
887
888 handle
890 .shutdown()
891 .await
892 .expect("Failed to shutdown task manager");
893
894 sleep(Duration::from_millis(100)).await;
896
897 assert!(manager_handle.is_finished());
899 }
900
901 #[tokio::test]
902 async fn test_task_manager_cancels_tasks_on_shutdown() {
903 let task_manager = TaskManager::new();
904 let handle = task_manager.handle();
905
906 let manager_handle = tokio::spawn(async move {
908 task_manager.run().await;
909 });
910
911 let task_info = handle
913 .start_task("sleep 10".to_string(), None, None)
914 .await
915 .expect("Failed to start task");
916
917 let status = handle
919 .get_task_status(task_info.id.clone())
920 .await
921 .expect("Failed to get task status");
922 assert_eq!(status, Some(TaskStatus::Running));
923
924 handle
926 .shutdown()
927 .await
928 .expect("Failed to shutdown task manager");
929
930 sleep(Duration::from_millis(100)).await;
932
933 assert!(manager_handle.is_finished());
935 }
936
937 #[tokio::test]
938 async fn test_task_manager_start_and_complete_task() {
939 let task_manager = TaskManager::new();
940 let handle = task_manager.handle();
941
942 let _manager_handle = tokio::spawn(async move {
944 task_manager.run().await;
945 });
946
947 let task_info = handle
949 .start_task("echo 'Hello, World!'".to_string(), None, None)
950 .await
951 .expect("Failed to start task");
952
953 sleep(Duration::from_millis(500)).await;
955
956 let status = handle
958 .get_task_status(task_info.id.clone())
959 .await
960 .expect("Failed to get task status");
961 assert_eq!(status, Some(TaskStatus::Completed));
962
963 let tasks = handle
965 .get_all_tasks()
966 .await
967 .expect("Failed to get all tasks");
968 assert_eq!(tasks.len(), 1);
969 assert_eq!(tasks[0].status, TaskStatus::Completed);
970
971 handle
973 .shutdown()
974 .await
975 .expect("Failed to shutdown task manager");
976 }
977
978 #[tokio::test]
979 async fn test_task_manager_detects_immediate_failure() {
980 let task_manager = TaskManager::new();
981 let handle = task_manager.handle();
982
983 let _manager_handle = tokio::spawn(async move {
985 task_manager.run().await;
986 });
987
988 let result = handle
990 .start_task("nonexistent_command_12345".to_string(), None, None)
991 .await;
992
993 assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
995
996 handle
998 .shutdown()
999 .await
1000 .expect("Failed to shutdown task manager");
1001 }
1002
1003 #[tokio::test]
1004 async fn test_task_manager_handle_drop_triggers_shutdown() {
1005 let task_manager = TaskManager::new();
1006 let handle = task_manager.handle();
1007
1008 let manager_handle = tokio::spawn(async move {
1009 task_manager.run().await;
1010 });
1011
1012 let _task_info = handle
1014 .start_task("sleep 30".to_string(), None, None)
1015 .await
1016 .expect("Failed to start task");
1017
1018 drop(handle);
1020
1021 sleep(Duration::from_millis(500)).await;
1025
1026 assert!(
1027 manager_handle.is_finished(),
1028 "TaskManager::run() should have exited after handle was dropped"
1029 );
1030 }
1031
1032 #[tokio::test]
1033 async fn test_task_manager_handle_drop_kills_child_processes() {
1034 let task_manager = TaskManager::new();
1035 let handle = task_manager.handle();
1036
1037 let _manager_handle = tokio::spawn(async move {
1038 task_manager.run().await;
1039 });
1040
1041 let marker = format!("/tmp/stakpak_test_drop_{}", std::process::id());
1043 let task_info = handle
1044 .start_task(format!("touch {} && sleep 30", marker), None, None)
1045 .await
1046 .expect("Failed to start task");
1047
1048 let status = handle
1050 .get_task_status(task_info.id.clone())
1051 .await
1052 .expect("Failed to get status");
1053 assert_eq!(status, Some(TaskStatus::Running));
1054
1055 drop(handle);
1057 sleep(Duration::from_millis(500)).await;
1058
1059 let _ = std::fs::remove_file(&marker);
1061 }
1062
1063 #[tokio::test]
1064 async fn test_task_manager_detects_immediate_exit_code_failure() {
1065 let task_manager = TaskManager::new();
1066 let handle = task_manager.handle();
1067
1068 let _manager_handle = tokio::spawn(async move {
1070 task_manager.run().await;
1071 });
1072
1073 let result = handle.start_task("exit 1".to_string(), None, None).await;
1075
1076 assert!(matches!(result, Err(TaskError::TaskFailedOnStart(_))));
1078
1079 handle
1081 .shutdown()
1082 .await
1083 .expect("Failed to shutdown task manager");
1084 }
1085}