1use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq)]
7#[allow(non_camel_case_types)]
8pub enum TaskType {
9 local_bash,
10 local_agent,
11 remote_agent,
12 in_process_teammate,
13 local_workflow,
14 monitor_mcp,
15 dream,
16}
17
18impl TaskType {
19 pub fn as_str(&self) -> &'static str {
20 match self {
21 TaskType::local_bash => "local_bash",
22 TaskType::local_agent => "local_agent",
23 TaskType::remote_agent => "remote_agent",
24 TaskType::in_process_teammate => "in_process_teammate",
25 TaskType::local_workflow => "local_workflow",
26 TaskType::monitor_mcp => "monitor_mcp",
27 TaskType::dream => "dream",
28 }
29 }
30
31 pub fn from_str(s: &str) -> Option<Self> {
32 match s {
33 "local_bash" => Some(TaskType::local_bash),
34 "local_agent" => Some(TaskType::local_agent),
35 "remote_agent" => Some(TaskType::remote_agent),
36 "in_process_teammate" => Some(TaskType::in_process_teammate),
37 "local_workflow" => Some(TaskType::local_workflow),
38 "monitor_mcp" => Some(TaskType::monitor_mcp),
39 "dream" => Some(TaskType::dream),
40 _ => None,
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq)]
47#[allow(non_camel_case_types)]
48pub enum TaskStatus {
49 pending,
50 running,
51 completed,
52 failed,
53 killed,
54}
55
56impl TaskStatus {
57 pub fn as_str(&self) -> &'static str {
58 match self {
59 TaskStatus::pending => "pending",
60 TaskStatus::running => "running",
61 TaskStatus::completed => "completed",
62 TaskStatus::failed => "failed",
63 TaskStatus::killed => "killed",
64 }
65 }
66
67 pub fn from_str(s: &str) -> Option<Self> {
68 match s {
69 "pending" => Some(TaskStatus::pending),
70 "running" => Some(TaskStatus::running),
71 "completed" => Some(TaskStatus::completed),
72 "failed" => Some(TaskStatus::failed),
73 "killed" => Some(TaskStatus::killed),
74 _ => None,
75 }
76 }
77}
78
79pub fn is_terminal_task_status(status: &TaskStatus) -> bool {
83 matches!(
84 status,
85 TaskStatus::completed | TaskStatus::failed | TaskStatus::killed
86 )
87}
88
89pub struct TaskHandle {
91 pub task_id: String,
92 pub cleanup: Option<Box<dyn Fn() + Send>>,
93}
94
95impl Clone for TaskHandle {
96 fn clone(&self) -> Self {
97 Self {
99 task_id: self.task_id.clone(),
100 cleanup: None,
101 }
102 }
103}
104
105pub type SetAppState = Box<dyn Fn(Box<dyn Fn() -> Box<dyn AppState>>) + Send + Sync>;
107
108pub trait AppState: Send + Sync {
110 }
112
113pub struct TaskContext {
115 pub abort_controller: AbortController,
116 pub get_app_state: Box<dyn Fn() -> Box<dyn AppState> + Send + Sync>,
117 pub set_app_state: SetAppState,
118}
119
120#[derive(Clone)]
122pub struct AbortController {
123 signal: Option<AbortSignal>,
124}
125
126impl AbortController {
127 pub fn new() -> Self {
128 Self { signal: None }
129 }
130
131 pub fn with_signal(signal: AbortSignal) -> Self {
132 Self {
133 signal: Some(signal),
134 }
135 }
136
137 pub fn signal(&self) -> Option<&AbortSignal> {
138 self.signal.as_ref()
139 }
140
141 pub fn abort(&self) {
142 if let Some(signal) = &self.signal {
143 signal
144 .aborted
145 .store(true, std::sync::atomic::Ordering::SeqCst);
146 }
147 }
148
149 pub fn is_aborted(&self) -> bool {
150 self.signal
151 .as_ref()
152 .map(|s| s.aborted.load(std::sync::atomic::Ordering::SeqCst))
153 .unwrap_or(false)
154 }
155}
156
157impl Default for AbortController {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163pub struct AbortSignal {
165 aborted: std::sync::atomic::AtomicBool,
166}
167
168impl AbortSignal {
169 pub fn new() -> Self {
170 Self {
171 aborted: std::sync::atomic::AtomicBool::new(false),
172 }
173 }
174
175 pub fn aborted(&self) -> bool {
176 self.aborted.load(std::sync::atomic::Ordering::SeqCst)
177 }
178}
179
180impl Clone for AbortSignal {
181 fn clone(&self) -> Self {
182 Self {
184 aborted: std::sync::atomic::AtomicBool::new(self.aborted()),
185 }
186 }
187}
188
189impl Default for AbortSignal {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct TaskStateBase {
198 pub id: String,
199 pub task_type: TaskType,
200 pub status: TaskStatus,
201 pub description: String,
202 pub tool_use_id: Option<String>,
203 pub start_time: u64,
204 pub end_time: Option<u64>,
205 pub total_paused_ms: Option<u64>,
206 pub output_file: String,
207 pub output_offset: u64,
208 pub notified: bool,
209}
210
211#[derive(Debug, Clone)]
213pub struct LocalShellSpawnInput {
214 pub command: String,
215 pub description: String,
216 pub timeout: Option<u64>,
217 pub tool_use_id: Option<String>,
218 pub agent_id: Option<String>,
219 pub kind: Option<ShellKind>,
221}
222
223#[derive(Debug, Clone, PartialEq)]
225pub enum ShellKind {
226 bash,
227 monitor,
228}
229
230impl ShellKind {
231 pub fn as_str(&self) -> &'static str {
232 match self {
233 ShellKind::bash => "bash",
234 ShellKind::monitor => "monitor",
235 }
236 }
237
238 pub fn from_str(s: &str) -> Option<Self> {
239 match s {
240 "bash" => Some(ShellKind::bash),
241 "monitor" => Some(ShellKind::monitor),
242 _ => None,
243 }
244 }
245}
246
247pub trait Task: Send + Sync {
249 fn name(&self) -> &str;
250 fn task_type(&self) -> TaskType;
251 fn kill(
252 &self,
253 task_id: &str,
254 set_app_state: SetAppState,
255 ) -> impl std::future::Future<Output = ()> + Send;
256}
257
258pub const TASK_ID_PREFIXES: &[(&str, &str)] = &[
260 ("local_bash", "b"),
261 ("local_agent", "a"),
262 ("remote_agent", "r"),
263 ("in_process_teammate", "t"),
264 ("local_workflow", "w"),
265 ("monitor_mcp", "m"),
266 ("dream", "d"),
267];
268
269pub fn get_task_id_prefix(task_type: &TaskType) -> &'static str {
271 TASK_ID_PREFIXES
272 .iter()
273 .find(|(t, _)| *t == task_type.as_str())
274 .map(|(_, p)| *p)
275 .unwrap_or("x")
276}
277
278pub const TASK_ID_ALPHABET: &str = "0123456789abcdefghijklmnopqrstuvwxyz";
281
282pub fn generate_task_id(task_type: &TaskType) -> String {
284 use std::time::{SystemTime, UNIX_EPOCH};
285
286 let prefix = get_task_id_prefix(task_type);
287 let mut rng_seed = SystemTime::now()
288 .duration_since(UNIX_EPOCH)
289 .unwrap()
290 .as_nanos() as u64;
291
292 let mut id = prefix.to_string();
293 for i in 0..8 {
294 rng_seed = rng_seed.wrapping_mul(1103515245).wrapping_add(12345);
296 let alphabet_idx = (rng_seed >> (i * 3)) as usize % TASK_ID_ALPHABET.len();
297 id.push(TASK_ID_ALPHABET.chars().nth(alphabet_idx).unwrap());
298 }
299 id
300}
301
302pub fn get_task_output_path(task_id: &str) -> String {
304 format!("/tmp/task_output_{}.txt", task_id)
306}
307
308pub fn create_task_state_base(
310 id: String,
311 task_type: TaskType,
312 description: String,
313 tool_use_id: Option<String>,
314) -> TaskStateBase {
315 let now = std::time::SystemTime::now()
316 .duration_since(std::time::UNIX_EPOCH)
317 .unwrap()
318 .as_millis() as u64;
319
320 TaskStateBase {
321 id,
322 task_type,
323 status: TaskStatus::pending,
324 description,
325 tool_use_id,
326 start_time: now,
327 end_time: None,
328 total_paused_ms: None,
329 output_file: String::new(),
330 output_offset: 0,
331 notified: false,
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_task_type_strings() {
341 assert_eq!(TaskType::local_bash.as_str(), "local_bash");
342 assert_eq!(TaskType::remote_agent.as_str(), "remote_agent");
343 }
344
345 #[test]
346 fn test_task_status_strings() {
347 assert_eq!(TaskStatus::pending.as_str(), "pending");
348 assert_eq!(TaskStatus::completed.as_str(), "completed");
349 }
350
351 #[test]
352 fn test_is_terminal_task_status() {
353 assert!(!is_terminal_task_status(&TaskStatus::pending));
354 assert!(!is_terminal_task_status(&TaskStatus::running));
355 assert!(is_terminal_task_status(&TaskStatus::completed));
356 assert!(is_terminal_task_status(&TaskStatus::failed));
357 assert!(is_terminal_task_status(&TaskStatus::killed));
358 }
359
360 #[test]
361 fn test_shell_kind_strings() {
362 assert_eq!(ShellKind::bash.as_str(), "bash");
363 assert_eq!(ShellKind::monitor.as_str(), "monitor");
364 }
365
366 #[test]
367 fn test_generate_task_id() {
368 let id = generate_task_id(&TaskType::local_bash);
369 assert!(id.starts_with('b'));
370 assert_eq!(id.len(), 9); }
372
373 #[test]
374 fn test_task_id_prefix() {
375 assert_eq!(get_task_id_prefix(&TaskType::local_bash), "b");
376 assert_eq!(get_task_id_prefix(&TaskType::local_agent), "a");
377 assert_eq!(get_task_id_prefix(&TaskType::remote_agent), "r");
378 assert_eq!(get_task_id_prefix(&TaskType::in_process_teammate), "t");
379 assert_eq!(get_task_id_prefix(&TaskType::local_workflow), "w");
380 assert_eq!(get_task_id_prefix(&TaskType::monitor_mcp), "m");
381 assert_eq!(get_task_id_prefix(&TaskType::dream), "d");
382 }
383}