Skip to main content

ai_agent/
task.rs

1//! Task types and utilities translated from TypeScript Task.ts
2
3use std::collections::HashMap;
4
5/// Types of tasks that can be created
6#[derive(Debug, Clone, PartialEq)]
7#[allow(non_camel_case_types)]
8pub enum TaskType {
9    local_bash,
10    local_agent,
11    remote_agent,
12    in_process_teammate,
13    local_workflow,
14    monitor_mcp,
15    dream,
16}
17
18impl TaskType {
19    pub fn as_str(&self) -> &'static str {
20        match self {
21            TaskType::local_bash => "local_bash",
22            TaskType::local_agent => "local_agent",
23            TaskType::remote_agent => "remote_agent",
24            TaskType::in_process_teammate => "in_process_teammate",
25            TaskType::local_workflow => "local_workflow",
26            TaskType::monitor_mcp => "monitor_mcp",
27            TaskType::dream => "dream",
28        }
29    }
30
31    pub fn from_str(s: &str) -> Option<Self> {
32        match s {
33            "local_bash" => Some(TaskType::local_bash),
34            "local_agent" => Some(TaskType::local_agent),
35            "remote_agent" => Some(TaskType::remote_agent),
36            "in_process_teammate" => Some(TaskType::in_process_teammate),
37            "local_workflow" => Some(TaskType::local_workflow),
38            "monitor_mcp" => Some(TaskType::monitor_mcp),
39            "dream" => Some(TaskType::dream),
40            _ => None,
41        }
42    }
43}
44
45/// Status of a task
46#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
47#[serde(rename_all = "snake_case")]
48#[allow(non_camel_case_types)]
49pub enum TaskStatus {
50    pending,
51    running,
52    completed,
53    failed,
54    killed,
55}
56
57impl TaskStatus {
58    pub fn as_str(&self) -> &'static str {
59        match self {
60            TaskStatus::pending => "pending",
61            TaskStatus::running => "running",
62            TaskStatus::completed => "completed",
63            TaskStatus::failed => "failed",
64            TaskStatus::killed => "killed",
65        }
66    }
67
68    pub fn from_str(s: &str) -> Option<Self> {
69        match s {
70            "pending" => Some(TaskStatus::pending),
71            "running" => Some(TaskStatus::running),
72            "completed" => Some(TaskStatus::completed),
73            "failed" => Some(TaskStatus::failed),
74            "killed" => Some(TaskStatus::killed),
75            _ => None,
76        }
77    }
78}
79
80/// True when a task is in a terminal state and will not transition further.
81/// Used to guard against injecting messages into dead teammates, evicting
82/// finished tasks from AppState, and orphan-cleanup paths.
83pub fn is_terminal_task_status(status: &TaskStatus) -> bool {
84    matches!(
85        status,
86        TaskStatus::completed | TaskStatus::failed | TaskStatus::killed
87    )
88}
89
90/// Handle to a task, including its ID and optional cleanup callback
91pub struct TaskHandle {
92    pub task_id: String,
93    pub cleanup: Option<Box<dyn Fn() + Send>>,
94}
95
96impl Clone for TaskHandle {
97    fn clone(&self) -> Self {
98        // Note: cleanup cannot be cloned, so we set it to None
99        Self {
100            task_id: self.task_id.clone(),
101            cleanup: None,
102        }
103    }
104}
105
106/// Function type for updating application state
107pub type SetAppState = Box<dyn Fn(Box<dyn Fn() -> Box<dyn AppState>>) + Send + Sync>;
108
109/// Trait for application state
110pub trait AppState: Send + Sync {
111    // Basic trait for state management
112}
113
114/// Context passed to tasks containing abort controller and state access
115pub struct TaskContext {
116    pub abort_controller: AbortController,
117    pub get_app_state: Box<dyn Fn() -> Box<dyn AppState> + Send + Sync>,
118    pub set_app_state: SetAppState,
119}
120
121/// Abort controller for cancelling operations
122#[derive(Clone)]
123pub struct AbortController {
124    signal: Option<AbortSignal>,
125}
126
127impl AbortController {
128    pub fn new() -> Self {
129        Self { signal: None }
130    }
131
132    pub fn with_signal(signal: AbortSignal) -> Self {
133        Self {
134            signal: Some(signal),
135        }
136    }
137
138    pub fn signal(&self) -> Option<&AbortSignal> {
139        self.signal.as_ref()
140    }
141
142    pub fn abort(&self) {
143        if let Some(signal) = &self.signal {
144            signal
145                .aborted
146                .store(true, std::sync::atomic::Ordering::SeqCst);
147        }
148    }
149
150    pub fn is_aborted(&self) -> bool {
151        self.signal
152            .as_ref()
153            .map(|s| s.aborted.load(std::sync::atomic::Ordering::SeqCst))
154            .unwrap_or(false)
155    }
156}
157
158impl Default for AbortController {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164/// Abort signal for cancellation
165pub struct AbortSignal {
166    aborted: std::sync::atomic::AtomicBool,
167}
168
169impl AbortSignal {
170    pub fn new() -> Self {
171        Self {
172            aborted: std::sync::atomic::AtomicBool::new(false),
173        }
174    }
175
176    pub fn aborted(&self) -> bool {
177        self.aborted.load(std::sync::atomic::Ordering::SeqCst)
178    }
179}
180
181impl Clone for AbortSignal {
182    fn clone(&self) -> Self {
183        // AtomicBool doesn't implement Clone, but we can create a new one with the same value
184        Self {
185            aborted: std::sync::atomic::AtomicBool::new(self.aborted()),
186        }
187    }
188}
189
190impl Default for AbortSignal {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196/// Base fields shared by all task states
197#[derive(Debug, Clone)]
198pub struct TaskStateBase {
199    pub id: String,
200    pub task_type: TaskType,
201    pub status: TaskStatus,
202    pub description: String,
203    pub tool_use_id: Option<String>,
204    pub start_time: u64,
205    pub end_time: Option<u64>,
206    pub total_paused_ms: Option<u64>,
207    pub output_file: String,
208    pub output_offset: u64,
209    pub notified: bool,
210}
211
212/// Input for spawning a local shell task
213#[derive(Debug, Clone)]
214pub struct LocalShellSpawnInput {
215    pub command: String,
216    pub description: String,
217    pub timeout: Option<u64>,
218    pub tool_use_id: Option<String>,
219    pub agent_id: Option<String>,
220    /// UI display variant: description-as-label, dialog title, status bar pill.
221    pub kind: Option<ShellKind>,
222}
223
224/// Shell kind for UI display
225#[derive(Debug, Clone, PartialEq)]
226pub enum ShellKind {
227    bash,
228    monitor,
229}
230
231impl ShellKind {
232    pub fn as_str(&self) -> &'static str {
233        match self {
234            ShellKind::bash => "bash",
235            ShellKind::monitor => "monitor",
236        }
237    }
238
239    pub fn from_str(s: &str) -> Option<Self> {
240        match s {
241            "bash" => Some(ShellKind::bash),
242            "monitor" => Some(ShellKind::monitor),
243            _ => None,
244        }
245    }
246}
247
248/// Task trait for kill operations
249pub trait Task: Send + Sync {
250    fn name(&self) -> &str;
251    fn task_type(&self) -> TaskType;
252    fn kill(
253        &self,
254        task_id: &str,
255        set_app_state: SetAppState,
256    ) -> impl std::future::Future<Output = ()> + Send;
257}
258
259/// Task ID prefixes for backward compatibility
260pub const TASK_ID_PREFIXES: &[(&str, &str)] = &[
261    ("local_bash", "b"),
262    ("local_agent", "a"),
263    ("remote_agent", "r"),
264    ("in_process_teammate", "t"),
265    ("local_workflow", "w"),
266    ("monitor_mcp", "m"),
267    ("dream", "d"),
268];
269
270/// Get task ID prefix for a task type
271pub fn get_task_id_prefix(task_type: &TaskType) -> &'static str {
272    TASK_ID_PREFIXES
273        .iter()
274        .find(|(t, _)| *t == task_type.as_str())
275        .map(|(_, p)| *p)
276        .unwrap_or("x")
277}
278
279/// Case-insensitive-safe alphabet (digits + lowercase) for task IDs.
280/// 36^8 ≈ 2.8 trillion combinations, sufficient to resist brute-force symlink attacks.
281pub const TASK_ID_ALPHABET: &str = "0123456789abcdefghijklmnopqrstuvwxyz";
282
283/// Generate a unique task ID for a given task type
284pub fn generate_task_id(task_type: &TaskType) -> String {
285    use std::time::{SystemTime, UNIX_EPOCH};
286
287    let prefix = get_task_id_prefix(task_type);
288    let mut rng_seed = SystemTime::now()
289        .duration_since(UNIX_EPOCH)
290        .unwrap()
291        .as_nanos() as u64;
292
293    let mut id = prefix.to_string();
294    for i in 0..8 {
295        // Simple pseudo-random based on seed
296        rng_seed = rng_seed.wrapping_mul(1103515245).wrapping_add(12345);
297        let alphabet_idx = (rng_seed >> (i * 3)) as usize % TASK_ID_ALPHABET.len();
298        id.push(TASK_ID_ALPHABET.chars().nth(alphabet_idx).unwrap());
299    }
300    id
301}
302
303/// Get the output file path for a task
304pub fn get_task_output_path(task_id: &str) -> String {
305    // This would typically use a proper path, using a simple placeholder
306    format!("/tmp/task_output_{}.txt", task_id)
307}
308
309/// Create a base task state
310pub fn create_task_state_base(
311    id: String,
312    task_type: TaskType,
313    description: String,
314    tool_use_id: Option<String>,
315) -> TaskStateBase {
316    let now = std::time::SystemTime::now()
317        .duration_since(std::time::UNIX_EPOCH)
318        .unwrap()
319        .as_millis() as u64;
320
321    TaskStateBase {
322        id,
323        task_type,
324        status: TaskStatus::pending,
325        description,
326        tool_use_id,
327        start_time: now,
328        end_time: None,
329        total_paused_ms: None,
330        output_file: String::new(),
331        output_offset: 0,
332        notified: false,
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_task_type_strings() {
342        assert_eq!(TaskType::local_bash.as_str(), "local_bash");
343        assert_eq!(TaskType::remote_agent.as_str(), "remote_agent");
344    }
345
346    #[test]
347    fn test_task_status_strings() {
348        assert_eq!(TaskStatus::pending.as_str(), "pending");
349        assert_eq!(TaskStatus::completed.as_str(), "completed");
350    }
351
352    #[test]
353    fn test_is_terminal_task_status() {
354        assert!(!is_terminal_task_status(&TaskStatus::pending));
355        assert!(!is_terminal_task_status(&TaskStatus::running));
356        assert!(is_terminal_task_status(&TaskStatus::completed));
357        assert!(is_terminal_task_status(&TaskStatus::failed));
358        assert!(is_terminal_task_status(&TaskStatus::killed));
359    }
360
361    #[test]
362    fn test_shell_kind_strings() {
363        assert_eq!(ShellKind::bash.as_str(), "bash");
364        assert_eq!(ShellKind::monitor.as_str(), "monitor");
365    }
366
367    #[test]
368    fn test_generate_task_id() {
369        let id = generate_task_id(&TaskType::local_bash);
370        assert!(id.starts_with('b'));
371        assert_eq!(id.len(), 9); // 1 prefix + 8 chars
372    }
373
374    #[test]
375    fn test_task_id_prefix() {
376        assert_eq!(get_task_id_prefix(&TaskType::local_bash), "b");
377        assert_eq!(get_task_id_prefix(&TaskType::local_agent), "a");
378        assert_eq!(get_task_id_prefix(&TaskType::remote_agent), "r");
379        assert_eq!(get_task_id_prefix(&TaskType::in_process_teammate), "t");
380        assert_eq!(get_task_id_prefix(&TaskType::local_workflow), "w");
381        assert_eq!(get_task_id_prefix(&TaskType::monitor_mcp), "m");
382        assert_eq!(get_task_id_prefix(&TaskType::dream), "d");
383    }
384}