1use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq)]
7#[allow(non_camel_case_types)]
8pub enum TaskType {
9 local_bash,
10 local_agent,
11 remote_agent,
12 in_process_teammate,
13 local_workflow,
14 monitor_mcp,
15 dream,
16}
17
18impl TaskType {
19 pub fn as_str(&self) -> &'static str {
20 match self {
21 TaskType::local_bash => "local_bash",
22 TaskType::local_agent => "local_agent",
23 TaskType::remote_agent => "remote_agent",
24 TaskType::in_process_teammate => "in_process_teammate",
25 TaskType::local_workflow => "local_workflow",
26 TaskType::monitor_mcp => "monitor_mcp",
27 TaskType::dream => "dream",
28 }
29 }
30
31 pub fn from_str(s: &str) -> Option<Self> {
32 match s {
33 "local_bash" => Some(TaskType::local_bash),
34 "local_agent" => Some(TaskType::local_agent),
35 "remote_agent" => Some(TaskType::remote_agent),
36 "in_process_teammate" => Some(TaskType::in_process_teammate),
37 "local_workflow" => Some(TaskType::local_workflow),
38 "monitor_mcp" => Some(TaskType::monitor_mcp),
39 "dream" => Some(TaskType::dream),
40 _ => None,
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
47#[serde(rename_all = "snake_case")]
48#[allow(non_camel_case_types)]
49pub enum TaskStatus {
50 pending,
51 running,
52 completed,
53 failed,
54 killed,
55}
56
57impl TaskStatus {
58 pub fn as_str(&self) -> &'static str {
59 match self {
60 TaskStatus::pending => "pending",
61 TaskStatus::running => "running",
62 TaskStatus::completed => "completed",
63 TaskStatus::failed => "failed",
64 TaskStatus::killed => "killed",
65 }
66 }
67
68 pub fn from_str(s: &str) -> Option<Self> {
69 match s {
70 "pending" => Some(TaskStatus::pending),
71 "running" => Some(TaskStatus::running),
72 "completed" => Some(TaskStatus::completed),
73 "failed" => Some(TaskStatus::failed),
74 "killed" => Some(TaskStatus::killed),
75 _ => None,
76 }
77 }
78}
79
80pub fn is_terminal_task_status(status: &TaskStatus) -> bool {
84 matches!(
85 status,
86 TaskStatus::completed | TaskStatus::failed | TaskStatus::killed
87 )
88}
89
90pub struct TaskHandle {
92 pub task_id: String,
93 pub cleanup: Option<Box<dyn Fn() + Send>>,
94}
95
96impl Clone for TaskHandle {
97 fn clone(&self) -> Self {
98 Self {
100 task_id: self.task_id.clone(),
101 cleanup: None,
102 }
103 }
104}
105
106pub type SetAppState = Box<dyn Fn(Box<dyn Fn() -> Box<dyn AppState>>) + Send + Sync>;
108
109pub trait AppState: Send + Sync {
111 }
113
114pub struct TaskContext {
116 pub abort_controller: AbortController,
117 pub get_app_state: Box<dyn Fn() -> Box<dyn AppState> + Send + Sync>,
118 pub set_app_state: SetAppState,
119}
120
121#[derive(Clone)]
123pub struct AbortController {
124 signal: Option<AbortSignal>,
125}
126
127impl AbortController {
128 pub fn new() -> Self {
129 Self { signal: None }
130 }
131
132 pub fn with_signal(signal: AbortSignal) -> Self {
133 Self {
134 signal: Some(signal),
135 }
136 }
137
138 pub fn signal(&self) -> Option<&AbortSignal> {
139 self.signal.as_ref()
140 }
141
142 pub fn abort(&self) {
143 if let Some(signal) = &self.signal {
144 signal
145 .aborted
146 .store(true, std::sync::atomic::Ordering::SeqCst);
147 }
148 }
149
150 pub fn is_aborted(&self) -> bool {
151 self.signal
152 .as_ref()
153 .map(|s| s.aborted.load(std::sync::atomic::Ordering::SeqCst))
154 .unwrap_or(false)
155 }
156}
157
158impl Default for AbortController {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164pub struct AbortSignal {
166 aborted: std::sync::atomic::AtomicBool,
167}
168
169impl AbortSignal {
170 pub fn new() -> Self {
171 Self {
172 aborted: std::sync::atomic::AtomicBool::new(false),
173 }
174 }
175
176 pub fn aborted(&self) -> bool {
177 self.aborted.load(std::sync::atomic::Ordering::SeqCst)
178 }
179}
180
181impl Clone for AbortSignal {
182 fn clone(&self) -> Self {
183 Self {
185 aborted: std::sync::atomic::AtomicBool::new(self.aborted()),
186 }
187 }
188}
189
190impl Default for AbortSignal {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct TaskStateBase {
199 pub id: String,
200 pub task_type: TaskType,
201 pub status: TaskStatus,
202 pub description: String,
203 pub tool_use_id: Option<String>,
204 pub start_time: u64,
205 pub end_time: Option<u64>,
206 pub total_paused_ms: Option<u64>,
207 pub output_file: String,
208 pub output_offset: u64,
209 pub notified: bool,
210}
211
212#[derive(Debug, Clone)]
214pub struct LocalShellSpawnInput {
215 pub command: String,
216 pub description: String,
217 pub timeout: Option<u64>,
218 pub tool_use_id: Option<String>,
219 pub agent_id: Option<String>,
220 pub kind: Option<ShellKind>,
222}
223
224#[derive(Debug, Clone, PartialEq)]
226pub enum ShellKind {
227 bash,
228 monitor,
229}
230
231impl ShellKind {
232 pub fn as_str(&self) -> &'static str {
233 match self {
234 ShellKind::bash => "bash",
235 ShellKind::monitor => "monitor",
236 }
237 }
238
239 pub fn from_str(s: &str) -> Option<Self> {
240 match s {
241 "bash" => Some(ShellKind::bash),
242 "monitor" => Some(ShellKind::monitor),
243 _ => None,
244 }
245 }
246}
247
248pub trait Task: Send + Sync {
250 fn name(&self) -> &str;
251 fn task_type(&self) -> TaskType;
252 fn kill(
253 &self,
254 task_id: &str,
255 set_app_state: SetAppState,
256 ) -> impl std::future::Future<Output = ()> + Send;
257}
258
259pub const TASK_ID_PREFIXES: &[(&str, &str)] = &[
261 ("local_bash", "b"),
262 ("local_agent", "a"),
263 ("remote_agent", "r"),
264 ("in_process_teammate", "t"),
265 ("local_workflow", "w"),
266 ("monitor_mcp", "m"),
267 ("dream", "d"),
268];
269
270pub fn get_task_id_prefix(task_type: &TaskType) -> &'static str {
272 TASK_ID_PREFIXES
273 .iter()
274 .find(|(t, _)| *t == task_type.as_str())
275 .map(|(_, p)| *p)
276 .unwrap_or("x")
277}
278
279pub const TASK_ID_ALPHABET: &str = "0123456789abcdefghijklmnopqrstuvwxyz";
282
283pub fn generate_task_id(task_type: &TaskType) -> String {
285 use std::time::{SystemTime, UNIX_EPOCH};
286
287 let prefix = get_task_id_prefix(task_type);
288 let mut rng_seed = SystemTime::now()
289 .duration_since(UNIX_EPOCH)
290 .unwrap()
291 .as_nanos() as u64;
292
293 let mut id = prefix.to_string();
294 for i in 0..8 {
295 rng_seed = rng_seed.wrapping_mul(1103515245).wrapping_add(12345);
297 let alphabet_idx = (rng_seed >> (i * 3)) as usize % TASK_ID_ALPHABET.len();
298 id.push(TASK_ID_ALPHABET.chars().nth(alphabet_idx).unwrap());
299 }
300 id
301}
302
303pub fn get_task_output_path(task_id: &str) -> String {
305 format!("/tmp/task_output_{}.txt", task_id)
307}
308
309pub fn create_task_state_base(
311 id: String,
312 task_type: TaskType,
313 description: String,
314 tool_use_id: Option<String>,
315) -> TaskStateBase {
316 let now = std::time::SystemTime::now()
317 .duration_since(std::time::UNIX_EPOCH)
318 .unwrap()
319 .as_millis() as u64;
320
321 TaskStateBase {
322 id,
323 task_type,
324 status: TaskStatus::pending,
325 description,
326 tool_use_id,
327 start_time: now,
328 end_time: None,
329 total_paused_ms: None,
330 output_file: String::new(),
331 output_offset: 0,
332 notified: false,
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_task_type_strings() {
342 assert_eq!(TaskType::local_bash.as_str(), "local_bash");
343 assert_eq!(TaskType::remote_agent.as_str(), "remote_agent");
344 }
345
346 #[test]
347 fn test_task_status_strings() {
348 assert_eq!(TaskStatus::pending.as_str(), "pending");
349 assert_eq!(TaskStatus::completed.as_str(), "completed");
350 }
351
352 #[test]
353 fn test_is_terminal_task_status() {
354 assert!(!is_terminal_task_status(&TaskStatus::pending));
355 assert!(!is_terminal_task_status(&TaskStatus::running));
356 assert!(is_terminal_task_status(&TaskStatus::completed));
357 assert!(is_terminal_task_status(&TaskStatus::failed));
358 assert!(is_terminal_task_status(&TaskStatus::killed));
359 }
360
361 #[test]
362 fn test_shell_kind_strings() {
363 assert_eq!(ShellKind::bash.as_str(), "bash");
364 assert_eq!(ShellKind::monitor.as_str(), "monitor");
365 }
366
367 #[test]
368 fn test_generate_task_id() {
369 let id = generate_task_id(&TaskType::local_bash);
370 assert!(id.starts_with('b'));
371 assert_eq!(id.len(), 9); }
373
374 #[test]
375 fn test_task_id_prefix() {
376 assert_eq!(get_task_id_prefix(&TaskType::local_bash), "b");
377 assert_eq!(get_task_id_prefix(&TaskType::local_agent), "a");
378 assert_eq!(get_task_id_prefix(&TaskType::remote_agent), "r");
379 assert_eq!(get_task_id_prefix(&TaskType::in_process_teammate), "t");
380 assert_eq!(get_task_id_prefix(&TaskType::local_workflow), "w");
381 assert_eq!(get_task_id_prefix(&TaskType::monitor_mcp), "m");
382 assert_eq!(get_task_id_prefix(&TaskType::dream), "d");
383 }
384}