Skip to main content

serdes_ai_a2a/
task.rs

1//! Task types for the A2A protocol.
2//!
3//! Tasks represent units of work submitted to an agent.
4
5use crate::schema::{Artifact, Message};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10/// Errors related to task operations.
11#[derive(Debug, Error)]
12pub enum TaskError {
13    /// Invalid state transition attempted.
14    #[error("Invalid state transition from {from} to {to}")]
15    InvalidStateTransition { from: TaskStatus, to: TaskStatus },
16}
17
18/// Unique identifier for a task.
19pub type TaskId = String;
20
21/// A task in the A2A protocol.
22///
23/// Tasks are the primary unit of work in A2A. They are submitted
24/// by clients and processed by agents.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Task {
27    /// Unique identifier for this task.
28    pub id: TaskId,
29    /// Thread ID for grouping related tasks.
30    pub thread_id: String,
31    /// Current status of the task.
32    pub status: TaskStatus,
33    /// The original message that created this task.
34    pub message: Message,
35    /// Messages in the conversation.
36    #[serde(default)]
37    pub messages: Vec<Message>,
38    /// Artifacts produced by the task.
39    #[serde(default)]
40    pub artifacts: Vec<Artifact>,
41    /// When the task was created.
42    pub created_at: DateTime<Utc>,
43    /// When the task was last updated.
44    pub updated_at: DateTime<Utc>,
45    /// Optional error message if the task failed.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub error: Option<String>,
48    /// Optional metadata.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub metadata: Option<serde_json::Value>,
51}
52
53impl Task {
54    /// Create a new task.
55    ///
56    /// The original message is automatically added to the messages list.
57    pub fn new(thread_id: impl Into<String>, message: Message) -> Self {
58        let now = Utc::now();
59        Self {
60            id: uuid::Uuid::new_v4().to_string(),
61            thread_id: thread_id.into(),
62            status: TaskStatus::Pending,
63            message: message.clone(),
64            messages: vec![message], // Include the original message
65            artifacts: Vec::new(),
66            created_at: now,
67            updated_at: now,
68            error: None,
69            metadata: None,
70        }
71    }
72
73    /// Create a new task with a specific ID.
74    ///
75    /// The original message is automatically added to the messages list.
76    pub fn with_id(id: impl Into<String>, thread_id: impl Into<String>, message: Message) -> Self {
77        let now = Utc::now();
78        Self {
79            id: id.into(),
80            thread_id: thread_id.into(),
81            status: TaskStatus::Pending,
82            message: message.clone(),
83            messages: vec![message], // Include the original message
84            artifacts: Vec::new(),
85            created_at: now,
86            updated_at: now,
87            error: None,
88            metadata: None,
89        }
90    }
91
92    /// Check if the task is pending.
93    pub fn is_pending(&self) -> bool {
94        self.status == TaskStatus::Pending
95    }
96
97    /// Check if the task is running.
98    pub fn is_running(&self) -> bool {
99        self.status == TaskStatus::Running
100    }
101
102    /// Check if the task is completed (success or failure).
103    pub fn is_completed(&self) -> bool {
104        matches!(
105            self.status,
106            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
107        )
108    }
109
110    /// Check if the task succeeded.
111    pub fn is_success(&self) -> bool {
112        self.status == TaskStatus::Completed
113    }
114
115    /// Check if the task failed.
116    pub fn is_failed(&self) -> bool {
117        self.status == TaskStatus::Failed
118    }
119
120    /// Mark the task as running.
121    ///
122    /// # Errors
123    ///
124    /// Returns `TaskError::InvalidStateTransition` if the task is not pending.
125    pub fn start(&mut self) -> Result<(), TaskError> {
126        if self.status != TaskStatus::Pending {
127            return Err(TaskError::InvalidStateTransition {
128                from: self.status,
129                to: TaskStatus::Running,
130            });
131        }
132        self.status = TaskStatus::Running;
133        self.updated_at = Utc::now();
134        Ok(())
135    }
136
137    /// Mark the task as completed.
138    ///
139    /// # Errors
140    ///
141    /// Returns `TaskError::InvalidStateTransition` if the task is not running.
142    pub fn complete(&mut self) -> Result<(), TaskError> {
143        if self.status != TaskStatus::Running {
144            return Err(TaskError::InvalidStateTransition {
145                from: self.status,
146                to: TaskStatus::Completed,
147            });
148        }
149        self.status = TaskStatus::Completed;
150        self.updated_at = Utc::now();
151        Ok(())
152    }
153
154    /// Mark the task as failed.
155    ///
156    /// # Errors
157    ///
158    /// Returns `TaskError::InvalidStateTransition` if the task is not running.
159    pub fn fail(&mut self, error: impl Into<String>) -> Result<(), TaskError> {
160        if self.status != TaskStatus::Running {
161            return Err(TaskError::InvalidStateTransition {
162                from: self.status,
163                to: TaskStatus::Failed,
164            });
165        }
166        self.status = TaskStatus::Failed;
167        self.error = Some(error.into());
168        self.updated_at = Utc::now();
169        Ok(())
170    }
171
172    /// Mark the task as cancelled.
173    ///
174    /// Cancellation is allowed from pending or running states.
175    ///
176    /// # Errors
177    ///
178    /// Returns `TaskError::InvalidStateTransition` if the task is already completed.
179    pub fn cancel(&mut self) -> Result<(), TaskError> {
180        if self.is_completed() {
181            return Err(TaskError::InvalidStateTransition {
182                from: self.status,
183                to: TaskStatus::Cancelled,
184            });
185        }
186        self.status = TaskStatus::Cancelled;
187        self.updated_at = Utc::now();
188        Ok(())
189    }
190
191    /// Force set status without validation (for internal/recovery use).
192    ///
193    /// Use with caution - this bypasses state transition validation.
194    pub fn force_status(&mut self, status: TaskStatus) {
195        self.status = status;
196        self.updated_at = Utc::now();
197    }
198
199    /// Add a message to the task.
200    pub fn add_message(&mut self, message: Message) {
201        self.messages.push(message);
202        self.updated_at = Utc::now();
203    }
204
205    /// Add an artifact to the task.
206    pub fn add_artifact(&mut self, artifact: Artifact) {
207        self.artifacts.push(artifact);
208        self.updated_at = Utc::now();
209    }
210
211    /// Set metadata on the task.
212    pub fn set_metadata(&mut self, metadata: serde_json::Value) {
213        self.metadata = Some(metadata);
214        self.updated_at = Utc::now();
215    }
216}
217
218/// Status of a task.
219#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
220#[serde(rename_all = "lowercase")]
221pub enum TaskStatus {
222    /// Task is waiting to be processed.
223    #[default]
224    Pending,
225    /// Task is currently being processed.
226    Running,
227    /// Task completed successfully.
228    Completed,
229    /// Task failed.
230    Failed,
231    /// Task was cancelled.
232    Cancelled,
233}
234
235impl std::fmt::Display for TaskStatus {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        match self {
238            TaskStatus::Pending => write!(f, "pending"),
239            TaskStatus::Running => write!(f, "running"),
240            TaskStatus::Completed => write!(f, "completed"),
241            TaskStatus::Failed => write!(f, "failed"),
242            TaskStatus::Cancelled => write!(f, "cancelled"),
243        }
244    }
245}
246
247/// Result of a task execution.
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct TaskResult {
250    /// The task ID.
251    pub task_id: TaskId,
252    /// Final status.
253    pub status: TaskStatus,
254    /// Response messages.
255    pub messages: Vec<Message>,
256    /// Produced artifacts.
257    pub artifacts: Vec<Artifact>,
258    /// Error message if failed.
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub error: Option<String>,
261    /// Execution duration in milliseconds.
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub duration_ms: Option<u64>,
264}
265
266impl TaskResult {
267    /// Create a successful result.
268    pub fn success(task_id: impl Into<String>, messages: Vec<Message>) -> Self {
269        Self {
270            task_id: task_id.into(),
271            status: TaskStatus::Completed,
272            messages,
273            artifacts: Vec::new(),
274            error: None,
275            duration_ms: None,
276        }
277    }
278
279    /// Create a failed result.
280    pub fn failure(task_id: impl Into<String>, error: impl Into<String>) -> Self {
281        Self {
282            task_id: task_id.into(),
283            status: TaskStatus::Failed,
284            messages: Vec::new(),
285            artifacts: Vec::new(),
286            error: Some(error.into()),
287            duration_ms: None,
288        }
289    }
290
291    /// Add artifacts to the result.
292    pub fn with_artifacts(mut self, artifacts: Vec<Artifact>) -> Self {
293        self.artifacts = artifacts;
294        self
295    }
296
297    /// Set the duration.
298    pub fn with_duration(mut self, duration_ms: u64) -> Self {
299        self.duration_ms = Some(duration_ms);
300        self
301    }
302
303    /// Check if the result is successful.
304    pub fn is_success(&self) -> bool {
305        self.status == TaskStatus::Completed
306    }
307
308    /// Check if the result is a failure.
309    pub fn is_failure(&self) -> bool {
310        self.status == TaskStatus::Failed
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::schema::Message;
318
319    #[test]
320    fn test_task_creation() {
321        let task = Task::new("thread-1", Message::user("Hello"));
322        assert!(task.is_pending());
323        assert!(!task.is_running());
324        assert!(!task.is_completed());
325    }
326
327    #[test]
328    fn test_task_includes_original_message() {
329        let task = Task::new("thread-1", Message::user("Hello"));
330
331        // Original message should be in the messages list
332        assert_eq!(task.messages.len(), 1);
333        assert_eq!(task.messages[0].text_content(), "Hello");
334        assert_eq!(task.message.text_content(), "Hello");
335    }
336
337    #[test]
338    fn test_task_with_id_includes_original_message() {
339        let task = Task::with_id("my-id", "thread-1", Message::user("Test"));
340
341        assert_eq!(task.id, "my-id");
342        assert_eq!(task.messages.len(), 1);
343        assert_eq!(task.messages[0].text_content(), "Test");
344    }
345
346    #[test]
347    fn test_task_lifecycle() {
348        let mut task = Task::new("thread-1", Message::user("Hello"));
349
350        task.start().unwrap();
351        assert!(task.is_running());
352
353        task.complete().unwrap();
354        assert!(task.is_completed());
355        assert!(task.is_success());
356    }
357
358    #[test]
359    fn test_task_failure() {
360        let mut task = Task::new("thread-1", Message::user("Hello"));
361
362        task.start().unwrap();
363        task.fail("Something went wrong").unwrap();
364
365        assert!(task.is_failed());
366        assert_eq!(task.error, Some("Something went wrong".to_string()));
367    }
368
369    #[test]
370    fn test_invalid_state_transition_start_from_running() {
371        let mut task = Task::new("thread-1", Message::user("Hello"));
372        task.start().unwrap();
373
374        let result = task.start();
375        assert!(result.is_err());
376        match result {
377            Err(TaskError::InvalidStateTransition { from, to }) => {
378                assert_eq!(from, TaskStatus::Running);
379                assert_eq!(to, TaskStatus::Running);
380            }
381            _ => panic!("Expected InvalidStateTransition error"),
382        }
383    }
384
385    #[test]
386    fn test_invalid_state_transition_complete_from_pending() {
387        let mut task = Task::new("thread-1", Message::user("Hello"));
388
389        let result = task.complete();
390        assert!(result.is_err());
391    }
392
393    #[test]
394    fn test_invalid_state_transition_complete_after_cancel() {
395        let mut task = Task::new("thread-1", Message::user("Hello"));
396        task.cancel().unwrap();
397
398        // Can't complete an already cancelled task
399        let result = task.complete();
400        assert!(result.is_err());
401    }
402
403    #[test]
404    fn test_cancel_from_pending() {
405        let mut task = Task::new("thread-1", Message::user("Hello"));
406        assert!(task.cancel().is_ok());
407        assert_eq!(task.status, TaskStatus::Cancelled);
408    }
409
410    #[test]
411    fn test_cancel_from_running() {
412        let mut task = Task::new("thread-1", Message::user("Hello"));
413        task.start().unwrap();
414        assert!(task.cancel().is_ok());
415        assert_eq!(task.status, TaskStatus::Cancelled);
416    }
417
418    #[test]
419    fn test_cannot_cancel_completed_task() {
420        let mut task = Task::new("thread-1", Message::user("Hello"));
421        task.start().unwrap();
422        task.complete().unwrap();
423
424        let result = task.cancel();
425        assert!(result.is_err());
426    }
427
428    #[test]
429    fn test_force_status() {
430        let mut task = Task::new("thread-1", Message::user("Hello"));
431        task.force_status(TaskStatus::Completed);
432        assert_eq!(task.status, TaskStatus::Completed);
433    }
434
435    #[test]
436    fn test_task_result_success() {
437        let result = TaskResult::success("task-1", vec![Message::agent("Done!")]);
438        assert!(result.is_success());
439        assert!(!result.is_failure());
440    }
441
442    #[test]
443    fn test_task_result_failure() {
444        let result = TaskResult::failure("task-1", "Error occurred");
445        assert!(result.is_failure());
446        assert_eq!(result.error, Some("Error occurred".to_string()));
447    }
448
449    #[test]
450    fn test_status_display() {
451        assert_eq!(TaskStatus::Pending.to_string(), "pending");
452        assert_eq!(TaskStatus::Running.to_string(), "running");
453        assert_eq!(TaskStatus::Completed.to_string(), "completed");
454        assert_eq!(TaskStatus::Failed.to_string(), "failed");
455        assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
456    }
457}