Skip to main content

a2a_rs/domain/core/
task.rs

1use crate::domain::error::A2AError;
2use serde::{Deserialize, Serialize};
3use serde_json::{Map, Value};
4
5#[cfg(feature = "tracing")]
6use tracing::instrument;
7
8#[cfg(feature = "tracing")]
9use crate::measure_duration;
10
11use super::message::{Artifact, Message};
12
13// Re-export generated types
14pub use crate::domain::generated::{Task, TaskPushNotificationConfig, TaskState, TaskStatus};
15
16#[allow(non_upper_case_globals)]
17impl TaskState {
18    pub const Submitted: Self = Self::TASK_STATE_SUBMITTED;
19    pub const Working: Self = Self::TASK_STATE_WORKING;
20    pub const InputRequired: Self = Self::TASK_STATE_INPUT_REQUIRED;
21    pub const Completed: Self = Self::TASK_STATE_COMPLETED;
22    pub const Canceled: Self = Self::TASK_STATE_CANCELED;
23    pub const Failed: Self = Self::TASK_STATE_FAILED;
24    pub const Rejected: Self = Self::TASK_STATE_REJECTED;
25    pub const AuthRequired: Self = Self::TASK_STATE_AUTH_REQUIRED;
26    pub const Unknown: Self = Self::TASK_STATE_UNSPECIFIED;
27
28    pub fn is_terminal(&self) -> bool {
29        matches!(
30            self,
31            Self::TASK_STATE_COMPLETED
32                | Self::TASK_STATE_FAILED
33                | Self::TASK_STATE_CANCELED
34                | Self::TASK_STATE_REJECTED
35        )
36    }
37}
38
39pub trait TaskStateExt {
40    fn is_terminal(&self) -> bool;
41}
42
43impl TaskStateExt for ::buffa::EnumValue<TaskState> {
44    fn is_terminal(&self) -> bool {
45        match self {
46            ::buffa::EnumValue::Known(state) => state.is_terminal(),
47            _ => false,
48        }
49    }
50}
51
52impl TaskStatus {
53    pub fn new(state: TaskState, message: Option<Message>) -> Self {
54        let timestamp = chrono::Utc::now();
55        let seconds = timestamp.timestamp();
56        let nanos = timestamp.timestamp_subsec_nanos() as i32;
57
58        Self {
59            state: ::buffa::EnumValue::from(state),
60            message: message.into(),
61            timestamp: ::buffa::MessageField::some(::buffa_types::google::protobuf::Timestamp {
62                seconds,
63                nanos,
64                ..Default::default()
65            }),
66            ..Default::default()
67        }
68    }
69
70    pub fn timestamp_utc(&self) -> Option<chrono::DateTime<chrono::Utc>> {
71        self.timestamp.as_option().and_then(|t| {
72            chrono::DateTime::<chrono::Utc>::from_timestamp(t.seconds, t.nanos as u32)
73        })
74    }
75}
76
77/// Parameters for identifying a task by ID.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct TaskIdParams {
80    pub id: String,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub metadata: Option<Map<String, Value>>,
83}
84
85/// Parameters for querying a task with optional history constraints.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct TaskQueryParams {
88    pub id: String,
89    #[serde(skip_serializing_if = "Option::is_none", rename = "historyLength")]
90    pub history_length: Option<u32>,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub metadata: Option<Map<String, Value>>,
93}
94
95/// Configuration options for sending messages including output modes and notifications.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MessageSendConfiguration {
98    #[serde(
99        skip_serializing_if = "Option::is_none",
100        rename = "acceptedOutputModes"
101    )]
102    pub accepted_output_modes: Option<Vec<String>>,
103    #[serde(skip_serializing_if = "Option::is_none", rename = "historyLength")]
104    pub history_length: Option<u32>,
105    #[serde(
106        skip_serializing_if = "Option::is_none",
107        rename = "pushNotificationConfig"
108    )]
109    pub push_notification_config: Option<TaskPushNotificationConfig>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub blocking: Option<bool>,
112}
113
114/// Parameters for sending a message with optional configuration.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct MessageSendParams {
117    pub message: Message,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub configuration: Option<MessageSendConfiguration>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub metadata: Option<Map<String, Value>>,
122}
123
124/// Parameters for sending a task (legacy)
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct TaskSendParams {
127    pub id: String,
128    #[serde(skip_serializing_if = "Option::is_none", rename = "sessionId")]
129    pub session_id: Option<String>,
130    pub message: Message,
131    #[serde(skip_serializing_if = "Option::is_none", rename = "pushNotification")]
132    pub push_notification: Option<TaskPushNotificationConfig>,
133    #[serde(skip_serializing_if = "Option::is_none", rename = "historyLength")]
134    pub history_length: Option<u32>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub metadata: Option<Map<String, Value>>,
137}
138
139/// Parameters for listing tasks with filtering and pagination.
140#[derive(Debug, Clone, Serialize, Deserialize, Default)]
141pub struct ListTasksParams {
142    #[serde(skip_serializing_if = "Option::is_none", rename = "contextId")]
143    pub context_id: Option<String>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub status: Option<TaskState>,
146    #[serde(skip_serializing_if = "Option::is_none", rename = "pageSize")]
147    pub page_size: Option<i32>,
148    #[serde(skip_serializing_if = "Option::is_none", rename = "pageToken")]
149    pub page_token: Option<String>,
150    #[serde(skip_serializing_if = "Option::is_none", rename = "historyLength")]
151    pub history_length: Option<i32>,
152    #[serde(skip_serializing_if = "Option::is_none", rename = "includeArtifacts")]
153    pub include_artifacts: Option<bool>,
154    #[serde(
155        skip_serializing_if = "Option::is_none",
156        rename = "statusTimestampAfter"
157    )]
158    pub status_timestamp_after: Option<String>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub metadata: Option<Map<String, Value>>,
161}
162
163/// Result object for tasks/list method.
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct ListTasksResult {
166    pub tasks: Vec<Task>,
167    #[serde(rename = "totalSize")]
168    pub total_size: i32,
169    #[serde(rename = "pageSize")]
170    pub page_size: i32,
171    #[serde(rename = "nextPageToken")]
172    pub next_page_token: String,
173}
174
175/// Parameters for getting a specific push notification config.
176#[derive(Debug, Clone, Serialize, Deserialize, Default)]
177pub struct GetTaskPushNotificationConfigParams {
178    pub id: String,
179    #[serde(
180        skip_serializing_if = "Option::is_none",
181        rename = "pushNotificationConfigId"
182    )]
183    pub push_notification_config_id: Option<String>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub metadata: Option<Map<String, Value>>,
186}
187
188/// Parameters for listing all push notification configs for a task.
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct ListTaskPushNotificationConfigsParams {
191    pub id: String,
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub metadata: Option<Map<String, Value>>,
194}
195
196/// Parameters for deleting a push notification config.
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct DeleteTaskPushNotificationConfigParams {
199    pub id: String,
200    #[serde(rename = "pushNotificationConfigId")]
201    pub push_notification_config_id: String,
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub metadata: Option<Map<String, Value>>,
204}
205
206pub struct TaskBuilder {
207    id: String,
208    context_id: String,
209    status: Option<TaskStatus>,
210    artifacts: Vec<Artifact>,
211    history: Vec<Message>,
212    metadata: Option<::buffa_types::google::protobuf::Struct>,
213}
214
215impl TaskBuilder {
216    pub fn new() -> Self {
217        Self {
218            id: String::new(),
219            context_id: String::new(),
220            status: None,
221            artifacts: Vec::new(),
222            history: Vec::new(),
223            metadata: None,
224        }
225    }
226
227    pub fn id(mut self, id: String) -> Self {
228        self.id = id;
229        self
230    }
231
232    pub fn context_id(mut self, context_id: String) -> Self {
233        self.context_id = context_id;
234        self
235    }
236
237    pub fn status(mut self, status: TaskStatus) -> Self {
238        self.status = Some(status);
239        self
240    }
241
242    pub fn artifacts(mut self, artifacts: Vec<Artifact>) -> Self {
243        self.artifacts = artifacts;
244        self
245    }
246
247    pub fn history(mut self, history: Vec<Message>) -> Self {
248        self.history = history;
249        self
250    }
251
252    pub fn metadata(mut self, metadata: ::buffa_types::google::protobuf::Struct) -> Self {
253        self.metadata = Some(metadata);
254        self
255    }
256
257    pub fn build(self) -> Task {
258        Task {
259            id: self.id,
260            context_id: self.context_id,
261            status: self
262                .status
263                .unwrap_or_else(|| TaskStatus::new(TaskState::TASK_STATE_SUBMITTED, None))
264                .into(),
265            artifacts: self.artifacts,
266            history: self.history,
267            metadata: self.metadata.into(),
268            ..Default::default()
269        }
270    }
271}
272
273impl Default for TaskBuilder {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279impl Task {
280    pub fn builder() -> TaskBuilder {
281        TaskBuilder::new()
282    }
283
284    /// Create a new task with the given ID in the submitted state
285    pub fn new(id: String, context_id: String) -> Self {
286        Self {
287            id,
288            context_id,
289            status: ::buffa::MessageField::some(TaskStatus::new(
290                TaskState::TASK_STATE_SUBMITTED,
291                None,
292            )),
293            artifacts: Vec::new(),
294            history: Vec::new(),
295            metadata: ::buffa::MessageField::none(),
296            ..Default::default()
297        }
298    }
299
300    /// Create a new task with the given ID and context ID in the submitted state
301    pub fn with_context(id: String, context_id: String) -> Self {
302        Self::new(id, context_id)
303    }
304
305    /// Update the task status
306    #[cfg_attr(feature = "tracing", instrument(skip(self, message), fields(
307        task.id = %self.id,
308        task.old_state = ?self.status.as_option().map(|s| &s.state),
309        task.new_state = ?state,
310        task.has_message = message.is_some()
311    )))]
312    pub fn update_status(&mut self, state: TaskState, message: Option<Message>) {
313        #[cfg(feature = "tracing")]
314        tracing::info!("Updating task status");
315
316        self.status = ::buffa::MessageField::some(TaskStatus::new(state, message.clone()));
317
318        if let Some(msg) = message {
319            self.history.push(msg);
320        }
321
322        #[cfg(feature = "tracing")]
323        tracing::info!("Task status updated successfully");
324    }
325
326    /// Get a copy of this task with history limited to the specified length
327    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(
328        task.id = %self.id,
329        history.current_size = self.history.len(),
330        history.requested_limit = ?history_length
331    )))]
332    pub fn with_limited_history(&self, history_length: Option<u32>) -> Self {
333        if history_length.is_none() {
334            #[cfg(feature = "tracing")]
335            tracing::debug!("No history truncation needed");
336            return self.clone();
337        }
338
339        #[cfg(feature = "tracing")]
340        let _span = tracing::Span::current();
341
342        let limit: usize = history_length.unwrap().try_into().unwrap_or(usize::MAX);
343
344        #[cfg(feature = "tracing")]
345        let mut task_copy = measure_duration!(_span, "operation.duration_ms", { self.clone() });
346
347        #[cfg(not(feature = "tracing"))]
348        let mut task_copy = self.clone();
349
350        if limit == 0 {
351            #[cfg(feature = "tracing")]
352            tracing::debug!("Removing all history (limit = 0)");
353            task_copy.history.clear();
354        } else if task_copy.history.len() > limit {
355            let items_to_skip = task_copy.history.len() - limit;
356            #[cfg(feature = "tracing")]
357            tracing::debug!(
358                "Truncating history from {} to {} items (removing {} oldest)",
359                self.history.len(),
360                limit,
361                items_to_skip
362            );
363            task_copy.history = task_copy
364                .history
365                .iter()
366                .skip(items_to_skip)
367                .cloned()
368                .collect();
369        }
370
371        task_copy
372    }
373
374    /// Add an artifact to the task
375    #[cfg_attr(feature = "tracing", instrument(skip(self, artifact), fields(
376        task.id = %self.id,
377        artifact.id = %artifact.artifact_id,
378        artifacts.count = self.artifacts.len()
379    )))]
380    pub fn add_artifact(&mut self, artifact: Artifact) {
381        self.artifacts.push(artifact);
382    }
383
384    /// Validate a task (useful after building with builder)
385    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(
386        task.id = %self.id,
387        task.state = ?self.status.as_option().map(|s| &s.state),
388        history.size = self.history.len()
389    )))]
390    pub fn validate(&self) -> Result<(), A2AError> {
391        #[cfg(feature = "tracing")]
392        tracing::debug!("Validating task");
393
394        let mut message_ids = std::collections::HashSet::new();
395        for (_index, message) in self.history.iter().enumerate() {
396            #[cfg(feature = "tracing")]
397            tracing::trace!("Validating message {} in history", _index);
398
399            if !message_ids.insert(&message.message_id) {
400                #[cfg(feature = "tracing")]
401                tracing::error!("Duplicate message ID found: {}", message.message_id);
402                return Err(A2AError::InvalidParams(format!(
403                    "Duplicate message ID in history: {}",
404                    message.message_id
405                )));
406            }
407            message.validate()?;
408        }
409
410        if let Some(status) = self.status.as_option() {
411            if let Some(msg) = status.message.as_option() {
412                #[cfg(feature = "tracing")]
413                tracing::trace!("Validating status message");
414                msg.validate()?;
415            }
416        }
417
418        #[cfg(feature = "tracing")]
419        tracing::debug!("Task validation successful");
420        Ok(())
421    }
422}
423
424/// A task paired with its storage version — the optimistic-concurrency token.
425///
426/// The version is a monotonic counter the storage adapter bumps on every
427/// successful mutation of the task. A caller reads a task and its version, then
428/// passes that version back on a conditional update
429/// ([`AsyncTaskVersioning::update_status_checked`](crate::port::AsyncTaskVersioning::update_status_checked));
430/// if another writer advanced the task in between, the update fails with
431/// [`A2AError::VersionConflict`](crate::domain::A2AError::VersionConflict) instead
432/// of silently clobbering it.
433#[derive(Debug, Clone, PartialEq)]
434pub struct VersionedTask {
435    /// The task at this version.
436    pub task: Task,
437    /// The storage version this snapshot was read or written at.
438    pub version: u64,
439}
440
441impl VersionedTask {
442    /// Pair a task with a version.
443    pub fn new(task: Task, version: u64) -> Self {
444        Self { task, version }
445    }
446}