use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
#[allow(non_camel_case_types)]
pub enum TaskType {
local_bash,
local_agent,
remote_agent,
in_process_teammate,
local_workflow,
monitor_mcp,
dream,
}
impl TaskType {
pub fn as_str(&self) -> &'static str {
match self {
TaskType::local_bash => "local_bash",
TaskType::local_agent => "local_agent",
TaskType::remote_agent => "remote_agent",
TaskType::in_process_teammate => "in_process_teammate",
TaskType::local_workflow => "local_workflow",
TaskType::monitor_mcp => "monitor_mcp",
TaskType::dream => "dream",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"local_bash" => Some(TaskType::local_bash),
"local_agent" => Some(TaskType::local_agent),
"remote_agent" => Some(TaskType::remote_agent),
"in_process_teammate" => Some(TaskType::in_process_teammate),
"local_workflow" => Some(TaskType::local_workflow),
"monitor_mcp" => Some(TaskType::monitor_mcp),
"dream" => Some(TaskType::dream),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
#[allow(non_camel_case_types)]
pub enum TaskStatus {
pending,
running,
completed,
failed,
killed,
}
impl TaskStatus {
pub fn as_str(&self) -> &'static str {
match self {
TaskStatus::pending => "pending",
TaskStatus::running => "running",
TaskStatus::completed => "completed",
TaskStatus::failed => "failed",
TaskStatus::killed => "killed",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"pending" => Some(TaskStatus::pending),
"running" => Some(TaskStatus::running),
"completed" => Some(TaskStatus::completed),
"failed" => Some(TaskStatus::failed),
"killed" => Some(TaskStatus::killed),
_ => None,
}
}
}
pub fn is_terminal_task_status(status: &TaskStatus) -> bool {
matches!(
status,
TaskStatus::completed | TaskStatus::failed | TaskStatus::killed
)
}
pub struct TaskHandle {
pub task_id: String,
pub cleanup: Option<Box<dyn Fn() + Send>>,
}
impl Clone for TaskHandle {
fn clone(&self) -> Self {
Self {
task_id: self.task_id.clone(),
cleanup: None,
}
}
}
pub type SetAppState = Box<dyn Fn(Box<dyn Fn() -> Box<dyn AppState>>) + Send + Sync>;
pub trait AppState: Send + Sync {
}
pub struct TaskContext {
pub abort_controller: AbortController,
pub get_app_state: Box<dyn Fn() -> Box<dyn AppState> + Send + Sync>,
pub set_app_state: SetAppState,
}
#[derive(Clone)]
pub struct AbortController {
signal: Option<AbortSignal>,
}
impl AbortController {
pub fn new() -> Self {
Self { signal: None }
}
pub fn with_signal(signal: AbortSignal) -> Self {
Self {
signal: Some(signal),
}
}
pub fn signal(&self) -> Option<&AbortSignal> {
self.signal.as_ref()
}
pub fn abort(&self) {
if let Some(signal) = &self.signal {
signal
.aborted
.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
pub fn is_aborted(&self) -> bool {
self.signal
.as_ref()
.map(|s| s.aborted.load(std::sync::atomic::Ordering::SeqCst))
.unwrap_or(false)
}
}
impl Default for AbortController {
fn default() -> Self {
Self::new()
}
}
pub struct AbortSignal {
aborted: std::sync::atomic::AtomicBool,
}
impl AbortSignal {
pub fn new() -> Self {
Self {
aborted: std::sync::atomic::AtomicBool::new(false),
}
}
pub fn aborted(&self) -> bool {
self.aborted.load(std::sync::atomic::Ordering::SeqCst)
}
}
impl Clone for AbortSignal {
fn clone(&self) -> Self {
Self {
aborted: std::sync::atomic::AtomicBool::new(self.aborted()),
}
}
}
impl Default for AbortSignal {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TaskStateBase {
pub id: String,
pub task_type: TaskType,
pub status: TaskStatus,
pub description: String,
pub tool_use_id: Option<String>,
pub start_time: u64,
pub end_time: Option<u64>,
pub total_paused_ms: Option<u64>,
pub output_file: String,
pub output_offset: u64,
pub notified: bool,
}
#[derive(Debug, Clone)]
pub struct LocalShellSpawnInput {
pub command: String,
pub description: String,
pub timeout: Option<u64>,
pub tool_use_id: Option<String>,
pub agent_id: Option<String>,
pub kind: Option<ShellKind>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ShellKind {
bash,
monitor,
}
impl ShellKind {
pub fn as_str(&self) -> &'static str {
match self {
ShellKind::bash => "bash",
ShellKind::monitor => "monitor",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"bash" => Some(ShellKind::bash),
"monitor" => Some(ShellKind::monitor),
_ => None,
}
}
}
pub trait Task: Send + Sync {
fn name(&self) -> &str;
fn task_type(&self) -> TaskType;
fn kill(
&self,
task_id: &str,
set_app_state: SetAppState,
) -> impl std::future::Future<Output = ()> + Send;
}
pub const TASK_ID_PREFIXES: &[(&str, &str)] = &[
("local_bash", "b"),
("local_agent", "a"),
("remote_agent", "r"),
("in_process_teammate", "t"),
("local_workflow", "w"),
("monitor_mcp", "m"),
("dream", "d"),
];
pub fn get_task_id_prefix(task_type: &TaskType) -> &'static str {
TASK_ID_PREFIXES
.iter()
.find(|(t, _)| *t == task_type.as_str())
.map(|(_, p)| *p)
.unwrap_or("x")
}
pub const TASK_ID_ALPHABET: &str = "0123456789abcdefghijklmnopqrstuvwxyz";
pub fn generate_task_id(task_type: &TaskType) -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let prefix = get_task_id_prefix(task_type);
let mut rng_seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
let mut id = prefix.to_string();
for i in 0..8 {
rng_seed = rng_seed.wrapping_mul(1103515245).wrapping_add(12345);
let alphabet_idx = (rng_seed >> (i * 3)) as usize % TASK_ID_ALPHABET.len();
id.push(TASK_ID_ALPHABET.chars().nth(alphabet_idx).unwrap());
}
id
}
pub fn get_task_output_path(task_id: &str) -> String {
format!("/tmp/task_output_{}.txt", task_id)
}
pub fn create_task_state_base(
id: String,
task_type: TaskType,
description: String,
tool_use_id: Option<String>,
) -> TaskStateBase {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
TaskStateBase {
id,
task_type,
status: TaskStatus::pending,
description,
tool_use_id,
start_time: now,
end_time: None,
total_paused_ms: None,
output_file: String::new(),
output_offset: 0,
notified: false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_type_strings() {
assert_eq!(TaskType::local_bash.as_str(), "local_bash");
assert_eq!(TaskType::remote_agent.as_str(), "remote_agent");
}
#[test]
fn test_task_status_strings() {
assert_eq!(TaskStatus::pending.as_str(), "pending");
assert_eq!(TaskStatus::completed.as_str(), "completed");
}
#[test]
fn test_is_terminal_task_status() {
assert!(!is_terminal_task_status(&TaskStatus::pending));
assert!(!is_terminal_task_status(&TaskStatus::running));
assert!(is_terminal_task_status(&TaskStatus::completed));
assert!(is_terminal_task_status(&TaskStatus::failed));
assert!(is_terminal_task_status(&TaskStatus::killed));
}
#[test]
fn test_shell_kind_strings() {
assert_eq!(ShellKind::bash.as_str(), "bash");
assert_eq!(ShellKind::monitor.as_str(), "monitor");
}
#[test]
fn test_generate_task_id() {
let id = generate_task_id(&TaskType::local_bash);
assert!(id.starts_with('b'));
assert_eq!(id.len(), 9); }
#[test]
fn test_task_id_prefix() {
assert_eq!(get_task_id_prefix(&TaskType::local_bash), "b");
assert_eq!(get_task_id_prefix(&TaskType::local_agent), "a");
assert_eq!(get_task_id_prefix(&TaskType::remote_agent), "r");
assert_eq!(get_task_id_prefix(&TaskType::in_process_teammate), "t");
assert_eq!(get_task_id_prefix(&TaskType::local_workflow), "w");
assert_eq!(get_task_id_prefix(&TaskType::monitor_mcp), "m");
assert_eq!(get_task_id_prefix(&TaskType::dream), "d");
}
}