ai_agent/tasks/
stop_task.rs1#![allow(dead_code)]
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use thiserror::Error;
8
9use crate::tasks::guards::is_local_shell_task_from_value;
10
11#[derive(Debug, Error)]
13pub enum StopTaskError {
14 #[error("No task found with ID: {0}")]
15 NotFound(String),
16 #[error("Task {0} is not running (status: {1})")]
17 NotRunning(String, String),
18 #[error("Unsupported task type: {0}")]
19 UnsupportedType(String),
20}
21
22pub struct StopTaskContext {
24 pub get_app_state: Box<dyn Fn() -> serde_json::Value>,
25 pub set_app_state: Box<dyn Fn(Box<dyn Fn(&serde_json::Value) -> serde_json::Value>)>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct StopTaskResult {
31 pub task_id: String,
32 #[serde(rename = "taskType")]
33 pub task_type: String,
34 pub command: Option<String>,
35}
36
37pub async fn stop_task(
43 task_id: &str,
44 context: &StopTaskContext,
45) -> Result<StopTaskResult, StopTaskError> {
46 let app_state = (context.get_app_state)();
47
48 let task = app_state
49 .get("tasks")
50 .and_then(|t| t.get("tasks").or_else(|| t.get(task_id)))
51 .or_else(|| app_state.get(task_id));
52
53 let task = match task {
54 Some(t) => t,
55 None => return Err(StopTaskError::NotFound(task_id.to_string())),
56 };
57
58 let status = task
59 .get("status")
60 .and_then(|s| s.as_str())
61 .unwrap_or("")
62 .to_string();
63
64 if status != "running" {
65 return Err(StopTaskError::NotRunning(task_id.to_string(), status));
66 }
67
68 let task_type = task
69 .get("type")
70 .and_then(|t| t.as_str())
71 .unwrap_or("")
72 .to_string();
73
74 let task_impl = get_task_by_type(&task_type);
75 if task_impl.is_none() {
76 return Err(StopTaskError::UnsupportedType(task_type.clone()));
77 }
78
79 let task_impl = task_impl.unwrap();
81 task_impl.kill(task_id, &context.set_app_state);
82
83 let is_shell_task = is_local_shell_task_from_value(task);
87 if is_shell_task {
88 let task_id_owned = task_id.to_string();
89 let suppressed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
90 let suppressed_clone = suppressed.clone();
91
92 (context.set_app_state)(Box::new(move |prev: &serde_json::Value| {
93 let prev_task = prev
94 .get("tasks")
95 .and_then(|t| t.get(task_id_owned.as_str()));
96 if let Some(prev_task) = prev_task {
97 if prev_task.get("notified").and_then(|n| n.as_bool()) == Some(false) {
98 suppressed_clone.store(true, std::sync::atomic::Ordering::SeqCst);
99 let mut new_prev = prev.clone();
100 if let Some(obj) = new_prev.as_object_mut() {
101 if let Some(tasks) = obj.get_mut("tasks") {
102 if let Some(tasks_obj) = tasks.as_object_mut() {
103 if let Some(task) = tasks_obj.get_mut(task_id_owned.as_str()) {
104 if let Some(task_obj) = task.as_object_mut() {
105 task_obj.insert(
106 "notified".to_string(),
107 serde_json::json!(true),
108 );
109 }
110 }
111 }
112 }
113 }
114 return new_prev;
115 }
116 }
117 prev.clone()
118 }));
119
120 if suppressed.load(std::sync::atomic::Ordering::SeqCst) {
124 let tool_use_id = task
125 .get("toolUseId")
126 .and_then(|v| v.as_str())
127 .map(|s| s.to_string());
128 let summary = task
129 .get("description")
130 .and_then(|v| v.as_str())
131 .map(|s| s.to_string());
132 emit_task_terminated_sdk(task_id, tool_use_id, summary);
133 }
134 }
135
136 let command = if is_shell_task {
137 task.get("command")
138 .and_then(|v| v.as_str())
139 .map(|s| s.to_string())
140 } else {
141 task.get("description")
142 .and_then(|v| v.as_str())
143 .map(|s| s.to_string())
144 };
145
146 Ok(StopTaskResult {
147 task_id: task_id.to_string(),
148 task_type,
149 command,
150 })
151}
152
153pub trait Task: Send + Sync {
155 fn name(&self) -> &str;
156 fn task_type(&self) -> &str;
157 fn kill(
158 &self,
159 task_id: &str,
160 set_app_state: &dyn Fn(Box<dyn Fn(&serde_json::Value) -> serde_json::Value>),
161 );
162}
163
164fn get_task_by_type(_task_type: &str) -> Option<Box<dyn Task>> {
166 None
168}
169
170fn emit_task_terminated_sdk(task_id: &str, tool_use_id: Option<String>, summary: Option<String>) {
172 crate::utils::sdk_event_queue::emit_task_terminated_sdk(
173 task_id,
174 tool_use_id,
175 "stopped",
176 summary,
177 None,
178 None,
179 );
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_stop_task_error_not_found() {
188 let error = StopTaskError::NotFound("test-id".to_string());
189 assert!(error.to_string().contains("test-id"));
190 }
191
192 #[test]
193 fn test_stop_task_error_not_running() {
194 let error = StopTaskError::NotRunning("test-id".to_string(), "pending".to_string());
195 assert!(error.to_string().contains("test-id"));
196 assert!(error.to_string().contains("pending"));
197 }
198
199 #[test]
200 fn test_stop_task_error_unsupported_type() {
201 let error = StopTaskError::UnsupportedType("unknown".to_string());
202 assert!(error.to_string().contains("unknown"));
203 }
204}