Skip to main content

ai_agent/
task.rs

1//! Task types and utilities translated from TypeScript Task.ts
2
3use std::collections::HashMap;
4
5/// Types of tasks that can be created
6#[derive(Debug, Clone, PartialEq)]
7#[allow(non_camel_case_types)]
8pub enum TaskType {
9    local_bash,
10    local_agent,
11    remote_agent,
12    in_process_teammate,
13    local_workflow,
14    monitor_mcp,
15    dream,
16}
17
18impl TaskType {
19    pub fn as_str(&self) -> &'static str {
20        match self {
21            TaskType::local_bash => "local_bash",
22            TaskType::local_agent => "local_agent",
23            TaskType::remote_agent => "remote_agent",
24            TaskType::in_process_teammate => "in_process_teammate",
25            TaskType::local_workflow => "local_workflow",
26            TaskType::monitor_mcp => "monitor_mcp",
27            TaskType::dream => "dream",
28        }
29    }
30
31    pub fn from_str(s: &str) -> Option<Self> {
32        match s {
33            "local_bash" => Some(TaskType::local_bash),
34            "local_agent" => Some(TaskType::local_agent),
35            "remote_agent" => Some(TaskType::remote_agent),
36            "in_process_teammate" => Some(TaskType::in_process_teammate),
37            "local_workflow" => Some(TaskType::local_workflow),
38            "monitor_mcp" => Some(TaskType::monitor_mcp),
39            "dream" => Some(TaskType::dream),
40            _ => None,
41        }
42    }
43}
44
45/// Status of a task
46#[derive(Debug, Clone, PartialEq)]
47#[allow(non_camel_case_types)]
48pub enum TaskStatus {
49    pending,
50    running,
51    completed,
52    failed,
53    killed,
54}
55
56impl TaskStatus {
57    pub fn as_str(&self) -> &'static str {
58        match self {
59            TaskStatus::pending => "pending",
60            TaskStatus::running => "running",
61            TaskStatus::completed => "completed",
62            TaskStatus::failed => "failed",
63            TaskStatus::killed => "killed",
64        }
65    }
66
67    pub fn from_str(s: &str) -> Option<Self> {
68        match s {
69            "pending" => Some(TaskStatus::pending),
70            "running" => Some(TaskStatus::running),
71            "completed" => Some(TaskStatus::completed),
72            "failed" => Some(TaskStatus::failed),
73            "killed" => Some(TaskStatus::killed),
74            _ => None,
75        }
76    }
77}
78
79/// True when a task is in a terminal state and will not transition further.
80/// Used to guard against injecting messages into dead teammates, evicting
81/// finished tasks from AppState, and orphan-cleanup paths.
82pub fn is_terminal_task_status(status: &TaskStatus) -> bool {
83    matches!(
84        status,
85        TaskStatus::completed | TaskStatus::failed | TaskStatus::killed
86    )
87}
88
89/// Handle to a task, including its ID and optional cleanup callback
90pub struct TaskHandle {
91    pub task_id: String,
92    pub cleanup: Option<Box<dyn Fn() + Send>>,
93}
94
95impl Clone for TaskHandle {
96    fn clone(&self) -> Self {
97        // Note: cleanup cannot be cloned, so we set it to None
98        Self {
99            task_id: self.task_id.clone(),
100            cleanup: None,
101        }
102    }
103}
104
105/// Function type for updating application state
106pub type SetAppState = Box<dyn Fn(Box<dyn Fn() -> Box<dyn AppState>>) + Send + Sync>;
107
108/// Trait for application state
109pub trait AppState: Send + Sync {
110    // Basic trait for state management
111}
112
113/// Context passed to tasks containing abort controller and state access
114pub struct TaskContext {
115    pub abort_controller: AbortController,
116    pub get_app_state: Box<dyn Fn() -> Box<dyn AppState> + Send + Sync>,
117    pub set_app_state: SetAppState,
118}
119
120/// Abort controller for cancelling operations
121#[derive(Clone)]
122pub struct AbortController {
123    signal: Option<AbortSignal>,
124}
125
126impl AbortController {
127    pub fn new() -> Self {
128        Self { signal: None }
129    }
130
131    pub fn with_signal(signal: AbortSignal) -> Self {
132        Self {
133            signal: Some(signal),
134        }
135    }
136
137    pub fn signal(&self) -> Option<&AbortSignal> {
138        self.signal.as_ref()
139    }
140
141    pub fn abort(&self) {
142        if let Some(signal) = &self.signal {
143            signal
144                .aborted
145                .store(true, std::sync::atomic::Ordering::SeqCst);
146        }
147    }
148
149    pub fn is_aborted(&self) -> bool {
150        self.signal
151            .as_ref()
152            .map(|s| s.aborted.load(std::sync::atomic::Ordering::SeqCst))
153            .unwrap_or(false)
154    }
155}
156
157impl Default for AbortController {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163/// Abort signal for cancellation
164pub struct AbortSignal {
165    aborted: std::sync::atomic::AtomicBool,
166}
167
168impl AbortSignal {
169    pub fn new() -> Self {
170        Self {
171            aborted: std::sync::atomic::AtomicBool::new(false),
172        }
173    }
174
175    pub fn aborted(&self) -> bool {
176        self.aborted.load(std::sync::atomic::Ordering::SeqCst)
177    }
178}
179
180impl Clone for AbortSignal {
181    fn clone(&self) -> Self {
182        // AtomicBool doesn't implement Clone, but we can create a new one with the same value
183        Self {
184            aborted: std::sync::atomic::AtomicBool::new(self.aborted()),
185        }
186    }
187}
188
189impl Default for AbortSignal {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195/// Base fields shared by all task states
196#[derive(Debug, Clone)]
197pub struct TaskStateBase {
198    pub id: String,
199    pub task_type: TaskType,
200    pub status: TaskStatus,
201    pub description: String,
202    pub tool_use_id: Option<String>,
203    pub start_time: u64,
204    pub end_time: Option<u64>,
205    pub total_paused_ms: Option<u64>,
206    pub output_file: String,
207    pub output_offset: u64,
208    pub notified: bool,
209}
210
211/// Input for spawning a local shell task
212#[derive(Debug, Clone)]
213pub struct LocalShellSpawnInput {
214    pub command: String,
215    pub description: String,
216    pub timeout: Option<u64>,
217    pub tool_use_id: Option<String>,
218    pub agent_id: Option<String>,
219    /// UI display variant: description-as-label, dialog title, status bar pill.
220    pub kind: Option<ShellKind>,
221}
222
223/// Shell kind for UI display
224#[derive(Debug, Clone, PartialEq)]
225pub enum ShellKind {
226    bash,
227    monitor,
228}
229
230impl ShellKind {
231    pub fn as_str(&self) -> &'static str {
232        match self {
233            ShellKind::bash => "bash",
234            ShellKind::monitor => "monitor",
235        }
236    }
237
238    pub fn from_str(s: &str) -> Option<Self> {
239        match s {
240            "bash" => Some(ShellKind::bash),
241            "monitor" => Some(ShellKind::monitor),
242            _ => None,
243        }
244    }
245}
246
247/// Task trait for kill operations
248pub trait Task: Send + Sync {
249    fn name(&self) -> &str;
250    fn task_type(&self) -> TaskType;
251    fn kill(
252        &self,
253        task_id: &str,
254        set_app_state: SetAppState,
255    ) -> impl std::future::Future<Output = ()> + Send;
256}
257
258/// Task ID prefixes for backward compatibility
259pub const TASK_ID_PREFIXES: &[(&str, &str)] = &[
260    ("local_bash", "b"),
261    ("local_agent", "a"),
262    ("remote_agent", "r"),
263    ("in_process_teammate", "t"),
264    ("local_workflow", "w"),
265    ("monitor_mcp", "m"),
266    ("dream", "d"),
267];
268
269/// Get task ID prefix for a task type
270pub fn get_task_id_prefix(task_type: &TaskType) -> &'static str {
271    TASK_ID_PREFIXES
272        .iter()
273        .find(|(t, _)| *t == task_type.as_str())
274        .map(|(_, p)| *p)
275        .unwrap_or("x")
276}
277
278/// Case-insensitive-safe alphabet (digits + lowercase) for task IDs.
279/// 36^8 ≈ 2.8 trillion combinations, sufficient to resist brute-force symlink attacks.
280pub const TASK_ID_ALPHABET: &str = "0123456789abcdefghijklmnopqrstuvwxyz";
281
282/// Generate a unique task ID for a given task type
283pub fn generate_task_id(task_type: &TaskType) -> String {
284    use std::time::{SystemTime, UNIX_EPOCH};
285
286    let prefix = get_task_id_prefix(task_type);
287    let mut rng_seed = SystemTime::now()
288        .duration_since(UNIX_EPOCH)
289        .unwrap()
290        .as_nanos() as u64;
291
292    let mut id = prefix.to_string();
293    for i in 0..8 {
294        // Simple pseudo-random based on seed
295        rng_seed = rng_seed.wrapping_mul(1103515245).wrapping_add(12345);
296        let alphabet_idx = (rng_seed >> (i * 3)) as usize % TASK_ID_ALPHABET.len();
297        id.push(TASK_ID_ALPHABET.chars().nth(alphabet_idx).unwrap());
298    }
299    id
300}
301
302/// Get the output file path for a task
303pub fn get_task_output_path(task_id: &str) -> String {
304    // This would typically use a proper path, using a simple placeholder
305    format!("/tmp/task_output_{}.txt", task_id)
306}
307
308/// Create a base task state
309pub fn create_task_state_base(
310    id: String,
311    task_type: TaskType,
312    description: String,
313    tool_use_id: Option<String>,
314) -> TaskStateBase {
315    let now = std::time::SystemTime::now()
316        .duration_since(std::time::UNIX_EPOCH)
317        .unwrap()
318        .as_millis() as u64;
319
320    TaskStateBase {
321        id,
322        task_type,
323        status: TaskStatus::pending,
324        description,
325        tool_use_id,
326        start_time: now,
327        end_time: None,
328        total_paused_ms: None,
329        output_file: String::new(),
330        output_offset: 0,
331        notified: false,
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_task_type_strings() {
341        assert_eq!(TaskType::local_bash.as_str(), "local_bash");
342        assert_eq!(TaskType::remote_agent.as_str(), "remote_agent");
343    }
344
345    #[test]
346    fn test_task_status_strings() {
347        assert_eq!(TaskStatus::pending.as_str(), "pending");
348        assert_eq!(TaskStatus::completed.as_str(), "completed");
349    }
350
351    #[test]
352    fn test_is_terminal_task_status() {
353        assert!(!is_terminal_task_status(&TaskStatus::pending));
354        assert!(!is_terminal_task_status(&TaskStatus::running));
355        assert!(is_terminal_task_status(&TaskStatus::completed));
356        assert!(is_terminal_task_status(&TaskStatus::failed));
357        assert!(is_terminal_task_status(&TaskStatus::killed));
358    }
359
360    #[test]
361    fn test_shell_kind_strings() {
362        assert_eq!(ShellKind::bash.as_str(), "bash");
363        assert_eq!(ShellKind::monitor.as_str(), "monitor");
364    }
365
366    #[test]
367    fn test_generate_task_id() {
368        let id = generate_task_id(&TaskType::local_bash);
369        assert!(id.starts_with('b'));
370        assert_eq!(id.len(), 9); // 1 prefix + 8 chars
371    }
372
373    #[test]
374    fn test_task_id_prefix() {
375        assert_eq!(get_task_id_prefix(&TaskType::local_bash), "b");
376        assert_eq!(get_task_id_prefix(&TaskType::local_agent), "a");
377        assert_eq!(get_task_id_prefix(&TaskType::remote_agent), "r");
378        assert_eq!(get_task_id_prefix(&TaskType::in_process_teammate), "t");
379        assert_eq!(get_task_id_prefix(&TaskType::local_workflow), "w");
380        assert_eq!(get_task_id_prefix(&TaskType::monitor_mcp), "m");
381        assert_eq!(get_task_id_prefix(&TaskType::dream), "d");
382    }
383}