Skip to main content

lean_ctx/core/a2a/
task.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub enum TaskState {
8    Created,
9    Working,
10    InputRequired,
11    Completed,
12    Failed,
13    Canceled,
14}
15
16impl std::fmt::Display for TaskState {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            TaskState::Created => write!(f, "created"),
20            TaskState::Working => write!(f, "working"),
21            TaskState::InputRequired => write!(f, "input-required"),
22            TaskState::Completed => write!(f, "completed"),
23            TaskState::Failed => write!(f, "failed"),
24            TaskState::Canceled => write!(f, "canceled"),
25        }
26    }
27}
28
29impl TaskState {
30    pub fn parse_str(s: &str) -> Option<Self> {
31        match s {
32            "created" => Some(Self::Created),
33            "working" => Some(Self::Working),
34            "input-required" | "input_required" => Some(Self::InputRequired),
35            "completed" => Some(Self::Completed),
36            "failed" => Some(Self::Failed),
37            "canceled" | "cancelled" => Some(Self::Canceled),
38            _ => None,
39        }
40    }
41
42    pub fn is_terminal(&self) -> bool {
43        matches!(self, Self::Completed | Self::Failed | Self::Canceled)
44    }
45
46    pub fn can_transition_to(&self, next: &TaskState) -> bool {
47        match self {
48            TaskState::Created => matches!(
49                next,
50                TaskState::Working | TaskState::Canceled | TaskState::Failed
51            ),
52            TaskState::Working => matches!(
53                next,
54                TaskState::InputRequired
55                    | TaskState::Completed
56                    | TaskState::Failed
57                    | TaskState::Canceled
58            ),
59            TaskState::InputRequired => matches!(
60                next,
61                TaskState::Working | TaskState::Canceled | TaskState::Failed
62            ),
63            TaskState::Completed | TaskState::Failed | TaskState::Canceled => false,
64        }
65    }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TaskMessage {
70    pub role: String,
71    pub parts: Vec<TaskPart>,
72    pub timestamp: DateTime<Utc>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76#[serde(tag = "type")]
77pub enum TaskPart {
78    #[serde(rename = "text")]
79    Text { text: String },
80    #[serde(rename = "data")]
81    Data { mime_type: String, data: String },
82    #[serde(rename = "file")]
83    File {
84        name: String,
85        mime_type: Option<String>,
86        data: Option<String>,
87        uri: Option<String>,
88    },
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TaskTransition {
93    pub from: TaskState,
94    pub to: TaskState,
95    pub timestamp: DateTime<Utc>,
96    pub reason: Option<String>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct Task {
101    pub id: String,
102    pub from_agent: String,
103    pub to_agent: String,
104    pub state: TaskState,
105    pub description: String,
106    pub messages: Vec<TaskMessage>,
107    pub artifacts: Vec<TaskPart>,
108    pub history: Vec<TaskTransition>,
109    pub metadata: HashMap<String, String>,
110    pub created_at: DateTime<Utc>,
111    pub updated_at: DateTime<Utc>,
112}
113
114impl Task {
115    pub fn new(from_agent: &str, to_agent: &str, description: &str) -> Self {
116        let now = Utc::now();
117        let id = format!("task-{}", generate_task_id());
118
119        Self {
120            id,
121            from_agent: from_agent.to_string(),
122            to_agent: to_agent.to_string(),
123            state: TaskState::Created,
124            description: description.to_string(),
125            messages: vec![TaskMessage {
126                role: from_agent.to_string(),
127                parts: vec![TaskPart::Text {
128                    text: description.to_string(),
129                }],
130                timestamp: now,
131            }],
132            artifacts: Vec::new(),
133            history: vec![TaskTransition {
134                from: TaskState::Created,
135                to: TaskState::Created,
136                timestamp: now,
137                reason: Some("task created".to_string()),
138            }],
139            metadata: HashMap::new(),
140            created_at: now,
141            updated_at: now,
142        }
143    }
144
145    pub fn transition(&mut self, new_state: TaskState, reason: Option<&str>) -> Result<(), String> {
146        if !self.state.can_transition_to(&new_state) {
147            return Err(format!(
148                "invalid transition: {} → {}",
149                self.state, new_state
150            ));
151        }
152
153        self.history.push(TaskTransition {
154            from: self.state.clone(),
155            to: new_state.clone(),
156            timestamp: Utc::now(),
157            reason: reason.map(|s| s.to_string()),
158        });
159
160        self.state = new_state;
161        self.updated_at = Utc::now();
162        Ok(())
163    }
164
165    pub fn add_message(&mut self, role: &str, parts: Vec<TaskPart>) {
166        self.messages.push(TaskMessage {
167            role: role.to_string(),
168            parts,
169            timestamp: Utc::now(),
170        });
171        self.updated_at = Utc::now();
172    }
173
174    pub fn add_artifact(&mut self, artifact: TaskPart) {
175        self.artifacts.push(artifact);
176        self.updated_at = Utc::now();
177    }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, Default)]
181pub struct TaskStore {
182    pub tasks: Vec<Task>,
183    pub updated_at: Option<DateTime<Utc>>,
184}
185
186impl TaskStore {
187    pub fn load() -> Self {
188        let path = match task_store_path() {
189            Some(p) => p,
190            None => return Self::default(),
191        };
192        match std::fs::read_to_string(&path) {
193            Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
194            Err(_) => Self::default(),
195        }
196    }
197
198    pub fn save(&self) -> std::io::Result<()> {
199        let path = match task_store_path() {
200            Some(p) => p,
201            None => {
202                return Err(std::io::Error::new(
203                    std::io::ErrorKind::NotFound,
204                    "no home dir",
205                ))
206            }
207        };
208
209        if let Some(parent) = path.parent() {
210            std::fs::create_dir_all(parent)?;
211        }
212
213        let json = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
214        let tmp = path.with_extension("tmp");
215        std::fs::write(&tmp, &json)?;
216        std::fs::rename(&tmp, &path)?;
217        Ok(())
218    }
219
220    pub fn create_task(&mut self, from: &str, to: &str, description: &str) -> String {
221        let task = Task::new(from, to, description);
222        let id = task.id.clone();
223        self.tasks.push(task);
224        self.updated_at = Some(Utc::now());
225        id
226    }
227
228    pub fn get_task(&self, task_id: &str) -> Option<&Task> {
229        self.tasks.iter().find(|t| t.id == task_id)
230    }
231
232    pub fn get_task_mut(&mut self, task_id: &str) -> Option<&mut Task> {
233        self.tasks.iter_mut().find(|t| t.id == task_id)
234    }
235
236    pub fn tasks_for_agent(&self, agent_id: &str) -> Vec<&Task> {
237        self.tasks
238            .iter()
239            .filter(|t| t.to_agent == agent_id || t.from_agent == agent_id)
240            .collect()
241    }
242
243    pub fn pending_tasks_for(&self, agent_id: &str) -> Vec<&Task> {
244        self.tasks
245            .iter()
246            .filter(|t| t.to_agent == agent_id && !t.state.is_terminal())
247            .collect()
248    }
249
250    pub fn cleanup_old(&mut self, max_age_hours: u64) {
251        let cutoff = Utc::now() - chrono::Duration::hours(max_age_hours as i64);
252        self.tasks
253            .retain(|t| !t.state.is_terminal() || t.updated_at > cutoff);
254    }
255}
256
257fn task_store_path() -> Option<PathBuf> {
258    dirs::home_dir().map(|h| h.join(".lean-ctx/agents/tasks.json"))
259}
260
261fn generate_task_id() -> String {
262    use std::time::{SystemTime, UNIX_EPOCH};
263    let ts = SystemTime::now()
264        .duration_since(UNIX_EPOCH)
265        .unwrap_or_default()
266        .as_millis();
267    let rand: u32 = (ts as u32).wrapping_mul(2654435761);
268    format!("{ts:x}-{rand:08x}")
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn task_lifecycle_happy_path() {
277        let mut task = Task::new("agent-a", "agent-b", "fix the bug");
278        assert_eq!(task.state, TaskState::Created);
279
280        task.transition(TaskState::Working, Some("started"))
281            .unwrap();
282        assert_eq!(task.state, TaskState::Working);
283
284        task.transition(TaskState::Completed, Some("done")).unwrap();
285        assert_eq!(task.state, TaskState::Completed);
286        assert_eq!(task.history.len(), 3);
287    }
288
289    #[test]
290    fn task_lifecycle_with_input_required() {
291        let mut task = Task::new("a", "b", "deploy");
292        task.transition(TaskState::Working, None).unwrap();
293        task.transition(TaskState::InputRequired, Some("need credentials"))
294            .unwrap();
295        task.transition(TaskState::Working, Some("got them"))
296            .unwrap();
297        task.transition(TaskState::Completed, None).unwrap();
298        assert_eq!(task.history.len(), 5);
299    }
300
301    #[test]
302    fn invalid_transitions_rejected() {
303        let mut task = Task::new("a", "b", "test");
304        task.transition(TaskState::Working, None).unwrap();
305        task.transition(TaskState::Completed, None).unwrap();
306
307        let err = task.transition(TaskState::Working, None);
308        assert!(err.is_err());
309    }
310
311    #[test]
312    fn task_store_operations() {
313        let mut store = TaskStore::default();
314        let id = store.create_task("agent-a", "agent-b", "review PR");
315        assert_eq!(store.tasks.len(), 1);
316
317        let task = store.get_task(&id).unwrap();
318        assert_eq!(task.from_agent, "agent-a");
319
320        let pending = store.pending_tasks_for("agent-b");
321        assert_eq!(pending.len(), 1);
322
323        store
324            .get_task_mut(&id)
325            .unwrap()
326            .transition(TaskState::Working, None)
327            .unwrap();
328        store
329            .get_task_mut(&id)
330            .unwrap()
331            .transition(TaskState::Completed, None)
332            .unwrap();
333
334        let pending = store.pending_tasks_for("agent-b");
335        assert_eq!(pending.len(), 0);
336    }
337
338    #[test]
339    fn terminal_states_correct() {
340        assert!(TaskState::Completed.is_terminal());
341        assert!(TaskState::Failed.is_terminal());
342        assert!(TaskState::Canceled.is_terminal());
343        assert!(!TaskState::Created.is_terminal());
344        assert!(!TaskState::Working.is_terminal());
345        assert!(!TaskState::InputRequired.is_terminal());
346    }
347}