Skip to main content

lean_ctx/core/a2a/
task.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7pub enum TaskState {
8    Created,
9    Working,
10    InputRequired,
11    Completed,
12    Failed,
13    Canceled,
14}
15
16impl std::fmt::Display for TaskState {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            TaskState::Created => write!(f, "created"),
20            TaskState::Working => write!(f, "working"),
21            TaskState::InputRequired => write!(f, "input-required"),
22            TaskState::Completed => write!(f, "completed"),
23            TaskState::Failed => write!(f, "failed"),
24            TaskState::Canceled => write!(f, "canceled"),
25        }
26    }
27}
28
29impl TaskState {
30    pub fn parse_str(s: &str) -> Option<Self> {
31        match s {
32            "created" => Some(Self::Created),
33            "working" => Some(Self::Working),
34            "input-required" | "input_required" => Some(Self::InputRequired),
35            "completed" => Some(Self::Completed),
36            "failed" => Some(Self::Failed),
37            "canceled" | "cancelled" => Some(Self::Canceled),
38            _ => None,
39        }
40    }
41
42    pub fn is_terminal(&self) -> bool {
43        matches!(self, Self::Completed | Self::Failed | Self::Canceled)
44    }
45
46    pub fn can_transition_to(&self, next: &TaskState) -> bool {
47        match self {
48            TaskState::Created => matches!(
49                next,
50                TaskState::Working | TaskState::Canceled | TaskState::Failed
51            ),
52            TaskState::Working => matches!(
53                next,
54                TaskState::InputRequired
55                    | TaskState::Completed
56                    | TaskState::Failed
57                    | TaskState::Canceled
58            ),
59            TaskState::InputRequired => matches!(
60                next,
61                TaskState::Working | TaskState::Canceled | TaskState::Failed
62            ),
63            TaskState::Completed | TaskState::Failed | TaskState::Canceled => false,
64        }
65    }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TaskMessage {
70    pub role: String,
71    pub parts: Vec<TaskPart>,
72    pub timestamp: DateTime<Utc>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76#[serde(tag = "type")]
77pub enum TaskPart {
78    #[serde(rename = "text")]
79    Text { text: String },
80    #[serde(rename = "data")]
81    Data { mime_type: String, data: String },
82    #[serde(rename = "file")]
83    File {
84        name: String,
85        mime_type: Option<String>,
86        data: Option<String>,
87        uri: Option<String>,
88    },
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TaskTransition {
93    pub from: TaskState,
94    pub to: TaskState,
95    pub timestamp: DateTime<Utc>,
96    pub reason: Option<String>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct Task {
101    pub id: String,
102    pub from_agent: String,
103    pub to_agent: String,
104    pub state: TaskState,
105    pub description: String,
106    pub messages: Vec<TaskMessage>,
107    pub artifacts: Vec<TaskPart>,
108    pub history: Vec<TaskTransition>,
109    pub metadata: HashMap<String, String>,
110    pub created_at: DateTime<Utc>,
111    pub updated_at: DateTime<Utc>,
112}
113
114impl Task {
115    pub fn new(from_agent: &str, to_agent: &str, description: &str) -> Self {
116        let now = Utc::now();
117        let id = format!("task-{}", generate_task_id());
118
119        Self {
120            id,
121            from_agent: from_agent.to_string(),
122            to_agent: to_agent.to_string(),
123            state: TaskState::Created,
124            description: description.to_string(),
125            messages: vec![TaskMessage {
126                role: from_agent.to_string(),
127                parts: vec![TaskPart::Text {
128                    text: description.to_string(),
129                }],
130                timestamp: now,
131            }],
132            artifacts: Vec::new(),
133            history: vec![TaskTransition {
134                from: TaskState::Created,
135                to: TaskState::Created,
136                timestamp: now,
137                reason: Some("task created".to_string()),
138            }],
139            metadata: HashMap::new(),
140            created_at: now,
141            updated_at: now,
142        }
143    }
144
145    pub fn transition(&mut self, new_state: TaskState, reason: Option<&str>) -> Result<(), String> {
146        if !self.state.can_transition_to(&new_state) {
147            return Err(format!(
148                "invalid transition: {} → {}",
149                self.state, new_state
150            ));
151        }
152
153        self.history.push(TaskTransition {
154            from: self.state.clone(),
155            to: new_state.clone(),
156            timestamp: Utc::now(),
157            reason: reason.map(std::string::ToString::to_string),
158        });
159
160        self.state = new_state;
161        self.updated_at = Utc::now();
162        Ok(())
163    }
164
165    pub fn add_message(&mut self, role: &str, parts: Vec<TaskPart>) {
166        self.messages.push(TaskMessage {
167            role: role.to_string(),
168            parts,
169            timestamp: Utc::now(),
170        });
171        self.updated_at = Utc::now();
172    }
173
174    pub fn add_artifact(&mut self, artifact: TaskPart) {
175        self.artifacts.push(artifact);
176        self.updated_at = Utc::now();
177    }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, Default)]
181pub struct TaskStore {
182    pub tasks: Vec<Task>,
183    pub updated_at: Option<DateTime<Utc>>,
184}
185
186impl TaskStore {
187    pub fn load() -> Self {
188        let Some(path) = task_store_path() else {
189            return Self::default();
190        };
191        match std::fs::read_to_string(&path) {
192            Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
193            Err(_) => Self::default(),
194        }
195    }
196
197    pub fn save(&self) -> std::io::Result<()> {
198        let Some(path) = task_store_path() else {
199            return Err(std::io::Error::new(
200                std::io::ErrorKind::NotFound,
201                "no home dir",
202            ));
203        };
204
205        if let Some(parent) = path.parent() {
206            std::fs::create_dir_all(parent)?;
207        }
208
209        let json = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
210        let tmp = path.with_extension("tmp");
211        std::fs::write(&tmp, &json)?;
212        std::fs::rename(&tmp, &path)?;
213        Ok(())
214    }
215
216    pub fn create_task(&mut self, from: &str, to: &str, description: &str) -> String {
217        let task = Task::new(from, to, description);
218        let id = task.id.clone();
219        self.tasks.push(task);
220        self.updated_at = Some(Utc::now());
221        id
222    }
223
224    pub fn get_task(&self, task_id: &str) -> Option<&Task> {
225        self.tasks.iter().find(|t| t.id == task_id)
226    }
227
228    pub fn get_task_mut(&mut self, task_id: &str) -> Option<&mut Task> {
229        self.tasks.iter_mut().find(|t| t.id == task_id)
230    }
231
232    pub fn tasks_for_agent(&self, agent_id: &str) -> Vec<&Task> {
233        self.tasks
234            .iter()
235            .filter(|t| t.to_agent == agent_id || t.from_agent == agent_id)
236            .collect()
237    }
238
239    pub fn pending_tasks_for(&self, agent_id: &str) -> Vec<&Task> {
240        self.tasks
241            .iter()
242            .filter(|t| t.to_agent == agent_id && !t.state.is_terminal())
243            .collect()
244    }
245
246    pub fn cleanup_old(&mut self, max_age_hours: u64) {
247        let cutoff = Utc::now() - chrono::Duration::hours(max_age_hours as i64);
248        self.tasks
249            .retain(|t| !t.state.is_terminal() || t.updated_at > cutoff);
250    }
251}
252
253fn task_store_path() -> Option<PathBuf> {
254    dirs::home_dir().map(|h| h.join(".lean-ctx/agents/tasks.json"))
255}
256
257fn generate_task_id() -> String {
258    use std::time::{SystemTime, UNIX_EPOCH};
259    let ts = SystemTime::now()
260        .duration_since(UNIX_EPOCH)
261        .unwrap_or_default()
262        .as_millis();
263    let rand: u32 = (ts as u32).wrapping_mul(2654435761);
264    format!("{ts:x}-{rand:08x}")
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn task_lifecycle_happy_path() {
273        let mut task = Task::new("agent-a", "agent-b", "fix the bug");
274        assert_eq!(task.state, TaskState::Created);
275
276        task.transition(TaskState::Working, Some("started"))
277            .unwrap();
278        assert_eq!(task.state, TaskState::Working);
279
280        task.transition(TaskState::Completed, Some("done")).unwrap();
281        assert_eq!(task.state, TaskState::Completed);
282        assert_eq!(task.history.len(), 3);
283    }
284
285    #[test]
286    fn task_lifecycle_with_input_required() {
287        let mut task = Task::new("a", "b", "deploy");
288        task.transition(TaskState::Working, None).unwrap();
289        task.transition(TaskState::InputRequired, Some("need credentials"))
290            .unwrap();
291        task.transition(TaskState::Working, Some("got them"))
292            .unwrap();
293        task.transition(TaskState::Completed, None).unwrap();
294        assert_eq!(task.history.len(), 5);
295    }
296
297    #[test]
298    fn invalid_transitions_rejected() {
299        let mut task = Task::new("a", "b", "test");
300        task.transition(TaskState::Working, None).unwrap();
301        task.transition(TaskState::Completed, None).unwrap();
302
303        let err = task.transition(TaskState::Working, None);
304        assert!(err.is_err());
305    }
306
307    #[test]
308    fn task_store_operations() {
309        let mut store = TaskStore::default();
310        let id = store.create_task("agent-a", "agent-b", "review PR");
311        assert_eq!(store.tasks.len(), 1);
312
313        let task = store.get_task(&id).unwrap();
314        assert_eq!(task.from_agent, "agent-a");
315
316        let pending = store.pending_tasks_for("agent-b");
317        assert_eq!(pending.len(), 1);
318
319        store
320            .get_task_mut(&id)
321            .unwrap()
322            .transition(TaskState::Working, None)
323            .unwrap();
324        store
325            .get_task_mut(&id)
326            .unwrap()
327            .transition(TaskState::Completed, None)
328            .unwrap();
329
330        let pending = store.pending_tasks_for("agent-b");
331        assert_eq!(pending.len(), 0);
332    }
333
334    #[test]
335    fn terminal_states_correct() {
336        assert!(TaskState::Completed.is_terminal());
337        assert!(TaskState::Failed.is_terminal());
338        assert!(TaskState::Canceled.is_terminal());
339        assert!(!TaskState::Created.is_terminal());
340        assert!(!TaskState::Working.is_terminal());
341        assert!(!TaskState::InputRequired.is_terminal());
342    }
343}