Skip to main content

a2a_ao/
task.rs

1//! Task — the stateful unit of work in the A2A protocol.
2//!
3//! A Task represents an interaction between a client agent and a remote agent.
4//! Tasks have a full lifecycle with well-defined state transitions:
5//!
6//! ```text
7//! SUBMITTED → WORKING → COMPLETED (terminal)
8//!                     → FAILED (terminal)
9//!                     → CANCELED (terminal)
10//!                     → REJECTED (terminal)
11//!                     → INPUT_REQUIRED (interrupted)
12//!                     → AUTH_REQUIRED (interrupted)
13//! ```
14
15use chrono::{DateTime, Utc};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use uuid::Uuid;
19
20use crate::artifact::Artifact;
21use crate::message::Message;
22
23/// A Task — the fundamental unit of work in A2A.
24#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
25#[serde(rename_all = "camelCase")]
26pub struct Task {
27    /// Unique identifier for this task.
28    pub id: String,
29
30    /// Optional context ID grouping related tasks for session coherence.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub context_id: Option<String>,
33
34    /// Current state of the task.
35    pub state: TaskState,
36
37    /// Messages exchanged during the task (multi-turn conversation).
38    #[serde(default, skip_serializing_if = "Vec::is_empty")]
39    pub messages: Vec<Message>,
40
41    /// Artifacts produced by the task.
42    #[serde(default, skip_serializing_if = "Vec::is_empty")]
43    pub artifacts: Vec<Artifact>,
44
45    /// Optional metadata attached to the task.
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub metadata: Option<serde_json::Value>,
48
49    /// When the task was created.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub created_at: Option<DateTime<Utc>>,
52
53    /// When the task was last updated.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub updated_at: Option<DateTime<Utc>>,
56}
57
58impl Task {
59    /// Create a new task with a generated ID.
60    pub fn new() -> Self {
61        Self {
62            id: Uuid::new_v4().to_string(),
63            context_id: None,
64            state: TaskState::Submitted,
65            messages: Vec::new(),
66            artifacts: Vec::new(),
67            metadata: None,
68            created_at: Some(Utc::now()),
69            updated_at: Some(Utc::now()),
70        }
71    }
72
73    /// Create a new task within a context (session).
74    pub fn with_context(context_id: impl Into<String>) -> Self {
75        Self {
76            context_id: Some(context_id.into()),
77            ..Self::new()
78        }
79    }
80
81    /// Check if the task is in a terminal state.
82    pub fn is_terminal(&self) -> bool {
83        matches!(
84            self.state,
85            TaskState::Completed | TaskState::Failed | TaskState::Canceled | TaskState::Rejected
86        )
87    }
88
89    /// Check if the task is in an interrupted state (needs input or auth).
90    pub fn is_interrupted(&self) -> bool {
91        matches!(
92            self.state,
93            TaskState::InputRequired | TaskState::AuthRequired
94        )
95    }
96
97    /// Transition the task to a new state.
98    pub fn transition(&mut self, new_state: TaskState) -> Result<(), InvalidTransition> {
99        if self.is_terminal() {
100            return Err(InvalidTransition {
101                from: self.state.clone(),
102                to: new_state,
103            });
104        }
105        self.state = new_state;
106        self.updated_at = Some(Utc::now());
107        Ok(())
108    }
109
110    /// Add a message to the task.
111    pub fn add_message(&mut self, message: Message) {
112        self.messages.push(message);
113        self.updated_at = Some(Utc::now());
114    }
115
116    /// Add an artifact to the task.
117    pub fn add_artifact(&mut self, artifact: Artifact) {
118        self.artifacts.push(artifact);
119        self.updated_at = Some(Utc::now());
120    }
121}
122
123impl Default for Task {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129/// The state of a task in its lifecycle.
130#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
131#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
132pub enum TaskState {
133    /// Task has been submitted but not yet started.
134    Submitted,
135
136    /// Task is actively being worked on.
137    Working,
138
139    /// Task completed successfully (terminal).
140    Completed,
141
142    /// Task failed (terminal).
143    Failed,
144
145    /// Task was canceled by the client (terminal).
146    Canceled,
147
148    /// Task was rejected by the remote agent (terminal).
149    Rejected,
150
151    /// Task is paused, waiting for additional input from the client.
152    InputRequired,
153
154    /// Task is paused, waiting for authentication/authorization.
155    AuthRequired,
156}
157
158impl std::fmt::Display for TaskState {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        match self {
161            TaskState::Submitted => write!(f, "SUBMITTED"),
162            TaskState::Working => write!(f, "WORKING"),
163            TaskState::Completed => write!(f, "COMPLETED"),
164            TaskState::Failed => write!(f, "FAILED"),
165            TaskState::Canceled => write!(f, "CANCELED"),
166            TaskState::Rejected => write!(f, "REJECTED"),
167            TaskState::InputRequired => write!(f, "INPUT_REQUIRED"),
168            TaskState::AuthRequired => write!(f, "AUTH_REQUIRED"),
169        }
170    }
171}
172
173/// Error for invalid task state transitions.
174#[derive(Debug)]
175pub struct InvalidTransition {
176    pub from: TaskState,
177    pub to: TaskState,
178}
179
180impl std::fmt::Display for InvalidTransition {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        write!(
183            f,
184            "invalid task transition from {} to {}",
185            self.from, self.to
186        )
187    }
188}
189
190impl std::error::Error for InvalidTransition {}
191
192/// Parameters for querying/listing tasks.
193#[derive(Debug, Default, Clone, Serialize, Deserialize)]
194#[serde(rename_all = "camelCase")]
195pub struct TaskQueryParams {
196    /// Filter by context ID.
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub context_id: Option<String>,
199
200    /// Filter by task state.
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub state: Option<TaskState>,
203
204    /// Maximum number of tasks to return.
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub limit: Option<u32>,
207
208    /// Cursor for pagination.
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub cursor: Option<String>,
211}
212
213/// A streaming event for a task update (sent via SSE).
214#[derive(Debug, Clone, Serialize, Deserialize)]
215#[serde(rename_all = "camelCase", tag = "type")]
216pub enum TaskEvent {
217    /// Task state changed.
218    StateChanged { task_id: String, state: TaskState },
219
220    /// New message added to the task.
221    MessageAdded { task_id: String, message: Message },
222
223    /// New artifact produced.
224    ArtifactAdded { task_id: String, artifact: Artifact },
225
226    /// Partial artifact data (streaming).
227    ArtifactChunk {
228        task_id: String,
229        artifact_id: String,
230        chunk: String,
231    },
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_task_lifecycle() {
240        let mut task = Task::new();
241        assert_eq!(task.state, TaskState::Submitted);
242        assert!(!task.is_terminal());
243
244        task.transition(TaskState::Working).unwrap();
245        assert_eq!(task.state, TaskState::Working);
246
247        task.transition(TaskState::InputRequired).unwrap();
248        assert!(task.is_interrupted());
249
250        task.transition(TaskState::Working).unwrap();
251        task.transition(TaskState::Completed).unwrap();
252        assert!(task.is_terminal());
253
254        // Cannot transition from terminal state
255        assert!(task.transition(TaskState::Working).is_err());
256    }
257
258    #[test]
259    fn test_task_serialization() {
260        let task = Task::new();
261        let json = serde_json::to_string(&task).unwrap();
262        assert!(json.contains("SUBMITTED"));
263
264        let parsed: Task = serde_json::from_str(&json).unwrap();
265        assert_eq!(parsed.state, TaskState::Submitted);
266    }
267}